- 
                Notifications
    You must be signed in to change notification settings 
- Fork 61
[WIP] Fused RMSNorm implementation #2205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds fused RMSNorm (Root Mean Square Normalization) support for XPU devices to match PyTorch's recent implementation. RMSNorm is a simpler normalization technique compared to LayerNorm that eliminates the mean centering step.
Key Changes:
- Adds forward and backward RMSNorm kernel registrations in the native functions YAML
- Refactors existing LayerNorm kernels to support both LayerNorm and RMSNorm via a template parameter
- Implements RMSNorm-specific computation paths that skip mean centering
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description | 
|---|---|
| yaml/native/native_functions.yaml | Registers _fused_rms_normand_fused_rms_norm_backwardfunctions with XPU dispatch | 
| src/ATen/native/xpu/sycl/LayerNormKernels.cpp | Adds rms_normtemplate parameter to kernel functors and implements RMSNorm computation logic | 
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| mean_[i] = m1; | ||
| rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); | ||
| } else { | ||
| rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_); | 
    
      
    
      Copilot
AI
    
    
    
      Oct 22, 2025 
    
  
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RMSNorm formula appears incorrect. For RMSNorm, m1 (mean) should be zero since we skip mean computation. The formula should be rsqrt(m2 + eps_) where m2 represents the mean of squares. The term m1 * m1 should not be added.
| rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_); | |
| rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); | 
| static_cast<T_ACC>(rstd_[i]) * gamma_v + | ||
| beta_v; | ||
| } else { | ||
| Y_[index] = (static_cast<T_ACC>(X_[index])) * | 
    
      
    
      Copilot
AI
    
    
    
      Oct 22, 2025 
    
  
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Remove the unnecessary inner parentheses around static_cast<T_ACC>(X_[index]). The expression can be simplified to Y_[index] = static_cast<T_ACC>(X_[index]) * static_cast<T_ACC>(rstd_[i]) * gamma_v;
| Y_[index] = (static_cast<T_ACC>(X_[index])) * | |
| Y_[index] = static_cast<T_ACC>(X_[index]) * | 
| U new_mean = curr_sum.mean + delta * (1.f / new_count); | ||
| return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | ||
| } else { | ||
| return {0.f, curr_sum.sigma2 + val * val, 0}; | 
    
      
    
      Copilot
AI
    
    
    
      Oct 22, 2025 
    
  
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return statement uses integer literal 0 for the mean and count fields, but the struct fields are of type float. For consistency and clarity, use 0.f for all three fields: return {0.f, curr_sum.sigma2 + val * val, 0.f};
| return {0.f, curr_sum.sigma2 + val * val, 0}; | |
| return {0.f, curr_sum.sigma2 + val * val, 0.f}; | 
Motivation
Fix #1905.
Refer to pytorch/pytorch#153666, add fused RMSNorm support on XPU.