JAX Flash Attention — JAX bindings for Flash Attention v4 (CuTeDSL) on NVIDIA Blackwell GPUs.
Wraps the CuTeDSL kernels from
flash-attn and exposes them
as a jax.custom_vjp-decorated function with forward + backward support.
- NVIDIA Blackwell GPU (SM100 / SM120)
nvidia-cutlass-dsl >= 4.2.0jax-tvm-ffiflash-attn(withflash_attn.cutemodules)
pip install jafaimport jax.numpy as jnp
from jafa import flash_attention
q = jnp.ones((batch, q_len, num_heads, head_dim), dtype=jnp.bfloat16)
k = jnp.ones((batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jnp.ones((batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
# Dense path
out = flash_attention(q, k, v)
# Varlen path via right-padding mask (..., heads, q_len, kv_len)
out = flash_attention(q, k, v, mask=mask)Kernels are compiled via cute.compile on first call for each (dtype, head_dim, head_dim_v, qhead_per_kvhead, arch, varlen) combination and cached
in-process. Call register_flash_attn_ops(...) ahead of time to pre-warm.
The mask argument is treated as a right-padding mask only: it's
converted to cu_seqlens via jnp.any(mask, axis=...). Causal,
sliding-window, or arbitrary masks will be silently miscomputed — pass
mask=None and handle those yourself, or extend register_flash_attn_ops
with is_causal=True etc.
- Unpin the upstream stack.
pyproject.tomlcurrently hard-pinsjax==0.9.2,nvidia-cutlass-dsl[cu13]==4.5.0.dev0,jax-tvm-ffi==0.1.2, andflash-attn-4==4.0.0b8. This is the only combination found to work end-to-end: newer cutlass-dsl releases changed thenvvm.atomicrmwsignature in a way that breaksflash_bwd, and the[cu13]extra is required because its_nvvm_ops_gen.pyoverrides the libs-base version with the signature the in-source monkey-patch expects. Revisit onceflash-attn-4ships a release compatible with cutlass-dsl 4.5.0 GA (or later) and we can drop both pins and the monkey-patch. - Drop the torch dependency.
flash-attn-4liststorchas a runtime requirement and pulls intorch==2.11.0+ the full CUDA 13 wheel stack (~3 GB). On jafa's call path torch is only incidentally loaded: we importassume_tensor_alignedfromflash_attn.cute.cute_dsl_utils, and that module doesimport torchat the top to serve unrelated helpers (get_device_capacity,to_cute_tensor,get_broadcast_dims). The kernel classes themselves operate oncute.Tensorand are torch-free. Plan: vendorassume_tensor_aligned(5 lines) and drop theflash_attn.cute.cute_dsl_utilsimport, then overrideflash-attn-4's declared deps so torch isn't installed. Saves several GB and a long install. - Expose causal / sliding-window / local masks. Currently every kernel
is registered with
is_causal=False,is_local=False. Themaskargument is a right-padding mask only; arbitrary attention patterns are silently miscomputed. Add a kwarg that selects the kernel variant at registration time. - fp8 support.
flash-attn-4supportsFloat8E4M3FN/Float8E5M2, but jafa's_JAX_TO_CUTLASStable only maps fp16/bf16/fp32 and JAX's fp8 DLPack story is still rough — needs investigation.
MIT