This source file includes following definitions.
- fft_number_type_enum_map
- fft_direction_enum_map
- generate
- schedule
#include "Halide.h"
#include "fft.h"
namespace {
using namespace Halide;
enum class FFTNumberType { Real, Complex };
std::map<std::string, FFTNumberType> fft_number_type_enum_map() {
return { { "real", FFTNumberType::Real },
{ "complex", FFTNumberType::Complex } };
}
enum class FFTDirection { SamplesToFrequency, FrequencyToSamples };
std::map<std::string, FFTDirection> fft_direction_enum_map() {
return { { "samples_to_frequency", FFTDirection::SamplesToFrequency },
{ "frequency_to_samples", FFTDirection::FrequencyToSamples } };
}
class FFTGenerator : public Halide::Generator<FFTGenerator> {
public:
GeneratorParam<float> gain{"gain", 1.0f};
GeneratorParam<int32_t> vector_width{"vector_width", 0};
GeneratorParam<bool> parallel{"parallel", false};
GeneratorParam<FFTDirection> direction{"direction", FFTDirection::SamplesToFrequency,
fft_direction_enum_map() };
GeneratorParam<FFTNumberType> input_number_type{"input_number_type",
FFTNumberType::Real, fft_number_type_enum_map() };
GeneratorParam<FFTNumberType> output_number_type{"output_number_type",
FFTNumberType::Real, fft_number_type_enum_map() };
GeneratorParam<int32_t> size0{"size0", 1};
GeneratorParam<int32_t> size1{"size1", 0};
Input<Buffer<float>> input{"input", 3};
Output<Buffer<float>> output{"output", 3};
void generate() {
_halide_user_assert(size0 > 0) << "FFT must be at least 1D\n";
Fft2dDesc desc;
desc.gain = gain;
desc.vector_width = vector_width;
const int sign = (direction == FFTDirection::SamplesToFrequency) ? -1 : 1;
if (input_number_type == FFTNumberType::Real) {
if (direction == FFTDirection::SamplesToFrequency) {
Func in;
in(x, y) = input(x, y, 0);
complex_result = fft2d_r2c(in, size0, size1, target, desc);
} else {
ComplexFunc in;
in(x, y) = ComplexExpr(input(x, y, 0), 0);
complex_result = fft2d_c2c(in, size0, size1, sign, target, desc);
}
} else {
ComplexFunc in;
in(x, y) = ComplexExpr(input(x, y, 0), input(x, y, 1));
if (output_number_type == FFTNumberType::Real &&
direction == FFTDirection::FrequencyToSamples) {
real_result = fft2d_c2r(in, size0, size1, target, desc);
} else {
complex_result = fft2d_c2c(in, size0, size1, sign, target, desc);
}
}
if (output_number_type == FFTNumberType::Real) {
if (real_result.defined()) {
output(x, y, c) = real_result(x, y);
} else {
output(x, y, c) = re(complex_result(x, y));
}
} else {
output(x, y, c) = select(c == 0,
re(complex_result(x, y)),
im(complex_result(x, y)));
}
}
void schedule() {
const int input_comps = (input_number_type == FFTNumberType::Real) ? 1 : 2;
const int output_comps = (output_number_type == FFTNumberType::Real) ? 1 : 2;
input.dim(0).set_stride(input_comps)
.dim(2).set_min(0).set_extent(input_comps).set_stride(1);
output.dim(0).set_stride(output_comps)
.dim(2).set_min(0).set_extent(output_comps).set_stride(1);
if (output_comps != 1) {
output.reorder(c, x, y).unroll(c);
}
if (real_result.defined()) {
real_result.compute_at(output, Var::outermost());
} else {
assert(complex_result.defined());
complex_result.compute_at(output, Var::outermost());
}
}
private:
Var x{"x"}, y{"y"}, c{"c"};
Func real_result;
ComplexFunc complex_result;
};
Halide::RegisterGenerator<FFTGenerator> register_fft{"fft"};
}