Skip to content

Commit 8421b4f

Browse files
committed
[RFC] Add support for device for loop indexing
This PR is mostly vibe coded via cursor using gpt5, however heavily modified for correctness and simplicity. Fixes #598 stack-info: PR: #673, branch: oulgen/stack/99
1 parent f25a771 commit 8421b4f

File tree

4 files changed

+233
-1
lines changed

4 files changed

+233
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,38 @@ def visit_For(self, node: ast.For) -> None:
590590
self._assign(node.target, inner_type.proxy())
591591
self._body(node.body)
592592
elif node._loop_type == LoopType.DEVICE:
593+
# Try static unrolling when begin/end are compile-time ints
594+
begin, end = self._extract_tile_begin_end(node)
595+
if isinstance(inner_type, SequenceType):
596+
iter_vars = inner_type.unpack()
597+
if begin is None:
598+
begin_list = [0] * len(iter_vars)
599+
else:
600+
begin_list = begin if isinstance(begin, (list, tuple)) else [begin]
601+
end_list = end if isinstance(end, (list, tuple)) else [end]
602+
try_static = all(
603+
isinstance(b, int) and isinstance(e, int)
604+
for b, e in zip(begin_list, end_list, strict=True)
605+
)
606+
count = 1
607+
if try_static:
608+
for b, e in zip(begin_list, end_list, strict=True):
609+
count *= max(0, e - b) # pyright: ignore[reportOperatorIssue]
610+
if try_static and count <= 64:
611+
# Assign inner proxy to target and then unroll nested ranges over scalar indices
612+
self._assign(node.target, inner_type.proxy())
613+
self._body(node.body)
614+
return
615+
else:
616+
# 1D case
617+
b0 = 0 if begin is None else begin
618+
if isinstance(b0, int) and isinstance(end, int):
619+
trip_count = max(0, end - b0)
620+
if trip_count <= 64:
621+
for iv in range(b0, end):
622+
self._assign(node.target, iv)
623+
self._body(node.body)
624+
return
593625
rw: ReadWrites = ReadWrites.from_ast(node)
594626
inputs: LiftTensorArgs = LiftTensorArgs(
595627
{
@@ -947,6 +979,11 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
947979
assert isinstance(value, ExtendedAST)
948980
type_info = value._type_info
949981
if isinstance(type_info, SequenceType):
982+
index_val = self.visit(node.slice)
983+
if isinstance(index_val, int):
984+
sequence_val = self.visit(value)
985+
assert isinstance(sequence_val, (list, tuple))
986+
return sequence_val[index_val]
950987
if isinstance(node.slice, ast.Constant):
951988
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
952989
raise exc.InvalidSequenceSubscription(node.slice)

helion/_compiler/type_propagation.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,20 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12681268
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12691269

12701270
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1271-
return super().propagate_getitem(key, origin)
1271+
# Try literal indexing first
1272+
try:
1273+
return super().propagate_getitem(key, origin)
1274+
except exc.TypeInferenceError:
1275+
# If indexing with a symbolic/grid index on device and the sequence length is known,
1276+
# conservatively merge all possible element types.
1277+
if origin.is_device() and isinstance(key, (SymIntType, GridIndexType)):
1278+
if not self.element_types:
1279+
return super().propagate_getitem(key, origin)
1280+
merged: TypeInfo = self.element_types[0]
1281+
for candidate in self.element_types[1:]:
1282+
merged = merged.merge(candidate)
1283+
return merged
1284+
raise
12721285

12731286
def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:
12741287
if isinstance(other, SequenceType):
@@ -2161,6 +2174,86 @@ def visit_For(self, node: ast.For) -> TypeInfo:
21612174
raise exc.NestedGridLoop
21622175

21632176
self.device_loop_depth += device_loop
2177+
2178+
# Try static unrolling for device grid loops when iteration count is known
2179+
try:
2180+
if node._loop_type != LoopType.HOST and isinstance(node.iter, ast.Call):
2181+
call_node = node.iter
2182+
# Extract begin, end, step; support only 1D grid here
2183+
begin_val: int | None
2184+
end_val: int | None
2185+
step_val: int | None
2186+
2187+
if len(call_node.args) == 1:
2188+
begin_val = 0
2189+
end_type = self.visit(call_node.args[0])
2190+
step_type: TypeInfo | None = None
2191+
else:
2192+
begin_type = self.visit(call_node.args[0])
2193+
end_type = self.visit(call_node.args[1])
2194+
step_type = (
2195+
self.visit(call_node.args[2])
2196+
if len(call_node.args) >= 3
2197+
else None
2198+
)
2199+
begin_val = (
2200+
begin_type.as_literal() if begin_type.is_literal() else None
2201+
) # type: ignore[assignment]
2202+
2203+
for kw in call_node.keywords:
2204+
if kw.arg == "step" and step_type is None:
2205+
step_type = self.visit(kw.value)
2206+
2207+
end_val = end_type.as_literal() if end_type.is_literal() else None # type: ignore[assignment]
2208+
step_val = (
2209+
step_type.as_literal()
2210+
if (step_type is not None and step_type.is_literal())
2211+
else 1
2212+
) # type: ignore[assignment]
2213+
2214+
if (
2215+
isinstance(begin_val, int)
2216+
and isinstance(end_val, int)
2217+
and isinstance(step_val, int)
2218+
):
2219+
# Build concrete iteration values
2220+
iter_values = list(range(begin_val, end_val, step_val))
2221+
# Small guard to avoid excessive compile-time blowups
2222+
if len(iter_values) <= 64:
2223+
merged_scope: LocalScope | None = None
2224+
for iv in iter_values:
2225+
# Emulate _loop_body with loop index bound to a literal
2226+
self.push_scope()
2227+
self._assign(node.target, LiteralType(self.origin(), iv))
2228+
exit_scopes = [self.scope]
2229+
for stmt in node.body:
2230+
self.visit(stmt)
2231+
if isinstance(stmt, (ast.Break, ast.Continue)):
2232+
exit_scopes.append(self.scope.clone())
2233+
# Reset loop variable back to its GridIndexType to avoid control-flow merging issues
2234+
self._assign(
2235+
node.target,
2236+
iter_type.propagate_iter(self.origin()),
2237+
)
2238+
self.pop_scope()
2239+
iter_scope = functools.reduce(
2240+
lambda x, y: x.merge(y), exit_scopes
2241+
)
2242+
if merged_scope is None:
2243+
merged_scope = iter_scope
2244+
else:
2245+
merged_scope.merge(iter_scope)
2246+
2247+
if merged_scope is not None:
2248+
body = merged_scope
2249+
orelse = self._body(node.orelse)
2250+
self.scope.merge_if_else(body, orelse)
2251+
self.device_loop_depth -= device_loop
2252+
return NoType(origin=self.origin())
2253+
except NotImplementedError:
2254+
# Fall back to generic handling if we can't statically determine iterations
2255+
pass
2256+
21642257
body = self._loop_body(node.body)
21652258
with self.swap_scope(body):
21662259
# second pass for fixed point

test/test_type_propagation.expected

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,74 @@ def root_graph_2():
537537
_for_loop = helion_language__tracing_ops__for_loop(1, [0], [x_size0], []); x_size0 = _for_loop = None
538538
return None
539539

540+
--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code0)
541+
from __future__ import annotations
542+
543+
import torch
544+
import triton
545+
import triton.language as tl
546+
from helion.runtime import default_launcher as _default_launcher
547+
548+
@triton.jit
549+
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
550+
pid_0 = tl.program_id(0)
551+
offset_0 = pid_0 * _BLOCK_SIZE_0
552+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
553+
mask_0 = indices_0 < out_size_0
554+
load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
555+
load_1 = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
556+
v_0 = load + load_1
557+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
558+
load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
559+
load_3 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
560+
v_1 = load_2 + load_3
561+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
562+
load_4 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
563+
load_5 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
564+
v_2 = load_4 + load_5
565+
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
566+
load_6 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
567+
load_7 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
568+
v_3 = load_6 + load_7
569+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
570+
571+
def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
572+
out = torch.zeros_like(As[0])
573+
_BLOCK_SIZE_0 = 16
574+
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
575+
return out
576+
577+
--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code1)
578+
from __future__ import annotations
579+
580+
import torch
581+
import triton
582+
import triton.language as tl
583+
from helion.runtime import default_launcher as _default_launcher
584+
585+
@triton.jit
586+
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
587+
pid_0 = tl.program_id(0)
588+
offset_0 = pid_0 * _BLOCK_SIZE_0
589+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
590+
mask_0 = indices_0 < out_size_0
591+
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
592+
load = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
593+
v_0 = acc + load
594+
load_1 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
595+
v_1 = v_0 + load_1
596+
load_2 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
597+
v_2 = v_1 + load_2
598+
load_3 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
599+
v_3 = v_2 + load_3
600+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
601+
602+
def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
603+
out = torch.zeros_like(As[0])
604+
_BLOCK_SIZE_0 = 16
605+
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
606+
return out
607+
540608
--- assertExpectedJournal(TestTypePropagation.test_hl_full_usage)
541609
def hl_full_usage(x: torch.Tensor):
542610
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:38>)

