-
Notifications
You must be signed in to change notification settings - Fork 51
Add epilogue subtiling #948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #948, branch: PaulZhang12/stack/14
cf439ac to
fcc7492
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
fcc7492 to
cdbedf6
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
cdbedf6 to
58496fb
Compare
examples/matmul.py
Outdated
| config=helion.Config( | ||
| block_sizes=[64, 64, 64], | ||
| loop_orders=[[0, 1]], | ||
| l2_groupings=[4], | ||
| range_unroll_factors=[0, 1], | ||
| range_num_stages=[0, 3], | ||
| range_multi_buffers=[None, False], | ||
| range_flattens=[None, None], | ||
| num_warps=8, | ||
| num_stages=6, | ||
| indexing='tensor_descriptor', | ||
| pid_type='flat' | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably dont want to check this in since the best config will depend on the machine
stack-info: PR: #948, branch: PaulZhang12/stack/14
58496fb to
965b193
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this help with matmul perf?
helion/_compiler/device_function.py
Outdated
| import re | ||
| host_function = HostFunction.current() | ||
| block_size_expr = ", ".join(map(self.literal_expr, block_size)) | ||
| pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 didn't you add something to fix this somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes - we have sanitization pass for triton_helpers.* right now at
helion/helion/_compiler/device_function.py
Lines 519 to 532 in 1aaba3f
| if isinstance(value, sympy.Expr): | |
| sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue] | |
| lambda node: isinstance(node, sympy.Function) | |
| and getattr(node.func, "__name__", "") | |
| == "triton_helpers.div_floor_integer", | |
| lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue] | |
| ).replace( # pyright: ignore[reportAttributeAccessIssue] | |
| lambda node: isinstance(node, sympy.Function) | |
| and getattr(node.func, "__name__", "") | |
| == "triton_helpers.remainder_integer", | |
| lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue] | |
| ) | |
| expr = cast("sympy.Expr", sanitized) | |
| return HostFunction.current().sympy_expr(expr) |
examples/matmul.py
Outdated
| Returns: | ||
| Tensor: Resulting matrix of shape [m, n]. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert unrelated change
|
|
||
| return None | ||
|
|
||
| def _supports_epilogue_subtiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be the same as the supports_tensor_descriptor helper we already have?
helion/autotuner/config_spec.py
Outdated
| config.setdefault( | ||
| "load_eviction_policies", self.load_eviction_policies.default() | ||
| ) | ||
| config.setdefault("epilogue_subtiling", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a list since we can have multiple stores in the program?
stack-info: PR: #948, branch: PaulZhang12/stack/14
965b193 to
2bc36d0
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
2bc36d0 to
1c1e282
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
1c1e282 to
cccb0af
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
cccb0af to
a6dd082
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
a6dd082 to
88d46a8
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
88d46a8 to
48eed82
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
b5692b3 to
ec85d55
Compare
ec85d55 to
c315444
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
c315444 to
c0ed00f
Compare
helion/_compiler/device_function.py
Outdated
| block_size_expr = ", ".join(map(self.literal_expr, block_size)) | ||
| pattern = r"triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)" | ||
| replacement = r"\1 // \2" | ||
| block_size_expr = re.sub(pattern, replacement, block_size_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should unify the triton_helpers handling: https://gist.github.com/yf225/dbb4045c44df97c902906290e0f6affa
stack-info: PR: #948, branch: PaulZhang12/stack/14
c0ed00f to
589fc42
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
589fc42 to
2da3497
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
2da3497 to
3e336b1
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
3e336b1 to
ded2521
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
ded2521 to
ad6ca82
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
0826b24 to
0c3d607
Compare
stack-info: PR: #948, branch: PaulZhang12/stack/14
0c3d607 to
bdf0793
Compare
Stacked PRs:
Add epilogue subtiling