root/src/Reduction.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. check
  2. split_predicate_test
  3. frozen
  4. accept
  5. mutate
  6. deep_copy
  7. domain
  8. visit
  9. domain
  10. set_predicate
  11. where
  12. predicate
  13. split_predicate
  14. freeze
  15. frozen
  16. accept
  17. mutate

#include "Var.h"
#include "IR.h"
#include "IREquality.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "IRMutator.h"
#include "Reduction.h"
#include "Simplify.h"

namespace Halide {
namespace Internal {

namespace {

void check(Expr pred, std::vector<Expr> &expected) {
    std::vector<Expr> result;
    split_into_ands(pred, result);
    bool is_equal = true;

    if (result.size() != expected.size()) {
        is_equal = false;
    } else {
        for (size_t i = 0; i < expected.size(); ++i) {
            if (!equal(simplify(result[i]), simplify(expected[i]))) {
                is_equal = false;
                break;
            }
        }
    }

    if (!is_equal) {
        std::cout << "Expect predicate " << pred << " to be split into:\n";
        for (const auto &e : expected) {
            std::cout << "  " << e << "\n";
        }
        std::cout << "Got:\n";
        for (const auto &e : result) {
            std::cout << "  " << e << "\n";
        }
        internal_error << "\n";
    }
}

}

void split_predicate_test() {
    Expr x = Var("x"), y = Var("y"), z = Var("z"), w = Var("w");

    {
        std::vector<Expr> expected;
        expected.push_back(z < 10);
        check(z < 10, expected);
    }

    {
        std::vector<Expr> expected;
        expected.push_back((x < y) || (x == 10));
        check((x < y) || (x == 10), expected);
    }

    {
        std::vector<Expr> expected;
        expected.push_back(x < y);
        expected.push_back(x == 10);
        check((x < y) && (x == 10), expected);
    }

    {
        std::vector<Expr> expected;
        expected.push_back(x < y);
        expected.push_back(x == 10);
        expected.push_back(y == z);
        check((x < y) && (x == 10) && (y == z), expected);
    }

    {
        std::vector<Expr> expected;
        expected.push_back((w == 1) || ((x == 10) && (y == z)));
        check((w == 1) || ((x == 10) && (y == z)), expected);
    }

    {
        std::vector<Expr> expected;
        expected.push_back(x < y);
        expected.push_back((w == 1) || ((x == 10) && (y == z)));
        check((x < y) && ((w == 1) || ((x == 10) && (y == z))), expected);
    }

    std::cout << "Split predicate test passed" << std::endl;
}

struct ReductionDomainContents {
    mutable RefCount ref_count;
    std::vector<ReductionVariable> domain;
    Expr predicate;
    bool frozen;

    ReductionDomainContents() : predicate(const_true()), frozen(false) {
    }

    // Pass an IRVisitor through to all Exprs referenced in the ReductionDomainContents
    void accept(IRVisitor *visitor) {
        for (const ReductionVariable &rvar : domain) {
            if (rvar.min.defined()) {
                rvar.min.accept(visitor);
            }
            if (rvar.extent.defined()) {
                rvar.extent.accept(visitor);
            }
        }
        if (predicate.defined()) {
            predicate.accept(visitor);
        }
    }

    // Pass an IRMutator through to all Exprs referenced in the ReductionDomainContents
    void mutate(IRMutator *mutator) {
        for (ReductionVariable &rvar : domain) {
            if (rvar.min.defined()) {
                rvar.min = mutator->mutate(rvar.min);
            }
            if (rvar.extent.defined()) {
                rvar.extent = mutator->mutate(rvar.extent);
            }
        }
        if (predicate.defined()) {
            predicate = mutator->mutate(predicate);
        }
    }
};

template<>
EXPORT RefCount &ref_count<Halide::Internal::ReductionDomainContents>(const ReductionDomainContents *p) {return p->ref_count;}

template<>
EXPORT void destroy<Halide::Internal::ReductionDomainContents>(const ReductionDomainContents *p) {delete p;}

ReductionDomain::ReductionDomain(const std::vector<ReductionVariable> &domain) :
    contents(new ReductionDomainContents) {
    contents->domain = domain;
}

ReductionDomain ReductionDomain::deep_copy() const {
    if (!contents.defined()) {
        return ReductionDomain();
    }
    ReductionDomain copy(contents->domain);
    copy.contents->predicate = contents->predicate;
    copy.contents->frozen = contents->frozen;
    return copy;
}

const std::vector<ReductionVariable> &ReductionDomain::domain() const {
    return contents->domain;
}

namespace {
class DropSelfReferences : public IRMutator {
    using IRMutator::visit;

    void visit(const Variable *op) {
        if (op->reduction_domain.defined()) {
            user_assert(op->reduction_domain.same_as(domain))
                << "An RDom's predicate may only refer to its own RVars, "
                << " not the RVars of some other RDom. "
                << "Cannot set the predicate to : " << predicate << "\n";
            expr = Variable::make(op->type, op->name);
        } else {
            expr = op;
        }
    }
public:
    Expr predicate;
    const ReductionDomain &domain;
    DropSelfReferences(Expr p, const ReductionDomain &d) :
        predicate(p), domain(d) {}
};
}

void ReductionDomain::set_predicate(Expr p) {
    // The predicate can refer back to the RDom. We need to break
    // those cycles to prevent a leak.
    contents->predicate = DropSelfReferences(p, *this).mutate(p);
}

void ReductionDomain::where(Expr predicate) {
    set_predicate(simplify(contents->predicate && predicate));
}

Expr ReductionDomain::predicate() const {
    return contents->predicate;
}

std::vector<Expr> ReductionDomain::split_predicate() const {
    std::vector<Expr> predicates;
    split_into_ands(contents->predicate, predicates);
    return predicates;
}

void ReductionDomain::freeze() {
    contents->frozen = true;
}

bool ReductionDomain::frozen() const {
    return contents->frozen;
}

void ReductionDomain::accept(IRVisitor *visitor) const {
    if (contents.defined()) {
        contents->accept(visitor);
    }
}

void ReductionDomain::mutate(IRMutator *mutator) {
    if (contents.defined()) {
        contents->mutate(mutator);
    }
}

}
}

/* [<][>][^][v][top][bottom][index][help] */