Skip to content

WIP: Lazy stencil expressions for kernel fusion#3401

Open
bendudson wants to merge 7 commits into
nextfrom
lazy-stencil-operators
Open

WIP: Lazy stencil expressions for kernel fusion#3401
bendudson wants to merge 7 commits into
nextfrom
lazy-stencil-operators

Conversation

@bendudson

@bendudson bendudson commented Jun 21, 2026

Copy link
Copy Markdown
Contributor

Operators return BinaryExpr types that perform derivatives 'lazily'. Like the single index operators they capture a CoordinatesAccessor to provide access to metrics etc. The difference is that the user code doesn't need to create field accessors or explicitly write the loop, capture variables etc. All that is done automatically if the compiler is smart/sufficiently aggressively optimizing.

Implemented DDZ_C2 and DDZ_C4 operators for Z derivatives. These appear to inline as expected in the Hasegawa-Wakatani example.

The DDZ(Field3D) operator now returns a lazy BinaryExpr type. It still allows runtime choice of numerical method (C2 / C4) and staggered fields.

Return BinaryExpr types, perform derivatives in Z
using CoordinateAccessor to get dz.

Appear to inline in Hasegawa-Wakatani example.
@bendudson bendudson added the work in progress Not ready for merging label Jun 21, 2026

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

Comment thread examples/hasegawa-wakatani/hw.cxx Outdated
int nz{0};

template <typename LView, typename RView>
BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "BOUT_FORCEINLINE" is directly included [misc-include-cleaner]

  BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
                   ^

Comment thread include/bout/stencil_expr.hxx
Comment thread include/bout/stencil_expr.hxx

template <typename LView, typename RView>
BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
const RView&) const {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
const RView&) const {
const RView& /*unused*/) const {


template <typename LView, typename RView>
BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
const RView&) const {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
const RView&) const {
const RView& /*unused*/) const {

inline bout::stencil::DDZExprC2 DDZ_C2(const Field3D& f) {
checkData(f);

const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]

  const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");
                                    ^
Additional context

include/bout/field_data.hxx:47: forward declaration of 'Mesh'

class Mesh;
      ^

f.getLocation(),
f.getDirections(),
region_id,
f.getMesh()->getRegion("RGN_NOBNDRY")};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]

      f.getMesh()->getRegion("RGN_NOBNDRY")};
                 ^
Additional context

include/bout/field_data.hxx:47: forward declaration of 'Mesh'

class Mesh;
      ^

inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) {
checkData(f);

const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]

  const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY");
                                    ^
Additional context

include/bout/field_data.hxx:47: forward declaration of 'Mesh'

class Mesh;
      ^

f.getLocation(),
f.getDirections(),
region_id,
f.getMesh()->getRegion("RGN_NOBNDRY")};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member access into incomplete type 'Mesh' [clang-diagnostic-error]

      f.getMesh()->getRegion("RGN_NOBNDRY")};
                 ^
Additional context

include/bout/field_data.hxx:47: forward declaration of 'Mesh'

class Mesh;
      ^

Implements Arakawa bracket in X-Z as a BinaryExpr.
Testing the performance impact of runtime dispatch. The hope is to
preserve the ability to switch method at runtime.  The conditional
that selects the method is the same for all iterations, so hopefully
good branch prediction and no warp divergence.

DDZ_Dispatch(f, method) selects between DDZ_C2 and DDZ_C4 at runtime.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

Comment thread examples/hasegawa-wakatani/hw.cxx Outdated

ddt(n) =
-bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi);
ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "ddt" is directly included [misc-include-cleaner]

    ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n)
    ^

Comment thread examples/hasegawa-wakatani/hw.cxx Outdated
ddt(n) =
-bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi);
ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n)
- kappa * DDZ_Dispatch(phi, DIFF_C2);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]

             - kappa * DDZ_Dispatch(phi, DIFF_C2);
                                         ^

