#ifndef HALIDE_ASSOCIATIVITY_H #define HALIDE_ASSOCIATIVITY_H /** \file * * Methods for extracting an associative operator from a Func's update definition * if there is any and computing the identity of the associative operator. */ #include "IR.h" #include "IREquality.h" #include "AssociativeOpsTable.h" #include <functional> namespace Halide { namespace Internal { /** * Represent the equivalent associative op of an update definition. * For example, the following associative Expr, min(f(x), g(r.x) + 2), * where f(x) is the self-recurrence term, is represented as: \code AssociativeOp assoc( AssociativePattern(min(x, y), +inf, true), {Replacement("x", f(x))}, {Replacement("y", g(r.x) + 2)}, true ); \endcode * * 'pattern' contains the list of equivalent binary/unary operators (+ identities) * for each Tuple element in the update definition. 'pattern' also contains * a boolean that indicates if the op is also commutative. 'xs' and 'ys' * contain the corresponding definition of each variable in the list of * binary operators. * * For unary operator, 'xs' is not set, i.e. it will be a pair of empty string * and undefined Expr: {"", Expr()}. 'pattern' will only contain the 'y' term in * this case. For example, min(g(r.x), 4), will be represented as: \code AssociativeOp assoc( AssociativePattern(y, 0, false), {Replacement("", Expr())}, {Replacement("y", min(g(r.x), 4))}, true ); \endcode * Since it is a unary operator, the identity does not matter. It can be * anything. */ struct AssociativeOp { struct Replacement { /** Variable name that is used to replace the expr in 'op'. */ std::string var; Expr expr; Replacement() {} Replacement(const std::string &var, Expr expr) : var(var), expr(expr) {} bool operator==(const Replacement &other) const { return (var == other.var) && equal(expr, other.expr); } bool operator!=(const Replacement &other) const { return !(*this == other); } }; /** List of pairs of binary associative op and its identity. */ AssociativePattern pattern; std::vector<Replacement> xs; std::vector<Replacement> ys; bool is_associative; AssociativeOp() : is_associative(false) {} AssociativeOp(size_t size) : pattern(size), xs(size), ys(size), is_associative(false) {} AssociativeOp(const AssociativePattern &p, const std::vector<Replacement> &xs, const std::vector<Replacement> &ys, bool is_associative) : pattern(p), xs(xs), ys(ys), is_associative(is_associative) {} bool associative() const { return is_associative; } bool commutative() const { return pattern.is_commutative; } size_t size() const { return pattern.size(); } }; /** * Given an update definition of a Func 'f', determine its equivalent * associative binary/unary operator if there is any. 'is_associative' * indicates if the operation was successfuly proven as associative. * * Note that even though f(x) = f(x) is associative, we'll treat it as * non-associative since it doesn't really make any sense to do any associative * reduction on that particular update definition. */ AssociativeOp prove_associativity( const std::string &f, std::vector<Expr> args, std::vector<Expr> exprs); EXPORT void associativity_test(); } } #endif