Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions examples/blackwell_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from __future__ import annotations

import math
import os
from typing import Callable

import torch
import triton
from triton.testing import do_bench

import helion
Expand Down Expand Up @@ -70,6 +72,55 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso
)


HAS_ACC = os.getenv("WITH_ACC", "0") == "1"
if HAS_ACC:
ACC_OPTIONS = [0, 11, 12, 13, 14, 15]

def make_acc_option(option):
return {"advanced_compiler_configuration": option}
else:
ACC_OPTIONS = [None]
make_acc_option = lambda _: {}


def _supports_reg_auto_ws():
"""Check if the current Triton version supports minRegAutoWS/maxRegAutoWS"""
try:
# Try to create a Config with minRegAutoWS to test support
test_config = triton.Config({}, minRegAutoWS=24, maxRegAutoWS=152)
return True
except (TypeError, AttributeError):
# Parameter not supported in this Triton version
return False


HAS_REG_AUTO_WS = _supports_reg_auto_ws()
print(f"!!!!!!!!!! {HAS_REG_AUTO_WS=} !!!!!!!!!!!!!")
if HAS_REG_AUTO_WS:
REG_AUTO_WS_OPTIONS = [152, 192]
M_OPTIONS = [256]

def make_reg_auto_ws_option(maxreg):
OUTER_LOOP = True
return dict(
_triton_range_id_data_partition_factor=0,
_triton_range_value_data_partition_factor=2,
_triton_config_maxRegAutoWS=maxreg,
range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True],
range_multi_buffers=[None, False],
)
else:
REG_AUTO_WS_OPTIONS = [None]
M_OPTIONS = [128]

def make_reg_auto_ws_option(maxreg):
return dict(
_triton_range_id_data_partition_factor=-1,
_triton_range_value_data_partition_factor=-1,
_triton_config_maxRegAutoWS=-1,
)


# %%
# Attention Kernel Implementation
# -------------------------------
Expand All @@ -79,20 +130,18 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso
@helion.kernel(
configs=[
helion.Config(
block_sizes=[256, N],
range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True],
range_multi_buffers=[None, False],
block_sizes=[M, N],
pid_type="persistent_interleaved",
indexing="tensor_descriptor",
num_warps=4,
num_stages=3,
_triton_range_id_data_partition_factor=0,
_triton_range_value_data_partition_factor=2,
_triton_config_maxRegAutoWS=maxreg,
**make_acc_option(ACC_OPTION),
**make_reg_auto_ws_option(REG_AUTO_WS_OPTION),
)
for M in M_OPTIONS
for N in [64, 128]
for OUTER_LOOP in [True]
for maxreg in [152, 192]
for ACC_OPTION in ACC_OPTIONS
for REG_AUTO_WS_OPTION in REG_AUTO_WS_OPTIONS
],
static_shapes=True,
autotune_accuracy_check=False,
Expand Down Expand Up @@ -136,12 +185,14 @@ def blackwell_attention_kernel(
assert M % block_m == 0
assert N % block_n == 0
hl.register_tunable(
"_triton_range_id_data_partition_factor", EnumFragment(choices=(0,))
"_triton_range_id_data_partition_factor", EnumFragment(choices=(-1, 0))
)
hl.register_tunable(
"_triton_range_value_data_partition_factor", EnumFragment(choices=(-1, 2))
)
hl.register_tunable(
"_triton_range_value_data_partition_factor", EnumFragment(choices=(2,))
"_triton_config_maxRegAutoWS", EnumFragment(choices=(-1, 152, 192))
)
hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192)))
SUBTILING = True
VECT_MUL = 1
qk_scale = qk_scale * 1.44269504 # 1/log(2)
Expand Down
8 changes: 6 additions & 2 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
+ [
x.removeprefix("_triton_config_")
for x in config
if x.startswith("_triton_config_")
if x.startswith("_triton_config_") and config[x] != -1
]
)
self._variable_renames: dict[str, list[str]] = {}
Expand Down Expand Up @@ -614,6 +614,10 @@ def codegen_function_call(self) -> ast.AST:
if any(self.config.range_warp_specializes):
num_warps = max(4, num_warps)

print(
type(self.config["_triton_config_maxRegAutoWS"]),
self.config["_triton_config_maxRegAutoWS"],
)
args.extend(
[
f"num_warps={num_warps}",
Expand All @@ -622,7 +626,7 @@ def codegen_function_call(self) -> ast.AST:
+ [
f"{x.removeprefix('_triton_config_')}={self.config[x]}"
for x in self.config
if x.startswith("_triton_config_")
if x.startswith("_triton_config_") and self.config[x] != -1
]
)
advanced_compiler_configuration = self.config.advanced_compiler_configuration
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}")
t0 = time.perf_counter()
# HACK: run checks multiple times to detect data races
for _ in range(5):
for _ in range(1):
if self._kernel_mutates_args:
self.args = self._clone_args(self._original_args)
torch.accelerator.synchronize()
Expand Down
Loading