Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a50ddda
add fused func
stelladuyx Feb 6, 2026
90b60e3
add dsa func test and benchmark
stelladuyx Feb 8, 2026
70a1848
Merge pull request #7 from yuxian05/dsa_layer
stelladuyx Feb 9, 2026
1add812
add ref_program
stelladuyx Feb 9, 2026
823a0ab
Merge branch 'tile-ai:main' into dsa_layer
stelladuyx Feb 9, 2026
815a532
add test
stelladuyx Feb 10, 2026
dd2a463
Merge branch 'dsa_layer' of github.com:stelladuyx/TileOPs into dsa_layer
stelladuyx Feb 10, 2026
25095c8
modified quant
stelladuyx Feb 11, 2026
72087e1
fix quant func & layer
stelladuyx Feb 11, 2026
4c102ea
modify indexer
stelladuyx Feb 11, 2026
44280a5
fix dsa fused test
stelladuyx Feb 11, 2026
b7c26ad
fix topk input
stelladuyx Feb 11, 2026
82be661
fix dimension of topk selector
stelladuyx Feb 26, 2026
bb9582c
fix indexer in dsa fused
stelladuyx Feb 26, 2026
b3bd7c1
fix kv_group in indexer
stelladuyx Feb 27, 2026
33227e6
fix indexer
stelladuyx Feb 28, 2026
ef9206a
delete print
stelladuyx Feb 28, 2026
8312e1d
fix kv_group dimension in quant
stelladuyx Feb 28, 2026
da6ef4f
add kv_group dimension into topk_selector
stelladuyx Feb 28, 2026
512f656
fix indexer and topk
stelladuyx Mar 2, 2026
55675bf
clean dsa func
stelladuyx Mar 2, 2026
53bdc97
merge upstream main into branch
stelladuyx Mar 2, 2026
7417095
merge upstream main into branch
stelladuyx Mar 2, 2026
5371186
merge upstream main into branch
stelladuyx Mar 2, 2026
8f89d1b
Merge branch 'main' into kernel_fix
stelladuyx Mar 2, 2026
f5ab71c
Apply suggestion from @gemini-code-assist[bot]
stelladuyx Mar 2, 2026
e860fa7
[chore] remove redundant code
stelladuyx Mar 2, 2026
460a247
[bug] fix topk
stelladuyx Mar 3, 2026
afa6f5a
Merge branch 'main' into kernel_fix
stelladuyx Mar 3, 2026
3f6359d
Update tileops/kernels/deepseek_mla/topk_selector.py
stelladuyx Mar 3, 2026
ba79608
X
stelladuyx Mar 4, 2026
f3ac19f
[fix] Fixed a compile-time bug
stelladuyx Mar 5, 2026
c7217ad
[fix] fix indexer
stelladuyx Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/ops/bench_fp8_lighting_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def calculate_memory(self) -> Optional[float]:
accum_dtype = torch.float32
index_dtype = torch.int32

index_q_memory = t.seq_len * t.heads * t.index_dim * dtype.itemsize
index_k_memory = t.seq_len_kv * t.index_dim * dtype.itemsize
index_k_scale_memory = t.seq_len_kv * accum_dtype.itemsize
logits_memory = t.seq_len * t.seq_len_kv * accum_dtype.itemsize
index_q_memory = t.batch * t.seq_len * t.heads * t.index_dim * dtype.itemsize
index_k_memory = t.batch * t.seq_len_kv * t.index_dim * t.kv_group * dtype.itemsize
index_k_scale_memory = t.batch * t.seq_len_kv * t.kv_group * accum_dtype.itemsize
logits_memory = t.batch * t.seq_len * t.seq_len_kv * t.kv_group * accum_dtype.itemsize
weights_memory = t.seq_len * t.heads * accum_dtype.itemsize
cu_seqlens_ks_memory = t.seq_len * index_dtype.itemsize
cu_seqlens_ke_memory = t.seq_len * index_dtype.itemsize
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/ops/bench_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ class Fp8QuantBenchmark(BenchmarkBase):

def calculate_flops(self) -> Optional[float]:
t = self.test
return 2 * t.seq_len_kv * t.index_dim + t.seq_len_kv + 4 * t.seq_len_kv * t.index_dim
return (2 * t.batch * t.seq_len_kv * t.kv_group * t.index_dim +
t.batch * t.seq_len_kv * t.kv_group + 4 * t.seq_len_kv * t.kv_group * t.index_dim)

def calculate_memory(self) -> Optional[float]:
t = self.test
return t.seq_len_kv * t.index_dim * t.in_dtype.itemsize
return t.batch * t.seq_len_kv * t.kv_group * t.index_dim * t.in_dtype.itemsize


