Skip to content

Commit 29b9304

Browse files
committed
Code simplifications for getting and validating user settings.
1 parent 4f4f66d commit 29b9304

File tree

2 files changed

+94
-89
lines changed

2 files changed

+94
-89
lines changed

helion/runtime/settings.py

Lines changed: 88 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88
from typing import TYPE_CHECKING
99
from typing import Callable
10+
from typing import Collection
1011
from typing import Literal
1112
from typing import Protocol
1213
from typing import Sequence
@@ -36,6 +37,45 @@ def __call__(
3637
) -> BaseAutotuner: ...
3738

3839

40+
def _validate_enum_setting(
41+
value: object,
42+
*,
43+
name: str,
44+
valid: Collection[str],
45+
allow_none: bool = True,
46+
) -> str | None:
47+
"""Normalize and validate an enum setting.
48+
49+
Args:
50+
value: The value to normalize and validate
51+
name: Name of the setting
52+
valid: Collection of valid settings
53+
allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error.
54+
"""
55+
# String values that should be treated as None
56+
_NONE_VALUES = frozenset({"", "0", "false", "none"})
57+
58+
# Normalize values
59+
normalized: str | None
60+
if isinstance(value, str):
61+
normalized = value.strip().lower()
62+
else:
63+
normalized = None
64+
65+
is_none_value = normalized is None or normalized in _NONE_VALUES
66+
is_valid = normalized in valid if normalized else False
67+
68+
# Valid value (none or valid setting)
69+
if is_none_value and allow_none:
70+
return None
71+
if is_valid:
72+
return normalized
73+
74+
# Invalid value, raise error
75+
valid_list = "', '".join(sorted(valid))
76+
raise ValueError(f"{name} must be one of '{valid_list}', got {value!r}")
77+
78+
3979
_tls: _TLS = cast("_TLS", threading.local())
4080

4181

