Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,6 @@ tvm-site/

# GDB history file
.gdb_history

# Less command history file
.lesshst
31 changes: 29 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,14 +2356,41 @@ def fuse_reduction_epilogue(
It requires:
1) The reduction block is a complete reduction block
2) The epilogue block only reads from the reduction block's output
3) The epilogue performs a simple addition: output = reduction_result + bias
3) The epilogue matches one of the supported patterns:
- Bias: ``output = reduction_result + bias``
- BiasReLU: ``output = max(reduction_result + bias, 0)``
- Clipping: ``output = min(max(reduction_result, lower), upper)``
or their commutative variants

.. warning::

**Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):**

For non-linear epilogues (BiasReLU and Clipping), fusion changes the
computation semantics from post-reduction application to per-iteration
application. This can lead to different numerical results.

**Example with Clipping to [-5, 5] and inputs [6, -2]:**

- **Post-reduction clipping** (original): ``clip(sum([6, -2])) = clip(4) = 4``
- **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 → clip(5+(-2))=3``

The fused version applies clipping at each reduction iteration, which
may be an intended optimization for some models but can cause unexpected
correctness issues if users are not aware of this behavior.

For linear epilogues (Bias), fusion preserves exact numerical equivalence.

Parameters
----------
reduction_block : Union[BlockRV, str]
The reduction block (e.g., matmul)
epilogue_block : Union[BlockRV, str]
The epilogue block to be fused (e.g., bias add)
The epilogue block to be fused (e.g., bias add, ReLU, clipping)

