4444 "range_num_stages" ,
4545 "range_multi_buffers" ,
4646 "range_flattens" ,
47+ "static_ranges" ,
4748 "num_warps" ,
4849 "num_stages" ,
4950 "pid_type" ,
@@ -85,6 +86,9 @@ class ConfigSpec:
8586 range_flattens : BlockIdSequence [RangeFlattenSpec ] = dataclasses .field (
8687 default_factory = BlockIdSequence
8788 )
89+ static_ranges : BlockIdSequence [StaticRangeSpec ] = dataclasses .field (
90+ default_factory = BlockIdSequence
91+ )
8892 user_defined_tunables : dict [str , ConfigSpecFragment ] = dataclasses .field (
8993 default_factory = dict
9094 )
@@ -109,6 +113,7 @@ def _remove_duplicates(self) -> None:
109113 self .range_num_stages ._remove_duplicates ()
110114 self .range_multi_buffers ._remove_duplicates ()
111115 self .range_flattens ._remove_duplicates ()
116+ self .static_ranges ._remove_duplicates ()
112117
113118 def disallow_pid_type (self , pid_type : PidTypeLiteral ) -> None :
114119 """Disallow a pid_type from being used in the config."""
@@ -135,6 +140,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
135140 "range_num_stage" ,
136141 "range_multi_buffer" ,
137142 "range_flatten" ,
143+ "static_range" ,
138144 ):
139145 if name in config :
140146 names = f"{ name } s"
@@ -153,11 +159,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
153159 ("range_num_stages" , self .range_num_stages , True ),
154160 ("range_multi_buffers" , self .range_multi_buffers , True ),
155161 ("range_flattens" , self .range_flattens , True ),
162+ ("static_ranges" , self .static_ranges , True ),
156163 ]:
157164 config [name ] = mapping ._normalize (
158165 name , config .get (name , ()), flatten = flatten
159166 )
160167
168+ static_range_block_ids = []
169+ for block_id in self .static_ranges .valid_block_ids ():
170+ use_static_range = self .static_ranges .config_get (
171+ config .get ("static_ranges" , ()), # pyre-ignore[6]
172+ block_id ,
173+ )
174+ if use_static_range :
175+ static_range_block_ids .append (block_id )
176+
177+ for name , mapping in (
178+ ("range_unroll_factors" , self .range_unroll_factors ),
179+ ("range_warp_specializes" , self .range_warp_specialize ),
180+ ("range_num_stages" , self .range_num_stages ),
181+ ("range_multi_buffers" , self .range_multi_buffers ),
182+ ("range_flattens" , self .range_flattens ),
183+ ):
184+ config [name ] = mapping ._reset_config_to_default (
185+ name , config .get (name , ()), block_ids = static_range_block_ids
186+ )
187+
161188 for name in (
162189 "loop_orders" ,
163190 "l2_groupings" ,
@@ -168,6 +195,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
168195 "range_num_stages" ,
169196 "range_multi_buffers" ,
170197 "range_flattens" ,
198+ "static_ranges" ,
171199 ):
172200 if not config [name ]:
173201 config .pop (name )
@@ -209,6 +237,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
209237 "range_num_stages" : self .range_num_stages ._flat_config (self , fn ),
210238 "range_multi_buffers" : self .range_multi_buffers ._flat_config (self , fn ),
211239 "range_flattens" : self .range_flattens ._flat_config (self , fn ),
240+ "static_ranges" : self .static_ranges ._flat_config (self , fn ),
212241 "num_warps" : fn (NumWarpsFragment (1 , 32 , DEFAULT_NUM_WARPS )),
213242 "num_stages" : fn (IntegerFragment (1 , 8 , DEFAULT_NUM_STAGES )),
214243 "indexing" : fn (EnumFragment (self ._valid_indexing_types ())),
@@ -228,6 +257,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
228257 "range_num_stages" ,
229258 "range_multi_buffers" ,
230259 "range_flattens" ,
260+ "static_ranges" ,
231261 ):
232262 if not config [name ]:
233263 config .pop (name )
@@ -416,6 +446,20 @@ class RangeFlattenSpec(_OptionalBoolSpec):
416446 pass
417447
418448
449+ class StaticRangeSpec (_BlockIdItem ):
450+ def _fragment (self , base : ConfigSpec ) -> BooleanFragment :
451+ return BooleanFragment ()
452+
453+ def _normalize (self , name : str , value : object ) -> bool :
454+ if not isinstance (value , bool ):
455+ raise InvalidConfig (f"{ name } must be a boolean, got { value !r} " )
456+ return value
457+
458+ def _fill_missing (self ) -> bool :
459+ """Provide a value when not provided by the user."""
460+ return False
461+
462+
419463def _product (seq : Sequence [int ]) -> int :
420464 """Return the product of the elements in the sequence."""
421465 return functools .reduce (operator .mul , seq , 1 )
0 commit comments