Skip to content

giovannic/jafa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jafa

JAX Flash Attention — JAX bindings for Flash Attention v4 (CuTeDSL) on NVIDIA Blackwell GPUs.

Wraps the CuTeDSL kernels from flash-attn and exposes them as a jax.custom_vjp-decorated function with forward + backward support.

Requirements

  • NVIDIA Blackwell GPU (SM100 / SM120)
  • nvidia-cutlass-dsl >= 4.2.0
  • jax-tvm-ffi
  • flash-attn (with flash_attn.cute modules)

Install

pip install jafa

Usage

import jax.numpy as jnp
from jafa import flash_attention

q = jnp.ones((batch, q_len, num_heads, head_dim), dtype=jnp.bfloat16)
k = jnp.ones((batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jnp.ones((batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)

# Dense path
out = flash_attention(q, k, v)

# Varlen path via right-padding mask (..., heads, q_len, kv_len)
out = flash_attention(q, k, v, mask=mask)

Kernels are compiled via cute.compile on first call for each (dtype, head_dim, head_dim_v, qhead_per_kvhead, arch, varlen) combination and cached in-process. Call register_flash_attn_ops(...) ahead of time to pre-warm.

A note on masks

The mask argument is treated as a right-padding mask only: it's converted to cu_seqlens via jnp.any(mask, axis=...). Causal, sliding-window, or arbitrary masks will be silently miscomputed — pass mask=None and handle those yourself, or extend register_flash_attn_ops with is_causal=True etc.

Roadmap

  • Unpin the upstream stack. pyproject.toml currently hard-pins jax==0.9.2, nvidia-cutlass-dsl[cu13]==4.5.0.dev0, jax-tvm-ffi==0.1.2, and flash-attn-4==4.0.0b8. This is the only combination found to work end-to-end: newer cutlass-dsl releases changed the nvvm.atomicrmw signature in a way that breaks flash_bwd, and the [cu13] extra is required because its _nvvm_ops_gen.py overrides the libs-base version with the signature the in-source monkey-patch expects. Revisit once flash-attn-4 ships a release compatible with cutlass-dsl 4.5.0 GA (or later) and we can drop both pins and the monkey-patch.
  • Drop the torch dependency. flash-attn-4 lists torch as a runtime requirement and pulls in torch==2.11.0 + the full CUDA 13 wheel stack (~3 GB). On jafa's call path torch is only incidentally loaded: we import assume_tensor_aligned from flash_attn.cute.cute_dsl_utils, and that module does import torch at the top to serve unrelated helpers (get_device_capacity, to_cute_tensor, get_broadcast_dims). The kernel classes themselves operate on cute.Tensor and are torch-free. Plan: vendor assume_tensor_aligned (5 lines) and drop the flash_attn.cute.cute_dsl_utils import, then override flash-attn-4's declared deps so torch isn't installed. Saves several GB and a long install.
  • Expose causal / sliding-window / local masks. Currently every kernel is registered with is_causal=False, is_local=False. The mask argument is a right-padding mask only; arbitrary attention patterns are silently miscomputed. Add a kwarg that selects the kernel variant at registration time.
  • fp8 support. flash-attn-4 supports Float8E4M3FN / Float8E5M2, but jafa's _JAX_TO_CUTLASS table only maps fp16/bf16/fp32 and JAX's fp8 DLPack story is still rough — needs investigation.

License

MIT

About

Jax bridge for flash attention v4

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages