diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..c6045d33907a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,7 +17,7 @@ import inspect import math from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -83,12 +83,17 @@ raise ImportError( "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) - from ..utils.kernels_utils import _get_fa3_from_hub + from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub - flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3) flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + + sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE) + sage_attn_func_hub = sage_interface_hub.sageattn + else: flash_attn_3_func_hub = None + sage_attn_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -162,10 +167,6 @@ def wrap(func): # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet -_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] -_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] -_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] - class AttentionBackendName(str, Enum): # EAGER = "eager" @@ -190,6 +191,7 @@ class AttentionBackendName(str, Enum): # `sageattention` SAGE = "sage" + SAGE_HUB = "sage_hub" SAGE_VARLEN = "sage_varlen" _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" @@ -404,14 +406,14 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) # TODO: add support Hub variant of FA3 varlen later - elif backend in [AttentionBackendName._FLASH_3_HUB]: + elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.SAGE_HUB]: if not DIFFUSERS_ENABLE_HUB_KERNELS: raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." + f"Attention backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Attention backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend in [ @@ -1756,6 +1758,31 @@ def _sage_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_HUB, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _sage_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + lse = None + if _parallel_config is None: + out = sage_attn_func_hub(q=query, k=key, v=value) + if return_lse: + out, lse, *_ = out + else: + raise NotImplementedError("SAGE attention doesn't yet support parallelism.") + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972fb7..61201b847b74 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -6,18 +6,25 @@ _DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention" +_KERNEL_REVISION = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + _DEFAULT_HUB_ID_FA3: "fake-ops-return-probs", + _DEFAULT_HUB_ID_SAGE: None, +} -def _get_fa3_from_hub(): +def _get_kernel_from_hub(kernel_id): if not is_kernels_available(): return None else: from kernels import get_kernel try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub + if kernel_id not in _KERNEL_REVISION: + raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.") + kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id)) + return kernel_hub except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}") raise