This source file includes following definitions.
- visit_function
- visit
- visit
- visit
- record
- record
- name
- parameters_alignment
- function_name
- key_size
- generate_key
- generate_lookup
- store_computation
- outputs
- visit
- visit
- inject_memoization
- get_realization_name
- visit
- visit
- visit
- rewrite_memoized_allocations
#include "Memoization.h"
#include "Error.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Param.h"
#include "Scope.h"
#include "Util.h"
#include "Var.h"
#include <map>
namespace Halide {
namespace Internal {
namespace {
class FindParameterDependencies : public IRGraphVisitor {
public:
FindParameterDependencies() { }
~FindParameterDependencies() { }
void visit_function(const Function &function) {
function.accept(this);
if (function.has_extern_definition()) {
const std::vector<ExternFuncArgument> &extern_args =
function.extern_arguments();
for (size_t i = 0; i < extern_args.size(); i++) {
if (extern_args[i].is_buffer()) {
record(Halide::Internal::Parameter(extern_args[i].buffer.type(), true,
extern_args[i].buffer.dimensions(),
extern_args[i].buffer.name()));
} else if (extern_args[i].is_image_param()) {
record(extern_args[i].image_param);
}
}
}
}
using IRGraphVisitor::visit;
void visit(const Call *call) {
if (call->param.defined()) {
record(call->param);
}
if (call->is_intrinsic(Call::memoize_expr)) {
internal_assert(!call->args.empty());
if (call->args.size() == 1) {
record(call->args[0]);
} else {
for (size_t i = 1; i < call->args.size(); i++) {
record(call->args[i]);
}
}
} else if (call->func.defined()) {
Function fn(call->func);
visit_function(fn);
IRGraphVisitor::visit(call);
} else {
IRGraphVisitor::visit(call);
}
}
void visit(const Load *load) {
if (load->param.defined()) {
record(load->param);
}
IRGraphVisitor::visit(load);
}
void visit(const Variable *var) {
if (var->param.defined()) {
record(var->param);
}
IRGraphVisitor::visit(var);
}
void record(const Parameter ¶meter) {
struct DependencyInfo info;
info.type = parameter.type();
if (parameter.is_buffer()) {
internal_error << "Buffer parameter " << parameter.name() <<
" encountered in computed_cached computation.\n" <<
"Computations which depend on buffer parameters " <<
"cannot be scheduled compute_cached.\n" <<
"Use memoize_tag to provide cache key information for buffer.\n";
} else if (info.type.is_handle()) {
internal_error << "Handle parameter " << parameter.name() <<
" encountered in computed_cached computation.\n" <<
"Computations which depend on handle parameters " <<
"cannot be scheduled compute_cached.\n" <<
"Use memoize_tag to provide cache key information for handle.\n";
} else {
info.size_expr = info.type.bytes();
info.value_expr = Internal::Variable::make(info.type, parameter.name(), parameter);
}
dependency_info[DependencyKey(info.type.bytes(), parameter.name())] = info;
}
void record(const Expr &expr) {
struct DependencyInfo info;
info.type = expr.type();
info.size_expr = info.type.bytes();
info.value_expr = expr;
dependency_info[DependencyKey(info.type.bytes(), unique_name("memoize_tag"))] = info;
}
struct DependencyKey {
uint32_t size;
std::string name;
bool operator<(const DependencyKey &rhs) const {
if (size < rhs.size) {
return true;
} else if (size == rhs.size) {
return name < rhs.name;
}
return false;
}
DependencyKey(uint32_t size_arg, const std::string &name_arg)
: size(size_arg), name(name_arg) {
}
};
struct DependencyInfo {
Type type;
Expr size_expr;
Expr value_expr;
};
std::map<DependencyKey, DependencyInfo> dependency_info;
};
typedef std::pair<FindParameterDependencies::DependencyKey, FindParameterDependencies::DependencyInfo> DependencyKeyInfoPair;
class KeyInfo {
FindParameterDependencies dependencies;
Expr key_size_expr;
const std::string &top_level_name;
const std::string &function_name;
size_t parameters_alignment() {
int32_t max_alignment = 0;
for (const DependencyKeyInfoPair &i : dependencies.dependency_info) {
int alignment = i.second.type.bytes();
if (alignment > max_alignment) {
max_alignment = alignment;
}
}
int i = 0;
while (i < 4 && max_alignment > (1 << i)) {
i = i + 1;
}
return size_t(1) << i;
}
public:
KeyInfo(const Function &function, const std::string &name)
: top_level_name(name), function_name(function.name())
{
dependencies.visit_function(function);
size_t size_so_far = 0;
size_so_far += Handle().bytes() + 4;
size_t needed_alignment = parameters_alignment();
if (needed_alignment > 1) {
size_so_far = (size_so_far + needed_alignment - 1) & ~(needed_alignment - 1);
}
key_size_expr = (int32_t)size_so_far;
for (const DependencyKeyInfoPair &i : dependencies.dependency_info) {
key_size_expr += i.second.size_expr;
}
}
Expr key_size() { return cast<int32_t>(key_size_expr); };
Stmt generate_key(std::string key_name) {
std::vector<Stmt> writes;
Expr index = Expr(0);
writes.push_back(Store::make(key_name,
StringImm::make(std::to_string(top_level_name.size()) + ":" + top_level_name +
std::to_string(function_name.size()) + ":" + function_name),
(index / Handle().bytes()), Parameter(), const_true()));
size_t alignment = Handle().bytes();
index += Handle().bytes();
static std::atomic<int> memoize_instance {0};
writes.push_back(Store::make(key_name,
memoize_instance++,
(index / Int(32).bytes()),
Parameter(), const_true()));
alignment += 4;
index += 4;
size_t needed_alignment = parameters_alignment();
if (needed_alignment > 1) {
while (alignment % needed_alignment) {
writes.push_back(Store::make(key_name, Cast::make(UInt(8), 0),
index, Parameter(), const_true()));
index = index + 1;
alignment++;
}
}
for (const DependencyKeyInfoPair &i : dependencies.dependency_info) {
writes.push_back(Store::make(key_name,
i.second.value_expr,
(index / i.second.size_expr),
Parameter(), const_true()));
index += i.second.size_expr;
}
Stmt blocks = Block::make(writes);
return blocks;
}
Expr generate_lookup(std::string key_allocation_name, std::string computed_bounds_name,
int32_t tuple_count, std::string storage_base_name) {
std::vector<Expr> args;
args.push_back(Variable::make(type_of<uint8_t *>(), key_allocation_name));
args.push_back(key_size());
args.push_back(Variable::make(type_of<halide_buffer_t *>(), computed_bounds_name));
args.push_back(tuple_count);
std::vector<Expr> buffers;
if (tuple_count == 1) {
buffers.push_back(Variable::make(type_of<halide_buffer_t *>(), storage_base_name + ".buffer"));
} else {
for (int32_t i = 0; i < tuple_count; i++) {
buffers.push_back(Variable::make(type_of<halide_buffer_t *>(), storage_base_name + "." + std::to_string(i) + ".buffer"));
}
}
args.push_back(Call::make(type_of<halide_buffer_t **>(), Call::make_struct, buffers, Call::Intrinsic));
return Call::make(Int(32), "halide_memoization_cache_lookup", args, Call::Extern);
}
Stmt store_computation(std::string key_allocation_name, std::string computed_bounds_name,
int32_t tuple_count, std::string storage_base_name) {
std::vector<Expr> args;
args.push_back(Variable::make(type_of<uint8_t *>(), key_allocation_name));
args.push_back(key_size());
args.push_back(Variable::make(type_of<halide_buffer_t *>(), computed_bounds_name));
args.push_back(tuple_count);
std::vector<Expr> buffers;
if (tuple_count == 1) {
buffers.push_back(Variable::make(type_of<halide_buffer_t *>(), storage_base_name + ".buffer"));
} else {
for (int32_t i = 0; i < tuple_count; i++) {
buffers.push_back(Variable::make(type_of<halide_buffer_t *>(), storage_base_name + "." + std::to_string(i) + ".buffer"));
}
}
args.push_back(Call::make(type_of<halide_buffer_t **>(), Call::make_struct, buffers, Call::Intrinsic));
return Evaluate::make(Call::make(Int(32), "halide_memoization_cache_store", args, Call::Extern));
}
};
}
class InjectMemoization : public IRMutator {
public:
const std::map<std::string, Function> &env;
const std::string &top_level_name;
const std::vector<Function> &outputs;
InjectMemoization(const std::map<std::string, Function> &e, const std::string &name,
const std::vector<Function> &outputs) :
env(e), top_level_name(name), outputs(outputs) {}
private:
using IRMutator::visit;
void visit(const Realize *op) {
std::map<std::string, Function>::const_iterator iter = env.find(op->name);
if (iter != env.end() &&
iter->second.schedule().memoized()) {
const Function f(iter->second);
for (const Function &o : outputs) {
if (f.same_as(o)) {
user_error << "Function " << f.name() << " cannot be memoized because "
<< "it an output of pipeline " << top_level_name << ".\n";
}
}
if (!f.schedule().compute_level().match(f.schedule().store_level())) {
user_error << "Function " << f.name() << " cannot be memoized because "
<< "it has compute and storage scheduled at different loop levels.\n";
}
Stmt mutated_body = mutate(op->body);
KeyInfo key_info(f, top_level_name);
std::string cache_key_name = op->name + ".cache_key";
std::string cache_result_name = op->name + ".cache_result";
std::string cache_miss_name = op->name + ".cache_miss";
std::string computed_bounds_name = op->name + ".computed_bounds.buffer";
Stmt cache_miss_marker = LetStmt::make(cache_miss_name,
Cast::make(Bool(), Variable::make(Int(32), cache_result_name)),
mutated_body);
Stmt cache_lookup_check = Block::make(AssertStmt::make(NE::make(Variable::make(Int(32), cache_result_name), -1),
Call::make(Int(32), "halide_error_out_of_memory", { }, Call::Extern)),
cache_miss_marker);
Stmt cache_lookup = LetStmt::make(cache_result_name,
key_info.generate_lookup(cache_key_name, computed_bounds_name, f.outputs(), op->name),
cache_lookup_check);
BufferBuilder builder;
builder.dimensions = f.dimensions();
std::string max_stage_num = std::to_string(f.updates().size());
for (const std::string arg : f.args()) {
std::string prefix = op->name + ".s" + max_stage_num + "." + arg;
Expr min = Variable::make(Int(32), prefix + ".min");
Expr max = Variable::make(Int(32), prefix + ".max");
builder.mins.push_back(min);
builder.extents.push_back(max + 1 - min);
}
Expr computed_bounds = builder.build();
Stmt computed_bounds_let = LetStmt::make(computed_bounds_name, computed_bounds, cache_lookup);
Stmt generate_key = Block::make(key_info.generate_key(cache_key_name), computed_bounds_let);
Stmt cache_key_alloc =
Allocate::make(cache_key_name, UInt(8), {key_info.key_size()},
const_true(), generate_key);
stmt = Realize::make(op->name, op->types, op->bounds, op->condition, cache_key_alloc);
} else {
IRMutator::visit(op);
}
}
void visit(const ProducerConsumer *op) {
std::map<std::string, Function>::const_iterator iter = env.find(op->name);
if (iter != env.end() &&
iter->second.schedule().memoized()) {
Stmt body = mutate(op->body);
std::string cache_miss_name = op->name + ".cache_miss";
Expr cache_miss = Variable::make(Bool(), cache_miss_name);
if (op->is_producer) {
Stmt mutated_body = IfThenElse::make(cache_miss, body);
stmt = ProducerConsumer::make(op->name, op->is_producer, mutated_body);
} else {
const Function f(iter->second);
KeyInfo key_info(f, top_level_name);
std::string cache_key_name = op->name + ".cache_key";
std::string computed_bounds_name = op->name + ".computed_bounds.buffer";
Stmt cache_store_back =
IfThenElse::make(cache_miss, key_info.store_computation(cache_key_name, computed_bounds_name, f.outputs(), op->name));
Stmt mutated_body = Block::make(cache_store_back, body);
stmt = ProducerConsumer::make(op->name, op->is_producer, mutated_body);
}
} else {
IRMutator::visit(op);
}
}
};
Stmt inject_memoization(Stmt s, const std::map<std::string, Function> &env,
const std::string &name,
const std::vector<Function> &outputs) {
InjectMemoization injector(env, name, outputs);
return injector.mutate(s);
}
class RewriteMemoizedAllocations : public IRMutator {
public:
RewriteMemoizedAllocations(const std::map<std::string, Function> &e)
: env(e) {}
private:
const std::map<std::string, Function> &env;
std::map<std::string, std::vector<const Allocate *>> pending_memoized_allocations;
std::string innermost_realization_name;
std::string get_realization_name(const std::string &allocation_name) {
std::string realization_name = allocation_name;
size_t off = realization_name.rfind('.');
if (off != std::string::npos) {
size_t i = off + 1;
while (i < realization_name.size() && isdigit(realization_name[i])) {
i++;
}
if (i == realization_name.size()) {
realization_name = realization_name.substr(0, off);
}
}
return realization_name;
}
using IRMutator::visit;
void visit(const Allocate *allocation) {
std::string realization_name = get_realization_name(allocation->name);
std::map<std::string, Function>::const_iterator iter = env.find(realization_name);
if (iter != env.end() && iter->second.schedule().memoized()) {
std::string old_innermost_realization_name = innermost_realization_name;
innermost_realization_name = realization_name;
pending_memoized_allocations[innermost_realization_name].push_back(allocation);
stmt = mutate(allocation->body);
innermost_realization_name = old_innermost_realization_name;
} else {
IRMutator::visit(allocation);
}
}
void visit(const Call *call) {
if (!innermost_realization_name.empty() &&
call->name == Call::buffer_init) {
internal_assert(call->args.size() >= 3)
<< "RewriteMemoizedAllocations: _halide_buffer_init call with fewer than two args.\n";
const Variable *var = call->args[2].as<Variable>();
if (var && get_realization_name(var->name) == innermost_realization_name) {
std::vector<Expr> args = call->args;
args[2] = make_zero(Handle());
expr = Call::make(type_of<struct halide_buffer_t *>(), Call::buffer_init,
args, Call::Extern);
return;
}
}
IRMutator::visit(call);
}
void visit(const LetStmt *let) {
if (let->name == innermost_realization_name + ".cache_miss") {
Expr value = mutate(let->value);
Stmt body = mutate(let->body);
std::vector<const Allocate *> &allocations = pending_memoized_allocations[innermost_realization_name];
for (size_t i = allocations.size(); i > 0; i--) {
const Allocate *allocation = allocations[i - 1];
body = Allocate::make(allocation->name, allocation->type, allocation->extents, allocation->condition, body,
Call::make(Handle(), Call::buffer_get_host,
{ Variable::make(type_of<struct halide_buffer_t *>(), allocation->name + ".buffer") }, Call::Extern),
"halide_memoization_cache_release");
}
pending_memoized_allocations.erase(innermost_realization_name);
stmt = LetStmt::make(let->name, value, body);
} else {
IRMutator::visit(let);
}
}
};
Stmt rewrite_memoized_allocations(Stmt s, const std::map<std::string, Function> &env) {
RewriteMemoizedAllocations rewriter(env);
return rewriter.mutate(s);
}
}
}