Skip to content

Commit 76c6a85

Browse files
committed
[RFC] Add support for device for loop indexing
stack-info: PR: #673, branch: oulgen/stack/99
1 parent 07b1182 commit 76c6a85

File tree

4 files changed

+225
-1
lines changed

4 files changed

+225
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,35 @@ def visit_For(self, node: ast.For) -> None:
590590
self._assign(node.target, inner_type.proxy())
591591
self._body(node.body)
592592
elif node._loop_type == LoopType.DEVICE:
593+
# Try static unrolling when begin/end are compile-time ints
594+
begin, end = self._extract_tile_begin_end(node)
595+
if isinstance(inner_type, SequenceType):
596+
iter_vars = inner_type.unpack()
597+
if begin is None:
598+
begin_list = [0] * len(iter_vars)
599+
else:
600+
begin_list = begin if isinstance(begin, (list, tuple)) else [begin]
601+
end_list = end if isinstance(end, (list, tuple)) else [end]
602+
try_static = all(isinstance(b, int) and isinstance(e, int) for b, e in zip(begin_list, end_list, strict=True))
603+
count = 1
604+
if try_static:
605+
for b, e in zip(begin_list, end_list, strict=True):
606+
count *= max(0, e - b)
607+
if try_static and count <= 64:
608+
# Assign inner proxy to target and then unroll nested ranges over scalar indices
609+
self._assign(node.target, inner_type.proxy())
610+
self._body(node.body)
611+
return
612+
else:
613+
# 1D case
614+
b0 = 0 if begin is None else begin
615+
if isinstance(b0, int) and isinstance(end, int):
616+
trip_count = max(0, end - b0)
617+
if trip_count <= 64:
618+
for iv in range(b0, end):
619+
self._assign(node.target, iv)
620+
self._body(node.body)
621+
return
593622
rw: ReadWrites = ReadWrites.from_ast(node)
594623
inputs: LiftTensorArgs = LiftTensorArgs(
595624
{
@@ -947,6 +976,11 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
947976
assert isinstance(value, ExtendedAST)
948977
type_info = value._type_info
949978
if isinstance(type_info, SequenceType):
979+
index_val = self.visit(node.slice)
980+
if isinstance(index_val, int):
981+
sequence_val = self.visit(value)
982+
assert isinstance(sequence_val, (list, tuple))
983+
return sequence_val[index_val]
950984
if isinstance(node.slice, ast.Constant):
951985
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
952986
raise exc.InvalidSequenceSubscription(node.slice)

helion/_compiler/type_propagation.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,20 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12681268
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12691269

12701270
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1271-
return super().propagate_getitem(key, origin)
1271+
# Try literal indexing first
1272+
try:
1273+
return super().propagate_getitem(key, origin)
1274+
except exc.TypeInferenceError:
1275+
# If indexing with a symbolic/grid index on device and the sequence length is known,
1276+
# conservatively merge all possible element types.
1277+
if origin.is_device() and isinstance(key, (SymIntType, GridIndexType)):
1278+
if not self.element_types:
1279+
return super().propagate_getitem(key, origin)
1280+
merged: TypeInfo = self.element_types[0]
1281+
for candidate in self.element_types[1:]:
1282+
merged = merged.merge(candidate)
1283+
return merged
1284+
raise
12721285

12731286
def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:
12741287
if isinstance(other, SequenceType):
@@ -2161,6 +2174,82 @@ def visit_For(self, node: ast.For) -> TypeInfo:
21612174
raise exc.NestedGridLoop
21622175

21632176
self.device_loop_depth += device_loop
2177+
2178+
# Try static unrolling for device grid loops when iteration count is known
2179+
try:
2180+
if node._loop_type != LoopType.HOST and isinstance(node.iter, ast.Call):
2181+
call_node = node.iter
2182+
# Extract begin, end, step; support only 1D grid here
2183+
begin_val: int | None
2184+
end_val: int | None
2185+
step_val: int | None
2186+
2187+
if len(call_node.args) == 1:
2188+
begin_val = 0
2189+
end_type = self.visit(call_node.args[0])
2190+
step_type: TypeInfo | None = None
2191+
else:
2192+
begin_type = self.visit(call_node.args[0])
2193+
end_type = self.visit(call_node.args[1])
2194+
step_type = (
2195+
self.visit(call_node.args[2]) if len(call_node.args) >= 3 else None
2196+
)
2197+
begin_val = (
2198+
begin_type.as_literal() if begin_type.is_literal() else None
2199+
) # type: ignore[assignment]
2200+
2201+
for kw in call_node.keywords:
2202+
if kw.arg == "step" and step_type is None:
2203+
step_type = self.visit(kw.value)
2204+
2205+
end_val = (
2206+
end_type.as_literal() if end_type.is_literal() else None
2207+
) # type: ignore[assignment]
2208+
step_val = (
2209+
step_type.as_literal() if (step_type is not None and step_type.is_literal()) else 1
2210+
) # type: ignore[assignment]
2211+
2212+
if (
2213+
isinstance(begin_val, int)
2214+
and isinstance(end_val, int)
2215+
and isinstance(step_val, int)
2216+
):
2217+
# Build concrete iteration values
2218+
iter_values = list(range(begin_val, end_val, step_val))
2219+
# Small guard to avoid excessive compile-time blowups
2220+
if len(iter_values) <= 64:
2221+
merged_scope: LocalScope | None = None
2222+
for iv in iter_values:
2223+
# Emulate _loop_body with loop index bound to a literal
2224+
self.push_scope()
2225+
self._assign(node.target, LiteralType(self.origin(), iv))
2226+
exit_scopes = [self.scope]
2227+
for stmt in node.body:
2228+
self.visit(stmt)
2229+
if isinstance(stmt, (ast.Break, ast.Continue)):
2230+
exit_scopes.append(self.scope.clone())
2231+
# Reset loop variable back to its GridIndexType to avoid control-flow merging issues
2232+
self._assign(
2233+
node.target,
2234+
iter_type.propagate_iter(self.origin()),
2235+
)
2236+
self.pop_scope()
2237+
iter_scope = functools.reduce(lambda x, y: x.merge(y), exit_scopes)
2238+
if merged_scope is None:
2239+
merged_scope = iter_scope
2240+
else:
2241+
merged_scope.merge(iter_scope)
2242+
2243+
if merged_scope is not None:
2244+
body = merged_scope
2245+
orelse = self._body(node.orelse)
2246+
self.scope.merge_if_else(body, orelse)
2247+
self.device_loop_depth -= device_loop
2248+
return NoType(origin=self.origin())
2249+
except NotImplementedError:
2250+
# Fall back to generic handling if we can't statically determine iterations
2251+
pass
2252+
21642253
body = self._loop_body(node.body)
21652254
with self.swap_scope(body):
21662255
# second pass for fixed point

test/test_type_propagation.expected

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,74 @@ def root_graph_2():
537537
_for_loop = helion_language__tracing_ops__for_loop(1, [0], [x_size0], []); x_size0 = _for_loop = None
538538
return None
539539

540+
--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code0)
541+
from __future__ import annotations
542+
543+
import torch
544+
import triton
545+
import triton.language as tl
546+
from helion.runtime import default_launcher as _default_launcher
547+
548+
@triton.jit
549+
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
550+
pid_0 = tl.program_id(0)
551+
offset_0 = pid_0 * _BLOCK_SIZE_0
552+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
553+
mask_0 = indices_0 < out_size_0
554+
load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
555+
load_1 = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
556+
v_0 = load + load_1
557+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
558+
load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
559+
load_3 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
560+
v_1 = load_2 + load_3
561+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
562+
load_4 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
563+
load_5 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
564+
v_2 = load_4 + load_5
565+
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
566+
load_6 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
567+
load_7 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
568+
v_3 = load_6 + load_7
569+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
570+
571+
def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
572+
out = torch.zeros_like(As[0])
573+
_BLOCK_SIZE_0 = 16
574+
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
575+
return out
576+
577+
--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code1)
578+
from __future__ import annotations
579+
580+
import torch
581+
import triton
582+
import triton.language as tl
583+
from helion.runtime import default_launcher as _default_launcher
584+
585+
@triton.jit
586+
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
587+
pid_0 = tl.program_id(0)
588+
offset_0 = pid_0 * _BLOCK_SIZE_0
589+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
590+
mask_0 = indices_0 < out_size_0
591+
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
592+
load = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
593+
v_0 = acc + load
594+
load_1 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
595+
v_1 = v_0 + load_1
596+
load_2 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
597+
v_2 = v_1 + load_2
598+
load_3 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
599+
v_3 = v_2 + load_3
600+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
601+
602+
def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
603+
out = torch.zeros_like(As[0])
604+
_BLOCK_SIZE_0 = 16
605+
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
606+
return out
607+
540608
--- assertExpectedJournal(TestTypePropagation.test_hl_full_usage)
541609
def hl_full_usage(x: torch.Tensor):
542610
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:38>)

