Skip to content

[Transform] Split instruction inference from LowerTileOp to InferInstructions#1808

Draft
Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Rachmanino:infer-inst
Draft

[Transform] Split instruction inference from LowerTileOp to InferInstructions#1808
Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Rachmanino:infer-inst

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Feb 7, 2026

Summary by CodeRabbit

  • New Features

    • Added an optional annotations parameter to public GEMM APIs and a new InferInstructions transform that injects inferred instruction annotations into call sites.
  • Refactor

    • Lowering and layout inference now consume instruction annotations instead of selecting instructions dynamically; GEMM operators expose annotations to callers.
  • Tests

    • Added CUDA GEMM tests validating annotation propagation, MMA vs WGMMA codegen, and numerical correctness.
  • Docs

    • Updated GEMM docstrings to document the new annotations parameter.

@github-actions
Copy link

github-actions bot commented Feb 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 7, 2026

📝 Walkthrough

Walkthrough

Adds annotation-driven GEMM instruction inference: a new TIR pass infers and injects an "instruction" annotation into Gemm calls; Gemm/GemmPy nodes and Python tileop layers store and consume that annotation during layout inference and lowering.

Changes

Cohort / File(s) Summary
C++ GEMM nodes
src/op/gemm.h, src/op/gemm.cc, src/op/gemm_py.h, src/op/gemm_py.cc
Add public annotations_ map to Gemm/GemmPy nodes and store constructor-provided annotations; Lower/InferLayout now read annotations_["instruction"] (ICHECK guards). getGemmInst visibility changed to public and annotations exposed via reflection.
TIR transform pass
src/transform/common/gemm.h, src/transform/infer_instructions.cc
Add TryParseGemmNode utility and new InferInstructions pass that reads thread block size, infers GemmInst when missing, injects "instruction" into call annotations, and registers the pass as tl.transform.InferInstructions.
Python GEMM API & pipeline
tilelang/language/gemm_op.py, tilelang/transform/__init__.py, tilelang/engine/phase.py
Add optional annotations parameter to public GEMM APIs and thread it into tir.call_intrin; expose InferInstructions() pass and insert it into the LowerAndLegalize pipeline.
TileOp Python internals
tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_cutedsl.py, tilelang/tileop/gemm/inst.py
Expose annotations property on GemmPy/GemmBase; remove dynamic _select_gemm_instruction and add _get_inferred_gemm_instruction() that asserts "instruction" presence and returns GemmInst; add GemmInst.__str__.
Tests
testing/python/transform/test_tilelang_transform_infer_instructions.py
Add CUDA-enabled test that exercises InferInstructions and annotation propagation, asserting generated kernel uses MMA vs WGMMA per annotation and validating numeric correctness.
Headers & packaging
src/transform/common/gemm.h, manifest_file, requirements.txt, pyproject.toml, setup.py
New header for Gemm parsing utilities and packaging/manifest metadata updates.

Sequence Diagram(s)

sequenceDiagram
    participant PythonLayer as Python Layer
    participant TIR as TIR IR
    participant InferPass as InferInstructionsPass
    participant GemmNode as Gemm/GemmPy Node
    participant Lowering as Lowering

    PythonLayer->>TIR: Emit TIR with Gemm call (annotations optional)
    TIR->>InferPass: Run InferInstructions pass
    InferPass->>InferPass: Read thread block size (AttrStmt threadIdx.x)
    InferPass->>InferPass: TryParseGemmNode(call) -> GemmNodeVariant
    InferPass->>InferPass: Compute GemmInst from block_size + target (if missing)
    InferPass->>TIR: Inject "instruction" annotation into Gemm call
    TIR->>GemmNode: Lower / InferLayout invoked
    GemmNode->>GemmNode: Read GemmInst from annotations_["instruction"]
    GemmNode->>Lowering: Use annotated GemmInst for warp partitioning & lowering
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through code both neat and bright,

I hid an instruction in the night,
From Python seed to TIR's clear stream,
The gemm now knows its chosen scheme,
Hooray — annotations take flight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately summarizes the main change: splitting instruction inference into a separate InferInstructions pass rather than having it embedded in LowerTileOp, which aligns with the substantial refactoring across multiple files to extract and centralize this logic.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/__init__.py (1)