test/test_type_propagation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
import helion
1010
from helion import exc
11+
from helion._testing import DEVICE
1112
from helion._testing import RefEagerTestDisabled
1213
from helion._testing import TestCase
14+
from helion._testing import code_and_output
1315
from helion._testing import import_path
1416
import helion.language as hl
1517

@@ -132,6 +134,38 @@ def use_unsupported_property(x: torch.Tensor) -> torch.Tensor:
132134
):
133135
type_propagation_report(use_unsupported_property, x)
134136

137+
def test_for_loop_indexing_in_device_code0(self):
138+
@helion.kernel
139+
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
140+
out = torch.zeros_like(As[0])
141+
for tile in hl.tile(out.size()):
142+
for i in range(len(As)):
143+
a = As[i]
144+
out[tile] += a[tile]
145+
return out
146+
147+
args = [torch.randn(16, device=DEVICE) for _ in range(4)]
148+
code, result = code_and_output(kernel, (args,))
149+
torch.testing.assert_close(result, sum(args))
150+
self.assertExpectedJournal(code)
151+
152+
def test_for_loop_indexing_in_device_code1(self):
153+
@helion.kernel
154+
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
155+
out = torch.zeros_like(As[0])
156+
for tile in hl.tile(out.size()):
157+
acc = hl.zeros(tile)
158+
for i in range(len(As)):
159+
a = As[i]
160+
acc = acc + a[tile]
161+
out[tile] = acc
162+
return out
163+
164+
args = [torch.randn(16, device=DEVICE) for _ in range(4)]
165+
code, result = code_and_output(kernel, (args,))
166+
torch.testing.assert_close(result, sum(args))
167+
self.assertExpectedJournal(code)
168+
135169

136170
if __name__ == "__main__":
137171
unittest.main()

0 commit comments

Comments
 (0)