Skip to content

[feature] Modify dsa kernels for dsa function#273

Open
stelladuyx wants to merge 33 commits intotile-ai:mainfrom
stelladuyx:kernel_fix
Open

[feature] Modify dsa kernels for dsa function#273
stelladuyx wants to merge 33 commits intotile-ai:mainfrom
stelladuyx:kernel_fix

Conversation

@stelladuyx
Copy link
Collaborator

@stelladuyx stelladuyx commented Mar 2, 2026

Description

Type of Change

  • Bug fix
  • New operator implementation
  • Performance improvement
  • Documentation update
  • Infrastructure/CI

Checklist

  • I have run pre-commit run --all-files and fixed all linting issues.
  • I have verified that my changes pass local unit tests.
  • (For new ops) I have added the corresponding Benchmark class in benchmarks/.
  • (For new ops) I have reported benchmark results in the tracking issue.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the DeepSeek MLA kernels by integrating batch and kv_group dimensions into the fp8_lighting_indexer, fp8_quant, and topk_selector components. This change allows these kernels to process data in batches and support grouped-query attention mechanisms, broadening their applicability and improving efficiency for various deep learning models. The modifications span across the benchmarks, functions, layers, ops, and the underlying kernel implementations to ensure comprehensive support and correct behavior with the new dimensionality.

Highlights

  • Batch and KV Group Support: Introduced batch and kv_group parameters across the fp8_lighting_indexer, fp8_quant, and topk_selector kernels, functions, ops, layers, and benchmarks to enable more flexible batched and grouped-query attention operations.
  • Kernel Logic Refinement: Refactored the core kernel logic in fp8_lighting_indexer and fp8_quant to correctly handle the new batch and kv_group dimensions, including updates to tensor shapes, memory allocation, and computation patterns.
  • Benchmark and Test Updates: Modified benchmarks and unit tests for fp8_lighting_indexer, fp8_quant, and topk_selector to incorporate the new batch and kv_group parameters, ensuring correctness and performance evaluation for the extended functionality.
  • API Simplification: Removed the config parameter from Fp8LightingIndexerFunc and Fp8LightingIndexerDecodeLayer constructors, streamlining their API.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/deepseek_mla/fp8_lighting_indexer.py
    • Added batch and kv_group parameters to the Fp8LightingIndexerBenchmark constructor.
    • Updated total_flops and total_memory calculations to account for batch and kv_group dimensions.
    • Modified gen_inputs to generate tensors with batch and kv_group dimensions.
    • Refactored ref_program to correctly handle the new tensor shapes and perform grouped einsum operations.
    • Added a shape assertion in validate_tensor_match for robust tensor comparison.
  • benchmarks/deepseek_mla/fp8_quant.py
    • Added batch and kv_group parameters to the Fp8QuantBenchmark constructor.
    • Updated total_flops and total_memory calculations to include batch and kv_group.
    • Modified gen_inputs to create input tensors with batch and kv_group dimensions.
    • Refactored ref_program to align with the new tensor shapes for quantization.
  • benchmarks/deepseek_mla/topk_selector.py
    • Added seq_len_kv and kv_group parameters to the TopkSelectorBenchmark constructor.
    • Updated total_memory calculation to reflect the new seq_len_kv and kv_group dimensions.
    • Modified gen_inputs to generate index_score with seq_len_kv and kv_group dimensions and adjusted ends tensor.
    • Refactored ref_program to perform top-k selection over the correct dimension and permute the output.
  • tests/functions/test_fp8_lighting_indexer_func.py
    • Removed the config parameter from test_fp8_lighting_indexer function signature and calls.
  • tests/functions/test_fp8_quant.py
    • Updated pytest.mark.parametrize decorator to include batch and kv_group parameters for test cases.
    • Modified test_fp8_quant function signature and calls to pass batch and kv_group.
  • tests/functions/test_topk_selector_func.py
    • Added seq_len_kv and kv_group parameters to test_topk_selector function signature and calls.
    • Updated argparse arguments to include seq_len_kv and kv_group with new default values for seq_len and seq_len_kv.
  • tests/ops/test_fp8_lighting_indexer.py
    • Removed Optional import.
    • Added batch and kv_group parameters to test_indexer function signature and calls.
    • Removed the config argument from Fp8LightingIndexerOp and Fp8LightingIndexerBenchmark instantiations.
    • Updated argparse arguments to include batch and kv_group and adjusted default seq_len and seq_len_kv.
  • tests/ops/test_fp8_quant.py
    • Updated pytest.mark.parametrize decorator to include batch and kv_group parameters and added a new test case.
    • Modified test_fp8_quant_op function signature and calls to pass batch and kv_group.
    • Added benchmark.profile(op, inputs) call.
  • tests/ops/test_topk_selector.py
    • Updated pytest.mark.parametrize decorator to include seq_len_kv and kv_group parameters and new test cases.
    • Modified test_topk_selector_op function signature and calls to pass seq_len_kv and kv_group.
  • top/functions/fp8_lighting_indexer.py
    • Removed Optional import.
    • Removed the config parameter from Fp8LightingIndexerFunc constructor and its instantiation of Fp8LightingIndexerOp.
  • top/functions/fp8_quant.py
    • Added batch and kv_group parameters to Fp8QuantFunc constructor and its instantiation of Fp8QuantOp.
  • top/functions/topk_selector.py
    • Added seq_len_kv and kv_group parameters to TopkSelectorFunc constructor and its instantiation of TopkSelectorOp.
  • top/kernels/deepseek_mla/fp8_lighting_indexer.py
    • Added batch and kv_group parameters to _fp8_lighting_indexer_kernel and clean_logits_ functions.
    • Updated tensor shapes for index_q, index_k, index_k_scale, and logits to include batch and kv_group dimensions.
    • Modified kernel loops and shared memory allocations to incorporate batch and kv_group.
    • Refactored einsum and related computations within the kernel to correctly handle grouped attention.
    • Updated fp8_lighting_indexer_wrapped_kernel and Fp8LightingIndexerKernel to pass the new parameters.
    • Changed the default num_stages in Fp8LightingIndexerKernel to 2.
    • Adjusted supply_prog to reflect the new tensor shapes for inputs.
  • top/kernels/deepseek_mla/fp8_quant.py
    • Added batch and kv_group parameters to _fp8_quant_kernel.
    • Updated tensor shapes in _fp8_quant_fwd_main to include batch and kv_group dimensions.
    • Modified kernel loops to iterate over batch and kv_group.
    • Updated _fp8_quant_wrapped_kernel and Fp8QuantKernel to pass the new parameters.
    • Expanded autotune_configs for block_m to include 64.
  • top/kernels/deepseek_mla/topk_selector.py
    • Modified convert_to_uint32 to explicitly cast to float32 before reinterpret.
    • Added seq_len_kv and kv_group parameters to _topk_selector_kernel.
    • Updated _topk_selector_kernel_main to handle batch, seq_len_kv, and kv_group dimensions in kernel loops and tensor accesses.
    • Modified _topk_selector_wrapped_kernel and TopkSelectorKernel to pass new parameters and added block_m to config.
    • Updated the return shape of the _topk_selector_wrapped_kernel.register_fake function.
  • top/layers/deepseek_mla.py
    • Removed the config parameter from Fp8LightingIndexerDecodeLayer constructor.
    • Added batch, seq_len_kv, and kv_group parameters to TopkSelectorLayer and Fp8QuantLayer constructors and their respective function instantiations.
  • top/ops/fp8_lighting_indexer.py
    • Added batch and kv_group parameters to Fp8LightingIndexerOp constructor and its instantiation of Fp8LightingIndexerKernel.
    • Removed the config parameter from the constructor.
    • Introduced torch_quant_forward and tl_quant_forward methods to handle different quantization input formats.
    • Updated the main forward method to dispatch calls based on the number of arguments.
    • Refined per_custom_dims_cast_to_fp8 logic to use amax(dim=-1) and squeeze(-1) for scale factor.
  • top/ops/fp8_quant.py
    • Added batch and kv_group parameters to Fp8QuantOp constructor and its instantiation of Fp8QuantKernel.
