This source file includes following definitions.
- make_checked_call
- make_wrapper
- visit
- wrap_legacy_extern_stages
- add_legacy_wrapper
#include "WrapExternStages.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Argument.h"
#include <set>
namespace Halide {
namespace Internal {
using std::set;
using std::string;
using std::vector;
using std::pair;
namespace {
Stmt make_checked_call(Expr call) {
internal_assert(call.type() == Int(32));
string result_var_name = unique_name('t');
Expr result_var = Variable::make(Int(32), result_var_name);
Stmt s = AssertStmt::make(result_var == 0, result_var);
s = LetStmt::make(result_var_name, call, s);
return s;
}
class WrapExternStages : public IRMutator {
using IRMutator::visit;
string make_wrapper(const Call *op) {
string wrapper_name = replace_all(prefix + op->name, ":", "_");
if (done.count(op->name)) {
return wrapper_name;
}
done.insert(op->name);
Function f(op->func);
vector<Argument> args;
for (ExternFuncArgument arg : f.extern_arguments()) {
if (arg.arg_type == ExternFuncArgument::FuncArg) {
Function f_arg(arg.func);
for (auto b : f_arg.output_buffers()) {
args.emplace_back(b.name(), Argument::InputBuffer,
b.type(), b.dimensions());
}
} else if (arg.arg_type == ExternFuncArgument::BufferArg) {
args.emplace_back(arg.buffer.name(), Argument::InputBuffer,
arg.buffer.type(), arg.buffer.dimensions());
} else if (arg.arg_type == ExternFuncArgument::ExprArg) {
args.emplace_back(unique_name('a'), Argument::InputScalar,
arg.expr.type(), 0);
} else if (arg.arg_type == ExternFuncArgument::ImageParamArg) {
args.emplace_back(arg.image_param.name(), Argument::InputBuffer,
arg.image_param.type(), arg.image_param.dimensions());
}
}
for (auto b : f.output_buffers()) {
args.emplace_back(b.name(), Argument::OutputBuffer, b.type(), b.dimensions());
}
vector<Stmt> upgrades, downgrades;
vector<pair<string, Expr>> old_buffers;
vector<Expr> call_args;
for (Argument a : args) {
if (a.kind == Argument::InputBuffer ||
a.kind == Argument::OutputBuffer) {
Expr new_buffer_var = Variable::make(a.type, a.name + ".buffer");
string old_buffer_name = a.name + ".old_buffer_t";
Expr old_buffer_var = Variable::make(type_of<struct buffer_t *>(), old_buffer_name);
Expr old_buffer = Call::make(type_of<struct buffer_t *>(), Call::alloca,
{(int)sizeof(buffer_t)}, Call::Intrinsic);
old_buffers.emplace_back(old_buffer_name, old_buffer);
Expr downgrade_call = Call::make(Int(32), "halide_downgrade_buffer_t",
{a.name, new_buffer_var, old_buffer_var},
Call::Extern);
downgrades.push_back(make_checked_call(downgrade_call));
Expr upgrade_call = Call::make(Int(32), "halide_upgrade_buffer_t",
{a.name, old_buffer_var, new_buffer_var},
Call::Extern);
upgrades.push_back(make_checked_call(upgrade_call));
call_args.push_back(old_buffer_var);
} else {
call_args.push_back(Variable::make(a.type, a.name));
}
}
Expr inner_call = Call::make(op->type, op->name, call_args, op->call_type);
Stmt body = make_checked_call(inner_call);
body = Block::make({Block::make(downgrades), body, Block::make(upgrades)});
while (!old_buffers.empty()) {
auto p = old_buffers.back();
body = LetStmt::make(p.first, p.second, body);
old_buffers.pop_back();
}
debug(2) << "Wrapped extern call to " << op->name << ":\n" << body << "\n\n";
LoweredFunc wrapper(wrapper_name, args, body, LoweredFunc::Internal, NameMangling::C);
module.append(wrapper);
return wrapper_name;
}
void visit(const Call *op) {
if ((op->call_type == Call::Extern ||
op->call_type == Call::ExternCPlusPlus) &&
op->func.defined()) {
Function f(op->func);
internal_assert(f.has_extern_definition());
if (f.extern_definition_uses_old_buffer_t()) {
vector<Expr> new_args;
for (Expr e : op->args) {
new_args.push_back(mutate(e));
}
expr = Call::make(op->type, make_wrapper(op), new_args, Call::Extern, op->func);
} else {
IRMutator::visit(op);
}
} else {
IRMutator::visit(op);
}
}
set<string> done;
Module module;
public:
WrapExternStages(Module m) : module(m) {}
string prefix;
};
}
void wrap_legacy_extern_stages(Module m) {
WrapExternStages wrap(m);
size_t num_functions = m.functions().size();
for (size_t i = 0; i < num_functions; i++) {
wrap.prefix = "_halide_wrapper_" + m.functions()[i].name + "_";
Stmt old_body = m.functions()[i].body;
Stmt new_body = wrap.mutate(old_body);
m.functions()[i].body = new_body;
debug(2) << "Body after wrapping extern calls:\n" << new_body << "\n\n";
}
}
void add_legacy_wrapper(Module module, const LoweredFunc &fn) {
vector<LoweredArgument> args;
vector<Stmt> upgrades, downgrades;
vector<Expr> call_args;
vector<pair<string, Expr>> new_buffers;
for (LoweredArgument arg : fn.args) {
if (arg.kind == Argument::InputScalar) {
args.push_back(arg);
call_args.push_back(Variable::make(arg.type, arg.name));
} else {
args.emplace_back(arg.name, Argument::InputScalar, type_of<buffer_t *>(), 0);
string new_buffer_name = arg.name + ".upgraded";
Expr new_buffer_var = Variable::make(type_of<struct halide_buffer_t *>(), new_buffer_name);
Expr old_buffer_var = Variable::make(type_of<struct buffer_t *>(), arg.name);
BufferBuilder builder;
builder.type = arg.type;
builder.dimensions = arg.dimensions;
Expr new_buffer = builder.build();
new_buffers.emplace_back(new_buffer_name, new_buffer);
Expr downgrade_call = Call::make(Int(32), "halide_downgrade_buffer_t",
{arg.name, new_buffer_var, old_buffer_var},
Call::Extern);
Stmt downgrade = make_checked_call(downgrade_call);
Expr downgrade_device_call = Call::make(Int(32), "halide_downgrade_buffer_t_device_fields",
{arg.name, new_buffer_var, old_buffer_var},
Call::Extern);
Stmt downgrade_device = make_checked_call(downgrade_device_call);
Expr bounds_query = Call::make(Bool(), Call::buffer_is_bounds_query,
{new_buffer_var}, Call::Extern);
downgrade = IfThenElse::make(bounds_query, downgrade, downgrade_device);
downgrades.push_back(downgrade);
Expr upgrade_call = Call::make(Int(32), "halide_upgrade_buffer_t",
{arg.name, old_buffer_var, new_buffer_var},
Call::Extern);
upgrades.push_back(make_checked_call(upgrade_call));
call_args.push_back(new_buffer_var);
}
}
Call::CallType call_type = Call::Extern;
if (fn.name_mangling == NameMangling::CPlusPlus ||
(fn.name_mangling == NameMangling::Default &&
module.target().has_feature(Target::CPlusPlusMangling))) {
call_type = Call::ExternCPlusPlus;
}
Expr inner_call = Call::make(Int(32), fn.name, call_args, call_type);
Stmt body = make_checked_call(inner_call);
body = Block::make({Block::make(upgrades), body, Block::make(downgrades)});
while (!new_buffers.empty()) {
auto p = new_buffers.back();
body = LetStmt::make(p.first, p.second, body);
new_buffers.pop_back();
}
string name = fn.name;
if (!module.target().has_feature(Target::CPlusPlusMangling)) {
name += "_old_buffer_t";
}
debug(2) << "Added legacy wrapper for " << fn.name << ":\n" << body << "\n\n";
LoweredFunc wrapper(name, args, body, LoweredFunc::External, NameMangling::Default);
module.append(wrapper);
}
}
}