This source file includes following definitions.
- visit
- my_trace
- main
#include "Halide.h"
using namespace Halide;
using namespace Halide::Internal;
class CheckForSelects : public IRVisitor {
using IRVisitor::visit;
void visit(const Select *op) {
result = true;
}
public:
bool result = false;
};
int trace_min, trace_extent;
int my_trace(void *user_context, const halide_trace_event_t *e) {
if (e->event == 2) {
trace_min = e->coordinates[0];
trace_extent = e->coordinates[1];
}
return 0;
}
int main(int argc, char **argv) {
{
Func f, g, h;
Var x;
f(x) = 3;
g(x) = select(x % 2 == 0, f(x+1), f(x-1)+8);
Param<int> p;
h(x) = g(x-p) + g(x+p);
f.compute_root();
g.compute_root().align_bounds(x, 2).unroll(x, 2).trace_realizations();
Module m = g.compile_to_module({p});
CheckForSelects checker;
m.functions()[0].body.accept(&checker);
if (checker.result) {
printf("Lowered code contained a select\n");
return -1;
}
p.set(3);
h.set_custom_trace(my_trace);
Buffer<int> result = h.realize(10);
for (int i = 0; i < 10; i++) {
int correct = (i&1) == 1 ? 6 : 22;
if (result(i) != correct) {
printf("result(%d) = %d instead of %d\n",
i, result(i), correct);
return -1;
}
}
if (trace_min != -4 || trace_extent != 18) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
p.set(4);
assert(result.data());
h.realize(result);
if (trace_min != -4 || trace_extent != 18) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
assert(result.data());
p.set(5);
h.realize(result);
if (trace_min != -6 || trace_extent != 22) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
}
{
Func f, g, h;
Var x;
f(x) = 3;
g(x) = select(x % 2 == 0, f(x+1), f(x-1)+8);
Param<int> p;
h(x) = g(x-p) + g(x+p);
f.compute_root();
g.compute_root().align_bounds(x, 2, 1).unroll(x, 2).trace_realizations();
Module m = g.compile_to_module({p});
CheckForSelects checker;
m.functions()[0].body.accept(&checker);
if (checker.result) {
printf("Lowered code contained a select\n");
return -1;
}
p.set(3);
h.set_custom_trace(my_trace);
Buffer<int> result = h.realize(10);
for (int i = 0; i < 10; i++) {
int correct = (i&1) == 1 ? 6 : 22;
if (result(i) != correct) {
printf("result(%d) = %d instead of %d\n",
i, result(i), correct);
return -1;
}
}
if (trace_min != -3 || trace_extent != 16) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
p.set(4);
h.realize(result);
if (trace_min != -5 || trace_extent != 20) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
p.set(5);
h.realize(result);
if (trace_min != -5 || trace_extent != 20) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
}
printf("Success!\n");
return 0;
}