This source file includes following definitions.
- visit
- count_interleaves
- check_interleave_count
- define
- define
- element
- main
#include "Halide.h"
#include <stdio.h>
#include <math.h>
using std::vector;
using namespace Halide;
using namespace Halide::Internal;
class CountInterleaves : public IRVisitor {
public:
int result;
CountInterleaves() : result(0) {}
using IRVisitor::visit;
void visit(const Shuffle *op) {
if (op->is_interleave()) {
result++;
}
IRVisitor::visit(op);
}
};
int count_interleaves(Func f) {
Target t = get_jit_target_from_environment();
t.set_feature(Target::NoBoundsQuery);
t.set_feature(Target::NoAsserts);
f.compute_root();
std::vector<Module> submodules;
Stmt s = Internal::lower_main_stmt({f.function()}, f.name(), t);
CountInterleaves i;
s.accept(&i);
return i.result;
}
void check_interleave_count(Func f, int correct) {
int c = count_interleaves(f);
if (c < correct) {
printf("Func %s should have interleaved >= %d times but interleaved %d times instead.\n",
f.name().c_str(), correct, c);
exit(-1);
}
}
void define(FuncRef f, std::vector<Expr> values) {
if (values.size() == 1) {
f = values[0];
} else {
f = Tuple(values);
}
}
void define(FuncRef f, Expr value, int count) {
std::vector<Expr> values;
for (int i = 0; i < count; i++) {
values.push_back(value);
}
define(f, values);
}
Expr element(FuncRef f, int i) {
if (f.size() == 1) {
assert(i == 0);
return f;
} else {
return f[i];
}
}
int main(int argc, char **argv) {
Var x, y, c;
{
Target t = get_host_target();
if (t.arch == Target::ARM && t.bits == 32) {
printf("Skipping test on arm-32 (see the source for why)\n");
return 0;
}
}
for (int elements = 1; elements <= 5; elements++) {
Func f, g, h;
std::vector<Expr> f_def, g_def;
for (int i = 0; i < elements; i++) {
f_def.push_back(sin(x + i));
g_def.push_back(cos(x + i));
}
define(f(x), f_def);
define(g(x), g_def);
std::vector<Expr> h_def;
for (int i = 0; i < elements; i++) {
h_def.push_back(select(x % 2 == 0, 1.0f/element(f(x/2), i), element(g(x/2), i)*17.0f));
g_def.push_back(cos(x + i));
}
define(h(x), h_def);
f.compute_root();
g.compute_root();
h.vectorize(x, 8);
check_interleave_count(h, 1);
Realization results = h.realize(16);
for (int i = 0; i < elements; i++) {
Buffer<float> result = results[i];
for (int x = 0; x < 16; x++) {
float correct = ((x % 2) == 0) ? (1.0f/(sinf(x/2 + i))) : (cosf(x/2 + i)*17.0f);
float delta = result(x) - correct;
if (delta > 0.01 || delta < -0.01) {
printf("result(%d) = %f instead of %f\n", x, result(x), correct);
return -1;
}
}
}
}
{
Func planar, interleaved;
planar(x, y) = Halide::cast<float>( 3 * x + y );
interleaved(x, y) = planar(x, y);
Var xy("xy");
planar
.compute_at(interleaved, xy)
.vectorize(x, 4);
interleaved
.reorder(y, x)
.bound(y, 0, 3)
.bound(x, 0, 16)
.fuse(y, x, xy)
.vectorize(xy, 12);
interleaved
.output_buffer()
.dim(0)
.set_stride(3)
.dim(1)
.set_min(0)
.set_stride(1)
.set_extent(3);
Buffer<float> buff3(3, 16);
buff3.transpose(0, 1);
interleaved.realize(buff3);
check_interleave_count(interleaved, 1);
for (int x = 0; x < 16; x++) {
for (int y = 0; y < 3; y++) {
float correct = 3*x + y;
float delta = buff3(x, y) - correct;
if (delta > 0.01 || delta < -0.01) {
printf("result(%d) = %f instead of %f\n", x, buff3(x,y), correct);
return -1;
}
}
}
}
{
Func f1, f2, f3, f4, f5;
f1(x) = sin(x);
f2(x) = sin(2*x);
f3(x) = sin(3*x);
f4(x) = sin(4*x);
f5(x) = sin(5*x);
Func output4;
output4(x, y) = select(y == 0, f1(x),
y == 1, f2(x),
y == 2, f3(x),
f4(x));
output4
.reorder(y, x)
.bound(y, 0, 4)
.unroll(y)
.vectorize(x, 4);
output4.output_buffer()
.dim(0)
.set_stride(4)
.dim(1)
.set_min(0)
.set_stride(1)
.set_extent(4);
check_interleave_count(output4, 1);
Buffer<float> buff4(4, 16);
buff4.transpose(0, 1);
output4.realize(buff4);
for (int x = 0; x < 16; x++) {
for (int y = 0; y < 4; y++) {
float correct = sin((y+1)*x);
float delta = buff4(x, y) - correct;
if (delta > 0.01 || delta < -0.01) {
printf("result(%d) = %f instead of %f\n", x, buff4(x,y), correct);
return -1;
}
}
}
Func output5;
output5(x, y) = select(y == 0, f1(x),
y == 1, f2(x),
y == 2, f3(x),
y == 3, f4(x),
f5(x));
output5
.reorder(y, x)
.bound(y, 0, 5)
.unroll(y)
.vectorize(x, 4);
output5.output_buffer()
.dim(0)
.set_stride(5)
.dim(1)
.set_min(0)
.set_stride(1)
.set_extent(5);
check_interleave_count(output5, 1);
Buffer<float> buff5(5, 16);
buff5.transpose(0, 1);
output5.realize(buff5);
for (int x = 0; x < 16; x++) {
for (int y = 0; y < 5; y++) {
float correct = sin((y+1)*x);
float delta = buff5(x, y) - correct;
if (delta > 0.01 || delta < -0.01) {
printf("result(%d) = %f instead of %f\n", x, buff5(x,y), correct);
return -1;
}
}
}
}
{
Func f1, f2, f3, f4, f5;
f1(x) = sin(x);
f1.compute_root();
f2(x) = sin(2*x);
f2.compute_root();
Func unrolled;
unrolled(x, y) = select(x % 2 == 0, f1(x), f2(x)) + y;
Var xi, yi;
unrolled.tile(x, y, xi, yi, 16, 2).unroll(xi, 2).vectorize(xi, 4).unroll(xi).unroll(yi);
check_interleave_count(unrolled, 4);
}
for (int elements = 1; elements <= 5; elements++) {
Realization* refs = nullptr;
for (int i = 0; i < 2; i++) {
Func output6;
define(output6(x, y), cast<uint8_t>(x), elements);
RDom r(0, 16);
define(output6(2*r, 0), cast<uint8_t>(3), elements);
define(output6(2*r+1, 0), cast<uint8_t>(4), elements);
define(output6(2*r, 1), cast<uint8_t>(3), elements);
define(output6(2*r+1, 1), cast<uint8_t>(4), elements);
define(output6(3*r, 3), cast<uint8_t>(3), elements);
define(output6(3*r+1, 3), cast<uint8_t>(4), elements);
std::vector<Expr> rdef0, rdef1;
for (int i = 0; i < elements; i++) {
rdef0.push_back(element(output6(2*r, 2), i) + 1);
rdef1.push_back(element(output6(2*r+1, 2), i) + 1);
}
define(output6(2*r, 2), rdef0);
define(output6(2*r+1, 2), rdef1);
define(output6(3*r, 3), cast<uint8_t>(7), elements);
define(output6(3*r+2, 3), cast<uint8_t>(9), elements);
define(output6(3*r+1, 3), cast<uint8_t>(8), elements);
if (i == 0) {
refs = new Realization(output6.realize(50, 4));
} else {
for (int j = 0; j < 11; j++) {
output6.update(j).vectorize(r);
}
check_interleave_count(output6, 2*elements);
Realization outs = output6.realize(50, 4);
for (int e = 0; e < elements; e++) {
Buffer<uint8_t> ref = (*refs)[e];
Buffer<uint8_t> out = outs[e];
for (int y = 0; y < ref.height(); y++) {
for (int x = 0; x < ref.width(); x++) {
if (out(x, y) != ref(x, y)) {
printf("result(%d, %d) = %d instead of %d\n",
x, y, out(x, y), ref(x, y));
return -1;
}
}
}
}
}
}
delete refs;
}
{
Func square("square");
square(x, y) = cast(UInt(16), 5*x + y);
Func trans1("trans1");
trans1(x, y) = square(y, x);
Func trans2("trans2");
trans2(x, y) = square(y, x);
square.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8);
trans1.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8)
.vectorize(x)
.unroll(y);
trans2.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8)
.unroll(x)
.vectorize(y);
trans1.output_buffer()
.dim(0)
.set_min(0)
.set_stride(1)
.set_extent(8)
.dim(1)
.set_min(0)
.set_stride(8)
.set_extent(8);
trans2.output_buffer()
.dim(0)
.set_min(0)
.set_stride(1)
.set_extent(8)
.dim(1)
.set_min(0)
.set_stride(8)
.set_extent(8);
Buffer<uint16_t> result6(8, 8);
Buffer<uint16_t> result7(8, 8);
trans1.realize(result6);
trans2.realize(result7);
for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
int correct = 5*y + x;
if (result6(x,y) != correct) {
printf("result(%d) = %d instead of %d\n", x, result6(x,y), correct);
return -1;
}
if (result7(x,y) != correct) {
printf("result(%d) = %d instead of %d\n", x, result7(x,y), correct);
return -1;
}
}
}
check_interleave_count(trans1, 1);
check_interleave_count(trans2, 1);
}
printf("Success!\n");
return 0;
}