root/src/Associativity.h

/* [<][>][^][v][top][bottom][index][help] */

INCLUDED FROM


#ifndef HALIDE_ASSOCIATIVITY_H
#define HALIDE_ASSOCIATIVITY_H

/** \file
 *
 * Methods for extracting an associative operator from a Func's update definition
 * if there is any and computing the identity of the associative operator.
 */

#include "IR.h"
#include "IREquality.h"
#include "AssociativeOpsTable.h"

#include <functional>

namespace Halide {
namespace Internal {

/**
 * Represent the equivalent associative op of an update definition.
 * For example, the following associative Expr, min(f(x), g(r.x) + 2),
 * where f(x) is the self-recurrence term, is represented as:
 \code
 AssociativeOp assoc(
    AssociativePattern(min(x, y), +inf, true),
    {Replacement("x", f(x))},
    {Replacement("y", g(r.x) + 2)},
    true
 );
 \endcode
 *
 * 'pattern' contains the list of equivalent binary/unary operators (+ identities)
 * for each Tuple element in the update definition. 'pattern' also contains
 * a boolean that indicates if the op is also commutative. 'xs' and 'ys'
 * contain the corresponding definition of each variable in the list of
 * binary operators.
 *
 * For unary operator, 'xs' is not set, i.e. it will be a pair of empty string
 * and undefined Expr: {"", Expr()}. 'pattern' will only contain the 'y' term in
 * this case. For example, min(g(r.x), 4), will be represented as:
 \code
 AssociativeOp assoc(
    AssociativePattern(y, 0, false),
    {Replacement("", Expr())},
    {Replacement("y", min(g(r.x), 4))},
    true
 );
 \endcode
 * Since it is a unary operator, the identity does not matter. It can be
 * anything.
 */
struct AssociativeOp {
    struct Replacement {
        /** Variable name that is used to replace the expr in 'op'. */
        std::string var;
        Expr expr;

        Replacement() {}
        Replacement(const std::string &var, Expr expr) : var(var), expr(expr) {}

        bool operator==(const Replacement &other) const {
            return (var == other.var) && equal(expr, other.expr);
        }
        bool operator!=(const Replacement &other) const {
            return !(*this == other);
        }
    };

    /** List of pairs of binary associative op and its identity. */
    AssociativePattern pattern;
    std::vector<Replacement> xs;
    std::vector<Replacement> ys;
    bool is_associative;

    AssociativeOp() : is_associative(false) {}
    AssociativeOp(size_t size) : pattern(size), xs(size), ys(size), is_associative(false) {}
    AssociativeOp(const AssociativePattern &p, const std::vector<Replacement> &xs,
                  const std::vector<Replacement> &ys, bool is_associative)
        : pattern(p), xs(xs), ys(ys), is_associative(is_associative) {}

    bool associative() const { return is_associative; }
    bool commutative() const { return pattern.is_commutative; }
    size_t size() const { return pattern.size(); }
};

/**
 * Given an update definition of a Func 'f', determine its equivalent
 * associative binary/unary operator if there is any. 'is_associative'
 * indicates if the operation was successfuly proven as associative.
 *
 * Note that even though f(x) = f(x) is associative, we'll treat it as
 * non-associative since it doesn't really make any sense to do any associative
 * reduction on that particular update definition.
 */
AssociativeOp prove_associativity(
    const std::string &f, std::vector<Expr> args, std::vector<Expr> exprs);

EXPORT void associativity_test();

}
}

#endif

/* [<][>][^][v][top][bottom][index][help] */