This source file includes following definitions.
- call_extern_and_assert
- fixup_device_api
- different_device_api
- visit
- visit
- visit
- visit
- visit
- visit
- device_api
- visit
- current_device
- make_device_interface_call
- make_dev_malloc
- make_buffer_copy
- do_copies
- should_track
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- device_api
- inject_host_dev_buffer_copies
#include "InjectHostDevBufferCopies.h"
#include "IRMutator.h"
#include "Debug.h"
#include "IRPrinter.h"
#include "CodeGen_GPU_Dev.h"
#include "IROperator.h"
#include <map>
namespace Halide {
namespace Internal {
using std::string;
using std::map;
using std::vector;
using std::set;
using std::pair;
Stmt call_extern_and_assert(const string& name, const vector<Expr>& args) {
Expr call = Call::make(Int(32), name, args, Call::Extern);
string call_result_name = unique_name(name + "_result");
Expr call_result_var = Variable::make(Int(32), call_result_name);
return LetStmt::make(call_result_name, call,
AssertStmt::make(EQ::make(call_result_var, 0), call_result_var));
}
namespace {
DeviceAPI fixup_device_api(DeviceAPI device_api, const Target &target) {
if (device_api == DeviceAPI::Default_GPU) {
if (target.has_feature(Target::Metal)) {
return DeviceAPI::Metal;
} else if (target.has_feature(Target::OpenCL)) {
return DeviceAPI::OpenCL;
} else if (target.has_feature(Target::CUDA)) {
return DeviceAPI::CUDA;
} else if (target.has_feature(Target::OpenGLCompute)) {
return DeviceAPI::OpenGLCompute;
} else {
user_error << "Schedule uses Default_GPU without a valid GPU (Metal, OpenCL CUDA, or OpenGLCompute) specified in target.\n";
}
}
return device_api;
}
bool different_device_api(DeviceAPI device_api, DeviceAPI stmt_api, const Target &target) {
device_api = fixup_device_api(device_api, target);
stmt_api = fixup_device_api(stmt_api, target);
return (stmt_api != DeviceAPI::None) && (device_api != stmt_api);
}
class FindBuffersToTrack : public IRVisitor {
map<string, DeviceAPI> internal;
const Target ⌖
DeviceAPI device_api;
using IRVisitor::visit;
void visit(const Allocate *op) {
debug(2) << "Buffers to track: Setting Allocate for loop " << op->name << " to " << static_cast<int>(device_api) << "\n";
internal_assert(internal.find(op->name) == internal.end()) << "Duplicate Allocate node in FindBuffersToTrack.\n";
auto it = internal.insert({ op->name, device_api });
IRVisitor::visit(op);
internal.erase(it.first);
}
void visit(const For *op) {
if (different_device_api(device_api, op->device_api, target)) {
debug(2) << "Buffers to track: switching from " << static_cast<int>(device_api) <<
" to " << static_cast<int>(op->device_api) << " for loop " << op->name << "\n";
DeviceAPI old_device_api = device_api;
device_api = fixup_device_api(op->device_api, target);
if (device_api == DeviceAPI::None) {
device_api = old_device_api;
}
internal_assert(device_api != DeviceAPI::None);
IRVisitor::visit(op);
device_api = old_device_api;
} else {
IRVisitor::visit(op);
}
}
void visit(const LetStmt *op) {
const Call *c = op->value.as<Call>();
if (ends_with(op->name, ".buffer") &&
c && c->name == Call::buffer_init) {
buffers_to_track.erase(op->name.substr(0, op->name.size() - 7));
}
IRVisitor::visit(op);
}
void visit(const Load *op) {
if (internal.find(op->name) == internal.end() ||
different_device_api(device_api, internal[op->name], target)) {
buffers_to_track.insert(op->name);
}
IRVisitor::visit(op);
}
void visit(const Store *op) {
if (internal.find(op->name) == internal.end() ||
different_device_api(device_api, internal[op->name], target)) {
buffers_to_track.insert(op->name);
}
IRVisitor::visit(op);
}
void visit(const Variable *op) {
if (op->type.is_handle() && ends_with(op->name, ".buffer")) {
buffers_to_track.insert(op->name.substr(0, op->name.size() - 7));
}
}
public:
set<string> buffers_to_track;
FindBuffersToTrack(const Target &t) : target(t), device_api(DeviceAPI::Host) {}
};
class NullifyHostField : public IRMutator {
using IRMutator::visit;
void visit(const Variable *op) {
if (op->name == buf_name) {
expr = make_zero(Handle());
} else {
expr = op;
}
}
std::string buf_name;
public:
NullifyHostField(const std::string &b) : buf_name(b) {}
};
class InjectBufferCopies : public IRMutator {
using IRMutator::visit;
struct BufferInfo {
bool host_touched,
dev_touched,
host_current,
dev_current,
internal,
dev_allocated;
string loop_level;
DeviceAPI device_first_touched;
DeviceAPI current_device;
std::set<DeviceAPI> devices_reading;
std::set<DeviceAPI> devices_writing;
std::set<DeviceAPI> devices_touched;
BufferInfo() : host_touched(false),
dev_touched(false),
host_current(false),
dev_current(false),
internal(false),
dev_allocated(true),
device_first_touched(DeviceAPI::None),
current_device(DeviceAPI::None) {}
};
map<string, BufferInfo> state;
string loop_level;
const set<string> &buffers_to_track;
const Target ⌖
DeviceAPI device_api;
Expr make_device_interface_call(DeviceAPI device_api) {
std::string interface_name;
switch (device_api) {
case DeviceAPI::CUDA:
interface_name = "halide_cuda_device_interface";
break;
case DeviceAPI::OpenCL:
interface_name = "halide_opencl_device_interface";
break;
case DeviceAPI::Metal:
interface_name = "halide_metal_device_interface";
break;
case DeviceAPI::GLSL:
interface_name = "halide_opengl_device_interface";
break;
case DeviceAPI::OpenGLCompute:
interface_name = "halide_openglcompute_device_interface";
break;
case DeviceAPI::Hexagon:
interface_name = "halide_hexagon_device_interface";
break;
default:
internal_error << "Bad DeviceAPI " << static_cast<int>(device_api) << "\n";
break;
}
std::vector<Expr> no_args;
return Call::make(type_of<const char *>(), interface_name, no_args, Call::Extern);
}
Stmt make_dev_malloc(string buf_name, DeviceAPI target_device_api, bool is_device_and_host) {
Expr buf = Variable::make(type_of<struct halide_buffer_t *>(), buf_name + ".buffer");
Expr device_interface = make_device_interface_call(target_device_api);
Stmt device_malloc = call_extern_and_assert(is_device_and_host ? "halide_device_and_host_malloc"
: "halide_device_malloc",
{buf, device_interface});
Stmt destructor =
Evaluate::make(Call::make(Int(32), Call::register_destructor,
{Expr(is_device_and_host ? "halide_device_and_host_free_as_destructor"
: "halide_device_free_as_destructor"), buf}, Call::Intrinsic));
return Block::make(device_malloc, destructor);
}
enum CopyDirection {
NoCopy,
ToHost,
ToDevice
};
Stmt make_buffer_copy(CopyDirection direction, string buf_name, DeviceAPI target_device_api) {
internal_assert(direction == ToHost || direction == ToDevice) << "make_buffer_copy caller logic error.\n";
std::vector<Expr> args;
Expr buffer = Variable::make(type_of<struct halide_buffer_t *>(), buf_name + ".buffer");
args.push_back(buffer);
if (direction == ToDevice) {
args.push_back(make_device_interface_call(target_device_api));
}
std::string suffix = (direction == ToDevice) ? "device" : "host";
return call_extern_and_assert("halide_copy_to_" + suffix, args);
}
Stmt do_copies(Stmt s) {
internal_assert(s.defined());
if (device_api != DeviceAPI::Host) {
return s;
}
debug(4) << "At loop level " << loop_level << "\n";
for (pair<const string, BufferInfo> &i : state) {
CopyDirection direction = NoCopy;
BufferInfo &buf = i.second;
if (buf.loop_level != loop_level) {
continue;
}
debug(4) << "do_copies for " << i.first << "\n"
<< "Host current: " << buf.host_current << " Device current: " << buf.dev_current << "\n"
<< "Host touched: " << buf.host_touched << " Device touched: " << buf.dev_touched << "\n"
<< "Internal: " << buf.internal << " Device touching first: "
<< static_cast<int>(buf.device_first_touched) << "\n"
<< "Current device: " << static_cast<int>(buf.current_device) << "\n";
DeviceAPI touching_device = DeviceAPI::None;
bool host_read = false;
size_t non_host_devices_reading_count = 0;
DeviceAPI reading_device = DeviceAPI::None;
for (DeviceAPI dev : buf.devices_reading) {
debug(4) << "Device " << static_cast<int>(dev) << " read buffer\n";
if (dev != DeviceAPI::Host) {
non_host_devices_reading_count++;
reading_device = dev;
touching_device = dev;
} else {
host_read = true;
}
}
bool host_wrote = false;
size_t non_host_devices_writing_count = 0;
DeviceAPI writing_device = DeviceAPI::None;
for (DeviceAPI dev : buf.devices_writing) {
debug(4) << "Device " << static_cast<int>(dev) << " wrote buffer\n";
if (dev != DeviceAPI::Host) {
non_host_devices_writing_count++;
writing_device = dev;
touching_device = dev;
} else {
host_wrote = true;
}
}
internal_assert(non_host_devices_reading_count <= 1);
internal_assert(non_host_devices_writing_count <= 1);
internal_assert((non_host_devices_reading_count == 0 || non_host_devices_writing_count == 0) ||
reading_device == writing_device);
bool device_read = non_host_devices_reading_count > 0;
bool device_wrote = non_host_devices_writing_count > 0;
buf.host_touched = host_wrote || host_read || buf.host_touched;
if (!buf.dev_touched && (device_wrote || device_read)) {
buf.dev_touched = true;
buf.device_first_touched = touching_device;
}
if ((host_read || host_wrote) && !buf.host_current && (!buf.internal || buf.dev_touched)) {
internal_assert(!device_wrote && !(host_wrote && device_read));
direction = ToHost;
buf.host_current = true;
buf.dev_current = buf.dev_current && !host_wrote;
debug(4) << "Needs copy to host\n";
} else if (host_wrote) {
internal_assert(!device_read && !device_wrote);
buf.dev_current = false;
debug(4) << "Invalidating dev_current\n";
}
if ((device_read || device_wrote) &&
((!buf.dev_current || (buf.current_device != touching_device)) ||
(!buf.internal || buf.host_touched))) {
internal_assert(!host_wrote && !(device_wrote && host_read));
direction = ToDevice;
buf.host_touched = buf.host_touched || (buf.current_device != DeviceAPI::None &&
buf.current_device != touching_device);
buf.dev_current = true;
buf.current_device = touching_device;
buf.host_current = buf.host_current && !device_wrote;
debug(4) << "Needs copy to dev\n";
} else if (device_wrote) {
internal_assert(!host_read && !host_wrote);
buf.host_current = false;
debug(4) << "Invalidating host_current\n";
}
Expr buffer = Variable::make(type_of<struct halide_buffer_t *>(), i.first + ".buffer");
if (host_wrote) {
debug(4) << "Setting host dirty for " << i.first << "\n";
Expr set_host_dirty = Call::make(Int(32), Call::buffer_set_host_dirty,
{buffer, const_true()}, Call::Extern);
s = Block::make(s, Evaluate::make(set_host_dirty));
}
if (device_wrote) {
Expr set_device_dirty = Call::make(Int(32), Call::buffer_set_device_dirty,
{buffer, const_true()}, Call::Extern);
s = Block::make(s, Evaluate::make(set_device_dirty));
}
buf.devices_reading.clear();
buf.devices_writing.clear();
if (direction != NoCopy && touching_device != DeviceAPI::Host) {
internal_assert(s.defined());
s = Block::make(make_buffer_copy(direction, i.first, touching_device), s);
}
if (!buf.dev_allocated &&
buf.device_first_touched != DeviceAPI::Host &&
buf.device_first_touched != DeviceAPI::None) {
debug(4) << "Injecting device malloc for " << i.first << " on " <<
static_cast<int>(buf.device_first_touched) << "\n";
Stmt dev_malloc = make_dev_malloc(i.first, buf.device_first_touched, false);
internal_assert(s.defined());
s = Block::make(dev_malloc, s);
buf.dev_allocated = true;
}
}
debug(4) << "\n";
return s;
}
bool should_track(const string &buf) {
return buffers_to_track.count(buf) != 0;
}
void visit(const Store *op) {
IRMutator::visit(op);
if (!should_track(op->name)) {
return;
}
debug(4) << "Device " << static_cast<int>(device_api) << " writes buffer " << op->name << "\n";
state[op->name].devices_writing.insert(device_api);
state[op->name].devices_touched.insert(device_api);
}
void visit(const Load *op) {
IRMutator::visit(op);
if (!should_track(op->name)) {
return;
}
debug(4) << "Device " << static_cast<int>(device_api) << " reads buffer " << op->name << "\n";
state[op->name].devices_reading.insert(device_api);
state[op->name].devices_touched.insert(device_api);
}
void visit(const Call *op) {
if (op->is_intrinsic(Call::image_load)) {
internal_assert(device_api == DeviceAPI::GLSL);
internal_assert(op->args.size() >= 2);
const Variable *buffer_var = op->args[1].as<Variable>();
internal_assert(buffer_var && ends_with(buffer_var->name, ".buffer"));
string buf_name = buffer_var->name.substr(0, buffer_var->name.size() - 7);
debug(4) << "Adding image read via image_load for " << buffer_var->name << "\n";
state[buf_name].devices_reading.insert(device_api);
state[buf_name].devices_touched.insert(device_api);
IRMutator::visit(op);
} else if (op->is_intrinsic(Call::image_store)) {
internal_assert(device_api == DeviceAPI::GLSL);
internal_assert(op->args.size() >= 2);
const Variable *buffer_var = op->args[1].as<Variable>();
internal_assert(buffer_var && ends_with(buffer_var->name, ".buffer"));
string buf_name = buffer_var->name.substr(0, buffer_var->name.size() - 7);
debug(4) << "Adding image write via image_store for " << buffer_var->name << "\n";
state[buf_name].devices_writing.insert(device_api);
state[buf_name].devices_touched.insert(device_api);
IRMutator::visit(op);
} else {
IRMutator::visit(op);
}
}
void visit(const ProducerConsumer *op) {
if (device_api != DeviceAPI::Host) {
IRMutator::visit(op);
return;
}
Stmt body = mutate(op->body);
body = do_copies(body);
if (body.same_as(op->body)) {
stmt = op;
} else {
stmt = ProducerConsumer::make(op->name, op->is_producer, body);
}
if (op->is_producer) {
bool is_output = true;
for (pair<const string, BufferInfo> &i : state) {
const string &buf_name = i.first;
if (buf_name == op->name || starts_with(buf_name, op->name + ".")) {
i.second.loop_level = loop_level;
is_output = false;
}
}
if (is_output) {
for (pair<const string, BufferInfo> &i : state) {
const string &buf_name = i.first;
if ((buf_name == op->name || starts_with(buf_name, op->name + ".")) &&
i.second.dev_touched && i.second.current_device != DeviceAPI::Host) {
debug(4) << "Injecting device copy for output " << buf_name << " on " <<
static_cast<int>(i.second.current_device) << "\n";
stmt = Block::make(make_buffer_copy(ToDevice, buf_name, i.second.current_device), stmt);
}
}
}
}
}
void visit(const Variable *op) {
IRMutator::visit(op);
if (ends_with(op->name, ".buffer")) {
string buf_name = op->name.substr(0, op->name.size() - 7);
if (state.find(buf_name) != state.end()) {
state[buf_name].host_touched = true;
}
}
}
void visit(const Allocate *op) {
if (device_api != DeviceAPI::Host ||
!should_track(op->name)) {
IRMutator::visit(op);
return;
}
string buf_name = op->name;
{
BufferInfo &buf_init(state[buf_name]);
buf_init.internal = true;
buf_init.dev_allocated = false;
}
IRMutator::visit(op);
op = stmt.as<Allocate>();
internal_assert(op);
BufferInfo &buf_info(state[buf_name]);
if (buf_info.dev_touched) {
user_assert(op->extents.size() <= 4)
<< "Buffer " << op->name
<< " cannot be used on the GPU, because it has more than four dimensions.\n";
}
bool on_single_device = ((buf_info.devices_touched.size() < 2) ||
(buf_info.devices_touched.size() == 2 &&
buf_info.devices_touched.count(DeviceAPI::Host)));
if (!buf_info.host_touched) {
debug(4) << "Eliding host alloc for " << op->name << "\n";
stmt = Allocate::make(op->name, op->type, op->extents, const_false(), op->body);
} else if (on_single_device &&
buf_info.dev_touched &&
buf_info.device_first_touched != DeviceAPI::None) {
debug(4) << "Making combined host/device alloc for " << op->name << "\n";
Stmt inner_body = op->body;
std::vector<const LetStmt *> body_lets;
const LetStmt *buffer_init_let = nullptr;
while (const LetStmt *inner_let = inner_body.as<LetStmt>()) {
inner_body = inner_let->body;
if (inner_let->name == op->name + ".buffer") {
buffer_init_let = inner_let;
break;
}
body_lets.push_back(inner_let);
}
Stmt combined_malloc = make_dev_malloc(op->name, buf_info.device_first_touched, true);
inner_body = Allocate::make(op->name, op->type, op->extents, op->condition, inner_body,
Call::make(Handle(), Call::buffer_get_host,
{ Variable::make(type_of<struct halide_buffer_t *>(), op->name + ".buffer") },
Call::Extern),
"halide_device_host_nop_free");
inner_body = Block::make(combined_malloc, inner_body);
std::vector<Expr> create_buffer_args;
internal_assert(buffer_init_let) << "Could not find definition of " << op->name << ".buffer\n";
Expr buf = NullifyHostField(op->name).mutate(buffer_init_let->value);
stmt = LetStmt::make(op->name + ".buffer", buf, inner_body);
for (size_t i = body_lets.size(); i > 0; i--) {
stmt = LetStmt::make(body_lets[i - 1]->name, body_lets[i - 1]->value, stmt);
}
}
state.erase(buf_name);
}
void visit(const LetStmt *op) {
IRMutator::visit(op);
if (device_api != DeviceAPI::Host) {
return;
}
op = stmt.as<LetStmt>();
internal_assert(op);
if (ends_with(op->name, ".buffer")) {
string buf_name = op->name.substr(0, op->name.size() - 7);
if (!should_track(buf_name)) {
return;
}
if (!state[buf_name].host_touched) {
Expr value = NullifyHostField(buf_name).mutate(op->value);
stmt = LetStmt::make(op->name, value, op->body);
}
}
}
void visit(const IfThenElse *op) {
if (device_api != DeviceAPI::Host) {
IRMutator::visit(op);
return;
}
Expr cond = mutate(op->condition);
map<string, BufferInfo> copy = state;
Stmt then_case = mutate(op->then_case);
then_case = do_copies(then_case);
copy.swap(state);
Stmt else_case = mutate(op->else_case);
if (else_case.defined()) {
else_case = do_copies(else_case);
}
for (const pair<string, BufferInfo> &i : copy) {
const string &buf_name = i.first;
const BufferInfo &then_state = i.second;
const BufferInfo &else_state = state[buf_name];
BufferInfo merged_state;
internal_assert(then_state.loop_level == else_state.loop_level)
<< "then_state and else_state should have the same loop level for " << buf_name;
merged_state.loop_level = then_state.loop_level;
merged_state.host_touched = then_state.host_touched || else_state.host_touched;
merged_state.dev_touched = then_state.dev_touched || else_state.dev_touched;
merged_state.host_current = then_state.host_current && else_state.host_current;
merged_state.dev_current = then_state.dev_current && else_state.dev_current &&
then_state.current_device == else_state.current_device;
if (then_state.device_first_touched == else_state.device_first_touched) {
merged_state.device_first_touched = then_state.device_first_touched;
} else {
merged_state.device_first_touched = DeviceAPI::None;
}
merged_state.devices_reading = then_state.devices_reading;
merged_state.devices_reading.insert(else_state.devices_reading.begin(),
else_state.devices_reading.end());
merged_state.devices_writing = then_state.devices_writing;
merged_state.devices_writing.insert(else_state.devices_writing.begin(),
else_state.devices_writing.end());
merged_state.devices_touched = then_state.devices_touched;
merged_state.devices_touched.insert(else_state.devices_touched.begin(),
else_state.devices_touched.end());
state[buf_name] = merged_state;
}
if (cond.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
stmt = op;
} else {
stmt = IfThenElse::make(cond, then_case, else_case);
}
}
void visit(const Block *op) {
if (device_api != DeviceAPI::Host) {
IRMutator::visit(op);
return;
}
Stmt first = mutate(op->first);
first = do_copies(first);
Stmt rest = op->rest;
if (rest.defined()) {
rest = mutate(rest);
rest = do_copies(rest);
}
stmt = Block::make(first, rest);
}
void visit(const For *op) {
string old_loop_level = loop_level;
loop_level = op->name;
if (different_device_api(device_api, op->device_api, target)) {
debug(4) << "Switching from device_api " << static_cast<int>(device_api) << " to op->device_api " <<
static_cast<int>(op->device_api) << " in for loop " << op->name <<"\n";
DeviceAPI old_device_api = device_api;
device_api = fixup_device_api(op->device_api, target);
if (device_api == DeviceAPI::None) {
device_api = old_device_api;
}
internal_assert(device_api != DeviceAPI::None);
IRMutator::visit(op);
device_api = old_device_api;
} else {
IRMutator::visit(op);
}
loop_level = old_loop_level;
}
public:
InjectBufferCopies(const set<string> &i, const Target &t) : loop_level(""), buffers_to_track(i), target(t), device_api(DeviceAPI::Host) {}
};
}
Stmt inject_host_dev_buffer_copies(Stmt s, const Target &t) {
FindBuffersToTrack f(t);
s.accept(&f);
debug(4) << "Tracking host <-> dev copies for the following buffers:\n";
for (const std::string &i : f.buffers_to_track) {
debug(4) << i << "\n";
}
return InjectBufferCopies(f.buffers_to_track, t).mutate(s);
}
}
}