From 3405c66c57c4bc866eede2c2648ee541ea287e30 Mon Sep 17 00:00:00 2001 From: Bradley Davis Date: Fri, 17 Oct 2025 14:27:39 -0700 Subject: [PATCH] make flash_attn ViT upgrade opt-in (#27124) Summary: In https://github.com/vllm-project/vllm/pull/26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when `VLLM_ATTENTION_BACKEND` is set. This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior. Differential Revision: D84946967 --- vllm/attention/layer.py | 65 +++++++++++----------- vllm/model_executor/models/dots_ocr.py | 3 +- vllm/model_executor/models/ernie45_vl.py | 4 +- vllm/model_executor/models/glm4_1v.py | 3 +- vllm/model_executor/models/qwen2_5_vl.py | 5 +- vllm/model_executor/models/qwen2_vl.py | 3 +- vllm/model_executor/models/siglip2navit.py | 3 +- 7 files changed, 39 insertions(+), 47 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e288770f2fcb..2e637e700620 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -65,7 +65,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -def check_upstream_fa_availability(dtype: torch.dtype): +def check_upstream_fa_availability(dtype: torch.dtype) -> bool: if ( dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda() @@ -80,26 +80,40 @@ def check_upstream_fa_availability(dtype: torch.dtype): return find_spec("flash_attn") is not None return False +def is_fa_backend(backend: _Backend) -> bool: + return backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA} def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, use_upstream_fa: bool -) -> tuple[_Backend, Callable]: - if ( - attn_backend != _Backend.FLASH_ATTN - and attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - ): + attn_backend: _Backend, + try_switch_to_fa: bool = False, + force_upstream_fa: bool = False) -> tuple[_Backend, Callable]: + + upstream_fa_available = check_upstream_fa_availability(torch.get_default_dtype()) + if force_upstream_fa: + assert upstream_fa_available, \ + "Upstream FlashAttn is not available." + + use_upstream_fa = force_upstream_fa + if try_switch_to_fa and not is_fa_backend(attn_backend) and upstream_fa_available: attn_backend = _Backend.FLASH_ATTN + logger.info_once("maybe_get_vit_flash_attn_backend: ", \ + "auto-switching to upstream FlashAttn.") use_upstream_fa = True - - if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: + + if current_platform.is_rocm() and \ + attn_backend == _Backend.FLASH_ATTN: + # Always upstream on ROCM. + logger.info_once("maybe_get_vit_flash_attn_backend: ", \ + "ROCM backend is now FLASH_ATTN, forcing upstream FA.") use_upstream_fa = True - - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + + if is_fa_backend(attn_backend): if attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: + assert upstream_fa_available, \ + "Upstream FlashAttn is not available." from flash_attn import flash_attn_varlen_func else: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -108,7 +122,6 @@ def maybe_get_vit_flash_attn_backend( return attn_backend, flash_attn_varlen_func - class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -428,11 +441,6 @@ def __init__( # Determine the attention backend backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) - # Some auto-selected backends can be upgraded - # to upstream flash attention if available. - # If vllm native fa is selected, we use it directly. - use_upstream_fa = False - if current_platform.is_xpu(): # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA @@ -450,30 +458,19 @@ def __init__( else _Backend.TORCH_SDPA ) - self.attn_backend, self._flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( + self.attn_backend, self._flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( self.attn_backend, - use_upstream_fa, + try_switch_to_fa=False, ) - ) if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): self.attn_backend = _Backend.TORCH_SDPA - self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, - } - - # this condition is just to make sure that the - # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: - use_upstream_fa = True + self.is_flash_attn_backend = is_fa_backend(self.attn_backend) logger.info_once( - f"MultiHeadAttention attn_backend: {self.attn_backend}, " - f"use_upstream_fa: {use_upstream_fa}" - ) + f"MultiHeadAttention attn_backend: {self.attn_backend}") def forward( self, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index bd7f37b07de3..704a0c8d6a9c 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -290,12 +290,11 @@ def __init__( self.attn_backend = get_vit_attn_backend( self.hidden_size_per_attention_head, torch.get_default_dtype() ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, ) ) if self.attn_backend not in { diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index e5badc0a28f6..5bb09cc5a8a0 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -198,12 +198,10 @@ def __init__( dtype=torch.get_default_dtype(), ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, ) ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 132f26253b36..103a4dbbe678 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -288,12 +288,11 @@ def __init__( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, ) ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index f05f130e1c44..2117e4956fe7 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -341,11 +341,12 @@ def __init__( disable_tp=use_data_parallel, ) self.attn_backend = attn_backend - self.use_upstream_fa = use_upstream_fa + self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, + force_upstream_fa=use_upstream_fa, ) ) self.is_flash_attn_backend = self.attn_backend in { diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 821a9d13dc6f..33f957c86797 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -356,12 +356,11 @@ def __init__( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, ) ) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index e7af0e7a7ae4..b7b1dbcd13b1 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -250,12 +250,11 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype() ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, + try_switch_to_fa=True, ) )