5252        "pid_type" ,
5353        "indexing" ,
5454        "load_eviction_policies" ,
55+         "advanced_compiler_configuration" ,
5556    ]
5657)
5758VALID_PID_TYPES  =  ("flat" , "xyz" , "persistent_blocked" , "persistent_interleaved" )
@@ -105,6 +106,7 @@ class ConfigSpec:
105106            EnumFragment (choices = VALID_EVICTION_POLICIES ), length = 0 
106107        )
107108    )
109+     ptxas_supported : bool  =  False 
108110
109111    @staticmethod  
110112    def  _valid_indexing_types () ->  tuple [IndexingLiteral , ...]:
@@ -238,6 +240,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
238240            else :
239241                config [name ] =  values [0 ]
240242
243+         if  "advanced_compiler_configuration"  in  config :
244+             value  =  config .get ("advanced_compiler_configuration" ) or  0 
245+             if  not  isinstance (value , int ):
246+                 raise  InvalidConfig (
247+                     f"advanced_compiler_configuration must be integer, got { value !r}  
248+                 )
249+             if  value  and  not  self .ptxas_supported :
250+                 raise  InvalidConfig (
251+                     "advanced_compiler_configuration requires PTXAS support" 
252+                 )
253+             config ["advanced_compiler_configuration" ] =  value 
254+ 
241255        # Set default values for grid indices when pid_type is not persistent 
242256        pid_type  =  config ["pid_type" ]
243257        if  pid_type  in  ("flat" , "xyz" ) and  self .grid_block_ids :
@@ -260,8 +274,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
260274    def  default_config (self ) ->  helion .Config :
261275        return  self .flat_config (lambda  x : x .default ())
262276
263-     def  flat_config (self , fn : Callable [[ConfigSpecFragment ], object ]) ->  helion .Config :
277+     def  flat_config (
278+         self ,
279+         fn : Callable [[ConfigSpecFragment ], object ],
280+         * ,
281+         include_advanced_compiler_configuration : bool  |  None  =  None ,
282+     ) ->  helion .Config :
264283        """Map a flattened version of the config using the given function.""" 
284+         include_advanced  =  self .ptxas_supported 
285+         if  include_advanced_compiler_configuration  is  not None :
286+             include_advanced  =  (
287+                 include_advanced  and  include_advanced_compiler_configuration 
288+             )
265289        config  =  {
266290            "block_sizes" : self .block_sizes ._flat_config (self , fn ),
267291            "loop_orders" : self .loop_orders ._flat_config (self , fn ),
@@ -280,6 +304,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
280304            "pid_type" : fn (EnumFragment (self .allowed_pid_types )),
281305            "load_eviction_policies" : fn (self .load_eviction_policies ),
282306        }
307+         if  include_advanced :
308+             from  ..runtime .ptxas_configs  import  search_ptxas_configs 
309+ 
310+             config ["advanced_compiler_configuration" ] =  fn (
311+                 EnumFragment ((0 , * search_ptxas_configs ()))
312+             )
283313        # Add tunable parameters 
284314        config .update (
285315            {key : fn (fragment ) for  key , fragment  in  self .user_defined_tunables .items ()}
0 commit comments