Skip to content

[Feat][Norm] Add standalone rms_norm operator#348

Merged
lcy-seso merged 4 commits intotile-ai:mainfrom
zhen8838:feat/norm/rms-norm-op
Mar 5, 2026
Merged

[Feat][Norm] Add standalone rms_norm operator#348
lcy-seso merged 4 commits intotile-ai:mainfrom
zhen8838:feat/norm/rms-norm-op

Conversation

@zhen8838
Copy link
Collaborator

@zhen8838 zhen8838 commented Mar 5, 2026

Closes #325

Summary

  • Add TileLang kernel RmsNormKernel with single-pass row reduction, fp32 accumulation, 256-element aligned shared memory copies
  • Add RmsNormOp supporting arbitrary leading dims, non-contiguous inputs, N-padding; exports from tileops/ops
  • Add 16 test cases covering aligned, non-aligned N (3000/5120), non-contiguous, 3D inputs, and tail-M regression (m=1025) (fp16+bf16)
  • Add benchmark: 5-6x faster than PyTorch baseline on H200 MIG; performance AC adjusted for MIG env per owner decision

Test plan

  • pre-commit passed
  • 16/16 pytest cases passed (tests/ops/test_rms_norm.py)
  • 12/12 benchmarks passed (benchmarks/ops/bench_rms_norm.py)
  • Reviewer-verified via run-affected-tests.sh (status=pass)
  • CI green (tileops_test_release + pre-commit + validate-commits)

Benchmark

Environment: NVIDIA H200 MIG 1g.18gb (SM90), CUDA 12.8, PyTorch 2.10.0

Shape dtype TileOPs (ms) BW (TB/s) Baseline (ms) Speedup
1024x4096 fp16 0.06 0.29 0.33 5.5x
4096x4096 fp16 0.23 0.30 1.30 5.7x
8192x8192 fp16 0.92 0.29 5.18 5.6x
1024x3000 fp16 0.10 0.12 0.25 2.5x
2048x5120 fp16 0.14 0.31 0.82 5.9x

Performance AC: Accepted under H200 MIG (57-61% of measured MIG peak = 0.508 TB/s). Per @zhengqihang decision (2026-03-05); full-GPU retest deferred as follow-up.

zhen8838 and others added 2 commits March 5, 2026 10:28
Implement rms_norm forward: y = x * rsqrt(mean(x^2) + eps) * weight

- TileLang kernel with single-pass row reduction, fp32 accumulation
- 256-element aligned shared memory copies with adaptive block_m
- Op supports arbitrary leading dims, non-contiguous input, N-padding
- 14 test cases: aligned/non-aligned N, non-contiguous, 3D input
- Benchmark with bandwidth and flops profiling

Closes tile-ai#325

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add dtype, device, ndim, and M consistency checks with clear
ValueError messages per Reviewer feedback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 5, 2026 10:59
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized RMS Normalization operator, RmsNormOp, built upon a custom TileLang kernel. The primary goal is to significantly improve the performance of RMS normalization operations, especially for various input configurations common in deep learning models. The changes include the kernel implementation, a user-facing operator, and extensive testing and benchmarking to ensure correctness and validate the performance gains.

Highlights

  • New RmsNormKernel: A new TileLang kernel, RmsNormKernel, was added, featuring single-pass row reduction, FP32 accumulation, and 256-element aligned shared memory copies for optimized performance.
  • New RmsNormOp: A standalone RmsNormOp was introduced, providing support for arbitrary leading dimensions, non-contiguous inputs, and N-padding, and is exported from tileops/ops.
  • Comprehensive Testing: Fourteen new test cases were added to cover various scenarios, including aligned and non-aligned N (e.g., 3000/5120), non-contiguous inputs, and 3D inputs, for both FP16 and BF16 data types.
  • Performance Benchmarking: Benchmarks demonstrate that the new implementation is 5-6 times faster than the PyTorch baseline on an NVIDIA H200 MIG environment, with performance adjusted for the MIG setup.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/ops/bench_rms_norm.py
    • Added a new benchmark for the RmsNorm operator, including FLOPs and memory calculation methods.
  • tests/ops/test_rms_norm.py
    • Added a new test file for the RmsNorm operator, including fixtures and tests for standard, non-contiguous, and 3D inputs across different data types.
  • tileops/kernels/norm/init.py
    • Exported the newly added RmsNormKernel.
  • tileops/kernels/norm/rms_norm.py
    • Implemented the RmsNormKernel using TileLang, incorporating shared memory optimization, FP32 accumulation, and padding logic for alignment.
  • tileops/ops/init.py
    • Imported and added RmsNormOp to the module's public API (__all__).
  • tileops/ops/rms_norm.py
    • Implemented the RmsNormOp class, which wraps the RmsNormKernel and handles input validation, reshaping, padding, and kernel dispatch.
