2121@_decorators .api (tiles_as_sizes = True ) 
2222def  wait (
2323    signal_pad : torch .Tensor ,
24-     index : list [object ],
24+     index : list [object ]  =  [] ,
2525    signal : int  =  1 ,
2626    update : int  |  None  =  None ,
2727    op : str  =  "ld" ,
2828    sem : str  =  "acquire" ,
2929    scope : str  =  "gpu" ,
3030    skip_sync : bool  =  False ,
31+     as_ptrs : bool  =  False ,
3132) ->  None :
3233    """Wait until all entries of the signal_pad slice are equal to the signal value. 
3334    Args: 
@@ -39,6 +40,7 @@ def wait(
3940        sem: The memory sematic for acquring the lock (default: 'acquire') 
4041        scope: The scope of the lock (default: 'gpu') 
4142        skip_sync: Skip the syncthreads after the wait (default: False) 
43+         as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) 
4244
4345    Returns: 
4446        None 
@@ -49,14 +51,15 @@ def wait(
4951@_decorators .prepare_args (wait ) 
5052def  _ (
5153    signal_pad : torch .Tensor ,
52-     index : list [object ],
54+     index : list [object ]  =  [] ,
5355    signal : int  =  1 ,
5456    update : int  |  None  =  None ,
5557    op : str  =  "ld" ,
5658    sem : str  =  "acquire" ,
5759    scope : str  =  "gpu" ,
5860    skip_sync : bool  =  False ,
59- ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool ]:
61+     as_ptrs : bool  =  False ,
62+ ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool , bool ]:
6063    from  helion .language .tile_proxy  import  Tile 
6164
6265    valid_ops  =  {"ld" , "atomic_cas" }
@@ -88,22 +91,35 @@ def _(
8891    if  scope  not  in valid_scopes :
8992        raise  ValueError (f"Invalid scope '{ scope } { valid_scopes }  )
9093
94+     if  as_ptrs :
95+         if  len (index ) !=  0 :
96+             raise  ValueError (
97+                 f"When as_ptrs=True, signal_pad must be used without indexing. " 
98+                 f"Expected 0 indices but got { len (index )}  
99+             )
100+         if  signal_pad .dtype  not  in torch .uint64 , torch .int64 ):
101+             raise  ValueError (
102+                 f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " 
103+                 f"to represent memory pointers. Got dtype { signal_pad .dtype }  
104+             )
105+ 
91106    index  =  Tile ._prepare_index (index )
92107    index  =  Tile ._tiles_to_sizes (index )
93108
94-     return  (signal_pad , index , signal , update , op , sem , scope , skip_sync )
109+     return  (signal_pad , index , signal , update , op , sem , scope , skip_sync ,  as_ptrs )
95110
96111
97112@_decorators .register_fake (wait ) 
98113def  _ (
99114    signal_pad : torch .Tensor ,
100-     index : list [object ],
115+     index : list [object ]  =  [] ,
101116    signal : int  =  1 ,
102117    update : int  |  None  =  None ,
103118    op : str  =  "ld" ,
104119    sem : str  =  "acquire" ,
105120    scope : str  =  "sys" ,
106121    skip_sync : bool  =  False ,
122+     as_ptrs : bool  =  False ,
107123) ->  None :
108124    return  None 
109125
@@ -123,35 +139,38 @@ def _(state: CodegenState) -> ast.AST:
123139    sem  =  state .proxy_arg (5 )
124140    scope  =  state .proxy_arg (6 )
125141    skip_sync  =  state .proxy_arg (7 )
142+     as_ptrs  =  state .proxy_arg (8 )
126143
127144    assert  isinstance (signal_pad , torch .Tensor )
128145    assert  isinstance (index , (list ))
129146
130-     indices  =  SubscriptIndexing .create (state , signal_pad , index )
131-     signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
132- 
133-     signal_expr  =  ast .Constant (value = signal )
134-     update_expr  =  ast .Constant (value = update )
135- 
136147    assert  type (op ) is  str 
137148    assert  type (sem ) is  str 
138149    assert  type (scope ) is  str 
139150
140-     bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
141-     is_scalar  =  len (bar_tensor_shape ) ==  0 
142- 
143-     if  is_scalar :
144-         call_triton_wait_signal  =  f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } { sem } { scope } { op } { skip_sync }  
151+     if  as_ptrs :
152+         bar_tensor_shape  =  signal_pad .shape 
153+         bar_addrs  =  "signal_pad_arg.to(tl.pointer_type(tl.int32))" 
145154    else :
155+         indices  =  SubscriptIndexing .create (state , signal_pad , index )
146156        if  signal_pad .dtype  not  in torch .int32 , torch .uint32 ):
147157            raise  NotImplementedError (
148158                f"Unsupported signal pad dtype: { signal_pad .dtype }  
149159            )
150-         call_triton_wait_signal  =  f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } { sem } { scope } { op } { skip_sync }  
160+         signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
161+         bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
162+         bar_addrs  =  f"{ signal_pad_name }  
163+ 
164+     signal_expr  =  ast .Constant (value = signal )
165+     update_expr  =  ast .Constant (value = update )
166+ 
167+     is_scalar  =  len (bar_tensor_shape ) ==  0 
168+ 
169+     call_triton_wait_signal  =  f"helion.runtime.triton_wait_{ ''  if  is_scalar  else  'multiple_' } { bar_addrs } { sem } { scope } { op } { skip_sync }  
151170
152171    return  expr_from_string (
153172        call_triton_wait_signal ,
154-         offset = indices .index_expr ,
173+         signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,
155174        signal = signal_expr ,
156175        update = update_expr ,
157176    )
@@ -161,13 +180,14 @@ def _(state: CodegenState) -> ast.AST:
161180@_decorators .api (tiles_as_sizes = True ) 
162181def  signal (
163182    signal_pad : torch .Tensor ,
164-     index : list [object ],
183+     index : list [object ]  =  [] ,
165184    signal : int  =  1 ,
166185    wait_for : int  |  None  =  None ,
167186    op : str  =  "atomic_xchg" ,
168187    sem : str  =  "release" ,
169188    scope : str  =  "gpu" ,
170189    skip_sync : bool  =  False ,
190+     as_ptrs : bool  =  False ,
171191) ->  torch .Tensor :
172192    """Set the signal_pad slice to the signal value. 
173193    Args: 
@@ -179,21 +199,25 @@ def signal(
179199        sem: The memory sematic for acquring the lock (default: 'release') 
180200        scope: The scope of the lock (default: 'gpu') 
181201        skip_sync: Skip the syncthreads before sending signal (default: False) 
202+         as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) 
203+     Returns: 
204+         The old value of the signal_pad slice before the update. 
182205    """ 
183206    raise  exc .NotInsideKernel 
184207
185208
186209@_decorators .prepare_args (signal ) 
187210def  _ (
188211    signal_pad : torch .Tensor ,
189-     index : list [object ],
212+     index : list [object ]  =  [] ,
190213    signal : int  =  1 ,
191214    wait_for : int  |  None  =  None ,
192215    op : str  =  "atomic_xchg" ,
193216    sem : str  =  "release" ,
194217    scope : str  =  "gpu" ,
195218    skip_sync : bool  =  False ,
196- ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool ]:
219+     as_ptrs : bool  =  False ,
220+ ) ->  tuple [torch .Tensor , object , int , int  |  None , str , str , str , bool , bool ]:
197221    from  helion .language .tile_proxy  import  Tile 
198222
199223    valid_ops  =  {"atomic_add" , "atomic_xchg" , "atomic_cas" }
@@ -220,23 +244,38 @@ def _(
220244    if  scope  not  in valid_scopes :
221245        raise  ValueError (f"Invalid scope '{ scope } { valid_scopes }  )
222246
247+     if  as_ptrs :
248+         if  len (index ) !=  0 :
249+             raise  ValueError (
250+                 f"When as_ptrs=True, signal_pad must be used without indexing. " 
251+                 f"Expected 0 indices but got { len (index )}  
252+             )
253+         if  signal_pad .dtype  not  in torch .uint64 , torch .int64 ):
254+             raise  ValueError (
255+                 f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " 
256+                 f"to represent memory pointers. Got dtype { signal_pad .dtype }  
257+             )
258+ 
223259    index  =  Tile ._prepare_index (index )
224260    index  =  Tile ._tiles_to_sizes (index )
225261
226-     return  (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync )
262+     return  (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync ,  as_ptrs )
227263
228264
229265@_decorators .register_fake (signal ) 
230266def  _ (
231267    signal_pad : torch .Tensor ,
232-     index : list [object ],
268+     index : list [object ]  =  [] ,
233269    signal : int  =  1 ,
234270    wait_for : int  |  None  =  None ,
235271    op : str  =  "atomic_xchg" ,
236272    sem : str  =  "release" ,
237273    scope : str  =  "gpu" ,
238274    skip_sync : bool  =  False ,
275+     as_ptrs : bool  =  False ,
239276) ->  torch .Tensor :
277+     if  as_ptrs :
278+         return  signal_pad .new_empty (signal_pad .shape )
240279    return  signal_pad .new_empty (SubscriptIndexing .compute_shape (signal_pad , index ))
241280
242281
@@ -255,43 +294,51 @@ def _(state: CodegenState) -> ast.AST:
255294    sem  =  state .proxy_arg (5 )
256295    scope  =  state .proxy_arg (6 )
257296    skip_sync  =  state .proxy_arg (7 )
297+     as_ptrs  =  state .proxy_arg (8 )
258298
259299    assert  isinstance (signal_pad , torch .Tensor )
260300    assert  isinstance (index , list )
261301
262-     indices  =  SubscriptIndexing .create (state , signal_pad , index )
263-     signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
302+     assert  type (op ) is  str 
303+     assert  type (sem ) is  str 
304+     assert  type (scope ) is  str 
305+ 
306+     if  as_ptrs :
307+         bar_tensor_shape  =  signal_pad .shape 
308+         bar_addrs  =  "signal_pad_arg.to(tl.pointer_type(tl.int32))" 
309+     else :
310+         indices  =  SubscriptIndexing .create (state , signal_pad , index )
311+         if  signal_pad .dtype  not  in torch .int32 , torch .uint32 ):
312+             raise  NotImplementedError (
313+                 f"Unsupported signal pad dtype: { signal_pad .dtype }  
314+             )
315+         signal_pad_name  =  state .device_function .tensor_arg (signal_pad ).name 
316+         bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
317+         bar_addrs  =  f"{ signal_pad_name }  
264318
265319    signal_expr  =  ast .Constant (value = signal )
266320    if  wait_for  is  not None :
267321        wait_for_expr  =  ast .Constant (value = wait_for )
268322    else :
269323        wait_for_expr  =  ast .Constant (value = 0 )
270324    skip_sync_expr  =  ast .Constant (value = skip_sync )
271-     assert  type (op ) is  str 
272-     assert  type (sem ) is  str 
273-     assert  type (scope ) is  str 
274325
275326    if  op  ==  "atomic_cas" :
276327        bar_tensor_shape  =  SubscriptIndexing .compute_shape (signal_pad , index )
277328        is_scalar  =  len (bar_tensor_shape ) ==  0 
278-         if  is_scalar :
279-             call_triton_wait_signal  =  f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } { sem } { scope } { op }  
280-         else :
281-             call_triton_wait_signal  =  f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } { sem } { scope } { op }  
282- 
329+         call_triton_wait_signal  =  f"helion.runtime.triton_wait_{ ''  if  is_scalar  else  'multiple_' } { bar_addrs } { sem } { scope } { op }  
283330        return  expr_from_string (
284331            call_triton_wait_signal ,
285-             offset = indices .index_expr ,
332+             signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,
286333            wait_for = wait_for_expr ,
287334            signal = signal_expr ,
288335            skip_sync = skip_sync_expr ,
289336        )
290-     call_triton_send_signal  =  f"helion.runtime.triton_send_signal(addr={ signal_pad_name }  + offset , update=signal, sem='{ sem } { scope } { op }  
337+     call_triton_send_signal  =  f"helion.runtime.triton_send_signal(addr={ bar_addrs } { sem } { scope } { op }  
291338
292339    return  expr_from_string (
293340        call_triton_send_signal ,
294-         offset = indices .index_expr ,
341+         signal_pad_arg = state . ast_arg ( 0 )  if   as_ptrs   else   indices .index_expr ,
295342        signal = signal_expr ,
296343        skip_sync = skip_sync_expr ,
297344    )
0 commit comments