struct DDZ_Dispatch_Op {
CoordinatesAccessor coords;
int nz{0};
DIFF_METHOD method{DIFF_DEFAULT};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_DEFAULT" is directly included [misc-include-cleaner]

  DIFF_METHOD method{DIFF_DEFAULT};
                     ^

struct DDZ_Dispatch_Op {
CoordinatesAccessor coords;
int nz{0};
DIFF_METHOD method{DIFF_DEFAULT};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_METHOD" is directly included [misc-include-cleaner]

  DIFF_METHOD method{DIFF_DEFAULT};
  ^


template <typename LView, typename RView>
BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
const RView&) const {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
const RView&) const {
const RView& /*unused*/) const {

BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs,
const RView&) const {
switch (method) {
case DIFF_C2:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]

    case DIFF_C2:
         ^

switch (method) {
case DIFF_C2:
return apply_c2(idx, lhs);
case DIFF_C4:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_C4" is directly included [misc-include-cleaner]

    case DIFF_C4:
         ^


if ((method != DIFF_C2) && (method != DIFF_C4)) {
throw BoutException("DDZ_Dispatch only supports DIFF_C2 and DIFF_C4, got {:s}",
toString(method));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "toString" is directly included [misc-include-cleaner]

                        toString(method));
                        ^

const Field3D& g) {
checkData(f);
checkData(g);
ASSERT1_FIELDS_COMPATIBLE(f, g);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "ASSERT1_FIELDS_COMPATIBLE" is directly included [misc-include-cleaner]

include/bout/stencil_expr.hxx:1:

- #ifndef BOUT_STENCIL_EXPR_HXX
+ #include "bout/field2d.hxx"
+ #ifndef BOUT_STENCIL_EXPR_HXX

Retains the runtime dispatch of DDZ, while building a BinaryExpr
lazy expression.
Stores the default numerical methods as DIFF_METHOD enums. This
enables runtime dispatch with minimal overhead in expressions.
Retains runtime choice of method, and handling of staggered inputs or
outputs. Only C2 and C4 methods are supported.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 25 out of 52. Check the log or trigger a new build to see more.

ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a););

ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2"););
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "CELL_DEFAULT" is directly included [misc-include-cleaner]

  ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
                                                ^

ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a););

ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2"););
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_C2" is directly included [misc-include-cleaner]

  ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
                                                              ^

ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a););

ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2"););
ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'start' of type 'SteadyClock' (aka 'time_pointstd::chrono::steady_clock') can be declared 'const' [misc-const-correctness]

  ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2););
  ^
Additional context

examples/performance/ddz/ddz.cxx:27: expanded from macro 'ITERATOR_TEST_BLOCK'

    SteadyClock start = steady_clock::now();                                        \
    ^

ITERATOR_TEST_BLOCK("DDZ W3", result = DDZ(a, CELL_DEFAULT, "DIFF_W3"););

ITERATOR_TEST_BLOCK("DDZ FFT", result = DDZ(a, CELL_DEFAULT, "DIFF_FFT"););
ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "DIFF_C4" is directly included [misc-include-cleaner]

  ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););
                                                              ^

ITERATOR_TEST_BLOCK("DDZ W3", result = DDZ(a, CELL_DEFAULT, "DIFF_W3"););

ITERATOR_TEST_BLOCK("DDZ FFT", result = DDZ(a, CELL_DEFAULT, "DIFF_FFT"););
ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'start' of type 'SteadyClock' (aka 'time_pointstd::chrono::steady_clock') can be declared 'const' [misc-const-correctness]

  ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4););
  ^
Additional context

examples/performance/ddz/ddz.cxx:27: expanded from macro 'ITERATOR_TEST_BLOCK'

    SteadyClock start = steady_clock::now();                                        \
    ^

Comment thread include/bout/mesh.hxx
static constexpr int num_staggers = 3;
static constexpr int num_deriv_kinds = 5;

std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: performing an implicit widening conversion to type 'std::size_t' (aka 'unsigned long') of a multiplication performed in type 'int' [bugprone-implicit-widening-of-multiplication-result]

    std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
                            ^
Additional context

include/bout/mesh.hxx:103: make conversion explicit to silence this warning

    std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
                            ^

include/bout/mesh.hxx:103: perform multiplication in a wider type

    std::array<DIFF_METHOD, num_directions * num_staggers * num_deriv_kinds> values{};
                            ^

Comment thread include/bout/mesh.hxx

DIFF_METHOD get(DIRECTION direction, DERIV deriv,
STAGGER stagger = STAGGER::None) const {
return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

      return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
             ^

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think values.at() suppresses this warning

Comment thread include/bout/mesh.hxx
DIFF_METHOD get(DIRECTION direction, DERIV deriv,
STAGGER stagger = STAGGER::None) const {
return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
* num_deriv_kinds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: '*' has higher precedence than '+'; add parentheses to explicitly specify the order of operations [readability-math-missing-parentheses]

Suggested change
* num_deriv_kinds
return values[((directionIndex(direction) * num_staggers + staggerIndex(stagger))
* num_deriv_kinds)

Comment thread include/bout/mesh.hxx

void set(DIRECTION direction, DERIV deriv, DIFF_METHOD method,
STAGGER stagger = STAGGER::None) {
values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

      values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
      ^

Comment thread include/bout/mesh.hxx
void set(DIRECTION direction, DERIV deriv, DIFF_METHOD method,
STAGGER stagger = STAGGER::None) {
values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))
* num_deriv_kinds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: '*' has higher precedence than '+'; add parentheses to explicitly specify the order of operations [readability-math-missing-parentheses]

Suggested change
* num_deriv_kinds
values[((directionIndex(direction) * num_staggers + staggerIndex(stagger))
* num_deriv_kinds)

@ZedThree ZedThree left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended as a replacement for the existing derivatives generation machinery?

Comment thread include/bout/field.hxx
Comment on lines +551 to +554
template <typename ResT, typename L, typename R, typename Func>
struct expression_result<BinaryExpr<ResT, L, R, Func>> {
using type = ResT;
};

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be clearer to consistently just use Result instead of ResT? It took me a minute to decipher that

Comment thread include/bout/mesh.hxx
Comment on lines +100 to +102
static constexpr int num_directions = 3;
static constexpr int num_staggers = 3;
static constexpr int num_deriv_kinds = 5;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making these std::size_t might fix a bunch of the clang-tidy warnings?

Comment thread include/bout/mesh.hxx
}
}

static DIFF_METHOD builtinDefaultMethod(DERIV deriv) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be also be constexpr

Comment thread include/bout/mesh.hxx
case DERIV::Flux:
return DIFF_U1;
}
throw BoutException("Unhandled derivative kind in builtinDefaultMethod");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use something like C++23's std::unreachable instead here

Comment thread include/bout/mesh.hxx

DIFF_METHOD get(DIRECTION direction, DERIV deriv,
STAGGER stagger = STAGGER::None) const {
return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think values.at() suppresses this warning

case STAGGER::L2C:
return apply_c2_l2c(idx, lhs);
}
return 0.0;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also use an unreachable() function

@bendudson

Copy link
Copy Markdown
Contributor Author

Is this intended as a replacement for the existing derivatives generation machinery?

Hi @ZedThree ! Yes, this would replace the index_derivs machinery. The current kernels are optimized for CPUs but are too small to run efficiently on GPUs. To have any chance of running on GPUs efficiently, kernels need to be merged into larger expressions. The single index operators were one way, but these template expressions hide a lot of the ceremony of creating field accessors, capturing variables etc. The interface stays essentially the same, apart from explicitly converting BinaryExpr to fields in some places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

work in progress Not ready for merging

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants