5252 "pid_type" ,
5353 "indexing" ,
5454 "load_eviction_policies" ,
55+ "advanced_compiler_configuration" ,
5556 ]
5657)
5758VALID_PID_TYPES = ("flat" , "xyz" , "persistent_blocked" , "persistent_interleaved" )
@@ -105,6 +106,7 @@ class ConfigSpec:
105106 EnumFragment (choices = VALID_EVICTION_POLICIES ), length = 0
106107 )
107108 )
109+ ptxas_supported : bool = False
108110
109111 @staticmethod
110112 def _valid_indexing_types () -> tuple [IndexingLiteral , ...]:
@@ -238,6 +240,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
238240 else :
239241 config [name ] = values [0 ]
240242
243+ if "advanced_compiler_configuration" in config :
244+ value = config .get ("advanced_compiler_configuration" ) or 0
245+ if not isinstance (value , int ):
246+ raise InvalidConfig (
247+ f"advanced_compiler_configuration must be integer, got { value !r} "
248+ )
249+ if value and not self .ptxas_supported :
250+ raise InvalidConfig (
251+ "advanced_compiler_configuration requires PTXAS support"
252+ )
253+ config ["advanced_compiler_configuration" ] = value
254+
241255 # Set default values for grid indices when pid_type is not persistent
242256 pid_type = config ["pid_type" ]
243257 if pid_type in ("flat" , "xyz" ) and self .grid_block_ids :
@@ -260,8 +274,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
260274 def default_config (self ) -> helion .Config :
261275 return self .flat_config (lambda x : x .default ())
262276
263- def flat_config (self , fn : Callable [[ConfigSpecFragment ], object ]) -> helion .Config :
277+ def flat_config (
278+ self ,
279+ fn : Callable [[ConfigSpecFragment ], object ],
280+ * ,
281+ include_advanced_compiler_configuration : bool | None = None ,
282+ ) -> helion .Config :
264283 """Map a flattened version of the config using the given function."""
284+ include_advanced = self .ptxas_supported
285+ if include_advanced_compiler_configuration is not None :
286+ include_advanced = (
287+ include_advanced and include_advanced_compiler_configuration
288+ )
265289 config = {
266290 "block_sizes" : self .block_sizes ._flat_config (self , fn ),
267291 "loop_orders" : self .loop_orders ._flat_config (self , fn ),
@@ -280,6 +304,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
280304 "pid_type" : fn (EnumFragment (self .allowed_pid_types )),
281305 "load_eviction_policies" : fn (self .load_eviction_policies ),
282306 }
307+ if include_advanced :
308+ from ..runtime .ptxas_configs import search_ptxas_configs
309+
310+ config ["advanced_compiler_configuration" ] = fn (
311+ EnumFragment ((0 , * search_ptxas_configs ()))
312+ )
283313 # Add tunable parameters
284314 config .update (
285315 {key : fn (fragment ) for key , fragment in self .user_defined_tunables .items ()}
0 commit comments