Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS


def check_upstream_fa_availability(dtype: torch.dtype):
def check_upstream_fa_availability(dtype: torch.dtype) -> bool:
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
Expand All @@ -80,26 +80,40 @@ def check_upstream_fa_availability(dtype: torch.dtype):
return find_spec("flash_attn") is not None
return False

def is_fa_backend(backend: _Backend) -> bool:
return backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}

def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend, use_upstream_fa: bool
) -> tuple[_Backend, Callable]:
if (
attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend: _Backend,
try_switch_to_fa: bool = False,
force_upstream_fa: bool = False) -> tuple[_Backend, Callable]:

upstream_fa_available = check_upstream_fa_availability(torch.get_default_dtype())
if force_upstream_fa:
assert upstream_fa_available, \
"Upstream FlashAttn is not available."

use_upstream_fa = force_upstream_fa
if try_switch_to_fa and not is_fa_backend(attn_backend) and upstream_fa_available:
attn_backend = _Backend.FLASH_ATTN
logger.info_once("maybe_get_vit_flash_attn_backend: ", \
"auto-switching to upstream FlashAttn.")
use_upstream_fa = True

if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:

if current_platform.is_rocm() and \
attn_backend == _Backend.FLASH_ATTN:
# Always upstream on ROCM.
logger.info_once("maybe_get_vit_flash_attn_backend: ", \
"ROCM backend is now FLASH_ATTN, forcing upstream FA.")
use_upstream_fa = True

if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if is_fa_backend(attn_backend):
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
assert upstream_fa_available, \
"Upstream FlashAttn is not available."
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand All @@ -108,7 +122,6 @@ def maybe_get_vit_flash_attn_backend(

return attn_backend, flash_attn_varlen_func


class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.

Expand Down Expand Up @@ -428,11 +441,6 @@ def __init__(
# Determine the attention backend
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False

if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
Expand All @@ -450,30 +458,19 @@ def __init__(
else _Backend.TORCH_SDPA
)

self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend, self._flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
try_switch_to_fa=False,
)
)

if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
self.attn_backend = _Backend.TORCH_SDPA

self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}

# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
self.is_flash_attn_backend = is_fa_backend(self.attn_backend)

logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}"
)
f"MultiHeadAttention attn_backend: {self.attn_backend}")

def forward(
self,
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,11 @@ def __init__(
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype()
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
)
)
if self.attn_backend not in {
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,10 @@ def __init__(
dtype=torch.get_default_dtype(),
)

self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
)
)

Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,11 @@ def __init__(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
)
)

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,12 @@ def __init__(
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
force_upstream_fa=use_upstream_fa,
)
)
self.is_flash_attn_backend = self.attn_backend in {
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,11 @@ def __init__(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
)
)

Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/siglip2navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,11 @@ def __init__(
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype()
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
try_switch_to_fa=True,
)
)

Expand Down
Loading