Skip to content

Commit 19dc2ad

Browse files
committed
Add advanced compiler configurations
stack-info: PR: #793, branch: jansel/stack/158
1 parent c455136 commit 19dc2ad

25 files changed

+264
-6
lines changed

docs/api/settings.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ with helion.set_default_settings(
154154
155155
Validate each candidate configuration against a baseline output before accepting it. Default is ``True``. Controlled by ``HELION_AUTOTUNE_ACCURACY_CHECK``.
156156
157+
.. autoattribute:: Settings.autotune_search_acc
158+
159+
Enable searching packaged PTXAS advanced compiler configurations during autotuning. Default is ``True``. Controlled by ``HELION_AUTOTUNE_SEARCH_ACC``.
160+
157161
.. autoattribute:: Settings.autotune_rebenchmark_threshold
158162
159163
Controls how aggressively Helion re-runs promising configs to avoid outliers. Default is ``1.5`` (re-benchmark anything within 1.5x of the best).
@@ -233,6 +237,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
233237
| ``HELION_AUTOTUNE_RANDOM_SEED`` | ``autotune_random_seed`` | Seed used for randomized autotuning searches. |
234238
| ``HELION_AUTOTUNE_MAX_GENERATIONS`` | ``autotune_max_generations`` | Upper bound on generations for Pattern Search and Differential Evolution. |
235239
| ``HELION_AUTOTUNE_ACCURACY_CHECK`` | ``autotune_accuracy_check`` | Toggle baseline validation for candidate configs. |
240+
| ``HELION_AUTOTUNE_SEARCH_ACC`` | ``autotune_search_acc`` | Enable packaged PTXAS advanced compiler configuration search during autotuning. |
236241
| ``HELION_REBENCHMARK_THRESHOLD`` | ``autotune_rebenchmark_threshold`` | Re-run configs whose performance is within a multiplier of the current best. |
237242
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
238243
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |

helion/_compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,12 @@ def warps_to_threads(num_warps: int) -> int:
243243
)
244244
return num_warps * (props.warp_size or 32)
245245
return num_warps * 32
246+
247+
248+
def supports_ptxas(device: torch.device) -> bool:
249+
"""Return True if PTXAS options are available for the given device."""
250+
if device.type != "cuda":
251+
return False
252+
if torch.version.hip is not None:
253+
return False
254+
return supports_tensor_descriptor()

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2121

2222
from .. import exc
23+
from .._compat import supports_ptxas
2324
from ..language.constexpr import ConstExpr
2425
from .loop_dependency_checker import LoopDependencyChecker
2526
from .source_location import SourceLocation
@@ -90,6 +91,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9091
self.block_sizes: list[BlockSizeInfo] = []
9192
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
9293
self.config_spec = ConfigSpec()
94+
self.config_spec.ptxas_supported = supports_ptxas(device)
9395
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
9496
collections.Counter()
9597
)

helion/_compiler/device_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,12 @@ def codegen_function_call(self) -> ast.AST:
574574
f"num_stages={self.config.num_stages}",
575575
]
576576
)
577+
advanced_compiler_configuration = self.config.advanced_compiler_configuration
578+
if advanced_compiler_configuration:
579+
from ..runtime.ptxas_configs import get_ptxas_option
580+
581+
ptx_option = get_ptxas_option(advanced_compiler_configuration)
582+
args.append(f"ptx_options={ptx_option!r}")
577583
pid = self.pid
578584
assert pid is not None
579585
# TODO(jansel): we should run CSE this statement

helion/autotuner/base_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def __init__(
481481
self.config_gen: ConfigGeneration = ConfigGeneration(
482482
self.config_spec,
483483
overrides=overrides,
484+
include_advanced_compiler_configuration=self.settings.autotune_search_acc,
484485
)
485486

486487
@property

helion/autotuner/config_generation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
config_spec: ConfigSpec,
3232
*,
3333
overrides: Mapping[str, object] | None = None,
34+
include_advanced_compiler_configuration: bool = True,
3435
) -> None:
3536
def _collect_spec(spec: ConfigSpecFragment) -> object:
3637
"""
@@ -47,8 +48,14 @@ def _collect_spec(spec: ConfigSpecFragment) -> object:
4748

4849
super().__init__()
4950
self.config_spec = config_spec
51+
self._include_advanced_compiler_configuration = (
52+
include_advanced_compiler_configuration
53+
)
5054
self.flat_spec: list[ConfigSpecFragment] = []
51-
config_spec.flat_config(_collect_spec)
55+
config_spec.flat_config(
56+
_collect_spec,
57+
include_advanced_compiler_configuration=include_advanced_compiler_configuration,
58+
)
5259
assert self.flat_spec, "No config values to tune"
5360
self._override_values = dict(overrides or {})
5461
self.block_size_indices: list[int] = [
@@ -93,7 +100,10 @@ def get_next_value(spec: ConfigSpecFragment) -> object:
93100

94101
assert len(flat_values) == len(self.flat_spec)
95102
count: itertools.count[int] = itertools.count()
96-
config = self.config_spec.flat_config(get_next_value)
103+
config = self.config_spec.flat_config(
104+
get_next_value,
105+
include_advanced_compiler_configuration=self._include_advanced_compiler_configuration,
106+
)
97107
assert next(count) == len(flat_values)
98108
return self._apply_overrides(config)
99109

helion/autotuner/config_spec.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"advanced_compiler_configuration",
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
@@ -105,6 +106,7 @@ class ConfigSpec:
105106
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106107
)
107108
)
109+
ptxas_supported: bool = False
108110

109111
@staticmethod
110112
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -238,6 +240,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
238240
else:
239241
config[name] = values[0]
240242

243+
if "advanced_compiler_configuration" in config:
244+
value = config.get("advanced_compiler_configuration") or 0
245+
if not isinstance(value, int):
246+
raise InvalidConfig(
247+
f"advanced_compiler_configuration must be integer, got {value!r}"
248+
)
249+
if value and not self.ptxas_supported:
250+
raise InvalidConfig(
251+
"advanced_compiler_configuration requires PTXAS support"
252+
)
253+
config["advanced_compiler_configuration"] = value
254+
241255
# Set default values for grid indices when pid_type is not persistent
242256
pid_type = config["pid_type"]
243257
if pid_type in ("flat", "xyz") and self.grid_block_ids:
@@ -260,8 +274,18 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
260274
def default_config(self) -> helion.Config:
261275
return self.flat_config(lambda x: x.default())
262276

263-
def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Config:
277+
def flat_config(
278+
self,
279+
fn: Callable[[ConfigSpecFragment], object],
280+
*,
281+
include_advanced_compiler_configuration: bool | None = None,
282+
) -> helion.Config:
264283
"""Map a flattened version of the config using the given function."""
284+
include_advanced = self.ptxas_supported
285+
if include_advanced_compiler_configuration is not None:
286+
include_advanced = (
287+
include_advanced and include_advanced_compiler_configuration
288+
)
265289
config = {
266290
"block_sizes": self.block_sizes._flat_config(self, fn),
267291
"loop_orders": self.loop_orders._flat_config(self, fn),
@@ -280,6 +304,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
280304
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
281305
"load_eviction_policies": fn(self.load_eviction_policies),
282306
}
307+
if include_advanced:
308+
from ..runtime.ptxas_configs import search_ptxas_configs
309+
310+
config["advanced_compiler_configuration"] = fn(
311+
EnumFragment((0, *search_ptxas_configs()))
312+
)
283313
# Add tunable parameters
284314
config.update(
285315
{key: fn(fragment) for key, fragment in self.user_defined_tunables.items()}

helion/autotuner/random_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ def __init__(
3939
configs=ConfigGeneration(
4040
kernel.config_spec,
4141
overrides=kernel.settings.autotune_config_overrides or None,
42+
include_advanced_compiler_configuration=kernel.settings.autotune_search_acc,
4243
).random_population(count),
4344
)

helion/runtime/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ def default_launcher(
6262
*args: object,
6363
num_warps: int,
6464
num_stages: int,
65+
ptx_options: str | None = None,
6566
) -> object:
6667
"""Default launcher function that executes the kernel immediately."""
67-
return triton_kernel.run(
68-
*args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages
69-
)
68+
run_kwargs = {
69+
"grid": grid,
70+
"warmup": False,
71+
"num_warps": num_warps,
72+
"num_stages": num_stages,
73+
}
74+
if ptx_options:
75+
run_kwargs["ptx_options"] = ptx_options
76+
return triton_kernel.run(*args, **run_kwargs)

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
advanced_compiler_configuration: int | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -61,6 +62,7 @@ def __init__(
6162
num_stages: Number of stages for software pipelining.
6263
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
6364
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
65+
advanced_compiler_configuration: Identifier for packaged PTXAS control files applied during compilation.
6466
**kwargs: Additional user-defined configuration parameters.
6567
"""
6668
self.config = {}
@@ -81,6 +83,7 @@ def __init__(
8183
"num_stages": num_stages,
8284
"indexing": indexing,
8385
"pid_type": pid_type,
86+
"advanced_compiler_configuration": advanced_compiler_configuration,
8487
}
8588
for key, value in core_props.items():
8689
if value is not None:
@@ -178,6 +181,10 @@ def pid_type(self) -> PidTypeLiteral:
178181
def range_unroll_factors(self) -> list[int]:
179182
return cast("list[int]", self.config.get("range_unroll_factors", []))
180183

184+
@property
185+
def advanced_compiler_configuration(self) -> int:
186+
return cast("int", self.config.get("advanced_compiler_configuration", 0))
187+
181188
@property
182189
def range_warp_specializes(self) -> list[bool | None]:
183190
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))

0 commit comments

Comments
 (0)