From 5fe24d38bd63734d204f8018c6d67c70e8428d5e Mon Sep 17 00:00:00 2001 From: Yuanqing Zhao Date: Sun, 24 May 2026 01:54:54 -0700 Subject: [PATCH 1/2] perf(deepseek-v4): vectorize read_deepseek_v4_indexer_fp8_cache The original implementation iterated over `slot_mapping.tolist()` in Python and performed GPU slicing + dtype-view + multiply per token. For a 16-req x 1024-token prefill batch (~14338 tokens) across ~30 sparse attention layers this is ~430K Python iterations per forward pass, each with several GPU ops. The CPU sync from `.tolist()` also blocks any hope of CUDA graph capture for the indexer path. Replace with a batched torch-op implementation following the same pattern already used by `read_deepseek_v4_indexer_mxfp4_cache` (same file): one `gather` per dimension, dequantize on device. Output is bit-identical to the reference loop for valid slots, zero for invalid slots (slot < 0). Measured impact on DeepSeek-V4-Flash with H20-3e TP=4, FP8 KV cache, random ISL=1024 OSL=4 c=16: TTFT (ms): 823,467 -> 18,197 (45x) TPOT (ms): 2,067 -> 298 (7x) 16/16 bench duration: 1350s -> 19s (70x) The vectorized implementation is also CUDA-graph-safe (no Python branches, no `.tolist()` CPU sync), unblocking `--enforce-eager` removal for V4-Flash's sparse indexer path. Existing test `test_csa_indexer_cache_insert_fp8_path` continues to pass; numerical equivalence with the original reference loop was verified against the DeepSeek-V4-Flash bring-up smoke ('The capital of France is Paris.') and a successful 16/16 random-prompt bench run. Signed-off-by: Yuanqing Zhao --- .../layers/attention/deepseek_v4_ops.py | 61 ++++++++++++------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py index 1124830f5..88b559269 100644 --- a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py +++ b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py @@ -758,28 +758,47 @@ def read_deepseek_v4_indexer_fp8_cache( f"cache_2d must be [pages, >= {min_stride}], got {tuple(cache_2d.shape)}" ) - out = torch.zeros( - slot_mapping.numel(), - index_head_dim, - device=cache_2d.device, - dtype=torch.float32, - ) + out_shape = (slot_mapping.numel(), index_head_dim) + if slot_mapping.numel() == 0: + return torch.empty(out_shape, device=cache_2d.device, dtype=torch.float32) + flat_cache = cache_2d.reshape(-1) - for token_idx, raw_slot in enumerate(slot_mapping.tolist()): - slot = int(raw_slot) - if slot < 0: - continue - page = slot // block_size - pos = slot % block_size - page_base = page * cache_2d.stride(0) - value_base = page_base + pos * index_head_dim - scale_base = page_base + block_size * index_head_dim + pos * scale_bytes - scale = flat_cache[scale_base : scale_base + scale_bytes].view(torch.float32)[0] - values = flat_cache[value_base : value_base + index_head_dim].view( - torch.float8_e4m3fn - ) - out[token_idx].copy_(values.float() * scale) - return out + # Move slot_mapping to cache_2d.device so the gather offsets composed below + # (which mix it with torch.arange(device=cache_2d.device)) don't fail on + # cross-device tensors when the caller passes a CPU slot_mapping. + slots = slot_mapping.to(device=cache_2d.device, dtype=torch.int64) + valid = slots >= 0 + safe_slots = torch.where(valid, slots, torch.zeros_like(slots)) + pages = torch.div(safe_slots, block_size, rounding_mode="floor") + pos = safe_slots % block_size + page_base = pages * cache_2d.stride(0) + value_base = page_base + pos * index_head_dim + scale_base = page_base + block_size * index_head_dim + pos * scale_bytes + + value_offsets = ( + value_base[:, None] + + torch.arange( + index_head_dim, + device=cache_2d.device, + dtype=torch.int64, + )[None, :] + ) + values = flat_cache[value_offsets].view(torch.float8_e4m3fn).float() + + scale_offsets = ( + scale_base[:, None] + + torch.arange( + scale_bytes, + device=cache_2d.device, + dtype=torch.int64, + )[None, :] + ) + # scale_bytes is a multiple of 4 (FP32). The reference loop uses only the + # first FP32 per row (`view(torch.float32)[0]`); mirror that here. + scales = flat_cache[scale_offsets].view(torch.float32)[:, 0] + + out = values * scales[:, None] + return torch.where(valid[:, None], out, torch.zeros_like(out)) def _compress_v4_state_windows_capturable( From 1e29154d473e76a9c1425c5685833d6a8b72b8a8 Mon Sep 17 00:00:00 2001 From: Yuanqing Zhao Date: Mon, 25 May 2026 07:50:46 -0700 Subject: [PATCH 2/2] perf(deepseek-v4): fix empty-cache OOB in vectorized indexer fp8 read When cache_2d has zero pages (e.g., warmup batches before any FP8 indexer rows are cached) and slot_mapping is all-padding, the reference per-token loop iterated `slot_mapping.tolist()` and `continue`d on every `slot < 0`, returning the zero-initialized output tensor without touching `flat_cache`. The vectorized path uses `where(valid, slots, 0)` to keep offsets in-range, but the resulting row-0 gather still indexes into an empty `flat_cache` and raises an out-of-bounds error (on CUDA, surfaces as a device-side assert). Add a shape-only early return when `slot_mapping.numel() == 0` or `cache_2d.shape[0] == 0`. Shape-only so the check stays CUDA-graph-capture-safe (no `valid.any()` host sync). Switch the empty return tensor from `torch.empty` to `torch.zeros` to match the reference behavior in the cache-has-zero-pages case. Caught by codex review on PR #238. Signed-off-by: Yuanqing Zhao --- .../runtime/layers/attention/deepseek_v4_ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py index 88b559269..a0852a6f1 100644 --- a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py +++ b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py @@ -759,8 +759,13 @@ def read_deepseek_v4_indexer_fp8_cache( ) out_shape = (slot_mapping.numel(), index_head_dim) - if slot_mapping.numel() == 0: - return torch.empty(out_shape, device=cache_2d.device, dtype=torch.float32) + # Also bail when cache_2d has zero pages: the gather below uses + # `where(valid, slots, 0)` to keep offsets in-range, but the resulting + # row-0 read still OOBs against an empty `flat_cache`. The reference + # per-token loop tolerated this (it iterates `slot_mapping.tolist()` and + # `continue`s on `slot < 0`), so preserve that behavior with zeros. + if slot_mapping.numel() == 0 or cache_2d.shape[0] == 0: + return torch.zeros(out_shape, device=cache_2d.device, dtype=torch.float32) flat_cache = cache_2d.reshape(-1) # Move slot_mapping to cache_2d.device so the gather offsets composed below