-
Notifications
You must be signed in to change notification settings - Fork 114
feat: added low VRAM flash attention backend #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
How much performance improvement does flex attention offer in comparison? |
|
I ran a comparison on 8xH200 and added it to the benchmarks section. I had a slight improvement to flex attention in both VRAM usage and speed (25% faster). We could probably push it up further by supporting I was not able to get flex-attention to compile on B200, one of the core motivations for this feature. |
0279e03 to
c9dbc1f
Compare
|
Thanks! I was not able to use flex-attention on B200, too. Meanwhile, can you pre-commit your code? |
d660579 to
0e0bba9
Compare
|
There is still conflict with the main branch. |
| torch.manual_seed(0) | ||
|
|
||
|
|
||
| def assert_similar(ref, out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like there is a more significant numeric differences between the two approaches. How much of a difference if we do point to point comparison between ref and out?
|
I trained qwen2.5-vl-7B-eagle3 using the latest specforge 0.1.1 and sglang 0.5.5, and encountered ”AttributeError: 'Qwen2_5_VLForConditionalGeneration' object has no attribute 'set_aux_hidden_states_layers'“. I didn't have this issue when using the version before the fix. What could be the reason? |
|
@Abigbigbig This looks like a different issue from this PR. Let's move to a different issue. I can point you the fix |
Motivation
The two existing attention backends both exhibit inefficiencies which inhibit the training experience.
sdpabackend materializes the fullbsz x num_heads x q_len x kv_lenattention score matrix in VRAM, severely inhibiting max sequence length.flex_attentionbackend is very particular about the linux environment and often requires different compilation flags depending on package versions. We were not able to get this kernel to compile reliably ontorch==2.8.0.Using a log sum exp trick, we can avoid materializing any attention matrix while handling TTT KV cache with very minimal overhead. We support this using the flash attention backend since it readily provides us with an LSE tensor along with the O tensor. Flash attention 4 is also SOTA for training on Blackwell and while porting FA4 is out of scope of this PR, supporting the flash attention interface is a first step.
Modifications
Added a new
LlamaFlashAttentionmodule which has the same api asLlamaAttention(using a manual hidden cache).Within the forward pass, we:
Added a test file
test_flash_attention.pywhich verifies equivalence with the SDPA backend (up to bf16 numerical stability).Related Issues
Accuracy Test
Ran
python -m tests.test_utils.test_flash_attention:Benchmark & Profiling
Trained a speculator on custom data for GLM 4.5 on 8xH200 with batch size per GPU of 1 and sequence length of 32K. Here are the performance comparisons to flex attention:
We also trained for one epoch on perfectblend and achieved accept length of 3 on GSM8K with chain spec of 3 steps.
GLM 4.5 support was added in a custom branch built on top of this PR here.
Checklist