This source file includes following definitions.
- make_block
- tag_linear_expression
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit_binary_linear
- visit
- visit
- visit
- visit
- visit_binary
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- max_expressions
- find_linear_expressions
- visit
- visit
- remove_varying_attributes
- visit
- replace_varying_attributes
- visit
- prune_varying_attributes
- visit
- visit
- float_type
- visit_binary_op
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- mutate
- mutate
- mutate_operator
- mutate_operator
- mutate_operator
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- dont_simplify
- used_in_codegen
- visit
- setup_gpu_vertex_buffer
#include "VaryingAttributes.h"
#include <algorithm>
#include "CodeGen_GPU_Dev.h"
#include "IRMutator.h"
#include "CSE.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
Stmt make_block(Stmt first, Stmt rest) {
if (first.defined() && rest.defined()) {
return Block::make(first, rest);
} else if (first.defined()) {
return first;
} else {
return rest;
}
}
class FindLinearExpressions : public IRMutator {
protected:
using IRMutator::visit;
bool in_glsl_loops;
Expr tag_linear_expression(Expr e, const std::string &name = unique_name('a')) {
internal_assert(name.length() > 0);
if (total_found >= max_expressions) {
return e;
}
Expr intrinsic = Call::make(e.type(), Call::glsl_varying,
{name + ".varying", e},
Call::Intrinsic);
++total_found;
return intrinsic;
}
virtual void visit(const Call *op) {
std::vector<Expr> new_args = op->args;
if (op->name == Call::glsl_texture_load) {
internal_assert(loop_vars.size() > 0) << "No GPU loop variables found at texture load\n";
for (int i = 2; i != 4; ++i) {
new_args[i] = mutate(op->args[i]);
if (order == 1) {
new_args[i] = tag_linear_expression(new_args[i]);
}
}
} else if (op->name == Call::glsl_texture_store) {
internal_assert(loop_vars.size() > 0) << "No GPU loop variables found at texture store\n";
new_args[5] = mutate(new_args[5]);
if (order == 1) {
new_args[5] = tag_linear_expression(new_args[5]);
}
}
order = 2;
expr = Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index, op->image, op->param);
}
virtual void visit(const Let *op) {
Expr mutated_value = mutate(op->value);
int value_order = order;
scope.push(op->name, order);
Expr mutated_body = mutate(op->body);
if ((value_order == 1) && (total_found < max_expressions)) {
mutated_value = Call::make(mutated_value.type(), Call::glsl_varying,
{op->name + ".varying", mutated_value},
Call::Intrinsic);
++total_found;
}
expr = Let::make(op->name, mutated_value, mutated_body);
scope.pop(op->name);
}
virtual void visit(const For *op) {
bool old_in_glsl_loops = in_glsl_loops;
bool kernel_loop = op->device_api == DeviceAPI::GLSL;
bool within_kernel_loop = !kernel_loop && in_glsl_loops;
if (kernel_loop) {
loop_vars.push_back(op->name);
in_glsl_loops = true;
} else if (within_kernel_loop) {
scope.push(op->name, 2);
}
Stmt mutated_body = mutate(op->body);
if (kernel_loop) {
loop_vars.pop_back();
} else if (within_kernel_loop) {
scope.pop(op->name);
}
in_glsl_loops = old_in_glsl_loops;
if (mutated_body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, mutated_body);
}
}
virtual void visit(const Variable *op) {
if (std::find(loop_vars.begin(), loop_vars.end(), op->name) != loop_vars.end()) {
order = 1;
} else if (scope.contains(op->name)) {
order = scope.get(op->name);
} else {
order = 0;
}
expr = op;
}
virtual void visit(const IntImm *op) { order = 0; expr = op; }
virtual void visit(const UIntImm *op) { order = 0; expr = op; }
virtual void visit(const FloatImm *op) { order = 0; expr = op; }
virtual void visit(const StringImm *op) { order = 0; expr = op; }
virtual void visit(const Cast *op) {
Expr mutated_value = mutate(op->value);
int value_order = order;
if (order && (!op->type.is_float())) {
order = 2;
}
if ((order > 1) && (value_order == 1)) {
mutated_value = tag_linear_expression(mutated_value);
}
expr = Cast::make(op->type, mutated_value);
}
template<typename T>
void visit_binary_linear(T *op) {
Expr a = mutate(op->a);
unsigned int order_a = order;
Expr b = mutate(op->b);
unsigned int order_b = order;
order = std::max(order_a, order_b);
if ((order > 1) && (order_a == 1)) {
a = tag_linear_expression(a);
}
if ((order > 1) && (order_b == 1)) {
b = tag_linear_expression(b);
}
expr = T::make(a, b);
}
virtual void visit(const Add *op) { visit_binary_linear(op); }
virtual void visit(const Sub *op) { visit_binary_linear(op); }
virtual void visit(const Mul *op) {
Expr a = mutate(op->a);
unsigned int order_a = order;
Expr b = mutate(op->b);
unsigned int order_b = order;
order = order_a + order_b;
if ((order > 1) && (order_a == 1)) {
a = tag_linear_expression(a);
}
if ((order > 1) && (order_b == 1)) {
b = tag_linear_expression(b);
}
expr = Mul::make(a, b);
}
virtual void visit(const Div *op) {
Expr a = mutate(op->a);
unsigned int order_a = order;
Expr b = mutate(op->b);
unsigned int order_b = order;
if (order_a && !order_b) {
order = order_a;
} else if (!order_a && order_b) {
order = 2;
} else {
order = order_a + order_b;
}
if ((order > 1) && (order_a == 1)) {
a = tag_linear_expression(a);
}
if ((order > 1) && (order_b == 1)) {
b = tag_linear_expression(b);
}
expr = Div::make(a, b);
}
template<typename T>
void visit_binary(T *op) {
Expr a = mutate(op->a);
unsigned int order_a = order;
Expr b = mutate(op->b);
unsigned int order_b = order;
if (order_a || order_b) {
order = 2;
}
if ((order > 1) && (order_a == 1)) {
a = tag_linear_expression(a);
}
if ((order > 1) && (order_b == 1)) {
b = tag_linear_expression(b);
}
expr = T::make(a, b);
}
virtual void visit(const Mod *op) { visit_binary(op); }
virtual void visit(const Min *op) { visit_binary(op); }
virtual void visit(const Max *op) { visit_binary(op); }
virtual void visit(const EQ *op) { visit_binary(op); }
virtual void visit(const NE *op) { visit_binary(op); }
virtual void visit(const LT *op) { visit_binary(op); }
virtual void visit(const LE *op) { visit_binary(op); }
virtual void visit(const GT *op) { visit_binary(op); }
virtual void visit(const GE *op) { visit_binary(op); }
virtual void visit(const And *op) { visit_binary(op); }
virtual void visit(const Or *op) { visit_binary(op); }
virtual void visit(const Not *op) {
Expr a = mutate(op->a);
unsigned int order_a = order;
if (order_a) {
order = 2;
}
expr = Not::make(a);
}
virtual void visit(const Broadcast *op) {
Expr a = mutate(op->value);
if (order == 1) {
a = tag_linear_expression(a);
}
if (order) {
order = 2;
}
expr = Broadcast::make(a, op->lanes);
}
virtual void visit(const Select *op) {
Expr mutated_condition = mutate(op->condition);
int condition_order = (order != 0) ? 2 : 0;
Expr mutated_true_value = mutate(op->true_value);
int true_value_order = order;
Expr mutated_false_value = mutate(op->false_value);
int false_value_order = order;
order = std::max(std::max(condition_order, true_value_order), false_value_order);
if ((order > 1) && (condition_order == 1)) {
mutated_condition = tag_linear_expression(mutated_condition);
}
if ((order > 1) && (true_value_order == 1)) {
mutated_true_value = tag_linear_expression(mutated_true_value);
}
if ((order > 1) && (false_value_order == 1)) {
mutated_false_value = tag_linear_expression(mutated_false_value);
}
expr = Select::make(mutated_condition, mutated_true_value, mutated_false_value);
}
public:
std::vector<std::string> loop_vars;
Scope<int> scope;
unsigned int order;
bool found;
unsigned int total_found;
const unsigned int max_expressions;
FindLinearExpressions() : in_glsl_loops(false), total_found(0), max_expressions(62) {}
};
Stmt find_linear_expressions(Stmt s) {
return FindLinearExpressions().mutate(s);
}
class FindVaryingAttributeTags : public IRVisitor
{
public:
FindVaryingAttributeTags(std::map<std::string, Expr>& varyings_) : varyings(varyings_) { }
using IRVisitor::visit;
virtual void visit(const Call *op) {
if (op->name == Call::glsl_varying) {
std::string name = op->args[0].as<StringImm>()->value;
varyings[name] = op->args[1];
}
IRVisitor::visit(op);
}
std::map<std::string, Expr>& varyings;
};
class RemoveVaryingAttributeTags : public IRMutator {
public:
using IRMutator::visit;
virtual void visit(const Call *op) {
if (op->name == Call::glsl_varying) {
expr = op->args[1];
} else {
IRMutator::visit(op);
}
}
};
Stmt remove_varying_attributes(Stmt s)
{
return RemoveVaryingAttributeTags().mutate(s);
}
class ReplaceVaryingAttributeTags : public IRMutator {
public:
using IRMutator::visit;
virtual void visit(const Call *op) {
if (op->name == Call::glsl_varying) {
std::string name = op->args[0].as<StringImm>()->value;
internal_assert(ends_with(name, ".varying"));
expr = Variable::make(op->type, name);
} else {
IRMutator::visit(op);
}
}
};
Stmt replace_varying_attributes(Stmt s)
{
return ReplaceVaryingAttributeTags().mutate(s);
}
class FindVaryingAttributeVars : public IRVisitor {
public:
using IRVisitor::visit;
virtual void visit(const Variable *op) {
if (ends_with(op->name, ".varying")) {
variables.insert(op->name);
}
}
std::set<std::string> variables;
};
void prune_varying_attributes(Stmt loop_stmt, std::map<std::string, Expr>& varying)
{
FindVaryingAttributeVars find;
loop_stmt.accept(&find);
std::vector<std::string> remove_list;
for (const std::pair<std::string, Expr> &i : varying) {
const std::string &name = i.first;
if (find.variables.find(name) == find.variables.end()) {
debug(2) << "Removed varying attribute " << name << "\n";
remove_list.push_back(name);
}
}
for (const std::string &i : remove_list) {
varying.erase(i);
}
}
class CastVaryingVariables : public IRMutator {
protected:
using IRMutator::visit;
virtual void visit(const Variable *op) {
if ((ends_with(op->name, ".varying")) && (op->type != Float(32))) {
Expr v = Variable::make(Float(32), op->name);
expr = Cast::make(op->type, floor(v + 0.5f));
} else {
expr = op;
}
}
};
class CastVariablesToFloatAndOffset : public IRMutator {
protected:
using IRMutator::visit;
virtual void visit(const Variable *op) {
if (std::find(names.begin(), names.end(), op->name) != names.end()) {
expr = Expr(op) - 0.5f;
} else if (scope.contains(op->name) && (op->type != scope.get(op->name).type())) {
expr = Variable::make(scope.get(op->name).type(), op->name);
} else {
expr = op;
}
}
Type float_type(Expr e) {
return Float(e.type().bits(), e.type().lanes());
}
template<typename T>
void visit_binary_op(const T *op) {
Expr mutated_a = mutate(op->a);
Expr mutated_b = mutate(op->b);
bool a_float = mutated_a.type().is_float();
bool b_float = mutated_b.type().is_float();
if (a_float || b_float) {
if (!a_float) {
mutated_a = Cast::make(float_type(op->b), mutated_a);
}
if (!b_float) {
mutated_b = Cast::make(float_type(op->a), mutated_b);
}
}
expr = T::make(mutated_a, mutated_b);
}
virtual void visit(const Add *op) { visit_binary_op(op); }
virtual void visit(const Sub *op) { visit_binary_op(op); }
virtual void visit(const Mul *op) { visit_binary_op(op); }
virtual void visit(const Div *op) { visit_binary_op(op); }
virtual void visit(const Mod *op) { visit_binary_op(op); }
virtual void visit(const Min *op) { visit_binary_op(op); }
virtual void visit(const Max *op) { visit_binary_op(op); }
virtual void visit(const EQ *op) { visit_binary_op(op); }
virtual void visit(const NE *op) { visit_binary_op(op); }
virtual void visit(const LT *op) { visit_binary_op(op); }
virtual void visit(const LE *op) { visit_binary_op(op); }
virtual void visit(const GT *op) { visit_binary_op(op); }
virtual void visit(const GE *op) { visit_binary_op(op); }
virtual void visit(const And *op) { visit_binary_op(op); }
virtual void visit(const Or *op) { visit_binary_op(op); }
virtual void visit(const Select *op) {
Expr mutated_condition = mutate(op->condition);
Expr mutated_true_value = mutate(op->true_value);
Expr mutated_false_value = mutate(op->false_value);
bool t_float = mutated_true_value.type().is_float();
bool f_float = mutated_false_value.type().is_float();
if (t_float || f_float) {
if (!t_float) {
mutated_true_value = Cast::make(float_type(op->true_value), mutated_true_value);
}
if (!f_float) {
mutated_false_value = Cast::make(float_type(op->false_value), mutated_false_value);
}
}
expr = Select::make(mutated_condition, mutated_true_value, mutated_false_value);
}
virtual void visit(const Ramp *op) {
Expr mutated_base = mutate(op->base);
Expr mutated_stride = mutate(op->stride);
bool base_float = mutated_base.type().is_float();
bool stride_float = mutated_stride.type().is_float();
if (!base_float && stride_float) {
mutated_base = Cast::make(float_type(op->base), mutated_base);
}
else if (base_float && !stride_float) {
mutated_stride = Cast::make(float_type(op->stride), mutated_stride);
}
if (mutated_base.same_as(op->base) && mutated_stride.same_as(op->stride)) {
expr = op;
}
else {
expr = Ramp::make(mutated_base, mutated_stride, op->lanes);
}
}
virtual void visit(const Let *op) {
Expr mutated_value = mutate(op->value);
bool changed = op->value.type().is_float() != mutated_value.type().is_float();
if (changed) {
scope.push(op->name, mutated_value);
}
Expr mutated_body = mutate(op->body);
if (changed) {
scope.pop(op->name);
}
expr = Let::make(op->name, mutated_value, mutated_body);
}
virtual void visit(const LetStmt *op) {
Expr mutated_value = mutate(op->value);
bool changed = op->value.type().is_float() != mutated_value.type().is_float();
if (changed) {
scope.push(op->name, mutated_value);
}
Stmt mutated_body = mutate(op->body);
if (changed) {
scope.pop(op->name);
}
stmt = LetStmt::make(op->name, mutated_value, mutated_body);
}
public:
CastVariablesToFloatAndOffset(const std::vector<std::string>& names_) : names(names_) { }
const std::vector<std::string>& names;
Scope<Expr> scope;
};
class IRFilter : public IRVisitor {
public:
virtual Stmt mutate(Expr expr);
virtual Stmt mutate(Stmt stmt);
protected:
using IRVisitor::visit;
Stmt stmt;
virtual void visit(const IntImm *);
virtual void visit(const FloatImm *);
virtual void visit(const StringImm *);
virtual void visit(const Cast *);
virtual void visit(const Variable *);
virtual void visit(const Add *);
virtual void visit(const Sub *);
virtual void visit(const Mul *);
virtual void visit(const Div *);
virtual void visit(const Mod *);
virtual void visit(const Min *);
virtual void visit(const Max *);
virtual void visit(const EQ *);
virtual void visit(const NE *);
virtual void visit(const LT *);
virtual void visit(const LE *);
virtual void visit(const GT *);
virtual void visit(const GE *);
virtual void visit(const And *);
virtual void visit(const Or *);
virtual void visit(const Not *);
virtual void visit(const Select *);
virtual void visit(const Load *);
virtual void visit(const Ramp *);
virtual void visit(const Broadcast *);
virtual void visit(const Call *);
virtual void visit(const Let *);
virtual void visit(const LetStmt *);
virtual void visit(const AssertStmt *);
virtual void visit(const ProducerConsumer *);
virtual void visit(const For *);
virtual void visit(const Store *);
virtual void visit(const Provide *);
virtual void visit(const Allocate *);
virtual void visit(const Free *);
virtual void visit(const Realize *);
virtual void visit(const Block *);
virtual void visit(const IfThenElse *);
virtual void visit(const Evaluate *);
};
Stmt IRFilter::mutate(Expr e) {
if (e.defined()) {
e.accept(this);
}
else {
stmt = Stmt();
}
return stmt;
}
Stmt IRFilter::mutate(Stmt s) {
if (s.defined()) {
s.accept(this);
} else {
stmt = Stmt();
}
return stmt;
}
namespace {
template<typename T, typename A>
void mutate_operator(IRFilter *mutator, const T *op, const A op_a, Stmt *stmt) {
Stmt a = mutator->mutate(op_a);
*stmt = a;
}
template<typename T, typename A, typename B>
void mutate_operator(IRFilter *mutator, const T *op, const A op_a, const B op_b, Stmt *stmt) {
Stmt a = mutator->mutate(op_a);
Stmt b = mutator->mutate(op_b);
*stmt = make_block(a, b);
}
template<typename T, typename A, typename B, typename C>
void mutate_operator(IRFilter *mutator, const T *op, const A op_a, const B op_b, const C op_c, Stmt *stmt) {
Stmt a = mutator->mutate(op_a);
Stmt b = mutator->mutate(op_b);
Stmt c = mutator->mutate(op_c);
*stmt = make_block(make_block(a, b), c);
}
}
void IRFilter::visit(const IntImm *op) {stmt = Stmt();}
void IRFilter::visit(const FloatImm *op) {stmt = Stmt();}
void IRFilter::visit(const StringImm *op) {stmt = Stmt();}
void IRFilter::visit(const Variable *op) {stmt = Stmt();}
void IRFilter::visit(const Cast *op) {
mutate_operator(this, op, op->value, &stmt);
}
void IRFilter::visit(const Add *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Sub *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Mul *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Div *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Mod *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Min *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Max *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const EQ *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const NE *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const LT *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const LE *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const GT *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const GE *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const And *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Or *op) {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Not *op) {
mutate_operator(this, op, op->a, &stmt);
}
void IRFilter::visit(const Select *op) {
mutate_operator(this, op, op->condition, op->true_value, op->false_value, &stmt);
}
void IRFilter::visit(const Load *op) {
mutate_operator(this, op, op->predicate, op->index, &stmt);
}
void IRFilter::visit(const Ramp *op) {
mutate_operator(this, op, op->base, op->stride, &stmt);
}
void IRFilter::visit(const Broadcast *op) {
mutate_operator(this, op, op->value, &stmt);
}
void IRFilter::visit(const Call *op) {
std::vector<Stmt> new_args(op->args.size());
for (size_t i = 0; i < op->args.size(); i++) {
Expr old_arg = op->args[i];
Stmt new_arg = mutate(old_arg);
new_args[i] = new_arg;
}
stmt = Stmt();
for (size_t i = 0; i < new_args.size(); ++i) {
if (new_args[i].defined()) {
stmt = make_block(new_args[i], stmt);
}
}
}
void IRFilter::visit(const Let *op) {
mutate_operator(this, op, op->value, op->body, &stmt);
}
void IRFilter::visit(const LetStmt *op) {
mutate_operator(this, op, op->value, op->body, &stmt);
}
void IRFilter::visit(const AssertStmt *op) {
mutate_operator(this, op, op->condition, op->message, &stmt);
}
void IRFilter::visit(const ProducerConsumer *op) {
mutate_operator(this, op, op->body, &stmt);
}
void IRFilter::visit(const For *op) {
mutate_operator(this, op, op->min, op->extent, op->body, &stmt);
}
void IRFilter::visit(const Store *op) {
mutate_operator(this, op, op->predicate, op->value, op->index, &stmt);
}
void IRFilter::visit(const Provide *op) {
stmt = Stmt();
for (size_t i = 0; i < op->args.size(); i++) {
Stmt new_arg = mutate(op->args[i]);
if (new_arg.defined()) {
stmt = make_block(new_arg, stmt);
}
Stmt new_value = mutate(op->values[i]);
if (new_value.defined()) {
stmt = make_block(new_value, stmt);
}
}
}
void IRFilter::visit(const Allocate *op) {
stmt = Stmt();
for (size_t i = 0; i < op->extents.size(); i++) {
Stmt new_extent = mutate(op->extents[i]);
if (new_extent.defined())
stmt = make_block(new_extent, stmt);
}
Stmt body = mutate(op->body);
if (body.defined())
stmt = make_block(body, stmt);
Stmt condition = mutate(op->condition);
if (condition.defined())
stmt = make_block(condition, stmt);
}
void IRFilter::visit(const Free *op) {
}
void IRFilter::visit(const Realize *op) {
stmt = Stmt();
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i].min;
Expr old_extent = op->bounds[i].extent;
Stmt new_min = mutate(old_min);
Stmt new_extent = mutate(old_extent);
if (new_min.defined())
stmt = make_block(new_min, stmt);
if (new_extent.defined())
stmt = make_block(new_extent, stmt);
}
Stmt body = mutate(op->body);
if (body.defined())
stmt = make_block(body, stmt);
Stmt condition = mutate(op->condition);
if (condition.defined())
stmt = make_block(condition, stmt);
}
void IRFilter::visit(const Block *op) {
mutate_operator(this, op, op->first, op->rest, &stmt);
}
void IRFilter::visit(const IfThenElse *op) {
mutate_operator(this, op, op->condition, op->then_case, op->else_case, &stmt);
}
void IRFilter::visit(const Evaluate *op) {
mutate_operator(this, op, op->value, &stmt);
}
class CreateVertexBufferOnHost : public IRFilter {
public:
using IRFilter::visit;
virtual void visit(const Call *op) {
if (op->name == Call::glsl_varying) {
std::string attribute_name = op->args[0].as<StringImm>()->value;
Expr offset_expression = Variable::make(Int(32), "gpu.vertex_offset") +
attribute_order[attribute_name];
stmt = Store::make(vertex_buffer_name, op->args[1], offset_expression,
Parameter(), const_true(op->args[1].type().lanes()));
} else {
IRFilter::visit(op);
}
}
virtual void visit(const Let *op) {
stmt = nullptr;
Stmt mutated_value = mutate(op->value);
Stmt mutated_body = mutate(op->body);
if (mutated_body.defined()) {
stmt = LetStmt::make(op->name, op->value, mutated_body);
}
if (mutated_value.defined()) {
stmt = make_block(mutated_value, stmt);
}
}
virtual void visit(const LetStmt *op) {
stmt = Stmt();
Stmt mutated_value = mutate(op->value);
Stmt mutated_body = mutate(op->body);
if (mutated_body.defined()) {
stmt = LetStmt::make(op->name, op->value, mutated_body);
}
if (mutated_value.defined()) {
stmt = make_block(mutated_value, stmt);
}
}
virtual void visit(const For *op) {
if (CodeGen_GPU_Dev::is_gpu_var(op->name) && op->device_api == DeviceAPI::GLSL) {
std::string name = op->name + ".idx";
const std::vector<Expr>& dim = dims[op->name];
internal_assert(for_loops.size() <= 1);
for_loops.push_back(op);
Expr loop_variable = Variable::make(Int(32),name);
loop_variables.push_back(loop_variable);
Expr coord_expr = select(loop_variable == 0, dim[0], dim[1]);
Stmt mutated_body = mutate(op->body);
const For *nested_for = op->body.as<For>();
if (!(nested_for && CodeGen_GPU_Dev::is_gpu_var(nested_for->name))) {
Expr gpu_varying_offset = Variable::make(Int(32), "gpu.vertex_offset");
Expr coord1 = cast<float>(Variable::make(Int(32),for_loops[0]->name));
Expr coord0 = cast<float>(Variable::make(Int(32),for_loops[1]->name));
coord1 = (coord1 / for_loops[0]->extent) * 2.0f - 1.0f;
coord0 = (coord0 / for_loops[1]->extent) * 2.0f - 1.0f;
mutated_body = remove_varying_attributes(mutated_body);
std::vector<std::string> names = {for_loops[0]->name, for_loops[1]->name};
CastVariablesToFloatAndOffset cast_and_offset(names);
mutated_body = cast_and_offset.mutate(mutated_body);
mutated_body = make_block(Store::make(vertex_buffer_name,
coord1,
gpu_varying_offset + 1,
Parameter(), const_true()),
mutated_body);
mutated_body = make_block(Store::make(vertex_buffer_name,
coord0,
gpu_varying_offset + 0,
Parameter(), const_true()),
mutated_body);
Expr offset_expression = (loop_variables[0] * num_padded_attributes * 2) +
(loop_variables[1] * num_padded_attributes);
mutated_body = LetStmt::make("gpu.vertex_offset",
offset_expression, mutated_body);
}
Stmt loop_var = LetStmt::make(op->name, coord_expr, mutated_body);
stmt = For::make(name, 0, (int)dim.size(), ForType::Serial, DeviceAPI::None, loop_var);
} else {
IRFilter::visit(op);
}
}
std::string vertex_buffer_name;
typedef std::map<std::string, std::vector<Expr>> DimsType;
DimsType dims;
std::map<std::string, int> attribute_order;
int num_padded_attributes;
std::vector<const For*> for_loops;
std::vector<Expr> loop_variables;
};
Expr dont_simplify(Expr v_) {
return Internal::Call::make(v_.type(),
Internal::Call::return_second,
{0, v_},
Internal::Call::Intrinsic);
}
Stmt used_in_codegen(Type type_, const std::string &v_) {
return Evaluate::make(Internal::Call::make(Int(32),
Internal::Call::return_second,
{Variable::make(type_, v_), 0},
Internal::Call::Intrinsic));
}
class CreateVertexBufferHostLoops : public IRMutator {
public:
using IRMutator::visit;
virtual void visit(const For *op) {
if (CodeGen_GPU_Dev::is_gpu_var(op->name) && op->device_api == DeviceAPI::GLSL) {
const For *loop1 = op;
const For *loop0 = loop1->body.as<For>();
internal_assert(loop1->body.as<For>()) << "Did not find pair of nested For loops";
std::map<std::string, Expr> varyings;
FindVaryingAttributeTags tag_finder(varyings);
op->accept(&tag_finder);
std::map<std::string, int> attribute_order;
attribute_order["__vertex_x"] = 0;
attribute_order["__vertex_y"] = 1;
int idx = 2;
for (const std::pair<std::string, Expr> &v : varyings) {
attribute_order[v.first] = idx++;
}
attribute_order[loop0->name] = 0;
attribute_order[loop1->name] = 1;
Expr loop0_max = Add::make(loop0->min, loop0->extent);
Expr loop1_max = Add::make(loop1->min, loop1->extent);
std::vector<std::vector<Expr>> coords(2);
coords[0].push_back(loop0->min);
coords[0].push_back(loop0_max);
coords[1].push_back(loop1->min);
coords[1].push_back(loop1_max);
int num_attributes = varyings.size() + 2;
int num_padded_attributes = (num_attributes + 0x3) & ~0x3;
int vertex_buffer_size = num_padded_attributes*coords[0].size()*coords[1].size();
CreateVertexBufferOnHost vs;
vs.vertex_buffer_name = "glsl.vertex_buffer";
vs.num_padded_attributes = num_padded_attributes;
vs.dims[loop0->name] = coords[0];
vs.dims[loop1->name] = coords[1];
vs.attribute_order = attribute_order;
Stmt vertex_setup = vs.mutate(loop1);
vertex_setup = remove_varying_attributes(vertex_setup);
vertex_setup = simplify(vertex_setup);
vertex_setup = simplify(vertex_setup);
vertex_setup = simplify(vertex_setup);
vertex_setup = simplify(vertex_setup);
Stmt loop_stmt = replace_varying_attributes(op);
loop_stmt = simplify(loop_stmt, true);
prune_varying_attributes(loop_stmt, varyings);
loop_stmt = CastVaryingVariables().mutate(loop_stmt);
stmt = LetStmt::make("glsl.num_coords_dim0", dont_simplify((int)(coords[0].size())),
LetStmt::make("glsl.num_coords_dim1", dont_simplify((int)(coords[1].size())),
LetStmt::make("glsl.num_padded_attributes", dont_simplify(num_padded_attributes),
Allocate::make(vs.vertex_buffer_name, Float(32), {vertex_buffer_size}, const_true(),
Block::make(vertex_setup,
Block::make(loop_stmt,
Block::make(used_in_codegen(Int(32), "glsl.num_coords_dim0"),
Block::make(used_in_codegen(Int(32), "glsl.num_coords_dim1"),
Block::make(used_in_codegen(Int(32), "glsl.num_padded_attributes"),
Free::make(vs.vertex_buffer_name))))))))));
} else {
IRMutator::visit(op);
}
}
};
Stmt setup_gpu_vertex_buffer(Stmt s) {
CreateVertexBufferHostLoops vb;
return vb.mutate(s);
}
}
}