77import  time 
88from  typing  import  TYPE_CHECKING 
99from  typing  import  Callable 
10+ from  typing  import  Collection 
1011from  typing  import  Literal 
1112from  typing  import  Protocol 
1213from  typing  import  Sequence 
@@ -36,6 +37,45 @@ def __call__(
3637        ) ->  BaseAutotuner : ...
3738
3839
40+ def  _validate_enum_setting (
41+     value : object ,
42+     * ,
43+     name : str ,
44+     valid : Collection [str ],
45+     allow_none : bool  =  True ,
46+ ) ->  str  |  None :
47+     """Normalize and validate an enum setting. 
48+ 
49+     Args: 
50+         value: The value to normalize and validate 
51+         name: Name of the setting 
52+         valid: Collection of valid settings 
53+         allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error. 
54+     """ 
55+     # String values that should be treated as None 
56+     _NONE_VALUES  =  frozenset ({"" , "0" , "false" , "none" })
57+ 
58+     # Normalize values 
59+     normalized : str  |  None 
60+     if  isinstance (value , str ):
61+         normalized  =  value .strip ().lower ()
62+     else :
63+         normalized  =  None 
64+ 
65+     is_none_value  =  normalized  is  None  or  normalized  in  _NONE_VALUES 
66+     is_valid  =  normalized  in  valid  if  normalized  else  False 
67+ 
68+     # Valid value (none or valid setting) 
69+     if  is_none_value  and  allow_none :
70+         return  None 
71+     if  is_valid :
72+         return  normalized 
73+ 
74+     # Invalid value, raise error 
75+     valid_list  =  "', '" .join (sorted (valid ))
76+     raise  ValueError (f"{ name } { valid_list } { value !r}  )
77+ 
78+ 
3979_tls : _TLS  =  cast ("_TLS" , threading .local ())
4080
4181
@@ -108,63 +148,6 @@ def default_autotuner_fn(
108148    return  LocalAutotuneCache (autotuner_cls (bound_kernel , args , ** kwargs ))  # pyright: ignore[reportArgumentType] 
109149
110150
111- def  _get_autotune_random_seed () ->  int :
112-     value  =  os .environ .get ("HELION_AUTOTUNE_RANDOM_SEED" )
113-     if  value  is  not None :
114-         return  int (value )
115-     return  int (time .time () *  1000 ) %  2 ** 32 
116- 
117- 
118- def  _get_autotune_max_generations () ->  int  |  None :
119-     value  =  os .environ .get ("HELION_AUTOTUNE_MAX_GENERATIONS" )
120-     if  value  is  not None :
121-         return  int (value )
122-     return  None 
123- 
124- 
125- def  _get_autotune_rebenchmark_threshold () ->  float  |  None :
126-     value  =  os .environ .get ("HELION_REBENCHMARK_THRESHOLD" )
127-     if  value  is  not None :
128-         return  float (value )
129-     return  None   # Will use effort profile default 
130- 
131- 
132- def  _normalize_autotune_effort (value : object ) ->  AutotuneEffort :
133-     if  isinstance (value , str ):
134-         normalized  =  value .lower ()
135-         if  normalized  in  _PROFILES :
136-             return  cast ("AutotuneEffort" , normalized )
137-     raise  ValueError ("autotune_effort must be one of 'none', 'quick', or 'full'" )
138- 
139- 
140- def  _get_autotune_effort () ->  AutotuneEffort :
141-     return  _normalize_autotune_effort (os .environ .get ("HELION_AUTOTUNE_EFFORT" , "full" ))
142- 
143- 
144- def  _get_autotune_precompile () ->  str  |  None :
145-     value  =  os .environ .get ("HELION_AUTOTUNE_PRECOMPILE" )
146-     if  value  is  None :
147-         return  "spawn" 
148-     mode  =  value .strip ().lower ()
149-     if  mode  in  {"" , "0" , "false" , "none" }:
150-         return  None 
151-     if  mode  in  {"spawn" , "fork" }:
152-         return  mode 
153-     raise  ValueError (
154-         "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile" 
155-     )
156- 
157- 
158- def  _get_autotune_precompile_jobs () ->  int  |  None :
159-     value  =  os .environ .get ("HELION_AUTOTUNE_PRECOMPILE_JOBS" )
160-     if  value  is  None  or  value .strip () ==  "" :
161-         return  None 
162-     jobs  =  int (value )
163-     if  jobs  <=  0 :
164-         raise  ValueError ("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer" )
165-     return  jobs 
166- 
167- 
168151@dataclasses .dataclass  
169152class  _Settings :
170153    # see __slots__ below for the doc strings that show up in help(Settings) 
@@ -182,33 +165,45 @@ class _Settings:
182165        os .environ .get ("HELION_AUTOTUNE_COMPILE_TIMEOUT" , "60" )
183166    )
184167    autotune_precompile : str  |  None  =  dataclasses .field (
185-         default_factory = _get_autotune_precompile 
168+         default_factory = lambda :  os . environ . get ( "HELION_AUTOTUNE_PRECOMPILE" ,  "spawn" ) 
186169    )
187170    autotune_precompile_jobs : int  |  None  =  dataclasses .field (
188-         default_factory = _get_autotune_precompile_jobs 
171+         default_factory = lambda : int (v )
172+         if  (v  :=  os .environ .get ("HELION_AUTOTUNE_PRECOMPILE_JOBS" ))
173+         else  None 
189174    )
190175    autotune_random_seed : int  =  dataclasses .field (
191-         default_factory = _get_autotune_random_seed 
176+         default_factory = lambda : (
177+             int (v )
178+             if  (v  :=  os .environ .get ("HELION_AUTOTUNE_RANDOM_SEED" ))
179+             else  int (time .time () *  1000 ) %  2 ** 32 
180+         )
192181    )
193182    autotune_accuracy_check : bool  =  (
194183        os .environ .get ("HELION_AUTOTUNE_ACCURACY_CHECK" , "1" ) ==  "1" 
195184    )
196185    autotune_rebenchmark_threshold : float  |  None  =  dataclasses .field (
197-         default_factory = _get_autotune_rebenchmark_threshold 
186+         default_factory = lambda : float (v )
187+         if  (v  :=  os .environ .get ("HELION_REBENCHMARK_THRESHOLD" ))
188+         else  None 
198189    )
199190    autotune_progress_bar : bool  =  (
200191        os .environ .get ("HELION_AUTOTUNE_PROGRESS_BAR" , "1" ) ==  "1" 
201192    )
202193    autotune_max_generations : int  |  None  =  dataclasses .field (
203-         default_factory = _get_autotune_max_generations 
194+         default_factory = lambda : int (v )
195+         if  (v  :=  os .environ .get ("HELION_AUTOTUNE_MAX_GENERATIONS" ))
196+         else  None 
204197    )
205198    print_output_code : bool  =  os .environ .get ("HELION_PRINT_OUTPUT_CODE" , "0" ) ==  "1" 
206199    force_autotune : bool  =  os .environ .get ("HELION_FORCE_AUTOTUNE" , "0" ) ==  "1" 
207200    autotune_config_overrides : dict [str , object ] =  dataclasses .field (
208201        default_factory = dict 
209202    )
210203    autotune_effort : AutotuneEffort  =  dataclasses .field (
211-         default_factory = _get_autotune_effort 
204+         default_factory = lambda : cast (
205+             "AutotuneEffort" , os .environ .get ("HELION_AUTOTUNE_EFFORT" , "full" )
206+         )
212207    )
213208    allow_warp_specialize : bool  =  (
214209        os .environ .get ("HELION_ALLOW_WARP_SPECIALIZE" , "1" ) ==  "1" 
@@ -220,35 +215,43 @@ class _Settings:
220215    autotuner_fn : AutotunerFunction  =  default_autotuner_fn 
221216
222217    def  __post_init__ (self ) ->  None :
223-         def  _is_bool (val : object ) ->  bool :
224-             return  isinstance (val , bool )
225- 
226-         def  _is_non_negative_int (val : object ) ->  bool :
227-             return  isinstance (val , int ) and  val  >=  0 
218+         # Validate all user settings 
219+ 
220+         self .autotune_effort  =  cast (
221+             "AutotuneEffort" ,
222+             _validate_enum_setting (
223+                 self .autotune_effort ,
224+                 name = "autotune_effort" ,
225+                 valid = _PROFILES .keys (),
226+                 allow_none = False ,  # do not allow None as "none" is a non-default setting 
227+             ),
228+         )
229+         self .autotune_precompile  =  _validate_enum_setting (
230+             self .autotune_precompile ,
231+             name = "autotune_precompile" ,
232+             valid = {"spawn" , "fork" },
233+         )
228234
229-         # Validate user settings 
230235        validators : dict [str , Callable [[object ], bool ]] =  {
231-             "autotune_log_level" : _is_non_negative_int ,
232-             "autotune_compile_timeout" : _is_non_negative_int ,
233-             "autotune_precompile" : lambda  v : v  in  (None , "spawn" , "fork" ),
234-             "autotune_precompile_jobs" : lambda  v : v  is  None  or  _is_non_negative_int (v ),
235-             "autotune_accuracy_check" : _is_bool ,
236-             "autotune_progress_bar" : _is_bool ,
237-             "autotune_max_generations" : lambda  v : v  is  None  or  _is_non_negative_int (v ),
238-             "print_output_code" : _is_bool ,
239-             "force_autotune" : _is_bool ,
240-             "allow_warp_specialize" : _is_bool ,
241-             "debug_dtype_asserts" : _is_bool ,
236+             "autotune_log_level" : lambda  v : isinstance (v , int ) and  v  >=  0 ,
237+             "autotune_compile_timeout" : lambda  v : isinstance (v , int ) and  v  >  0 ,
238+             "autotune_precompile_jobs" : lambda  v : v  is  None 
239+             or  (isinstance (v , int ) and  v  >  0 ),
240+             "autotune_accuracy_check" : lambda  v : isinstance (v , bool ),
241+             "autotune_progress_bar" : lambda  v : isinstance (v , bool ),
242+             "autotune_max_generations" : lambda  v : v  is  None 
243+             or  (isinstance (v , int ) and  v  >=  0 ),
244+             "print_output_code" : lambda  v : isinstance (v , bool ),
245+             "force_autotune" : lambda  v : isinstance (v , bool ),
246+             "allow_warp_specialize" : lambda  v : isinstance (v , bool ),
247+             "debug_dtype_asserts" : lambda  v : isinstance (v , bool ),
242248            "autotune_rebenchmark_threshold" : lambda  v : v  is  None 
243249            or  (isinstance (v , (int , float )) and  v  >=  0 ),
244250        }
245251
246-         normalized_effort  =  _normalize_autotune_effort (self .autotune_effort )
247-         object .__setattr__ (self , "autotune_effort" , normalized_effort )
248- 
249-         for  field_name , checker  in  validators .items ():
252+         for  field_name , validator  in  validators .items ():
250253            value  =  getattr (self , field_name )
251-             if  not  checker (value ):
254+             if  not  validator (value ):
252255                raise  ValueError (f"Invalid value for { field_name } { value !r}  )
253256
254257
0 commit comments