test/test_type_propagation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from helion._testing import RefEagerTestDisabled
1212
from helion._testing import TestCase
1313
from helion._testing import import_path
14+
from helion._testing import DEVICE
15+
from helion._testing import code_and_output
1416
import helion.language as hl
1517

1618
if TYPE_CHECKING:
@@ -132,6 +134,37 @@ def use_unsupported_property(x: torch.Tensor) -> torch.Tensor:
132134
):
133135
type_propagation_report(use_unsupported_property, x)
134136

137+
def test_for_loop_indexing_in_device_code0(self):
138+
@helion.kernel
139+
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
140+
out = torch.zeros_like(As[0])
141+
for tile in hl.tile(out.size()):
142+
for i in range(len(As)):
143+
a = As[i]
144+
out[tile] += a[tile]
145+
return out
146+
147+
args = [torch.randn(16, device=DEVICE) for _ in range(4)]
148+
code, result = code_and_output(kernel, (args,))
149+
torch.testing.assert_close(result, sum(args))
150+
self.assertExpectedJournal(code)
151+
152+
def test_for_loop_indexing_in_device_code1(self):
153+
@helion.kernel
154+
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
155+
out = torch.zeros_like(As[0])
156+
for tile in hl.tile(out.size()):
157+
acc = hl.zeros(tile)
158+
for i in range(len(As)):
159+
a = As[i]
160+
acc = acc + a[tile]
161+
out[tile] = acc
162+
return out
163+
164+
args = [torch.randn(16, device=DEVICE) for _ in range(4)]
165+
code, result = code_and_output(kernel, (args,))
166+
torch.testing.assert_close(result, sum(args))
167+
self.assertExpectedJournal(code)
135168

136169
if __name__ == "__main__":
137170
unittest.main()

0 commit comments

Comments
 (0)