This source file includes following definitions.
- flip_x
- main
#include "Halide.h"
#include <stdio.h>
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
extern "C" DLLEXPORT int flip_x(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) {
int min = out->dim[0].min;
int max = out->dim[0].min + out->dim[0].extent - 1;
int extent = out->dim[0].extent;
int flipped_min = -max;
int flipped_max = -min;
if (in1->host == nullptr || in2->host == nullptr) {
printf("Doing flip_x bounds inference over [%d %d]\n", min, max);
if (in1->host == nullptr) {
in1->dim[0].min = flipped_min;
in1->dim[0].extent = extent;
}
if (in2->host == nullptr) {
in2->dim[0].min = flipped_min;
in2->dim[0].extent = extent;
}
} else {
assert(in1->type == halide_type_of<uint8_t>());
assert(in2->type == halide_type_of<int32_t>());
assert(out->type == halide_type_of<uint8_t>());
printf("Computing flip_x over [%d %d]\n", min, max);
assert(in1->dim[0].min <= flipped_min &&
in1->dim[0].min + in1->dim[0].extent > flipped_max);
assert(in2->dim[0].min <= flipped_min &&
in2->dim[0].min + in2->dim[0].extent > flipped_max);
assert(in1->dim[0].stride == 1 && in2->dim[0].stride == 1 && out->dim[0].stride == 1);
uint8_t *dst = (uint8_t *)(out->host) - out->dim[0].min;
uint8_t *src1 = (uint8_t *)(in1->host) - in1->dim[0].min;
int *src2 = (int *)(in2->host) - in2->dim[0].min;
for (int i = min; i <= max; i++) {
dst[i] = src1[-i] + src2[-i];
}
}
return 0;
}
using namespace Halide;
int main(int argc, char **argv) {
Func f, g, h;
Var x;
Buffer<uint8_t> input(100);
input.set_min(-99);
lambda(x, cast<uint8_t>(x*x)).realize(input);
assert(input(-99) == (uint8_t)(-99*-99));
f(x) = x*x;
std::vector<ExternFuncArgument> args(2);
args[0] = input;
args[1] = f;
g.define_extern("flip_x", args, UInt(8), 1);
h(x) = g(x) * 2;
f.compute_at(h, x);
g.compute_at(h, x);
Var xi;
h.vectorize(x, 8).unroll(x, 2).split(x, x, xi, 4).parallel(x);
Buffer<uint8_t> result = h.realize(100);
for (int i = 0; i < 100; i++) {
uint8_t correct = 4*i*i;
if (result(i) != correct) {
printf("result(%d) = %d instead of %d\n", i, result(i), correct);
return -1;
}
}
printf("Success!\n");
return 0;
}