This source file includes following definitions.
- modulus_remainder
- modulus_remainder
- reduce_expr_modulo
- reduce_expr_modulo
- analyze
- check
- modulus_remainder_test
- visit
- visit
- visit
- visit
- visit
- visit
- gcd
- lcm
- mod
- visit
- visit
- visit
- visit
- unify_alternatives
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
#include "ModulusRemainder.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "IR.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
class ComputeModulusRemainder : public IRVisitor {
public:
ModulusRemainder analyze(Expr e);
int modulus, remainder;
Scope<ModulusRemainder> scope;
ComputeModulusRemainder(const Scope<ModulusRemainder> *s) {
scope.set_containing_scope(s);
}
void visit(const IntImm *);
void visit(const UIntImm *);
void visit(const FloatImm *);
void visit(const StringImm *);
void visit(const Cast *);
void visit(const Variable *);
void visit(const Add *);
void visit(const Sub *);
void visit(const Mul *);
void visit(const Div *);
void visit(const Mod *);
void visit(const Min *);
void visit(const Max *);
void visit(const EQ *);
void visit(const NE *);
void visit(const LT *);
void visit(const LE *);
void visit(const GT *);
void visit(const GE *);
void visit(const And *);
void visit(const Or *);
void visit(const Not *);
void visit(const Select *);
void visit(const Load *);
void visit(const Ramp *);
void visit(const Broadcast *);
void visit(const Call *);
void visit(const Let *);
void visit(const LetStmt *);
void visit(const AssertStmt *);
void visit(const ProducerConsumer *);
void visit(const For *);
void visit(const Store *);
void visit(const Provide *);
void visit(const Allocate *);
void visit(const Realize *);
void visit(const Block *);
void visit(const IfThenElse *);
void visit(const Free *);
void visit(const Evaluate *);
void visit(const Shuffle *);
void visit(const Prefetch *);
};
ModulusRemainder modulus_remainder(Expr e) {
ComputeModulusRemainder mr(nullptr);
return mr.analyze(e);
}
ModulusRemainder modulus_remainder(Expr e, const Scope<ModulusRemainder> &scope) {
ComputeModulusRemainder mr(&scope);
return mr.analyze(e);
}
bool reduce_expr_modulo(Expr expr, int modulus, int *remainder) {
ModulusRemainder result = modulus_remainder(expr);
if (result.modulus % modulus == 0) {
*remainder = result.remainder % modulus;
return true;
} else {
return false;
}
}
bool reduce_expr_modulo(Expr expr, int modulus, int *remainder, const Scope<ModulusRemainder> &scope) {
ModulusRemainder result = modulus_remainder(expr, scope);
if (result.modulus % modulus == 0) {
*remainder = result.remainder % modulus;
return true;
} else {
return false;
}
}
ModulusRemainder ComputeModulusRemainder::analyze(Expr e) {
e.accept(this);
return ModulusRemainder(modulus, remainder);
}
namespace {
void check(Expr e, int m, int r) {
ModulusRemainder result = modulus_remainder(e);
if (result.modulus != m || result.remainder != r) {
std::cerr << "Test failed for modulus_remainder:\n";
std::cerr << "Expression: " << e << "\n";
std::cerr << "Correct modulus, remainder = " << m << ", " << r << "\n";
std::cerr << "Computed modulus, remainder = "
<< result.modulus << ", "
<< result.remainder << "\n";
exit(-1);
}
}
}
void modulus_remainder_test() {
Expr x = Variable::make(Int(32), "x");
Expr y = Variable::make(Int(32), "y");
check((30*x + 3) + (40*y + 2), 10, 5);
check((6*x + 3) * (4*y + 1), 2, 1);
check(max(30*x - 24, 40*y + 31), 5, 1);
check(10*x - 33*y, 1, 0);
check(10*x - 35*y, 5, 0);
check(123, 0, 123);
check(Let::make("y", x*3 + 4, y*3 + 4), 9, 7);
std::cout << "modulus_remainder test passed\n";
}
void ComputeModulusRemainder::visit(const IntImm *op) {
remainder = op->value;
modulus = 0;
}
void ComputeModulusRemainder::visit(const UIntImm *op) {
internal_error << "modulus_remainder of uint\n";
}
void ComputeModulusRemainder::visit(const FloatImm *) {
internal_error << "modulus_remainder of float\n";
}
void ComputeModulusRemainder::visit(const StringImm *) {
internal_error << "modulus_remainder of string\n";
}
void ComputeModulusRemainder::visit(const Cast *) {
modulus = 1;
remainder = 0;
}
void ComputeModulusRemainder::visit(const Variable *op) {
if (scope.contains(op->name)) {
ModulusRemainder mod_rem = scope.get(op->name);
modulus = mod_rem.modulus;
remainder = mod_rem.remainder;
} else {
modulus = 1;
remainder = 0;
}
}
int gcd(int a, int b) {
if (a < b) std::swap(a, b);
while (b != 0) {
int64_t tmp = b;
b = a % b;
a = tmp;
}
return a;
}
int lcm(int a, int b) {
return (a*b)/gcd(a, b);
}
int mod(int a, int m) {
if (m == 0) return a;
return mod_imp(a, m);
}
void ComputeModulusRemainder::visit(const Add *op) {
ModulusRemainder a = analyze(op->a);
ModulusRemainder b = analyze(op->b);
modulus = gcd(a.modulus, b.modulus);
remainder = mod(a.remainder + b.remainder, modulus);
}
void ComputeModulusRemainder::visit(const Sub *op) {
ModulusRemainder a = analyze(op->a);
ModulusRemainder b = analyze(op->b);
modulus = gcd(a.modulus, b.modulus);
remainder = mod(a.remainder - b.remainder, modulus);
}
void ComputeModulusRemainder::visit(const Mul *op) {
ModulusRemainder a = analyze(op->a);
ModulusRemainder b = analyze(op->b);
if (a.modulus == 0) {
modulus = a.remainder * b.modulus;
remainder = a.remainder * b.remainder;
} else if (b.modulus == 0) {
modulus = b.remainder * a.modulus;
remainder = a.remainder * b.remainder;
} else if (a.remainder == 0 && b.remainder == 0) {
modulus = a.modulus * b.modulus;
remainder = 0;
} else if (a.remainder == 0) {
modulus = a.modulus * gcd(b.modulus, b.remainder);
remainder = 0;
} else if (b.remainder == 0) {
modulus = b.modulus * gcd(a.modulus, a.remainder);
remainder = 0;
} else {
modulus = gcd(a.modulus, b.modulus);
a.remainder = mod(a.remainder * b.remainder, modulus);
}
}
void ComputeModulusRemainder::visit(const Div *) {
remainder = 0;
modulus = 1;
}
namespace {
ModulusRemainder unify_alternatives(ModulusRemainder a, ModulusRemainder b) {
int modulus = gcd(a.modulus, b.modulus);
int64_t diff = (int64_t)a.remainder - (int64_t)b.remainder;
if (!Int(32).can_represent(diff)) {
return ModulusRemainder(0, 1);
}
if (diff < 0) diff = -diff;
modulus = gcd((int)diff, modulus);
int ra = mod(a.remainder, modulus);
internal_assert(ra == mod(b.remainder, modulus))
<< "There's a bug inside ModulusRemainder in unify_alternatives:\n"
<< "a.modulus = " << a.modulus << "\n"
<< "a.remainder = " << a.remainder << "\n"
<< "b.modulus = " << b.modulus << "\n"
<< "b.remainder = " << b.remainder << "\n"
<< "diff = " << diff << "\n"
<< "unified modulus = " << modulus << "\n"
<< "unified remainder = " << ra << "\n";
return ModulusRemainder(modulus, ra);
}
}
void ComputeModulusRemainder::visit(const Mod *op) {
ModulusRemainder a = analyze(op->a);
ModulusRemainder b = analyze(op->b);
modulus = gcd(a.modulus, b.modulus);
modulus = gcd(modulus, b.remainder);
remainder = mod(a.remainder, modulus);
}
void ComputeModulusRemainder::visit(const Min *op) {
ModulusRemainder r = unify_alternatives(analyze(op->a), analyze(op->b));
modulus = r.modulus;
remainder = r.remainder;
}
void ComputeModulusRemainder::visit(const Max *op) {
ModulusRemainder r = unify_alternatives(analyze(op->a), analyze(op->b));
modulus = r.modulus;
remainder = r.remainder;
}
void ComputeModulusRemainder::visit(const EQ *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const NE *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const LT *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const LE *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const GT *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const GE *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const And *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const Or *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const Not *) {
internal_assert(false) << "modulus_remainder of bool\n";
}
void ComputeModulusRemainder::visit(const Select *op) {
ModulusRemainder r = unify_alternatives(analyze(op->true_value),
analyze(op->false_value));
modulus = r.modulus;
remainder = r.remainder;
}
void ComputeModulusRemainder::visit(const Load *) {
modulus = 1;
remainder = 0;
}
void ComputeModulusRemainder::visit(const Ramp *) {
internal_assert(false) << "modulus_remainder of vector\n";
}
void ComputeModulusRemainder::visit(const Broadcast *) {
internal_assert(false) << "modulus_remainder of vector\n";
}
void ComputeModulusRemainder::visit(const Call *) {
modulus = 1;
remainder = 0;
}
void ComputeModulusRemainder::visit(const Let *op) {
bool value_interesting = op->value.type().is_int();
if (value_interesting) {
ModulusRemainder val = analyze(op->value);
scope.push(op->name, val);
}
ModulusRemainder val = analyze(op->body);
if (value_interesting) {
scope.pop(op->name);
}
modulus = val.modulus;
remainder = val.remainder;
}
void ComputeModulusRemainder::visit(const Shuffle *op) {
internal_assert(op->indices.size() == 1) << "modulus_remainder of vector\n";
modulus = 1;
remainder = 0;
}
void ComputeModulusRemainder::visit(const LetStmt *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const AssertStmt *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const ProducerConsumer *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const For *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Store *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Provide *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Allocate *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Realize *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Block *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Free *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const IfThenElse *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Evaluate *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
void ComputeModulusRemainder::visit(const Prefetch *) {
internal_assert(false) << "modulus_remainder of statement\n";
}
}
}