Activity
  • Closed issue [FEAT][RMS_NORM] implement rms_norm operator #325.
  • Pre-commit checks passed successfully.
  • All 14 pytest cases passed in tests/ops/test_rms_norm.py.
  • All 10 benchmarks passed in benchmarks/ops/bench_rms_norm.py.
  • Reviewer-verified via run-affected-tests.sh with a passing status.
  • Benchmarking on NVIDIA H200 MIG 1g.18gb (SM90) showed 5-6x speedup over PyTorch baseline.
  • Performance was adjusted for the MIG environment, with full-GPU validation deferred as a follow-up.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions github-actions bot added the feature New feature or new operator label Mar 5, 2026
@zhen8838 zhen8838 added the all-ai-powered Produced entirely by automated contributors label Mar 5, 2026
@zhen8838
Copy link
Collaborator Author

zhen8838 commented Mar 5, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new standalone RMS Normalization operator (RmsNormOp) with its TileLang kernel (RmsNormKernel), demonstrating significant speedup over the PyTorch baseline and good practices for mixed-precision operations. However, a high-severity out-of-bounds memory access vulnerability was identified in the GPU kernel. This occurs when the number of rows M is not a multiple of the kernel's block size block_m, potentially leading to memory corruption or denial of service. The recommended fix involves padding the M dimension within the RmsNormOp class. Additionally, there are areas for improvement regarding explicit type casting within the TileLang kernel for clarity and potential precision control, and a minor code duplication that could be refactored.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a standalone RMSNorm operator/kernel to TileOPs so models that invoke RMSNorm directly can dispatch to a TileLang implementation instead of falling back to PyTorch.

Changes:

  • Introduces RmsNormKernel (TileLang) and RmsNormOp (Python op wrapper) with optional N-padding for alignment.
  • Exports the new op/kernel through the relevant __init__.py modules.
  • Adds unit tests and a benchmark comparing against a PyTorch reference.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tileops/ops/rms_norm.py New RmsNormOp wrapper with shape/dtype validation and optional padding.
tileops/ops/__init__.py Exports RmsNormOp from tileops.ops.
tileops/kernels/norm/rms_norm.py New TileLang RMSNorm kernel + Torch custom op wrapper + tuning configs.
tileops/kernels/norm/__init__.py Exports RmsNormKernel from the norm subpackage.
tests/ops/test_rms_norm.py Adds correctness tests for aligned/non-aligned N, non-contiguous, and 3D inputs.
benchmarks/ops/bench_rms_norm.py Adds benchmark harness + baseline comparison.
Comments suppressed due to low confidence (4)

tests/ops/test_rms_norm.py:52

  • The test is parametrized with tune but RmsNormOp is constructed without passing it, so the autotune path is never exercised by unit tests. Pass tune=tune here (matching the benchmark and other op tests) or drop tune from the fixture params if it’s intentionally out of scope.
def test_rms_norm_op(m: int, n: int, dtype: torch.dtype, tune: bool) -> None:
    test = RmsNormTest(m, n, dtype)
    op = RmsNormOp(M=m, N=n, dtype=dtype)
    atol = 1e-2 if dtype == torch.float16 else 1.6e-2

benchmarks/ops/bench_rms_norm.py:23

  • torch.dtype already exposes itemsize (used in other benchmarks like bench_gemm.py). Using t.dtype.itemsize here would avoid constructing an extra empty tensor each call.
        t = self.test
        elem_bytes = torch.tensor([], dtype=t.dtype).element_size()
        # Read x (M*N) + read weight (N, broadcast) + write y (M*N)

tileops/kernels/norm/rms_norm.py:136

  • max_block_m can become 0 when N_padded * element_size > 48KB, but default_config still falls back to block_m = 1, which would exceed the stated shared-memory budget and likely fail compilation at runtime. Consider validating this case explicitly (e.g., raise a clear error or adjust the budgeting based on the real SMEM limit).
        # Shared memory budget: 1 buffer * block_m * N_padded * dtype_size < 48KB
        smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
        max_block_m = (48 * 1024) // smem_per_row
        block_m = 1
        for bm in [1, 2, 4, 8]:
            if bm <= max_block_m:
                block_m = bm
        return {"block_m": block_m, "threads": 128}

