Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,31 @@
)


def _validate_decode_inputs(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen_kv: torch.Tensor | None,
) -> None:
assert seqlen_kv is not None, "seqlen_kv must be provided for decode"
tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv}

for name, tensor in tensors.items():
# assert tensor.is_contiguous(), f"{name} is not contiguous"
assert tensor.is_cuda, f"{name} must be on GPU"


def _cutlass_blackwell_fmha_gen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen_kv: torch.Tensor,
batch_idx: torch.Tensor,
kernel_type: GenKernelType = GenKernelType.UMMA_I,
window_left: int = -1,
window_right: int = -1,
) -> torch.Tensor:
assert q.is_contiguous(), "q is not contiguous"
assert k.is_contiguous(), "k is not contiguous"
assert v.is_contiguous(), "v is not contiguous"
assert seqlen_kv.is_contiguous(), "seqlen_kv is not contiguous"
assert batch_idx.is_contiguous(), "batch_idx is not contiguous"
assert q.is_cuda, "q must be on GPU"
assert k.is_cuda, "k must be on GPU"
assert v.is_cuda, "v must be on GPU"
assert seqlen_kv.is_cuda, "seqlen_kv must be on GPU"
assert batch_idx.is_cuda, "batch_idx must be on GPU"
_validate_decode_inputs(q, k, v, seqlen_kv)
return torch.ops.fbgemm.fmha_gen_fwd(
q,
k,
Expand All @@ -157,6 +164,118 @@
)


def _prepare_decode_inputs(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]:
"""
Prepare inputs for decode kernel by handling both varlen and batch formats.

Returns:
- Reshaped q, k, v tensors in batch format [B, 1, H, D]
- batch_size
- needs_reshape_output flag
- original_shape of q
"""
original_shape = tuple(q.shape)
needs_reshape_output = False
batch_size = q.shape[0]

if q.dim() == 3:
# Varlen format: [total_queries, num_heads, head_dim]
q = q.view(batch_size, 1, q.shape[1], q.shape[2])
needs_reshape_output = True

if q.dim() != 4:
raise ValueError(
f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]"
)
assert q.shape[1] == 1, "Kernel have sq=1"

k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k
v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v

return q, k, v, batch_size, needs_reshape_output, original_shape


def _create_decode_lse(
out: torch.Tensor,
batch_size: int,
needs_reshape_output: bool,
q_shape: tuple[int, ...],
) -> torch.Tensor:
"""
Create dummy LSE tensor for decode output compatibility.
Gen kernel doesn't return LSE, so we create a zero tensor.
"""
if needs_reshape_output:
# For varlen output format
lse_shape = [batch_size, q_shape[-1]] # [B, H]
else:
# For batch output format
lse_shape = [batch_size, q_shape[-2], q_shape[1]] # [B, H, 1]

return torch.zeros(*lse_shape, dtype=torch.float32, device=out.device)


def _cutlass_blackwell_fmha_decode_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen_kv: torch.Tensor | None = None,
cu_seqlens_q: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
max_seq_len_q: int | None = None,
max_seq_len_k: int | None = None,
softmax_scale: float | None = None,
causal: bool = False,
window_left: int = -1,
window_right: int = -1,
bottom_right: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Decode-optimized forward pass using the gen kernel.
This wrapper adapts the variable-length batch interface to use the gen kernel
which is optimized for decode (query length = 1).

Accepts inputs in two formats:
- Varlen format: [total_queries, num_heads, head_dim] (3D)
- Batch format: [batch_size, 1, num_heads, head_dim] (4D)
"""
_validate_decode_inputs(q, k, v, seqlen_kv)
# Handle window size for causal attention
if causal and window_left >= 0:
window_right = 0

Check failure on line 247 in fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

F841 local variable 'window_right' is assigned to but never used

# Prepare inputs and handle format conversion
q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs(
q, k, v
)

# Create batch_idx tensor
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)

# Call the gen kernel (optimized for decode)
out = torch.ops.fbgemm.fmha_gen_fwd(
q,
k,
v,
seqlen_kv,
batch_idx,
kernel_type=GenKernelType.UMMA_I,
# window_left=window_left,
# window_right=window_right,
)

# Reshape output back to original format if needed
if needs_reshape_output:
out = out.view(*original_shape)

# Create dummy LSE for compatibility
lse = _create_decode_lse(out, batch_size, needs_reshape_output, original_shape)

return out, lse


class CutlassBlackwellFmhaFunc(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
Expand All @@ -175,67 +294,66 @@
bottom_right: bool = True,
deterministic: bool = False,
) -> torch.Tensor:
window_left, window_right = window_size
# Check if this is generation phase (sq = 1)
sq = q.shape[1]
# Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
if cu_seqlens_q is not None and cu_seqlens_k is not None:
assert (
cu_seqlens_q.dtype == torch.int32
and cu_seqlens_q.dtype == cu_seqlens_k.dtype
), "cu_seqlens_q and cu_seqlens_k must be int32"

# handle window_size
window_left, window_right = window_size
if causal and window_left >= 0:
window_right = 0

if q.dim() == 4 and sq == 1:
batch_size = q.shape[0]

# Use provided seqlen_kv
assert (
seqlen_kv is not None
), "seqlen_kv must be provided for generation phase"

# Create batch_idx tensor
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)

# Use gen forward (no backward needed for generation)
out = _cutlass_blackwell_fmha_gen(
q, k, v, seqlen_kv, batch_idx, kernel_type=GenKernelType.UMMA_I
)
# For gen case, we don't need to save tensors for backward
ctx.is_gen = True
return out
else:
# Use regular FMHA for non-generation case
out, softmax_lse = _cutlass_blackwell_fmha_forward(
out, _ = _cutlass_blackwell_fmha_decode_forward(
q,
k,
v,
seqlen_kv,
cu_seqlens_q,
cu_seqlens_k,
max_seq_len_q,
max_seq_len_k,
softmax_scale,
causal,
seqlen_kv,
window_left,
window_right,
bottom_right,
)
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.max_seq_len_q = max_seq_len_q
ctx.max_seq_len_k = max_seq_len_k
ctx.cu_seqlens_q = cu_seqlens_q
ctx.cu_seqlens_k = cu_seqlens_k
ctx.is_gen = False
ctx.bottom_right = bottom_right
ctx.deterministic = deterministic
return out
# Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
if cu_seqlens_q is not None and cu_seqlens_k is not None:
assert (
cu_seqlens_q.dtype == torch.int32
and cu_seqlens_q.dtype == cu_seqlens_k.dtype
), "cu_seqlens_q and cu_seqlens_k must be int32"

# handle window_size
if causal and window_left >= 0:
window_right = 0
# Use regular FMHA for non-generation case
out, softmax_lse = _cutlass_blackwell_fmha_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seq_len_q,
max_seq_len_k,
softmax_scale,
causal,
seqlen_kv,
window_left,
window_right,
bottom_right,
)
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.max_seq_len_q = max_seq_len_q
ctx.max_seq_len_k = max_seq_len_k
ctx.cu_seqlens_q = cu_seqlens_q
ctx.cu_seqlens_k = cu_seqlens_k
ctx.is_gen = False
ctx.bottom_right = bottom_right
ctx.deterministic = deterministic
return out

@staticmethod
def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore
Expand Down
Loading
Loading