2121@_decorators .api (tiles_as_sizes = True )
2222def wait (
2323 signal_pad : torch .Tensor ,
24- index : list [object ],
24+ index : list [object ] | None = None ,
2525 signal : int = 1 ,
2626 update : int | None = None ,
2727 op : str = "ld" ,
2828 sem : str = "acquire" ,
2929 scope : str = "gpu" ,
3030 skip_sync : bool = False ,
31+ as_ptrs : bool = False ,
3132) -> None :
3233 """Wait until all entries of the signal_pad slice are equal to the signal value.
3334 Args:
@@ -39,6 +40,7 @@ def wait(
3940 sem: The memory sematic for acquring the lock (default: 'acquire')
4041 scope: The scope of the lock (default: 'gpu')
4142 skip_sync: Skip the syncthreads after the wait (default: False)
43+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
4244
4345 Returns:
4446 None
@@ -49,14 +51,15 @@ def wait(
4951@_decorators .prepare_args (wait )
5052def _ (
5153 signal_pad : torch .Tensor ,
52- index : list [object ],
54+ index : list [object ] | None = None ,
5355 signal : int = 1 ,
5456 update : int | None = None ,
5557 op : str = "ld" ,
5658 sem : str = "acquire" ,
5759 scope : str = "gpu" ,
5860 skip_sync : bool = False ,
59- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
61+ as_ptrs : bool = False ,
62+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
6063 from helion .language .tile_proxy import Tile
6164
6265 valid_ops = {"ld" , "atomic_cas" }
@@ -88,22 +91,37 @@ def _(
8891 if scope not in valid_scopes :
8992 raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
9093
94+ if as_ptrs :
95+ if index is not None :
96+ raise ValueError (
97+ f"When as_ptrs=True, signal_pad must be used without indexing. "
98+ f"Expected 0 indices but got { len (index )} . "
99+ )
100+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
101+ raise ValueError (
102+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
103+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
104+ )
105+ if index is None :
106+ index = []
107+
91108 index = Tile ._prepare_index (index )
92109 index = Tile ._tiles_to_sizes (index )
93110
94- return (signal_pad , index , signal , update , op , sem , scope , skip_sync )
111+ return (signal_pad , index , signal , update , op , sem , scope , skip_sync , as_ptrs )
95112
96113
97114@_decorators .register_fake (wait )
98115def _ (
99116 signal_pad : torch .Tensor ,
100- index : list [object ],
117+ index : list [object ] | None = None ,
101118 signal : int = 1 ,
102119 update : int | None = None ,
103120 op : str = "ld" ,
104121 sem : str = "acquire" ,
105122 scope : str = "sys" ,
106123 skip_sync : bool = False ,
124+ as_ptrs : bool = False ,
107125) -> None :
108126 return None
109127
@@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST:
123141 sem = state .proxy_arg (5 )
124142 scope = state .proxy_arg (6 )
125143 skip_sync = state .proxy_arg (7 )
144+ as_ptrs = state .proxy_arg (8 )
126145
127146 assert isinstance (signal_pad , torch .Tensor )
128147 assert isinstance (index , (list ))
129148
130- indices = SubscriptIndexing .create (state , signal_pad , index )
131- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
132-
133- signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
134- update_expr = ast .Constant (value = update ) # pyright: ignore[reportArgumentType]
135-
136149 assert type (op ) is str
137150 assert type (sem ) is str
138151 assert type (scope ) is str
139152
140- bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
141- is_scalar = len (bar_tensor_shape ) == 0
142-
143- if is_scalar :
144- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
153+ if as_ptrs :
154+ bar_tensor_shape = signal_pad .shape
155+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
145156 else :
157+ indices = SubscriptIndexing .create (state , signal_pad , index )
146158 if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
147159 raise NotImplementedError (
148160 f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
149161 )
150- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
162+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
163+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
164+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
165+
166+ signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
167+ update_expr = ast .Constant (value = update ) # pyright: ignore[reportArgumentType]
168+
169+ is_scalar = len (bar_tensor_shape ) == 0
170+
171+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
151172
152173 return expr_from_string (
153174 call_triton_wait_signal ,
154- offset = indices .index_expr ,
175+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
155176 signal = signal_expr ,
156177 update = update_expr ,
157178 )
@@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST:
161182@_decorators .api (tiles_as_sizes = True )
162183def signal (
163184 signal_pad : torch .Tensor ,
164- index : list [object ],
185+ index : list [object ] | None = None ,
165186 signal : int = 1 ,
166187 wait_for : int | None = None ,
167188 op : str = "atomic_xchg" ,
168189 sem : str = "release" ,
169190 scope : str = "gpu" ,
170191 skip_sync : bool = False ,
192+ as_ptrs : bool = False ,
171193) -> torch .Tensor :
172194 """Set the signal_pad slice to the signal value.
173195 Args:
@@ -179,21 +201,25 @@ def signal(
179201 sem: The memory sematic for acquring the lock (default: 'release')
180202 scope: The scope of the lock (default: 'gpu')
181203 skip_sync: Skip the syncthreads before sending signal (default: False)
204+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
205+ Returns:
206+ The old value of the signal_pad slice before the update.
182207 """
183208 raise exc .NotInsideKernel
184209
185210
186211@_decorators .prepare_args (signal )
187212def _ (
188213 signal_pad : torch .Tensor ,
189- index : list [object ],
214+ index : list [object ] | None = None ,
190215 signal : int = 1 ,
191216 wait_for : int | None = None ,
192217 op : str = "atomic_xchg" ,
193218 sem : str = "release" ,
194219 scope : str = "gpu" ,
195220 skip_sync : bool = False ,
196- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
221+ as_ptrs : bool = False ,
222+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
197223 from helion .language .tile_proxy import Tile
198224
199225 valid_ops = {"atomic_add" , "atomic_xchg" , "atomic_cas" }
@@ -220,23 +246,42 @@ def _(
220246 if scope not in valid_scopes :
221247 raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
222248
249+ if as_ptrs :
250+ if index is not None :
251+ raise ValueError (
252+ f"When as_ptrs=True, signal_pad must be used without indexing. "
253+ f"Expected 0 indices but got { len (index )} . "
254+ )
255+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
256+ raise ValueError (
257+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
258+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
259+ )
260+ if index is None :
261+ index = []
262+
223263 index = Tile ._prepare_index (index )
224264 index = Tile ._tiles_to_sizes (index )
225265
226- return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync )
266+ return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync , as_ptrs )
227267
228268
229269@_decorators .register_fake (signal )
230270def _ (
231271 signal_pad : torch .Tensor ,
232- index : list [object ],
272+ index : list [object ] | None = None ,
233273 signal : int = 1 ,
234274 wait_for : int | None = None ,
235275 op : str = "atomic_xchg" ,
236276 sem : str = "release" ,
237277 scope : str = "gpu" ,
238278 skip_sync : bool = False ,
279+ as_ptrs : bool = False ,
239280) -> torch .Tensor :
281+ if index is None :
282+ index = []
283+ if as_ptrs :
284+ return signal_pad .new_empty (signal_pad .shape )
240285 return signal_pad .new_empty (SubscriptIndexing .compute_shape (signal_pad , index ))
241286
242287
@@ -255,43 +300,51 @@ def _(state: CodegenState) -> ast.AST:
255300 sem = state .proxy_arg (5 )
256301 scope = state .proxy_arg (6 )
257302 skip_sync = state .proxy_arg (7 )
303+ as_ptrs = state .proxy_arg (8 )
258304
259305 assert isinstance (signal_pad , torch .Tensor )
260306 assert isinstance (index , list )
261307
262- indices = SubscriptIndexing .create (state , signal_pad , index )
263- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
308+ assert type (op ) is str
309+ assert type (sem ) is str
310+ assert type (scope ) is str
311+
312+ if as_ptrs :
313+ bar_tensor_shape = signal_pad .shape
314+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
315+ else :
316+ indices = SubscriptIndexing .create (state , signal_pad , index )
317+ if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
318+ raise NotImplementedError (
319+ f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
320+ )
321+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
322+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
323+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
324+
325+ is_scalar = len (bar_tensor_shape ) == 0
264326
265327 signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
266328 if wait_for is not None :
267329 wait_for_expr = ast .Constant (value = wait_for ) # pyright: ignore[reportArgumentType]
268330 else :
269331 wait_for_expr = ast .Constant (value = 0 )
270332 skip_sync_expr = ast .Constant (value = skip_sync ) # pyright: ignore[reportArgumentType]
271- assert type (op ) is str
272- assert type (sem ) is str
273- assert type (scope ) is str
274333
275334 if op == "atomic_cas" :
276- bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
277- is_scalar = len (bar_tensor_shape ) == 0
278- if is_scalar :
279- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
280- else :
281- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
282-
335+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
283336 return expr_from_string (
284337 call_triton_wait_signal ,
285- offset = indices .index_expr ,
338+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
286339 wait_for = wait_for_expr ,
287340 signal = signal_expr ,
288341 skip_sync = skip_sync_expr ,
289342 )
290- call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ signal_pad_name } + offset , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
343+ call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ bar_addrs } , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
291344
292345 return expr_from_string (
293346 call_triton_send_signal ,
294- offset = indices .index_expr ,
347+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
295348 signal = signal_expr ,
296349 skip_sync = skip_sync_expr ,
297350 )
0 commit comments