This source file includes following definitions.
- name
- visit
- visit
- sum
- sum
- product
- product
- maximum
- maximum
- minimum
- minimum
- argmax
- argmax
- argmin
- argmin
#include "InlineReductions.h"
#include "Func.h"
#include "Scope.h"
#include "IROperator.h"
#include "IRMutator.h"
#include "Debug.h"
#include "CSE.h"
namespace Halide {
using std::string;
using std::vector;
using std::ostringstream;
namespace Internal {
class FindFreeVars : public IRMutator {
public:
vector<Var> free_vars;
vector<Expr> call_args;
RDom rdom;
FindFreeVars(RDom r, const string &n) :
rdom(r), explicit_rdom(r.defined()), name(n) {
}
private:
bool explicit_rdom;
const string &name;
Scope<int> internal;
using IRMutator::visit;
void visit(const Let *op) {
Expr value = mutate(op->value);
internal.push(op->name, 0);
Expr body = mutate(op->body);
internal.pop(op->name);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
expr = op;
} else {
expr = Let::make(op->name, value, body);
}
}
void visit(const Variable *v) {
string var_name = v->name;
expr = v;
if (internal.contains(var_name)) {
return;
}
if (v->reduction_domain.defined()) {
if (explicit_rdom) {
if (v->reduction_domain.same_as(rdom.domain())) {
return;
} else {
var_name = unique_name('v');
expr = Variable::make(v->type, var_name);
}
} else {
if (!rdom.defined()) {
rdom = RDom(v->reduction_domain);
return;
} else if (!rdom.domain().same_as(v->reduction_domain)) {
user_error << "Inline reduction \"" << name
<< "\" refers to reduction variables from multiple reduction domains: "
<< v->name << ", " << rdom.x.name() << "\n";
} else {
return;
}
}
}
if (v->param.defined()) {
return;
}
for (size_t i = 0; i < free_vars.size(); i++) {
if (var_name == free_vars[i].name()) return;
}
free_vars.push_back(Var(var_name));
call_args.push_back(v);
}
};
}
Expr sum(Expr e, const std::string &name) {
return sum(RDom(), e, name);
}
Expr sum(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
user_assert(v.rdom.defined()) << "Expression passed to sum must reference a reduction domain";
Func f(name);
f(v.free_vars) += e;
return f(v.call_args);
}
Expr product(Expr e, const std::string &name) {
return product(RDom(), e, name);
}
Expr product(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
user_assert(v.rdom.defined()) << "Expression passed to product must reference a reduction domain";
Func f(name);
f(v.free_vars) *= e;
return f(v.call_args);
}
Expr maximum(Expr e, const std::string &name) {
return maximum(RDom(), e, name);
}
Expr maximum(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
user_assert(v.rdom.defined()) << "Expression passed to maximum must reference a reduction domain";
Func f(name);
f(v.free_vars) = e.type().min();
f(v.free_vars) = max(f(v.free_vars), e);
return f(v.call_args);
}
Expr minimum(Expr e, const std::string &name) {
return minimum(RDom(), e, name);
}
Expr minimum(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
user_assert(v.rdom.defined()) << "Expression passed to minimum must reference a reduction domain";
Func f(name);
f(v.free_vars) = e.type().max();
f(v.free_vars) = min(f(v.free_vars), e);
return f(v.call_args);
}
Tuple argmax(Expr e, const std::string &name) {
return argmax(RDom(), e, name);
}
Tuple argmax(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
Func f(name);
user_assert(v.rdom.defined()) << "Expression passed to argmax must reference a reduction domain";
Tuple initial_tup(vector<Expr>(v.rdom.dimensions()+1));
Tuple update_tup(vector<Expr>(v.rdom.dimensions()+1));
for (int i = 0; i < v.rdom.dimensions(); i++) {
initial_tup[i] = 0;
update_tup[i] = v.rdom[i];
}
int value_index = (int)initial_tup.size()-1;
initial_tup[value_index] = e.type().min();
update_tup[value_index] = e;
f(v.free_vars) = initial_tup;
Expr better = e > f(v.free_vars)[value_index];
Tuple update = tuple_select(better, update_tup, f(v.free_vars));
f(v.free_vars) = update;
return f(v.call_args);
}
Tuple argmin(Expr e, const std::string &name) {
return argmin(RDom(), e, name);
}
Tuple argmin(RDom r, Expr e, const std::string &name) {
Internal::FindFreeVars v(r, name);
e = v.mutate(common_subexpression_elimination(e));
Func f(name);
user_assert(v.rdom.defined()) << "Expression passed to argmin must reference a reduction domain";
Tuple initial_tup(vector<Expr>(v.rdom.dimensions()+1));
Tuple update_tup(vector<Expr>(v.rdom.dimensions()+1));
for (int i = 0; i < v.rdom.dimensions(); i++) {
initial_tup[i] = 0;
update_tup[i] = v.rdom[i];
}
int value_index = (int)initial_tup.size()-1;
initial_tup[value_index] = e.type().max();
update_tup[value_index] = e;
f(v.free_vars) = initial_tup;
Expr better = e < f(v.free_vars)[value_index];
f(v.free_vars) = tuple_select(better, update_tup, f(v.free_vars));
return f(v.call_args);
}
}