diff --git a/benchmarks/ops/bench_fp8_lighting_indexer.py b/benchmarks/ops/bench_fp8_lighting_indexer.py index 60e66ce3..9f80f59e 100644 --- a/benchmarks/ops/bench_fp8_lighting_indexer.py +++ b/benchmarks/ops/bench_fp8_lighting_indexer.py @@ -20,10 +20,10 @@ def calculate_memory(self) -> Optional[float]: accum_dtype = torch.float32 index_dtype = torch.int32 - index_q_memory = t.seq_len * t.heads * t.index_dim * dtype.itemsize - index_k_memory = t.seq_len_kv * t.index_dim * dtype.itemsize - index_k_scale_memory = t.seq_len_kv * accum_dtype.itemsize - logits_memory = t.seq_len * t.seq_len_kv * accum_dtype.itemsize + index_q_memory = t.batch * t.seq_len * t.heads * t.index_dim * dtype.itemsize + index_k_memory = t.batch * t.seq_len_kv * t.index_dim * t.kv_group * dtype.itemsize + index_k_scale_memory = t.batch * t.seq_len_kv * t.kv_group * accum_dtype.itemsize + logits_memory = t.batch * t.seq_len * t.seq_len_kv * t.kv_group * accum_dtype.itemsize weights_memory = t.seq_len * t.heads * accum_dtype.itemsize cu_seqlens_ks_memory = t.seq_len * index_dtype.itemsize cu_seqlens_ke_memory = t.seq_len * index_dtype.itemsize diff --git a/benchmarks/ops/bench_fp8_quant.py b/benchmarks/ops/bench_fp8_quant.py index 02198e69..ee74d606 100644 --- a/benchmarks/ops/bench_fp8_quant.py +++ b/benchmarks/ops/bench_fp8_quant.py @@ -12,11 +12,12 @@ class Fp8QuantBenchmark(BenchmarkBase): def calculate_flops(self) -> Optional[float]: t = self.test - return 2 * t.seq_len_kv * t.index_dim + t.seq_len_kv + 4 * t.seq_len_kv * t.index_dim + return (2 * t.batch * t.seq_len_kv * t.kv_group * t.index_dim + + t.batch * t.seq_len_kv * t.kv_group + 4 * t.batch * t.seq_len_kv * t.kv_group * t.index_dim) def calculate_memory(self) -> Optional[float]: t = self.test - return t.seq_len_kv * t.index_dim * t.in_dtype.itemsize + return t.batch * t.seq_len_kv * t.kv_group * t.index_dim * t.in_dtype.itemsize @Fp8QuantFixture diff --git a/benchmarks/ops/bench_topk_selector.py b/benchmarks/ops/bench_topk_selector.py index e4875012..4fd00841 100644 --- a/benchmarks/ops/bench_topk_selector.py +++ b/benchmarks/ops/bench_topk_selector.py @@ -16,10 +16,10 @@ def calculate_flops(self) -> Optional[float]: def calculate_memory(self) -> Optional[float]: t = self.test - index_score_memory = t.batch * t.seq_len * t.in_dtype.itemsize - index_memory = t.batch * t.topk * t.out_dtype.itemsize - starts_memory = t.batch * t.out_dtype.itemsize - ends_memory = t.batch * t.out_dtype.itemsize + index_score_memory = (t.batch * t.seq_len * t.seq_len_kv * t.kv_group * t.in_dtype.itemsize) + index_memory = t.batch * t.seq_len * t.topk * t.kv_group * t.out_dtype.itemsize + starts_memory = t.batch * t.seq_len * t.out_dtype.itemsize + ends_memory = t.batch * t.seq_len * t.out_dtype.itemsize return index_score_memory + index_memory + starts_memory + ends_memory diff --git a/tests/ops/test_fp8_lighting_indexer.py b/tests/ops/test_fp8_lighting_indexer.py index 9cf6903e..174d4e88 100644 --- a/tests/ops/test_fp8_lighting_indexer.py +++ b/tests/ops/test_fp8_lighting_indexer.py @@ -9,20 +9,29 @@ class Fp8LightingIndexerFixture(FixtureBase): PARAMS = [ - ("seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune", [ - (4096, 32, 64, 8192, True, None, False), + ("batch, seq_len, heads, index_dim, seq_len_kv, kv_group, clean_logits, config, tune", [ + (1, 4096, 32, 64, 8192, 1, True, None, False), ]), ] class Fp8LightingIndexerTest(TestBase): - def __init__(self, seq_len: int, heads: int, index_dim: int, seq_len_kv: int, - clean_logits: bool = True, config: Optional[dict] = None): + def __init__(self, + batch: int, + seq_len: int, + heads: int, + index_dim: int, + seq_len_kv: int, + kv_group: int, + clean_logits: bool = True, + config: Optional[dict] = None): + self.batch = batch self.seq_len = seq_len self.heads = heads self.index_dim = index_dim self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.clean_logits = clean_logits self.config = config self.dtype = torch.float8_e4m3fn @@ -147,31 +156,54 @@ def generate_random_cu_seqlens(self, assert len(ks) == len(ke) == self.seq_len return ks, ke - def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + def gen_inputs( + self, + params=None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: IndexQ = torch.randn( - self.seq_len, self.heads, self.index_dim, device='cuda', dtype=torch.bfloat16) - IndexK = torch.randn(self.seq_len_kv, self.index_dim, device='cuda', dtype=torch.bfloat16) + self.batch, + self.seq_len, + self.heads, + self.index_dim, + device='cuda', + dtype=torch.bfloat16) + IndexK = torch.randn( + self.batch, + self.seq_len_kv, + self.kv_group, + self.index_dim, + device='cuda', + dtype=torch.bfloat16) Weights = torch.randn(self.seq_len, self.heads, device='cuda', dtype=self.accum_dtype) CuSeqLenKS, CuSeqLenKE = self.generate_random_cu_seqlens( cp_size=4, cp_rank=3, kv_stride=1, average_q_len=2048) return IndexQ, IndexK, Weights, CuSeqLenKS, CuSeqLenKE def ref_program(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor) -> Tuple[torch.Tensor]: + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor) -> Tuple[torch.Tensor]: k = kv q = q.float() k = k.float() + batch, seq_len, heads, index_dim = q.shape + seq_len_kv = self.seq_len_kv + kv_group = self.kv_group + heads_per_group = heads // kv_group + + k = k.view(batch, seq_len_kv, kv_group, index_dim) + q = q.view(batch, seq_len, kv_group, heads_per_group, index_dim) - seq_len_kv = kv.shape[0] mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) + score = torch.einsum("bsghd,bngd->bghsn", q, k) + weights = weights.view(seq_len, kv_group, heads_per_group) + weights = weights.permute(1, 2, 0).unsqueeze(0).unsqueeze(-1) + score = score.relu() * weights + logits = score.sum(dim=2) + logits = logits.permute(0, 2, 3, 1) + mask_expanded = mask.unsqueeze(0).unsqueeze(-1) + logits = logits.masked_fill(~mask_expanded, float("-inf")) return (logits,) @staticmethod @@ -181,7 +213,9 @@ def _compute_correlation(a: torch.Tensor, b: torch.Tensor) -> float: return 2 * (a * b).sum() / norm_sum @staticmethod - def _validate_tensor_match(output: torch.Tensor, output_ref: torch.Tensor, tolerance: float = 1e-3) -> None: + def _validate_tensor_match(output: torch.Tensor, + output_ref: torch.Tensor, + tolerance: float = 1e-3) -> None: if isinstance(output, tuple): output = output[0] if isinstance(output_ref, tuple): @@ -206,11 +240,20 @@ def _validate_tensor_match(output: torch.Tensor, output_ref: torch.Tensor, toler @Fp8LightingIndexerFixture -def test_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int, clean_logits: bool, - config: Optional[dict], tune: bool) -> None: - test = Fp8LightingIndexerTest(seq_len, heads, index_dim, seq_len_kv, clean_logits, config) +def test_indexer(batch: int, seq_len: int, heads: int, index_dim: int, seq_len_kv: int, + kv_group: int, clean_logits: bool, config: Optional[dict], tune: bool) -> None: + test = Fp8LightingIndexerTest(batch, seq_len, heads, index_dim, seq_len_kv, kv_group, + clean_logits, config) op = Fp8LightingIndexerOp( - seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune=tune) + batch=batch, + seq_len=seq_len, + heads=heads, + index_dim=index_dim, + seq_len_kv=seq_len_kv, + kv_group=kv_group, + clean_logits=clean_logits, + config=config, + tune=tune) test.check(op, *test.gen_inputs(), compare=Fp8LightingIndexerTest._validate_tensor_match) diff --git a/tests/ops/test_fp8_quant.py b/tests/ops/test_fp8_quant.py index bc3b9f8e..e41904ad 100644 --- a/tests/ops/test_fp8_quant.py +++ b/tests/ops/test_fp8_quant.py @@ -10,11 +10,12 @@ class Fp8QuantFixture(FixtureBase): PARAMS = [ - ("seq_len_kv, index_dim, in_dtype, tune", [ - (8192, 64, torch.float16, False), - (8192, 64, torch.bfloat16, False), - (4096, 128, torch.float32, False), - (16384, 32, torch.float32, False), + ("batch, seq_len_kv, kv_group, index_dim, in_dtype, tune", [ + (1, 8192, 1, 64, torch.float16, False), + (1, 8192, 1, 64, torch.bfloat16, False), + (1, 4096, 1, 128, torch.float32, False), + (1, 16384, 1, 32, torch.float32, False), + (1, 1024, 4, 64, torch.float16, False), ]), ] @@ -30,29 +31,44 @@ def _cosine_compare(output: torch.Tensor, output_ref: torch.Tensor) -> None: class Fp8QuantTest(TestBase): - def __init__(self, seq_len_kv: int, index_dim: int, in_dtype: torch.dtype): + def __init__(self, batch: int, seq_len_kv: int, kv_group: int, index_dim: int, + in_dtype: torch.dtype): + self.batch = batch self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.index_dim = index_dim self.in_dtype = in_dtype def gen_inputs(self) -> Tuple[torch.Tensor]: input_tensor = torch.randn( - self.seq_len_kv, self.index_dim, dtype=self.in_dtype, device="cuda") + self.batch, + self.seq_len_kv, + self.kv_group, + self.index_dim, + dtype=self.in_dtype, + device="cuda") return (input_tensor,) def ref_program(self, input_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - amax_value = torch.abs(input_tensor).amax(dim=1, keepdim=True).clamp(min=1e-4) + # input_tensor: (batch, seq_len_kv, kv_group, index_dim) + amax_value = torch.abs(input_tensor).amax(dim=-1, keepdim=True).clamp(min=1e-4) scale_tensor = amax_value / 448.0 output_tensor = torch.clamp(input_tensor / scale_tensor, min=-448.0, max=448.0) output_tensor = output_tensor.to(torch.float8_e4m3fn) - return scale_tensor.squeeze(dim=1), output_tensor + return scale_tensor.squeeze(dim=-1), output_tensor @Fp8QuantFixture -def test_fp8_quant_op(seq_len_kv: int, index_dim: int, in_dtype: torch.dtype, - tune: bool) -> None: - test = Fp8QuantTest(seq_len_kv, index_dim, in_dtype) - op = Fp8QuantOp(seq_len_kv=seq_len_kv, index_dim=index_dim, in_dtype=in_dtype, tune=tune) +def test_fp8_quant_op(batch: int, seq_len_kv: int, kv_group: int, index_dim: int, + in_dtype: torch.dtype, tune: bool) -> None: + test = Fp8QuantTest(batch, seq_len_kv, kv_group, index_dim, in_dtype) + op = Fp8QuantOp( + batch=batch, + seq_len_kv=seq_len_kv, + kv_group=kv_group, + index_dim=index_dim, + in_dtype=in_dtype, + tune=tune) test.check(op, *test.gen_inputs(), compare=_cosine_compare) diff --git a/tests/ops/test_topk_selector.py b/tests/ops/test_topk_selector.py index cf8bb6d7..2180dbf2 100644 --- a/tests/ops/test_topk_selector.py +++ b/tests/ops/test_topk_selector.py @@ -10,11 +10,11 @@ class TopkSelectorFixture(FixtureBase): PARAMS = [ - ("batch, seq_len, topk, in_dtype_str, out_dtype_str, tune", [ - (64, 32 * 1024, 1024, "float32", "int32", False), - (64, 32 * 1024, 2048, "float32", "int32", False), - (128, 64 * 1024, 1024, "float32", "int32", False), - (128, 64 * 1024, 2048, "float32", "int32", False), + ("batch, seq_len, seq_len_kv, kv_group, topk, in_dtype_str, out_dtype_str, tune", [ + (4, 256, 1024, 1, 32, "float32", "int32", False), + (8, 512, 2048, 1, 64, "float32", "int32", False), + (1, 32 * 1024, 64 * 1024, 1, 1024, "float32", "int32", False), + (1, 32 * 1024, 64 * 2048, 1, 2048, "float32", "int32", False), ]), ] @@ -33,33 +33,57 @@ def _set_compare(output: torch.Tensor, output_ref: torch.Tensor) -> None: class TopkSelectorTest(TestBase): - def __init__(self, batch: int, seq_len: int, topk: int, in_dtype: torch.dtype, - out_dtype: torch.dtype): + def __init__(self, batch: int, seq_len: int, seq_len_kv: int, kv_group: int, topk: int, + in_dtype: torch.dtype, out_dtype: torch.dtype): self.batch = batch self.seq_len = seq_len + self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.topk = topk self.in_dtype = in_dtype self.out_dtype = out_dtype def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - index_score = torch.randn(self.batch, self.seq_len, dtype=self.in_dtype, device="cuda") - starts = torch.zeros(self.batch, dtype=self.out_dtype, device="cuda") - ends = torch.ones(self.batch, dtype=self.out_dtype, device="cuda") * self.seq_len + index_score = torch.randn( + self.batch, + self.seq_len, + self.seq_len_kv, + self.kv_group, + dtype=self.in_dtype, + device="cuda") + starts = torch.zeros(self.batch, self.seq_len, dtype=self.out_dtype, device="cuda") + ends = torch.ones(self.batch, self.seq_len, dtype=self.out_dtype, + device="cuda") * self.seq_len_kv return index_score, starts, ends def ref_program(self, index_score: torch.Tensor, starts: torch.Tensor, ends: torch.Tensor) -> torch.Tensor: - indexes_ref = torch.topk(index_score, self.topk, dim=-1)[1] - return indexes_ref + # index_score: (batch, seq_len, seq_len_kv, kv_group); topk over seq_len_kv (dim=2) + indexes_ref = torch.topk(index_score, self.topk, dim=2)[1] + # Match kernel/output layout: (batch, seq_len, kv_group, topk) + return indexes_ref.permute(0, 1, 3, 2) @TopkSelectorFixture -def test_topk_selector_op(batch: int, seq_len: int, topk: int, in_dtype_str: str, - out_dtype_str: str, tune: bool) -> None: +def test_topk_selector_op(batch: int, + seq_len: int, + seq_len_kv: int, + kv_group: int, + topk: int, + in_dtype_str: str, + out_dtype_str: str, + tune: bool) -> None: in_dtype = str2dtype[in_dtype_str] out_dtype = str2dtype[out_dtype_str] - test = TopkSelectorTest(batch, seq_len, topk, in_dtype, out_dtype) - op = TopkSelectorOp(batch, seq_len, topk, in_dtype, out_dtype, tune=tune) + test = TopkSelectorTest(batch, seq_len, seq_len_kv, kv_group, topk, in_dtype, out_dtype) + op = TopkSelectorOp(batch=batch, + seq_len=seq_len, + seq_len_kv=seq_len_kv, + kv_group=kv_group, + topk=topk, + in_dtype=in_dtype, + out_dtype=out_dtype, + tune=tune) test.check(op, *test.gen_inputs(), compare=_set_compare) diff --git a/tileops/kernels/deepseek_mla/fp8_lighting_indexer.py b/tileops/kernels/deepseek_mla/fp8_lighting_indexer.py index ccae631c..97498400 100644 --- a/tileops/kernels/deepseek_mla/fp8_lighting_indexer.py +++ b/tileops/kernels/deepseek_mla/fp8_lighting_indexer.py @@ -13,7 +13,13 @@ __all__ = ["Fp8LightingIndexerKernel"] -def _fp8_lighting_indexer_kernel(seq_len, heads, index_dim, seq_len_kv, clean_logits=True): +def _fp8_lighting_indexer_kernel(batch, + seq_len, + heads, + index_dim, + seq_len_kv, + kv_group, + clean_logits=True): @tilelang.jit( pass_configs={ @@ -34,10 +40,10 @@ def _fp8_lighting_indexer_func( seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") - index_q_shape = [seq_len * heads, index_dim] - index_k_shape = [seq_len_kv, index_dim] - index_k_scale_shape = [seq_len_kv] - logits_shape = [seq_len, seq_len_kv] + index_q_shape = [batch, seq_len * heads, index_dim] + index_k_shape = [batch, seq_len_kv, kv_group, index_dim] + index_k_scale_shape = [batch, seq_len_kv, kv_group] + logits_shape = [batch, seq_len, seq_len_kv, kv_group] @T.prim_func def _fp8_lighting_indexer_main( @@ -49,16 +55,21 @@ def _fp8_lighting_indexer_main( CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): - with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + heads_per_group = heads // kv_group + with T.Kernel(T.ceildiv(seq_len, block_Q), batch, threads=threads) as (bx, by): index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) - index_k_shared = T.alloc_shared([block_N, index_dim], dtype) - index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) - s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) - s_reshaped = T.reshape(s, (block_N, block_Q, heads)) - logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + index_q_group_shared = T.alloc_shared([block_Q * heads_per_group, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, kv_group, index_dim], dtype) + index_k_group_shared = T.alloc_shared([block_N, index_dim], dtype) + index_k_scale_fragment = T.alloc_fragment([block_N, kv_group], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) # + s_reshaped = T.reshape(s, (block_N, block_Q, heads_per_group, kv_group)) + s_tmp = T.alloc_fragment([block_N, heads_per_group], accum_dtype) + logits = T.alloc_fragment([block_N, block_Q, kv_group], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype) seq_len_i = bx * block_Q + b_i = by cu_k_s_min = T.alloc_var(index_dtype) cu_k_e_max = T.alloc_var(index_dtype) @@ -71,33 +82,43 @@ def _fp8_lighting_indexer_main( for bq_i in T.serial(block_Q): cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) - T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(IndexQ[b_i, seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) for nbn_i in T.Pipelined( T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): - T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) - T.copy(IndexKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) - - T.gemm( - index_k_shared, - index_q_shared, - s, - transpose_B=True, - clear_accum=True, - policy=T.GemmWarpPolicy.FullCol, - ) - - for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] - - T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) - - for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, - bq_i] + T.copy(IndexK[b_i, cu_k_s_min + nbn_i * block_N, 0, 0], index_k_shared) + T.copy(IndexKScale[b_i, cu_k_s_min + nbn_i * block_N, 0], + index_k_scale_fragment) + + for g in T.Serial(kv_group): + for bn_i, d_i in T.Parallel(block_N, index_dim): + index_k_group_shared[bn_i, d_i] = index_k_shared[bn_i, g, d_i] # + for i, d in T.Parallel(block_Q * heads_per_group, index_dim): + index_q_group_shared[i, d] = index_q_shared[g * heads_per_group + i, + d] # + T.gemm( + index_k_group_shared, + index_q_group_shared, + s_tmp, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + for bn_i, h_i in T.Parallel(block_N, heads_per_group): + s[bn_i, g * heads_per_group + h_i] = s_tmp[bn_i, h_i] + + for bn_i, bq_i, h_i, g in T.Parallel(block_N, block_Q, heads_per_group, + kv_group): + s_reshaped[bn_i, bq_i, h_i, g] = (T.max(s_reshaped[bn_i, bq_i, h_i, g], 0) * + weights[bq_i, g * heads_per_group + h_i] + ) * index_k_scale_fragment[bn_i, g] + + T.reduce_sum(s_reshaped, logits, dim=-2, clear=True) + + for bq_i, bn_i, g in T.Parallel(block_Q, block_N, kv_group): + Logits[b_i, seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i, + g] = logits[bn_i, bq_i, g] # Return the kernel function handle return _fp8_lighting_indexer_main @@ -106,49 +127,50 @@ def _fp8_lighting_indexer_main( @tilelang.jit -def clean_logits_( - threads: int = 512, - block_K: int = 4096, -): +def clean_logits_(block_K: int = 4096,): + batch = T.dynamic("batch") seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") + kv_group = T.dynamic("kv_group") dtype = T.float indices_dtype = T.int32 @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + Logits: T.Tensor([batch, seq_len, seq_len_kv, kv_group], dtype), # type: ignore CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): - with T.Kernel(seq_len, threads=threads) as bx: - tx = T.thread_binding(0, threads, thread="threadIdx.x") + with T.Kernel(seq_len, batch, threads=512) as (bx, by): + tx = T.get_thread_binding() cu_k_s = CuSeqLenKS[bx] cu_k_e = CuSeqLenKE[bx] for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): - for k_i in T.serial(block_K // threads): - idx = n_i * block_K + k_i * threads + tx + for k_i in T.serial(block_K // 512): + idx = n_i * block_K + k_i * 512 + tx if idx < cu_k_s or idx >= cu_k_e: - Logits[bx, idx] = -T.infinity(dtype) + for g in T.serial(kv_group): + Logits[by, bx, idx, g] = -T.infinity(dtype) return clean_logits_kernel @torch.library.custom_op("top::fp8_lighting_indexer_wrapped_kernel", mutates_args=()) -def fp8_lighting_indexer_wrapped_kernel(seq_len: int, heads: int, index_dim: int, seq_len_kv: int, - clean_logits: bool, block_N: int, num_stages: int, - threads: int, block_Q: int, IndexQ: torch.Tensor, - IndexK: torch.Tensor, IndexKScale: torch.Tensor, - Logits: torch.Tensor, Weights: torch.Tensor, - CuSeqLenKS: torch.Tensor, +def fp8_lighting_indexer_wrapped_kernel(batch: int, seq_len: int, heads: int, index_dim: int, + seq_len_kv: int, kv_group: int, clean_logits: bool, + block_N: int, num_stages: int, threads: int, block_Q: int, + IndexQ: torch.Tensor, IndexK: torch.Tensor, + IndexKScale: torch.Tensor, Logits: torch.Tensor, + Weights: torch.Tensor, CuSeqLenKS: torch.Tensor, CuSeqLenKE: torch.Tensor) -> torch.Tensor: - _fp8_lighting_indexer_kernel(seq_len, heads, index_dim, - seq_len_kv)(block_N, num_stages, threads, - block_Q)(IndexQ.view(seq_len * heads, - index_dim), IndexK, IndexKScale, - Logits, Weights, CuSeqLenKS, CuSeqLenKE) + + _fp8_lighting_indexer_kernel(batch, seq_len, heads, index_dim, seq_len_kv, + kv_group)(block_N, num_stages, threads, + block_Q)(IndexQ.view(batch, seq_len * heads, + index_dim), IndexK, IndexKScale, + Logits, Weights, CuSeqLenKS, CuSeqLenKE) if clean_logits: clean_logits_()(Logits, CuSeqLenKS, CuSeqLenKE) return Logits.clone() @@ -156,10 +178,12 @@ def fp8_lighting_indexer_wrapped_kernel(seq_len: int, heads: int, index_dim: int @fp8_lighting_indexer_wrapped_kernel.register_fake def _( + batch: int, seq_len: int, heads: int, index_dim: int, seq_len_kv: int, + kv_group: int, clean_logits: bool, block_N: int, num_stages: int, @@ -181,29 +205,34 @@ class Fp8LightingIndexerKernel(Kernel): supported_archs: list[int] = [90] def __init__(self, + batch, seq_len, heads, index_dim, seq_len_kv, + kv_group, clean_logits=True, config: Optional[dict] = None, tune=False): super().__init__() + self.batch = batch self.seq_len = seq_len self.heads = heads self.index_dim = index_dim self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.clean_logits = clean_logits self.config = config - self.kernel = _fp8_lighting_indexer_kernel(self.seq_len, self.heads, self.index_dim, - self.seq_len_kv, self.clean_logits) + self.kernel = _fp8_lighting_indexer_kernel(self.batch, self.seq_len, self.heads, + self.index_dim, self.seq_len_kv, self.kv_group, + self.clean_logits) self.init_config(config, tune) @property def default_config(self) -> dict: - return {"block_N": 64, "num_stages": 0, "threads": 128, "block_Q": 1} + return {"block_N": 64, "num_stages": 2, "threads": 128, "block_Q": 1} @property def autotune_configs(self) -> list[dict]: @@ -230,14 +259,14 @@ def forward( CuSeqLenKS: torch.Tensor, # type: ignore CuSeqLenKE: torch.Tensor, # type: ignore ) -> torch.Tensor: - Logits = torch.empty([self.seq_len, self.seq_len_kv], + Logits = torch.empty([self.batch, self.seq_len, self.seq_len_kv, self.kv_group], device=IndexQ.device, dtype=torch.float32) return fp8_lighting_indexer_wrapped_kernel( - self.seq_len, self.heads, self.index_dim, self.seq_len_kv, self.clean_logits, - self.config["block_N"], self.config["num_stages"], self.config["threads"], - self.config["block_Q"], IndexQ, IndexK, IndexKScale, Logits, Weights, CuSeqLenKS, - CuSeqLenKE) + self.batch, self.seq_len, self.heads, self.index_dim, self.seq_len_kv, self.kv_group, + self.clean_logits, self.config["block_N"], self.config["num_stages"], + self.config["threads"], self.config["block_Q"], IndexQ, IndexK, IndexKScale, Logits, + Weights, CuSeqLenKS, CuSeqLenKE) def supply_prog( self, @@ -250,11 +279,13 @@ def supply_prog( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: self.q = q self.kv = kv - seq_len, heads, index_dim = q.shape - seq_len_kv = kv.shape[0] - IndexQ = torch.randn(seq_len * heads, index_dim, device='cuda', dtype=torch.float8_e4m3fn) - IndexK = torch.randn(seq_len_kv, index_dim, device='cuda', dtype=self.dtype) - IndexKScale = torch.randn(seq_len_kv, device='cuda', dtype=accum_dtype) + batch, seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[1] + IndexQ = torch.randn( + batch, seq_len * heads, index_dim, device='cuda', dtype=torch.float8_e4m3fn) + IndexK = torch.randn( + batch, seq_len_kv, self.kv_group, index_dim, device='cuda', dtype=self.dtype) + IndexKScale = torch.randn(batch, seq_len_kv, self.kv_group, device='cuda', dtype=accum_dtype) Weights = torch.randn(seq_len, heads, device='cuda', dtype=accum_dtype) CuSeqLenKS = torch.zeros(seq_len, device='cuda', dtype=index_dtype) CuSeqLenKE = torch.full((seq_len,), diff --git a/tileops/kernels/deepseek_mla/fp8_quant.py b/tileops/kernels/deepseek_mla/fp8_quant.py index cd6bc5f0..4ed38e27 100644 --- a/tileops/kernels/deepseek_mla/fp8_quant.py +++ b/tileops/kernels/deepseek_mla/fp8_quant.py @@ -10,7 +10,7 @@ __all__ = ["Fp8QuantKernel"] -def _fp8_quant_kernel(seq_len_kv, index_dim, in_dtype: str): +def _fp8_quant_kernel(batch, seq_len_kv, kv_group, index_dim, in_dtype: str): @tilelang.jit(out_idx=[1, 2]) def _fp8_quant_fwd_func(num_stages, block_m): @@ -21,30 +21,38 @@ def _fp8_quant_fwd_func(num_stages, block_m): fp8_max_inv = 1 / fp8_max @T.prim_func - def _fp8_quant_fwd_main(input_tensor: T.Tensor[(seq_len_kv, index_dim), in_dtype], - scale_tensor: T.Tensor[(seq_len_kv,), scale_dtype], - output_tensor: T.Tensor[(seq_len_kv, index_dim), out_dtype]): - with T.Kernel(T.ceildiv(seq_len_kv, block_m), threads=128) as (pid_m): - input_shared = T.alloc_shared((block_m, index_dim), in_dtype) + def _fp8_quant_fwd_main(input_tensor: T.Tensor[(batch, seq_len_kv, kv_group, index_dim), + in_dtype], + scale_tensor: T.Tensor[(batch, seq_len_kv, kv_group), scale_dtype], + output_tensor: T.Tensor[(batch, seq_len_kv, kv_group, index_dim), + out_dtype]): + with T.Kernel( + batch, T.ceildiv(seq_len_kv, block_m), kv_group, threads=128) as (bx, pid_m, g): input_local = T.alloc_fragment((block_m, index_dim), in_dtype) amax_local = T.alloc_fragment((block_m,), scale_dtype) scale_local = T.alloc_fragment((block_m,), scale_dtype) output_local = T.alloc_fragment((block_m, index_dim), out_dtype) - output_shared = T.alloc_shared((block_m, index_dim), out_dtype) - T.copy(input_tensor[pid_m * block_m, 0], input_shared) - T.copy(input_shared, input_local) + # Load a (block_m, index_dim) tile explicitly to avoid stride bugs with extra dims + for i, j in T.Parallel(block_m, index_dim): + input_local[i, j] = input_tensor[bx, pid_m * block_m + i, g, j] + + # Reduce over index_dim to get amax per sequence position T.reduce_absmax(input_local, amax_local, dim=1) for i in T.Parallel(block_m): amax_local[i] = T.max(amax_local[i], 1e-4) scale_local[i] = amax_local[i] * fp8_max_inv + + # Quantize: q = clamp(input / scale, [-448, 448]) for i, j in T.Parallel(block_m, index_dim): output_local[i, j] = T.clamp(input_local[i, j] / scale_local[i], fp8_min, fp8_max) + + # Write back scale and output for i in T.Parallel(block_m): - scale_tensor[pid_m * block_m + i] = scale_local[i] - T.copy(output_local, output_shared) - T.copy(output_shared, output_tensor[pid_m * block_m, 0]) + scale_tensor[bx, pid_m * block_m + i, g] = scale_local[i] + for i, j in T.Parallel(block_m, index_dim): + output_tensor[bx, pid_m * block_m + i, g, j] = output_local[i, j] return _fp8_quant_fwd_main @@ -52,18 +60,20 @@ def _fp8_quant_fwd_main(input_tensor: T.Tensor[(seq_len_kv, index_dim), in_dtype @torch.library.custom_op("top::fp8_quant_wrapped_kernel", mutates_args=()) -def _fp8_quant_wrapped_kernel(seq_len_kv: int, index_dim: int, in_dtype: str, num_stages: int, - block_m: int, +def _fp8_quant_wrapped_kernel(batch: int, seq_len_kv: int, kv_group: int, index_dim: int, + in_dtype: str, num_stages: int, block_m: int, input_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - return _fp8_quant_kernel(seq_len_kv, index_dim, in_dtype)(num_stages, block_m)(input_tensor) + return _fp8_quant_kernel(batch, seq_len_kv, kv_group, index_dim, in_dtype)(num_stages, block_m)( + input_tensor) @_fp8_quant_wrapped_kernel.register_fake -def _(seq_len_kv, index_dim, in_dtype, - num_stages, block_m, - *inputs): - return torch.empty((seq_len_kv), dtype=torch.float32, device=inputs[0].device), torch.empty( - (seq_len_kv, index_dim), dtype=torch.float8, device=inputs[0].device) +def _(batch, seq_len_kv, kv_group, index_dim, in_dtype, num_stages, block_m, *inputs): + return torch.empty((batch, seq_len_kv, kv_group), dtype=torch.float32, + device=inputs[0].device), torch.empty( + (batch, seq_len_kv, kv_group, index_dim), + dtype=torch.float8_e4m3fn, + device=inputs[0].device) class Fp8QuantKernel(Kernel): @@ -71,19 +81,28 @@ class Fp8QuantKernel(Kernel): supported_archs: list[int] = [90] def __init__(self, + batch: int, seq_len_kv: int, + kv_group: int, index_dim: int, in_dtype: torch.dtype, config: Optional[dict] = None, tune: bool = False): super().__init__() + self.batch = batch self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.index_dim = index_dim self.dtype = in_dtype self.config = config or {} - self.kernel = _fp8_quant_kernel(self.seq_len_kv, self.index_dim, self.dtype_str) + self.kernel = _fp8_quant_kernel(self.batch, self.seq_len_kv, self.kv_group, self.index_dim, + self.dtype_str) self.init_config(config, tune) + @property + def dtype_str(self) -> str: + return str(self.dtype).replace("torch.", "") + @property def default_config(self) -> dict: return {"num_stages": 0, "block_m": 32} @@ -91,12 +110,12 @@ def default_config(self) -> dict: @property def autotune_configs(self) -> list[dict]: num_stages = [0, 2] - block_m = [32] + block_m = [32, 64] _configs = list(itertools.product(num_stages, block_m)) return [{'num_stages': c[0], 'block_m': c[1]} for c in _configs] def forward(self, input_tensor: torch.Tensor): - return _fp8_quant_wrapped_kernel(self.seq_len_kv, self.index_dim, self.dtype_str, - self.config["num_stages"], self.config["block_m"], - input_tensor) + return _fp8_quant_wrapped_kernel(self.batch, self.seq_len_kv, self.kv_group, self.index_dim, + self.dtype_str, self.config["num_stages"], + self.config["block_m"], input_tensor) diff --git a/tileops/kernels/deepseek_mla/topk_selector.py b/tileops/kernels/deepseek_mla/topk_selector.py index ad90d36b..d6598515 100644 --- a/tileops/kernels/deepseek_mla/topk_selector.py +++ b/tileops/kernels/deepseek_mla/topk_selector.py @@ -22,7 +22,7 @@ def convert_to_uint16(x): def convert_to_uint32(x): - bits_uint = T.reinterpret(T.uint32, x) + bits_uint = T.reinterpret(T.uint32, T.Cast(T.float32, x)) bits_uint = T.if_then_else( x < 0, ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), @@ -31,25 +31,30 @@ def convert_to_uint32(x): return bits_uint -def _topk_selector_kernel(batch, seq_len, topk, in_dtype, out_dtype): +def _topk_selector_kernel(batch, seq_len, seq_len_kv, kv_group, topk, in_dtype, out_dtype): @tilelang.jit( out_idx=[1], pass_configs={ tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, }) - def topk_selector_fwd_func(RADIX=1 << 8, BLOCK_SIZE=1024, SMEM_INPUT_SIZE=4096): + def topk_selector_fwd_func(RADIX=1 << 8, BLOCK_SIZE=1024, SMEM_INPUT_SIZE=4096, block_m=32): batch = T.dynamic("batch") - seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") @T.prim_func def _topk_selector_kernel_main( - index_score: T.Tensor[(batch, seq_len), in_dtype], - index: T.Tensor[(batch, topk), out_dtype], - starts: T.Tensor[(batch), out_dtype], - ends: T.Tensor[(batch), out_dtype], + index_score: T.Tensor[(batch, seq_len, seq_len_kv, kv_group), in_dtype], + index: T.Tensor[(batch, seq_len, kv_group, topk), out_dtype], + starts: T.Tensor[(batch, seq_len), out_dtype], + ends: T.Tensor[(batch, seq_len), out_dtype], ): - with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + # Parallelize over seq rows by assigning one block per (batch, seq_row, kv_group). + with T.Kernel( + batch, seq_len, kv_group, + threads=BLOCK_SIZE) as (bx, by, g): tx = T.get_thread_binding() + # by is the seq row index (one block per row; no m_i loop) + seq_row = by s_threshold_bin_id = T.alloc_shared([1], T.int32) s_histogram = T.alloc_shared([RADIX + 1], T.int32) @@ -65,20 +70,27 @@ def _topk_selector_kernel_main( l_start_idx = T.alloc_var(T.int32) l_end_idx = T.alloc_var(T.int32) l_out_pos = T.alloc_var(T.int32) + l_pos = T.alloc_var(T.int32) l_new_topk = topk - l_start_idx = starts[bx] - l_end_idx = ends[bx] + l_start_idx = starts[bx, seq_row] + l_end_idx = ends[bx, seq_row] # stage 1: use 8bit to do quick topk - T.fill(s_histogram, 0) - T.fill(s_num_input[0], 0) + # T.fill(s_histogram, 0) + # T.fill(s_num_input[0], 0) + + for j in T.serial(RADIX + 1): + s_histogram[j] = 0 + s_num_input[0] = 0 T.sync_threads() - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + + for s in T.serial(T.ceildiv(seq_len_kv, BLOCK_SIZE)): input_idx = s * BLOCK_SIZE + tx - if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: - inval_int16 = convert_to_uint16(index_score[bx, input_idx]) + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len_kv: + inval_int16 = convert_to_uint16(index_score[bx, seq_row, + input_idx, g]) T.atomic_add(s_histogram[inval_int16], 1) T.sync_threads() @@ -102,22 +114,27 @@ def _topk_selector_kernel_main( l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] T.sync_threads() - # collect all elements with exponent ≥ threshold - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): - T.sync_threads() + # collect all elements with exponent >= threshold + # Avoid in-loop block barriers on dynamic serial loops: this can deadlock + # on newer TileLang codegen. + for s in T.serial(T.ceildiv(seq_len_kv, BLOCK_SIZE)): input_idx = s * BLOCK_SIZE + tx - if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: - bin_id = convert_to_uint16(index_score[bx, input_idx]) + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len_kv: + bin_id = convert_to_uint16(index_score[bx, seq_row, + input_idx, g]) l_bin_id32 = T.Cast(T.int32, bin_id) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) - pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) - index[bx, pos] = input_idx + l_pos = T.atomic_add( + s_histogram[l_bin_id32 + 1], 1, return_prev=True) + if l_pos < topk: + index[bx, seq_row, g, l_pos] = input_idx elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: # pos = s_num_input[0] - pos = T.atomic_add(s_num_input[0], 1, return_prev=True) - s_input_idx[0, pos] = input_idx + l_pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + if l_pos < SMEM_INPUT_SIZE: + s_input_idx[0, l_pos] = input_idx # stage 2: tail pass for round in T.serial(4): @@ -128,7 +145,9 @@ def _topk_selector_kernel_main( l_start_pos = topk - l_new_topk T.sync_threads() - T.fill(s_histogram, 0) + for j in T.serial(RADIX + 1): + s_histogram[j] = 0 + # T.fill(s_histogram, 0) if tx == 0: s_num_input[r_idx ^ 1] = 0 T.sync_threads() @@ -137,7 +156,8 @@ def _topk_selector_kernel_main( for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast(T.int32, ((convert_to_uint32( - index_score[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> + index_score[bx, seq_row, + s_input_idx[r_idx, s * BLOCK_SIZE + tx], g]) >> (24 - round * 8)) & 0xFF)) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() @@ -163,27 +183,32 @@ def _topk_selector_kernel_main( T.sync_threads() for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): - T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast(T.int32, ((convert_to_uint32( - index_score[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> + index_score[bx, seq_row, + s_input_idx[r_idx, s * BLOCK_SIZE + tx], g]) >> (24 - round * 8)) & 0xFF)) if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos - index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + l_pos = T.atomic_add( + s_histogram[l_bin_id32 + 1], 1, + return_prev=True) + l_start_pos + index[bx, seq_row, g, + l_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: l_out_pos = T.atomic_add( s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: - index[bx, l_out_pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + index[bx, seq_row, g, + l_out_pos] = s_input_idx[r_idx, + s * BLOCK_SIZE + tx] else: - pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + l_pos = T.atomic_add( + s_num_input[r_idx ^ 1], 1, return_prev=True) + if l_pos < SMEM_INPUT_SIZE: + s_input_idx[r_idx ^ 1, + l_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return _topk_selector_kernel_main @@ -194,24 +219,27 @@ def _topk_selector_kernel_main( def _topk_selector_wrapped_kernel( batch: int, seq_len: int, + seq_len_kv: int, + kv_group: int, topk: int, in_dtype: str, out_dtype: str, RADIX: int, BLOCK_SIZE: int, SMEM_INPUT_SIZE: int, + block_m: int, index_score: torch.Tensor, starts: torch.Tensor, ends: torch.Tensor, ) -> torch.Tensor: - return _topk_selector_kernel(batch, seq_len, topk, in_dtype, - out_dtype)(RADIX, BLOCK_SIZE, SMEM_INPUT_SIZE)(index_score, starts, - ends) + return _topk_selector_kernel(batch, seq_len, seq_len_kv, kv_group, topk, in_dtype, + out_dtype)(RADIX, BLOCK_SIZE, SMEM_INPUT_SIZE, + block_m)(index_score, starts, ends) @_topk_selector_wrapped_kernel.register_fake -def _(batch, seq_len, topk, in_dtype, out_dtype, *inputs) -> None: - return torch.empty([batch, topk], device=inputs[0].device, dtype=torch.int32) +def _(batch, seq_len, seq_len_kv, kv_group, topk, in_dtype, out_dtype, *inputs) -> None: + return torch.empty([batch, seq_len, kv_group, topk], device=inputs[0].device, dtype=torch.int32) class TopkSelectorKernel(Kernel): @@ -221,6 +249,8 @@ class TopkSelectorKernel(Kernel): def __init__(self, batch: int, seq_len: int, + seq_len_kv: int, + kv_group: int, topk: int, in_dtype: torch.dtype, out_dtype: torch.dtype, @@ -229,13 +259,16 @@ def __init__(self, super().__init__() self.batch = batch self.seq_len = seq_len + self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.topk = topk self.in_dtype = in_dtype self.out_dtype = out_dtype self.in_dtype_str = self.dtype_to_str(self.in_dtype) self.out_dtype_str = self.dtype_to_str(self.out_dtype) - self.kernel = _topk_selector_kernel(self.batch, self.seq_len, self.topk, self.in_dtype_str, + self.kernel = _topk_selector_kernel(self.batch, self.seq_len, self.seq_len_kv, + self.kv_group, self.topk, self.in_dtype_str, self.out_dtype_str) self.init_config(config, tune) @@ -245,6 +278,7 @@ def default_config(self) -> dict: "RADIX": 1 << 8, "BLOCK_SIZE": 1024, "SMEM_INPUT_SIZE": 4096, + "block_m": 32, } @property @@ -258,14 +292,16 @@ def autotune_configs(self) -> list[dict]: RADIX = [1 << 8] BLOCK_SIZE = [1024] SMEM_INPUT_SIZE = [4096] - _configs = list(itertools.product(RADIX, BLOCK_SIZE, SMEM_INPUT_SIZE)) + block_m = [32] + _configs = list(itertools.product(RADIX, BLOCK_SIZE, SMEM_INPUT_SIZE, block_m)) - return [{'RADIX': c[0], 'BLOCK_SIZE': c[1], 'SMEM_INPUT_SIZE': c[2]} for c in _configs] + return [{'RADIX': c[0], 'BLOCK_SIZE': c[1], 'SMEM_INPUT_SIZE': c[2], 'block_m': c[3]} for c in _configs] def forward(self, index_score: torch.Tensor, starts: torch.Tensor, ends: torch.Tensor) -> torch.Tensor: - return _topk_selector_wrapped_kernel(self.batch, self.seq_len, self.topk, self.in_dtype_str, + return _topk_selector_wrapped_kernel(self.batch, self.seq_len, self.seq_len_kv, + self.kv_group, self.topk, self.in_dtype_str, self.out_dtype_str, self.config["RADIX"], self.config["BLOCK_SIZE"], - self.config["SMEM_INPUT_SIZE"], index_score, starts, - ends) + self.config["SMEM_INPUT_SIZE"], self.config["block_m"], + index_score, starts, ends) diff --git a/tileops/ops/fp8_lighting_indexer.py b/tileops/ops/fp8_lighting_indexer.py index 6787515e..99961b15 100644 --- a/tileops/ops/fp8_lighting_indexer.py +++ b/tileops/ops/fp8_lighting_indexer.py @@ -13,42 +13,63 @@ class Fp8LightingIndexerOp(Op): def __init__(self, + batch, seq_len, heads, index_dim, seq_len_kv, + kv_group, clean_logits=True, config: Optional[dict] = None, kernel_map: Optional[Dict[str, Kernel]] = None, tune=False) -> None: + self.batch = batch self.seq_len = seq_len self.heads = heads self.index_dim = index_dim self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.clean_logits = clean_logits - self.config = config self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["fp8_lighting_indexer_kernel"]( - seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune=tune) + batch, seq_len, heads, index_dim, seq_len_kv, kv_group, clean_logits, config, tune=tune) @property def default_kernel_map(self) -> Dict[str, Kernel]: return {"fp8_lighting_indexer_kernel": Fp8LightingIndexerKernel} - def forward(self, index_q: torch.Tensor, index_k: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor) -> torch.Tensor: + def torch_quant_forward(self, index_q: torch.Tensor, index_k: torch.Tensor, + weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor) -> torch.Tensor: index_q = index_q.to(torch.float8_e4m3fn) index_k, index_k_scale = self.per_custom_dims_cast_to_fp8(index_k, (0,), False) + + return self.kernel(index_q, index_k, index_k_scale, weights, cu_seqlen_ks, cu_seqlen_ke) + + def tl_quant_forward(self, index_q: torch.Tensor, index_k: torch.Tensor, + index_k_scale: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor) -> torch.Tensor: return self.kernel(index_q, index_k, index_k_scale, weights, cu_seqlen_ks, cu_seqlen_ke) + def forward(self, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + index_k_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + if index_k_scale is None: + return self.torch_quant_forward(index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke) + return self.tl_quant_forward(index_q, index_k, index_k_scale, weights, cu_seqlen_ks, + cu_seqlen_ke) + def per_custom_dims_cast_to_fp8(self, x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) - x_amax = x.to(torch.float32).abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 + x_absmax = x.to(torch.float32).abs().amax(dim=-1, keepdim=True).clamp(1e-4) + sf = x_absmax / 448.0 if use_ue8m0: assert sf.view(-1).amax().item() > 0 - sf = torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + sf = torch.pow(2.0, torch.ceil(torch.log2(x_absmax))) x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled, sf.squeeze() + return x_scaled, sf.squeeze(-1) diff --git a/tileops/ops/fp8_quant.py b/tileops/ops/fp8_quant.py index 1785f3b9..39cce286 100644 --- a/tileops/ops/fp8_quant.py +++ b/tileops/ops/fp8_quant.py @@ -13,17 +13,21 @@ class Fp8QuantOp(Op): def __init__(self, + batch, seq_len_kv, + kv_group, index_dim, in_dtype: torch.dtype, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False): + self.batch = batch self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.index_dim = index_dim self.in_dtype = in_dtype self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["fp8_quant_kernel"]( - self.seq_len_kv, self.index_dim, self.in_dtype, tune=tune) + self.batch, self.seq_len_kv, self.kv_group, self.index_dim, self.in_dtype, tune=tune) @property def default_kernel_map(self) -> Dict[str, Kernel]: diff --git a/tileops/ops/topk_selector.py b/tileops/ops/topk_selector.py index 0a303f6e..74507e5d 100644 --- a/tileops/ops/topk_selector.py +++ b/tileops/ops/topk_selector.py @@ -15,6 +15,8 @@ class TopkSelectorOp(Op): def __init__(self, batch: int, seq_len: int, + seq_len_kv: int, + kv_group: int, topk: int, in_dtype: torch.dtype, out_dtype: torch.dtype, @@ -22,13 +24,22 @@ def __init__(self, tune: bool = False) -> None: self.batch = batch self.seq_len = seq_len + self.seq_len_kv = seq_len_kv + self.kv_group = kv_group self.topk = topk self.in_dtype = in_dtype self.out_dtype = out_dtype self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["topk_selector_kernel"]( - self.batch, self.seq_len, self.topk, self.in_dtype, self.out_dtype, tune=tune) + self.batch, + self.seq_len, + self.seq_len_kv, + self.kv_group, + self.topk, + self.in_dtype, + self.out_dtype, + tune=tune) @property def default_kernel_map(self) -> Dict[str, Kernel]: