This source file includes following definitions.
- check
- split_predicate_test
- frozen
- accept
- mutate
- deep_copy
- domain
- visit
- domain
- set_predicate
- where
- predicate
- split_predicate
- freeze
- frozen
- accept
- mutate
#include "Var.h"
#include "IR.h"
#include "IREquality.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "IRMutator.h"
#include "Reduction.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
namespace {
void check(Expr pred, std::vector<Expr> &expected) {
std::vector<Expr> result;
split_into_ands(pred, result);
bool is_equal = true;
if (result.size() != expected.size()) {
is_equal = false;
} else {
for (size_t i = 0; i < expected.size(); ++i) {
if (!equal(simplify(result[i]), simplify(expected[i]))) {
is_equal = false;
break;
}
}
}
if (!is_equal) {
std::cout << "Expect predicate " << pred << " to be split into:\n";
for (const auto &e : expected) {
std::cout << " " << e << "\n";
}
std::cout << "Got:\n";
for (const auto &e : result) {
std::cout << " " << e << "\n";
}
internal_error << "\n";
}
}
}
void split_predicate_test() {
Expr x = Var("x"), y = Var("y"), z = Var("z"), w = Var("w");
{
std::vector<Expr> expected;
expected.push_back(z < 10);
check(z < 10, expected);
}
{
std::vector<Expr> expected;
expected.push_back((x < y) || (x == 10));
check((x < y) || (x == 10), expected);
}
{
std::vector<Expr> expected;
expected.push_back(x < y);
expected.push_back(x == 10);
check((x < y) && (x == 10), expected);
}
{
std::vector<Expr> expected;
expected.push_back(x < y);
expected.push_back(x == 10);
expected.push_back(y == z);
check((x < y) && (x == 10) && (y == z), expected);
}
{
std::vector<Expr> expected;
expected.push_back((w == 1) || ((x == 10) && (y == z)));
check((w == 1) || ((x == 10) && (y == z)), expected);
}
{
std::vector<Expr> expected;
expected.push_back(x < y);
expected.push_back((w == 1) || ((x == 10) && (y == z)));
check((x < y) && ((w == 1) || ((x == 10) && (y == z))), expected);
}
std::cout << "Split predicate test passed" << std::endl;
}
struct ReductionDomainContents {
mutable RefCount ref_count;
std::vector<ReductionVariable> domain;
Expr predicate;
bool frozen;
ReductionDomainContents() : predicate(const_true()), frozen(false) {
}
void accept(IRVisitor *visitor) {
for (const ReductionVariable &rvar : domain) {
if (rvar.min.defined()) {
rvar.min.accept(visitor);
}
if (rvar.extent.defined()) {
rvar.extent.accept(visitor);
}
}
if (predicate.defined()) {
predicate.accept(visitor);
}
}
void mutate(IRMutator *mutator) {
for (ReductionVariable &rvar : domain) {
if (rvar.min.defined()) {
rvar.min = mutator->mutate(rvar.min);
}
if (rvar.extent.defined()) {
rvar.extent = mutator->mutate(rvar.extent);
}
}
if (predicate.defined()) {
predicate = mutator->mutate(predicate);
}
}
};
template<>
EXPORT RefCount &ref_count<Halide::Internal::ReductionDomainContents>(const ReductionDomainContents *p) {return p->ref_count;}
template<>
EXPORT void destroy<Halide::Internal::ReductionDomainContents>(const ReductionDomainContents *p) {delete p;}
ReductionDomain::ReductionDomain(const std::vector<ReductionVariable> &domain) :
contents(new ReductionDomainContents) {
contents->domain = domain;
}
ReductionDomain ReductionDomain::deep_copy() const {
if (!contents.defined()) {
return ReductionDomain();
}
ReductionDomain copy(contents->domain);
copy.contents->predicate = contents->predicate;
copy.contents->frozen = contents->frozen;
return copy;
}
const std::vector<ReductionVariable> &ReductionDomain::domain() const {
return contents->domain;
}
namespace {
class DropSelfReferences : public IRMutator {
using IRMutator::visit;
void visit(const Variable *op) {
if (op->reduction_domain.defined()) {
user_assert(op->reduction_domain.same_as(domain))
<< "An RDom's predicate may only refer to its own RVars, "
<< " not the RVars of some other RDom. "
<< "Cannot set the predicate to : " << predicate << "\n";
expr = Variable::make(op->type, op->name);
} else {
expr = op;
}
}
public:
Expr predicate;
const ReductionDomain &domain;
DropSelfReferences(Expr p, const ReductionDomain &d) :
predicate(p), domain(d) {}
};
}
void ReductionDomain::set_predicate(Expr p) {
contents->predicate = DropSelfReferences(p, *this).mutate(p);
}
void ReductionDomain::where(Expr predicate) {
set_predicate(simplify(contents->predicate && predicate));
}
Expr ReductionDomain::predicate() const {
return contents->predicate;
}
std::vector<Expr> ReductionDomain::split_predicate() const {
std::vector<Expr> predicates;
split_into_ands(contents->predicate, predicates);
return predicates;
}
void ReductionDomain::freeze() {
contents->frozen = true;
}
bool ReductionDomain::frozen() const {
return contents->frozen;
}
void ReductionDomain::accept(IRVisitor *visitor) const {
if (contents.defined()) {
contents->accept(visitor);
}
}
void ReductionDomain::mutate(IRMutator *mutator) {
if (contents.defined()) {
contents->mutate(mutator);
}
}
}
}