This source file includes following definitions.
- extern_call_uses_buffer
- treat_selects_as_guards
- visit
- visit
- visit_let
- visit
- visit
- visit
- make_and
- make_or
- make_select
- make_not
- visit_conditional
- visit
- visit
- visit
- visit
- alloc_predicate
- memoize_call_uses_buffer
- visit
- visit
- in_vector_loop
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- result
- skip_stages
#include "SkipStages.h"
#include "CSE.h"
#include "Debug.h"
#include "ExprUsesVar.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IRPrinter.h"
#include "IROperator.h"
#include "Scope.h"
#include "Simplify.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
using std::string;
using std::vector;
using std::map;
namespace {
bool extern_call_uses_buffer(const Call *op, const std::string &func) {
if (op->call_type == Call::Extern ||
op->call_type == Call::ExternCPlusPlus) {
if (starts_with(op->name, "halide_memoization")) {
return false;
}
for (size_t i = 0; i < op->args.size(); i++) {
const Variable *var = op->args[i].as<Variable>();
if (var &&
starts_with(var->name, func + ".") &&
ends_with(var->name, ".buffer")) {
return true;
}
}
}
return false;
}
}
class PredicateFinder : public IRVisitor {
public:
Expr predicate;
PredicateFinder(const string &b, bool s) : predicate(const_false()),
buffer(b),
varies(false),
treat_selects_as_guards(s) {}
private:
using IRVisitor::visit;
string buffer;
bool varies;
bool treat_selects_as_guards;
Scope<int> varying;
Scope<int> in_pipeline;
void visit(const Variable *op) {
bool this_varies = varying.contains(op->name);
varies |= this_varies;
}
void visit(const For *op) {
op->min.accept(this);
bool min_varies = varies;
op->extent.accept(this);
bool should_pop = false;
if (!is_one(op->extent) || min_varies) {
should_pop = true;
varying.push(op->name, 0);
}
op->body.accept(this);
if (should_pop) {
varying.pop(op->name);
} else if (expr_uses_var(predicate, op->name)) {
predicate = Let::make(op->name, op->min, predicate);
}
}
template<typename T>
void visit_let(const std::string &name, Expr value, T body) {
bool old_varies = varies;
varies = false;
value.accept(this);
bool value_varies = varies;
varies |= old_varies;
if (value_varies) {
varying.push(name, 0);
}
body.accept(this);
if (value_varies) {
varying.pop(name);
}
if (expr_uses_var(predicate, name)) {
predicate = Let::make(name, value, predicate);
}
}
void visit(const LetStmt *op) {
visit_let(op->name, op->value, op->body);
}
void visit(const Let *op) {
visit_let(op->name, op->value, op->body);
}
void visit(const ProducerConsumer *op) {
in_pipeline.push(op->name, 0);
if (op->is_producer) {
if (op->name != buffer) {
op->body.accept(this);
}
} else {
IRVisitor::visit(op);
}
in_pipeline.pop(op->name);
}
Expr make_and(Expr a, Expr b) {
if (is_zero(a) || is_one(b)) {
return a;
} else if (is_zero(b) || is_one(a)) {
return b;
} else if (equal(a, b)) {
return a;
} else {
return a && b;
}
}
Expr make_or(Expr a, Expr b) {
if (is_zero(a) || is_one(b)) {
return b;
} else if (is_zero(b) || is_one(a)) {
return a;
} else if (equal(a, b)) {
return a;
} else {
return a || b;
}
}
Expr make_select(Expr a, Expr b, Expr c) {
if (is_one(a)) {
return b;
} else if (is_zero(a)) {
return c;
} else if (is_one(b)) {
return make_or(a, c);
} else if (is_zero(b)) {
return make_and(make_not(a), c);
} else if (is_one(c)) {
return make_or(make_not(a), b);
} else if (is_zero(c)) {
return make_and(a, b);
} else {
return select(a, b, c);
}
}
Expr make_not(Expr a) {
if (is_one(a)) {
return make_zero(a.type());
} else if (is_zero(a)) {
return make_one(a.type());
} else {
return !a;
}
}
template<typename T>
void visit_conditional(Expr condition, T true_case, T false_case) {
Expr old_predicate = predicate;
predicate = const_false();
true_case.accept(this);
Expr true_predicate = predicate;
predicate = const_false();
if (false_case.defined()) false_case.accept(this);
Expr false_predicate = predicate;
bool old_varies = varies;
predicate = const_false();
varies = false;
condition.accept(this);
predicate = make_or(predicate, old_predicate);
if (varies) {
predicate = make_or(predicate, make_or(true_predicate, false_predicate));
} else {
predicate = make_or(predicate, make_select(condition, true_predicate, false_predicate));
}
varies = varies || old_varies;
}
void visit(const Select *op) {
if (treat_selects_as_guards) {
visit_conditional(op->condition, op->true_value, op->false_value);
} else {
IRVisitor::visit(op);
}
}
void visit(const IfThenElse *op) {
visit_conditional(op->condition, op->then_case, op->else_case);
}
void visit(const Call *op) {
varies |= in_pipeline.contains(op->name);
IRVisitor::visit(op);
if (op->name == buffer || extern_call_uses_buffer(op, buffer)) {
predicate = const_true();
}
}
void visit(const Allocate *op) {
varying.push(op->name, 0);
varying.push(op->name + ".buffer", 0);
IRVisitor::visit(op);
varying.pop(op->name + ".buffer");
varying.pop(op->name);
}
};
class ProductionGuarder : public IRMutator {
public:
ProductionGuarder(const string &b, Expr compute_p, Expr alloc_p):
buffer(b), compute_predicate(compute_p), alloc_predicate(alloc_p) {}
private:
string buffer;
Expr compute_predicate;
Expr alloc_predicate;
using IRMutator::visit;
bool memoize_call_uses_buffer(const Call *op) {
internal_assert(op->call_type == Call::Extern);
internal_assert(starts_with(op->name, "halide_memoization"));
for (size_t i = 0; i < op->args.size(); i++) {
const Variable *var = op->args[i].as<Variable>();
if (var &&
starts_with(var->name, buffer + ".") &&
ends_with(var->name, ".buffer")) {
return true;
}
}
return false;
}
void visit(const Call *op) {
if ((op->name == "halide_memoization_cache_lookup") &&
memoize_call_uses_buffer(op)) {
expr = Call::make(op->type, Call::if_then_else,
{alloc_predicate, op, 0}, Call::PureIntrinsic);
} else if ((op->name == "halide_memoization_cache_store") &&
memoize_call_uses_buffer(op)) {
expr = Call::make(op->type, Call::if_then_else,
{compute_predicate, op, 0}, Call::PureIntrinsic);
} else {
IRMutator::visit(op);
}
}
void visit(const ProducerConsumer *op) {
IRMutator::visit(op);
if (op->is_producer) {
op = stmt.as<ProducerConsumer>();
internal_assert(op);
if (op->name == buffer) {
Stmt body = IfThenElse::make(compute_predicate, op->body);
stmt = ProducerConsumer::make(op->name, op->is_producer, body);
}
}
}
};
class StageSkipper : public IRMutator {
public:
StageSkipper(const string &f) : func(f), in_vector_loop(false) {}
private:
string func;
using IRMutator::visit;
Scope<int> vector_vars;
bool in_vector_loop;
void visit(const For *op) {
bool old_in_vector_loop = in_vector_loop;
if (op->for_type == ForType::Vectorized) {
vector_vars.push(op->name, 0);
in_vector_loop = true;
}
IRMutator::visit(op);
if (op->for_type == ForType::Vectorized) {
vector_vars.pop(op->name);
}
in_vector_loop = old_in_vector_loop;
}
void visit(const LetStmt *op) {
bool should_pop = false;
if (in_vector_loop &&
expr_uses_vars(op->value, vector_vars)) {
should_pop = true;
vector_vars.push(op->name, 0);
}
IRMutator::visit(op);
if (should_pop) {
vector_vars.pop(op->name);
}
}
void visit(const Realize *op) {
if (op->name == func) {
debug(3) << "Finding compute predicate for " << op->name << "\n";
PredicateFinder find_compute(op->name, true);
op->body.accept(&find_compute);
debug(3) << "Simplifying compute predicate for " << op->name << ": " << find_compute.predicate << "\n";
Expr compute_predicate = simplify(common_subexpression_elimination(find_compute.predicate));
debug(3) << "Compute predicate for " << op->name << " : " << compute_predicate << "\n";
if (expr_uses_vars(compute_predicate, vector_vars)) {
compute_predicate = const_true();
}
if (!is_one(compute_predicate)) {
debug(3) << "Finding allocate predicate for " << op->name << "\n";
PredicateFinder find_alloc(op->name, false);
op->body.accept(&find_alloc);
debug(3) << "Simplifying allocate predicate for " << op->name << "\n";
Expr alloc_predicate = simplify(common_subexpression_elimination(find_alloc.predicate));
debug(3) << "Allocate predicate for " << op->name << " : " << alloc_predicate << "\n";
ProductionGuarder g(op->name, compute_predicate, alloc_predicate);
Stmt body = g.mutate(op->body);
debug(3) << "Done guarding computation for " << op->name << "\n";
stmt = Realize::make(op->name, op->types, op->bounds,
alloc_predicate, body);
} else {
IRMutator::visit(op);
}
} else {
IRMutator::visit(op);
}
}
};
class MightBeSkippable : public IRVisitor {
using IRVisitor::visit;
void visit(const Call *op) {
IRVisitor::visit(op);
if (op->name == func || extern_call_uses_buffer(op, func)) {
result &= guarded;
}
}
void visit(const IfThenElse *op) {
op->condition.accept(this);
bool old = guarded;
guarded = true;
op->then_case.accept(this);
if (op->else_case.defined()) {
op->else_case.accept(this);
}
guarded = old;
}
void visit(const Select *op) {
op->condition.accept(this);
bool old = guarded;
guarded = true;
op->true_value.accept(this);
op->false_value.accept(this);
guarded = old;
}
void visit(const Realize *op) {
if (op->name == func) {
guarded = false;
}
IRVisitor::visit(op);
}
void visit(const ProducerConsumer *op) {
if (!op->is_producer && (op->name == func)) {
bool old_result = result;
result = true;
op->body.accept(this);
result = result || old_result;
} else {
IRVisitor::visit(op);
}
}
string func;
bool guarded;
public:
bool result;
MightBeSkippable(string f) : func(f), guarded(false), result(false) {}
};
Stmt skip_stages(Stmt stmt, const vector<string> &order) {
for (size_t i = order.size()-1; i > 0; i--) {
debug(2) << "skip_stages checking " << order[i-1] << "\n";
MightBeSkippable check(order[i-1]);
stmt.accept(&check);
if (check.result) {
debug(2) << "skip_stages can skip " << order[i-1] << "\n";
StageSkipper skipper(order[i-1]);
stmt = skipper.mutate(stmt);
}
}
return stmt;
}
}
}