[Transform] Split instruction inference from LowerTileOp to InferInstructions#1808
[Transform] Split instruction inference from LowerTileOp to InferInstructions#1808Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds annotation-driven GEMM instruction inference: a new TIR pass infers and injects an "instruction" annotation into Gemm calls; Gemm/GemmPy nodes and Python tileop layers store and consume that annotation during layout inference and lowering. Changes
Sequence Diagram(s)sequenceDiagram
participant PythonLayer as Python Layer
participant TIR as TIR IR
participant InferPass as InferInstructionsPass
participant GemmNode as Gemm/GemmPy Node
participant Lowering as Lowering
PythonLayer->>TIR: Emit TIR with Gemm call (annotations optional)
TIR->>InferPass: Run InferInstructions pass
InferPass->>InferPass: Read thread block size (AttrStmt threadIdx.x)
InferPass->>InferPass: TryParseGemmNode(call) -> GemmNodeVariant
InferPass->>InferPass: Compute GemmInst from block_size + target (if missing)
InferPass->>TIR: Inject "instruction" annotation into Gemm call
TIR->>GemmNode: Lower / InferLayout invoked
GemmNode->>GemmNode: Read GemmInst from annotations_["instruction"]
GemmNode->>Lowering: Use annotated GemmInst for warp partitioning & lowering
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/__init__.py (1)
158-167:⚠️ Potential issue | 🟡 MinorUnreachable duplicate branch:
is_tcgen5mma()is checked twice.Line 160 handles
is_tcgen5mma()and returnsGemmTCGEN5. The second check at line 164 is identical and can never be reached. This appears to be a copy-paste leftover — theraise NotImplementedErroris dead code.Proposed fix — remove the dead branch
elif gemm_inst.is_mfma(): return GemmMFMA - elif gemm_inst.is_tcgen5mma(): - raise NotImplementedError("TCGEN5MMA is not implemented") else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
🤖 Fix all issues with AI agents
In `@src/op/gemm_py.h`:
- Around line 70-71: Change the reflection key on GemmPyNode so the annotations
map is exposed under the same name as GemmNode: replace the registration that
uses "ann" with "annotations" (the property registered via
GemmPyNode::annotations_) so Python accessors (e.g., GemmBase.annotations which
reads getattr(self.gemm_node, "annotations", {})) can find it; update the
.def_ro call on GemmPyNode that currently binds policy_ and annotations_ to use
"annotations" instead of "ann".
In `@src/op/gemm.h`:
- Around line 137-139: Fix the typo in the comment above the field annotations_:
change "Anotations" to "Annotations" so the comment reads "Annotations for the
GEMM operation"; update the nearby comment block that lists supported annotation
keys (mentioning "instruction" and GemmInst) if needed to keep wording
consistent with the corrected spelling.
- Around line 163-164: GemmPyNode registers the annotations field with the
reflection key "ann", causing Python accessors to miss it; change the
registration in GemmPyNode where it exposes annotations_ (the .def_ro call for
the annotations_ member) to use the key "annotations" so it matches GemmNode's
.def_ro("annotations", &GemmNode::annotations_) and other node types, ensuring
getattr(self.gemm_node, "annotations", {}) returns the actual annotations.
In `@src/transform/common/gemm.h`:
- Around line 25-33: The header defines TryParseGemmNode (returning
GemmNodeVariant and using TileOperator, ParseOperator, CallNode, GemmNode,
GemmPyNode) which can cause ODR violations when included by multiple translation
units; fix by either marking the function inline in the header (add the inline
keyword to the TryParseGemmNode function signature) or move the full definition
into a .cc file and leave only a declaration in the header so that only one
translation unit contains the definition.
In `@tilelang/language/gemm_op.py`:
- Line 219: The top-level docstring for the GEMM operator is written as an
f-string (f""") which prevents it from being recognized as the function/module
docstring and keeps gemm.__doc__ unset; change the leading f""" to a plain
triple-quoted string """ so the string becomes the actual docstring for the GEMM
operator (the docstring appearing near the gemm function/class in gemm_op.py) —
no interpolation needed, just remove the f prefix.
In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 134-136: The annotations property on GemmBase uses
getattr(self.gemm_node, "annotations", {}) which silently returns {} for
GemmPyNode because that class registers its data under "ann"; update the
property to handle both names (check "annotations" first then fall back to
"ann") or normalize the gemm_node metadata at construction so both GemmPyNode
and GemmNode expose the same key; locate the annotations property on the
GemmBase class (method annotations) and adjust it to read
getattr(self.gemm_node, "annotations", getattr(self.gemm_node, "ann", {})) or
perform equivalent normalization to ensure consistent behavior.
🧹 Nitpick comments (3)
src/transform/common/gemm.h (1)
11-11: Avoidusing namespacedirectives in header files.
using namespace tir;in a header pollutes the namespace of every translation unit that includes this header, potentially causing name collisions. Move it inside the function body, or qualify names explicitly (e.g.,tir::Call).src/op/gemm.cc (1)
444-446: Consider extracting the repeated annotation-to-GemmInst conversion into a helper.The identical pattern appears in both
Lower(line 444-446) andInferLayout(line 608-610):ICHECK(annotations_.count("instruction")) << "Gemm instruction is not inferred"; GemmInst gemm_inst = static_cast<GemmInst>(Downcast<IntImm>(annotations_.at("instruction"))->value);A small helper (e.g.,
GemmInst GemmNode::getInferredInst() const) would reduce duplication and centralize the annotation key.Also applies to: 608-610
src/transform/infer_instructions.cc (1)
47-74: Clean instruction inference logic with a couple of minor items.The overall approach — parse GEMM, skip if already annotated, infer and annotate — is solid. Two small notes:
- Line 68: Remove the commented-out
LOG(INFO)debug artifact.- Lines 48-72: The indentation of the inner block is slightly off (6 spaces instead of the 4 used in the rest of the class). This is cosmetic but inconsistent.
Proposed cleanup
Stmt VisitStmt_(const EvaluateNode *op) final { if (const auto *call = op->value.as<CallNode>()) { - // Only handle Gemm operator for now - // Copy instruction inference requires layout information, which is not available yet - auto gemm_node = TryParseGemmNode(*call); - if (!gemm_node.has_value()) { - return StmtExprMutator::VisitStmt_(op); - } - if (call->annotations.count("instruction")) { - // Instruction is specified by user, skip inference - return StmtExprMutator::VisitStmt_(op); - } - Map<String, ObjectRef> new_annotations = call->annotations; - ICHECK(thread_block_size_ > 0) - << "Thread block size not set, ensure AttrStmt with thread_extent is visited first"; - GemmInst gemm_inst = std::visit( - [this](const auto *node) { - return node->getGemmInst(thread_block_size_, target_); - }, - *gemm_node); - new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst))); - // LOG(INFO) << "Inferred GEMM instruction: " << static_cast<int>(gemm_inst); - Call new_call = - Call(call->dtype, call->op, call->args, new_annotations, call->span); - return Evaluate(new_call); + // Only handle Gemm operator for now + // Copy instruction inference requires layout information, which is not available yet + auto gemm_node = TryParseGemmNode(*call); + if (!gemm_node.has_value()) { + return StmtExprMutator::VisitStmt_(op); + } + if (call->annotations.count("instruction")) { + // Instruction is specified by user, skip inference + return StmtExprMutator::VisitStmt_(op); + } + Map<String, ObjectRef> new_annotations = call->annotations; + ICHECK(thread_block_size_ > 0) + << "Thread block size not set, ensure AttrStmt with thread_extent is visited first"; + GemmInst gemm_inst = std::visit( + [this](const auto *node) { + return node->getGemmInst(thread_block_size_, target_); + }, + *gemm_node); + new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst))); + Call new_call = + Call(call->dtype, call->op, call->args, new_annotations, call->span); + return Evaluate(new_call); } return StmtExprMutator::VisitStmt_(op); }
- Introduced a new test file to validate the functionality of the GEMM kernel, ensuring correct behavior for both MMA and WGMMA configurations. - Implemented a helper function to detect the use of MMA in the generated kernel source. - Added parameterized tests to check the output against reference results using PyTorch. - Enhanced the GemmInst class with a string representation method for better debugging.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/__init__.py (1)
159-164:⚠️ Potential issue | 🟡 MinorDead code:
is_tcgen5mma()is checked twice — second branch is unreachable.Line 159 already handles
is_tcgen5mma()by returningGemmTCGEN5. Theelifon line 163 re-checks the same condition and raisesNotImplementedError, but this branch can never be reached.📝 Proposed fix — remove the dead branch
elif gemm_inst.is_mfma(): return GemmMFMA - elif gemm_inst.is_tcgen5mma(): - raise NotImplementedError("TCGEN5MMA is not implemented") else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
🤖 Fix all issues with AI agents
In `@src/op/gemm.cc`:
- Around line 92-94: The code unconditionally reads args[17] and args[18] to set
node->cCoords_ which can OOB; add bounds checks like the existing guards for
args[14]/args[15]/args[16] so you only access args[17] and args[18] when
args.size() > 17 and args.size() > 18 respectively (or a single check for
args.size() > 18) before calling args[17].as<PrimExpr>().value() and
args[18].as<PrimExpr>().value(); leave node->annotations_ assignment unchanged.
Ensure you follow the same pattern used in gemm_sp_py.cc for consistency.
🧹 Nitpick comments (7)
tilelang/tileop/gemm/gemm_cutedsl.py (1)
65-68:_get_inferred_gemm_instructionis duplicated inGemmPy(__init__.pylines 129-133).Both
GemmCuTeDSLandGemmPydefine identical_get_inferred_gemm_instruction()methods. Consider moving this toGemmBaseto avoid duplication.tilelang/tileop/gemm/__init__.py (1)
113-115: Workaround property for reflection key mismatch.This
annotationsproperty exists solely becauseGemmPyNodeexposes its annotations as"ann"instead of"annotations". If the reflection key is aligned (as suggested insrc/op/gemm_py.h), this alias can be removed.tilelang/language/gemm_op.py (1)
239-239: Long line with many positional arguments.This single-line call with 12 positional arguments is hard to read and fragile — a positional mix-up would be a silent bug. Consider splitting across multiple lines or using keyword arguments.
♻️ Suggested formatting
- return impl(A, B, C, transpose_A, transpose_B, policy, clear_accum, k_pack, wg_wait, mbar, annotations) + return impl( + A, B, C, + transpose_A=transpose_A, + transpose_B=transpose_B, + policy=policy, + clear_accum=clear_accum, + k_pack=k_pack, + wg_wait=wg_wait, + mbar=mbar, + annotations=annotations, + )src/op/gemm.cc (1)
444-447: Consider extracting the repeated annotation-to-GemmInst logic into a helper.The same 4-line pattern (ICHECK + Downcast + static_cast) is duplicated in both
LowerandInferLayout. A small helper method onGemmNode(e.g.,getAnnotatedGemmInst()) would reduce duplication and centralize validation.♻️ Example helper in gemm.h
GemmInst getAnnotatedGemmInst() const { ICHECK(annotations_.count("instruction")) << "Gemm instruction is not inferred"; return static_cast<GemmInst>( Downcast<IntImm>(annotations_.at("instruction"))->value); }Then in both
LowerandInferLayout:- ICHECK(annotations_.count("instruction")) - << "Gemm instruction is not inferred"; - GemmInst gemm_inst = static_cast<GemmInst>( - Downcast<IntImm>(annotations_.at("instruction"))->value); + GemmInst gemm_inst = getAnnotatedGemmInst();Also applies to: 609-612
testing/python/transform/test_tilelang_transform_infer_instructions.py (1)
10-24: Detection heuristic is fragile and could silently produce false negatives.
_kernel_uses_mmareturnsFalsewhen nogemm_ssmatch is found (line 22), silently treating "cannot determine" the same as "not MMA." If the codegen template name ever changes (e.g.,gemm_rs,gemm_sr), the test would silently pass for the wrong reason.Consider raising or logging when the regex doesn't match at all, so a broken heuristic is caught early:
♻️ Proposed improvement
match = re.search(r"gemm_ss\s*<\s*([^>]+)\s*>", kernel_src) if not match: - return False + raise RuntimeError( + "Could not find gemm_ss<...> in kernel source; " + "detection heuristic needs updating" + )src/transform/infer_instructions.cc (2)
48-78: Instruction inference logic is clean and correctly preserves user-provided annotations.The flow — skip if annotation exists, otherwise infer via
getGemmInstand inject — is sound. Thestd::visitdispatch for theGemmNodeVariantis a nice pattern.One nit: Lines 71-72 contain a commented-out
LOG(INFO)statement. Consider removing it before merge to keep the code clean.🧹 Remove debug artifact
new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst))); - // LOG(INFO) << "Inferred GEMM instruction: " << - // static_cast<int>(gemm_inst); Call new_call =
86-91: Minor:[=]capture is unnecessary.The lambda captures nothing from the enclosing scope (
InferInstructions()is a free function with no locals captured). Using[]or[](PrimFunc f, ...)would be more precise, though this has no functional impact.♻️ Suggested change
- auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
| node->cCoords_ = Array<PrimExpr>( | ||
| {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); | ||
| node->annotations_ = annotations; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate the gemm.cc file
fd -type f -name "gemm.cc"Repository: tile-ai/tilelang
Length of output: 230
🏁 Script executed:
# Once found, examine the specific lines and surrounding context
cat -n src/op/gemm.cc | sed -n '80,110p'Repository: tile-ai/tilelang
Length of output: 1195
🏁 Script executed:
# Search for size guards and args bounds checking patterns in the function
rg -n "args\.size()|args\[1[4-9]\]|args\[2[0-9]\]" src/op/gemm.cc -B5 -A5Repository: tile-ai/tilelang
Length of output: 1173
🏁 Script executed:
# Check the function signature to understand args parameter
rg -n "void.*\(.*args.*\)" src/op/gemm.cc -A2 | head -50Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for guards on args[14], args[15], args[16] to compare
rg -n "args.size()" src/op/gemm.cc -B3 -A3Repository: tile-ai/tilelang
Length of output: 735
🏁 Script executed:
# Get more context - find the function name
cat -n src/op/gemm.cc | sed -n '40,96p'Repository: tile-ai/tilelang
Length of output: 2746
🏁 Script executed:
# Search for where Gemm is called to understand args construction
rg -n "Gemm\(" src/op/gemm.cc -B3 -A10 | head -80Repository: tile-ai/tilelang
Length of output: 1164
🏁 Script executed:
# Check the header file for any documentation about args requirements
fd -type f -name "gemm.h" -o -name "gemm.cc"Repository: tile-ai/tilelang
Length of output: 230
🏁 Script executed:
# Look for args construction in related files
rg -n "args.*17|args.*18" src/ -C2Repository: tile-ai/tilelang
Length of output: 885
Add bounds checking before accessing args[17] and args[18].
Lines 92–93 unconditionally access args[17] and args[18] without size validation. The code defensively guards args[14], args[15], and args[16] with if (args.size() > N) checks (lines 74–90), but skips this guard for indices 17–18, creating an out-of-bounds vulnerability if a caller provides fewer than 19 elements. A parallel implementation in gemm_sp_py.cc (line 83) correctly guards args[17], confirming this pattern should be applied here as well.
🤖 Prompt for AI Agents
In `@src/op/gemm.cc` around lines 92 - 94, The code unconditionally reads args[17]
and args[18] to set node->cCoords_ which can OOB; add bounds checks like the
existing guards for args[14]/args[15]/args[16] so you only access args[17] and
args[18] when args.size() > 17 and args.size() > 18 respectively (or a single
check for args.size() > 18) before calling args[17].as<PrimExpr>().value() and
args[18].as<PrimExpr>().value(); leave node->annotations_ assignment unchanged.
Ensure you follow the same pattern used in gemm_sp_py.cc for consistency.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/gemm_cutedsl.py (1)
30-36:⚠️ Potential issue | 🔴 CriticalAdd TCGEN5MMA handling to GemmCuTeDSL.infer_layout() or prevent TCGEN5MMA from being inferred with CuTeDSL targets.
GemmCuTeDSL.infer_layout() currently only branches on
is_wgmma(), routing all other instructions (including TCGEN5MMA) to GemmMMA layout inference. However, TCGEN5MMA can be inferred on SM100 targets with appropriate memory scopes, independent of the CuTeDSL backend selection. This causes TCGEN5MMA to silently use GemmMMA's layout logic instead of GemmTCGEN5's dedicated handling, producing incorrect layouts.The GemmCuTeDSL docstring explicitly documents support for only "WGMMA/MMA" instructions. Either add explicit TCGEN5MMA and MFMA handling by importing GemmTCGEN5/GemmMFMA and delegating appropriately, or prevent TCGEN5MMA instruction selection when
is_cutedsl_target()is true.
|
convert this pr into draft for now, inferInstructions may be helpful in future but now it will introduce extra complexity. |
Summary by CodeRabbit
New Features
Refactor
Tests
Docs