158-167: ⚠️ Potential issue | 🟡 Minor

Unreachable duplicate branch: is_tcgen5mma() is checked twice.

Line 160 handles is_tcgen5mma() and returns GemmTCGEN5. The second check at line 164 is identical and can never be reached. This appears to be a copy-paste leftover — the raise NotImplementedError is dead code.

Proposed fix — remove the dead branch
         elif gemm_inst.is_mfma():
             return GemmMFMA
-        elif gemm_inst.is_tcgen5mma():
-            raise NotImplementedError("TCGEN5MMA is not implemented")
         else:
             raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
🤖 Fix all issues with AI agents
In `@src/op/gemm_py.h`:
- Around line 70-71: Change the reflection key on GemmPyNode so the annotations
map is exposed under the same name as GemmNode: replace the registration that
uses "ann" with "annotations" (the property registered via
GemmPyNode::annotations_) so Python accessors (e.g., GemmBase.annotations which
reads getattr(self.gemm_node, "annotations", {})) can find it; update the
.def_ro call on GemmPyNode that currently binds policy_ and annotations_ to use
"annotations" instead of "ann".

In `@src/op/gemm.h`:
- Around line 137-139: Fix the typo in the comment above the field annotations_:
change "Anotations" to "Annotations" so the comment reads "Annotations for the
GEMM operation"; update the nearby comment block that lists supported annotation
keys (mentioning "instruction" and GemmInst) if needed to keep wording
consistent with the corrected spelling.
- Around line 163-164: GemmPyNode registers the annotations field with the
reflection key "ann", causing Python accessors to miss it; change the
registration in GemmPyNode where it exposes annotations_ (the .def_ro call for
the annotations_ member) to use the key "annotations" so it matches GemmNode's
.def_ro("annotations", &GemmNode::annotations_) and other node types, ensuring
getattr(self.gemm_node, "annotations", {}) returns the actual annotations.

In `@src/transform/common/gemm.h`:
- Around line 25-33: The header defines TryParseGemmNode (returning
GemmNodeVariant and using TileOperator, ParseOperator, CallNode, GemmNode,
GemmPyNode) which can cause ODR violations when included by multiple translation
units; fix by either marking the function inline in the header (add the inline
keyword to the TryParseGemmNode function signature) or move the full definition
into a .cc file and leave only a declaration in the header so that only one
translation unit contains the definition.

