@@ -66,7 +66,7 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out
6666 indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
6767 mask_1 = indices_1 < b
6868 for offset_3 in tl.range(0, d.to(tl.int32), step=1):
69- indices_3 = offset_3 + tl.arange(0, 1).to( tl.int32)
69+ indices_3 = offset_3 + tl.zeros([1], tl.int32)
7070 load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0)
7171 v_0 = tl_math.sin(load)
7272 tl.store(out + (indices_0[:, None, None, None] * out_stride_0 + indices_1[None, :, None, None] * out_stride_1 + indices_2[None, None, :, None] * out_stride_2 + indices_3[None, None, None, :] * out_stride_3), v_0, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None])
@@ -197,7 +197,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0,
197197 v_3 = 2.0
198198 v_4 = in_x * v_3
199199 for offset_2 in tl.range(2, 5, step=1):
200- indices_2 = offset_2 + tl.arange(0, 1).to( tl.int32)
200+ indices_2 = offset_2 + tl.zeros([1], tl.int32)
201201 v_4_copy = v_4
202202 in_x_0_copy = in_x_0
203203 T0_copy = T0
@@ -245,13 +245,13 @@ import triton
245245import triton.language as tl
246246
247247@triton.jit
248- def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
248+ def _fn_kernel(x, end, out, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
249249 pid_0 = tl.program_id(0)
250250 offset_1 = pid_0 * _BLOCK_SIZE_1
251251 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
252252 mask_1 = indices_1 < x_size_0
253253 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
254- load = tl.load(end + tl.zeros([], tl.int32) , None)
254+ load = tl.load(end + 0 * end_stride_0 , None)
255255 for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0):
256256 indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
257257 mask_0 = indices_0 < load
@@ -267,7 +267,7 @@ def fn(x: torch.Tensor, end: torch.Tensor):
267267 bs = 32
268268 _BLOCK_SIZE_1 = 32
269269 _BLOCK_SIZE_0 = 32
270- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
270+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
271271 return out
272272
273273def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
@@ -276,7 +276,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
276276 _BLOCK_SIZE_1 = 32
277277 _BLOCK_SIZE_0 = 32
278278 from helion.runtime.precompile_shim import make_precompiler
279- return make_precompiler(_fn_kernel)(x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
279+ return make_precompiler(_fn_kernel)(x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
280280
281281--- assertExpectedJournal(TestLoops.test_data_dependent_bounds2)
282282from __future__ import annotations
@@ -286,13 +286,13 @@ import triton
286286import triton.language as tl
287287
288288@triton.jit
289- def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
289+ def _fn_kernel(x, end, out, out_size_0, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
290290 pid_0 = tl.program_id(0)
291291 offset_0 = pid_0 * _BLOCK_SIZE_0
292292 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
293293 mask_0 = indices_0 < x_size_0
294294 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
295- load = tl.load(end + tl.zeros([], tl.int32) , None)
295+ load = tl.load(end + 0 * end_stride_0 , None)
296296 for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
297297 indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
298298 mask_1 = indices_1 < load
@@ -307,15 +307,15 @@ def fn(x: torch.Tensor, end: torch.Tensor):
307307 out = x.new_empty([x.size(0)])
308308 _BLOCK_SIZE_0 = 32
309309 _BLOCK_SIZE_1 = 32
310- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
310+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
311311 return out
312312
313313def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
314314 out = x.new_empty([x.size(0)])
315315 _BLOCK_SIZE_0 = 32
316316 _BLOCK_SIZE_1 = 32
317317 from helion.runtime.precompile_shim import make_precompiler
318- return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
318+ return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
319319
320320--- assertExpectedJournal(TestLoops.test_data_dependent_bounds3)
321321from __future__ import annotations
@@ -325,14 +325,14 @@ import triton
325325import triton.language as tl
326326
327327@triton.jit
328- def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
328+ def _fn_kernel(x, end0, end1, out, x_size_0, end0_stride_0, end1_stride_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
329329 pid_0 = tl.program_id(0)
330330 offset_0 = pid_0 * _BLOCK_SIZE_0
331331 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
332332 mask_0 = indices_0 < x_size_0
333333 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64)
334- load = tl.load(end0 + tl.zeros([], tl.int32) , None)
335- load_1 = tl.load(end1 + tl.zeros([], tl.int32) , None)
334+ load = tl.load(end0 + 0 * end0_stride_0 , None)
335+ load_1 = tl.load(end1 + 0 * end1_stride_0 , None)
336336 for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
337337 indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
338338 mask_1 = indices_1 < load
@@ -352,7 +352,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
352352 _BLOCK_SIZE_0 = 32
353353 _BLOCK_SIZE_2 = 32
354354 _BLOCK_SIZE_1 = 32
355- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
355+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
356356 return out
357357
358358def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
@@ -361,7 +361,7 @@ def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor
361361 _BLOCK_SIZE_2 = 32
362362 _BLOCK_SIZE_1 = 32
363363 from helion.runtime.precompile_shim import make_precompiler
364- return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
364+ return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
365365
366366--- assertExpectedJournal(TestLoops.test_data_dependent_bounds4)
367367from __future__ import annotations
@@ -371,14 +371,14 @@ import triton
371371import triton.language as tl
372372
373373@triton.jit
374- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
374+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
375375 pid_0 = tl.program_id(0)
376376 offset_1 = pid_0 * _BLOCK_SIZE_1
377377 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
378378 mask_1 = indices_1 < x_size_0
379379 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
380- load = tl.load(begin + tl.zeros([], tl.int32) , None)
381- load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
380+ load = tl.load(begin + 0 * begin_stride_0 , None)
381+ load_1 = tl.load(end + 0 * end_stride_0 , None)
382382 for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0):
383383 indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
384384 mask_0 = indices_0 < load_1
@@ -394,7 +394,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
394394 bs = 32
395395 _BLOCK_SIZE_1 = 32
396396 _BLOCK_SIZE_0 = 32
397- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
397+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
398398 return out
399399
400400def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
@@ -403,7 +403,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor
403403 _BLOCK_SIZE_1 = 32
404404 _BLOCK_SIZE_0 = 32
405405 from helion.runtime.precompile_shim import make_precompiler
406- return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
406+ return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
407407
408408--- assertExpectedJournal(TestLoops.test_data_dependent_bounds5)
409409from __future__ import annotations
@@ -413,14 +413,14 @@ import triton
413413import triton.language as tl
414414
415415@triton.jit
416- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
416+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
417417 pid_0 = tl.program_id(0)
418418 offset_0 = pid_0 * _BLOCK_SIZE_0
419419 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
420420 mask_0 = indices_0 < x_size_0
421421 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
422- load = tl.load(begin + tl.zeros([], tl.int32) , None)
423- load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
422+ load = tl.load(begin + 0 * begin_stride_0 , None)
423+ load_1 = tl.load(end + 0 * end_stride_0 , None)
424424 for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1):
425425 indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
426426 mask_1 = indices_1 < load_1
@@ -435,15 +435,15 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
435435 out = x.new_empty([x.size(0)])
436436 _BLOCK_SIZE_0 = 32
437437 _BLOCK_SIZE_1 = 32
438- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
438+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
439439 return out
440440
441441def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
442442 out = x.new_empty([x.size(0)])
443443 _BLOCK_SIZE_0 = 32
444444 _BLOCK_SIZE_1 = 32
445445 from helion.runtime.precompile_shim import make_precompiler
446- return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
446+ return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
447447
448448--- assertExpectedJournal(TestLoops.test_l2_grouping_with_register_block_size)
449449from __future__ import annotations
0 commit comments