Activity
  • The pull request was created by stelladuyx, introducing changes to DeepSeek MLA kernels.
  • The author provided a general description and a checklist for the pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant modifications across multiple components to add support for batching and grouped-query attention (via kv_group) to several DSA (DeepSeek Attention) kernels. The changes are extensive, affecting kernels, functions, layers, benchmarks, and tests. While the overall implementation direction is correct, I have identified several issues, including bugs in benchmark metric calculations, autotuning configurations, and tensor initializations within the kernels. Additionally, there are some code quality concerns, such as duplicated test code and debugging artifacts left in the kernel code. Addressing these points will improve the correctness and maintainability of the new features.

Comment on lines 66 to +74
def per_custom_dims_cast_to_fp8(self, x: torch.Tensor, dims: Tuple[int],
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.to(torch.float32).abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_absmax = x.to(torch.float32).abs().amax(dim=-1, keepdim=True).clamp(1e-4)
sf = x_absmax / 448.0
if use_ue8m0:
assert sf.view(-1).amax().item() > 0
sf = torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
sf = torch.pow(2.0, torch.ceil(torch.log2(x_absmax)))
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
return x_scaled, sf.squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of per_custom_dims_cast_to_fp8 has been changed to compute amax over the last dimension (dim=-1) instead of a more generic set of excluded dimensions. This aligns it with per-token quantization, which is what's needed here. The logic for calculating the scale factor and applying it seems correct for this purpose.

@stelladuyx stelladuyx requested a review from a team March 2, 2026 07:37
@lcy-seso lcy-seso marked this pull request as draft March 2, 2026 09:52
@stelladuyx stelladuyx self-assigned this Mar 3, 2026
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@stelladuyx stelladuyx added feature New feature or new operator human-led Human-led task. AI assists but does not own. labels Mar 3, 2026
@stelladuyx stelladuyx marked this pull request as ready for review March 5, 2026 09:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or new operator human-led Human-led task. AI assists but does not own.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants