Skip to content

Commit 0daab4c

Browse files
committed
Add advanced compiler configurations
stack-info: PR: #793, branch: jansel/stack/158
1 parent 8aefc0a commit 0daab4c

20 files changed

+184
-3
lines changed

helion/_compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,12 @@ def warps_to_threads(num_warps: int) -> int:
105105
)
106106
return num_warps * (props.warp_size or 32)
107107
return num_warps * 32
108+
109+
110+
def supports_ptxas(device: torch.device) -> bool:
111+
"""Return True if PTXAS options are available for the given device."""
112+
if device.type != "cuda":
113+
return False
114+
if torch.version.hip is not None:
115+
return False
116+
return supports_tensor_descriptor()

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2121

2222
from .. import exc
23+
from .._compat import supports_ptxas
2324
from ..language.constexpr import ConstExpr
2425
from .loop_dependency_checker import LoopDependencyChecker
2526
from .source_location import SourceLocation
@@ -90,6 +91,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9091
self.block_sizes: list[BlockSizeInfo] = []
9192
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
9293
self.config_spec = ConfigSpec()
94+
self.config_spec.ptxas_supported = supports_ptxas(device)
9395
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
9496
collections.Counter()
9597
)

helion/_compiler/device_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,12 @@ def codegen_function_call(self) -> ast.AST:
574574
f"num_stages={self.config.num_stages}",
575575
]
576576
)
577+
ptxas_config = self.config.ptxas_config
578+
if ptxas_config:
579+
from ..runtime.ptxas_configs import get_ptxas_option
580+
581+
ptx_option = get_ptxas_option(ptxas_config)
582+
args.append(f"ptx_options={ptx_option!r}")
577583
pid = self.pid
578584
assert pid is not None
579585
# TODO(jansel): we should run CSE this statement

helion/autotuner/config_spec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"ptxas_config",
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
@@ -105,6 +106,7 @@ class ConfigSpec:
105106
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106107
)
107108
)
109+
ptxas_supported: bool = False
108110

109111
@staticmethod
110112
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -238,6 +240,11 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
238240
else:
239241
config[name] = values[0]
240242

243+
if self.ptxas_supported:
244+
value = config.get("ptxas_config") or 0
245+
if not isinstance(value, int):
246+
raise InvalidConfig(f"ptxas_config must be integer, got {value!r}")
247+
241248
# Set default values for grid indices when pid_type is not persistent
242249
pid_type = config["pid_type"]
243250
if pid_type in ("flat", "xyz") and self.grid_block_ids:
@@ -280,6 +287,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
280287
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
281288
"load_eviction_policies": fn(self.load_eviction_policies),
282289
}
290+
if self.ptxas_supported:
291+
from ..runtime.ptxas_configs import search_ptxas_configs
292+
293+
config["ptxas_config"] = fn(EnumFragment((0, *search_ptxas_configs())))
283294
# Add tunable parameters
284295
config.update(
285296
{key: fn(fragment) for key, fragment in self.user_defined_tunables.items()}

helion/runtime/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,15 @@ def default_launcher(
6161
*args: object,
6262
num_warps: int,
6363
num_stages: int,
64+
ptx_options: str | None = None,
6465
) -> object:
6566
"""Default launcher function that executes the kernel immediately."""
66-
return triton_kernel.run(
67-
*args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages
68-
)
67+
run_kwargs = {
68+
"grid": grid,
69+
"warmup": False,
70+
"num_warps": num_warps,
71+
"num_stages": num_stages,
72+
}
73+
if ptx_options:
74+
run_kwargs["ptx_options"] = ptx_options
75+
return triton_kernel.run(*args, **run_kwargs)

helion/runtime/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
ptxas_config: int | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -81,6 +82,7 @@ def __init__(
8182
"num_stages": num_stages,
8283
"indexing": indexing,
8384
"pid_type": pid_type,
85+
"ptxas_config": ptxas_config,
8486
}
8587
for key, value in core_props.items():
8688
if value is not None:
@@ -178,6 +180,10 @@ def pid_type(self) -> PidTypeLiteral:
178180
def range_unroll_factors(self) -> list[int]:
179181
return cast("list[int]", self.config.get("range_unroll_factors", []))
180182

183+
@property
184+
def ptxas_config(self) -> int:
185+
return cast("int", self.config.get("ptxas_config", 0))
186+
181187
@property
182188
def range_warp_specializes(self) -> list[bool | None]:
183189
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Utilities for working with packaged PTXAS control files."""
2+
3+
from __future__ import annotations
4+
5+
from functools import cache
6+
from pathlib import Path
7+
8+
_ADVANCED_COMPILER_CONFIGURATIONS: dict[int, str] = {
9+
# 1: "fp8fatt_0.bin", # caused timeouts
10+
2: "fp8fatt_1.bin",
11+
3: "fp8fatt_2.bin",
12+
4: "fp8fatt_3.bin",
13+
5: "matmul_0.bin",
14+
6: "matmul_1.bin",
15+
7: "matmul_2.bin",
16+
8: "matmul_3.bin",
17+
9: "matmul_4.bin",
18+
10: "matmul_5.bin",
19+
}
20+
21+
22+
def _config_root() -> Path:
23+
return Path(__file__).resolve().parent
24+
25+
26+
@cache
27+
def search_ptxas_configs() -> tuple[int, ...]:
28+
"""Return the sorted tuple of available PTXAS config IDs."""
29+
30+
return tuple(sorted(_ADVANCED_COMPILER_CONFIGURATIONS))
31+
32+
33+
def _advanced_compiler_configuration_path(config_id: int) -> str:
34+
"""Return the absolute path to the advanced compiler configuration for ``config_id``."""
35+
36+
try:
37+
filename = _ADVANCED_COMPILER_CONFIGURATIONS[config_id]
38+
except KeyError as exc: # pragma: no cover - defensive
39+
raise ValueError(
40+
f"Unknown advanced compiler configuration id: {config_id}"
41+
) from exc
42+
resolved = (_config_root() / filename).resolve()
43+
if not resolved.is_file():
44+
raise FileNotFoundError(
45+
f"Missing advanced compiler configuration file: {resolved}"
46+
)
47+
return str(resolved)
48+
49+
50+
@cache
51+
def get_ptxas_option(config_value: int) -> str | None:
52+
"""Translate a config enum value into a PTXAS option string."""
53+
54+
if config_value == 0:
55+
return None
56+
return f"--apply-controls {_advanced_compiler_configuration_path(config_value)}"
9.75 KB
Binary file not shown.
2.44 KB
Binary file not shown.
2.38 KB
Binary file not shown.

0 commit comments

Comments
 (0)