Skip to content

Commit 058cd60

Browse files
committed
Use helion.cdiv
stack-info: PR: #852, branch: oulgen/stack/129
1 parent df51b71 commit 058cd60

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

examples/grouped_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def grouped_gemm_jagged_persistent(
153153

154154
if m_size > 0:
155155
# Compute tile grid dimensions for current group
156-
num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M
156+
num_m_tiles = helion.cdiv(m_size, BLOCK_M) # pyright: ignore[reportArgumentType]
157157
# Calculate number of N tiles (shared across all groups)
158-
num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N
158+
num_n_tiles = helion.cdiv(N, BLOCK_N) # pyright: ignore[reportArgumentType]
159159
num_group_tiles = num_m_tiles * num_n_tiles
160160

161161
# Distribute tiles among workers using strided access pattern

examples/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def rms_norm_bwd(
9292
m_block = hl.register_block_size(x.size(0))
9393
grad_x = torch.empty_like(x)
9494
grad_weight = x.new_empty(
95-
[(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32
95+
[helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32
9696
)
9797
weight_shape = hl.specialize(weight.size(0))
9898
for mb_cta in hl.tile(x.size(0), block_size=m_block):

test/test_examples.expected

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3158,6 +3158,7 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.
31583158
from __future__ import annotations
31593159

31603160
import torch
3161+
import helion
31613162
import triton
31623163
import triton.language as tl
31633164
from helion.runtime import default_launcher as _default_launcher
@@ -3223,7 +3224,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
32233224
"""
32243225
m_block = 32
32253226
grad_x = torch.empty_like(x)
3226-
grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
3227+
grad_weight = x.new_empty([helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32)
32273228
_BLOCK_SIZE_0 = 32
32283229
_RDIM_SIZE_2 = 64
32293230
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)