Skip to content

Commit 589fc42

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent efc520e commit 589fc42

File tree

8 files changed

+318
-20
lines changed

8 files changed

+318
-20
lines changed

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9999
self.device_load_count = (
100100
0 # Track number of loads in all device code for eviction policy tuning
101101
)
102+
self.device_store_count = 0 # Track number of stores for subtiling
102103

103104
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
104105
from .device_function import contains_only_block_size_symbols

helion/_compiler/device_function.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250
self.rng_seed_count = 0
251251
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
252252
# Name of the RNG seed buffer parameter in kernel signature
253+
self.device_store_index = (
254+
0 # Track which store in device code we're generating (for subtiling)
255+
)
253256
self.rng_seed_buffer_param_name = None
254257

255258
def has_rng_ops(self) -> bool:
@@ -420,9 +423,15 @@ def tensor_arg(
420423
def tensor_descriptor_arg(
421424
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422425
) -> TensorDescriptorArg:
426+
import re
427+
423428
host_function = HostFunction.current()
424429
block_size_expr = ", ".join(map(self.literal_expr, block_size))
430+
pattern = r"triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)"
431+
replacement = r"\1 // \2"
432+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
425433
key = (fake_value, block_size_expr)
434+
426435
if key not in self._tensor_descriptor_args:
427436
origin = host_function.tensor_to_origin[fake_value]
428437
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/device_ir.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def visit_For(self, node: ast.For) -> None:
10761076
self.generic_visit(node)
10771077

10781078

1079-
def _count_device_loads(device_ir: DeviceIR) -> int:
1079+
def _count_device_loads_and_stores(device_ir: DeviceIR) -> int:
10801080
"""Count the number of load operations in all device code for eviction policy tuning."""
10811081
from ..language import memory_ops
10821082

@@ -1087,26 +1087,29 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871087
if info.new_graph_id is not None
10881088
}
10891089

1090-
load_count = 0
1090+
load_count, store_count = 0, 0
10911091
# Walk all graphs except rolled duplicates
10921092
for graph_info in device_ir.graphs:
10931093
if graph_info.graph_id in rolled_graph_ids:
10941094
continue
10951095

10961096
for node in graph_info.graph.nodes:
10971097
# Check if this is a load operation
1098-
if node.op == "call_function" and node.target is memory_ops.load:
1099-
# Only count loads without explicit eviction policy
1100-
# (user can still specify eviction_policy to override tuning)
1101-
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1102-
eviction_policy_arg = node.kwargs.get("eviction_policy")
1103-
if eviction_policy_arg is None:
1104-
# Check if eviction_policy was passed as positional arg (index 3)
1105-
if len(node.args) >= 4:
1106-
eviction_policy_arg = node.args[3]
1098+
if node.op == "call_function":
1099+
if node.target is memory_ops.load:
1100+
# Only count loads without explicit eviction policy
1101+
# (user can still specify eviction_policy to override tuning)
1102+
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1103+
eviction_policy_arg = node.kwargs.get("eviction_policy")
11071104
if eviction_policy_arg is None:
1108-
load_count += 1
1109-
return load_count
1105+
# Check if eviction_policy was passed as positional arg (index 3)
1106+
if len(node.args) >= 4:
1107+
eviction_policy_arg = node.args[3]
1108+
if eviction_policy_arg is None:
1109+
load_count += 1
1110+
elif node.target is memory_ops.store:
1111+
store_count += 1
1112+
return load_count, store_count
11101113

11111114

11121115
def _register_eviction_policy_tunable(load_count: int) -> None:
@@ -1125,6 +1128,24 @@ def _register_eviction_policy_tunable(load_count: int) -> None:
11251128
env.device_load_count = load_count
11261129

11271130

1131+
def _register_epilogue_subtile_tunable(store_count: int) -> None:
1132+
"""Register the epilogue subtile tunable for all device stores."""
1133+
if store_count == 0:
1134+
return
1135+
1136+
from ..autotuner.config_fragment import EnumFragment
1137+
from ..autotuner.config_fragment import ListOf
1138+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1139+
1140+
env = CompileEnvironment.current()
1141+
# Register a tunable for epilogue subtile for all device stores
1142+
fragment = ListOf(
1143+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
1144+
)
1145+
env.config_spec.epilogue_subtiling = fragment
1146+
env.device_store_count = store_count
1147+
1148+
11281149
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11291150
device_ir = DeviceIR()
11301151
with func, device_ir, compile_lock:
@@ -1148,9 +1169,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11481169
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
11491170

