@@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
8787 indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
8888 acc_copy = acc
8989 acc_copy_0 = acc_copy
90- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
90+ load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
9191 load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
9292 acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
9393 v_0 = acc.to(tl.float16)
94- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
94+ tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
9595
9696def grid_1d(x: torch.Tensor, y: torch.Tensor):
9797 b, m, k = x.size()
@@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
225225 indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
226226 acc_copy = acc
227227 acc_copy_0 = acc_copy
228- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
228+ load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
229229 load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
230230 acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
231231 v_0 = acc.to(tl.float16)
232- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
232+ tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
233233
234234def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
235235 bi, bj, m, k = x.size()
@@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
363363 indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
364364 acc_copy = acc
365365 acc_copy_0 = acc_copy
366- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
366+ load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
367367 load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
368368 acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
369369 v_0 = acc.to(tl.float16)
370- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
370+ tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
371371
372372def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
373373 bi, bj, m, k = x.size()
@@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0):
425425 pid_0 = tl.program_id(0)
426426 begin_0 = 2
427427 offset_0 = begin_0 + pid_0
428- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
428+ load = tl.load(x + offset_0 * x_stride_0, None)
429429 v_0 = 2.0
430430 v_1 = load * v_0
431- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
431+ tl.store(out + offset_0 * out_stride_0, v_1, None)
432432
433433def grid_begin_end(x: torch.Tensor):
434434 n = x.size(0)
@@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor:
475475def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
476476 pid_0 = tl.program_id(0)
477477 offset_0 = pid_0 * _BLOCK_SIZE_0
478- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
478+ load = tl.load(x + offset_0 * x_stride_0, None)
479479 v_0 = 2.0
480480 v_1 = load * v_0
481- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
481+ tl.store(out + offset_0 * out_stride_0, v_1, None)
482482
483483def grid_begin_end_step(x: torch.Tensor):
484484 n = x.size(0)
@@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor:
527527def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
528528 pid_0 = tl.program_id(0)
529529 offset_0 = pid_0 * _BLOCK_SIZE_0
530- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
530+ load = tl.load(x + offset_0 * x_stride_0, None)
531531 v_0 = 2.0
532532 v_1 = load * v_0
533- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
533+ tl.store(out + offset_0 * out_stride_0, v_1, None)
534534
535535def grid_end_step_kwarg(x: torch.Tensor):
536536 n = x.size(0)
@@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride
587587 offset_0 = begin_0 + pid_0
588588 begin_1 = 1
589589 offset_1 = begin_1 + pid_1
590- load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
590+ load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
591591 v_0 = 2.0
592592 v_1 = load * v_0
593- tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
593+ tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
594594
595595def grid_multidim_begin_end(x: torch.Tensor):
596596 m, n = x.size()
@@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s
643643 pid_1 = tl.program_id(0) // num_blocks_0
644644 offset_0 = pid_0 * _BLOCK_SIZE_0
645645 offset_1 = pid_1 * _BLOCK_SIZE_1
646- load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
646+ load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
647647 v_0 = 2.0
648648 v_1 = load * v_0
649- tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
649+ tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
650650
651651def grid_multidim_begin_end_step(x: torch.Tensor):
652652 m, n = x.size()
0 commit comments