From e25448cddca785a4e8a607ee61202ddc432eaa60 Mon Sep 17 00:00:00 2001 From: Yuan Qin Date: Sat, 6 Jun 2026 01:43:50 +0200 Subject: [PATCH 1/2] add changes --- .../HandshakeCombineSteeringLogic.cpp | 647 ++++++++++++++++-- 1 file changed, 596 insertions(+), 51 deletions(-) diff --git a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp index 7fbbc54dd4..165e8577a8 100644 --- a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp +++ b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp @@ -22,7 +22,10 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include +#include +#include // [START Boilerplate code for the MLIR pass] #include "experimental/Transforms/Passes.h" // IWYU pragma: keep @@ -37,8 +40,147 @@ namespace experimental { using namespace mlir; using namespace dynamatic; +static void logLine(const char *msg) { + std::ofstream f("/home/yuqin/dynamatic-scripts/TempOutputs/" + "HandshakeCombineSteeringLogic.txt", + std::ios::app); + f << msg << "\n"; +} + +static void inheritBB(Operation *from, Operation *to) { + if (auto bbAttr = from->getAttr("handshake.bb")) + to->setAttr("handshake.bb", bbAttr); +} + +static Location getConditionLocOrFallback(Value condition, + Operation *fallback) { + if (Operation *defOp = condition.getDefiningOp()) + return defOp->getLoc(); + return fallback->getLoc(); +} + +static void inheritConditionBBOrFallback(Value condition, Operation *fallback, + Operation *to) { + if (Operation *defOp = condition.getDefiningOp()) { + if (auto bbAttr = defOp->getAttr("handshake.bb")) { + to->setAttr("handshake.bb", bbAttr); + return; + } + } + inheritBB(fallback, to); +} + +static void +refreshBranchAttrsFromCondition(handshake::ConditionalBranchOp branchOp, + Operation *fallback) { + Value condition = branchOp.getConditionOperand(); + branchOp->setLoc(getConditionLocOrFallback(condition, fallback)); + inheritConditionBBOrFallback(condition, fallback, branchOp); +} + +static std::optional getHandshakeBB(Operation *op) { + if (auto bbAttr = op->getAttrOfType("handshake.bb")) + return bbAttr.getUInt(); + return std::nullopt; +} + +/// Returns true when `op` is assigned to a BB whose CFG edge exits a loop. In +/// the serialized CFG annotation, such a BB is a conditional source with one +/// successor going back to an earlier/same BB and the other going forward. +static bool isInLoopExitBB(Operation *op) { + std::optional bb = getHandshakeBB(op); + if (!bb) + return false; + + auto funcOp = op->getParentOfType(); + if (!funcOp) + return false; + + auto edgesAttr = funcOp->getAttrOfType("cfg.edges"); + if (!edgesAttr) + return false; + + SmallVector edges; + edgesAttr.getValue().split(edges, ']'); + for (StringRef edge : edges) { + size_t openBracket = edge.find('['); + if (openBracket == StringRef::npos) + continue; + + StringRef edgeBody = edge.drop_front(openBracket + 1); + SmallVector fields; + edgeBody.split(fields, ','); + if (fields.size() != 4) + continue; + + unsigned source, trueSucc, falseSucc; + if (fields[0].getAsInteger(10, source) || + fields[1].getAsInteger(10, trueSucc) || + fields[2].getAsInteger(10, falseSucc)) + continue; + + if (source != *bb) + continue; + + bool trueIsBackedge = trueSucc <= source; + bool falseIsBackedge = falseSucc <= source; + return trueIsBackedge != falseIsBackedge; + } + + return false; +} + namespace { +/// Check if two values are functionally equivalent: +/// - Same SSA value, OR +/// - Both are ConstantOps with the same attribute value, OR +/// - Both are NotIOps whose inputs are themselves equivalent (recursive) +static bool areEquivalentValues(Value a, Value b) { + if (a == b) + return true; + + if (a.getType() != b.getType()) + return false; + + Operation *defA = a.getDefiningOp(); + Operation *defB = b.getDefiningOp(); + if (!defA || !defB) + return false; + + if (auto constA = dyn_cast(defA)) { + if (auto constB = dyn_cast(defB)) + return constA.getValueAttr() == constB.getValueAttr(); + return false; + } + + if (auto notA = dyn_cast(defA)) { + if (auto notB = dyn_cast(defB)) + return areEquivalentValues(notA.getOperand(), notB.getOperand()); + return false; + } + + return false; +} + +static FailureOr +getSingleConstantOperandIndex(handshake::MergeOp mergeOp) { + int constIdx = -1; + for (int i = 0; i < 2; i++) { + if (!isa_and_nonnull( + mergeOp.getDataOperands()[i].getDefiningOp())) + continue; + + if (constIdx != -1) + return failure(); + constIdx = i; + } + + if (constIdx == -1) + return failure(); + return constIdx; +} + /// Combine redundant init merges. These merges have one constant input and a /// condition input. If two merges are identical, then one of them can be /// removed @@ -51,33 +193,41 @@ struct CombineInits : public OpRewritePattern { if (mergeOp->getNumOperands() != 2) return failure(); - // One of the inputs of the merge must be a constants - int constIdx = -1; - for (int i = 0; i < 2; i++) { - if (isa_and_nonnull( - mergeOp.getDataOperands()[i].getDefiningOp())) - constIdx = i; - } - - if (constIdx == -1) + // Exactly one of the inputs of the merge must be a constant. + FailureOr maybeConstIdx = getSingleConstantOperandIndex(mergeOp); + if (failed(maybeConstIdx)) return failure(); + int constIdx = *maybeConstIdx; // Get the index of the other input int loopIdx = 1 - constIdx; - // If there are other merges fed from the same input at the loopIdx - DenseSet redundantInits; - for (auto *user : mergeOp.getDataOperands()[loopIdx].getUsers()) - if (isa_and_nonnull(user) && user != mergeOp) { - handshake::MergeOp mergeUser = cast(user); - if (isa_and_nonnull( - mergeUser.getDataOperands()[constIdx].getDefiningOp())) - redundantInits.insert(mergeUser); - } + SmallVector redundantInits; + mergeOp->getParentRegion()->walk([&](handshake::MergeOp mergeUser) { + if (mergeUser == mergeOp) + return; + if (mergeUser->getNumOperands() != 2) + return; + + FailureOr maybeUserConstIdx = + getSingleConstantOperandIndex(mergeUser); + if (failed(maybeUserConstIdx) || *maybeUserConstIdx != constIdx) + return; + + if (!areEquivalentValues(mergeUser.getDataOperands()[constIdx], + mergeOp.getDataOperands()[constIdx])) + return; + if (!areEquivalentValues(mergeUser.getDataOperands()[loopIdx], + mergeOp.getDataOperands()[loopIdx])) + return; + + redundantInits.push_back(mergeUser); + }); if (redundantInits.empty()) return failure(); + logLine("[HandshakeCombineSteeringLogic] CombineInits applied"); for (auto init : redundantInits) { rewriter.replaceAllUsesWith(init.getResult(), mergeOp.getResult()); rewriter.eraseOp(init); @@ -87,6 +237,50 @@ struct CombineInits : public OpRewritePattern { } }; +/// Combine NotIOps that have functionally identical inputs. +struct CombineEquivalentNotIOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(handshake::NotIOp notOp, + PatternRewriter &rewriter) const override { + SmallVector redundant; + + notOp->getParentRegion()->walk([&](handshake::NotIOp otherNot) { + if (otherNot == notOp) + return; + if (!areEquivalentValues(otherNot.getOperand(), notOp.getOperand())) + return; + redundant.push_back(otherNot); + }); + + if (redundant.empty()) + return failure(); + + logLine("[HandshakeCombineSteeringLogic] CombineEquivalentNotIOps applied"); + for (auto notUser : redundant) { + rewriter.replaceAllUsesWith(notUser.getResult(), notOp.getResult()); + rewriter.eraseOp(notUser); + } + + return success(); + } +}; + +/// Remove back-to-back NotIOps. +struct RemoveDoubleNotIOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(handshake::NotIOp notOp, + PatternRewriter &rewriter) const override { + auto innerNot = + dyn_cast_or_null(notOp.getOperand().getDefiningOp()); + if (!innerNot) + return failure(); + + logLine("[HandshakeCombineSteeringLogic] RemoveDoubleNotIOp applied"); + rewriter.replaceOp(notOp, innerNot.getOperand()); + return success(); + } +}; + /// Returns true if the loop under analysis has a self regenerating mux. One /// input of the mux comes from the mux itself, while the other input comes from /// somewhere else. @@ -213,11 +407,12 @@ struct CombineMuxes : public OpRewritePattern { if (redundantMuxes.empty()) return failure(); + logLine("[HandshakeCombineSteeringLogic] CombineMuxes applied"); // Loop over redundantMuxes and replace the users of them with the output of // muxOp Note that the users of all redundantMuxes include the Branches // forming cycles with each of them, but as we erase the redundantMuxes, // these Branches will have their two outputs feeding nothing and will be - // erased using the RemoveDoubleSinkBranches + // erased using the RemoveUnusedOp for (auto mux : redundantMuxes) { rewriter.replaceAllUsesWith(mux.getResult(), muxOp.getResult()); rewriter.eraseOp(mux); @@ -227,39 +422,102 @@ struct CombineMuxes : public OpRewritePattern { } }; -/// Remove muxes that have no successors -struct RemoveSinkMuxes : public OpRewritePattern { +/// Combine MuxOps that have functionally identical inputs. +struct CombineEquivalentMuxes : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(handshake::MuxOp muxOp, PatternRewriter &rewriter) const override { - // The pattern fails if the Mux has any successors - if (!muxOp.getResult().getUsers().empty()) + if (muxOp.getNumOperands() != 3) + return failure(); + + SmallVector redundant; + + muxOp->getParentRegion()->walk([&](handshake::MuxOp otherMux) { + if (otherMux == muxOp) + return; + if (otherMux.getNumOperands() != 3) + return; + if (!areEquivalentValues(otherMux.getSelectOperand(), + muxOp.getSelectOperand())) + return; + if (!areEquivalentValues(otherMux.getDataOperands()[0], + muxOp.getDataOperands()[0])) + return; + if (!areEquivalentValues(otherMux.getDataOperands()[1], + muxOp.getDataOperands()[1])) + return; + redundant.push_back(otherMux); + }); + + if (redundant.empty()) return failure(); - rewriter.eraseOp(muxOp); + logLine("[HandshakeCombineSteeringLogic] CombineEquivalentMuxes applied\n"); + for (auto mux : redundant) { + rewriter.replaceAllUsesWith(mux.getResult(), muxOp.getResult()); + rewriter.eraseOp(mux); + } + return success(); } }; -/// Remove conditional branches that have no successors -struct RemoveDoubleSinkBranches +/// Combine ConditionalBranchOps that have functionally identical inputs. +struct CombineEquivalentBranches : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(handshake::ConditionalBranchOp condBranchOp, PatternRewriter &rewriter) const override { - Value branchTrueResult = condBranchOp.getTrueResult(); - Value branchFalseResult = condBranchOp.getFalseResult(); - - // The pattern fails if the branch has either true or false successors - if (!branchTrueResult.getUsers().empty()) + SmallVector redundant; + + condBranchOp->getParentRegion()->walk( + [&](handshake::ConditionalBranchOp otherBr) { + if (otherBr == condBranchOp) + return; + if (!areEquivalentValues(otherBr.getConditionOperand(), + condBranchOp.getConditionOperand())) + return; + if (!areEquivalentValues(otherBr.getDataOperand(), + condBranchOp.getDataOperand())) + return; + redundant.push_back(otherBr); + }); + + if (redundant.empty()) return failure(); - if (!branchFalseResult.getUsers().empty()) - return failure(); + logLine( + "[HandshakeCombineSteeringLogic] CombineEquivalentBranches applied\n"); + for (auto br : redundant) { + rewriter.replaceAllUsesWith(br.getTrueResult(), + condBranchOp.getTrueResult()); + rewriter.replaceAllUsesWith(br.getFalseResult(), + condBranchOp.getFalseResult()); + rewriter.eraseOp(br); + } - rewriter.eraseOp(condBranchOp); + return success(); + } +}; + +/// Remove any op of type OpTy whose results are all unused. +template +struct RemoveUnusedOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // The pattern fails if the Op has any successors + for (auto result : op->getResults()) { + if (!result.use_empty()) + return failure(); + } + + logLine(("[HandshakeCombineSteeringLogic] RemoveUnusedOp<" + + std::string(OpTy::getOperationName()) + "> applied") + .c_str()); + rewriter.eraseOp(op); return success(); } }; @@ -317,6 +575,8 @@ struct CombineBranchesOppositeSign if (redundantBranches.empty()) return failure(); + logLine("[HandshakeCombineSteeringLogic] CombineBranchesOppositeSign " + "applied\n"); // Erase the redundant branch for (auto br : redundantBranches) { rewriter.replaceAllUsesWith(br.getFalseResult(), @@ -356,37 +616,314 @@ struct RemoveNotCondition rewriter.replaceAllUsesWith(condBranchOp.getFalseResult(), newBranch.getTrueResult()); - newBranch->setAttr("handshake.bb", condBranchOp->getAttr("handshake.bb")); + refreshBranchAttrsFromCondition(newBranch, condBranchOp); + rewriter.eraseOp(condBranchOp); + logLine("[HandshakeCombineSteeringLogic] RemoveNotCondition applied\n"); return success(); } }; -/// Remove branches with same data operands and same conditional operand -struct CombineBranchesSameSign +/// When a ConditionalBranch has cond == data (or they differ only by a NOT), +/// each output carries a known boolean. Replace the condition operand of any +/// downstream branch that uses these outputs as condition with a constant +/// 0 or 1, disconnecting the upstream branch's use. +struct SimplifyKnownConditionBranch : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(handshake::ConditionalBranchOp condBranchOp, PatternRewriter &rewriter) const override { + Value condOperand = condBranchOp.getConditionOperand(); Value dataOperand = condBranchOp.getDataOperand(); + + // Match three cases: + // 1) cond == data (direct) + // 2) cond == not(data) (inverted) + // 3) data == not(cond) (inverted) + bool inverted = false; + if (condOperand == dataOperand) { + inverted = false; + } else { + Operation *condDef = condOperand.getDefiningOp(); + Operation *dataDef = dataOperand.getDefiningOp(); + if (isa_and_nonnull(condDef) && + condDef->getOperand(0) == dataOperand) { + inverted = true; + } else if (isa_and_nonnull(dataDef) && + dataDef->getOperand(0) == condOperand) { + inverted = true; + } else { + return failure(); + } + } + + bool changed = false; + + // For a given output of the upstream branch, replace the condition + // of all downstream branches that use it as condition with a constant. + auto replaceDownstreamCond = [&](Value branchOutput, bool outputIsTrue) { + // Runtime boolean value carried by branchOutput: + // direct: true output -> 1, false output -> 0 + // inverted: true output -> 0, false output -> 1 + bool knownCondTrue = inverted ? !outputIsTrue : outputIsTrue; + + // Collect downstream branches using branchOutput as condition + SmallVector toSimplify; + for (auto *user : branchOutput.getUsers()) { + if (auto br = dyn_cast(user)) { + if (br.getConditionOperand() == branchOutput) + toSimplify.push_back(br); + } + } + + for (auto br : toSimplify) { + rewriter.setInsertionPoint(br); + + // Create source as trigger + auto sourceOp = rewriter.create(br.getLoc()); + if (auto bbAttr = br->getAttr("handshake.bb")) + sourceOp->setAttr("handshake.bb", bbAttr); + + // Build the i1 attribute + auto i1Type = rewriter.getIntegerType(1); + auto cstAttr = rewriter.getIntegerAttr(i1Type, knownCondTrue ? 1 : 0); + + // Check if the condition operand is channelified + Type condType = branchOutput.getType(); + handshake::ConstantOp constOp; + + if (auto channelType = dyn_cast(condType)) { + // Channelified: use 4-arg constructor (loc, resultType, attr, ctrl) + // matching the pattern from the existing codebase + constOp = rewriter.create( + br.getLoc(), channelType, cstAttr, sourceOp.getResult()); + } else { + // Raw i1: use 3-arg constructor (loc, attr, ctrl) + constOp = rewriter.create( + br.getLoc(), cstAttr, sourceOp.getResult()); + } + + if (auto bbAttr = br->getAttr("handshake.bb")) + constOp->setAttr("handshake.bb", bbAttr); + + // Replace condition operand of downstream branch + br->setOperand(0, constOp.getResult()); + refreshBranchAttrsFromCondition(br, br); + + changed = true; + } + }; + + replaceDownstreamCond(condBranchOp.getTrueResult(), /*outputIsTrue=*/true); + replaceDownstreamCond(condBranchOp.getFalseResult(), + /*outputIsTrue=*/false); + + if (changed) + logLine("[HandshakeCombineSteeringLogic] SimplifyKnownConditionBranch " + "applied\n"); + return changed ? success() : failure(); + } +}; + +/// Eliminate a ConditionalBranch whose condition is a constant. +/// Short-circuit: the always-taken output is replaced with the data operand, +/// and the branch (along with its feeding constant + source) is erased. +struct EliminateConstantCondBranch + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(handshake::ConditionalBranchOp condBranchOp, + PatternRewriter &rewriter) const override { + Value condOperand = condBranchOp.getConditionOperand(); + auto constOp = + dyn_cast_or_null(condOperand.getDefiningOp()); + if (!constOp) + return failure(); - auto redundantBranches = - findRedundantBranches(condOperand, dataOperand, condBranchOp); + auto constAttr = dyn_cast(constOp.getValueAttr()); + if (!constAttr) + return failure(); - // Nothing to erase - if (redundantBranches.empty()) + bool condIsTrue = constAttr.getValue().getBoolValue(); + + Value takenResult = condIsTrue ? condBranchOp.getTrueResult() + : condBranchOp.getFalseResult(); + Value notTakenResult = condIsTrue ? condBranchOp.getFalseResult() + : condBranchOp.getTrueResult(); + + // Only proceed when the never-taken side has no users + if (!notTakenResult.use_empty()) return failure(); - // Erase the redundant branch - for (auto br : redundantBranches) { - rewriter.replaceAllUsesWith(br.getTrueResult(), - condBranchOp.getTrueResult()); - rewriter.replaceAllUsesWith(br.getFalseResult(), - condBranchOp.getFalseResult()); - rewriter.eraseOp(br); + logLine("[HandshakeCombineSteeringLogic] EliminateConstantCondBranch " + "applied\n"); + // Short-circuit the always-taken side + rewriter.replaceAllUsesWith(takenResult, condBranchOp.getDataOperand()); + + // Erase the branch + rewriter.eraseOp(condBranchOp); + + // Clean up the constant + source if they have no other users + if (constOp.getResult().use_empty()) { + Value trigger = constOp.getCtrl(); + rewriter.eraseOp(constOp); + if (auto sourceOp = + dyn_cast_or_null(trigger.getDefiningOp())) { + if (sourceOp.getResult().use_empty()) + rewriter.eraseOp(sourceOp); + } + } + return success(); + } +}; + +static void inheritBB(Operation *from, Operation *to) { + if (auto bbAttr = from->getAttr("handshake.bb")) + to->setAttr("handshake.bb", bbAttr); +} + +static Location getConditionLocOrFallback(Value condition, + Operation *fallback) { + if (Operation *defOp = condition.getDefiningOp()) + return defOp->getLoc(); + return fallback->getLoc(); +} + +static void inheritConditionBBOrFallback(Value condition, Operation *fallback, + Operation *to) { + if (Operation *defOp = condition.getDefiningOp()) { + if (auto bbAttr = defOp->getAttr("handshake.bb")) { + to->setAttr("handshake.bb", bbAttr); + return; + } + } + inheritBB(fallback, to); +} + +/// Match: +/// br_mux : cond_br (mux %c [d0, d1]), %data +/// br_base : cond_br %c, %data +/// +/// Rewrite br_mux into: +/// br_outer : cond_br %c, %data +/// br_inner : cond_br %x, +/// +/// The newly created outer branch is intentionally left structurally identical +/// to br_base so that CombineEquivalentBranches can fold them afterwards. +/// This rewrite is only valid when: +/// - exactly one mux input is a constant +/// - branch output emptiness matches the constant value: +/// const 1 => true unused, false used +/// const 0 => false unused, true used +/// - if the constant is mux input 0, innerBranch is fed from outer.false +/// - if the constant is mux input 1, innerBranch is fed from outer.true +struct SplitBranchWithMuxCondition + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(handshake::ConditionalBranchOp condBranchOp, + PatternRewriter &rewriter) const override { + + auto muxOp = dyn_cast_or_null( + condBranchOp.getConditionOperand().getDefiningOp()); + if (!muxOp || muxOp.getNumOperands() != 3) + return failure(); + + int constIdx = -1; + int nonConstIdx = -1; + handshake::ConstantOp constOp; + for (int idx = 0; idx < 2; ++idx) { + if (auto candidate = dyn_cast_or_null( + muxOp.getDataOperands()[idx].getDefiningOp())) { + // Checking that we have not found constants before + if (constIdx != -1) + return failure(); + constIdx = idx; + constOp = candidate; + } else { + if (nonConstIdx != -1) + return failure(); + nonConstIdx = idx; + } + } + + if (constIdx == -1 || nonConstIdx == -1) + return failure(); + + auto constAttr = dyn_cast(constOp.getValueAttr()); + if (!constAttr) + return failure(); + bool constValue = constAttr.getValue().getBoolValue(); + + Value baseCond = muxOp.getSelectOperand(); + Value dataOperand = condBranchOp.getDataOperand(); + + // Keep the rewrite profitable: the outer branch should be mergeable with an + // already existing branch on the same data and condition. For loop-exit + // BBs, split anyway to expose the simple exit condition. + bool loopExitBB = isInLoopExitBB(condBranchOp); + if (!loopExitBB && + findRedundantBranches(baseCond, dataOperand, condBranchOp).empty()) + return failure(); + + bool trueEmpty = condBranchOp.getTrueResult().use_empty(); + bool falseEmpty = condBranchOp.getFalseResult().use_empty(); + // Exactly one of the two outputs of the branch must be empty + if (trueEmpty == falseEmpty) + return failure(); + + // The empty output should be consistent with the value of the constant + if (constValue) { + if (!trueEmpty || falseEmpty) + return failure(); + } else { + if (trueEmpty || !falseEmpty) + return failure(); } + + Value nestedCond = muxOp.getDataOperands()[nonConstIdx]; + + rewriter.setInsertionPoint(condBranchOp); + + auto outerBranch = rewriter.create( + getConditionLocOrFallback(baseCond, condBranchOp), baseCond, + dataOperand); + inheritConditionBBOrFallback(baseCond, condBranchOp, outerBranch); + + Value outerToInner = nonConstIdx == 0 ? outerBranch.getFalseResult() + : outerBranch.getTrueResult(); + + auto innerBranch = rewriter.create( + getConditionLocOrFallback(nestedCond, condBranchOp), nestedCond, + outerToInner); + inheritConditionBBOrFallback(nestedCond, condBranchOp, innerBranch); + + logLine(loopExitBB + ? "[HandshakeCombineSteeringLogic] " + "SplitBranchWithMuxCondition applied on loop exit BB\n" + : "[HandshakeCombineSteeringLogic] SplitBranchWithMuxCondition " + "applied\n"); + rewriter.replaceOp(condBranchOp, {innerBranch.getTrueResult(), + innerBranch.getFalseResult()}); + return success(); + } +}; + +struct EliminateMuxWithIdenticalInputs + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(handshake::MuxOp muxOp, + PatternRewriter &rewriter) const override { + ValueRange dataOperands = muxOp.getDataOperands(); + if (dataOperands.size() != 2) + return failure(); + + if (dataOperands[0] != dataOperands[1]) + return failure(); + + rewriter.replaceOp(muxOp, dataOperands[0]); + return success(); } }; @@ -403,9 +940,17 @@ struct HandshakeCombineSteeringLogicPass config.useTopDownTraversal = true; config.enableRegionSimplification = false; RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add, + RemoveUnusedOp, + RemoveUnusedOp, + RemoveUnusedOp, + RemoveUnusedOp, SplitBranchWithMuxCondition, + CombineBranchesOppositeSign, RemoveDoubleNotIOp, + CombineEquivalentNotIOps, CombineInits, CombineMuxes, + RemoveNotCondition, SimplifyKnownConditionBranch, + EliminateConstantCondBranch, CombineEquivalentMuxes, + CombineEquivalentBranches, EliminateMuxWithIdenticalInputs>( + ctx); if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) return signalPassFailure(); }; From b99d5375422dd8e1dfeaad3e47d179873b004b3d Mon Sep 17 00:00:00 2001 From: Yuan Qin Date: Sat, 6 Jun 2026 02:34:33 +0200 Subject: [PATCH 2/2] clean up --- .../HandshakeCombineSteeringLogic.cpp | 74 +++---------------- 1 file changed, 9 insertions(+), 65 deletions(-) diff --git a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp index 165e8577a8..bfb42437c7 100644 --- a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp +++ b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file implements the pass which simplify the resulting FTD circuit by -// merging units which have the smae inputs and the same outputs. +// This file implements the pass which simplifies the resulting FTD circuit by +// merging units which have the same inputs and the same outputs. // //===----------------------------------------------------------------------===// @@ -21,10 +21,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include -#include #include // [START Boilerplate code for the MLIR pass] @@ -40,13 +37,6 @@ namespace experimental { using namespace mlir; using namespace dynamatic; -static void logLine(const char *msg) { - std::ofstream f("/home/yuqin/dynamatic-scripts/TempOutputs/" - "HandshakeCombineSteeringLogic.txt", - std::ios::app); - f << msg << "\n"; -} - static void inheritBB(Operation *from, Operation *to) { if (auto bbAttr = from->getAttr("handshake.bb")) to->setAttr("handshake.bb", bbAttr); @@ -227,7 +217,6 @@ struct CombineInits : public OpRewritePattern { if (redundantInits.empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] CombineInits applied"); for (auto init : redundantInits) { rewriter.replaceAllUsesWith(init.getResult(), mergeOp.getResult()); rewriter.eraseOp(init); @@ -255,7 +244,6 @@ struct CombineEquivalentNotIOps : public OpRewritePattern { if (redundant.empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] CombineEquivalentNotIOps applied"); for (auto notUser : redundant) { rewriter.replaceAllUsesWith(notUser.getResult(), notOp.getResult()); rewriter.eraseOp(notUser); @@ -275,7 +263,6 @@ struct RemoveDoubleNotIOp : public OpRewritePattern { if (!innerNot) return failure(); - logLine("[HandshakeCombineSteeringLogic] RemoveDoubleNotIOp applied"); rewriter.replaceOp(notOp, innerNot.getOperand()); return success(); } @@ -300,7 +287,6 @@ bool isSelfRegenerateMux(handshake::MuxOp muxOp, int &muxCycleInputIdx) { // cycle bool foundCycle = false; int operIdx = 0; - handshake::ConditionalBranchOp condBranchOp; for (auto muxOperand : muxOp.getDataOperands()) { auto *op = muxOperand.getDefiningOp(); @@ -309,7 +295,6 @@ bool isSelfRegenerateMux(handshake::MuxOp muxOp, int &muxCycleInputIdx) { if (branches.contains(br)) { foundCycle = true; muxCycleInputIdx = operIdx; - condBranchOp = br; break; } } @@ -355,8 +340,8 @@ Operation *returnMuxAtSameDepth(Operation *op, // conventions about the index of the input coming from outside the loop and // that coming from inside through a cycle // This pattern combines all Muxes that are used to regenerate the same value -// but to different consumers.. It searches for a Mux that has a bwd edge -// (cyclic input) and searches for all Muxes using the some condition and also +// but to different consumers. It searches for a Mux that has a bwd edge +// (cyclic input) and searches for all Muxes using the same condition and also // having a bwd edge struct CombineMuxes : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -381,7 +366,7 @@ struct CombineMuxes : public OpRewritePattern { // traversal and return its produced value Value valProducedByNonMux = returnNonMuxProducerVal(muxOp, muxOutIdx); - // Get users of the non-Mux operation at the muxOuterInputIdx + // Get users of the non-Mux operation at muxOutIdx for (auto *dataUser : valProducedByNonMux.getUsers()) { Operation *returnedMux = returnMuxAtSameDepth(dataUser, muxOp); if (returnedMux != nullptr) { @@ -407,7 +392,6 @@ struct CombineMuxes : public OpRewritePattern { if (redundantMuxes.empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] CombineMuxes applied"); // Loop over redundantMuxes and replace the users of them with the output of // muxOp Note that the users of all redundantMuxes include the Branches // forming cycles with each of them, but as we erase the redundantMuxes, @@ -453,7 +437,6 @@ struct CombineEquivalentMuxes : public OpRewritePattern { if (redundant.empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] CombineEquivalentMuxes applied\n"); for (auto mux : redundant) { rewriter.replaceAllUsesWith(mux.getResult(), muxOp.getResult()); rewriter.eraseOp(mux); @@ -488,8 +471,6 @@ struct CombineEquivalentBranches if (redundant.empty()) return failure(); - logLine( - "[HandshakeCombineSteeringLogic] CombineEquivalentBranches applied\n"); for (auto br : redundant) { rewriter.replaceAllUsesWith(br.getTrueResult(), condBranchOp.getTrueResult()); @@ -514,9 +495,6 @@ struct RemoveUnusedOp : public OpRewritePattern { return failure(); } - logLine(("[HandshakeCombineSteeringLogic] RemoveUnusedOp<" + - std::string(OpTy::getOperationName()) + "> applied") - .c_str()); rewriter.eraseOp(op); return success(); } @@ -575,8 +553,6 @@ struct CombineBranchesOppositeSign if (redundantBranches.empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] CombineBranchesOppositeSign " - "applied\n"); // Erase the redundant branch for (auto br : redundantBranches) { rewriter.replaceAllUsesWith(br.getFalseResult(), @@ -590,7 +566,9 @@ struct CombineBranchesOppositeSign } }; -/// Remove branches with same data operands and same conditional operand +/// If a branch's condition is a NotIOp, rewrite it into an equivalent branch +/// driven directly by the NOT's input, with the true/false outputs swapped. +/// This drops the NOT from the condition path. struct RemoveNotCondition : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -619,7 +597,6 @@ struct RemoveNotCondition refreshBranchAttrsFromCondition(newBranch, condBranchOp); rewriter.eraseOp(condBranchOp); - logLine("[HandshakeCombineSteeringLogic] RemoveNotCondition applied\n"); return success(); } }; @@ -719,9 +696,6 @@ struct SimplifyKnownConditionBranch replaceDownstreamCond(condBranchOp.getFalseResult(), /*outputIsTrue=*/false); - if (changed) - logLine("[HandshakeCombineSteeringLogic] SimplifyKnownConditionBranch " - "applied\n"); return changed ? success() : failure(); } }; @@ -756,8 +730,6 @@ struct EliminateConstantCondBranch if (!notTakenResult.use_empty()) return failure(); - logLine("[HandshakeCombineSteeringLogic] EliminateConstantCondBranch " - "applied\n"); // Short-circuit the always-taken side rewriter.replaceAllUsesWith(takenResult, condBranchOp.getDataOperand()); @@ -778,29 +750,6 @@ struct EliminateConstantCondBranch } }; -static void inheritBB(Operation *from, Operation *to) { - if (auto bbAttr = from->getAttr("handshake.bb")) - to->setAttr("handshake.bb", bbAttr); -} - -static Location getConditionLocOrFallback(Value condition, - Operation *fallback) { - if (Operation *defOp = condition.getDefiningOp()) - return defOp->getLoc(); - return fallback->getLoc(); -} - -static void inheritConditionBBOrFallback(Value condition, Operation *fallback, - Operation *to) { - if (Operation *defOp = condition.getDefiningOp()) { - if (auto bbAttr = defOp->getAttr("handshake.bb")) { - to->setAttr("handshake.bb", bbAttr); - return; - } - } - inheritBB(fallback, to); -} - /// Match: /// br_mux : cond_br (mux %c [d0, d1]), %data /// br_base : cond_br %c, %data @@ -898,11 +847,6 @@ struct SplitBranchWithMuxCondition outerToInner); inheritConditionBBOrFallback(nestedCond, condBranchOp, innerBranch); - logLine(loopExitBB - ? "[HandshakeCombineSteeringLogic] " - "SplitBranchWithMuxCondition applied on loop exit BB\n" - : "[HandshakeCombineSteeringLogic] SplitBranchWithMuxCondition " - "applied\n"); rewriter.replaceOp(condBranchOp, {innerBranch.getTrueResult(), innerBranch.getFalseResult()}); return success(); @@ -928,7 +872,7 @@ struct EliminateMuxWithIdenticalInputs } }; -/// Simple driver for the Handshake Combine Branches Merges pass, based on a +/// Simple driver for the Handshake Combine Steering Logic pass, based on a /// greedy pattern rewriter. struct HandshakeCombineSteeringLogicPass : public dynamatic::experimental::impl::HandshakeCombineSteeringLogicBase<