WIP: Lazy stencil expressions for kernel fusion#3401
Conversation
Return BinaryExpr types, perform derivatives in Z using CoordinateAccessor to get dz. Appear to inline in Hasegawa-Wakatani example.
| int nz{0}; | ||
|
|
||
| template <typename LView, typename RView> | ||
| BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, |
There was a problem hiding this comment.
warning: no header providing "BOUT_FORCEINLINE" is directly included [misc-include-cleaner]
BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
^|
|
||
| template <typename LView, typename RView> | ||
| BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, | ||
| const RView&) const { |
There was a problem hiding this comment.
warning: all parameters should be named in a function [readability-named-parameter]
| const RView&) const { | |
| const RView& /*unused*/) const { |
|
|
||
| template <typename LView, typename RView> | ||
| BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, | ||
| const RView&) const { |
There was a problem hiding this comment.
warning: all parameters should be named in a function [readability-named-parameter]
| const RView&) const { | |
| const RView& /*unused*/) const { |
| inline bout::stencil::DDZExprC2 DDZ_C2(const Field3D& f) { | ||
| checkData(f); | ||
|
|
||
| const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); |
There was a problem hiding this comment.
warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]
const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");
^Additional context
include/bout/field_data.hxx:47: forward declaration of 'Mesh'
class Mesh;
^| f.getLocation(), | ||
| f.getDirections(), | ||
| region_id, | ||
| f.getMesh()->getRegion("RGN_NOBNDRY")}; |
There was a problem hiding this comment.
warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]
f.getMesh()->getRegion("RGN_NOBNDRY")};
^Additional context
include/bout/field_data.hxx:47: forward declaration of 'Mesh'
class Mesh;
^| inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) { | ||
| checkData(f); | ||
|
|
||
| const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); |
There was a problem hiding this comment.
warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]
const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");
^Additional context
include/bout/field_data.hxx:47: forward declaration of 'Mesh'
class Mesh;
^| f.getLocation(), | ||
| f.getDirections(), | ||
| region_id, | ||
| f.getMesh()->getRegion("RGN_NOBNDRY")}; |
There was a problem hiding this comment.
warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]
f.getMesh()->getRegion("RGN_NOBNDRY")};
^Additional context
include/bout/field_data.hxx:47: forward declaration of 'Mesh'
class Mesh;
^Implements Arakawa bracket in X-Z as a BinaryExpr.
Testing the performance impact of runtime dispatch. The hope is to preserve the ability to switch method at runtime. The conditional that selects the method is the same for all iterations, so hopefully good branch prediction and no warp divergence. DDZ_Dispatch(f, method) selects between DDZ_C2 and DDZ_C4 at runtime.
|
|
||
| ddt(n) = | ||
| -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi); | ||
| ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) |
There was a problem hiding this comment.
warning: no header providing "ddt" is directly included [misc-include-cleaner]
ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n)
^| ddt(n) = | ||
| -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi); | ||
| ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) | ||
| - kappa * DDZ_Dispatch(phi, DIFF_C2); |
There was a problem hiding this comment.
warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]
- kappa * DDZ_Dispatch(phi, DIFF_C2);
^| struct DDZ_Dispatch_Op { | ||
| CoordinatesAccessor coords; | ||
| int nz{0}; | ||
| DIFF_METHOD method{DIFF_DEFAULT}; |
There was a problem hiding this comment.
warning: no header providing "DIFF_DEFAULT" is directly included [misc-include-cleaner]
DIFF_METHOD method{DIFF_DEFAULT};
^| struct DDZ_Dispatch_Op { | ||
| CoordinatesAccessor coords; | ||
| int nz{0}; | ||
| DIFF_METHOD method{DIFF_DEFAULT}; |
There was a problem hiding this comment.
warning: no header providing "DIFF_METHOD" is directly included [misc-include-cleaner]
DIFF_METHOD method{DIFF_DEFAULT};
^|
|
||
| template <typename LView, typename RView> | ||
| BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, | ||
| const RView&) const { |
There was a problem hiding this comment.
warning: all parameters should be named in a function [readability-named-parameter]
| const RView&) const { | |
| const RView& /*unused*/) const { |
| BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, | ||
| const RView&) const { | ||
| switch (method) { | ||
| case DIFF_C2: |
There was a problem hiding this comment.
warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]
case DIFF_C2:
^| switch (method) { | ||
| case DIFF_C2: | ||
| return apply_c2(idx, lhs); | ||
| case DIFF_C4: |
There was a problem hiding this comment.
warning: no header providing "DIFF_C4" is directly included [misc-include-cleaner]
case DIFF_C4:
^|
|
||
| if ((method != DIFF_C2) && (method != DIFF_C4)) { | ||
| throw BoutException("DDZ_Dispatch only supports DIFF_C2 and DIFF_C4, got {:s}", | ||
| toString(method)); |
There was a problem hiding this comment.
warning: no header providing "toString" is directly included [misc-include-cleaner]
toString(method));
^| const Field3D& g) { | ||
| checkData(f); | ||
| checkData(g); | ||
| ASSERT1_FIELDS_COMPATIBLE(f, g); |
There was a problem hiding this comment.
warning: no header providing "ASSERT1_FIELDS_COMPATIBLE" is directly included [misc-include-cleaner]
include/bout/stencil_expr.hxx:1:
- #ifndef BOUT_STENCIL_EXPR_HXX
+ #include "bout/field2d.hxx"
+ #ifndef BOUT_STENCIL_EXPR_HXXRetains the runtime dispatch of DDZ, while building a BinaryExpr lazy expression.
Stores the default numerical methods as DIFF_METHOD enums. This enables runtime dispatch with minimal overhead in expressions.
Retains runtime choice of method, and handling of staggered inputs or outputs. Only C2 and C4 methods are supported.
| ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a);); | ||
|
|
||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2");); | ||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2);); |
There was a problem hiding this comment.
warning: no header providing "CELL_DEFAULT" is directly included [misc-include-cleaner]
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
^| ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a);); | ||
|
|
||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2");); | ||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2);); |
There was a problem hiding this comment.
warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
^| ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a);); | ||
|
|
||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2");); | ||
| ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2);); |
There was a problem hiding this comment.
warning: variable 'start' of type 'SteadyClock' (aka 'time_pointstd::chrono::steady_clock') can be declared 'const' [misc-const-correctness]
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
^Additional context
examples/performance/ddz/ddz.cxx:27: expanded from macro 'ITERATOR_TEST_BLOCK'
SteadyClock start = steady_clock::now(); \
^| ITERATOR_TEST_BLOCK("DDZ W3", result = DDZ(a, CELL_DEFAULT, "DIFF_W3");); | ||
|
|
||
| ITERATOR_TEST_BLOCK("DDZ FFT", result = DDZ(a, CELL_DEFAULT, "DIFF_FFT");); | ||
| ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4);); |
There was a problem hiding this comment.
warning: no header providing "DIFF_C4" is directly included [misc-include-cleaner]
ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););
^| ITERATOR_TEST_BLOCK("DDZ W3", result = DDZ(a, CELL_DEFAULT, "DIFF_W3");); | ||
|
|
||
| ITERATOR_TEST_BLOCK("DDZ FFT", result = DDZ(a, CELL_DEFAULT, "DIFF_FFT");); | ||
| ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4);); |
There was a problem hiding this comment.
warning: variable 'start' of type 'SteadyClock' (aka 'time_pointstd::chrono::steady_clock') can be declared 'const' [misc-const-correctness]
ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););
^Additional context
examples/performance/ddz/ddz.cxx:27: expanded from macro 'ITERATOR_TEST_BLOCK'
SteadyClock start = steady_clock::now(); \
^| static constexpr int num_staggers = 3; | ||
| static constexpr int num_deriv_kinds = 5; | ||
|
|
||
| std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{}; |
There was a problem hiding this comment.
warning: performing an implicit widening conversion to type 'std::size_t' (aka 'unsigned long') of a multiplication performed in type 'int' [bugprone-implicit-widening-of-multiplication-result]
std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
^Additional context
include/bout/mesh.hxx:103: make conversion explicit to silence this warning
std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
^include/bout/mesh.hxx:103: perform multiplication in a wider type
std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
^|
|
||
| DIFF_METHOD get(DIRECTION direction, DERIV deriv, | ||
| STAGGER stagger = STAGGER::None) const { | ||
| return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) |
There was a problem hiding this comment.
warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]
return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
^There was a problem hiding this comment.
I think values.at() suppresses this warning
| DIFF_METHOD get(DIRECTION direction, DERIV deriv, | ||
| STAGGER stagger = STAGGER::None) const { | ||
| return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) | ||
| * num_deriv_kinds |
There was a problem hiding this comment.
warning: '*' has higher precedence than '+'; add parentheses to explicitly specify the order of operations [readability-math-missing-parentheses]
| * num_deriv_kinds | |
| return values[((directionIndex(direction) * num_staggers + staggerIndex(stagger)) | |
| * num_deriv_kinds) |
|
|
||
| void set(DIRECTION direction, DERIV deriv, DIFF_METHOD method, | ||
| STAGGER stagger = STAGGER::None) { | ||
| values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) |
There was a problem hiding this comment.
warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]
values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
^| void set(DIRECTION direction, DERIV deriv, DIFF_METHOD method, | ||
| STAGGER stagger = STAGGER::None) { | ||
| values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) | ||
| * num_deriv_kinds |
There was a problem hiding this comment.
warning: '*' has higher precedence than '+'; add parentheses to explicitly specify the order of operations [readability-math-missing-parentheses]
| * num_deriv_kinds | |
| values[((directionIndex(direction) * num_staggers + staggerIndex(stagger)) | |
| * num_deriv_kinds) |
ZedThree
left a comment
There was a problem hiding this comment.
Is this intended as a replacement for the existing derivatives generation machinery?
| template <typename ResT, typename L, typename R, typename Func> | ||
| struct expression_result<BinaryExpr<ResT, L, R, Func>> { | ||
| using type = ResT; | ||
| }; |
There was a problem hiding this comment.
Might be clearer to consistently just use Result instead of ResT? It took me a minute to decipher that
| static constexpr int num_directions = 3; | ||
| static constexpr int num_staggers = 3; | ||
| static constexpr int num_deriv_kinds = 5; |
There was a problem hiding this comment.
Making these std::size_t might fix a bunch of the clang-tidy warnings?
| } | ||
| } | ||
|
|
||
| static DIFF_METHOD builtinDefaultMethod(DERIV deriv) { |
There was a problem hiding this comment.
I think these can be also be constexpr
| case DERIV::Flux: | ||
| return DIFF_U1; | ||
| } | ||
| throw BoutException("Unhandled derivative kind in builtinDefaultMethod"); |
There was a problem hiding this comment.
We could use something like C++23's std::unreachable instead here
|
|
||
| DIFF_METHOD get(DIRECTION direction, DERIV deriv, | ||
| STAGGER stagger = STAGGER::None) const { | ||
| return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) |
There was a problem hiding this comment.
I think values.at() suppresses this warning
| case STAGGER::L2C: | ||
| return apply_c2_l2c(idx, lhs); | ||
| } | ||
| return 0.0; |
There was a problem hiding this comment.
This could also use an unreachable() function
Hi @ZedThree ! Yes, this would replace the index_derivs machinery. The current kernels are optimized for CPUs but are too small to run efficiently on GPUs. To have any chance of running on GPUs efficiently, kernels need to be merged into larger expressions. The single index operators were one way, but these template expressions hide a lot of the ceremony of creating field accessors, capturing variables etc. The interface stays essentially the same, apart from explicitly converting BinaryExpr to fields in some places. |
Operators return BinaryExpr types that perform derivatives 'lazily'. Like the single index operators they capture a CoordinatesAccessor to provide access to metrics etc. The difference is that the user code doesn't need to create field accessors or explicitly write the loop, capture variables etc. All that is done automatically if the compiler is smart/sufficiently aggressively optimizing.
Implemented DDZ_C2 and DDZ_C4 operators for Z derivatives. These appear to inline as expected in the Hasegawa-Wakatani example.
The DDZ(Field3D) operator now returns a lazy BinaryExpr type. It still allows runtime choice of numerical method (C2 / C4) and staggered fields.