Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion helion/autotuner/base_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
counters["autotune"]["cache_miss"] += 1
log.debug("cache miss")

self.autotuner.log("Starting autotuning process, this may take a while...")
effort = self.kernel.settings.autotune_effort
self.autotuner.log(
f"Starting autotuning process with effort={effort}, this may take a while..."
)

config = self.autotuner.autotune()

Expand Down
1 change: 1 addition & 0 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ def _wait_for_all_step(
) -> list[PrecompileFuture]:
"""Start up to the concurrency cap, wait for progress, and return remaining futures."""
cap = futures[0].search._jobs if futures else 1
assert cap > 0, "autotune_precompile_jobs must be positive"
running = [f for f in futures if f.started and f.ok is None and f.is_alive()]

# Start queued futures up to the cap
Expand Down
2 changes: 1 addition & 1 deletion helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def _implicit_config(self) -> Config | None:
if not is_ref_mode_enabled(self.kernel.settings):
kernel_decorator = self.format_kernel_decorator(config, self.settings)
print(
f"Using default config: {kernel_decorator}",
f"Using default config (autotune_effort=none): {kernel_decorator}",
file=sys.stderr,
)
return config
Expand Down
155 changes: 100 additions & 55 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import threading
import time
from typing import TYPE_CHECKING
from typing import Callable
from typing import Collection
from typing import Literal
from typing import Protocol
from typing import Sequence
Expand All @@ -15,6 +17,7 @@
from torch._environment import is_fbcode

from .. import exc
from ..autotuner.effort_profile import _PROFILES
from ..autotuner.effort_profile import AutotuneEffort
from ..autotuner.effort_profile import get_effort_profile
from .ref_mode import RefMode
Expand All @@ -34,6 +37,45 @@ def __call__(
) -> BaseAutotuner: ...


def _validate_enum_setting(
value: object,
*,
name: str,
valid: Collection[str],
allow_none: bool = True,
) -> str | None:
"""Normalize and validate an enum setting.

Args:
value: The value to normalize and validate
name: Name of the setting
valid: Collection of valid settings
allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error.
"""
# String values that should be treated as None
_NONE_VALUES = frozenset({"", "0", "false", "none"})

# Normalize values
normalized: str | None
if isinstance(value, str):
normalized = value.strip().lower()
else:
normalized = None

is_none_value = normalized is None or normalized in _NONE_VALUES
is_valid = normalized in valid if normalized else False

# Valid value (none or valid setting)
if is_none_value and allow_none:
return None
if is_valid:
return normalized

# Invalid value, raise error
valid_list = "', '".join(sorted(valid))
raise ValueError(f"{name} must be one of '{valid_list}', got {value!r}")


_tls: _TLS = cast("_TLS", threading.local())


Expand Down Expand Up @@ -106,55 +148,6 @@ def default_autotuner_fn(
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]


def _get_autotune_random_seed() -> int:
value = os.environ.get("HELION_AUTOTUNE_RANDOM_SEED")
if value is not None:
return int(value)
return int(time.time() * 1000) % 2**32


def _get_autotune_max_generations() -> int | None:
value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS")
if value is not None:
return int(value)
return None


def _get_autotune_rebenchmark_threshold() -> float | None:
value = os.environ.get("HELION_REBENCHMARK_THRESHOLD")
if value is not None:
return float(value)
return None # Will use effort profile default


def _get_autotune_effort() -> AutotuneEffort:
return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))


def _get_autotune_precompile() -> str | None:
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE")
if value is None:
return "spawn"
mode = value.strip().lower()
if mode in {"", "0", "false", "none"}:
return None
if mode in {"spawn", "fork"}:
return mode
raise ValueError(
"HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile"
)


def _get_autotune_precompile_jobs() -> int | None:
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")
if value is None or value.strip() == "":
return None
jobs = int(value)
if jobs <= 0:
raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer")
return jobs


