This source file includes following definitions.
- name
- visit
- visit
- sum
- sum
- product
- product
- maximum
- maximum
- minimum
- minimum
- argmax
- argmax
- argmin
- argmin
#include "InlineReductions.h"
#include "Func.h"
#include "Scope.h"
#include "IROperator.h"
#include "IRMutator.h"
#include "Debug.h"
#include "CSE.h"
namespace Halide {
using std::string;
using std::vector;
using std::ostringstream;
namespace Internal {
class FindFreeVars : public IRMutator {
public:
    vector<Var> free_vars;
    vector<Expr> call_args;
    RDom rdom;
    FindFreeVars(RDom r, const string &n) :
        rdom(r), explicit_rdom(r.defined()), name(n) {
    }
private:
    bool explicit_rdom;
    const string &name;
    Scope<int> internal;
    using IRMutator::visit;
    void visit(const Let *op) {
        Expr value = mutate(op->value);
        internal.push(op->name, 0);
        Expr body = mutate(op->body);
        internal.pop(op->name);
        if (value.same_as(op->value) &&
            body.same_as(op->body)) {
            expr = op;
        } else {
            expr = Let::make(op->name, value, body);
        }
    }
    void visit(const Variable *v) {
        string var_name = v->name;
        expr = v;
        if (internal.contains(var_name)) {
            
            return;
        }
        if (v->reduction_domain.defined()) {
            if (explicit_rdom) {
                if (v->reduction_domain.same_as(rdom.domain())) {
                    
                    
                    return;
                } else {
                    
                    
                    var_name = unique_name('v');
                    expr = Variable::make(v->type, var_name);
                }
            } else {
                if (!rdom.defined()) {
                    
                    
                    rdom = RDom(v->reduction_domain);
                    return;
                } else if (!rdom.domain().same_as(v->reduction_domain)) {
                    
                    
                    user_error << "Inline reduction \"" << name
                               << "\" refers to reduction variables from multiple reduction domains: "
                               << v->name << ", " << rdom.x.name() << "\n";
                } else {
                    
                    return;
                }
            }
        }
        if (v->param.defined()) {
            
            return;
        }
        for (size_t i = 0; i < free_vars.size(); i++) {
            if (var_name == free_vars[i].name()) return;
        }
        free_vars.push_back(Var(var_name));
        call_args.push_back(v);
    }
};
}
Expr sum(Expr e, const std::string &name) {
    return sum(RDom(), e, name);
}
Expr sum(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    user_assert(v.rdom.defined()) << "Expression passed to sum must reference a reduction domain";
    Func f(name);
    f(v.free_vars) += e;
    return f(v.call_args);
}
Expr product(Expr e, const std::string &name) {
    return product(RDom(), e, name);
}
Expr product(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    user_assert(v.rdom.defined()) << "Expression passed to product must reference a reduction domain";
    Func f(name);
    f(v.free_vars) *= e;
    return f(v.call_args);
}
Expr maximum(Expr e, const std::string &name) {
    return maximum(RDom(), e, name);
}
Expr maximum(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    user_assert(v.rdom.defined()) << "Expression passed to maximum must reference a reduction domain";
    Func f(name);
    f(v.free_vars) = e.type().min();
    f(v.free_vars) = max(f(v.free_vars), e);
    return f(v.call_args);
}
Expr minimum(Expr e, const std::string &name) {
    return minimum(RDom(), e, name);
}
Expr minimum(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    user_assert(v.rdom.defined()) << "Expression passed to minimum must reference a reduction domain";
    Func f(name);
    f(v.free_vars) = e.type().max();
    f(v.free_vars) = min(f(v.free_vars), e);
    return f(v.call_args);
}
Tuple argmax(Expr e, const std::string &name) {
    return argmax(RDom(), e, name);
}
Tuple argmax(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    Func f(name);
    user_assert(v.rdom.defined()) << "Expression passed to argmax must reference a reduction domain";
    Tuple initial_tup(vector<Expr>(v.rdom.dimensions()+1));
    Tuple update_tup(vector<Expr>(v.rdom.dimensions()+1));
    for (int i = 0; i < v.rdom.dimensions(); i++) {
        initial_tup[i] = 0;
        update_tup[i] = v.rdom[i];
    }
    int value_index = (int)initial_tup.size()-1;
    initial_tup[value_index] = e.type().min();
    update_tup[value_index] = e;
    f(v.free_vars) = initial_tup;
    Expr better = e > f(v.free_vars)[value_index];
    Tuple update = tuple_select(better, update_tup, f(v.free_vars));
    f(v.free_vars) = update;
    return f(v.call_args);
}
Tuple argmin(Expr e, const std::string &name) {
    return argmin(RDom(), e, name);
}
Tuple argmin(RDom r, Expr e, const std::string &name) {
    Internal::FindFreeVars v(r, name);
    e = v.mutate(common_subexpression_elimination(e));
    Func f(name);
    user_assert(v.rdom.defined()) << "Expression passed to argmin must reference a reduction domain";
    Tuple initial_tup(vector<Expr>(v.rdom.dimensions()+1));
    Tuple update_tup(vector<Expr>(v.rdom.dimensions()+1));
    for (int i = 0; i < v.rdom.dimensions(); i++) {
        initial_tup[i] = 0;
        update_tup[i] = v.rdom[i];
    }
    int value_index = (int)initial_tup.size()-1;
    initial_tup[value_index] = e.type().max();
    update_tup[value_index] = e;
    f(v.free_vars) = initial_tup;
    Expr better = e < f(v.free_vars)[value_index];
    f(v.free_vars) = tuple_select(better, update_tup, f(v.free_vars));
    return f(v.call_args);
}
}