11501171
# Count all device loads and register eviction policy tunable
1151-
load_count = _count_device_loads(device_ir)
1172+
load_count, store_count = _count_device_loads_and_stores(device_ir)
11521173
_register_eviction_policy_tunable(load_count)
11531174

1175+
# Epilogue subtiling only for Blackwell
1176+
if torch.cuda.get_device_capability() >= (10, 0):
1177+
_register_epilogue_subtile_tunable(store_count)
1178+
11541179
return device_ir
11551180

11561181

helion/_compiler/indexing_strategy.py

Lines changed: 197 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -353,7 +354,6 @@ def codegen_load(
353354
)
354355
assert extra_mask is None
355356
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
356-
357357
# Load from tensor descriptor with permuted offsets
358358
load_expr = expr_from_string(
359359
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"
@@ -383,10 +383,12 @@ def codegen_store(
383383
)
384384
assert extra_mask is None
385385
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386+
store_value = indexing.reshape_store(state, value)
386387

388+
config = DeviceFunction.current().config
389+
epilogue_subtiles = state.config.epilogue_subtiling
387390
# Apply permutation to the value being stored if needed
388391
desc_arg = indexing.tensor_descriptor_arg(state)
389-
store_value = indexing.reshape_store(state, value)
390392

391393
if desc_arg.permutation is not None:
392394
# Apply permutation to the value
@@ -395,11 +397,204 @@ def codegen_store(
395397
store_val=store_value,
396398
)
397399

400+
if (idx := state.device_function.device_store_index) < len(epilogue_subtiles):
401+
subtile_split = epilogue_subtiles[idx]
402+
state.device_function.device_store_index += 1
403+
404+
# Check if we should fuse a pointwise operation into the epilogue store
405+
fused_pointwise_node = self._get_fusable_pointwise_node(state)
406+
407+
subtile_codegen = self._codegen_epilogue_subtile_store(
408+
state,
409+
fake_tensor,
410+
indexing,
411+
store_value,
412+
subtile_split,
413+
config,
414+
fused_pointwise_node,
415+
)
416+
if subtile_codegen is not None:
417+
return subtile_codegen
418+
398419
return expr_from_string(
399420
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
400421
value=store_value,
401422
)
402423

424+
def _get_fusable_pointwise_node(self, state: CodegenState) -> torch.fx.Node | None:
425+
"""Find a pointwise node feeding into this store that can be fused.
426+
427+
Returns the pointwise FX node if found, None otherwise.
428+
"""
429+
if state.fx_node is None:
430+
return None
431+
432+
# Get the value being stored (3rd argument to store)
433+
if len(state.fx_node.args) < 3:
434+
return None
435+
436+
value_node = state.fx_node.args[2]
437+
if not isinstance(value_node, torch.fx.Node):
438+
return None
439+
440+
# Check if this is a pointwise node
441+
from .inductor_lowering import PointwiseLowering
442+
443+
lowering = value_node.meta.get("lowering")
444+
if not isinstance(lowering, PointwiseLowering):
445+
return None
446+
447+
# Check if this node only has one user (the store)
448+
if len(list(value_node.users)) != 1:
449+
return None
450+
451+
return value_node
452+
453+
def _apply_pointwise_to_subtile(
454+
self, state: CodegenState, pointwise_node: torch.fx.Node, subtile_value: ast.AST
455+
) -> ast.AST:
456+
"""Apply a pointwise operation to a subtile value.
457+
458+
Args:
459+
state: The codegen state
460+
pointwise_node: The FX node representing the pointwise operation
461+
subtile_value: The AST for the subtile value to apply the operation to
462+
463+
Returns:
464+
AST for the result after applying the pointwise operation
465+
"""
466+
from torch._inductor import ir
467+
468+
from .inductor_lowering import PointwiseLowering
469+
from .inductor_lowering import install_inductor_kernel_handlers
470+
471+
lowering = pointwise_node.meta["lowering"]
472+
assert isinstance(lowering, PointwiseLowering)
473+
474+
# Get the pointwise buffer
475+
buffer = lowering.buffer
476+
assert isinstance(buffer.data, ir.Pointwise)
477+
478+
# Create a temporary variable for the subtile
479+
codegen = state.codegen
480+
subtile_var = codegen.lift(subtile_value, prefix="subtile")
481+
482+
# Set up the inductor kernel handlers with the subtile as input
483+
with install_inductor_kernel_handlers(
484+
codegen, {lowering.input_names[0]: subtile_var}
485+
):
486+
# Generate the pointwise operation
487+
indices = [sympy.Symbol(f"i{n}") for n in range(len(buffer.data.ranges))]
488+
from .inductor_lowering import _unpack_opsvalue
489+
490+
result_name = _unpack_opsvalue(buffer.data.inner_fn(indices))
491+
return expr_from_string(result_name)
492+
493+
def _codegen_epilogue_subtile_store(
494+
self,
495+
state: CodegenState,
496+
fake_tensor: torch.Tensor,
497+
indexing: BlockedSubscriptIndexing,
498+
store_value: ast.AST,
499+
subtile_split: int,
500+
config: Config,
501+
fused_pointwise_node: torch.fx.Node | None = None,
502+
) -> ast.AST | None:
503+
# Currently support 2D tiles without permutations
504+
if (
505+
len(indexing.block_shape) != 2
506+
or len(indexing.offsets) != 2
507+
or subtile_split == 0
508+
):
509+
return None
510+
511+
env = CompileEnvironment.current()
512+
block_m, block_n = indexing.block_shape
513+
try:
514+
block_n_hint = env.size_hint(block_n)
515+
block_idx = env.get_block_id(block_n)
516+
block_size = env.block_sizes[block_idx].from_config(config)
517+
except Exception:
518+
return None
519+
520+
if block_n_hint % 2 != 0 or block_size <= 16:
521+
return None
522+
523+
device_fn = state.device_function
524+
codegen = state.codegen
525+
526+
block_m_str = device_fn.literal_expr(block_m)
527+
block_n_str = device_fn.literal_expr(block_n)
528+
indexing.block_shape[1] //= subtile_split
529+
530+
# TODO(PaulZhang12): Support more epilogue subtile configs besides 2
531+
block_n_half_str = f"({block_n_str} // {subtile_split})"
532+
533+
# If we have a fused pointwise operation, mark it to skip normal codegen
534+
# and get its input value instead
535+
if fused_pointwise_node is not None:
536+
fused_pointwise_node.meta["fused_into_store"] = True
537+
538+
# Lift the store value into a temporary variable for reuse
539+
acc_var = codegen.lift(store_value, prefix="acc")
540+
541+
reshape_expr = expr_from_string(
542+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)",
543+
acc=acc_var,
544+
dim_m=expr_from_string(block_m_str),
545+
dim_half=expr_from_string(block_n_half_str),
546+
)
547+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
548+
549+
acc0_name = codegen.tmpvar(prefix="acc")
550+
acc1_name = codegen.tmpvar(prefix="acc")
551+
codegen.add_statement(
552+
statement_from_string(
553+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
554+
acc=reshape_var,
555+
)
556+
)
557+
558+
# Now apply the pointwise operation per-subtile if we have one
559+
if fused_pointwise_node is not None:
560+
acc0 = self._apply_pointwise_to_subtile(
561+
state, fused_pointwise_node, expr_from_string(acc0_name)
562+
)
563+
acc1 = self._apply_pointwise_to_subtile(
564+
state, fused_pointwise_node, expr_from_string(acc1_name)
565+
)
566+
else:
567+
acc0 = expr_from_string(acc0_name)
568+
acc1 = expr_from_string(acc1_name)
569+
570+
desc_name = indexing.tensor_descriptor(state)
571+
offset0 = expr_from_string(indexing.offsets[0])
572+
offset1 = expr_from_string(indexing.offsets[1])
573+
574+
# First subtile store
575+
codegen.add_statement(
576+
statement_from_string(
577+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
578+
off0=offset0,
579+
off1=offset1,
580+
value=acc0,
581+
)
582+
)
583+
584+
offset1_shifted = expr_from_string(
585+
"({offset} + {half})",
586+
offset=expr_from_string(indexing.offsets[1]),
587+
half=expr_from_string(block_n_half_str),
588+
)
589+
590+
# Emit second subtile store as the expression returned to the caller
591+
return expr_from_string(
592+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
593+
off0=offset0,
594+
off1=offset1_shifted,
595+
value=acc1,
596+
)
597+
403598

404599
class StackIndexingStrategy:
405600
"""

0 commit comments

Comments
 (0)