@@ -74,15 +74,35 @@ def sort_key(self) -> tuple[object, ...]:
7474@dataclasses .dataclass
7575class TensorArg (Argument ):
7676 fake_value : torch .Tensor
77- _host_str : str
77+ _host_str : str | None
7878
7979 def host_str (self ) -> str :
80+ if self ._host_str is None :
81+ raise RuntimeError ("TensorArg has no host representation" )
8082 return self ._host_str
8183
8284
8385@dataclasses .dataclass
8486class TensorDescriptorArg (TensorArg ):
85- pass
87+ # Permutation applied to make stride==1 dimension last
88+ permutation : list [int ] | None = None
89+
90+ def host_str (self ) -> str :
91+ if self ._host_str is None :
92+ raise RuntimeError (
93+ "TensorDescriptorArg is device-only and has no host representation"
94+ )
95+ return self ._host_str
96+
97+ @property
98+ def inverse_permutation (self ) -> list [int ]:
99+ """Get the inverse permutation to undo the applied permutation."""
100+ if (permutation := self .permutation ) is None :
101+ raise RuntimeError ("TensorDescriptorArg.permutation is None" )
102+ inverse_perm = [0 ] * len (permutation )
103+ for i , p in enumerate (permutation ):
104+ inverse_perm [p ] = i
105+ return inverse_perm
86106
87107
88108@dataclasses .dataclass
@@ -144,6 +164,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
144164 self .config = config
145165 self .codegen = codegen
146166 self .arguments : list [Argument ] = []
167+ self .preamble : list [ast .AST ] = []
147168 self .body : list [ast .AST ] = []
148169 self ._tensor_args : dict [torch .Tensor , TensorArg ] = {}
149170 self ._tensor_descriptor_args : dict [
@@ -272,20 +293,59 @@ def tensor_arg(
272293
273294 def tensor_descriptor_arg (
274295 self , fake_value : torch .Tensor , block_size : list [int | torch .SymInt ]
275- ) -> TensorArg :
296+ ) -> TensorDescriptorArg :
276297 host_function = HostFunction .current ()
277- block_size_expr = ", " .join (
278- map (HostFunction .current ().literal_expr , block_size )
279- )
298+ block_size_expr = ", " .join (map (self .literal_expr , block_size ))
280299 key = (fake_value , block_size_expr )
281300 if key not in self ._tensor_descriptor_args :
282301 origin = host_function .tensor_to_origin [fake_value ]
302+ desc_name = self .new_var (origin .suggest_var_name () + "_desc" )
303+ env = CompileEnvironment .current ()
304+
305+ # Find which dimension has stride==1
306+ stride_one_dim = [* map (env .size_hint , fake_value .stride ())].index (1 )
307+
308+ # Determine if we need permutation (stride==1 dimension is not last)
309+ permutation = None
310+ if stride_one_dim != fake_value .ndim - 1 :
311+ # Create permutation to move stride==1 dimension to last position
312+ permutation = [* range (fake_value .ndim )]
313+ permutation .pop (stride_one_dim )
314+ permutation .append (stride_one_dim )
315+
316+ # Create the regular tensor arg and size/stride args
317+ tensor_arg = self .tensor_arg (fake_value )
318+ size_args = [
319+ self .tensor_size (fake_value , i ) for i in range (fake_value .ndim )
320+ ]
321+ stride_args = [
322+ self .tensor_stride (fake_value , i ) for i in range (fake_value .ndim )
323+ ]
324+
325+ # Apply permutation if needed
326+ if permutation is not None :
327+ size_args = [size_args [i ] for i in permutation ]
328+ stride_args = [stride_args [i ] for i in permutation ]
329+ block_size = [block_size [i ] for i in permutation ]
330+ # Update block_size_expr for the permuted order
331+ block_size_expr = ", " .join (map (self .literal_expr , block_size ))
332+
333+ # Add tl.make_tensor_descriptor call to preamble
334+ sizes = ", " .join ([arg .name for arg in size_args ])
335+ strides = ", " .join ([arg .name for arg in stride_args ])
336+
337+ descriptor_stmt = statement_from_string (
338+ f"{ desc_name } = tl.make_tensor_descriptor({ tensor_arg .name } , [{ sizes } ], [{ strides } ], [{ block_size_expr } ])"
339+ )
340+ self .preamble .append (descriptor_stmt )
341+
283342 arg = TensorDescriptorArg (
284- self . new_var ( origin . suggest_var_name () + "_desc" ) ,
343+ desc_name ,
285344 fake_value ,
286- f"TensorDescriptor.from_tensor({ origin .host_str ()} , [{ block_size_expr } ])" ,
345+ None , # No host_str since this is device-only
346+ permutation ,
287347 )
288- self .arguments . append ( arg )
348+ # Don't add to self.arguments since this is device-only
289349 self ._tensor_descriptor_args [key ] = arg
290350 return self ._tensor_descriptor_args [key ]
291351
@@ -342,20 +402,28 @@ def sorted_args(self) -> list[Argument]:
342402 self .arguments .sort (key = lambda arg : arg .sort_key ())
343403 return self .arguments
344404
345- def codegen_function_def (self ) -> ast .FunctionDef :
346- return ast_rename (
347- create (
348- ast .FunctionDef ,
349- name = self .name ,
350- args = create_arguments (
351- [arg .arg_def_node () for arg in self .sorted_args ()]
405+ def codegen_function_def (self ) -> list [ast .stmt ]:
406+ prefix = []
407+ if self ._tensor_descriptor_args :
408+ prefix .append (
409+ statement_from_string ("helion.runtime.set_triton_allocator()" )
410+ )
411+ return [
412+ * prefix ,
413+ ast_rename (
414+ create (
415+ ast .FunctionDef ,
416+ name = self .name ,
417+ args = create_arguments (
418+ [arg .arg_def_node () for arg in self .sorted_args ()]
419+ ),
420+ body = [* self .preamble , * self .body ],
421+ decorator_list = [expr_from_string ("triton.jit" )],
422+ type_params = [],
352423 ),
353- body = self .body ,
354- decorator_list = [expr_from_string ("triton.jit" )],
355- type_params = [],
424+ {k : v [0 ] for k , v in self ._variable_renames .items ()},
356425 ),
357- {k : v [0 ] for k , v in self ._variable_renames .items ()},
358- )
426+ ]
359427
360428 def codegen_function_call (self ) -> ast .AST :
361429 args = [arg .host_str () for arg in self .sorted_args ()]
@@ -390,14 +458,15 @@ def dead_code_elimination(self) -> None:
390458 """
391459
392460 for _ in range (8 ):
393- rw = ReadWrites .from_list (self .body )
461+ rw = ReadWrites .from_list ([ * self .preamble , * self . body ] )
394462 to_remove = set ()
395463 for name in self .dce_vars :
396464 if name in rw .writes and name not in rw .reads :
397465 to_remove .add (name )
398466 if not to_remove :
399467 break
400468 self .body [:] = ast_delete_assignments (self .body , to_remove )
469+ self .preamble [:] = ast_delete_assignments (self .preamble , to_remove )
401470
402471 # drop any unused args
403472 args_to_remove = {
0 commit comments