@Fp8QuantFixture
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/ops/bench_topk_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def calculate_flops(self) -> Optional[float]:

def calculate_memory(self) -> Optional[float]:
t = self.test
index_score_memory = t.batch * t.seq_len * t.in_dtype.itemsize
index_memory = t.batch * t.topk * t.out_dtype.itemsize
starts_memory = t.batch * t.out_dtype.itemsize
ends_memory = t.batch * t.out_dtype.itemsize
index_score_memory = (t.batch * t.seq_len * t.seq_len_kv * t.kv_group * t.in_dtype.itemsize)
index_memory = t.batch * t.seq_len * t.topk * t.kv_group * t.out_dtype.itemsize
starts_memory = t.batch * t.seq_len * t.out_dtype.itemsize
ends_memory = t.batch * t.seq_len * t.out_dtype.itemsize
return index_score_memory + index_memory + starts_memory + ends_memory


Expand Down
86 changes: 67 additions & 19 deletions tests/ops/test_fp8_lighting_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@

class Fp8LightingIndexerFixture(FixtureBase):
PARAMS = [
("seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune", [
(4096, 32, 64, 8192, True, None, False),
("batch, seq_len, heads, index_dim, seq_len_kv, kv_group, clean_logits, config, tune", [
(1, 4096, 32, 64, 8192, 1, True, None, False),
]),
]


class Fp8LightingIndexerTest(TestBase):

def __init__(self, seq_len: int, heads: int, index_dim: int, seq_len_kv: int,
clean_logits: bool = True, config: Optional[dict] = None):
def __init__(self,
batch: int,
seq_len: int,
heads: int,
index_dim: int,
seq_len_kv: int,
kv_group: int,
clean_logits: bool = True,
config: Optional[dict] = None):
self.batch = batch
self.seq_len = seq_len
self.heads = heads
self.index_dim = index_dim
self.seq_len_kv = seq_len_kv
self.kv_group = kv_group
self.clean_logits = clean_logits
self.config = config
self.dtype = torch.float8_e4m3fn
Expand Down Expand Up @@ -147,31 +156,59 @@ def generate_random_cu_seqlens(self,
assert len(ks) == len(ke) == self.seq_len
return ks, ke

def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
def gen_inputs(
self,
params=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
IndexQ = torch.randn(
self.seq_len, self.heads, self.index_dim, device='cuda', dtype=torch.bfloat16)
IndexK = torch.randn(self.seq_len_kv, self.index_dim, device='cuda', dtype=torch.bfloat16)
self.batch,
self.seq_len,
self.heads,
self.index_dim,
device='cuda',
dtype=torch.bfloat16)
IndexK = torch.randn(
self.batch,
self.seq_len_kv,
self.kv_group,
self.index_dim,
device='cuda',
dtype=torch.bfloat16)
Weights = torch.randn(self.seq_len, self.heads, device='cuda', dtype=self.accum_dtype)
CuSeqLenKS = torch.zeros(self.seq_len, device='cuda', dtype=self.index_dtype)
CuSeqLenKE = torch.full((self.seq_len,),
fill_value=self.seq_len_kv - 1,
device='cuda',
dtype=self.index_dtype)
CuSeqLenKS, CuSeqLenKE = self.generate_random_cu_seqlens(
cp_size=4, cp_rank=3, kv_stride=1, average_q_len=2048)
return IndexQ, IndexK, Weights, CuSeqLenKS, CuSeqLenKE

def ref_program(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor) -> Tuple[torch.Tensor]:
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor) -> Tuple[torch.Tensor]:
k = kv
q = q.float()
k = k.float()
batch, seq_len, heads, index_dim = q.shape
seq_len_kv = self.seq_len_kv
kv_group = self.kv_group
heads_per_group = heads // kv_group

k = k.view(batch, seq_len_kv, kv_group, index_dim)
q = q.view(batch, seq_len, kv_group, heads_per_group, index_dim)

seq_len_kv = kv.shape[0]
mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi

score = torch.einsum("mhd,nd->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
score = torch.einsum("bsghd,bngd->bghsn", q, k)
weights = weights.view(seq_len, kv_group, heads_per_group)
weights = weights.permute(1, 2, 0).unsqueeze(0).unsqueeze(-1)
score = score.relu() * weights
logits = score.sum(dim=2)
logits = logits.permute(0, 2, 3, 1)
mask_expanded = mask.unsqueeze(0).unsqueeze(-1)
logits = logits.masked_fill(~mask_expanded, float("-inf"))
return (logits,)

@staticmethod
Expand All @@ -181,7 +218,9 @@ def _compute_correlation(a: torch.Tensor, b: torch.Tensor) -> float:
return 2 * (a * b).sum() / norm_sum

@staticmethod
def _validate_tensor_match(output: torch.Tensor, output_ref: torch.Tensor, tolerance: float = 1e-3) -> None:
def _validate_tensor_match(output: torch.Tensor,
output_ref: torch.Tensor,
tolerance: float = 1e-3) -> None:
if isinstance(output, tuple):
output = output[0]
if isinstance(output_ref, tuple):
Expand All @@ -206,11 +245,20 @@ def _validate_tensor_match(output: torch.Tensor, output_ref: torch.Tensor, toler


@Fp8LightingIndexerFixture
def test_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int, clean_logits: bool,
config: Optional[dict], tune: bool) -> None:
test = Fp8LightingIndexerTest(seq_len, heads, index_dim, seq_len_kv, clean_logits, config)
def test_indexer(batch: int, seq_len: int, heads: int, index_dim: int, seq_len_kv: int,
kv_group: int, clean_logits: bool, config: Optional[dict], tune: bool) -> None:
test = Fp8LightingIndexerTest(batch, seq_len, heads, index_dim, seq_len_kv, kv_group,
clean_logits, config)
op = Fp8LightingIndexerOp(
seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune=tune)
batch=batch,
seq_len=seq_len,
heads=heads,
index_dim=index_dim,
seq_len_kv=seq_len_kv,
kv_group=kv_group,
clean_logits=clean_logits,
config=config,
tune=tune)
test.check(op, *test.gen_inputs(), compare=Fp8LightingIndexerTest._validate_tensor_match)


Expand Down
42 changes: 29 additions & 13 deletions tests/ops/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

class Fp8QuantFixture(FixtureBase):
PARAMS = [
("seq_len_kv, index_dim, in_dtype, tune", [
(8192, 64, torch.float16, False),
(8192, 64, torch.bfloat16, False),
(4096, 128, torch.float32, False),
(16384, 32, torch.float32, False),
("batch, seq_len_kv, kv_group, index_dim, in_dtype, tune", [
(1, 8192, 1, 64, torch.float16, False),
(1, 8192, 1, 64, torch.bfloat16, False),
(1, 4096, 1, 128, torch.float32, False),
(1, 16384, 1, 32, torch.float32, False),
(1, 1024, 4, 64, torch.float16, False),
]),
]

Expand All @@ -30,29 +31,44 @@ def _cosine_compare(output: torch.Tensor, output_ref: torch.Tensor) -> None:

class Fp8QuantTest(TestBase):

def __init__(self, seq_len_kv: int, index_dim: int, in_dtype: torch.dtype):
def __init__(self, batch: int, seq_len_kv: int, kv_group: int, index_dim: int,
in_dtype: torch.dtype):
self.batch = batch
self.seq_len_kv = seq_len_kv
self.kv_group = kv_group
self.index_dim = index_dim
self.in_dtype = in_dtype

def gen_inputs(self) -> Tuple[torch.Tensor]:
input_tensor = torch.randn(
self.seq_len_kv, self.index_dim, dtype=self.in_dtype, device="cuda")
self.batch,
self.seq_len_kv,
self.kv_group,
self.index_dim,
dtype=self.in_dtype,
device="cuda")
return (input_tensor,)

def ref_program(self, input_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
amax_value = torch.abs(input_tensor).amax(dim=1, keepdim=True).clamp(min=1e-4)
# input_tensor: (batch, seq_len_kv, kv_group, index_dim)
amax_value = torch.abs(input_tensor).amax(dim=-1, keepdim=True).clamp(min=1e-4)
scale_tensor = amax_value / 448.0
output_tensor = torch.clamp(input_tensor / scale_tensor, min=-448.0, max=448.0)
output_tensor = output_tensor.to(torch.float8_e4m3fn)
return scale_tensor.squeeze(dim=1), output_tensor
return scale_tensor.squeeze(dim=-1), output_tensor


@Fp8QuantFixture
def test_fp8_quant_op(seq_len_kv: int, index_dim: int, in_dtype: torch.dtype,
tune: bool) -> None:
test = Fp8QuantTest(seq_len_kv, index_dim, in_dtype)
op = Fp8QuantOp(seq_len_kv=seq_len_kv, index_dim=index_dim, in_dtype=in_dtype, tune=tune)
def test_fp8_quant_op(batch: int, seq_len_kv: int, kv_group: int, index_dim: int,
in_dtype: torch.dtype, tune: bool) -> None:
test = Fp8QuantTest(batch, seq_len_kv, kv_group, index_dim, in_dtype)
op = Fp8QuantOp(
batch=batch,
seq_len_kv=seq_len_kv,
kv_group=kv_group,
index_dim=index_dim,
in_dtype=in_dtype,
tune=tune)
test.check(op, *test.gen_inputs(), compare=_cosine_compare)


Expand Down
56 changes: 40 additions & 16 deletions tests/ops/test_topk_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

class TopkSelectorFixture(FixtureBase):
PARAMS = [
("batch, seq_len, topk, in_dtype_str, out_dtype_str, tune", [
(64, 32 * 1024, 1024, "float32", "int32", False),
(64, 32 * 1024, 2048, "float32", "int32", False),
(128, 64 * 1024, 1024, "float32", "int32", False),
(128, 64 * 1024, 2048, "float32", "int32", False),
("batch, seq_len, seq_len_kv, kv_group, topk, in_dtype_str, out_dtype_str, tune", [
(4, 256, 1024, 1, 32, "float32", "int32", False),
(8, 512, 2048, 1, 64, "float32", "int32", False),
(1, 32 * 1024, 64 * 1024, 1, 1024, "float32", "int32", False),
(1, 32 * 1024, 64 * 2048, 1, 2048, "float32", "int32", False),
]),
]

Expand All @@ -33,33 +33,57 @@ def _set_compare(output: torch.Tensor, output_ref: torch.Tensor) -> None:

class TopkSelectorTest(TestBase):

def __init__(self, batch: int, seq_len: int, topk: int, in_dtype: torch.dtype,
out_dtype: torch.dtype):
def __init__(self, batch: int, seq_len: int, seq_len_kv: int, kv_group: int, topk: int,
in_dtype: torch.dtype, out_dtype: torch.dtype):
self.batch = batch
self.seq_len = seq_len
self.seq_len_kv = seq_len_kv
self.kv_group = kv_group
self.topk = topk
self.in_dtype = in_dtype
self.out_dtype = out_dtype

def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
index_score = torch.randn(self.batch, self.seq_len, dtype=self.in_dtype, device="cuda")
starts = torch.zeros(self.batch, dtype=self.out_dtype, device="cuda")
ends = torch.ones(self.batch, dtype=self.out_dtype, device="cuda") * self.seq_len
index_score = torch.randn(
self.batch,
self.seq_len,
self.seq_len_kv,
self.kv_group,
dtype=self.in_dtype,
device="cuda")
starts = torch.zeros(self.batch, self.seq_len, dtype=self.out_dtype, device="cuda")
ends = torch.ones(self.batch, self.seq_len, dtype=self.out_dtype,
device="cuda") * self.seq_len_kv
return index_score, starts, ends

def ref_program(self, index_score: torch.Tensor, starts: torch.Tensor,
ends: torch.Tensor) -> torch.Tensor:
indexes_ref = torch.topk(index_score, self.topk, dim=-1)[1]
return indexes_ref
# index_score: (batch, seq_len, seq_len_kv, kv_group); topk over seq_len_kv (dim=2)
indexes_ref = torch.topk(index_score, self.topk, dim=2)[1]
# Match kernel/output layout: (batch, seq_len, kv_group, topk)
return indexes_ref.permute(0, 1, 3, 2)


@TopkSelectorFixture
def test_topk_selector_op(batch: int, seq_len: int, topk: int, in_dtype_str: str,
out_dtype_str: str, tune: bool) -> None:
def test_topk_selector_op(batch: int,
seq_len: int,
seq_len_kv: int,
kv_group: int,
topk: int,
in_dtype_str: str,
out_dtype_str: str,
tune: bool) -> None:
in_dtype = str2dtype[in_dtype_str]
out_dtype = str2dtype[out_dtype_str]
test = TopkSelectorTest(batch, seq_len, topk, in_dtype, out_dtype)
op = TopkSelectorOp(batch, seq_len, topk, in_dtype, out_dtype, tune=tune)
test = TopkSelectorTest(batch, seq_len, seq_len_kv, kv_group, topk, in_dtype, out_dtype)
op = TopkSelectorOp(batch=batch,
seq_len=seq_len,
seq_len_kv=seq_len_kv,
kv_group=kv_group,
topk=topk,
in_dtype=in_dtype,
out_dtype=out_dtype,
tune=tune)
test.check(op, *test.gen_inputs(), compare=_set_compare)


Expand Down
Loading
Loading