This source file includes following definitions.
- make_real
- make_complex
- log2
- main
#include "Halide.h"
#include <cstdio>
#include <vector>
#include "fft.h"
#include "halide_benchmark.h"
#ifdef WITH_FFTW
#include <fftw3.h>
#endif
using namespace Halide;
using namespace Halide::Tools;
Var x("x"), y("y");
template <typename T>
Func make_real(const Buffer<T> &re) {
Func ret;
ret(x, y) = re(x, y);
return ret;
}
template <typename T>
ComplexFunc make_complex(const Buffer<T> &re) {
ComplexFunc ret;
ret(x, y) = re(x, y);
return ret;
}
double log2(double x) {
return log(x)/log(2.0);
}
int main(int argc, char **argv) {
int W = 32;
int H = 32;
if (argc >= 3) {
W = atoi(argv[1]);
H = atoi(argv[2]);
}
std::string output_dir;
if (argc >= 4) {
output_dir = argv[3];
}
Buffer<float> in(W, H);
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
in(x, y) = (float)rand()/(float)RAND_MAX;
}
}
const int box = 3;
Buffer<float> kernel(W, H);
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
int u = x < (W - x) ? x : (W - x);
int v = y < (H - y) ? y : (H - y);
kernel(x, y) = u <= box/2 && v <= box/2 ? 1.0f/(box*box) : 0.0f;
}
}
Target target = get_jit_target_from_environment();
Fft2dDesc fwd_desc;
Fft2dDesc inv_desc;
inv_desc.gain = 1.0f/(W*H);
Func filtered_c2c;
{
ComplexFunc dft_in = fft2d_c2c(make_complex(in), W, H, -1, target, fwd_desc);
ComplexFunc dft_kernel = fft2d_c2c(make_complex(kernel), W, H, -1, target, fwd_desc);
dft_in.compute_root();
dft_kernel.compute_root();
ComplexFunc dft_filtered("dft_filtered");
dft_filtered(x, y) = dft_in(x, y) * dft_kernel(x, y);
ComplexFunc dft_out = fft2d_c2c(dft_filtered, W, H, 1, target, inv_desc);
dft_out.compute_root();
filtered_c2c(x, y) = re(dft_out(x, y));
}
Func filtered_r2c;
{
ComplexFunc dft_in = fft2d_r2c(make_real(in), W, H, target, fwd_desc);
ComplexFunc dft_kernel = fft2d_r2c(make_real(kernel), W, H, target, fwd_desc);
dft_in.compute_root();
dft_kernel.compute_root();
ComplexFunc dft_filtered("dft_filtered");
dft_filtered(x, y) = dft_in(x, y) * dft_kernel(x, y);
filtered_r2c = fft2d_c2r(dft_filtered, W, H, target, inv_desc);
}
Buffer<float> result_c2c = filtered_c2c.realize(W, H, target);
Buffer<float> result_r2c = filtered_r2c.realize(W, H, target);
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
float correct = 0;
for (int i = -box/2; i <= box/2; i++) {
for (int j = -box/2; j <= box/2; j++) {
correct += in((x + j + W)%W, (y + i + H)%H);
}
}
correct /= box*box;
if (fabs(result_c2c(x, y) - correct) > 1e-6f) {
printf("result_c2c(%d, %d) = %f instead of %f\n", x, y, result_c2c(x, y), correct);
return -1;
}
if (fabs(result_r2c(x, y) - correct) > 1e-6f) {
printf("result_r2c(%d, %d) = %f instead of %f\n", x, y, result_r2c(x, y), correct);
return -1;
}
}
}
const int samples = 100;
const int reps = 1000;
Var rep("rep");
Buffer<float> re_in = lambda(x, y, 0.0f).realize(W, H);
Buffer<float> im_in = lambda(x, y, 0.0f).realize(W, H);
printf("%12s %5s%11s%5s %5s%11s%5s\n", "", "", "Halide", "", "", "FFTW", "");
printf("%12s %10s %10s %10s %10s %10s\n", "DFT type", "Time (us)", "MFLOP/s", "Time (us)", "MFLOP/s", "Ratio");
ComplexFunc c2c_in;
c2c_in(x, y, rep) = {re_in(x, y), im_in(x, y)};
Func bench_c2c = fft2d_c2c(c2c_in, W, H, -1, target, fwd_desc);
bench_c2c.compile_to_lowered_stmt(output_dir + "c2c.html", bench_c2c.infer_arguments(), HTML);
Realization R_c2c = bench_c2c.realize(W, H, reps, target);
R_c2c[0].raw_buffer()->dim[2].stride = 0;
R_c2c[1].raw_buffer()->dim[2].stride = 0;
double halide_t = benchmark(samples, 1, [&]() { bench_c2c.realize(R_c2c); })*1e6/reps;
#ifdef WITH_FFTW
std::vector<std::pair<float, float>> fftw_c1(W * H);
std::vector<std::pair<float, float>> fftw_c2(W * H);
fftwf_plan c2c_plan = fftwf_plan_dft_2d(W, H, (fftwf_complex*)&fftw_c1[0], (fftwf_complex*)&fftw_c2[0], FFTW_FORWARD, FFTW_EXHAUSTIVE);
double fftw_t = benchmark(samples, reps, [&]() { fftwf_execute(c2c_plan); })*1e6;
#else
double fftw_t = 0;
#endif
printf("%12s %10.3f %10.2f %10.3f %10.2f %10.3g\n",
"c2c",
halide_t,
5*W*H*(log2(W) + log2(H))/halide_t,
fftw_t,
5*W*H*(log2(W) + log2(H))/fftw_t,
fftw_t / halide_t);
Func r2c_in;
r2c_in(x, y, rep) = re_in(x, y);
Func bench_r2c = fft2d_r2c(r2c_in, W, H, target, fwd_desc);
bench_r2c.compile_to_lowered_stmt(output_dir + "r2c.html", bench_r2c.infer_arguments(), HTML);
Realization R_r2c = bench_r2c.realize(W, H/2 + 1, reps, target);
R_r2c[0].raw_buffer()->dim[2].stride = 0;
R_r2c[1].raw_buffer()->dim[2].stride = 0;
halide_t = benchmark(samples, 1, [&]() { bench_r2c.realize(R_r2c); })*1e6/reps;
#ifdef WITH_FFTW
std::vector<float> fftw_r(W * H);
fftwf_plan r2c_plan = fftwf_plan_dft_r2c_2d(W, H, &fftw_r[0], (fftwf_complex*)&fftw_c1[0], FFTW_EXHAUSTIVE);
fftw_t = benchmark(samples, reps, [&]() { fftwf_execute(r2c_plan); })*1e6;
#else
fftw_t = 0;
#endif
printf("%12s %10.3f %10.2f %10.3f %10.2f %10.3g\n",
"r2c",
halide_t,
2.5*W*H*(log2(W) + log2(H))/halide_t,
fftw_t,
2.5*W*H*(log2(W) + log2(H))/fftw_t,
fftw_t / halide_t);
ComplexFunc c2r_in;
c2r_in(x, y, rep) = {re_in(x, y), im_in(x, y)};
Func bench_c2r = fft2d_c2r(c2r_in, W, H, target, inv_desc);
bench_c2r.compile_to_lowered_stmt(output_dir + "c2r.html", bench_c2r.infer_arguments(), HTML);
Realization R_c2r = bench_c2r.realize(W, H, reps, target);
R_c2r[0].raw_buffer()->dim[2].stride = 0;
halide_t = benchmark(samples, 1, [&]() { bench_c2r.realize(R_c2r); })*1e6/reps;
#ifdef WITH_FFTW
fftwf_plan c2r_plan = fftwf_plan_dft_c2r_2d(W, H, (fftwf_complex*)&fftw_c1[0], &fftw_r[0], FFTW_EXHAUSTIVE);
fftw_t = benchmark(samples, reps, [&]() { fftwf_execute(c2r_plan); })*1e6;
#else
fftw_t = 0;
#endif
printf("%12s %10.3f %10.2f %10.3f %10.2f %10.3g\n",
"c2r",
halide_t,
2.5*W*H*(log2(W) + log2(H))/halide_t,
fftw_t,
2.5*W*H*(log2(W) + log2(H))/fftw_t,
fftw_t / halide_t);
#ifdef WITH_FFTW
fftwf_destroy_plan(c2c_plan);
fftwf_destroy_plan(r2c_plan);
fftwf_destroy_plan(c2r_plan);
#endif
return 0;
}