Skip to content

Conversation

Daisy-Ma-coder
Copy link
Contributor

@Daisy-Ma-coder Daisy-Ma-coder commented Oct 17, 2025

Bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490

Run into illegal memory access error when testing some prompts with prefix caching enabled on Flash Attention MLA backend

Log below is generated with CUDA_LAUNCH_BLOCKING=1 which indicating it's flash attn mla.

INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:CUDA error (../../.deps/vllm-flash-attn-src/hopper/flash_fwd_combine_launch_template.h:60): an illegal memory access was encountered
INFO:/scripts/vllm_scripts/utils.py:[1;36m(EngineCore_0 pid=481)[0;0m ERROR 10-13 10:51:40 [multiproc_executor.py:146] Worker proc VllmWorker-5 died unexpectedly, shutting down executor.
...

And realized it's the same root cause as #25490 where get_scheduler_metadata was being called with a different max_num_splits than what was being passed to FlashAttnMLAMetadata.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an illegal memory access error in Flash Attention MLA with full CUDA graph support by ensuring get_scheduler_metadata and FlashAttnMLAMetadata receive the same max_num_splits value. The changes correctly refactor the logic to calculate max_num_splits before it's used. However, I've identified a remaining logic issue where a similar discrepancy can occur when vllm_is_batch_invariant() is true, which could lead to the same bug under different conditions. I've provided a suggestion to fully resolve this.

qqma added 2 commits October 17, 2025 15:36
Signed-off-by: qqma <[email protected]>
Signed-off-by: qqma <[email protected]>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) October 17, 2025 23:00
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 17, 2025
@Daisy-Ma-coder
Copy link
Contributor Author

seems like the failed tests are unrelated, is it fine to still merge it?

@LucasWilkinson LucasWilkinson merged commit 5beacce into vllm-project:main Oct 22, 2025
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants