5252        "pid_type" ,
5353        "indexing" ,
5454        "load_eviction_policies" ,
55+         "advanced_compiler_configuration" ,
5556    ]
5657)
5758VALID_PID_TYPES  =  ("flat" , "xyz" , "persistent_blocked" , "persistent_interleaved" )
@@ -105,6 +106,7 @@ class ConfigSpec:
105106            EnumFragment (choices = VALID_EVICTION_POLICIES ), length = 0 
106107        )
107108    )
109+     ptxas_supported : bool  =  False 
108110
109111    @staticmethod  
110112    def  _valid_indexing_types () ->  tuple [IndexingLiteral , ...]:
@@ -231,6 +233,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
231233            else :
232234                config [name ] =  values [0 ]
233235
236+         if  "advanced_compiler_configuration"  in  config :
237+             value  =  config .get ("advanced_compiler_configuration" ) or  0 
238+             if  not  isinstance (value , int ):
239+                 raise  InvalidConfig (
240+                     f"advanced_compiler_configuration must be integer, got { value !r}  
241+                 )
242+             if  value  and  not  self .ptxas_supported :
243+                 raise  InvalidConfig (
244+                     "advanced_compiler_configuration requires PTXAS support" 
245+                 )
246+             config ["advanced_compiler_configuration" ] =  value 
247+ 
234248        # Set default values for grid indices when pid_type is not persistent 
235249        pid_type  =  config ["pid_type" ]
236250        if  pid_type  in  ("flat" , "xyz" ) and  self .grid_block_ids :
@@ -270,8 +284,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
270284    def  default_config (self ) ->  helion .Config :
271285        return  self .flat_config (lambda  x : x .default ())
272286
273-     def  flat_config (self , fn : Callable [[ConfigSpecFragment ], object ]) ->  helion .Config :
287+     def  flat_config (
288+         self ,
289+         fn : Callable [[ConfigSpecFragment ], object ],
290+         * ,
291+         include_advanced_compiler_configuration : bool  |  None  =  None ,
292+     ) ->  helion .Config :
274293        """Map a flattened version of the config using the given function.""" 
294+         include_advanced  =  self .ptxas_supported 
295+         if  include_advanced_compiler_configuration  is  not None :
296+             include_advanced  =  (
297+                 include_advanced  and  include_advanced_compiler_configuration 
298+             )
275299        config  =  {
276300            "block_sizes" : self .block_sizes ._flat_config (self , fn ),
277301            "loop_orders" : self .loop_orders ._flat_config (self , fn ),
@@ -290,6 +314,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
290314            "pid_type" : fn (EnumFragment (self .allowed_pid_types )),
291315            "load_eviction_policies" : fn (self .load_eviction_policies ),
292316        }
317+         if  include_advanced :
318+             from  ..runtime .ptxas_configs  import  search_ptxas_configs 
319+ 
320+             config ["advanced_compiler_configuration" ] =  fn (
321+                 EnumFragment ((0 , * search_ptxas_configs ()))
322+             )
293323        # Add tunable parameters 
294324        config .update (
295325            {key : fn (fragment ) for  key , fragment  in  self .user_defined_tunables .items ()}
0 commit comments