@dataclasses.dataclass
class _Settings:
# see __slots__ below for the doc strings that show up in help(Settings)
Expand All @@ -172,33 +165,45 @@ class _Settings:
os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60")
)
autotune_precompile: str | None = dataclasses.field(
default_factory=_get_autotune_precompile
default_factory=lambda: os.environ.get("HELION_AUTOTUNE_PRECOMPILE", "spawn")
)
autotune_precompile_jobs: int | None = dataclasses.field(
default_factory=_get_autotune_precompile_jobs
default_factory=lambda: int(v)
if (v := os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS"))
else None
)
autotune_random_seed: int = dataclasses.field(
default_factory=_get_autotune_random_seed
default_factory=lambda: (
int(v)
if (v := os.environ.get("HELION_AUTOTUNE_RANDOM_SEED"))
else int(time.time() * 1000) % 2**32
)
)
autotune_accuracy_check: bool = (
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
)
autotune_rebenchmark_threshold: float | None = dataclasses.field(
default_factory=_get_autotune_rebenchmark_threshold
default_factory=lambda: float(v)
if (v := os.environ.get("HELION_REBENCHMARK_THRESHOLD"))
else None
)
autotune_progress_bar: bool = (
os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1"
)
autotune_max_generations: int | None = dataclasses.field(
default_factory=_get_autotune_max_generations
default_factory=lambda: int(v)
if (v := os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS"))
else None
)
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
autotune_config_overrides: dict[str, object] = dataclasses.field(
default_factory=dict
)
autotune_effort: AutotuneEffort = dataclasses.field(
default_factory=_get_autotune_effort
default_factory=lambda: cast(
"AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")
)
)
allow_warp_specialize: bool = (
os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1"
Expand All @@ -209,6 +214,46 @@ class _Settings:
)
autotuner_fn: AutotunerFunction = default_autotuner_fn

def __post_init__(self) -> None:
# Validate all user settings

self.autotune_effort = cast(
"AutotuneEffort",
_validate_enum_setting(
self.autotune_effort,
name="autotune_effort",
valid=_PROFILES.keys(),
allow_none=False, # do not allow None as "none" is a non-default setting
),
)
self.autotune_precompile = _validate_enum_setting(
self.autotune_precompile,
name="autotune_precompile",
valid={"spawn", "fork"},
)

validators: dict[str, Callable[[object], bool]] = {
"autotune_log_level": lambda v: isinstance(v, int) and v >= 0,
"autotune_compile_timeout": lambda v: isinstance(v, int) and v > 0,
"autotune_precompile_jobs": lambda v: v is None
or (isinstance(v, int) and v > 0),
"autotune_accuracy_check": lambda v: isinstance(v, bool),
"autotune_progress_bar": lambda v: isinstance(v, bool),
"autotune_max_generations": lambda v: v is None
or (isinstance(v, int) and v >= 0),
"print_output_code": lambda v: isinstance(v, bool),
"force_autotune": lambda v: isinstance(v, bool),
"allow_warp_specialize": lambda v: isinstance(v, bool),
"debug_dtype_asserts": lambda v: isinstance(v, bool),
"autotune_rebenchmark_threshold": lambda v: v is None
or (isinstance(v, (int, float)) and v >= 0),
}

for field_name, validator in validators.items():
value = getattr(self, field_name)
if not validator(value):
raise ValueError(f"Invalid value for {field_name}: {value!r}")


class Settings(_Settings):
"""
Expand Down
37 changes: 37 additions & 0 deletions test/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import unittest

import helion


class TestSettingsValidation(unittest.TestCase):
def test_autotune_effort_none_raises(self) -> None:
with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"):
helion.Settings(autotune_effort=None)

def test_autotune_effort_case_insensitive(self) -> None:
settings = helion.Settings(autotune_effort="Quick")
self.assertEqual(settings.autotune_effort, "quick")

def test_negative_compile_timeout_raises(self) -> None:
with self.assertRaisesRegex(
ValueError, r"Invalid value for autotune_compile_timeout: -1"
):
helion.Settings(autotune_compile_timeout=-1)

def test_autotune_precompile_jobs_negative_raises(self) -> None:
with self.assertRaisesRegex(
ValueError, r"Invalid value for autotune_precompile_jobs: -1"
):
helion.Settings(autotune_precompile_jobs=-1)

def test_autotune_max_generations_negative_raises(self) -> None:
with self.assertRaisesRegex(
ValueError, r"Invalid value for autotune_max_generations: -1"
):
helion.Settings(autotune_max_generations=-1)

def test_autotune_effort_invalid_raises(self) -> None:
with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"):
helion.Settings(autotune_effort="super-fast")