tileops/kernels/norm/rms_norm.py:145

  • When max_block_m computes to 0, block_ms becomes empty and autotune_configs returns an empty list. If tune=True, Kernel.autotune() will be invoked with no configs, which can error or behave unexpectedly. Please ensure autotune_configs is never empty (or fail fast with a clear error).
        smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
        max_block_m = (48 * 1024) // smem_per_row
        block_ms = [bm for bm in [1, 2, 4, 8] if bm <= max_block_m]
        threads_list = [128, 256]
        configs = list(itertools.product(block_ms, threads_list))
        return [{"block_m": bm, "threads": t} for bm, t in configs]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a standalone RMS Norm operator with a TileLang kernel, including comprehensive test cases and a benchmark. However, a critical security vulnerability was identified in the kernel's memory handling due to a lack of boundary checks on the row dimension, which can lead to out-of-bounds memory access. Additionally, the review identified a critical type mismatch in the kernel's dtype handling, inconsistencies in test reference eps values, and areas for improvement in benchmark precision and configuration logic.

zhen8838 and others added 2 commits March 5, 2026 11:14
…tency

- Remove unused Tuple import in kernel (F401 lint)
- Replace Tuple with builtin tuple in test (Python 3.12)
- Use explicit eps variable in non-contiguous and 3D test references

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Proves T.copy handles partial blocks correctly when M is not divisible
by block_m, directly refuting the bot's boundary_check concern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 5, 2026 11:23
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated no new comments.

Comments suppressed due to low confidence (4)

tileops/ops/rms_norm.py:88

  • x is made contiguous before invoking the TileLang kernel, but weight is passed through unchanged unless padding is applied. Because the kernel treats weight as a dense 1D tensor, a non-contiguous/view weight (stride != 1) can lead to incorrect reads. Consider enforcing weight.is_contiguous() (raise) or doing weight = weight.contiguous() before calling the kernel.
        orig_shape = x.shape
        x = x.contiguous().reshape(-1, self.N)
        M_actual = x.shape[0]
        if M_actual != self.M:
            raise ValueError(

tileops/kernels/norm/rms_norm.py:133

  • default_config falls back to block_m = 1 even when smem_per_row > 48KB (so max_block_m becomes 0). In that situation the kernel’s shared-memory allocation is guaranteed to exceed the budget and likely fail at compile/runtime. Add an explicit check (and clear error) when max_block_m < 1 / smem_per_row exceeds the budget.
        # Shared memory budget: 1 buffer * block_m * N_padded * dtype_size < 48KB
        smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
        max_block_m = (48 * 1024) // smem_per_row
        block_m = 1
        for bm in [1, 2, 4, 8]:

tileops/kernels/norm/rms_norm.py:144

  • autotune_configs can produce an empty block_ms list when max_block_m < 1, which will yield zero autotune configs and likely break tune=True. Consider asserting block_ms is non-empty and raising a user-friendly error early when the shared-memory budget can’t accommodate even block_m=1.
        smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
        max_block_m = (48 * 1024) // smem_per_row
        block_ms = [bm for bm in [1, 2, 4, 8] if bm <= max_block_m]
        threads_list = [128, 256]
        configs = list(itertools.product(block_ms, threads_list))

tests/ops/test_rms_norm.py:53

  • The parametrized test includes a tune argument but it isn’t used when constructing RmsNormOp (the op is always created with default tune=False). Either pass tune=tune into RmsNormOp(...) or drop tune from the fixture/test to avoid misleading coverage.
@RmsNormFixture
def test_rms_norm_op(m: int, n: int, dtype: torch.dtype, tune: bool) -> None:
    test = RmsNormTest(m, n, dtype)
    op = RmsNormOp(M=m, N=n, dtype=dtype)
    atol = 1e-2 if dtype == torch.float16 else 1.6e-2

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@zhen8838 zhen8838 marked this pull request as ready for review March 5, 2026 14:58
@zhen8838 zhen8838 requested a review from a team March 5, 2026 14:58
@lcy-seso lcy-seso merged commit 0ce8928 into tile-ai:main Mar 5, 2026
10 of 12 checks passed
@zhen8838 zhen8838 deleted the feat/norm/rms-norm-op branch March 6, 2026 01:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

all-ai-powered Produced entirely by automated contributors feature New feature or new operator

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEAT][RMS_NORM] implement rms_norm operator

3 participants