Skip to content

Conversation

bradleyhd
Copy link

Summary:
In #26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when VLLM_ATTENTION_BACKEND is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Oct 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request makes the FlashAttention backend upgrade for Vision Transformer (ViT) models an opt-in behavior, addressing an issue where it was unconditionally attempted, causing problems on AMD platforms. The change is implemented by introducing a try_switch_to_fa flag in maybe_get_vit_flash_attn_backend and updating the call sites in various models.

The overall approach is sound and correctly addresses the reported issue. However, I've identified a critical bug in the new implementation that could lead to crashes on platforms not supporting FlashAttention, like XPU. I've also pointed out a high-severity maintainability issue regarding the modification of function parameters, which could make the code harder to reason about. Addressing these points will improve the robustness and clarity of the code.

Comment on lines 91 to 92
if try_switch_to_fa and not is_fa_backend(attn_backend):
attn_backend = _Backend.FLASH_ATTN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic unconditionally switches the backend to FLASH_ATTN if try_switch_to_fa is true. This can cause a crash on platforms that do not support FlashAttention, such as XPU, because the subsequent import of vllm.vllm_flash_attn will fail. The switch should be guarded to only occur on supported platforms (CUDA and ROCm).

Suggested change
if try_switch_to_fa and not is_fa_backend(attn_backend):
attn_backend = _Backend.FLASH_ATTN
if try_switch_to_fa and not is_fa_backend(attn_backend) and (
current_platform.is_cuda() or current_platform.is_rocm()):
attn_backend = _Backend.FLASH_ATTN

attn_backend == _Backend.FLASH_ATTN:
# Always try upstream on ROCM.
logger.info_once("maybe_get_vit_flash_attn_backend: forcing upstream FlashAttn on ROCM.")
try_use_upstream_fa = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying an input parameter try_use_upstream_fa directly is confusing and can lead to unexpected side effects. It's better to use a local variable to track the state within the function. For example, you could introduce use_upstream_fa = try_use_upstream_fa at the beginning of the function and then modify and use use_upstream_fa.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Summary:

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967
@bradleyhd
Copy link
Author

Updated to try and mimic #26104 as closely as possible to make this an equivalent change. Not sure the behavior in the original PR's is good / should be preserved, though.

@zhewenl zhewenl added rocm Related to AMD ROCm ci/build ci-failure Issue about an unexpected test failure in CI labels Oct 17, 2025
@zhewenl
Copy link
Collaborator

zhewenl commented Oct 17, 2025

This PR also fix existing AMD failures(example):

(EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layers/cross_attention.py", line 168, in __init__
--
  | (EngineCore_DP0 pid=50574)     super().__init__(
  | (EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 236, in __init__
  | (EngineCore_DP0 pid=50574)     self.impl = impl_cls(
  | (EngineCore_DP0 pid=50574)                 ^^^^^^^^^
  | (EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/triton_attn.py", line 248, in __init__
  | (EngineCore_DP0 pid=50574)     raise NotImplementedError(
  | (EngineCore_DP0 pid=50574) NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for TritonAttentionImpl
  | [rank0]:[W1017 04:54:49.728888986 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

cc @Alexei-V-Ivanov-AMD

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ci-failure Issue about an unexpected test failure in CI qwen Related to Qwen models rocm Related to AMD ROCm

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants