diff --git a/src/AddAtomicMutex.cpp b/src/AddAtomicMutex.cpp index cf3b0ae8bb89..ef1259ab435e 100644 --- a/src/AddAtomicMutex.cpp +++ b/src/AddAtomicMutex.cpp @@ -13,278 +13,184 @@ namespace Internal { namespace { /** Collect names of all stores matching the producer name inside a statement. */ -class CollectProducerStoreNames : public IRVisitor { -public: - CollectProducerStoreNames(const std::string &producer_name) - : producer_name(producer_name) { - } - +Scope collect_producer_store_names(const Stmt &s, const std::string &producer_name) { Scope store_names; - -protected: - using IRVisitor::visit; - - void visit(const Store *op) override { - IRVisitor::visit(op); + visit_with(s, [&](auto *self, const Store *op) { + self->visit_base(op); if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { // This is a Store for the designated Producer. store_names.push(op->name); } - } - - const std::string &producer_name; -}; + }); + return store_names; +} /** Find Store inside of an Atomic node for the designated producer * and return their indices. */ -class FindProducerStoreIndex : public IRVisitor { -public: - FindProducerStoreIndex(const std::string &producer_name) - : producer_name(producer_name) { - } - - Expr index; // The returned index. - -protected: - using IRVisitor::visit; - - // Need to also extract the let bindings of a Store index. - void visit(const Let *op) override { - IRVisitor::visit(op); // Make sure we visit the Store first. - if (index.defined()) { - if (expr_uses_var(index, op->name)) { - index = Let::make(op->name, op->value, index); - } - } - } - void visit(const LetStmt *op) override { - IRVisitor::visit(op); // Make sure we visit the Store first. - if (index.defined()) { - if (expr_uses_var(index, op->name)) { - index = Let::make(op->name, op->value, index); - } - } - } - - void visit(const Store *op) override { - IRVisitor::visit(op); - if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { - // This is a Store for the designated producer. - - // Ideally we want to insert equal() checks here for different stores, - // but the indices of them actually are different in the case of tuples, - // since they usually refer to the strides/min/extents of their own tuple - // buffers. However, different elements in a tuple would have the same - // strides/min/extents so we are fine. +Expr find_producer_store_index(const Stmt &s, const std::string &producer_name) { + Expr index; + visit_with( + s, + // Need to also extract the let bindings of a Store index. + [&](auto *self, const Let *op) { + self->visit_base(op); // Make sure we visit the Store first. if (index.defined()) { - return; - } - index = op->index; - } - } - - const std::string &producer_name; -}; + if (expr_uses_var(index, op->name)) { + index = Let::make(op->name, op->value, index); + } + } // + }, + [&](auto *self, const LetStmt *op) { + self->visit_base(op); // Make sure we visit the Store first. + if (index.defined()) { + if (expr_uses_var(index, op->name)) { + index = Let::make(op->name, op->value, index); + } + } // + }, + [&](auto *self, const Store *op) { + self->visit_base(op); + if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { + // This is a Store for the designated producer. + + // Ideally we want to insert equal() checks here for different stores, + // but the indices of them actually are different in the case of tuples, + // since they usually refer to the strides/min/extents of their own tuple + // buffers. However, different elements in a tuple would have the same + // strides/min/extents so we are fine. + if (index.defined()) { + return; + } + index = op->index; + } // + }); + return index; +} /** Throws an assertion for cases where the indexing on left-hand-side of * an atomic update references to itself. * e.g. f(clamp(f(r), 0, 100)) = f(r) + 1 should be rejected. */ -class CheckAtomicValidity : public IRVisitor { -protected: - using IRVisitor::visit; - - void visit(const Atomic *op) override { +bool check_atomic_validity(const Stmt &s) { + bool any_atomic = false; + visit_with(s, [&](auto *self, const Atomic *op) { any_atomic = true; // Collect the names of all Store nodes inside. - CollectProducerStoreNames collector(op->producer_name); - op->body.accept(&collector); + Scope store_names = collect_producer_store_names(op->body, op->producer_name); // Find the indices from the Store nodes inside the body. - FindProducerStoreIndex find(op->producer_name); - op->body.accept(&find); - - Expr index = find.index; + Expr index = find_producer_store_index(op->body, op->producer_name); if (index.defined()) { - user_assert(!expr_uses_vars(index, collector.store_names)) + user_assert(!expr_uses_vars(index, store_names)) << "Can't use atomic() on an update where the index written " << "to depends on the current value of the Func\n"; } - op->body.accept(this); - } -public: - bool any_atomic = false; -}; + op->body.accept(self); + }); + return any_atomic; +} /** Search if the value of a Store node has a variable pointing to a let binding, * where the let binding contains the Store location. Use for checking whether * we need a mutex lock for Atomic since some lowering pass before lifted a let * binding from the Store node (currently only SplitTuple would do this). */ -class FindAtomicLetBindings : public IRVisitor { -public: - FindAtomicLetBindings(const Scope &store_names) - : store_names(store_names) { - } - +bool find_atomic_let_bindings(const Stmt &s, const Scope &store_names) { bool found = false; - -protected: - using IRVisitor::visit; - - void visit(const Let *op) override { - op->value.accept(this); - { - ScopedBinding bind(let_bindings, op->name, op->value); - op->body.accept(this); - } - } - - void visit(const LetStmt *op) override { - op->value.accept(this); - { - ScopedBinding bind(let_bindings, op->name, op->value); - op->body.accept(this); - } - } - - void visit(const Variable *op) override { - if (!inside_store.empty()) { - // If this Variable inside the store value is an expression - // that depends on one of the store_names, we found a lifted let. - if (expr_uses_vars(op, store_names, let_bindings)) { - found = true; - } - } - } - - void visit(const Store *op) override { - op->predicate.accept(this); - op->index.accept(this); - if (store_names.contains(op->name)) { - // If we are in a designated store and op->value has a let binding - // that uses one of the store_names, we found a lifted let. - ScopedValue old_inside_store(inside_store, op->name); - op->value.accept(this); - } else { - op->value.accept(this); - } - } - std::string inside_store; - const Scope &store_names; Scope let_bindings; -}; + visit_with( + s, + [&](auto *self, const Let *op) { + op->value.accept(self); + { + ScopedBinding bind(let_bindings, op->name, op->value); + op->body.accept(self); + } // + }, + [&](auto *self, const LetStmt *op) { + op->value.accept(self); + { + ScopedBinding bind(let_bindings, op->name, op->value); + op->body.accept(self); + } // + }, + [&](auto *self, const Variable *op) { + if (!inside_store.empty()) { + // If this Variable inside the store value is an expression + // that depends on one of the store_names, we found a lifted let. + if (expr_uses_vars(op, store_names, let_bindings)) { + found = true; + } + } // + }, + [&](auto *self, const Store *op) { + op->predicate.accept(self); + op->index.accept(self); + if (store_names.contains(op->name)) { + // If we are in a designated store and op->value has a let binding + // that uses one of the store_names, we found a lifted let. + ScopedValue old_inside_store(inside_store, op->name); + op->value.accept(self); + } else { + op->value.accept(self); + } // + }); + return found; +} /** Clear out the Atomic node's mutex usages if it doesn't need one. */ -class RemoveUnnecessaryMutexUse : public IRMutator { -public: +Stmt remove_unnecessary_mutex_use(const Stmt &s) { std::set remove_mutex_lock_names; - -protected: - using IRMutator::visit; - - Stmt visit(const Atomic *op) override { + return mutate_with(s, [&](auto *self, const Atomic *op) { // Collect the names of all Store nodes inside. - CollectProducerStoreNames collector(op->producer_name); - op->body.accept(&collector); + Scope store_names = collect_producer_store_names(op->body, op->producer_name); // Search for let bindings that access the producers. - FindAtomicLetBindings finder(collector.store_names); - op->body.accept(&finder); // Each individual Store that remains can be done as a CAS // loop or an actual atomic RMW of some form. - if (finder.found) { + if (find_atomic_let_bindings(op->body, store_names)) { // Can't remove mutex lock. Leave the Stmt as is. - return IRMutator::visit(op); + return self->visit_base(op); } else { remove_mutex_lock_names.insert(op->mutex_name); - Stmt body = mutate(op->body); + Stmt body = self->mutate(op->body); return Atomic::make(op->producer_name, std::string{}, std::move(body)); } - } -}; - -/** Find Store inside an Atomic that matches the provided store_names. */ -class FindStoreInAtomicMutex : public IRVisitor { -public: - using IRVisitor::visit; - - FindStoreInAtomicMutex(const std::set &store_names) - : store_names(store_names) { - } - - bool found = false; - std::string producer_name; - std::string mutex_name; - -protected: - void visit(const Atomic *op) override { - if (!found && !op->mutex_name.empty()) { - ScopedValue old_in_atomic_mutex(in_atomic_mutex, true); - op->body.accept(this); - if (found) { - // We found a Store inside Atomic with matching name, - // record the mutex information. - producer_name = op->producer_name; - mutex_name = op->mutex_name; - } - } else { - op->body.accept(this); - } - } - - void visit(const Store *op) override { - if (in_atomic_mutex) { - if (store_names.find(op->name) != store_names.end()) { - found = true; - } - } - IRVisitor::visit(op); - } - - bool in_atomic_mutex = false; - const std::set &store_names; -}; + }); +} /** Replace the indices in the Store nodes with the specified variable. */ -class ReplaceStoreIndexWithVar : public IRMutator { -public: - ReplaceStoreIndexWithVar(const std::string &producer_name, Expr var) - : producer_name(producer_name), var(std::move(var)) { - } - -protected: - using IRMutator::visit; - - Stmt visit(const Store *op) override { - Expr predicate = mutate(op->predicate); - Expr value = mutate(op->value); - return Store::make(op->name, - std::move(value), - var, - op->param, - std::move(predicate), - op->alignment); - } - - const std::string &producer_name; - Expr var; -}; - -/** Add mutex allocation & lock & unlock if required. */ -class AddAtomicMutex : public IRMutator { -public: - AddAtomicMutex(const std::vector &o) { - for (const Function &f : o) { - outputs.emplace(f.name(), f); +Stmt replace_store_index_with_var(const Stmt &s, const std::string &producer_name, Expr var) { + return mutate_with(s, [&](auto *self, const Store *op) { + if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { + return Store::make(op->name, op->value, var, op->param, op->predicate, op->alignment); } - } + return self->visit_base(op); + }); +} -protected: - using IRMutator::visit; +Stmt allocate_mutex(const std::string &mutex_name, Expr extent, const Stmt &body) { + Expr mutex_array = Call::make(type_of(), + "halide_mutex_array_create", + {std::move(extent)}, + Call::Extern); + + // Allocate a scalar of halide_mutex_array. + // This generates halide_mutex_array mutex[1]; + return Allocate::make(mutex_name, + type_of(), + MemoryType::Stack, + {}, + const_true(), + body, + mutex_array, + "halide_mutex_array_destroy"); +} +/** Add mutex allocation & lock & unlock if required. */ +Stmt inject_atomic_mutex(const Stmt &s, const std::vector &o) { // Maps from a producer name to a mutex name, for all encountered atomic // nodes. Scope needs_mutex_allocation; @@ -292,170 +198,149 @@ class AddAtomicMutex : public IRMutator { // Pipeline outputs std::map outputs; - Stmt allocate_mutex(const std::string &mutex_name, Expr extent, Stmt body) { - Expr mutex_array = Call::make(type_of(), - "halide_mutex_array_create", - {std::move(extent)}, - Call::Extern); - - // Allocate a scalar of halide_mutex_array. - // This generates halide_mutex_array mutex[1]; - body = Allocate::make(mutex_name, - type_of(), - MemoryType::Stack, - {}, - const_true(), - body, - mutex_array, - "halide_mutex_array_destroy"); - return body; + for (const Function &f : o) { + outputs.emplace(f.name(), f); } - Stmt visit(const Allocate *op) override { - // If this Allocate node is allocating a buffer for a producer, - // and there is a Store node inside of an Atomic node requiring mutex lock - // matching the name of the Allocate, allocate a mutex lock. - - Stmt body = mutate(op->body); + return mutate_with( + s, + [&](auto *self, const Allocate *op) -> Stmt { + // If this Allocate node is allocating a buffer for a producer, + // and there is a Store node inside of an Atomic node requiring mutex lock + // matching the name of the Allocate, allocate a mutex lock. - std::string producer_name; - if (ends_with(op->name, ".0")) { - producer_name = op->name.substr(0, op->name.size() - 2); - } else { - producer_name = op->name; - } + Stmt body = self->mutate(op->body); - if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) { - Expr extent = cast(1); // uint64_t to handle LargeBuffers - for (const Expr &e : op->extents) { - extent = extent * e; + std::string producer_name; + if (ends_with(op->name, ".0")) { + producer_name = op->name.substr(0, op->name.size() - 2); + } else { + producer_name = op->name; } - body = allocate_mutex(*mutex_name, extent, body); - - // At this stage in lowering it should be impossible to have an - // allocation that shadows the name of an outer allocation, but may as - // well handle it anyway by using a scope and popping at each allocate - // node. - needs_mutex_allocation.pop(producer_name); - } - - if (body.same_as(op->body)) { - return op; - } else { - return Allocate::make(op->name, - op->type, - op->memory_type, - op->extents, - op->condition, - std::move(body), - op->new_expr, - op->free_function, - op->padding); - } - } + if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) { + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (const Expr &e : op->extents) { + extent = extent * e; + } - Stmt visit(const ProducerConsumer *op) override { - // Usually we allocate the mutex buffer at the Allocate node, - // but outputs don't have Allocate. For those we allocate the mutex - // buffer at the producer node. + body = allocate_mutex(*mutex_name, extent, body); - if (!op->is_producer) { - // This is a consumer - return IRMutator::visit(op); - } - - auto it = outputs.find(op->name); - if (it == outputs.end()) { - // Not an output - return IRMutator::visit(op); - } - - Function f = it->second; - - Stmt body = mutate(op->body); + // At this stage in lowering it should be impossible to have an + // allocation that shadows the name of an outer allocation, but may as + // well handle it anyway by using a scope and popping at each allocate + // node. + needs_mutex_allocation.pop(producer_name); + } - if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) { - // All output buffers in a Tuple have the same extent. - OutputImageParam output_buffer = Func(f).output_buffers()[0]; - Expr extent = cast(1); // uint64_t to handle LargeBuffers - for (int i = 0; i < output_buffer.dimensions(); i++) { - extent *= output_buffer.dim(i).extent(); + if (body.same_as(op->body)) { + return op; + } else { + return Allocate::make(op->name, + op->type, + op->memory_type, + op->extents, + op->condition, + std::move(body), + op->new_expr, + op->free_function, + op->padding); + } // + }, + [&](auto *self, const ProducerConsumer *op) -> Stmt { + // Usually we allocate the mutex buffer at the Allocate node, + // but outputs don't have Allocate. For those we allocate the mutex + // buffer at the producer node. + + if (!op->is_producer) { + // This is a consumer + return self->visit_base(op); } - body = allocate_mutex(*mutex_name, extent, body); - } - if (body.same_as(op->body)) { - return op; - } else { - return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); - } - } + auto it = outputs.find(op->name); + if (it == outputs.end()) { + // Not an output + return self->visit_base(op); + } - Stmt visit(const Atomic *op) override { - if (op->mutex_name.empty()) { - return IRMutator::visit(op); - } + Function f = it->second; - // Lock the mutexes using the indices from the Store nodes inside the body. - FindProducerStoreIndex find(op->producer_name); - op->body.accept(&find); + Stmt body = self->mutate(op->body); - Stmt body = op->body; + if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) { + // All output buffers in a Tuple have the same extent. + OutputImageParam output_buffer = Func(f).output_buffers()[0]; + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (int i = 0; i < output_buffer.dimensions(); i++) { + extent *= output_buffer.dim(i).extent(); + } + body = allocate_mutex(*mutex_name, extent, body); + } - Expr index = find.index; - Expr index_let; // If defined, represents the value of the lifted let binding. - if (!index.defined()) { - // Scalar output. - index = Expr(0); - } else { - // Lift the index outside of the atomic node. - // This is for avoiding side-effects inside those expressions - // being evaluated twice. - std::string name = unique_name('t'); - index_let = index; - index = Variable::make(index.type(), name); - body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body); - } - // This generates a pointer to the mutex array - Expr mutex_array = Variable::make( - type_of(), op->mutex_name); - // Add mutex locks & unlocks - // If a thread locks the mutex and throws an exception, - // halide_mutex_array_destroy will be called and cleanup the mutex locks. - body = Block::make( - Evaluate::make(Call::make(type_of(), - "halide_mutex_array_lock", - {mutex_array, index}, - Call::CallType::Extern)), - Block::make(std::move(body), - Evaluate::make(Call::make(type_of(), - "halide_mutex_array_unlock", - {mutex_array, index}, - Call::CallType::Extern)))); - Stmt ret = Atomic::make(op->producer_name, - op->mutex_name, - std::move(body)); + if (body.same_as(op->body)) { + return op; + } else { + return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); + } + }, + [&](auto *self, const Atomic *op) { + if (op->mutex_name.empty()) { + return self->visit_base(op); + } - if (index_let.defined()) { - // Attach the let binding outside of the atomic node. - internal_assert(index.as() != nullptr); - ret = LetStmt::make(index.as()->name, index_let, ret); - } - needs_mutex_allocation.push(op->producer_name, op->mutex_name); + // Lock the mutexes using the indices from the Store nodes inside the body. + Stmt body = op->body; + Expr index = find_producer_store_index(body, op->producer_name); + Expr index_let; // If defined, represents the value of the lifted let binding. + if (!index.defined()) { + // Scalar output. + index = Expr(0); + } else { + // Lift the index outside of the atomic node. + // This is for avoiding side-effects inside those expressions + // being evaluated twice. + std::string name = unique_name('t'); + index_let = index; + index = Variable::make(index.type(), name); + body = replace_store_index_with_var(body, op->producer_name, index); + } + // This generates a pointer to the mutex array + Expr mutex_array = Variable::make( + type_of(), op->mutex_name); + // Add mutex locks & unlocks + // If a thread locks the mutex and throws an exception, + // halide_mutex_array_destroy will be called and cleanup the mutex locks. + body = Block::make( + Evaluate::make(Call::make(type_of(), + "halide_mutex_array_lock", + {mutex_array, index}, + Call::CallType::Extern)), + Block::make(std::move(body), + Evaluate::make(Call::make(type_of(), + "halide_mutex_array_unlock", + {mutex_array, index}, + Call::CallType::Extern)))); + Stmt ret = Atomic::make(op->producer_name, + op->mutex_name, + std::move(body)); + + if (index_let.defined()) { + // Attach the let binding outside of the atomic node. + internal_assert(index.as() != nullptr); + ret = LetStmt::make(index.as()->name, index_let, ret); + } + needs_mutex_allocation.push(op->producer_name, op->mutex_name); - return ret; - } + return ret; + }); }; } // namespace Stmt add_atomic_mutex(Stmt s, const std::vector &outputs) { - CheckAtomicValidity check; - s.accept(&check); - if (check.any_atomic) { - s = RemoveUnnecessaryMutexUse().mutate(s); - s = AddAtomicMutex(outputs).mutate(s); + if (check_atomic_validity(s)) { + s = remove_unnecessary_mutex_use(s); + s = inject_atomic_mutex(s, outputs); } return s; } diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 32b4159ea4f7..737062e1c88b 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1994,29 +1994,6 @@ bool box_contains(const Box &outer, const Box &inner) { namespace { -class FindInnermostVar : public IRVisitor { -public: - const Scope &vars_depth; - string innermost_var; - - FindInnermostVar(const Scope &vars_depth) - : vars_depth(vars_depth) { - } - -private: - using IRVisitor::visit; - int innermost_depth = -1; - - void visit(const Variable *op) override { - if (const int *depth = vars_depth.find(op->name)) { - if (*depth > innermost_depth) { - innermost_var = op->name; - innermost_depth = *depth; - } - } - } -}; - // Place innermost vars in an IfThenElse's condition as far to the left as possible. class SolveIfThenElse : public IRMutator { // Scope of variable names and their depths. Higher depth indicates @@ -2075,10 +2052,19 @@ class SolveIfThenElse : public IRMutator { op = stmt.as(); internal_assert(op); - FindInnermostVar find(vars_depth); - op->condition.accept(&find); - if (!find.innermost_var.empty()) { - Expr condition = solve_expression(op->condition, find.innermost_var).result; + string innermost_var; + int innermost_depth = -1; + visit_with(op->condition, [&](auto *, const Variable *var) { + if (const int *var_depth = vars_depth.find(var->name)) { + if (*var_depth > innermost_depth) { + innermost_var = var->name; + innermost_depth = *var_depth; + } + } + }); + + if (!innermost_var.empty()) { + Expr condition = solve_expression(op->condition, innermost_var).result; if (!condition.same_as(op->condition)) { stmt = IfThenElse::make(condition, op->then_case, op->else_case); } diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 72f45360b3b5..85f5bb680bc2 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -64,71 +64,55 @@ bool depends_on_bounds_inference(const Expr &e) { * bounds_of_inner_var(y) would return 2 to 12, and * bounds_of_inner_var(x) would return 0 to 10. */ -class BoundsOfInnerVar : public IRVisitor { -public: +Interval bounds_of_inner_var(const string &var, const Stmt &s) { Interval result; - BoundsOfInnerVar(const string &v) - : var(v) { - } - -private: - string var; bool found = false; - - using IRVisitor::visit; - - void visit(const LetStmt *op) override { - if (op->name == var) { - result = Interval::single_point(op->value); - found = true; - } else if (!found) { - op->body.accept(this); - if (found) { - if (expr_uses_var(result.min, op->name)) { - result.min = Let::make(op->name, op->value, result.min); - } - if (expr_uses_var(result.max, op->name)) { - result.max = Let::make(op->name, op->value, result.max); + visit_with( + s, + [&](auto *self, const LetStmt *op) { + if (op->name == var) { + result = Interval::single_point(op->value); + found = true; + } else if (!found) { + op->body.accept(self); + if (found) { + if (expr_uses_var(result.min, op->name)) { + result.min = Let::make(op->name, op->value, result.min); + } + if (expr_uses_var(result.max, op->name)) { + result.max = Let::make(op->name, op->value, result.max); + } } } - } - } - - void visit(const Block *op) override { - // We're most likely to find our var at the end of a - // block. The start of the block could be unrelated producers. - op->rest.accept(this); - if (!found) { - op->first.accept(this); - } - } - - void visit(const For *op) override { - Interval in(op->min, op->max); - - if (op->name == var) { - result = in; - found = true; - } else if (!found) { - op->body.accept(this); - if (found) { - Scope scope; - scope.push(op->name, in); - if (expr_uses_var(result.min, op->name)) { - result.min = bounds_of_expr_in_scope(result.min, scope).min; - } - if (expr_uses_var(result.max, op->name)) { - result.max = bounds_of_expr_in_scope(result.max, scope).max; - } + }, + [&](auto *self, const Block *op) { + // We're most likely to find our var at the end of a + // block. The start of the block could be unrelated producers. + op->rest.accept(self); + if (!found) { + op->first.accept(self); } - } - } -}; - -Interval bounds_of_inner_var(const string &var, const Stmt &s) { - BoundsOfInnerVar b(var); - s.accept(&b); - return b.result; + }, + [&](auto *self, const For *op) { + Interval in(op->min, op->max); + if (op->name == var) { + result = in; + found = true; + } else if (!found) { + op->body.accept(self); + if (found) { + Scope scope; + scope.push(op->name, in); + if (expr_uses_var(result.min, op->name)) { + result.min = bounds_of_expr_in_scope(result.min, scope).min; + } + if (expr_uses_var(result.max, op->name)) { + result.max = bounds_of_expr_in_scope(result.max, scope).max; + } + } + } // + }); + return result; } size_t find_fused_group_index(const Function &producing_func, @@ -386,23 +370,17 @@ class BoundsInference : public IRMutator { // don't care what sites are loaded, just what sites need // to have the correct value in them. So remap all selects // to if_then_elses to get tighter bounds. - class SelectToIfThenElse : public IRMutator { - using IRMutator::visit; - Expr visit(const Select *op) override { + for (auto &e : exprs) { + e.value = mutate_with(e.value, [](auto *self, const Select *op) { if (is_pure(op->condition)) { return Call::make(op->type, Call::if_then_else, - {mutate(op->condition), - mutate(op->true_value), - mutate(op->false_value)}, + {self->mutate(op->condition), + self->mutate(op->true_value), + self->mutate(op->false_value)}, Call::PureIntrinsic); - } else { - return IRMutator::visit(op); } - } - } select_to_if_then_else; - - for (auto &e : exprs) { - e.value = select_to_if_then_else.mutate(e.value); + return self->visit_base(op); + }); } } diff --git a/src/EarlyFree.cpp b/src/EarlyFree.cpp index 8b664c2bcf8d..f0dcf44c2ce6 100644 --- a/src/EarlyFree.cpp +++ b/src/EarlyFree.cpp @@ -109,39 +109,28 @@ class FindLastUse : public IRVisitor { } }; -class InjectMarker : public IRMutator { -public: - string func; - Stmt last_use; - -private: +Stmt inject_marker(const Stmt &stmt, const string &func, const Stmt &last_use) { bool injected = false; + return mutate_with(stmt, [&](auto *self, const Block *block) -> Stmt { + auto do_injection = [&](const Stmt &s) { + if (injected) { + return s; + } + if (s.same_as(last_use)) { + injected = true; + return Block::make(s, Free::make(func)); + } + return self->mutate(s); + }; - using IRMutator::visit; - - Stmt inject_marker(Stmt s) { - if (injected) { - return s; - } - if (s.same_as(last_use)) { - injected = true; - return Block::make(s, Free::make(func)); - } else { - return mutate(s); - } - } - - Stmt visit(const Block *block) override { - Stmt new_rest = inject_marker(block->rest); - Stmt new_first = inject_marker(block->first); + Stmt new_rest = do_injection(block->rest); + Stmt new_first = do_injection(block->first); - if (new_first.same_as(block->first) && - new_rest.same_as(block->rest)) { + if (new_first.same_as(block->first) && new_rest.same_as(block->rest)) { return block; - } else { - return Block::make(new_first, new_rest); } - } + return Block::make(new_first, new_rest); + }); }; class InjectEarlyFrees : public IRMutator { @@ -156,10 +145,7 @@ class InjectEarlyFrees : public IRMutator { stmt.accept(&last_use); if (last_use.last_use.defined()) { - InjectMarker inject_marker; - inject_marker.func = alloc->name; - inject_marker.last_use = last_use.last_use; - stmt = inject_marker.mutate(stmt); + stmt = inject_marker(stmt, alloc->name, last_use.last_use); } else { stmt = Allocate::make(alloc->name, alloc->type, alloc->memory_type, alloc->extents, alloc->condition, diff --git a/src/FindCalls.cpp b/src/FindCalls.cpp index 1fca6de1175c..5692cea0b0d9 100644 --- a/src/FindCalls.cpp +++ b/src/FindCalls.cpp @@ -10,34 +10,30 @@ namespace Internal { namespace { -/* Find all the internal halide calls in an expr */ -class FindCalls : public IRVisitor { -public: +struct CallInfo { std::map calls; std::vector order; +}; - using IRVisitor::visit; - - void include_function(const Function &f) { - auto [it, inserted] = calls.emplace(f.name(), f); - if (inserted) { - order.push_back(f); - } else { - user_assert(it->second.same_as(f)) - << "Can't compile a pipeline using multiple functions with same name: " - << f.name() << "\n"; - } - } - - void visit(const Call *call) override { - IRVisitor::visit(call); - +/* Find all the internal halide calls in an expr */ +CallInfo find_calls(const Function &f) { + CallInfo info; + visit_with(f, [&](auto *self, const Call *call) { + self->visit_base(call); if (call->call_type == Call::Halide && call->func.defined()) { Function f(call->func); - include_function(f); + auto [it, inserted] = info.calls.emplace(f.name(), f); + if (inserted) { + info.order.push_back(f); + } else { + user_assert(it->second.same_as(f)) + << "Can't compile a pipeline using multiple functions with same name: " + << f.name() << "\n"; + } } - } -}; + }); + return info; +} void populate_environment_helper(const Function &f, std::map *env, @@ -61,29 +57,28 @@ void populate_environment_helper(const Function &f, } }; - FindCalls calls; - f.accept(&calls); + auto [f_calls, f_order] = find_calls(f); if (f.has_extern_definition()) { for (const ExternFuncArgument &arg : f.extern_arguments()) { if (arg.is_func()) { - insert_func(Function{arg.func}, &calls.calls, &calls.order); + insert_func(Function{arg.func}, &f_calls, &f_order); } } } if (include_wrappers) { for (const auto &it : f.schedule().wrappers()) { - insert_func(Function{it.second}, &calls.calls, &calls.order); + insert_func(Function{it.second}, &f_calls, &f_order); } } if (!recursive) { - for (const Function &g : calls.order) { + for (const Function &g : f_order) { insert_func(g, env, order); } } else { insert_func(f, env, order); - for (const Function &g : calls.order) { + for (const Function &g : f_order) { populate_environment_helper(g, env, order, recursive, include_wrappers); } } diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 2873c5eb4ca9..951a691be44d 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1123,36 +1123,25 @@ class FindIntrinsics : public IRMutator { class SubstituteInWideningLets : public IRMutator { using IRMutator::visit; - bool widens(const Expr &e) { - class AllInputsNarrowerThan : public IRVisitor { - int bits; - - using IRVisitor::visit; - - void visit(const Variable *op) override { + static bool widens(const Expr &e) { + const int bits = e.type().bits(); + bool result = true; + visit_with( + e, + [&](auto *, const Variable *op) { result &= op->type.bits() < bits; - } - - void visit(const Load *op) override { + }, + [&](auto *, const Load *op) { result &= op->type.bits() < bits; - } - - void visit(const Call *op) override { + }, + [&](auto *self, const Call *op) { if (op->is_pure() && op->is_intrinsic()) { - IRVisitor::visit(op); + self->visit_base(op); } else { result &= op->type.bits() < bits; } - } - - public: - AllInputsNarrowerThan(Type t) - : bits(t.bits()) { - } - bool result = true; - } widens(e.type()); - e.accept(&widens); - return widens.result; + }); + return result; } Scope replacements; @@ -1635,27 +1624,24 @@ Expr lower_intrinsic(const Call *op) { } namespace { - -class LowerIntrinsics : public IRMutator { - using IRMutator::visit; - - Expr visit(const Call *op) override { +template +T lower_intrinsics_impl(const T &ir) { + return mutate_with(ir, [&](auto *self, const Call *op) { Expr lowered = lower_intrinsic(op); if (lowered.defined()) { - return mutate(lowered); + return self->mutate(lowered); } - return IRMutator::visit(op); - } -}; - + return self->visit_base(op); + }); +} } // namespace Expr lower_intrinsics(const Expr &e) { - return LowerIntrinsics().mutate(e); + return lower_intrinsics_impl(e); } Stmt lower_intrinsics(const Stmt &s) { - return LowerIntrinsics().mutate(s); + return lower_intrinsics_impl(s); } } // namespace Internal diff --git a/src/Func.cpp b/src/Func.cpp index ef9e2c7af8f0..caa5b98240c3 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -352,6 +352,27 @@ std::string Stage::name() const { } namespace { +struct CheckResult { + bool has_self_reference = false; + bool has_rvar = false; +}; + +CheckResult check_self_ref_and_rvar(const vector &exprs, const string &func_name) { + CheckResult result; + for (const Expr &e : exprs) { + visit_with( + e, + [&](auto *, const Variable *op) { + result.has_rvar |= op->reduction_domain.defined(); // + }, + [&](auto *self, const Call *op) { + result.has_self_reference |= (op->call_type == Call::Halide && op->name == func_name); + self->visit_base(op); + }); + } + return result; +} + bool is_const_assignment(const string &func_name, const vector &args, const vector &values) { // Check if an update definition is a non-recursive and just // scatters a value that doesn't depend on the reduction @@ -364,37 +385,11 @@ bool is_const_assignment(const string &func_name, const vector &args, cons // never be races between two distinct values of the pure var by // construction (because the pure var must appear as one of the // args) e.g: f(g(r, x), x) = h(x); - class Checker : public IRVisitor { - using IRVisitor::visit; - - void visit(const Variable *op) override { - has_rvar |= op->reduction_domain.defined(); - } - - void visit(const Call *op) override { - has_self_reference |= (op->call_type == Call::Halide && op->name == func_name); - IRVisitor::visit(op); - } - - const string &func_name; - - public: - Checker(const string &func_name) - : func_name(func_name) { - } - - bool has_self_reference = false; - bool has_rvar = false; - } lhs_checker(func_name), rhs_checker(func_name); - for (const Expr &v : args) { - v.accept(&lhs_checker); - } - for (const Expr &v : values) { - v.accept(&rhs_checker); - } - return !(lhs_checker.has_self_reference || - rhs_checker.has_self_reference || - rhs_checker.has_rvar); + auto lhs_result = check_self_ref_and_rvar(args, func_name); + auto rhs_result = check_self_ref_and_rvar(values, func_name); + return !(lhs_result.has_self_reference || + rhs_result.has_self_reference || + rhs_result.has_rvar); } void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { @@ -574,47 +569,31 @@ std::string Stage::dump_argument_list() const { namespace { -class SubstituteSelfReference : public IRMutator { - using IRMutator::visit; - - const string func; - const Function substitute; - const vector new_args; - - Expr visit(const Call *c) override { - Expr expr = IRMutator::visit(c); - c = expr.as(); - internal_assert(c); - - if ((c->call_type == Call::Halide) && (func == c->name)) { - debug(4) << "...Replace call to Func \"" << c->name << "\" with " - << "\"" << substitute.name() << "\"\n"; - vector args; - args.insert(args.end(), c->args.begin(), c->args.end()); - args.insert(args.end(), new_args.begin(), new_args.end()); - expr = Call::make(substitute, args, c->value_index); - } - return expr; - } - -public: - SubstituteSelfReference(const string &func, const Function &substitute, - const vector &new_args) - : func(func), substitute(substitute), new_args(new_args) { - internal_assert(substitute.get_contents().defined()); - } -}; - /** Substitute all self-reference calls to 'func' with 'substitute' which * args (LHS) is the old args (LHS) plus 'new_args' in that order. * Expect this method to be called on the value (RHS) of an update definition. */ vector substitute_self_reference(const vector &values, const string &func, const Function &substitute, const vector &new_args) { - SubstituteSelfReference subs(func, substitute, new_args); + internal_assert(substitute.get_contents().defined()); + vector result; result.reserve(values.size()); for (const auto &val : values) { - result.push_back(subs.mutate(val)); + result.push_back(mutate_with(val, [&](auto *self, const Call *c) { + Expr expr = self->visit_base(c); + c = expr.as(); + internal_assert(c); + + if (c->call_type == Call::Halide && func == c->name) { + debug(4) << "...Replace call to Func \"" << c->name << "\" with " + << "\"" << substitute.name() << "\"\n"; + vector args; + args.insert(args.end(), c->args.begin(), c->args.end()); + args.insert(args.end(), new_args.begin(), new_args.end()); + expr = Call::make(substitute, args, c->value_index); + } + return expr; + })); } return result; } diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 88f9a542550f..b65c4085927b 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1484,81 +1484,38 @@ class FuseGPUThreadLoops : public IRMutator { } }; -class ZeroGPULoopMins : public IRMutator { - bool in_non_glsl_gpu = false; - using IRMutator::visit; +} // namespace - Stmt visit(const For *op) override { +// Also used by InjectImageIntrinsics +Stmt zero_gpu_loop_mins(const Stmt &s) { + bool in_non_glsl_gpu = false; + return mutate_with(s, [&](auto *self, const For *op) { ScopedValue old_in_non_glsl_gpu(in_non_glsl_gpu); in_non_glsl_gpu = (in_non_glsl_gpu && op->device_api == DeviceAPI::None) || - (op->device_api == DeviceAPI::CUDA) || (op->device_api == DeviceAPI::OpenCL) || - (op->device_api == DeviceAPI::Metal) || - (op->device_api == DeviceAPI::D3D12Compute) || - (op->device_api == DeviceAPI::Vulkan); + op->device_api == DeviceAPI::CUDA || + op->device_api == DeviceAPI::OpenCL || + op->device_api == DeviceAPI::Metal || + op->device_api == DeviceAPI::D3D12Compute || + op->device_api == DeviceAPI::Vulkan; - Stmt stmt = IRMutator::visit(op); + Stmt stmt = self->visit_base(op); if (is_gpu(op->for_type) && !is_const_zero(op->min)) { op = stmt.as(); internal_assert(op); Expr adjusted = Variable::make(Int(32), op->name) + op->min; Stmt body = substitute(op->name, adjusted, op->body); - stmt = For::make(op->name, 0, simplify(op->max - op->min), op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, + 0, simplify(op->max - op->min), + op->for_type, op->partition_policy, op->device_api, + body); } return stmt; - } - -public: - ZeroGPULoopMins() = default; -}; - -} // namespace - -// Also used by InjectImageIntrinsics -Stmt zero_gpu_loop_mins(const Stmt &s) { - return ZeroGPULoopMins().mutate(s); + }); } namespace { -// Find the inner most GPU block of a statement. -class FindInnermostGPUBlock : public IRVisitor { - using IRVisitor::visit; - - void visit(const For *op) override { - if (op->for_type == ForType::GPUBlock) { - // Set the last found GPU block to found_gpu_block. - found_gpu_block = op; - } - IRVisitor::visit(op); - } - -public: - const For *found_gpu_block = nullptr; -}; - -// Given a condition and a loop, add the condition -// to the loop body. -class AddConditionToALoop : public IRMutator { - using IRMutator::visit; - - Stmt visit(const For *op) override { - if (op != loop) { - return IRMutator::visit(op); - } - - return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, - IfThenElse::make(condition, op->body, Stmt())); - } - -public: - AddConditionToALoop(const Expr &condition, const For *loop) - : condition(condition), loop(loop) { - } - const Expr &condition; - const For *loop; -}; - // Push if statements between GPU blocks through all GPU blocks. // Throw error if the if statement has an else clause. class NormalizeIfStatements : public IRMutator { @@ -1578,11 +1535,23 @@ class NormalizeIfStatements : public IRMutator { if (!inside_gpu_blocks) { return IRMutator::visit(op); } - FindInnermostGPUBlock find; - op->accept(&find); - if (find.found_gpu_block != nullptr) { + const For *innermost_gpu_block = nullptr; + visit_with(op, [&](auto *self, const For *loop) { + if (loop->for_type == ForType::GPUBlock) { + innermost_gpu_block = loop; + } + self->visit_base(loop); + }); + if (innermost_gpu_block != nullptr) { internal_assert(!op->else_case.defined()) << "Found an if statement with else case between two GPU blocks.\n"; - return AddConditionToALoop(op->condition, find.found_gpu_block).mutate(op->then_case); + // Add the condition to the loop body. + return mutate_with(op->then_case, [&](auto *self, const For *loop) { + if (loop != innermost_gpu_block) { + return self->visit_base(op); + } + return For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, loop->device_api, + IfThenElse::make(op->condition, loop->body, Stmt())); + }); } return IRMutator::visit(op); } @@ -1596,7 +1565,7 @@ Stmt fuse_gpu_thread_loops(Stmt s) { // merge the predicate into the merged GPU thread. s = NormalizeIfStatements().mutate(s); s = FuseGPUThreadLoops().mutate(s); - s = ZeroGPULoopMins().mutate(s); + s = zero_gpu_loop_mins(s); return s; } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index 473f825335ec..561b9d503bff 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -252,13 +252,17 @@ struct LambdaVisitor final : IRVisitor { } }; -template -void visit_with(const IRNode *ir, Lambdas &&...lambdas) { +template +void visit_with(T &&ir, Lambdas &&...lambdas) { LambdaVisitor visitor{std::forward(lambdas)...}; constexpr bool all_take_two_args = (std::is_invocable_v && ...); static_assert(all_take_two_args); - ir->accept(&visitor); + if constexpr (std::is_pointer_v>) { + ir->accept(&visitor); + } else { + ir.accept(&visitor); + } } template