This source file includes following definitions.
- visit
- visit
- visit
- get_new_name
- visit
- can_parallelize_rvar
#include "IR.h"
#include "ParallelRVar.h"
#include "IRMutator.h"
#include "Debug.h"
#include "Simplify.h"
#include "IROperator.h"
#include "Substitute.h"
#include "CSE.h"
#include "IREquality.h"
namespace Halide {
namespace Internal {
using std::string;
using std::vector;
using std::map;
namespace {
class FindLoads : public IRVisitor {
using IRVisitor::visit;
const string &func;
void visit(const Call *op) {
if (op->name == func && op->call_type == Call::Halide) {
loads.push_back(op->args);
}
IRVisitor::visit(op);
}
void visit(const Let *op) {
IRVisitor::visit(op);
for (size_t i = 0; i < loads.size(); i++) {
for (size_t j = 0; j < loads[i].size(); j++) {
loads[i][j] = substitute(op->name, op->value, loads[i][j]);
}
}
}
public:
FindLoads(const string &f) : func(f) {}
vector<vector<Expr>> loads;
};
class RenameFreeVars : public IRMutator {
using IRMutator::visit;
map<string, string> new_names;
void visit(const Variable *op) {
if (!op->param.defined() && !op->image.defined()) {
expr = Variable::make(op->type, get_new_name(op->name));
} else {
expr = op;
}
}
public:
string get_new_name(const string &s) {
map<string, string>::iterator iter = new_names.find(s);
if (iter != new_names.end()) {
return iter->second;
} else {
string new_name = s + "$_";
new_names[s] = new_name;
return new_name;
}
}
};
class SubstituteInBooleanLets : public IRMutator {
using IRMutator::visit;
void visit(const Let *op) {
if (op->value.type() == Bool()) {
expr = substitute(op->name, mutate(op->value), mutate(op->body));
} else {
IRMutator::visit(op);
}
}
};
}
bool can_parallelize_rvar(const string &v,
const string &f,
const Definition &r) {
const vector<Expr> &values = r.values();
const vector<Expr> &args = r.args();
const vector<ReductionVariable> &rvars = r.schedule().rvars();
FindLoads find(f);
for (size_t i = 0; i < values.size(); i++) {
values[i].accept(&find);
}
RenameFreeVars renamer;
vector<Expr> other_store(args.size());
for (size_t i = 0; i < args.size(); i++) {
other_store[i] = renamer.mutate(args[i]);
}
Expr distinct_v = (Variable::make(Int(32), v) !=
Variable::make(Int(32), renamer.get_new_name(v)));
Expr hazard = const_true();
for (size_t i = 0; i < args.size(); i++) {
hazard = hazard && (distinct_v && (args[i] == other_store[i]));
}
for (size_t i = 0; i < find.loads.size(); i++) {
internal_assert(find.loads[i].size() == other_store.size());
Expr check = const_true();
for (size_t j = 0; j < find.loads[i].size(); j++) {
check = check && (distinct_v && (find.loads[i][j] == other_store[j]));
}
hazard = hazard || check;
}
Scope<Interval> bounds;
for (const auto &rv : rvars) {
Interval in = Interval(rv.min, simplify(rv.min + rv.extent - 1));
bounds.push(rv.var, in);
bounds.push(renamer.get_new_name(rv.var), in);
}
Expr pred = simplify(r.predicate());
if (pred.defined() || !equal(const_true(), pred)) {
Expr this_pred = pred;
Expr other_pred = renamer.mutate(pred);
debug(3) << "......this thread predicate: " << this_pred << "\n";
debug(3) << "......other thread predicate: " << other_pred << "\n";
hazard = hazard && this_pred && other_pred;
}
debug(3) << "Attempting to falsify: " << hazard << "\n";
hazard = common_subexpression_elimination(hazard);
hazard = SubstituteInBooleanLets().mutate(hazard);
hazard = simplify(hazard, false, bounds);
debug(3) << "Simplified to: " << hazard << "\n";
while (const Let *l = hazard.as<Let>()) {
hazard = l->body;
}
return is_zero(hazard);
}
}
}