@@ -294,6 +294,60 @@ def fn(x: torch.Tensor) -> torch.Tensor:
294294 )
295295 torch .testing .assert_close (result , expected )
296296
297+ def test_tile_id (self ):
298+ @helion .kernel
299+ def test_tile_id_access (x : torch .Tensor ) -> torch .Tensor :
300+ out = torch .zeros_like (x , dtype = torch .int32 )
301+ for tile in hl .tile (x .size (0 )):
302+ out [tile ] = tile .id
303+ return out
304+
305+ x = torch .randn ([64 ], device = DEVICE )
306+ code , result = code_and_output (
307+ test_tile_id_access ,
308+ (x ,),
309+ block_size = 16 ,
310+ )
311+ expected = torch .arange (4 , device = DEVICE , dtype = torch .int32 ).repeat_interleave (
312+ repeats = 16
313+ )
314+ torch .testing .assert_close (result , expected )
315+ code , result = code_and_output (
316+ test_tile_id_access ,
317+ (x ,),
318+ block_size = 1 ,
319+ )
320+ expected = torch .arange (64 , device = DEVICE , dtype = torch .int32 )
321+ torch .testing .assert_close (result , expected )
322+
323+ def test_tile_id_indexing (self ):
324+ @helion .kernel
325+ def test_tile_id_atomic_add (x : torch .Tensor ) -> torch .Tensor :
326+ out = torch .zeros_like (x , dtype = torch .int32 )
327+ for tile_m , tile_n in hl .tile (x .size ()):
328+ x [tile_m , tile_n ] = x [tile_m , tile_n ] + 1
329+ hl .atomic_add (out , [tile_m .id , tile_n .id ], 1 )
330+ return out
331+
332+ x = torch .randn (64 , 64 , device = DEVICE )
333+ code , result = code_and_output (
334+ test_tile_id_atomic_add ,
335+ (x ,),
336+ block_size = [16 , 16 ],
337+ )
338+
339+ print (result )
340+ expected = torch .arange (64 , 64 , device = DEVICE , dtype = torch .int32 )
341+ expected [:4 , :4 ] = 1
342+ torch .testing .assert_close (result , expected )
343+ code , result = code_and_output (
344+ test_tile_id_atomic_add ,
345+ (x ,),
346+ block_size = 1 ,
347+ )
348+ expected = torch .ones (64 , 64 , device = DEVICE , dtype = torch .int32 )
349+ torch .testing .assert_close (result , expected )
350+
297351 def test_atomic_add_symint (self ):
298352 @helion .kernel (config = {"block_size" : 32 })
299353 def fn (x : torch .Tensor ) -> torch .Tensor :
0 commit comments