Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,6 @@ tvm-site/

# GDB history file
.gdb_history

# Less command history file
.lesshst
97 changes: 86 additions & 11 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,12 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
* \brief Helper to fuse epilogue block into reduction block
* Analyzes epilogue pattern and transforms reduction init/update
*/
// Epilogue type enumeration
enum class EpilogueType {
Bias, // temp + C
BiasReLU, // max(temp + C, 0)
};

class ReductionEpilogueFuser : public BaseInliner {
public:
explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block,
Expand All @@ -996,7 +1002,19 @@ class ReductionEpilogueFuser : public BaseInliner {
: BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref),
reduction_block_(reduction_block),
epilogue_block_(epilogue_block_realize->block.get()),
mod_(mod) {}
mod_(mod),
epilogue_type_(EpilogueType::Bias) {
// Disable opaque access check for epilogue fusion
// Epilogue blocks can read multiple buffers (temp + bias), which is allowed
has_opaque_access = false;
}

// Override CheckOpaqueAccess to allow multiple buffer reads
void CheckOpaqueAccess(const VarNode* buffer_var) {
// For epilogue fusion, we allow multiple buffer reads (temp + bias)
// So we don't check for opaque access
// BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class
}

bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);

Expand All @@ -1013,18 +1031,21 @@ class ReductionEpilogueFuser : public BaseInliner {
const BufferStoreNode* from) {
struct Extractor : public ExprVisitor {
void VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.get() == buffer) {
if (load->buffer.same_as(buffer)) {
result.push_back(load);
}
// Continue visiting child nodes (indices)
ExprVisitor::VisitExpr_(load);
}
const BufferNode* buffer;
Buffer buffer;
std::vector<const BufferLoadNode*> result;
} extractor;
extractor.buffer = buffer.get();
extractor.buffer = buffer;
// Visit indices first (though they typically don't contain BufferLoad)
for (const PrimExpr& expr : from->indices) {
extractor(expr);
}
// Visit the value expression (e.g., max(temp + C, 0) for ReLU)
extractor(from->value);
return std::move(extractor.result);
}
Expand All @@ -1038,6 +1059,7 @@ class ReductionEpilogueFuser : public BaseInliner {
BufferRegion epilogue_output_region_{nullptr}; // Write region of D
Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
EpilogueType epilogue_type_; // Type of epilogue operation
};

bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) {
Expand Down Expand Up @@ -1079,7 +1101,7 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
}

bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
// Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
if (const auto* add = value.as<AddNode>()) {
const auto* load_a = add->a.as<BufferLoadNode>();
const auto* load_b = add->b.as<BufferLoadNode>();
Expand All @@ -1090,10 +1112,40 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
epilogue_type_ = EpilogueType::Bias;
return true;
}
}

// Pattern 2: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU)
if (const auto* max_node = value.as<MaxNode>()) {
// Check if second operand is zero (ReLU: max(x, 0))
// Support both integer and float zero constants
bool is_zero_const = false;
if (tir::is_zero(max_node->b)) {
is_zero_const = true;
} else if (const auto* float_imm = max_node->b.as<FloatImmNode>()) {
is_zero_const = (float_imm->value == 0.0);
}
Comment on lines +1124 to +1129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tir::is_zero function handles both integer and floating-point zero constants, so the else if condition checking for FloatImmNode is redundant. This check can be simplified.

Suggested change
bool is_zero_const = false;
if (tir::is_zero(max_node->b)) {
is_zero_const = true;
} else if (const auto* float_imm = max_node->b.as<FloatImmNode>()) {
is_zero_const = (float_imm->value == 0.0);
}
bool is_zero_const = tir::is_zero(max_node->b);

if (is_zero_const) {
// Check if first operand is AddNode
if (const auto* add = max_node->a.as<AddNode>()) {
const auto* load_a = add->a.as<BufferLoadNode>();
const auto* load_b = add->b.as<BufferLoadNode>();

bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);

// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
epilogue_type_ = EpilogueType::BiasReLU;
return true;
}
}
}
}

return false;
}

Expand Down Expand Up @@ -1160,20 +1212,40 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}

// 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map),
Substitute(epilogue_output_indices_, var_map));
// 2. Change init to epilogue value based on epilogue type
BufferStore new_init_store;
if (epilogue_type_ == EpilogueType::BiasReLU) {
// For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics
PrimExpr init_value = Substitute(epilogue_addend_, var_map);
PrimExpr zero = tir::make_zero(init_value.dtype());
new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero),
Substitute(epilogue_output_indices_, var_map));
} else {
// Bias: D[vi, vj] = C[vi, vj]
new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map),
Substitute(epilogue_output_indices_, var_map));
}
new_block->init = new_init_store;

// 3. Replace output buffer from temp to D in body
class BufferReplacer : public StmtExprMutator {
public:
BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {}
BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype)
: old_buffer_(old_buf),
new_buffer_(new_buf),
epilogue_type_(epilogue_type),
dtype_(dtype) {}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
return BufferStore(new_buffer_, store->value, store->indices);
PrimExpr new_value = store->value;
// For ReLU, apply max per iteration to match per-iteration ReLU semantics
if (epilogue_type_ == EpilogueType::BiasReLU) {
PrimExpr zero = tir::make_zero(dtype_);
new_value = Max(new_value, zero);
}
return BufferStore(new_buffer_, new_value, store->indices);
}
return store;
}
Expand All @@ -1189,9 +1261,12 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
private:
Buffer old_buffer_;
Buffer new_buffer_;
EpilogueType epilogue_type_;
DataType dtype_;
};

BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_);
DataType dtype = epilogue_output_buffer_->dtype;
BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype);
new_block->body = replacer(reduction_block->body);

// 4. Update write regions
Expand Down
Loading