This source file includes following definitions.
- reset_trace
- my_trace
- reset_alloc_counts
- my_malloc
- my_free
- visit
- visit
- main
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
bool vector_store;
bool scalar_store;
uint16_t vector_store_lanes;
void reset_trace() {
vector_store_lanes = 0;
vector_store = scalar_store = false;
}
int my_trace(void *user_context, const halide_trace_event_t *ev) {
if (ev->event == halide_trace_store) {
if (ev->type.lanes > 1) {
vector_store = true;
vector_store_lanes = ev->type.lanes;
} else {
scalar_store = true;
}
}
return 0;
}
int empty_allocs = 0, nonempty_allocs = 0, frees = 0;
void reset_alloc_counts() {
empty_allocs = nonempty_allocs = frees = 0;
}
void *my_malloc(void *ctx, size_t sz) {
if (sz == 0) {
empty_allocs++;
} else {
nonempty_allocs++;
}
return malloc(sz);
}
void my_free(void *ctx, void *ptr) {
frees++;
free(ptr);
}
int if_then_else_count = 0;
class CountIfThenElse : public Internal::IRMutator {
int producer_consumers;
public:
CountIfThenElse() : producer_consumers(0) {}
void visit(const Internal::ProducerConsumer *op) {
producer_consumers++;
IRMutator::visit(op);
producer_consumers--;
}
void visit(const Internal::IfThenElse *op) {
if (producer_consumers > 0) {
if_then_else_count++;
}
Internal::IRMutator::visit(op);
}
using Internal::IRMutator::visit;
};
int main(int argc, char **argv) {
{
Param<bool> param;
Func f;
Var x;
f(x) = select(param, x*3, x*17);
Expr cond = (f.output_buffer().width() >= 4);
f.specialize(cond).vectorize(x, 4);
f.specialize(cond).specialize(param);
f.specialize(param);
f.set_custom_trace(&my_trace);
f.trace_stores();
Buffer<int> out(100);
param.set(true);
reset_trace();
f.realize(out);
for (int i = 0; i < out.width(); i++) {
int correct = i*3;
if (out(i) != correct) {
printf("out(%d) was %d instead of %d\n",
i, out(i), correct);
}
}
param.set(false);
f.realize(out);
for (int i = 0; i < out.width(); i++) {
int correct = i*17;
if (out(i) != correct) {
printf("out(%d) was %d instead of %d\n",
i, out(i), correct);
}
}
if (!vector_store || scalar_store) {
printf("This was supposed to use vector stores\n");
return -1;
}
out = Buffer<int>(3);
param.set(true);
reset_trace();
f.realize(out);
for (int i = 0; i < out.width(); i++) {
int correct = i*3;
if (out(i) != correct) {
printf("out(%d) was %d instead of %d\n",
i, out(i), correct);
}
}
param.set(false);
f.realize(out);
for (int i = 0; i < out.width(); i++) {
int correct = i*17;
if (out(i) != correct) {
printf("out(%d) was %d instead of %d\n",
i, out(i), correct);
}
}
if (vector_store || !scalar_store) {
printf("This was supposed to use scalar stores\n");
return -1;
}
}
{
Func f1, f2, g1, g2;
Var x;
f1(x) = x + 7;
g1(x) = f1(x) + f1(x + 1);
f2(x) = x * 34;
g2(x) = f2(x) + f2(x - 1);
Param<bool> param;
Func out;
out(x) = select(param, g1(x), g2(x));
f1.compute_root();
g1.compute_root();
f2.compute_root();
out.specialize(param);
out.set_custom_allocator(&my_malloc, &my_free);
reset_alloc_counts();
param.set(true);
out.realize(100);
if (empty_allocs != 1 || nonempty_allocs != 2 || frees != 3) {
printf("There were supposed to be 1 empty alloc, 2 nonempty allocs, and 3 frees.\n"
"Instead we got %d empty allocs, %d nonempty allocs, and %d frees.\n",
empty_allocs, nonempty_allocs, frees);
return -1;
}
reset_alloc_counts();
param.set(false);
out.realize(100);
if (empty_allocs != 2 || nonempty_allocs != 1 || frees != 3) {
printf("There were supposed to be 2 empty allocs, 1 nonempty alloc, and 3 frees.\n"
"Instead we got %d empty allocs, %d nonempty allocs, and %d frees.\n",
empty_allocs, nonempty_allocs, frees);
return -1;
}
}
{
ImageParam im(Float(32), 1);
im.dim(0).set_stride(Expr());
Func f;
Var x;
f(x) = im(x);
f.specialize(im.dim(0).stride() == 1 && im.width() >= 8).vectorize(x, 8);
f.trace_stores();
f.set_custom_trace(&my_trace);
f.infer_input_bounds(5);
int m = im.get().min(0), e = im.get().extent(0);
if (m != 0 || e != 5) {
printf("min, extent = %d, %d instead of 0, 5\n", m, e);
return -1;
}
reset_trace();
f.realize(5);
if (!scalar_store || vector_store) {
printf("These stores were supposed to be scalar.\n");
return -1;
}
Buffer<float> image(100);
im.set(image);
reset_trace();
f.realize(100);
if (scalar_store || !vector_store) {
printf("These stores were supposed to be vector.\n");
return -1;
}
}
{
ImageParam im(Float(32), 1);
Param<bool> param;
Func f;
Var x;
f(x) = select(param, im(x + 10), im(x - 10));
f.specialize(param);
param.set(true);
f.infer_input_bounds(100);
int m = im.get().min(0);
if (m != 10) {
printf("min %d instead of 10\n", m);
return -1;
}
param.set(false);
im.reset();
f.infer_input_bounds(100);
m = im.get().min(0);
if (m != -10) {
printf("min %d instead of -10\n", m);
return -1;
}
}
{
Func f;
Var x;
Param<int> start, size;
RDom r(start, size);
f(x) = x;
f(r) = 10 - r;
f.update().specialize(size == 1);
f.update().specialize(size == 0);
start.set(0);
size.set(1);
f.realize(100);
}
{
ImageParam im(Float(32), 1);
Param<bool> param;
Func f;
Var x;
f(x) = select(param, im(x), 0.0f);
f.specialize(param);
param.set(false);
Buffer<float> image(10);
im.set(image);
f.realize(100);
}
{
ImageParam im(Int(32), 2);
Func f;
Var x, y;
f(x, y) = im(x, y);
Expr cond = f.output_buffer().width() >= 4;
f.reorder(y, x).unroll(y, 2).reorder(x, y);
f.specialize(cond).vectorize(x, 4);
f.infer_input_bounds(3, 1);
if (im.get().extent(0) != 3) {
printf("extent(0) was supposed to be 3.\n");
return -1;
}
if (im.get().extent(1) != 2) {
printf("extent(1) was supposed to be 2.\n");
return -1;
}
}
{
ImageParam im(Int(32), 1);
Func f, g, h, out;
Var x;
f(x) = im(x);
g(x) = f(x);
h(x) = g(x);
out(x) = h(x);
Expr w = out.output_buffer().dim(0).extent();
out.output_buffer().dim(0).set_min(0);
f.compute_root().specialize(w >= 4).vectorize(x, 4);
g.compute_root().vectorize(x, 4);
h.compute_root().vectorize(x, 4);
out.specialize(w >= 4).vectorize(x, 4);
Buffer<int> input(3), output(3);
im.set(input);
out.realize(output);
}
{
ImageParam im(Int(32), 2);
Param<bool> cond1, cond2;
Func f, out;
Var x, y;
f(x, y) = im(x, y);
out(x, y) = f(x, y);
f.compute_at(out, x).specialize(cond1 && cond2).vectorize(x, 4);
out.compute_root().specialize(cond1 && cond2).vectorize(x, 4);
if_then_else_count = 0;
CountIfThenElse pass1;
for (auto ff : out.compile_to_module(out.infer_arguments()).functions()) {
pass1.mutate(ff.body);
}
Buffer<int> input(3, 3), output(3, 3);
im.set(input);
out.realize(output);
if (if_then_else_count != 1) {
printf("Expected 1 IfThenElse stmts. Found %d.\n", if_then_else_count);
return -1;
}
}
{
ImageParam im(Int(32), 2);
Param<bool> cond1, cond2;
Func f, out;
Var x, y;
f(x, y) = im(x, y);
out(x, y) = f(x, y);
f.compute_at(out, x).specialize(cond1).vectorize(x, 4);
out.compute_root().specialize(cond1 && cond2).vectorize(x, 4);
if_then_else_count = 0;
CountIfThenElse pass2;
for (auto ff : out.compile_to_module(out.infer_arguments()).functions()) {
pass2.mutate(ff.body);
}
Buffer<int> input(3, 3), output(3, 3);
im.set(input);
out.realize(output);
if (if_then_else_count != 2) {
printf("Expected 2 IfThenElse stmts. Found %d.\n", if_then_else_count);
return -1;
}
}
{
ImageParam im(Int(32), 2);
Param<int> p;
Expr test = (p > 73) || (p*p + p + 1 == 0);
Func f;
Var x;
f(x) = select(test, im(x, 0), im(0, x));
f.specialize(test);
p.set(100);
f.infer_input_bounds(10);
int w = im.get().width();
int h = im.get().height();
if (w != 10 || h != 1) {
printf("Incorrect inferred size: %d %d\n", w, h);
return -1;
}
im.reset();
p.set(-100);
f.infer_input_bounds(10);
w = im.get().width();
h = im.get().height();
if (w != 1 || h != 10) {
printf("Incorrect inferred size: %d %d\n", w, h);
return -1;
}
}
{
ImageParam im(Int(32), 2);
Param<int> p;
Expr test = (p > 73);
Func f;
Var x;
f(x) = select(p > 50, im(x, 0), im(0, x));
f.specialize(test);
p.set(100);
f.infer_input_bounds(10);
int w = im.get().width();
int h = im.get().height();
if (w != 10 || h != 1) {
printf("Incorrect inferred size: %d %d\n", w, h);
return -1;
}
im.reset();
p.set(-100);
f.infer_input_bounds(10);
w = im.get().width();
h = im.get().height();
if (w != 10 || h != 10) {
printf("Incorrect inferred size: %d %d\n", w, h);
return -1;
}
}
{
Var x, y;
Param<int> p;
Expr const_false = Expr(0) == Expr(1);
Expr const_true = Expr(0) == Expr(0);
Expr different_const_true = Expr(1) == Expr(1);
Func f;
f(x) = x;
f.specialize(p == 0).vectorize(x, 32);
f.specialize(const_false).vectorize(x, 8);
f.vectorize(x, 4);
_halide_user_assert(f.function().definition().specializations().size() == 2);
std::map<std::string, Internal::Function> env;
env.insert({f.function().name(), f.function()});
simplify_specializations(env);
const auto &s = f.function().definition().specializations();
_halide_user_assert(s.size() == 1);
_halide_user_assert(s[0].condition.as<Internal::EQ>() && is_zero(s[0].condition.as<Internal::EQ>()->b));
f.set_custom_trace(&my_trace);
f.trace_stores();
vector_store_lanes = 0;
p.set(0);
f.realize(100);
_halide_user_assert(vector_store_lanes == 32);
vector_store_lanes = 0;
p.set(42);
f.realize(100);
_halide_user_assert(vector_store_lanes == 4);
}
{
Var x;
Param<int> p;
Expr const_false = Expr(0) == Expr(1);
Expr const_true = Expr(0) == Expr(0);
Expr different_const_true = Expr(1) == Expr(1);
Func f;
f(x) = x;
f.specialize(p == 0).vectorize(x, 32);
f.specialize(const_true).vectorize(x, 16);
f.specialize(const_false).vectorize(x, 4);
f.specialize(p == 42).vectorize(x, 8);
f.specialize(const_true);
f.specialize(different_const_true);
_halide_user_assert(f.function().definition().specializations().size() == 5);
std::map<std::string, Internal::Function> env;
env.insert({f.function().name(), f.function()});
simplify_specializations(env);
const auto &s = f.function().definition().specializations();
_halide_user_assert(s.size() == 1);
_halide_user_assert(s[0].condition.as<Internal::EQ>() && is_zero(s[0].condition.as<Internal::EQ>()->b));
f.set_custom_trace(&my_trace);
f.trace_stores();
vector_store_lanes = 0;
p.set(42);
f.realize(100);
_halide_user_assert(vector_store_lanes == 16);
vector_store_lanes = 0;
p.set(0);
f.realize(100);
_halide_user_assert(vector_store_lanes == 32);
}
{
Var x;
Param<int> p;
Expr const_true = Expr(0) == Expr(0);
Expr different_const_true = Expr(1) == Expr(1);
Func f("foof");
f(x) = x;
f.specialize(p == 0).vectorize(x, 32);
f.specialize(const_true).vectorize(x, 16);
f.set_custom_trace(&my_trace);
f.trace_stores();
vector_store_lanes = 0;
p.set(42);
f.realize(100);
_halide_user_assert(vector_store_lanes == 16);
vector_store_lanes = 0;
p.set(0);
f.realize(100);
_halide_user_assert(vector_store_lanes == 32);
}
{
Var x;
Param<int> p;
Func f;
f(x) = x;
f.specialize(p == 0);
f.specialize_fail("Unhandled Param value encountered.");
f.specialize(p == 0).vectorize(x, 32);
f.set_custom_trace(&my_trace);
f.trace_stores();
vector_store_lanes = 0;
p.set(0);
f.realize(100);
_halide_user_assert(vector_store_lanes == 32);
}
printf("Success!\n");
return 0;
}