root/apps/fft/fft_generator.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. fft_number_type_enum_map
  2. fft_direction_enum_map
  3. generate
  4. schedule

#include "Halide.h"

#include "fft.h"

namespace {

using namespace Halide;

enum class FFTNumberType { Real, Complex };
std::map<std::string, FFTNumberType> fft_number_type_enum_map() {
     return { { "real", FFTNumberType::Real },
              { "complex", FFTNumberType::Complex } };
}

// Direction of FFT. Samples can be read as "time" or "spatial" depending
// on the meaning of the input domain.
enum class FFTDirection { SamplesToFrequency, FrequencyToSamples };
std::map<std::string, FFTDirection> fft_direction_enum_map() {
  return { { "samples_to_frequency", FFTDirection::SamplesToFrequency },
           { "frequency_to_samples", FFTDirection::FrequencyToSamples } };
}

class FFTGenerator : public Halide::Generator<FFTGenerator> {
public:

    // Gain to apply to the FFT. This is folded into gains already
    // being applied to the FFT. A gain of 1.0f indicates an
    // unnormalized FFT. 1 / sqrt(N) gives a unitary transform such that
    // forward and inverse operations have the same gain without changing
    // signal magnitude.
    // A common convention is 1/N for the forward direction and 1 for the
    // inverse.
    // "N" above is the size of the input, which is the product of
    // the dimensions.
    GeneratorParam<float> gain{"gain", 1.0f};

    // The following option specifies that a particular vector width should be
    // used when the vector width can change the results of the FFT.
    // Some parts of the FFT algorithm use the vector width to change the way
    // floating point operations are ordered and grouped, which causes the results
    // to vary with respect to the target architecture. Setting this option forces
    // such stages to use the specified vector width (independent of the actual
    // architecture's vector width), which eliminates the architecture specific
    // behavior.
    GeneratorParam<int32_t> vector_width{"vector_width", 0};

    // The following option indicates that the FFT should parallelize within a
    // single FFT. This only makes sense to use on large FFTs, and generally only
    // if there is no outer loop around FFTs that can be parallelized.
    GeneratorParam<bool> parallel{"parallel", false};

    // Indicates forward or inverse Fourier transform --
    // "samples_to_frequency" maps to a forward FFT. (Other packages sometimes call this a sign of -1)
    // "frequency_to_samples" maps to a forward FFT. (Other packages sometimes call this a sign of +1)
    GeneratorParam<FFTDirection> direction{"direction", FFTDirection::SamplesToFrequency,
        fft_direction_enum_map() };

    // Whether the input is "real" or "complex".
    GeneratorParam<FFTNumberType> input_number_type{"input_number_type",
        FFTNumberType::Real, fft_number_type_enum_map() };
    // Whether the output is "real" or "complex".
    GeneratorParam<FFTNumberType> output_number_type{"output_number_type",
        FFTNumberType::Real, fft_number_type_enum_map() };

    // Size of first dimension, required to be greater than zero.
    GeneratorParam<int32_t> size0{"size0", 1};
    // Size of second dimension, may be zero for 1D FFT.
    GeneratorParam<int32_t> size1{"size1", 0};
    // TODO(zalman): Add support for 3D and maybe 4D FFTs

    // The input buffer. Must be separate from the output.
    // Only Float(32) is supported.
    //
    // For a real input FFT, this should have the following shape:
    // Dim0: extent = size0, stride = 1
    // Dim1: extent = size1 / 2 - 1, stride = size0
    // Dim2: extent = 1, stride = 1
    //
    // For a complex input FFT, this should have the following shape:
    // Dim0: extent = size0, stride = 2
    // Dim1: extent = size1, stride = size0 * 2
    // Dim2: extent = 2, stride = 1 (real followed by imaginary components)
    Input<Buffer<float>>  input{"input", 3};
    Output<Buffer<float>> output{"output", 3};

    void generate() {
        _halide_user_assert(size0 > 0) << "FFT must be at least 1D\n";

        Fft2dDesc desc;

        desc.gain = gain;
        desc.vector_width = vector_width;

        // The logic below calls the specialized r2c or c2r version if
        // applicable to take advantage of better scheduling. It is
        // assumed that projecting a real Func to a ComplexFunc and
        // immediately back has zero cost.

        const int sign = (direction == FFTDirection::SamplesToFrequency) ? -1 : 1;

        if (input_number_type == FFTNumberType::Real) {
            if (direction == FFTDirection::SamplesToFrequency) {
                // TODO: Not sure why this is necessary as ImageParam
                // -> Func conversion should happen, It may not work
                // with implicit dimension (use of _) logic in FFT.
                Func in;
                in(x, y) = input(x, y, 0);

                complex_result = fft2d_r2c(in, size0, size1, target, desc);
            } else {
                ComplexFunc in;
                in(x, y) = ComplexExpr(input(x, y, 0), 0);

                complex_result = fft2d_c2c(in, size0, size1, sign, target, desc);
            }
        } else {
            ComplexFunc in;
            in(x, y) = ComplexExpr(input(x, y, 0), input(x, y, 1));
            if (output_number_type == FFTNumberType::Real &&
                direction == FFTDirection::FrequencyToSamples) {
                real_result = fft2d_c2r(in, size0, size1, target, desc);
            } else {
                complex_result = fft2d_c2c(in, size0, size1, sign, target, desc);
            }
        }

        if (output_number_type == FFTNumberType::Real) {
            if (real_result.defined()) {
                 output(x, y, c) = real_result(x, y);
            } else {
                 output(x, y, c) = re(complex_result(x, y));
            }
        } else {
            output(x, y, c) = select(c == 0, 
                                     re(complex_result(x, y)), 
                                     im(complex_result(x, y)));
        }
    }

    void schedule() {
        const int input_comps = (input_number_type == FFTNumberType::Real) ? 1 : 2;
        const int output_comps = (output_number_type == FFTNumberType::Real) ? 1 : 2;

        input.dim(0).set_stride(input_comps)
             .dim(2).set_min(0).set_extent(input_comps).set_stride(1);

        output.dim(0).set_stride(output_comps)
              .dim(2).set_min(0).set_extent(output_comps).set_stride(1);
        
        if (output_comps != 1) {
            output.reorder(c, x, y).unroll(c);
        }

        if (real_result.defined()) {
            real_result.compute_at(output, Var::outermost());
        } else {
            assert(complex_result.defined());
            complex_result.compute_at(output, Var::outermost());
        }
    }
private:
    Var x{"x"}, y{"y"}, c{"c"};
    Func real_result;
    ComplexFunc complex_result;
};

Halide::RegisterGenerator<FFTGenerator> register_fft{"fft"};

}

/* [<][>][^][v][top][bottom][index][help] */