[Feat][Norm] Add standalone rms_norm operator#348
Conversation
Implement rms_norm forward: y = x * rsqrt(mean(x^2) + eps) * weight - TileLang kernel with single-pass row reduction, fp32 accumulation - 256-element aligned shared memory copies with adaptive block_m - Op supports arbitrary leading dims, non-contiguous input, N-padding - 14 test cases: aligned/non-aligned N, non-contiguous, 3D input - Benchmark with bandwidth and flops profiling Closes tile-ai#325 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add dtype, device, ndim, and M consistency checks with clear ValueError messages per Reviewer feedback. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized RMS Normalization operator, Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new standalone RMS Normalization operator (RmsNormOp) with its TileLang kernel (RmsNormKernel), demonstrating significant speedup over the PyTorch baseline and good practices for mixed-precision operations. However, a high-severity out-of-bounds memory access vulnerability was identified in the GPU kernel. This occurs when the number of rows M is not a multiple of the kernel's block size block_m, potentially leading to memory corruption or denial of service. The recommended fix involves padding the M dimension within the RmsNormOp class. Additionally, there are areas for improvement regarding explicit type casting within the TileLang kernel for clarity and potential precision control, and a minor code duplication that could be refactored.
There was a problem hiding this comment.
Pull request overview
Adds a standalone RMSNorm operator/kernel to TileOPs so models that invoke RMSNorm directly can dispatch to a TileLang implementation instead of falling back to PyTorch.
Changes:
- Introduces
RmsNormKernel(TileLang) andRmsNormOp(Python op wrapper) with optional N-padding for alignment. - Exports the new op/kernel through the relevant
__init__.pymodules. - Adds unit tests and a benchmark comparing against a PyTorch reference.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
tileops/ops/rms_norm.py |
New RmsNormOp wrapper with shape/dtype validation and optional padding. |
tileops/ops/__init__.py |
Exports RmsNormOp from tileops.ops. |
tileops/kernels/norm/rms_norm.py |
New TileLang RMSNorm kernel + Torch custom op wrapper + tuning configs. |
tileops/kernels/norm/__init__.py |
Exports RmsNormKernel from the norm subpackage. |
tests/ops/test_rms_norm.py |
Adds correctness tests for aligned/non-aligned N, non-contiguous, and 3D inputs. |
benchmarks/ops/bench_rms_norm.py |
Adds benchmark harness + baseline comparison. |
Comments suppressed due to low confidence (4)
tests/ops/test_rms_norm.py:52
- The test is parametrized with
tunebutRmsNormOpis constructed without passing it, so the autotune path is never exercised by unit tests. Passtune=tunehere (matching the benchmark and other op tests) or droptunefrom the fixture params if it’s intentionally out of scope.
def test_rms_norm_op(m: int, n: int, dtype: torch.dtype, tune: bool) -> None:
test = RmsNormTest(m, n, dtype)
op = RmsNormOp(M=m, N=n, dtype=dtype)
atol = 1e-2 if dtype == torch.float16 else 1.6e-2
benchmarks/ops/bench_rms_norm.py:23
torch.dtypealready exposesitemsize(used in other benchmarks likebench_gemm.py). Usingt.dtype.itemsizehere would avoid constructing an extra empty tensor each call.
t = self.test
elem_bytes = torch.tensor([], dtype=t.dtype).element_size()
# Read x (M*N) + read weight (N, broadcast) + write y (M*N)
tileops/kernels/norm/rms_norm.py:136
max_block_mcan become 0 whenN_padded * element_size > 48KB, butdefault_configstill falls back toblock_m = 1, which would exceed the stated shared-memory budget and likely fail compilation at runtime. Consider validating this case explicitly (e.g., raise a clear error or adjust the budgeting based on the real SMEM limit).
# Shared memory budget: 1 buffer * block_m * N_padded * dtype_size < 48KB
smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
max_block_m = (48 * 1024) // smem_per_row
block_m = 1
for bm in [1, 2, 4, 8]:
if bm <= max_block_m:
block_m = bm
return {"block_m": block_m, "threads": 128}
tileops/kernels/norm/rms_norm.py:145
- When
max_block_mcomputes to 0,block_msbecomes empty andautotune_configsreturns an empty list. Iftune=True,Kernel.autotune()will be invoked with no configs, which can error or behave unexpectedly. Please ensureautotune_configsis never empty (or fail fast with a clear error).
smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
max_block_m = (48 * 1024) // smem_per_row
block_ms = [bm for bm in [1, 2, 4, 8] if bm <= max_block_m]
threads_list = [128, 256]
configs = list(itertools.product(block_ms, threads_list))
return [{"block_m": bm, "threads": t} for bm, t in configs]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Code Review
This pull request introduces a standalone RMS Norm operator with a TileLang kernel, including comprehensive test cases and a benchmark. However, a critical security vulnerability was identified in the kernel's memory handling due to a lack of boundary checks on the row dimension, which can lead to out-of-bounds memory access. Additionally, the review identified a critical type mismatch in the kernel's dtype handling, inconsistencies in test reference eps values, and areas for improvement in benchmark precision and configuration logic.
…tency - Remove unused Tuple import in kernel (F401 lint) - Replace Tuple with builtin tuple in test (Python 3.12) - Use explicit eps variable in non-contiguous and 3D test references Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Proves T.copy handles partial blocks correctly when M is not divisible by block_m, directly refuting the bot's boundary_check concern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (4)
tileops/ops/rms_norm.py:88
xis made contiguous before invoking the TileLang kernel, butweightis passed through unchanged unless padding is applied. Because the kernel treatsweightas a dense 1D tensor, a non-contiguous/viewweight(stride != 1) can lead to incorrect reads. Consider enforcingweight.is_contiguous()(raise) or doingweight = weight.contiguous()before calling the kernel.
orig_shape = x.shape
x = x.contiguous().reshape(-1, self.N)
M_actual = x.shape[0]
if M_actual != self.M:
raise ValueError(
tileops/kernels/norm/rms_norm.py:133
default_configfalls back toblock_m = 1even whensmem_per_row > 48KB(somax_block_mbecomes 0). In that situation the kernel’s shared-memory allocation is guaranteed to exceed the budget and likely fail at compile/runtime. Add an explicit check (and clear error) whenmax_block_m < 1/smem_per_rowexceeds the budget.
# Shared memory budget: 1 buffer * block_m * N_padded * dtype_size < 48KB
smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
max_block_m = (48 * 1024) // smem_per_row
block_m = 1
for bm in [1, 2, 4, 8]:
tileops/kernels/norm/rms_norm.py:144
autotune_configscan produce an emptyblock_mslist whenmax_block_m < 1, which will yield zero autotune configs and likely breaktune=True. Consider assertingblock_msis non-empty and raising a user-friendly error early when the shared-memory budget can’t accommodate evenblock_m=1.
smem_per_row = self.N_padded * torch.tensor([], dtype=self.dtype).element_size()
max_block_m = (48 * 1024) // smem_per_row
block_ms = [bm for bm in [1, 2, 4, 8] if bm <= max_block_m]
threads_list = [128, 256]
configs = list(itertools.product(block_ms, threads_list))
tests/ops/test_rms_norm.py:53
- The parametrized test includes a
tuneargument but it isn’t used when constructingRmsNormOp(the op is always created with defaulttune=False). Either passtune=tuneintoRmsNormOp(...)or droptunefrom the fixture/test to avoid misleading coverage.
@RmsNormFixture
def test_rms_norm_op(m: int, n: int, dtype: torch.dtype, tune: bool) -> None:
test = RmsNormTest(m, n, dtype)
op = RmsNormOp(M=m, N=n, dtype=dtype)
atol = 1e-2 if dtype == torch.float16 else 1.6e-2
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Closes #325
Summary
RmsNormKernelwith single-pass row reduction, fp32 accumulation, 256-element aligned shared memory copiesRmsNormOpsupporting arbitrary leading dims, non-contiguous inputs, N-padding; exports fromtileops/opsTest plan
tests/ops/test_rms_norm.py)benchmarks/ops/bench_rms_norm.py)run-affected-tests.sh(status=pass)Benchmark
Environment: NVIDIA H200 MIG 1g.18gb (SM90), CUDA 12.8, PyTorch 2.10.0
Performance AC: Accepted under H200 MIG (57-61% of measured MIG peak = 0.508 TB/s). Per @zhengqihang decision (2026-03-05); full-GPU retest deferred as follow-up.