Skip to content

Commit e66ed8d

Browse files
committed
symbolizing composite PrimExpr
1 parent 2265bd1 commit e66ed8d

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

src/relax/transform/canonicalize_bindings.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,74 @@ class SymbolicVarCanonicalizer : public ExprMutator {
134134
return output;
135135
}
136136

137+
Expr VisitExpr_(const ShapeExprNode* op) override {
138+
// For each dimension, check if it is a composite expression that symbolization
139+
ffi::Array<PrimExpr> new_values;
140+
bool changed = false;
141+
142+
for (const auto& dim : op->values) {
143+
PrimExpr new_dim = VisitPrimExpr(dim);
144+
145+
// Check if this is a composite expression (not a constant or simple variable)
146+
if (IsCompositePrimExpr(new_dim)) {
147+
// Introduce a new symbolic variable for this composite expression
148+
tir::Var symbolic_var = CreateSymbolicVar(new_dim);
149+
new_values.push_back(symbolic_var);
150+
changed = true;
151+
} else {
152+
new_values.push_back(new_dim);
153+
if (!new_dim.same_as(dim)) {
154+
changed = true;
155+
}
156+
}
157+
}
158+
159+
if (!changed) {
160+
return ffi::GetRef<Expr>(op);
161+
}
162+
163+
return ShapeExpr(new_values);
164+
}
165+
137166
private:
138167
struct KnownValue {
139168
PrimExpr expr;
140169
MatchCast source;
141170
};
142171

172+
bool IsCompositePrimExpr(const PrimExpr& expr) {
173+
// Constants and simple variables are not composite
174+
if (expr->IsInstance<tir::IntImmNode>() || expr->IsInstance<tir::FloatImmNode>() ||
175+
expr->IsInstance<tir::VarNode>()) {
176+
return false;
177+
}
178+
179+
// Check if the expression contains variables
180+
auto vars = tir::UndefinedVars(expr);
181+
182+
// If it has variables, it's composite (e.g., x * y, x + 1, etc.)
183+
return vars.size() >= 1;
184+
}
185+
186+
tir::Var CreateSymbolicVar(const PrimExpr& expr) {
187+
tir::Var symbolic_var("composite_" + std::to_string(composite_counter_++), expr->dtype);
188+
189+
// Create PrimValue for the composite expression
190+
PrimValue prim_val(expr);
191+
PrimStructInfo prim_sinfo(symbolic_var);
192+
Var relax_var("comp_val_" + std::to_string(composite_counter_ - 1), prim_sinfo);
193+
194+
// Emit MatchCast to define the symbolic variable
195+
auto match_cast = MatchCast(relax_var, prim_val, prim_sinfo);
196+
builder_->Emit(match_cast);
197+
198+
known_values_[symbolic_var] = KnownValue{expr, match_cast};
199+
200+
return symbolic_var;
201+
}
202+
143203
std::unordered_map<tir::Var, KnownValue> known_values_;
204+
int composite_counter_ = 0;
144205
};
145206

146207
struct CanonicalizationPlan {

0 commit comments

Comments
 (0)