Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
613 changes: 249 additions & 364 deletions src/AddAtomicMutex.cpp

Large diffs are not rendered by default.

40 changes: 13 additions & 27 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1994,29 +1994,6 @@ bool box_contains(const Box &outer, const Box &inner) {

namespace {

class FindInnermostVar : public IRVisitor {
public:
const Scope<int> &vars_depth;
string innermost_var;

FindInnermostVar(const Scope<int> &vars_depth)
: vars_depth(vars_depth) {
}

private:
using IRVisitor::visit;
int innermost_depth = -1;

void visit(const Variable *op) override {
if (const int *depth = vars_depth.find(op->name)) {
if (*depth > innermost_depth) {
innermost_var = op->name;
innermost_depth = *depth;
}
}
}
};

// Place innermost vars in an IfThenElse's condition as far to the left as possible.
class SolveIfThenElse : public IRMutator {
// Scope of variable names and their depths. Higher depth indicates
Expand Down Expand Up @@ -2075,10 +2052,19 @@ class SolveIfThenElse : public IRMutator {
op = stmt.as<IfThenElse>();
internal_assert(op);

FindInnermostVar find(vars_depth);
op->condition.accept(&find);
if (!find.innermost_var.empty()) {
Expr condition = solve_expression(op->condition, find.innermost_var).result;
string innermost_var;
int innermost_depth = -1;
visit_with(op->condition, [&](auto *, const Variable *var) {
if (const int *var_depth = vars_depth.find(var->name)) {
if (*var_depth > innermost_depth) {
innermost_var = var->name;
innermost_depth = *var_depth;
}
}
});

if (!innermost_var.empty()) {
Expr condition = solve_expression(op->condition, innermost_var).result;
if (!condition.same_as(op->condition)) {
stmt = IfThenElse::make(condition, op->then_case, op->else_case);
}
Expand Down
124 changes: 51 additions & 73 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,71 +64,55 @@ bool depends_on_bounds_inference(const Expr &e) {
* bounds_of_inner_var(y) would return 2 to 12, and
* bounds_of_inner_var(x) would return 0 to 10.
*/
class BoundsOfInnerVar : public IRVisitor {
public:
Interval bounds_of_inner_var(const string &var, const Stmt &s) {
Interval result;
BoundsOfInnerVar(const string &v)
: var(v) {
}

private:
string var;
bool found = false;

using IRVisitor::visit;

void visit(const LetStmt *op) override {
if (op->name == var) {
result = Interval::single_point(op->value);
found = true;
} else if (!found) {
op->body.accept(this);
if (found) {
if (expr_uses_var(result.min, op->name)) {
result.min = Let::make(op->name, op->value, result.min);
}
if (expr_uses_var(result.max, op->name)) {
result.max = Let::make(op->name, op->value, result.max);
visit_with(
s,
[&](auto *self, const LetStmt *op) {
if (op->name == var) {
result = Interval::single_point(op->value);
found = true;
} else if (!found) {
op->body.accept(self);
if (found) {
if (expr_uses_var(result.min, op->name)) {
result.min = Let::make(op->name, op->value, result.min);
}
if (expr_uses_var(result.max, op->name)) {
result.max = Let::make(op->name, op->value, result.max);
}
}
}
}
}

void visit(const Block *op) override {
// We're most likely to find our var at the end of a
// block. The start of the block could be unrelated producers.
op->rest.accept(this);
if (!found) {
op->first.accept(this);
}
}

void visit(const For *op) override {
Interval in(op->min, op->max);

if (op->name == var) {
result = in;
found = true;
} else if (!found) {
op->body.accept(this);
if (found) {
Scope<Interval> scope;
scope.push(op->name, in);
if (expr_uses_var(result.min, op->name)) {
result.min = bounds_of_expr_in_scope(result.min, scope).min;
}
if (expr_uses_var(result.max, op->name)) {
result.max = bounds_of_expr_in_scope(result.max, scope).max;
}
},
[&](auto *self, const Block *op) {
// We're most likely to find our var at the end of a
// block. The start of the block could be unrelated producers.
op->rest.accept(self);
if (!found) {
op->first.accept(self);
}
}
}
};

Interval bounds_of_inner_var(const string &var, const Stmt &s) {
BoundsOfInnerVar b(var);
s.accept(&b);
return b.result;
},
[&](auto *self, const For *op) {
Interval in(op->min, op->max);
if (op->name == var) {
result = in;
found = true;
} else if (!found) {
op->body.accept(self);
if (found) {
Scope<Interval> scope;
scope.push(op->name, in);
if (expr_uses_var(result.min, op->name)) {
result.min = bounds_of_expr_in_scope(result.min, scope).min;
}
if (expr_uses_var(result.max, op->name)) {
result.max = bounds_of_expr_in_scope(result.max, scope).max;
}
}
} //
});
return result;
}

size_t find_fused_group_index(const Function &producing_func,
Expand Down Expand Up @@ -386,23 +370,17 @@ class BoundsInference : public IRMutator {
// don't care what sites are loaded, just what sites need
// to have the correct value in them. So remap all selects
// to if_then_elses to get tighter bounds.
class SelectToIfThenElse : public IRMutator {
using IRMutator::visit;
Expr visit(const Select *op) override {
for (auto &e : exprs) {
e.value = mutate_with(e.value, [](auto *self, const Select *op) {
if (is_pure(op->condition)) {
return Call::make(op->type, Call::if_then_else,
{mutate(op->condition),
mutate(op->true_value),
mutate(op->false_value)},
{self->mutate(op->condition),
self->mutate(op->true_value),
self->mutate(op->false_value)},
Call::PureIntrinsic);
} else {
return IRMutator::visit(op);
}
}
} select_to_if_then_else;

for (auto &e : exprs) {
e.value = select_to_if_then_else.mutate(e.value);
return self->visit_base(op);
});
}
}

Expand Down
50 changes: 18 additions & 32 deletions src/EarlyFree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,28 @@ class FindLastUse : public IRVisitor {
}
};

class InjectMarker : public IRMutator {
public:
string func;
Stmt last_use;

private:
Stmt inject_marker(const Stmt &stmt, const string &func, const Stmt &last_use) {
bool injected = false;
return mutate_with(stmt, [&](auto *self, const Block *block) -> Stmt {
auto do_injection = [&](const Stmt &s) {
if (injected) {
return s;
}
if (s.same_as(last_use)) {
injected = true;
return Block::make(s, Free::make(func));
}
return self->mutate(s);
};

using IRMutator::visit;

Stmt inject_marker(Stmt s) {
if (injected) {
return s;
}
if (s.same_as(last_use)) {
injected = true;
return Block::make(s, Free::make(func));
} else {
return mutate(s);
}
}

Stmt visit(const Block *block) override {
Stmt new_rest = inject_marker(block->rest);
Stmt new_first = inject_marker(block->first);
Stmt new_rest = do_injection(block->rest);
Stmt new_first = do_injection(block->first);

if (new_first.same_as(block->first) &&
new_rest.same_as(block->rest)) {
if (new_first.same_as(block->first) && new_rest.same_as(block->rest)) {
return block;
} else {
return Block::make(new_first, new_rest);
}
}
return Block::make(new_first, new_rest);
});
};

class InjectEarlyFrees : public IRMutator {
Expand All @@ -156,10 +145,7 @@ class InjectEarlyFrees : public IRMutator {
stmt.accept(&last_use);

if (last_use.last_use.defined()) {
InjectMarker inject_marker;
inject_marker.func = alloc->name;
inject_marker.last_use = last_use.last_use;
stmt = inject_marker.mutate(stmt);
stmt = inject_marker(stmt, alloc->name, last_use.last_use);
} else {
stmt = Allocate::make(alloc->name, alloc->type, alloc->memory_type,
alloc->extents, alloc->condition,
Expand Down
51 changes: 23 additions & 28 deletions src/FindCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,30 @@ namespace Internal {

namespace {

/* Find all the internal halide calls in an expr */
class FindCalls : public IRVisitor {
public:
struct CallInfo {
std::map<std::string, Function> calls;
std::vector<Function> order;
};

using IRVisitor::visit;

void include_function(const Function &f) {
auto [it, inserted] = calls.emplace(f.name(), f);
if (inserted) {
order.push_back(f);
} else {
user_assert(it->second.same_as(f))
<< "Can't compile a pipeline using multiple functions with same name: "
<< f.name() << "\n";
}
}

void visit(const Call *call) override {
IRVisitor::visit(call);

/* Find all the internal halide calls in an expr */
CallInfo find_calls(const Function &f) {
CallInfo info;
visit_with(f, [&](auto *self, const Call *call) {
self->visit_base(call);
if (call->call_type == Call::Halide && call->func.defined()) {
Function f(call->func);
include_function(f);
auto [it, inserted] = info.calls.emplace(f.name(), f);
if (inserted) {
info.order.push_back(f);
} else {
user_assert(it->second.same_as(f))
<< "Can't compile a pipeline using multiple functions with same name: "
<< f.name() << "\n";
}
}
}
};
});
return info;
}

void populate_environment_helper(const Function &f,
std::map<std::string, Function> *env,
Expand All @@ -61,29 +57,28 @@ void populate_environment_helper(const Function &f,
}
};

FindCalls calls;
f.accept(&calls);
auto [f_calls, f_order] = find_calls(f);
if (f.has_extern_definition()) {
for (const ExternFuncArgument &arg : f.extern_arguments()) {
if (arg.is_func()) {
insert_func(Function{arg.func}, &calls.calls, &calls.order);
insert_func(Function{arg.func}, &f_calls, &f_order);
}
}
}

if (include_wrappers) {
for (const auto &it : f.schedule().wrappers()) {
insert_func(Function{it.second}, &calls.calls, &calls.order);
insert_func(Function{it.second}, &f_calls, &f_order);
}
}

if (!recursive) {
for (const Function &g : calls.order) {
for (const Function &g : f_order) {
insert_func(g, env, order);
}
} else {
insert_func(f, env, order);
for (const Function &g : calls.order) {
for (const Function &g : f_order) {
populate_environment_helper(g, env, order, recursive, include_wrappers);
}
}
Expand Down
Loading
Loading