@@ -66,7 +66,7 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out
6666            indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
6767            mask_1 = indices_1 < b
6868            for offset_3 in tl.range(0, d.to(tl.int32), step=1):
69-                 indices_3 = offset_3 + tl.arange(0, 1).to( tl.int32)
69+                 indices_3 = offset_3 + tl.zeros([1],  tl.int32)
7070                load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0)
7171                v_0 = tl_math.sin(load)
7272                tl.store(out + (indices_0[:, None, None, None] * out_stride_0 + indices_1[None, :, None, None] * out_stride_1 + indices_2[None, None, :, None] * out_stride_2 + indices_3[None, None, None, :] * out_stride_3), v_0, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None])
@@ -197,7 +197,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0,
197197    v_3 = 2.0
198198    v_4 = in_x * v_3
199199    for offset_2 in tl.range(2, 5, step=1):
200-         indices_2 = offset_2 + tl.arange(0, 1).to( tl.int32)
200+         indices_2 = offset_2 + tl.zeros([1],  tl.int32)
201201        v_4_copy = v_4
202202        in_x_0_copy = in_x_0
203203        T0_copy = T0
@@ -245,13 +245,13 @@ import triton
245245import triton.language as tl
246246
247247@triton.jit
248- def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
248+ def _fn_kernel(x, end, out, x_size_0, end_stride_0,  out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
249249    pid_0 = tl.program_id(0)
250250    offset_1 = pid_0 * _BLOCK_SIZE_1
251251    indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
252252    mask_1 = indices_1 < x_size_0
253253    acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
254-     load = tl.load(end + tl.zeros([], tl.int32) , None)
254+     load = tl.load(end + 0 * end_stride_0 , None)
255255    for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0):
256256        indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
257257        mask_0 = indices_0 < load
@@ -267,7 +267,7 @@ def fn(x: torch.Tensor, end: torch.Tensor):
267267    bs = 32
268268    _BLOCK_SIZE_1 = 32
269269    _BLOCK_SIZE_0 = 32
270-     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
270+     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
271271    return out
272272
273273def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
@@ -276,7 +276,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
276276    _BLOCK_SIZE_1 = 32
277277    _BLOCK_SIZE_0 = 32
278278    from helion.runtime.precompile_shim import make_precompiler
279-     return make_precompiler(_fn_kernel)(x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
279+     return make_precompiler(_fn_kernel)(x, end, out, x.size(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
280280
281281--- assertExpectedJournal(TestLoops.test_data_dependent_bounds2)
282282from __future__ import annotations
@@ -286,13 +286,13 @@ import triton
286286import triton.language as tl
287287
288288@triton.jit
289- def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
289+ def _fn_kernel(x, end, out, out_size_0, x_size_0, end_stride_0,  out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
290290    pid_0 = tl.program_id(0)
291291    offset_0 = pid_0 * _BLOCK_SIZE_0
292292    indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
293293    mask_0 = indices_0 < x_size_0
294294    acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
295-     load = tl.load(end + tl.zeros([], tl.int32) , None)
295+     load = tl.load(end + 0 * end_stride_0 , None)
296296    for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
297297        indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
298298        mask_1 = indices_1 < load
@@ -307,15 +307,15 @@ def fn(x: torch.Tensor, end: torch.Tensor):
307307    out = x.new_empty([x.size(0)])
308308    _BLOCK_SIZE_0 = 32
309309    _BLOCK_SIZE_1 = 32
310-     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
310+     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
311311    return out
312312
313313def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
314314    out = x.new_empty([x.size(0)])
315315    _BLOCK_SIZE_0 = 32
316316    _BLOCK_SIZE_1 = 32
317317    from helion.runtime.precompile_shim import make_precompiler
318-     return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
318+     return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
319319
320320--- assertExpectedJournal(TestLoops.test_data_dependent_bounds3)
321321from __future__ import annotations
@@ -325,14 +325,14 @@ import triton
325325import triton.language as tl
326326
327327@triton.jit
328- def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
328+ def _fn_kernel(x, end0, end1, out, x_size_0, end0_stride_0, end1_stride_0,  out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
329329    pid_0 = tl.program_id(0)
330330    offset_0 = pid_0 * _BLOCK_SIZE_0
331331    indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
332332    mask_0 = indices_0 < x_size_0
333333    acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64)
334-     load = tl.load(end0 + tl.zeros([], tl.int32) , None)
335-     load_1 = tl.load(end1 + tl.zeros([], tl.int32) , None)
334+     load = tl.load(end0 + 0 * end0_stride_0 , None)
335+     load_1 = tl.load(end1 + 0 * end1_stride_0 , None)
336336    for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
337337        indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
338338        mask_1 = indices_1 < load
@@ -352,7 +352,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
352352    _BLOCK_SIZE_0 = 32
353353    _BLOCK_SIZE_2 = 32
354354    _BLOCK_SIZE_1 = 32
355-     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
355+     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0),  out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
356356    return out
357357
358358def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
@@ -361,7 +361,7 @@ def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor
361361    _BLOCK_SIZE_2 = 32
362362    _BLOCK_SIZE_1 = 32
363363    from helion.runtime.precompile_shim import make_precompiler
364-     return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
364+     return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0),  out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
365365
366366--- assertExpectedJournal(TestLoops.test_data_dependent_bounds4)
367367from __future__ import annotations
@@ -371,14 +371,14 @@ import triton
371371import triton.language as tl
372372
373373@triton.jit
374- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
374+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0,  out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
375375    pid_0 = tl.program_id(0)
376376    offset_1 = pid_0 * _BLOCK_SIZE_1
377377    indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
378378    mask_1 = indices_1 < x_size_0
379379    acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
380-     load = tl.load(begin + tl.zeros([], tl.int32) , None)
381-     load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
380+     load = tl.load(begin + 0 * begin_stride_0 , None)
381+     load_1 = tl.load(end + 0 * end_stride_0 , None)
382382    for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0):
383383        indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
384384        mask_0 = indices_0 < load_1
@@ -394,7 +394,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
394394    bs = 32
395395    _BLOCK_SIZE_1 = 32
396396    _BLOCK_SIZE_0 = 32
397-     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
397+     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
398398    return out
399399
400400def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
@@ -403,7 +403,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor
403403    _BLOCK_SIZE_1 = 32
404404    _BLOCK_SIZE_0 = 32
405405    from helion.runtime.precompile_shim import make_precompiler
406-     return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
406+     return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
407407
408408--- assertExpectedJournal(TestLoops.test_data_dependent_bounds5)
409409from __future__ import annotations
@@ -413,14 +413,14 @@ import triton
413413import triton.language as tl
414414
415415@triton.jit
416- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
416+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0,  out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
417417    pid_0 = tl.program_id(0)
418418    offset_0 = pid_0 * _BLOCK_SIZE_0
419419    indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
420420    mask_0 = indices_0 < x_size_0
421421    acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
422-     load = tl.load(begin + tl.zeros([], tl.int32) , None)
423-     load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
422+     load = tl.load(begin + 0 * begin_stride_0 , None)
423+     load_1 = tl.load(end + 0 * end_stride_0 , None)
424424    for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1):
425425        indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
426426        mask_1 = indices_1 < load_1
@@ -435,15 +435,15 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
435435    out = x.new_empty([x.size(0)])
436436    _BLOCK_SIZE_0 = 32
437437    _BLOCK_SIZE_1 = 32
438-     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
438+     _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
439439    return out
440440
441441def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
442442    out = x.new_empty([x.size(0)])
443443    _BLOCK_SIZE_0 = 32
444444    _BLOCK_SIZE_1 = 32
445445    from helion.runtime.precompile_shim import make_precompiler
446-     return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
446+     return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0),  out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
447447
448448--- assertExpectedJournal(TestLoops.test_l2_grouping_with_register_block_size)
449449from __future__ import annotations
0 commit comments