@@ -134,13 +134,74 @@ class SymbolicVarCanonicalizer : public ExprMutator {
134134 return output;
135135 }
136136
137+ Expr VisitExpr_ (const ShapeExprNode* op) override {
138+ // For each dimension, check if it is a composite expression that symbolization
139+ ffi::Array<PrimExpr> new_values;
140+ bool changed = false ;
141+
142+ for (const auto & dim : op->values ) {
143+ PrimExpr new_dim = VisitPrimExpr (dim);
144+
145+ // Check if this is a composite expression (not a constant or simple variable)
146+ if (IsCompositePrimExpr (new_dim)) {
147+ // Introduce a new symbolic variable for this composite expression
148+ tir::Var symbolic_var = CreateSymbolicVar (new_dim);
149+ new_values.push_back (symbolic_var);
150+ changed = true ;
151+ } else {
152+ new_values.push_back (new_dim);
153+ if (!new_dim.same_as (dim)) {
154+ changed = true ;
155+ }
156+ }
157+ }
158+
159+ if (!changed) {
160+ return ffi::GetRef<Expr>(op);
161+ }
162+
163+ return ShapeExpr (new_values);
164+ }
165+
137166 private:
138167 struct KnownValue {
139168 PrimExpr expr;
140169 MatchCast source;
141170 };
142171
172+ bool IsCompositePrimExpr (const PrimExpr& expr) {
173+ // Constants and simple variables are not composite
174+ if (expr->IsInstance <tir::IntImmNode>() || expr->IsInstance <tir::FloatImmNode>() ||
175+ expr->IsInstance <tir::VarNode>()) {
176+ return false ;
177+ }
178+
179+ // Check if the expression contains variables
180+ auto vars = tir::UndefinedVars (expr);
181+
182+ // If it has variables, it's composite (e.g., x * y, x + 1, etc.)
183+ return vars.size () >= 1 ;
184+ }
185+
186+ tir::Var CreateSymbolicVar (const PrimExpr& expr) {
187+ tir::Var symbolic_var (" composite_" + std::to_string (composite_counter_++), expr->dtype );
188+
189+ // Create PrimValue for the composite expression
190+ PrimValue prim_val (expr);
191+ PrimStructInfo prim_sinfo (symbolic_var);
192+ Var relax_var (" comp_val_" + std::to_string (composite_counter_ - 1 ), prim_sinfo);
193+
194+ // Emit MatchCast to define the symbolic variable
195+ auto match_cast = MatchCast (relax_var, prim_val, prim_sinfo);
196+ builder_->Emit (match_cast);
197+
198+ known_values_[symbolic_var] = KnownValue{expr, match_cast};
199+
200+ return symbolic_var;
201+ }
202+
143203 std::unordered_map<tir::Var, KnownValue> known_values_;
204+ int composite_counter_ = 0 ;
144205};
145206
146207struct CanonicalizationPlan {
0 commit comments