Examples
--------
See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples.
"""
reduction_block = self._normalize_block_arg(reduction_block)
epilogue_block = self._normalize_block_arg(epilogue_block)
Expand Down
224 changes: 209 additions & 15 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -988,14 +988,33 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
* \brief Helper to fuse epilogue block into reduction block
* Analyzes epilogue pattern and transforms reduction init/update
*/
// Epilogue type enumeration
enum class EpilogueType {
Bias, // temp + C
BiasReLU, // max(temp + C, 0)
Clipping, // min(max(temp, lower), upper)
};

class ReductionEpilogueFuser : public BaseInliner {
public:
explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block,
const BlockRealize& epilogue_block_realize,
const StmtSRef& scope_root_sref)
: BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref),
reduction_block_(reduction_block),
epilogue_block_(epilogue_block_realize->block.get()) {}
epilogue_block_(epilogue_block_realize->block.get()),
epilogue_type_(EpilogueType::Bias) {
// Disable opaque access check for epilogue fusion
// Epilogue blocks can read multiple buffers (temp + bias), which is allowed
has_opaque_access = false;
}

// Override CheckOpaqueAccess to allow multiple buffer reads
void CheckOpaqueAccess(const VarNode* buffer_var) {
// For epilogue fusion, we allow multiple buffer reads (temp + bias)
// So we don't check for opaque access
// BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class
}

bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);

Expand All @@ -1012,18 +1031,21 @@ class ReductionEpilogueFuser : public BaseInliner {
const BufferStoreNode* from) {
struct Extractor : public ExprVisitor {
void VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.get() == buffer) {
if (load->buffer.same_as(buffer)) {
result.push_back(load);
}
// Continue visiting child nodes (indices)
ExprVisitor::VisitExpr_(load);
}
const BufferNode* buffer;
Buffer buffer;
std::vector<const BufferLoadNode*> result;
} extractor;
extractor.buffer = buffer.get();
extractor.buffer = buffer;
// Visit indices first (though they typically don't contain BufferLoad)
for (const PrimExpr& expr : from->indices) {
extractor(expr);
}
// Visit the value expression (e.g., max(temp + C, 0) for ReLU)
extractor(from->value);
return std::move(extractor.result);
}
Expand All @@ -1036,6 +1058,9 @@ class ReductionEpilogueFuser : public BaseInliner {
BufferRegion epilogue_output_region_{nullptr}; // Write region of D
Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
EpilogueType epilogue_type_; // Type of epilogue operation
PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping
PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping
};

bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) {
Expand All @@ -1058,26 +1083,36 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
return false;
}

// 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
// 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
// D[i,j] = min(max(temp[i,j], lower), upper)
if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
// Failure: epilogue is not a simple addition pattern
// Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping)
return false;
}

// 5. Verify temp appears exactly once in the epilogue pattern
// This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping)
// The reduction result buffer must be used exactly once in the epilogue expression
if (loads.size() != 1) {
// Failure: The reduction result (temp) must be used exactly once in the
// epilogue expression for fusion.
return false;
}

// 5. Check if producer is a reduction block
// 6. Check if producer is a reduction block
if (!IsReductionBlock(reduction_block_)) {
// Failure: producer is not a reduction block
return false;
}

// 6. Extract epilogue information (output buffer, indices, regions, etc.)
// 7. Extract epilogue information (output buffer, indices, regions, etc.)
ExtractEpilogueInfo();

return true;
}

bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
// Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
if (const auto* add = value.as<AddNode>()) {
const auto* load_a = add->a.as<BufferLoadNode>();
const auto* load_b = add->b.as<BufferLoadNode>();
Expand All @@ -1088,10 +1123,125 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
epilogue_type_ = EpilogueType::Bias;
return true;
}
}

// Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping)
// Handle all commutative variants of min/max at each level.

// Helper to check if an expression is a load from the reduction buffer, and
// return the other operand as `other` if so.
auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b,
PrimExpr* other) -> bool {
if (const auto* load_a = a.as<BufferLoadNode>()) {
if (load_a->buffer.same_as(inlined_buffer_)) {
*other = b;
return true;
}
}
if (const auto* load_b = b.as<BufferLoadNode>()) {
if (load_b->buffer.same_as(inlined_buffer_)) {
*other = a;
return true;
}
}
return false;
};

// Check for min(max(temp, lower), upper) and commutative variants
if (const auto* min_node = value.as<MinNode>()) {
const MaxNode* max_node = nullptr;
PrimExpr upper;
// Try both (a, b) as possible positions of the inner max
if ((max_node = min_node->a.as<MaxNode>())) {
upper = min_node->b;
} else if ((max_node = min_node->b.as<MaxNode>())) {
upper = min_node->a;
}
if (max_node != nullptr) {
PrimExpr lower;
if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
clipping_lower_ = lower;
clipping_upper_ = upper;
epilogue_type_ = EpilogueType::Clipping;
return true;
}
}
}

// Check for max(min(temp[i,j], upper), lower) and commutative variants
if (const auto* max_node = value.as<MaxNode>()) {
const MinNode* min_node = nullptr;
PrimExpr lower;
// Try both (a, b) as possible positions of the inner min
if ((min_node = max_node->a.as<MinNode>())) {
lower = max_node->b;
} else if ((min_node = max_node->b.as<MinNode>())) {
lower = max_node->a;
}
if (min_node != nullptr) {
PrimExpr upper;
if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
clipping_lower_ = lower;
clipping_upper_ = upper;
epilogue_type_ = EpilogueType::Clipping;
return true;
}
}
}

// Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU)
// Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
if (const auto* max_node = value.as<MaxNode>()) {
// Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
// Support both integer and float zero constants.
const PrimExpr* add_candidate = nullptr;
bool is_zero_const = false;
auto is_zero_expr = [](const PrimExpr& expr) -> bool {
if (tir::is_zero(expr)) {
return true;
}
if (const auto* float_imm = expr.as<FloatImmNode>()) {
return float_imm->value == 0.0;
}
return false;
};

if (is_zero_expr(max_node->a)) {
is_zero_const = true;
add_candidate = &max_node->b;
} else if (is_zero_expr(max_node->b)) {
is_zero_const = true;
add_candidate = &max_node->a;
}

if (is_zero_const && add_candidate != nullptr) {
if (const auto* add = add_candidate->as<AddNode>()) {
const auto* load_a = add->a.as<BufferLoadNode>();
const auto* load_b = add->b.as<BufferLoadNode>();

bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);

// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
epilogue_type_ = EpilogueType::BiasReLU;
return true;
}
} else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
// Handle bias-free ReLU: max(temp, 0) or max(0, temp)
if (load->buffer.same_as(inlined_buffer_)) {
epilogue_addend_ = tir::make_zero(load->dtype);
epilogue_type_ = EpilogueType::BiasReLU;
return true;
}
}
}
}

return false;
}

Expand Down Expand Up @@ -1158,20 +1308,54 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}

// 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map),
Substitute(epilogue_output_indices_, var_map));
// 2. Change init to epilogue value based on epilogue type
BufferStore new_init_store;
if (epilogue_type_ == EpilogueType::BiasReLU) {
// For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics
PrimExpr init_value = Substitute(epilogue_addend_, var_map);
PrimExpr zero = tir::make_zero(init_value.dtype());
new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero),
Substitute(epilogue_output_indices_, var_map));
} else if (epilogue_type_ == EpilogueType::Clipping) {
// For Clipping, init should be min(max(init_value, lower), upper)
// Since init is typically 0, this becomes min(max(0, lower), upper)
PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)),
Substitute(clipping_upper_, var_map));
new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
Substitute(epilogue_output_indices_, var_map));
} else {
// Bias: D[vi, vj] = C[vi, vj]
new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map),
Substitute(epilogue_output_indices_, var_map));
}
new_block->init = new_init_store;

// 3. Replace output buffer from temp to D in body
class BufferReplacer : public StmtExprMutator {
public:
BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {}
BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype,
PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr())
: old_buffer_(old_buf),
new_buffer_(new_buf),
epilogue_type_(epilogue_type),
dtype_(dtype),
clipping_lower_(clipping_lower),
clipping_upper_(clipping_upper) {}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
return BufferStore(new_buffer_, store->value, store->indices);
PrimExpr new_value = store->value;
// For ReLU, apply max per iteration to match per-iteration ReLU semantics
if (epilogue_type_ == EpilogueType::BiasReLU) {
PrimExpr zero = tir::make_zero(dtype_);
new_value = Max(new_value, zero);
} else if (epilogue_type_ == EpilogueType::Clipping) {
// For Clipping, apply min(max(value, lower), upper) per iteration
new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
}
return BufferStore(new_buffer_, new_value, store->indices);
}
return store;
}
Expand All @@ -1187,9 +1371,19 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
private:
Buffer old_buffer_;
Buffer new_buffer_;
EpilogueType epilogue_type_;
DataType dtype_;
PrimExpr clipping_lower_;
PrimExpr clipping_upper_;
};

BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_);
DataType dtype = epilogue_output_buffer_->dtype;
PrimExpr clipping_lower_subst =
epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr();
PrimExpr clipping_upper_subst =
epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr();
BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype,
clipping_lower_subst, clipping_upper_subst);
new_block->body = replacer(reduction_block->body);

// 4. Update write regions
Expand Down
Loading