Skip to content

Conversation

@swahtz
Copy link
Contributor

@swahtz swahtz commented Dec 3, 2025

This PR makes it possible to use other available Torch attention backends (flash, efficient, math) by wrapping the input JaggedTensors data in Torch nested Tensors. The mechanism for selecting these other backends is the same as the Torch built-in attention sdpa_kernel (see examples at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html):

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
            fvdb.scaled_dot_product_attention(query, key, value, 1)

Some backends allow us to just take views of the JaggedTensor data to create the nested Tensor and others require copies, this PR tries to be smart to check what backend PyTorch is being configured to use and allows backends which can run on the nested Tensors created by views to do so.

Also changed SDPA tests to run all the backends with their available data types

fixes #363

swahtz and others added 12 commits December 3, 2025 13:41
…ion backends.

Added pytests that test each attention backend and compatible data type

Signed-off-by: Jonathan Swartz <[email protected]>
Co-authored-by: Copilot <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Co-authored-by: Copilot <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
…offsets instead of joffsets.

Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
@swahtz swahtz requested a review from a team as a code owner December 3, 2025 07:35
@swahtz swahtz requested review from blackencino, Copilot and phapalova and removed request for blackencino and phapalova December 3, 2025 07:35
@swahtz swahtz added the core library Core fVDB library. i.e. anything in the _Cpp module (C++) or fvdb python module label Dec 3, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables support for additional PyTorch attention backends (Flash Attention, Efficient Attention, Math) in the scaled_dot_product_attention function by wrapping JaggedTensor data in PyTorch nested tensors. The implementation intelligently selects between zero-copy views and tensor copies based on backend requirements to optimize performance.

Key Changes:

  • Added backend-aware nested tensor creation logic in C++ that chooses between zero-copy views (make_nested_view) and tensor copies (make_nested_tensor) based on the selected attention backend
  • Expanded test coverage to parameterize tests across all supported backends (Flash, Efficient, Math) with their compatible data types
  • Adjusted dimension handling in tests to accommodate Flash Attention's requirement that q, k, v have matching last dimensions

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
tests/unit/test_jagged_tensor.py Parameterized SDPA tests to cover Flash, Efficient, and Math backends with appropriate dtypes; added Flash Attention dimension constraint
src/fvdb/FVDB.cpp Implemented backend-aware nested tensor creation with two helper functions and runtime backend detection logic to optimize tensor creation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

… wrong creation function

Use enums instead of ints

Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Copy link
Contributor

@blackencino blackencino left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have three big questions after going through this more carefully.

  1. Why do we need to do any of it in C++? Could we do the same thing in Python, within the JaggedTensor frontend, and work towards the goal of having less and less C++ code where we can avoid it.

  2. We're deferring to the torch backend specification mechanism as the way users would go about defining the attention machinery, which aligns with how they'd use torch. I like that the default torch behavior is to try to choose the best backend based on your usage. Does fVDB need/want to have any stronger opinions about which backends to use? I think we probably don't, but I get cautious around API decisions that expose expert-level algorithm internals.

  3. Our tests seem mostly like smoke tests at this point, we don't have a validation against what SDPA is supposed to mean, or any conceptual explanation. Plus, we're doing permutation of inputs and outputs when comparing to torch , which makes it hard to say that what we're computing is what we expect to see. We should be wrapping the SDPA API so that we consume and produce the same dimensional ordering as torch, except where that's impossible.

# Torch -- For-loop approach (always use MATH for reference to ensure consistency)
out_jagged_torch_forloop_list = []
for b in range(batch_size):
# From LHE to NHLE / SHV to NHSV
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LHE, NHLE, SHV, NHSV - these are not meaningful abbreviations. This is potentially mission-critical code, we should be explaining our assumptions and what we're actually testing against.

Permuting our data to work with torch APIs, or to test the results, is something we worked hard to get rid of in the convolution code. Is there a way we can wrap our use of attention such that the dimensional ordering in and out matches what torch would consume or produce? These permute lines look very much like what we've just gotten rid of in convolution.

@swahtz
Copy link
Contributor Author

swahtz commented Dec 11, 2025

I have three big questions after going through this more carefully.

  1. Why do we need to do any of it in C++? Could we do the same thing in Python, within the JaggedTensor frontend, and work towards the goal of having less and less C++ code where we can avoid it.

We absolutely could do this in python. Is having less and less C++ code the goal? For the longest time the goal was to have as thin of a python layer as we could for binding and have all the meaningful logic in C++. One potential use-case of having this in C++ is portability to other systems (such as inference systems like ONNX) where perhaps the user needs a C++ equivalent to scaled_dot_product_attention available to them that will match what they called in PyTorch in python.

  1. We're deferring to the torch backend specification mechanism as the way users would go about defining the attention machinery, which aligns with how they'd use torch. I like that the default torch behavior is to try to choose the best backend based on your usage. Does fVDB need/want to have any stronger opinions about which backends to use? I think we probably don't, but I get cautious around API decisions that expose expert-level algorithm internals.

That I don't really know if we have any stronger opinions, perhaps @heiwang1997 would have thoughts. I think consistent behaviour of backend choice to Torch is probably a good goal… if people are using the same configurations in a network that uses our operator and Torch's for different things, I think it'd be expected for the user to see the same attention backend selection used.

  1. Our tests seem mostly like smoke tests at this point, we don't have a validation against what SDPA is supposed to mean, or any conceptual explanation. Plus, we're doing permutation of inputs and outputs when comparing to torch , which makes it hard to say that what we're computing is what we expect to see. We should be wrapping the SDPA API so that we consume and produce the same dimensional ordering as torch, except where that's impossible.

I looked around PyTorch's test for their operator, their tests largely compare results of random inputs against reference implementations. So like here they have a reference SDPA function implemented entirely from basic PyTorch operations and compare the outputs to the SDPA operators:

https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1122

Not saying we shouldn't do as you suggest, just providing info on how PyTorch validates these operators.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

core library Core fVDB library. i.e. anything in the _Cpp module (C++) or fvdb python module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Evaluate using PyTorch's FlashAttention

2 participants