This source file includes following definitions.
- is_linear
- visit
- block_to_vector
- block_to_vector
- scratch_index
- visit
- step_forwards
- visit
- visit
- visit
- lift_carried_values_out_of_stmt
- visit
- visit
- max_carried_values
- visit
- visit
- loop_carry
#include "LoopCarry.h"
#include "IRMutator.h"
#include "Substitute.h"
#include "IROperator.h"
#include "Simplify.h"
#include "IREquality.h"
#include "ExprUsesVar.h"
#include "CSE.h"
#include <algorithm>
namespace Halide {
namespace Internal {
using std::vector;
using std::string;
using std::pair;
using std::set;
using std::map;
namespace {
Expr is_linear(Expr e, const Scope<Expr> &linear) {
if (e.type() != Int(32)) {
return Expr();
}
if (const Variable *v = e.as<Variable>()) {
if (linear.contains(v->name)) {
return linear.get(v->name);
} else {
return make_zero(v->type);
}
} else if (const IntImm *op = e.as<IntImm>()) {
return make_zero(op->type);
} else if (const Add *add = e.as<Add>()) {
Expr la = is_linear(add->a, linear);
Expr lb = is_linear(add->b, linear);
if (is_zero(lb)) {
return la;
} else if (is_zero(la)) {
return lb;
} else if (la.defined() && lb.defined()) {
return la + lb;
} else {
return Expr();
}
} else if (const Sub *sub = e.as<Sub>()) {
Expr la = is_linear(sub->a, linear);
Expr lb = is_linear(sub->b, linear);
if (is_zero(lb)) {
return la;
} else if (la.defined() && lb.defined()) {
return la - lb;
} else {
return Expr();
}
} else if (const Mul *mul = e.as<Mul>()) {
Expr la = is_linear(mul->a, linear);
Expr lb = is_linear(mul->b, linear);
if (is_zero(la) && is_zero(lb)) {
return la;
} else if (is_zero(la) && lb.defined()) {
return mul->a * lb;
} else if (la.defined() && is_zero(lb)) {
return la * mul->b;
} else {
return Expr();
}
} else if (const Ramp *r = e.as<Ramp>()) {
Expr la = is_linear(r->base, linear);
Expr lb = is_linear(r->stride, linear);
if (is_zero(lb)) {
return la;
} else {
return Expr();
}
} else if (const Broadcast *b = e.as<Broadcast>()) {
return is_linear(b->value, linear);
} else {
return Expr();
}
}
class FindLoads : public IRGraphVisitor {
using IRVisitor::visit;
set<const Load *> found;
void visit(const Load *op) {
if (found.count(op) == 0) {
found.insert(op);
result.push_back(op);
}
}
public:
vector<const Load *> result;
};
void block_to_vector(Stmt s, vector<Stmt> &v) {
const Block *b = s.as<Block>();
if (!b) {
v.push_back(s);
} else {
block_to_vector(b->first, v);
block_to_vector(b->rest, v);
}
}
vector<Stmt> block_to_vector(Stmt s) {
vector<Stmt> result;
block_to_vector(s, result);
return result;
}
Expr scratch_index(int i, Type t) {
if (t.is_scalar()) {
return i;
} else {
return Ramp::make(i * t.lanes(), 1, t.lanes());
}
}
class StepForwards : public IRGraphMutator {
const Scope<Expr> &linear;
using IRGraphMutator::visit;
void visit(const Variable *op) {
if (linear.contains(op->name)) {
Expr step = linear.get(op->name);
if (!step.defined()) {
success = false;
expr = op;
} else if (is_zero(step)) {
expr = op;
} else {
expr = Expr(op) + step;
}
} else {
expr = op;
}
}
public:
bool success = true;
StepForwards(const Scope<Expr> &s) : linear(s) {}
};
Expr step_forwards(Expr e, const Scope<Expr> &linear) {
StepForwards step(linear);
e = step.mutate(e);
if (!step.success) {
return Expr();
} else {
e = common_subexpression_elimination(e);
e = simplify(e);
e = substitute_in_all_lets(e);
return e;
}
}
class LoopCarryOverLoop : public IRMutator {
Scope<Expr> linear;
vector<pair<string, Expr>> containing_lets;
const Scope<int> &in_consume;
int max_carried_values;
using IRMutator::visit;
void visit(const LetStmt *op) {
Expr value = mutate(op->value);
Expr step = is_linear(value, linear);
linear.push(op->name, step);
containing_lets.push_back({ op->name, value });
Stmt body = mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
stmt = op;
} else {
stmt = LetStmt::make(op->name, value, body);
}
containing_lets.pop_back();
linear.pop(op->name);
}
void visit(const Store *op) {
stmt = lift_carried_values_out_of_stmt(op);
}
void visit(const Block *op) {
vector<Stmt> v = block_to_vector(op);
vector<Stmt> stores;
vector<Stmt> result;
for (size_t i = 0; i < v.size(); i++) {
if (v[i].as<Store>()) {
stores.push_back(v[i]);
} else {
if (!stores.empty()) {
result.push_back(lift_carried_values_out_of_stmt(Block::make(stores)));
stores.clear();
}
result.push_back(mutate(v[i]));
}
}
if (!stores.empty()) {
result.push_back(lift_carried_values_out_of_stmt(Block::make(stores)));
}
stmt = Block::make(result);
}
Stmt lift_carried_values_out_of_stmt(Stmt orig_stmt) {
debug(4) << "About to lift carried values out of stmt: " << orig_stmt << "\n";
Stmt graph_stmt = substitute_in_all_lets(orig_stmt);
FindLoads find_loads;
graph_stmt.accept(&find_loads);
debug(4) << "Found " << find_loads.result.size() << " loads\n";
vector<vector<const Load *>> loads;
for (const Load *load : find_loads.result) {
bool safe = (load->image.defined() ||
load->param.defined() ||
in_consume.contains(load->name));
if (!safe) continue;
bool represented = false;
for (vector<const Load *> &v : loads) {
if (graph_equal(Expr(load), Expr(v[0]))) {
v.push_back(load);
represented = true;
}
}
if (!represented) {
loads.push_back({load});
}
}
vector<Expr> indices, next_indices, predicates, next_predicates;
for (const vector<const Load *> &v: loads) {
indices.push_back(v[0]->index);
next_indices.push_back(step_forwards(v[0]->index, linear));
predicates.push_back(v[0]->predicate);
next_predicates.push_back(step_forwards(v[0]->predicate, linear));
}
vector<vector<int>> chains;
for (int i = 0; i < (int)indices.size(); i++) {
for (int j = 0; j < (int)indices.size(); j++) {
if (i == j) continue;
if (loads[i][0]->name == loads[j][0]->name &&
next_indices[j].defined() &&
graph_equal(indices[i], next_indices[j]) &&
next_predicates[j].defined() &&
graph_equal(predicates[i], next_predicates[j])) {
chains.push_back({j, i});
debug(3) << "Found carried value:\n"
<< i << ": -> " << Expr(loads[i][0]) << "\n"
<< j << ": -> " << Expr(loads[j][0]) << "\n";
}
}
}
if (chains.empty()) {
return orig_stmt;
}
bool done = false;
while (!done) {
done = true;
for (size_t i = 0; i < chains.size(); i++) {
if (chains[i].empty()) continue;
for (size_t j = 0; j < chains.size(); j++) {
if (chains[j].empty()) continue;
if (chains[i].back() == chains[j].front()) {
chains[i].insert(chains[i].end(), chains[j].begin()+1, chains[j].end());
chains[j].clear();
done = false;
}
}
}
for (size_t i = 0; i < chains.size(); i++) {
while (i < chains.size() && chains[i].empty()) {
chains[i].swap(chains.back());
chains.pop_back();
}
}
}
std::sort(chains.begin(), chains.end(),
[&](const vector<int> &c1, const vector<int> &c2){return c1.size() > c2.size();});
for (const vector<int> &c : chains) {
debug(3) << "Found chain of carried values:\n";
for (int i : c) {
debug(3) << i << ": <- " << indices[i] << "\n";
}
}
vector<vector<int>> trimmed;
size_t sz = 0;
for (const vector<int> &c : chains) {
if (sz + c.size() > (size_t)max_carried_values) {
if (sz < (size_t)max_carried_values - 1) {
trimmed.emplace_back(c.begin(), c.begin() + max_carried_values - sz);
}
break;
}
trimmed.push_back(c);
sz += c.size();
}
chains.swap(trimmed);
vector<Stmt> not_first_iteration_scratch_stores;
vector<Stmt> scratch_shuffles;
Stmt core = graph_stmt;
for (const vector<int> &c : chains) {
string scratch = unique_name('c');
vector<Expr> initial_scratch_values;
for (size_t i = 0; i < c.size(); i++) {
const Load *orig_load = loads[c[i]][0];
Expr scratch_idx = scratch_index(i, orig_load->type);
Expr load_from_scratch = Load::make(orig_load->type, scratch, scratch_idx,
Buffer<>(), Parameter(), const_true(orig_load->type.lanes()));
for (const Load *l : loads[c[i]]) {
core = graph_substitute(l, load_from_scratch, core);
}
if (i == c.size() - 1) {
Stmt store_to_scratch = Store::make(scratch, orig_load, scratch_idx,
Parameter(), const_true(orig_load->type.lanes()));
not_first_iteration_scratch_stores.push_back(store_to_scratch);
} else {
initial_scratch_values.push_back(orig_load);
}
if (i > 0) {
Stmt shuffle = Store::make(scratch, load_from_scratch,
scratch_index(i-1, orig_load->type),
Parameter(), const_true(orig_load->type.lanes()));
scratch_shuffles.push_back(shuffle);
}
}
vector<pair<string, Expr>> initial_lets;
Expr call = Call::make(Int(32), unique_name('b'), initial_scratch_values, Call::PureIntrinsic);
call = simplify(common_subexpression_elimination(call));
while (const Let *l = call.as<Let>()) {
initial_lets.push_back({ l->name, l->value });
call = l->body;
}
internal_assert(call.as<Call>());
initial_scratch_values = call.as<Call>()->args;
vector<Stmt> initial_scratch_stores;
for (size_t i = 0; i < c.size() - 1; i++) {
Expr scratch_idx = scratch_index(i, initial_scratch_values[i].type());
Stmt store_to_scratch = Store::make(scratch, initial_scratch_values[i],
scratch_idx, Parameter(),
const_true(scratch_idx.type().lanes()));
initial_scratch_stores.push_back(store_to_scratch);
}
Stmt initial_stores = Block::make(initial_scratch_stores);
for (size_t i = initial_lets.size(); i > 0; i--) {
auto l = initial_lets[i-1];
initial_stores = LetStmt::make(l.first, l.second, initial_stores);
}
for (size_t i = containing_lets.size(); i > 0; i--) {
auto l = containing_lets[i-1];
if (stmt_uses_var(initial_stores, l.first)) {
initial_stores = LetStmt::make(l.first, l.second, initial_stores);
}
}
allocs.push_back({scratch,
loads[c.front()][0]->type.element_of(),
(int)c.size() * loads[c.front()][0]->type.lanes(),
initial_stores});
}
Stmt s = Block::make(not_first_iteration_scratch_stores);
s = Block::make(s, core);
s = Block::make(s, Block::make(scratch_shuffles));
s = common_subexpression_elimination(s);
return s;
}
void visit(const For *op) {
stmt = op;
}
void visit(const IfThenElse *op) {
stmt = op;
}
public:
LoopCarryOverLoop(const string &var, const Scope<int> &s, int max_carried_values)
: in_consume(s), max_carried_values(max_carried_values) {
linear.push(var, 1);
}
struct ScratchAllocation {
string name;
Type type;
int size;
Stmt initial_stores;
};
vector<ScratchAllocation> allocs;
};
class LoopCarry : public IRMutator {
using IRMutator::visit;
int max_carried_values;
Scope<int> in_consume;
void visit(const ProducerConsumer *op) {
if (op->is_producer) {
IRMutator::visit(op);
} else {
in_consume.push(op->name, 0);
Stmt body = mutate(op->body);
in_consume.pop(op->name);
stmt = ProducerConsumer::make(op->name, op->is_producer, body);
}
}
void visit(const For *op) {
if (op->for_type == ForType::Serial && !is_one(op->extent)) {
Stmt body = mutate(op->body);
LoopCarryOverLoop carry(op->name, in_consume, max_carried_values);
body = carry.mutate(body);
if (body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
}
for (const auto &alloc : carry.allocs) {
stmt = Block::make(substitute(op->name, op->min, alloc.initial_stores), stmt);
stmt = Allocate::make(alloc.name, alloc.type, {alloc.size}, const_true(), stmt);
}
if (!carry.allocs.empty()) {
stmt = IfThenElse::make(op->extent > 0, stmt);
}
} else {
IRMutator::visit(op);
}
}
public:
LoopCarry(int max_carried_values) : max_carried_values(max_carried_values) {}
};
}
Stmt loop_carry(Stmt s, int max_carried_values) {
s = LoopCarry(max_carried_values).mutate(s);
return s;
}
}
}