This source file includes following definitions.
- should_extract
- cache
- mutate
- with_cache
- mutate
- visit
- visit
- visit
- include
- mutate
- mutate
- common_subexpression_elimination
- common_subexpression_elimination
- visit
- visit
- check
- ssa_block
- cse_test
#include <map>
#include "CSE.h"
#include "IRMutator.h"
#include "IREquality.h"
#include "IROperator.h"
#include "Scope.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
using std::vector;
using std::string;
using std::map;
using std::pair;
namespace {
bool should_extract(Expr e) {
if (is_const(e)) {
return false;
}
if (e.as<Variable>()) {
return false;
}
if (const Broadcast *a = e.as<Broadcast>()) {
return should_extract(a->value);
}
if (const Cast *a = e.as<Cast>()) {
return should_extract(a->value);
}
if (const Add *a = e.as<Add>()) {
return !(is_const(a->a) || is_const(a->b));
}
if (const Sub *a = e.as<Sub>()) {
return !(is_const(a->a) || is_const(a->b));
}
if (const Mul *a = e.as<Mul>()) {
return !(is_const(a->a) || is_const(a->b));
}
if (const Div *a = e.as<Div>()) {
return !(is_const(a->a) || is_const(a->b));
}
if (const Ramp *a = e.as<Ramp>()) {
return !is_const(a->stride);
}
return true;
}
class GVN : public IRMutator {
public:
struct Entry {
Expr expr;
int use_count;
};
vector<Entry> entries;
typedef map<ExprWithCompareCache, int> CacheType;
CacheType numbering;
map<Expr, int, ExprCompare> shallow_numbering;
Scope<int> let_substitutions;
int number;
IRCompareCache cache;
GVN() : number(0), cache(8) {}
Stmt mutate(Stmt s) {
internal_error << "Can't call GVN on a Stmt: " << s << "\n";
return Stmt();
}
ExprWithCompareCache with_cache(Expr e) {
return ExprWithCompareCache(e, &cache);
}
Expr mutate(Expr e) {
{
map<Expr, int, ExprCompare>::iterator iter = shallow_numbering.find(e);
if (iter != shallow_numbering.end()) {
number = iter->second;
internal_assert(entries[number].expr.type() == e.type());
return entries[number].expr;
}
}
if (const Variable *var = e.as<Variable>()) {
if (let_substitutions.contains(var->name)) {
number = let_substitutions.get(var->name);
internal_assert(entries[number].expr.type() == e.type());
return entries[number].expr;
}
}
CacheType::iterator iter = numbering.find(with_cache(e));
if (iter != numbering.end()) {
number = iter->second;
shallow_numbering[e] = number;
internal_assert(entries[number].expr.type() == e.type());
return entries[number].expr;
}
Expr old_e = e;
e = IRMutator::mutate(e);
iter = numbering.find(with_cache(e));
if (iter != numbering.end()) {
number = iter->second;
shallow_numbering[old_e] = number;
internal_assert(entries[number].expr.type() == old_e.type());
return entries[number].expr;
}
Entry entry = {e, 0};
number = (int)entries.size();
numbering[with_cache(e)] = number;
shallow_numbering[e] = number;
entries.push_back(entry);
internal_assert(e.type() == old_e.type());
return e;
}
using IRMutator::visit;
void visit(const Let *let) {
Expr value = mutate(let->value);
let_substitutions.push(let->name, number);
Expr body = mutate(let->body);
let_substitutions.pop(let->name);
expr = body;
}
void visit(const Load *op) {
Expr predicate = op->predicate;
if (!is_one(predicate)) {
predicate = mutate(op->predicate);
}
Expr index = mutate(op->index);
if (predicate.same_as(op->predicate) && index.same_as(op->index)) {
expr = op;
} else {
expr = Load::make(op->type, op->name, index, op->image, op->param, predicate);
}
}
void visit(const Store *op) {
Expr predicate = op->predicate;
if (!is_one(predicate)) {
predicate = mutate(op->predicate);
}
Expr value = mutate(op->value);
Expr index = mutate(op->index);
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
stmt = op;
} else {
stmt = Store::make(op->name, value, index, op->param, predicate);
}
}
};
class ComputeUseCounts : public IRGraphVisitor {
GVN &gvn;
public:
ComputeUseCounts(GVN &g) : gvn(g) {}
using IRGraphVisitor::include;
using IRGraphVisitor::visit;
void include(const Expr &e) {
debug(4) << "Include: " << e << "; should extract: " << should_extract(e) << "\n";
if (!should_extract(e)) {
e.accept(this);
return;
}
map<Expr, int, ExprCompare>::iterator iter = gvn.shallow_numbering.find(e);
if (iter != gvn.shallow_numbering.end()) {
GVN::Entry &entry = gvn.entries[iter->second];
entry.use_count++;
}
if (!visited.count(e.get())) {
visited.insert(e.get());
e.accept(this);
}
}
};
class Replacer : public IRMutator {
public:
map<Expr, Expr, ExprCompare> replacements;
Replacer(const map<Expr, Expr, ExprCompare> &r) : replacements(r) {}
using IRMutator::mutate;
Expr mutate(Expr e) {
map<Expr, Expr, ExprCompare>::iterator iter = replacements.find(e);
if (iter != replacements.end()) {
return iter->second;
}
Expr new_e = IRMutator::mutate(e);
replacements[e] = new_e;
return new_e;
}
};
class CSEEveryExprInStmt : public IRMutator {
public:
using IRMutator::mutate;
Expr mutate(Expr e) {
return common_subexpression_elimination(e);
}
};
}
Expr common_subexpression_elimination(Expr e) {
if (is_const(e) || e.as<Variable>()) return e;
debug(4) << "\n\n\nInput to letify " << e << "\n";
GVN gvn;
e = gvn.mutate(e);
ComputeUseCounts count_uses(gvn);
count_uses.include(e);
debug(4) << "Canonical form without lets " << e << "\n";
vector<pair<string, Expr>> lets;
vector<Expr> new_version(gvn.entries.size());
map<Expr, Expr, ExprCompare> replacements;
for (size_t i = 0; i < gvn.entries.size(); i++) {
const GVN::Entry &e = gvn.entries[i];
Expr old = e.expr;
if (e.use_count > 1) {
string name = unique_name('t');
lets.push_back({ name, e.expr });
replacements[e.expr] = Variable::make(e.expr.type(), name);
}
debug(4) << i << ": " << e.expr << ", " << e.use_count << "\n";
}
Replacer replacer(replacements);
e = replacer.mutate(e);
debug(4) << "With variables " << e << "\n";
for (size_t i = lets.size(); i > 0; i--) {
Expr value = lets[i-1].second;
replacer.replacements.erase(value);
value = replacer.mutate(lets[i-1].second);
e = Let::make(lets[i-1].first, value, e);
}
debug(4) << "With lets: " << e << "\n";
return e;
}
Stmt common_subexpression_elimination(Stmt s) {
return CSEEveryExprInStmt().mutate(s);
}
namespace {
class NormalizeVarNames : public IRMutator {
int counter;
map<string, string> new_names;
using IRMutator::visit;
void visit(const Variable *var) {
map<string, string>::iterator iter = new_names.find(var->name);
if (iter == new_names.end()) {
expr = var;
} else {
expr = Variable::make(var->type, iter->second);
}
}
void visit(const Let *let) {
string new_name = "t" + std::to_string(counter++);
new_names[let->name] = new_name;
Expr value = mutate(let->value);
Expr body = mutate(let->body);
expr = Let::make(new_name, value, body);
}
public:
NormalizeVarNames() : counter(0) {}
};
void check(Expr in, Expr correct) {
Expr result = common_subexpression_elimination(in);
NormalizeVarNames n;
result = n.mutate(result);
internal_assert(equal(result, correct))
<< "Incorrect CSE:\n" << in
<< "\nbecame:\n" << result
<< "\ninstead of:\n" << correct << "\n";
}
Expr ssa_block(vector<Expr> exprs) {
Expr e = exprs.back();
for (size_t i = exprs.size() - 1; i > 0; i--) {
string name = "t" + std::to_string(i-1);
e = Let::make(name, exprs[i-1], e);
}
return e;
}
}
void cse_test() {
Expr x = Variable::make(Int(32), "x");
Expr y = Variable::make(Int(32), "y");
Expr t[32], tf[32];
for (int i = 0; i < 32; i++) {
t[i] = Variable::make(Int(32), "t" + std::to_string(i));
tf[i] = Variable::make(Float(32), "t" + std::to_string(i));
}
Expr e, correct;
e = ssa_block({sin(x), tf[0]*tf[0]});
check(e, e);
e = ((x*x + x)*(x*x + x)) + x*x;
e += e;
correct = ssa_block({x*x,
t[0] + x,
t[1] * t[1] + t[0],
t[2] + t[2]});
check(e, correct);
check(correct, correct);
e = ssa_block({x*x,
x*x,
t[0] / t[1],
t[1] / t[1],
t[2] % t[3],
(t[4] + x*x) + x*x});
correct = ssa_block({x*x,
t[0] / t[0],
(t[1] % t[1] + t[0]) + t[0]});
check(e, correct);
Expr e1 = ssa_block({x*x,
t[0] + x,
t[1] * t[1] * t[0]});
Expr e2 = ssa_block({x*x,
t[0] - x,
t[1] * t[1] * t[0]});
e = ssa_block({e1 + x*x,
e1 + e2,
t[0] + t[0] * t[1]});
correct = ssa_block({x*x,
t[0] + x,
t[1] * t[1] * t[0],
t[2] + t[0],
t[0] - x,
t[3] + t[3] * (t[2] + t[4] * t[4] * t[0])});
check(e, correct);
e = x;
for (int i = 0; i < 100; i++) {
e = e*e + e + i;
e = e*e - e * i;
}
Expr result = common_subexpression_elimination(e);
{
Expr pred = x*x + y*y > 0;
Expr index = select(x*x + y*y > 0, x*x + y*y + 2, x*x + y*y + 10);
Expr load = Load::make(Int(32), "buf", index, Buffer<>(), Parameter(), const_true());
Expr pred_load = Load::make(Int(32), "buf", index, Buffer<>(), Parameter(), pred);
e = select(x*y > 10, x*y + 2, x*y + 3 + load) + pred_load;
Expr t2 = Variable::make(Bool(), "t2");
Expr cse_load = Load::make(Int(32), "buf", t[3], Buffer<>(), Parameter(), const_true());
Expr cse_pred_load = Load::make(Int(32), "buf", t[3], Buffer<>(), Parameter(), t2);
correct = ssa_block({x*y,
x*x + y*y,
t[1] > 0,
select(t2, t[1] + 2, t[1] + 10),
select(t[0] > 10, t[0] + 2, t[0] + 3 + cse_load) + cse_pred_load});
check(e, correct);
}
{
Expr pred = x*x + y*y > 0;
Expr index = select(x*x + y*y > 0, x*x + y*y + 2, x*x + y*y + 10);
Expr load = Load::make(Int(32), "buf", index, Buffer<>(), Parameter(), const_true());
Expr pred_load = Load::make(Int(32), "buf", index, Buffer<>(), Parameter(), pred);
e = select(x*y > 10, x*y + 2, x*y + 3 + pred_load) + pred_load;
Expr t2 = Variable::make(Bool(), "t2");
Expr cse_load = Load::make(Int(32), "buf", select(t2, t[1] + 2, t[1] + 10), Buffer<>(), Parameter(), const_true());
Expr cse_pred_load = Load::make(Int(32), "buf", select(t2, t[1] + 2, t[1] + 10), Buffer<>(), Parameter(), t2);
correct = ssa_block({x*y,
x*x + y*y,
t[1] > 0,
cse_pred_load,
select(t[0] > 10, t[0] + 2, t[0] + 3 + t[3]) + t[3]});
check(e, correct);
}
{
Expr handle_a = reinterpret(type_of<int *>(), make_zero(UInt(64)));
Expr handle_b = reinterpret(type_of<float *>(), make_zero(UInt(64)));
Expr handle_c = reinterpret(type_of<float *>(), make_zero(UInt(64)));
e = Call::make(Int(32), "dummy", {handle_a, handle_b, handle_c}, Call::Extern);
Expr t0 = Variable::make(handle_b.type(), "t0");
correct = Let::make("t0", handle_b,
Call::make(Int(32), "dummy", {handle_a, t0, t0}, Call::Extern));
check(e, correct);
}
debug(0) << "common_subexpression_elimination test passed\n";
}
}
}