This source file includes following definitions.
- visit
- visit
- var
- expr_depends_on_var
- visit
- expand_expr
- is_dim_always_pure
- visit
- visit
- visit
- loop_min
- visit
- visit
- sliding_window
#include "SlidingWindow.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Scope.h"
#include "Debug.h"
#include "Substitute.h"
#include "IRPrinter.h"
#include "Simplify.h"
#include "Monotonic.h"
#include "Bounds.h"
namespace Halide {
namespace Internal {
using std::string;
using std::map;
namespace {
class ExprDependsOnVar : public IRVisitor {
using IRVisitor::visit;
void visit(const Variable *op) {
if (op->name == var) result = true;
}
void visit(const Let *op) {
op->value.accept(this);
if (op->name != var) {
op->body.accept(this);
}
}
public:
bool result;
string var;
ExprDependsOnVar(string v) : result(false), var(v) {
}
};
bool expr_depends_on_var(Expr e, string v) {
ExprDependsOnVar depends(v);
e.accept(&depends);
return depends.result;
}
class ExpandExpr : public IRMutator {
using IRMutator::visit;
const Scope<Expr> &scope;
void visit(const Variable *var) {
if (scope.contains(var->name)) {
expr = scope.get(var->name);
debug(3) << "Fully expanded " << var->name << " -> " << expr << "\n";
} else {
expr = var;
}
}
public:
ExpandExpr(const Scope<Expr> &s) : scope(s) {}
};
Expr expand_expr(Expr e, const Scope<Expr> &scope) {
ExpandExpr ee(scope);
Expr result = ee.mutate(e);
debug(3) << "Expanded " << e << " into " << result << "\n";
return result;
}
}
class SlidingWindowOnFunctionAndLoop : public IRMutator {
Function func;
string loop_var;
Expr loop_min;
Scope<Expr> scope;
map<string, Expr> replacements;
using IRMutator::visit;
bool is_dim_always_pure(const Definition &def, const string& dim, int dim_idx) {
const Variable *var = def.args()[dim_idx].as<Variable>();
if ((!var) || (var->name != dim)) {
return false;
}
for (const auto &s : def.specializations()) {
bool pure = is_dim_always_pure(s.definition, dim, dim_idx);
if (!pure) {
return false;
}
}
return true;
}
void visit(const ProducerConsumer *op) {
if (!op->is_producer || (op->name != func.name())) {
IRMutator::visit(op);
} else {
stmt = op;
string dim = "";
int dim_idx = 0;
Expr min_required, max_required;
debug(3) << "Considering sliding " << func.name()
<< " along loop variable " << loop_var << "\n"
<< "Region provided:\n";
string prefix = func.name() + ".s" + std::to_string(func.updates().size()) + ".";
const std::vector<string> func_args = func.args();
for (int i = 0; i < func.dimensions(); i++) {
string var = prefix + func_args[i];
internal_assert(scope.contains(var + ".min") && scope.contains(var + ".max"));
Expr min_req = scope.get(var + ".min");
Expr max_req = scope.get(var + ".max");
min_req = expand_expr(min_req, scope);
max_req = expand_expr(max_req, scope);
debug(3) << func_args[i] << ":" << min_req << ", " << max_req << "\n";
if (expr_depends_on_var(min_req, loop_var) ||
expr_depends_on_var(max_req, loop_var)) {
if (!dim.empty()) {
dim = "";
min_required = Expr();
max_required = Expr();
break;
} else {
dim = func_args[i];
dim_idx = i;
min_required = min_req;
max_required = max_req;
}
}
}
if (!min_required.defined()) {
debug(3) << "Could not perform sliding window optimization of "
<< func.name() << " over " << loop_var << " because either zero "
<< "or many dimensions of the function dependended on the loop var\n";
return;
}
bool pure = true;
for (const Definition &def : func.updates()) {
pure = is_dim_always_pure(def, dim, dim_idx);
if (!pure) {
break;
}
}
if (!pure) {
debug(3) << "Could not performance sliding window optimization of "
<< func.name() << " over " << loop_var << " because the function "
<< "scatters along the related axis.\n";
return;
}
bool can_slide_up = false;
bool can_slide_down = false;
Monotonic monotonic_min = is_monotonic(min_required, loop_var);
Monotonic monotonic_max = is_monotonic(max_required, loop_var);
if (monotonic_min == Monotonic::Increasing ||
monotonic_min == Monotonic::Constant) {
can_slide_up = true;
}
if (monotonic_max == Monotonic::Decreasing ||
monotonic_max == Monotonic::Constant) {
can_slide_down = true;
}
if (!can_slide_up && !can_slide_down) {
debug(3) << "Not sliding " << func.name()
<< " over dimension " << dim
<< " along loop variable " << loop_var
<< " because I couldn't prove it moved monotonically along that dimension\n"
<< "Min is " << min_required << "\n"
<< "Max is " << max_required << "\n";
return;
}
debug(3) << "Sliding " << func.name()
<< " over dimension " << dim
<< " along loop variable " << loop_var << "\n";
Expr loop_var_expr = Variable::make(Int(32), loop_var);
Expr prev_max_plus_one = substitute(loop_var, loop_var_expr - 1, max_required) + 1;
Expr prev_min_minus_one = substitute(loop_var, loop_var_expr - 1, min_required) - 1;
if (can_prove(min_required >= prev_max_plus_one) ||
can_prove(max_required <= prev_min_minus_one)) {
debug(3) << "Not sliding " << func.name()
<< " over dimension " << dim
<< " along loop variable " << loop_var
<< " there's no overlap in the region computed across iterations\n"
<< "Min is " << min_required << "\n"
<< "Max is " << max_required << "\n";
return;
}
Expr new_min, new_max;
if (can_slide_up) {
new_min = select(loop_var_expr <= loop_min, min_required, likely(prev_max_plus_one));
new_max = max_required;
} else {
new_min = min_required;
new_max = select(loop_var_expr <= loop_min, max_required, likely(prev_min_minus_one));
}
Expr early_stages_min_required = new_min;
Expr early_stages_max_required = new_max;
debug(3) << "Sliding " << func.name() << ", " << dim << "\n"
<< "Pushing min up from " << min_required << " to " << new_min << "\n"
<< "Shrinking max from " << max_required << " to " << new_max << "\n";
if (can_slide_up) {
replacements[prefix + dim + ".min"] = new_min;
} else {
replacements[prefix + dim + ".max"] = new_max;
}
for (size_t i = 0; i < func.updates().size(); i++) {
string n = func.name() + ".s" + std::to_string(i) + "." + dim;
replacements[n + ".min"] = Variable::make(Int(32), prefix + dim + ".min");
replacements[n + ".max"] = Variable::make(Int(32), prefix + dim + ".max");
}
if (!func.updates().empty()) {
Box b = box_provided(op->body, func.name());
if (can_slide_up) {
string n = prefix + dim + ".min";
Expr var = Variable::make(Int(32), n);
stmt = LetStmt::make(n, min(var, b[dim_idx].min), stmt);
} else {
string n = prefix + dim + ".max";
Expr var = Variable::make(Int(32), n);
stmt = LetStmt::make(n, max(var, b[dim_idx].max), stmt);
}
}
}
}
void visit(const For *op) {
Expr min = expand_expr(op->min, scope);
Expr extent = expand_expr(op->extent, scope);
if (is_one(extent)) {
Stmt s = LetStmt::make(op->name, min, op->body);
s = mutate(s);
const LetStmt *l = s.as<LetStmt>();
internal_assert(l);
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, l->body);
} else if (is_monotonic(min, loop_var) != Monotonic::Constant ||
is_monotonic(extent, loop_var) != Monotonic::Constant) {
debug(3) << "Not entering loop over " << op->name
<< " because the bounds depend on the var we're sliding over: "
<< min << ", " << extent << "\n";
stmt = op;
} else {
IRMutator::visit(op);
}
}
void visit(const LetStmt *op) {
scope.push(op->name, simplify(expand_expr(op->value, scope)));
Stmt new_body = mutate(op->body);
Expr value = op->value;
map<string, Expr>::iterator iter = replacements.find(op->name);
if (iter != replacements.end()) {
value = iter->second;
replacements.erase(iter);
}
if (new_body.same_as(op->body) && value.same_as(op->value)) {
stmt = op;
} else {
stmt = LetStmt::make(op->name, value, new_body);
}
scope.pop(op->name);
}
public:
SlidingWindowOnFunctionAndLoop(Function f, string v, Expr v_min) : func(f), loop_var(v), loop_min(v_min) {}
};
class SlidingWindowOnFunction : public IRMutator {
Function func;
using IRMutator::visit;
void visit(const For *op) {
debug(3) << " Doing sliding window analysis over loop: " << op->name << "\n";
Stmt new_body = op->body;
new_body = mutate(new_body);
if (op->for_type == ForType::Serial ||
op->for_type == ForType::Unrolled) {
new_body = SlidingWindowOnFunctionAndLoop(func, op->name, op->min).mutate(new_body);
}
if (new_body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, new_body);
}
}
public:
SlidingWindowOnFunction(Function f) : func(f) {}
};
class SlidingWindow : public IRMutator {
const map<string, Function> &env;
using IRMutator::visit;
void visit(const Realize *op) {
map<string, Function>::const_iterator iter = env.find(op->name);
if (iter == env.end()) {
IRMutator::visit(op);
return;
}
const Schedule &sched = iter->second.schedule();
if (sched.compute_level() == sched.store_level()) {
IRMutator::visit(op);
return;
}
Stmt new_body = op->body;
debug(3) << "Doing sliding window analysis on realization of " << op->name << "\n";
new_body = SlidingWindowOnFunction(iter->second).mutate(new_body);
new_body = mutate(new_body);
if (new_body.same_as(op->body)) {
stmt = op;
} else {
stmt = Realize::make(op->name, op->types, op->bounds, op->condition, new_body);
}
}
public:
SlidingWindow(const map<string, Function> &e) : env(e) {}
};
Stmt sliding_window(Stmt s, const map<string, Function> &env) {
return SlidingWindow(env).mutate(s);
}
}
}