#ifndef HALIDE_ASSOCIATIVE_OPS_TABLE_H
#define HALIDE_ASSOCIATIVE_OPS_TABLE_H
#include "IROperator.h"
#include "IREquality.h"
#include <iostream>
#include <vector>
namespace Halide {
namespace Internal {
struct AssociativePattern {
std::vector<Expr> ops;
std::vector<Expr> identities;
bool is_commutative;
AssociativePattern() : is_commutative(false) {}
AssociativePattern(size_t size) : ops(size), identities(size), is_commutative(false) {}
AssociativePattern(const std::vector<Expr> &ops, const std::vector<Expr> &ids, bool is_commutative)
: ops(ops), identities(ids), is_commutative(is_commutative) {}
AssociativePattern(Expr op, Expr id, bool is_commutative)
: ops({op}), identities({id}), is_commutative(is_commutative) {}
bool operator==(const AssociativePattern &other) const {
if ((is_commutative != other.is_commutative) || (ops.size() != other.ops.size())) {
return false;
}
for (size_t i = 0; i < size(); ++i) {
if (!equal(ops[i], other.ops[i]) || !equal(identities[i], other.identities[i])) {
return false;
}
}
return true;
}
bool operator!=(const AssociativePattern &other) const { return !(*this == other); }
size_t size() const { return ops.size(); }
bool commutative() const { return is_commutative; }
};
const std::vector<AssociativePattern> &get_ops_table(const std::vector<Expr> &exprs);
}
}
#endif