@@ -108,63 +148,6 @@ def default_autotuner_fn(
108148
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
109149

110150

111-
def _get_autotune_random_seed() -> int:
112-
value = os.environ.get("HELION_AUTOTUNE_RANDOM_SEED")
113-
if value is not None:
114-
return int(value)
115-
return int(time.time() * 1000) % 2**32
116-
117-
118-
def _get_autotune_max_generations() -> int | None:
119-
value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS")
120-
if value is not None:
121-
return int(value)
122-
return None
123-
124-
125-
def _get_autotune_rebenchmark_threshold() -> float | None:
126-
value = os.environ.get("HELION_REBENCHMARK_THRESHOLD")
127-
if value is not None:
128-
return float(value)
129-
return None # Will use effort profile default
130-
131-
132-
def _normalize_autotune_effort(value: object) -> AutotuneEffort:
133-
if isinstance(value, str):
134-
normalized = value.lower()
135-
if normalized in _PROFILES:
136-
return cast("AutotuneEffort", normalized)
137-
raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'")
138-
139-
140-
def _get_autotune_effort() -> AutotuneEffort:
141-
return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))
142-
143-
144-
def _get_autotune_precompile() -> str | None:
145-
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE")
146-
if value is None:
147-
return "spawn"
148-
mode = value.strip().lower()
149-
if mode in {"", "0", "false", "none"}:
150-
return None
151-
if mode in {"spawn", "fork"}:
152-
return mode
153-
raise ValueError(
154-
"HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile"
155-
)
156-
157-
158-
def _get_autotune_precompile_jobs() -> int | None:
159-
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")
160-
if value is None or value.strip() == "":
161-
return None
162-
jobs = int(value)
163-
if jobs <= 0:
164-
raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer")
165-
return jobs
166-
167-
168151
@dataclasses.dataclass
169152
class _Settings:
170153
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -182,33 +165,45 @@ class _Settings:
182165
os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60")
183166
)
184167
autotune_precompile: str | None = dataclasses.field(
185-
default_factory=_get_autotune_precompile
168+
default_factory=lambda: os.environ.get("HELION_AUTOTUNE_PRECOMPILE", "spawn")
186169
)
187170
autotune_precompile_jobs: int | None = dataclasses.field(
188-
default_factory=_get_autotune_precompile_jobs
171+
default_factory=lambda: int(v)
172+
if (v := os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS"))
173+
else None
189174
)
190175
autotune_random_seed: int = dataclasses.field(
191-
default_factory=_get_autotune_random_seed
176+
default_factory=lambda: (
177+
int(v)
178+
if (v := os.environ.get("HELION_AUTOTUNE_RANDOM_SEED"))
179+
else int(time.time() * 1000) % 2**32
180+
)
192181
)
193182
autotune_accuracy_check: bool = (
194183
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
195184
)
196185
autotune_rebenchmark_threshold: float | None = dataclasses.field(
197-
default_factory=_get_autotune_rebenchmark_threshold
186+
default_factory=lambda: float(v)
187+
if (v := os.environ.get("HELION_REBENCHMARK_THRESHOLD"))
188+
else None
198189
)
199190
autotune_progress_bar: bool = (
200191
os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1"
201192
)
202193
autotune_max_generations: int | None = dataclasses.field(
203-
default_factory=_get_autotune_max_generations
194+
default_factory=lambda: int(v)
195+
if (v := os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS"))
196+
else None
204197
)
205198
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
206199
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
207200
autotune_config_overrides: dict[str, object] = dataclasses.field(
208201
default_factory=dict
209202
)
210203
autotune_effort: AutotuneEffort = dataclasses.field(
211-
default_factory=_get_autotune_effort
204+
default_factory=lambda: cast(
205+
"AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")
206+
)
212207
)
213208
allow_warp_specialize: bool = (
214209
os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1"
@@ -220,35 +215,43 @@ class _Settings:
220215
autotuner_fn: AutotunerFunction = default_autotuner_fn
221216

222217
def __post_init__(self) -> None:
223-
def _is_bool(val: object) -> bool:
224-
return isinstance(val, bool)
225-
226-
def _is_non_negative_int(val: object) -> bool:
227-
return isinstance(val, int) and val >= 0
218+
# Validate all user settings
219+
220+
self.autotune_effort = cast(
221+
"AutotuneEffort",
222+
_validate_enum_setting(
223+
self.autotune_effort,
224+
name="autotune_effort",
225+
valid=_PROFILES.keys(),
226+
allow_none=False, # do not allow None as "none" is a non-default setting
227+
),
228+
)
229+
self.autotune_precompile = _validate_enum_setting(
230+
self.autotune_precompile,
231+
name="autotune_precompile",
232+
valid={"spawn", "fork"},
233+
)
228234

229-
# Validate user settings
230235
validators: dict[str, Callable[[object], bool]] = {
231-
"autotune_log_level": _is_non_negative_int,
232-
"autotune_compile_timeout": _is_non_negative_int,
233-
"autotune_precompile": lambda v: v in (None, "spawn", "fork"),
234-
"autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v),
235-
"autotune_accuracy_check": _is_bool,
236-
"autotune_progress_bar": _is_bool,
237-
"autotune_max_generations": lambda v: v is None or _is_non_negative_int(v),
238-
"print_output_code": _is_bool,
239-
"force_autotune": _is_bool,
240-
"allow_warp_specialize": _is_bool,
241-
"debug_dtype_asserts": _is_bool,
236+
"autotune_log_level": lambda v: isinstance(v, int) and v >= 0,
237+
"autotune_compile_timeout": lambda v: isinstance(v, int) and v > 0,
238+
"autotune_precompile_jobs": lambda v: v is None
239+
or (isinstance(v, int) and v > 0),
240+
"autotune_accuracy_check": lambda v: isinstance(v, bool),
241+
"autotune_progress_bar": lambda v: isinstance(v, bool),
242+
"autotune_max_generations": lambda v: v is None
243+
or (isinstance(v, int) and v >= 0),
244+
"print_output_code": lambda v: isinstance(v, bool),
245+
"force_autotune": lambda v: isinstance(v, bool),
246+
"allow_warp_specialize": lambda v: isinstance(v, bool),
247+
"debug_dtype_asserts": lambda v: isinstance(v, bool),
242248
"autotune_rebenchmark_threshold": lambda v: v is None
243249
or (isinstance(v, (int, float)) and v >= 0),
244250
}
245251

246-
normalized_effort = _normalize_autotune_effort(self.autotune_effort)
247-
object.__setattr__(self, "autotune_effort", normalized_effort)
248-
249-
for field_name, checker in validators.items():
252+
for field_name, validator in validators.items():
250253
value = getattr(self, field_name)
251-
if not checker(value):
254+
if not validator(value):
252255
raise ValueError(f"Invalid value for {field_name}: {value!r}")
253256

254257

test/test_settings.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
class TestSettingsValidation(unittest.TestCase):
99
def test_autotune_effort_none_raises(self) -> None:
10-
with self.assertRaisesRegex(
11-
ValueError, "autotune_effort must be one of 'none', 'quick', or 'full'"
12-
):
10+
with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"):
1311
helion.Settings(autotune_effort=None)
1412

15-
def test_autotune_effort_quick_normalized(self) -> None:
13+
def test_autotune_effort_case_insensitive(self) -> None:
1614
settings = helion.Settings(autotune_effort="Quick")
1715
self.assertEqual(settings.autotune_effort, "quick")
1816

@@ -33,3 +31,7 @@ def test_autotune_max_generations_negative_raises(self) -> None:
3331
ValueError, r"Invalid value for autotune_max_generations: -1"
3432
):
3533
helion.Settings(autotune_max_generations=-1)
34+
35+
def test_autotune_effort_invalid_raises(self) -> None:
36+
with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"):
37+
helion.Settings(autotune_effort="super-fast")

0 commit comments

Comments
 (0)