This source file includes following definitions.
- visit
- visit
- uses_extern_image
- visit
- visit
- visit
- visit
- visit
- split_tuples
#include "SplitTuples.h"
#include "Bounds.h"
#include "IRMutator.h"
namespace Halide {
namespace Internal {
using std::map;
using std::string;
using std::vector;
using std::pair;
using std::set;
namespace {
class FindCallValueIndices : public IRVisitor {
public:
const string func;
map<string, set<int>> func_value_indices;
using IRVisitor::visit;
void visit(const Call *call) {
IRVisitor::visit(call);
if ((call->call_type == Call::Halide) && call->func.defined()) {
func_value_indices[call->name].insert(call->value_index);
}
}
};
class UsesExternImage : public IRVisitor {
using IRVisitor::visit;
void visit(const Call *c) {
if (c->call_type == Call::Image) {
result = true;
} else {
IRVisitor::visit(c);
}
}
public:
UsesExternImage() : result(false) {}
bool result;
};
inline bool uses_extern_image(Stmt s) {
UsesExternImage uses;
s.accept(&uses);
return uses.result;
}
class SplitTuples : public IRMutator {
using IRMutator::visit;
map<string, set<int>> func_value_indices;
void visit(const Realize *op) {
realizations.push(op->name, 0);
if (op->types.size() > 1) {
Stmt body = mutate(op->body);
for (int i = (int)op->types.size() - 1; i >= 0; i--) {
body = Realize::make(op->name + "." + std::to_string(i), {op->types[i]}, op->bounds, op->condition, body);
}
stmt = body;
} else {
IRMutator::visit(op);
}
realizations.pop(op->name);
}
void visit(const For *op) {
map<string, set<int>> old_func_value_indices = func_value_indices;
FindCallValueIndices find;
op->body.accept(&find);
func_value_indices = find.func_value_indices;
IRMutator::visit(op);
func_value_indices = old_func_value_indices;
}
void visit(const Prefetch *op) {
if (!op->param.defined() && (op->types.size() > 1)) {
const auto &indices = func_value_indices.find(op->name);
internal_assert(indices != func_value_indices.end());
auto it = indices->second.begin();
internal_assert((*it) < (int)op->types.size());
stmt = Prefetch::make(op->name + "." + std::to_string(*it), {op->types[(*it)]}, op->bounds);
for (++it; it != indices->second.end(); ++it) {
internal_assert((*it) < (int)op->types.size());
stmt = Block::make(stmt, Prefetch::make(op->name + "." + std::to_string(*it), {op->types[(*it)]}, op->bounds));
}
} else {
IRMutator::visit(op);
}
}
void visit(const Call *op) {
if (op->call_type == Call::Halide) {
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
string name = op->name;
if (f.outputs() > 1) {
name += "." + std::to_string(op->value_index);
}
vector<Expr> args;
for (Expr e : op->args) {
args.push_back(mutate(e));
}
expr = Call::make(op->type, name, args, op->call_type, f.get_contents());
} else {
IRMutator::visit(op);
}
}
void visit(const Provide *op) {
if (op->values.size() == 1) {
IRMutator::visit(op);
return;
}
bool atomic = false;
if (!realizations.contains(op->name) &&
uses_extern_image(op)) {
atomic = true;
} else {
Box provided = box_provided(op, op->name);
Box required = box_required(op, op->name);
atomic = boxes_overlap(provided, required);
}
vector<Expr> args;
for (Expr e : op->args) {
args.push_back(mutate(e));
}
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
vector<Stmt> provides;
vector<pair<string, Expr>> lets;
for (size_t i = 0; i < op->values.size(); i++) {
string name = op->name + "." + std::to_string(i);
string var_name = name + ".value";
Expr val = mutate(op->values[i]);
if (!is_undef(val) && atomic) {
lets.push_back({ var_name, val });
val = Variable::make(val.type(), var_name);
}
provides.push_back(Provide::make(name, {val}, args));
}
Stmt result = Block::make(provides);
while (!lets.empty()) {
auto p = lets.back();
lets.pop_back();
result = LetStmt::make(p.first, p.second, result);
}
stmt = result;
}
const map<string, Function> &env;
Scope<int> realizations;
public:
SplitTuples(const map<string, Function> &e) : env(e) {}
};
}
Stmt split_tuples(Stmt s, const map<string, Function> &env) {
return SplitTuples(env).mutate(s);
}
}
}