-
Notifications
You must be signed in to change notification settings - Fork 17
scaled_dot_production_attention support for additional Torch backends (Flash Attention)
#365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…h NestedTensor Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
…ion backends. Added pytests that test each attention backend and compatible data type Signed-off-by: Jonathan Swartz <[email protected]>
Co-authored-by: Copilot <[email protected]> Signed-off-by: Jonathan Swartz <[email protected]>
Co-authored-by: Copilot <[email protected]> Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
…offsets instead of joffsets. Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables support for additional PyTorch attention backends (Flash Attention, Efficient Attention, Math) in the scaled_dot_product_attention function by wrapping JaggedTensor data in PyTorch nested tensors. The implementation intelligently selects between zero-copy views and tensor copies based on backend requirements to optimize performance.
Key Changes:
- Added backend-aware nested tensor creation logic in C++ that chooses between zero-copy views (
make_nested_view) and tensor copies (make_nested_tensor) based on the selected attention backend - Expanded test coverage to parameterize tests across all supported backends (Flash, Efficient, Math) with their compatible data types
- Adjusted dimension handling in tests to accommodate Flash Attention's requirement that q, k, v have matching last dimensions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| tests/unit/test_jagged_tensor.py | Parameterized SDPA tests to cover Flash, Efficient, and Math backends with appropriate dtypes; added Flash Attention dimension constraint |
| src/fvdb/FVDB.cpp | Implemented backend-aware nested tensor creation with two helper functions and runtime backend detection logic to optimize tensor creation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
… wrong creation function Use enums instead of ints Signed-off-by: Jonathan Swartz <[email protected]>
Signed-off-by: Jonathan Swartz <[email protected]>
blackencino
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have three big questions after going through this more carefully.
-
Why do we need to do any of it in C++? Could we do the same thing in Python, within the JaggedTensor frontend, and work towards the goal of having less and less C++ code where we can avoid it.
-
We're deferring to the torch backend specification mechanism as the way users would go about defining the attention machinery, which aligns with how they'd use torch. I like that the default torch behavior is to try to choose the best backend based on your usage. Does fVDB need/want to have any stronger opinions about which backends to use? I think we probably don't, but I get cautious around API decisions that expose expert-level algorithm internals.
-
Our tests seem mostly like smoke tests at this point, we don't have a validation against what SDPA is supposed to mean, or any conceptual explanation. Plus, we're doing permutation of inputs and outputs when comparing to torch , which makes it hard to say that what we're computing is what we expect to see. We should be wrapping the SDPA API so that we consume and produce the same dimensional ordering as torch, except where that's impossible.
| # Torch -- For-loop approach (always use MATH for reference to ensure consistency) | ||
| out_jagged_torch_forloop_list = [] | ||
| for b in range(batch_size): | ||
| # From LHE to NHLE / SHV to NHSV |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LHE, NHLE, SHV, NHSV - these are not meaningful abbreviations. This is potentially mission-critical code, we should be explaining our assumptions and what we're actually testing against.
Permuting our data to work with torch APIs, or to test the results, is something we worked hard to get rid of in the convolution code. Is there a way we can wrap our use of attention such that the dimensional ordering in and out matches what torch would consume or produce? These permute lines look very much like what we've just gotten rid of in convolution.
We absolutely could do this in python. Is having less and less C++ code the goal? For the longest time the goal was to have as thin of a python layer as we could for binding and have all the meaningful logic in C++. One potential use-case of having this in C++ is portability to other systems (such as inference systems like ONNX) where perhaps the user needs a C++ equivalent to
That I don't really know if we have any stronger opinions, perhaps @heiwang1997 would have thoughts. I think consistent behaviour of backend choice to Torch is probably a good goal… if people are using the same configurations in a network that uses our operator and Torch's for different things, I think it'd be expected for the user to see the same attention backend selection used.
I looked around PyTorch's test for their operator, their tests largely compare results of random inputs against reference implementations. So like here they have a reference SDPA function implemented entirely from basic PyTorch operations and compare the outputs to the SDPA operators: https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1122 Not saying we shouldn't do as you suggest, just providing info on how PyTorch validates these operators. |
This PR makes it possible to use other available Torch attention backends (flash, efficient, math) by wrapping the input
JaggedTensorsdata in Torch nested Tensors. The mechanism for selecting these other backends is the same as the Torch built-in attentionsdpa_kernel(see examples at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html):Some backends allow us to just take views of the
JaggedTensordata to create the nested Tensor and others require copies, this PR tries to be smart to check what backend PyTorch is being configured to use and allows backends which can run on the nested Tensors created by views to do so.Also changed SDPA tests to run all the backends with their available data types
fixes #363