In `@tilelang/language/gemm_op.py`:
- Line 219: The top-level docstring for the GEMM operator is written as an
f-string (f""") which prevents it from being recognized as the function/module
docstring and keeps gemm.__doc__ unset; change the leading f""" to a plain
triple-quoted string """ so the string becomes the actual docstring for the GEMM
operator (the docstring appearing near the gemm function/class in gemm_op.py) —
no interpolation needed, just remove the f prefix.

In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 134-136: The annotations property on GemmBase uses
getattr(self.gemm_node, "annotations", {}) which silently returns {} for
GemmPyNode because that class registers its data under "ann"; update the
property to handle both names (check "annotations" first then fall back to
"ann") or normalize the gemm_node metadata at construction so both GemmPyNode
and GemmNode expose the same key; locate the annotations property on the
GemmBase class (method annotations) and adjust it to read
getattr(self.gemm_node, "annotations", getattr(self.gemm_node, "ann", {})) or
perform equivalent normalization to ensure consistent behavior.
🧹 Nitpick comments (3)
src/transform/common/gemm.h (1)

11-11: Avoid using namespace directives in header files.

using namespace tir; in a header pollutes the namespace of every translation unit that includes this header, potentially causing name collisions. Move it inside the function body, or qualify names explicitly (e.g., tir::Call).

src/op/gemm.cc (1)

444-446: Consider extracting the repeated annotation-to-GemmInst conversion into a helper.

The identical pattern appears in both Lower (line 444-446) and InferLayout (line 608-610):

ICHECK(annotations_.count("instruction")) << "Gemm instruction is not inferred";
GemmInst gemm_inst =
    static_cast<GemmInst>(Downcast<IntImm>(annotations_.at("instruction"))->value);

A small helper (e.g., GemmInst GemmNode::getInferredInst() const) would reduce duplication and centralize the annotation key.

Also applies to: 608-610

src/transform/infer_instructions.cc (1)

47-74: Clean instruction inference logic with a couple of minor items.

The overall approach — parse GEMM, skip if already annotated, infer and annotate — is solid. Two small notes:

  1. Line 68: Remove the commented-out LOG(INFO) debug artifact.
  2. Lines 48-72: The indentation of the inner block is slightly off (6 spaces instead of the 4 used in the rest of the class). This is cosmetic but inconsistent.
Proposed cleanup
   Stmt VisitStmt_(const EvaluateNode *op) final {
     if (const auto *call = op->value.as<CallNode>()) {
-        // Only handle Gemm operator for now
-        // Copy instruction inference requires layout information, which is not available yet
-        auto gemm_node = TryParseGemmNode(*call);
-        if (!gemm_node.has_value()) {
-          return StmtExprMutator::VisitStmt_(op);
-        }
-        if (call->annotations.count("instruction")) {
-          // Instruction is specified by user, skip inference
-          return StmtExprMutator::VisitStmt_(op);
-        }
-        Map<String, ObjectRef> new_annotations = call->annotations;
-        ICHECK(thread_block_size_ > 0)
-            << "Thread block size not set, ensure AttrStmt with thread_extent is visited first";
-        GemmInst gemm_inst = std::visit(
-            [this](const auto *node) {
-              return node->getGemmInst(thread_block_size_, target_);
-            },
-            *gemm_node);
-        new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst)));
-        // LOG(INFO) << "Inferred GEMM instruction: " << static_cast<int>(gemm_inst);
-        Call new_call =
-            Call(call->dtype, call->op, call->args, new_annotations, call->span);
-        return Evaluate(new_call);
+      // Only handle Gemm operator for now
+      // Copy instruction inference requires layout information, which is not available yet
+      auto gemm_node = TryParseGemmNode(*call);
+      if (!gemm_node.has_value()) {
+        return StmtExprMutator::VisitStmt_(op);
+      }
+      if (call->annotations.count("instruction")) {
+        // Instruction is specified by user, skip inference
+        return StmtExprMutator::VisitStmt_(op);
+      }
+      Map<String, ObjectRef> new_annotations = call->annotations;
+      ICHECK(thread_block_size_ > 0)
+          << "Thread block size not set, ensure AttrStmt with thread_extent is visited first";
+      GemmInst gemm_inst = std::visit(
+          [this](const auto *node) {
+            return node->getGemmInst(thread_block_size_, target_);
+          },
+          *gemm_node);
+      new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst)));
+      Call new_call =
+          Call(call->dtype, call->op, call->args, new_annotations, call->span);
+      return Evaluate(new_call);
     }
     return StmtExprMutator::VisitStmt_(op);
   }

- Introduced a new test file to validate the functionality of the GEMM kernel, ensuring correct behavior for both MMA and WGMMA configurations.
- Implemented a helper function to detect the use of MMA in the generated kernel source.
- Added parameterized tests to check the output against reference results using PyTorch.
- Enhanced the GemmInst class with a string representation method for better debugging.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/__init__.py (1)

159-164: ⚠️ Potential issue | 🟡 Minor

Dead code: is_tcgen5mma() is checked twice — second branch is unreachable.

Line 159 already handles is_tcgen5mma() by returning GemmTCGEN5. The elif on line 163 re-checks the same condition and raises NotImplementedError, but this branch can never be reached.

📝 Proposed fix — remove the dead branch
         elif gemm_inst.is_mfma():
             return GemmMFMA
-        elif gemm_inst.is_tcgen5mma():
-            raise NotImplementedError("TCGEN5MMA is not implemented")
         else:
             raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
🤖 Fix all issues with AI agents
In `@src/op/gemm.cc`:
- Around line 92-94: The code unconditionally reads args[17] and args[18] to set
node->cCoords_ which can OOB; add bounds checks like the existing guards for
args[14]/args[15]/args[16] so you only access args[17] and args[18] when
args.size() > 17 and args.size() > 18 respectively (or a single check for
args.size() > 18) before calling args[17].as<PrimExpr>().value() and
args[18].as<PrimExpr>().value(); leave node->annotations_ assignment unchanged.
Ensure you follow the same pattern used in gemm_sp_py.cc for consistency.
🧹 Nitpick comments (7)
tilelang/tileop/gemm/gemm_cutedsl.py (1)

65-68: _get_inferred_gemm_instruction is duplicated in GemmPy (__init__.py lines 129-133).

Both GemmCuTeDSL and GemmPy define identical _get_inferred_gemm_instruction() methods. Consider moving this to GemmBase to avoid duplication.

tilelang/tileop/gemm/__init__.py (1)

113-115: Workaround property for reflection key mismatch.

This annotations property exists solely because GemmPyNode exposes its annotations as "ann" instead of "annotations". If the reflection key is aligned (as suggested in src/op/gemm_py.h), this alias can be removed.

tilelang/language/gemm_op.py (1)

239-239: Long line with many positional arguments.

This single-line call with 12 positional arguments is hard to read and fragile — a positional mix-up would be a silent bug. Consider splitting across multiple lines or using keyword arguments.

♻️ Suggested formatting
-    return impl(A, B, C, transpose_A, transpose_B, policy, clear_accum, k_pack, wg_wait, mbar, annotations)
+    return impl(
+        A, B, C,
+        transpose_A=transpose_A,
+        transpose_B=transpose_B,
+        policy=policy,
+        clear_accum=clear_accum,
+        k_pack=k_pack,
+        wg_wait=wg_wait,
+        mbar=mbar,
+        annotations=annotations,
+    )
src/op/gemm.cc (1)

444-447: Consider extracting the repeated annotation-to-GemmInst logic into a helper.

The same 4-line pattern (ICHECK + Downcast + static_cast) is duplicated in both Lower and InferLayout. A small helper method on GemmNode (e.g., getAnnotatedGemmInst()) would reduce duplication and centralize validation.

♻️ Example helper in gemm.h
GemmInst getAnnotatedGemmInst() const {
  ICHECK(annotations_.count("instruction"))
      << "Gemm instruction is not inferred";
  return static_cast<GemmInst>(
      Downcast<IntImm>(annotations_.at("instruction"))->value);
}

Then in both Lower and InferLayout:

-  ICHECK(annotations_.count("instruction"))
-      << "Gemm instruction is not inferred";
-  GemmInst gemm_inst = static_cast<GemmInst>(
-      Downcast<IntImm>(annotations_.at("instruction"))->value);
+  GemmInst gemm_inst = getAnnotatedGemmInst();

Also applies to: 609-612

testing/python/transform/test_tilelang_transform_infer_instructions.py (1)

10-24: Detection heuristic is fragile and could silently produce false negatives.

_kernel_uses_mma returns False when no gemm_ss match is found (line 22), silently treating "cannot determine" the same as "not MMA." If the codegen template name ever changes (e.g., gemm_rs, gemm_sr), the test would silently pass for the wrong reason.

Consider raising or logging when the regex doesn't match at all, so a broken heuristic is caught early:

♻️ Proposed improvement
     match = re.search(r"gemm_ss\s*<\s*([^>]+)\s*>", kernel_src)
     if not match:
-        return False
+        raise RuntimeError(
+            "Could not find gemm_ss<...> in kernel source; "
+            "detection heuristic needs updating"
+        )
src/transform/infer_instructions.cc (2)

48-78: Instruction inference logic is clean and correctly preserves user-provided annotations.

The flow — skip if annotation exists, otherwise infer via getGemmInst and inject — is sound. The std::visit dispatch for the GemmNodeVariant is a nice pattern.

One nit: Lines 71-72 contain a commented-out LOG(INFO) statement. Consider removing it before merge to keep the code clean.

🧹 Remove debug artifact
       new_annotations.Set("instruction", Integer(static_cast<int>(gemm_inst)));
-      // LOG(INFO) << "Inferred GEMM instruction: " <<
-      // static_cast<int>(gemm_inst);
       Call new_call =

86-91: Minor: [=] capture is unnecessary.

The lambda captures nothing from the enclosing scope (InferInstructions() is a free function with no locals captured). Using [] or [](PrimFunc f, ...) would be more precise, though this has no functional impact.

♻️ Suggested change
-  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
+  auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {

Comment on lines 92 to +94
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
node->annotations_ = annotations;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate the gemm.cc file
fd -type f -name "gemm.cc"

Repository: tile-ai/tilelang

Length of output: 230


🏁 Script executed:

# Once found, examine the specific lines and surrounding context
cat -n src/op/gemm.cc | sed -n '80,110p'

Repository: tile-ai/tilelang

Length of output: 1195


🏁 Script executed:

# Search for size guards and args bounds checking patterns in the function
rg -n "args\.size()|args\[1[4-9]\]|args\[2[0-9]\]" src/op/gemm.cc -B5 -A5

Repository: tile-ai/tilelang

Length of output: 1173


🏁 Script executed:

# Check the function signature to understand args parameter
rg -n "void.*\(.*args.*\)" src/op/gemm.cc -A2 | head -50

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for guards on args[14], args[15], args[16] to compare
rg -n "args.size()" src/op/gemm.cc -B3 -A3

Repository: tile-ai/tilelang

Length of output: 735


🏁 Script executed:

# Get more context - find the function name
cat -n src/op/gemm.cc | sed -n '40,96p'

Repository: tile-ai/tilelang

Length of output: 2746


🏁 Script executed:

# Search for where Gemm is called to understand args construction
rg -n "Gemm\(" src/op/gemm.cc -B3 -A10 | head -80

Repository: tile-ai/tilelang

Length of output: 1164


🏁 Script executed:

# Check the header file for any documentation about args requirements
fd -type f -name "gemm.h" -o -name "gemm.cc"

Repository: tile-ai/tilelang

Length of output: 230


🏁 Script executed:

# Look for args construction in related files
rg -n "args.*17|args.*18" src/ -C2

Repository: tile-ai/tilelang

Length of output: 885


Add bounds checking before accessing args[17] and args[18].

Lines 92–93 unconditionally access args[17] and args[18] without size validation. The code defensively guards args[14], args[15], and args[16] with if (args.size() > N) checks (lines 74–90), but skips this guard for indices 17–18, creating an out-of-bounds vulnerability if a caller provides fewer than 19 elements. A parallel implementation in gemm_sp_py.cc (line 83) correctly guards args[17], confirming this pattern should be applied here as well.

🤖 Prompt for AI Agents
In `@src/op/gemm.cc` around lines 92 - 94, The code unconditionally reads args[17]
and args[18] to set node->cCoords_ which can OOB; add bounds checks like the
existing guards for args[14]/args[15]/args[16] so you only access args[17] and
args[18] when args.size() > 17 and args.size() > 18 respectively (or a single
check for args.size() > 18) before calling args[17].as<PrimExpr>().value() and
args[18].as<PrimExpr>().value(); leave node->annotations_ assignment unchanged.
Ensure you follow the same pattern used in gemm_sp_py.cc for consistency.

@Rachmanino Rachmanino requested review from LeiWang1999 and removed request for LeiWang1999 February 7, 2026 10:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/gemm_cutedsl.py (1)

30-36: ⚠️ Potential issue | 🔴 Critical

Add TCGEN5MMA handling to GemmCuTeDSL.infer_layout() or prevent TCGEN5MMA from being inferred with CuTeDSL targets.

GemmCuTeDSL.infer_layout() currently only branches on is_wgmma(), routing all other instructions (including TCGEN5MMA) to GemmMMA layout inference. However, TCGEN5MMA can be inferred on SM100 targets with appropriate memory scopes, independent of the CuTeDSL backend selection. This causes TCGEN5MMA to silently use GemmMMA's layout logic instead of GemmTCGEN5's dedicated handling, producing incorrect layouts.

The GemmCuTeDSL docstring explicitly documents support for only "WGMMA/MMA" instructions. Either add explicit TCGEN5MMA and MFMA handling by importing GemmTCGEN5/GemmMFMA and delegating appropriately, or prevent TCGEN5MMA instruction selection when is_cutedsl_target() is true.

@LeiWang1999 LeiWang1999 marked this pull request as draft February 24, 2026 06:12
@LeiWang1999
Copy link
Member

convert this pr into draft for now, inferInstructions may be helpful in future but now it will introduce extra complexity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants