2121@_decorators .api (tiles_as_sizes = True ) 
2222def  wait (
2323    signal_pad : torch .Tensor ,
24-     index : list [object ],
24+     index : list [object ]  |   None   =   None ,
2525    signal : int  =  1 ,
2626    update : int  |  None  =  None ,
2727    op : str  =  "ld" ,
2828    sem : str  =  "acquire" ,
2929    scope : str  =  "gpu" ,
3030    skip_sync : bool  =  False ,
31+     as_ptrs : bool  =  False ,
3132) ->  None :
3233    """Wait until all entries of the signal_pad slice are equal to the signal value. 
3334    Args: 
@@ -39,6 +40,7 @@ def wait(
3940        sem: The memory sematic for acquring the lock (default: 'acquire') 
4041        scope: The scope of the lock (default: 'gpu') 
4142        skip_sync: Skip the syncthreads after the wait (default: False) 
43+         as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) 
4244
4345    Returns: 
4446        None 
@@ -49,14 +51,15 @@ def wait(
4951@_decorators .prepare_args (wait ) 
5052def  _ (
5153    signal_pad : torch .Tensor ,
52-     index : list [object ],
54+     index : list [object ]  |   None   =   None ,
5355    signal : int  =  1 ,
5456    update : int  |  None  =  None ,
5557    op : str  =  "ld" ,
5658    sem : str  =  "acquire" ,
5759    scope : str  =  "gpu" ,
5860    skip_sync : bool  =  False ,
59- ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool ]:
61+     as_ptrs : bool  =  False ,
62+ ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool , bool ]:
6063    from  helion .language .tile_proxy  import  Tile 
6164
6265    valid_ops  =  {"ld" , "atomic_cas" }
@@ -88,22 +91,37 @@ def _(
8891    if  scope  not  in valid_scopes :
8992        raise  ValueError (f"Invalid scope '{ scope } { valid_scopes }  )
9093
94+     if  as_ptrs :
95+         if  index  is  not None :
96+             raise  ValueError (
97+                 f"When as_ptrs=True, signal_pad must be used without indexing. " 
98+                 f"Expected 0 indices but got { len (index )}  
99+             )
100+         if  signal_pad .dtype  not  in torch .uint64 , torch .int64 ):
101+             raise  ValueError (
102+                 f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " 
103+                 f"to represent memory pointers. Got dtype { signal_pad .dtype }  
104+             )
105+     if  index  is  None :
106+         index  =  []
107+ 
91108    index  =  Tile ._prepare_index (index )
92109    index  =  Tile ._tiles_to_sizes (index )
93110
94-     return  (signal_pad , index , signal , update , op , sem , scope , skip_sync )
111+     return  (signal_pad , index , signal , update , op , sem , scope , skip_sync ,  as_ptrs )
95112
96113
97114@_decorators .register_fake (wait ) 
98115def  _ (
99116    signal_pad : torch .Tensor ,
100-     index : list [object ],
117+     index : list [object ]  |   None   =   None ,
101118    signal : int  =  1 ,
102119    update : int  |  None  =  None ,
103120    op : str  =  "ld" ,
104121    sem : str  =  "acquire" ,
105122    scope : str  =  "sys" ,
106123    skip_sync : bool  =  False ,
124+     as_ptrs : bool  =  False ,
107125) ->  None :
108126    return  None 
109127
@@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST:
123141    sem  =  state .proxy_arg (5 )
124142    scope  =  state .proxy_arg (6 )
125143    skip_sync  =  state .proxy_arg (7 )
144+     as_ptrs  =  state .proxy_arg (8 )
126145
127146    assert  isinstance (signal_pad , torch .Tensor )
128147    assert  isinstance (index , (list ))
129148
130-     indices  =  SubscriptIndexing .create (state , signal_pad , index )
131-     signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
132- 
133-     signal_expr  =  ast .Constant (value = signal )  # pyright: ignore[reportArgumentType] 
134-     update_expr  =  ast .Constant (value = update )  # pyright: ignore[reportArgumentType] 
135- 
136149    assert  type (op ) is  str 
137150    assert  type (sem ) is  str 
138151    assert  type (scope ) is  str 
139152
140-     bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
141-     is_scalar  =  len (bar_tensor_shape ) ==  0 
142- 
143-     if  is_scalar :
144-         call_triton_wait_signal  =  f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } { sem } { scope } { op } { skip_sync }  
153+     if  as_ptrs :
154+         bar_tensor_shape  =  signal_pad .shape 
155+         bar_addrs  =  "signal_pad_arg.to(tl.pointer_type(tl.int32))" 
145156    else :
157+         indices  =  SubscriptIndexing .create (state , signal_pad , index )
146158        if  signal_pad .dtype  not  in torch .int32 , torch .uint32 ):
147159            raise  NotImplementedError (
148160                f"Unsupported signal pad dtype: { signal_pad .dtype }  
149161            )
150-         call_triton_wait_signal  =  f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } { sem } { scope } { op } { skip_sync }  
162+         signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
163+         bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
164+         bar_addrs  =  f"{ signal_pad_name }  
165+ 
166+     signal_expr  =  ast .Constant (value = signal )  # pyright: ignore[reportArgumentType] 
167+     update_expr  =  ast .Constant (value = update )  # pyright: ignore[reportArgumentType] 
168+ 
169+     is_scalar  =  len (bar_tensor_shape ) ==  0 
170+ 
171+     call_triton_wait_signal  =  f"helion.runtime.triton_wait_{ ''  if  is_scalar  else  'multiple_' } { bar_addrs } { sem } { scope } { op } { skip_sync }  
151172
152173    return  expr_from_string (
153174        call_triton_wait_signal ,
154-         offset = indices .index_expr ,
175+         signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,   # pyright: ignore[reportPossiblyUnboundVariable] 
155176        signal = signal_expr ,
156177        update = update_expr ,
157178    )
@@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST:
161182@_decorators .api (tiles_as_sizes = True ) 
162183def  signal (
163184    signal_pad : torch .Tensor ,
164-     index : list [object ],
185+     index : list [object ]  |   None   =   None ,
165186    signal : int  =  1 ,
166187    wait_for : int  |  None  =  None ,
167188    op : str  =  "atomic_xchg" ,
168189    sem : str  =  "release" ,
169190    scope : str  =  "gpu" ,
170191    skip_sync : bool  =  False ,
192+     as_ptrs : bool  =  False ,
171193) ->  torch .Tensor :
172194    """Set the signal_pad slice to the signal value. 
173195    Args: 
@@ -179,21 +201,25 @@ def signal(
179201        sem: The memory sematic for acquring the lock (default: 'release') 
180202        scope: The scope of the lock (default: 'gpu') 
181203        skip_sync: Skip the syncthreads before sending signal (default: False) 
204+         as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) 
205+     Returns: 
206+         The old value of the signal_pad slice before the update. 
182207    """ 
183208    raise  exc .NotInsideKernel 
184209
185210
186211@_decorators .prepare_args (signal ) 
187212def  _ (
188213    signal_pad : torch .Tensor ,
189-     index : list [object ],
214+     index : list [object ]  |   None   =   None ,
190215    signal : int  =  1 ,
191216    wait_for : int  |  None  =  None ,
192217    op : str  =  "atomic_xchg" ,
193218    sem : str  =  "release" ,
194219    scope : str  =  "gpu" ,
195220    skip_sync : bool  =  False ,
196- ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool ]:
221+     as_ptrs : bool  =  False ,
222+ ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool , bool ]:
197223    from  helion .language .tile_proxy  import  Tile 
198224
199225    valid_ops  =  {"atomic_add" , "atomic_xchg" , "atomic_cas" }
@@ -220,23 +246,42 @@ def _(
220246    if  scope  not  in valid_scopes :
221247        raise  ValueError (f"Invalid scope '{ scope } { valid_scopes }  )
222248
249+     if  as_ptrs :
250+         if  index  is  not None :
251+             raise  ValueError (
252+                 f"When as_ptrs=True, signal_pad must be used without indexing. " 
253+                 f"Expected 0 indices but got { len (index )}  
254+             )
255+         if  signal_pad .dtype  not  in torch .uint64 , torch .int64 ):
256+             raise  ValueError (
257+                 f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " 
258+                 f"to represent memory pointers. Got dtype { signal_pad .dtype }  
259+             )
260+     if  index  is  None :
261+         index  =  []
262+ 
223263    index  =  Tile ._prepare_index (index )
224264    index  =  Tile ._tiles_to_sizes (index )
225265
226-     return  (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync )
266+     return  (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync ,  as_ptrs )
227267
228268
229269@_decorators .register_fake (signal ) 
230270def  _ (
231271    signal_pad : torch .Tensor ,
232-     index : list [object ],
272+     index : list [object ]  |   None   =   None ,
233273    signal : int  =  1 ,
234274    wait_for : int  |  None  =  None ,
235275    op : str  =  "atomic_xchg" ,
236276    sem : str  =  "release" ,
237277    scope : str  =  "gpu" ,
238278    skip_sync : bool  =  False ,
279+     as_ptrs : bool  =  False ,
239280) ->  torch .Tensor :
281+     if  index  is  None :
282+         index  =  []
283+     if  as_ptrs :
284+         return  signal_pad .new_empty (signal_pad .shape )
240285    return  signal_pad .new_empty (SubscriptIndexing .compute_shape (signal_pad , index ))
241286
242287
@@ -255,43 +300,51 @@ def _(state: CodegenState) -> ast.AST:
255300    sem  =  state .proxy_arg (5 )
256301    scope  =  state .proxy_arg (6 )
257302    skip_sync  =  state .proxy_arg (7 )
303+     as_ptrs  =  state .proxy_arg (8 )
258304
259305    assert  isinstance (signal_pad , torch .Tensor )
260306    assert  isinstance (index , list )
261307
262-     indices  =  SubscriptIndexing .create (state , signal_pad , index )
263-     signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
308+     assert  type (op ) is  str 
309+     assert  type (sem ) is  str 
310+     assert  type (scope ) is  str 
311+ 
312+     if  as_ptrs :
313+         bar_tensor_shape  =  signal_pad .shape 
314+         bar_addrs  =  "signal_pad_arg.to(tl.pointer_type(tl.int32))" 
315+     else :
316+         indices  =  SubscriptIndexing .create (state , signal_pad , index )
317+         if  signal_pad .dtype  not  in torch .int32 , torch .uint32 ):
318+             raise  NotImplementedError (
319+                 f"Unsupported signal pad dtype: { signal_pad .dtype }  
320+             )
321+         signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
322+         bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
323+         bar_addrs  =  f"{ signal_pad_name }  
324+ 
325+     is_scalar  =  len (bar_tensor_shape ) ==  0 
264326
265327    signal_expr  =  ast .Constant (value = signal )  # pyright: ignore[reportArgumentType] 
266328    if  wait_for  is  not None :
267329        wait_for_expr  =  ast .Constant (value = wait_for )  # pyright: ignore[reportArgumentType] 
268330    else :
269331        wait_for_expr  =  ast .Constant (value = 0 )
270332    skip_sync_expr  =  ast .Constant (value = skip_sync )  # pyright: ignore[reportArgumentType] 
271-     assert  type (op ) is  str 
272-     assert  type (sem ) is  str 
273-     assert  type (scope ) is  str 
274333
275334    if  op  ==  "atomic_cas" :
276-         bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
277-         is_scalar  =  len (bar_tensor_shape ) ==  0 
278-         if  is_scalar :
279-             call_triton_wait_signal  =  f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } { sem } { scope } { op }  
280-         else :
281-             call_triton_wait_signal  =  f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } { sem } { scope } { op }  
282- 
335+         call_triton_wait_signal  =  f"helion.runtime.triton_wait_{ ''  if  is_scalar  else  'multiple_' } { bar_addrs } { sem } { scope } { op }  
283336        return  expr_from_string (
284337            call_triton_wait_signal ,
285-             offset = indices .index_expr ,
338+             signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,   # pyright: ignore[reportPossiblyUnboundVariable] 
286339            wait_for = wait_for_expr ,
287340            signal = signal_expr ,
288341            skip_sync = skip_sync_expr ,
289342        )
290-     call_triton_send_signal  =  f"helion.runtime.triton_send_signal(addr={ signal_pad_name }  + offset , update=signal, sem='{ sem } { scope } { op }  
343+     call_triton_send_signal  =  f"helion.runtime.triton_send_signal(addr={ bar_addrs } { sem } { scope } { op }  
291344
292345    return  expr_from_string (
293346        call_triton_send_signal ,
294-         offset = indices .index_expr ,
347+         signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,   # pyright: ignore[reportPossiblyUnboundVariable] 
295348        signal = signal_expr ,
296349        skip_sync = skip_sync_expr ,
297350    )
0 commit comments