This source file includes following definitions.
- undef_z
- gcd
- lcm
- product
- A
- A
- get_func_refs
- mul
- dft2
- dft4
- dft6
- dft8
- dftN
- dft1d_c2c
- twiddle_factors
- fft_dim1
- transpose
- tiled_transpose
- fft2d_c2c
- fft2d_r2c
- fft2d_c2r
- radix_factor
- fft2d_c2c
- fft2d_r2c
- fft2d_c2r
#include "fft.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include <limits>
#include <map>
#include <ostream>
#include <string>
#include "funct.h"
using std::vector;
using std::string;
using namespace Halide;
using namespace Halide::BoundaryConditions;
namespace {
#ifndef M_PI
#define M_PI 3.14159265358979310000
#endif
const float kPi = static_cast<float>(M_PI);
Var group("g");
const ComplexExpr j(Expr(0), Expr(1));
ComplexExpr undef_z(Type t = Float(32)) {
return ComplexExpr(undef(t), undef(t));
}
int gcd(int x, int y) {
while (y != 0) {
int r = x % y;
x = y;
y = r;
}
return x;
}
int lcm(int x, int y) {
return std::min(x, y) * (std::max(x, y) / gcd(x, y));
}
int product(const vector<int> &R) {
int p = 1;
for (size_t i = 0; i < R.size(); i++) {
p *= R[i];
}
return p;
}
vector<Var> A(vector<Var> l, const vector<Var> &r) {
for (const Var& i : r) {
l.push_back(i);
}
return l;
}
template <typename T>
vector<Expr> A(vector<Expr> l, const vector<T> &r) {
for (const Var& i : r) {
l.push_back(i);
}
return l;
}
typedef FuncRefT<ComplexExpr> ComplexFuncRef;
vector<ComplexFuncRef> get_func_refs(ComplexFunc x, int N, bool temps = false) {
vector<Var> args(x.args());
args.erase(args.begin());
vector<ComplexFuncRef> refs;
for (int i = 0; i < N; i++) {
if (temps) {
refs.push_back(x(A({Expr(-i - 1)}, args)));
} else {
refs.push_back(x(A({Expr(i)}, args)));
}
}
return refs;
}
ComplexExpr mul(ComplexExpr a, float re_b, float im_b) {
return a * ComplexExpr(re_b, im_b);
}
ComplexFunc dft2(ComplexFunc f, const string& prefix) {
Type type = f.output_types()[0];
ComplexFunc F(prefix + "X2");
F(f.args()) = undef_z(type);
vector<ComplexFuncRef> x = get_func_refs(f, 2);
vector<ComplexFuncRef> X = get_func_refs(F, 2);
X[0] = x[0] + x[1];
X[1] = x[0] - x[1];
return F;
}
ComplexFunc dft4(ComplexFunc f, int sign, const string& prefix) {
Type type = f.output_types()[0];
ComplexFunc F(prefix + "X4");
F(f.args()) = undef_z(type);
vector<ComplexFuncRef> x = get_func_refs(f, 4);
vector<ComplexFuncRef> X = get_func_refs(F, 4);
vector<ComplexFuncRef> T = get_func_refs(F, 2, true);
T.push_back(T[1]);
T.push_back(T[0]);
T[0] = (x[0] + x[2]);
T[2] = (x[1] + x[3]);
X[0] = (T[0] + T[2]);
X[2] = (T[0] - T[2]);
T[1] = (x[0] - x[2]);
T[3] = (x[1] - x[3]) * j * sign;
X[1] = (T[1] + T[3]);
X[3] = (T[1] - T[3]);
return F;
}
ComplexFunc dft6(ComplexFunc f, int sign, const string& prefix) {
const float re_W1_3 = -0.5f;
const float im_W1_3 = sign*0.866025404f;
ComplexExpr W1_3(re_W1_3, im_W1_3);
ComplexExpr W2_3(re_W1_3, -im_W1_3);
ComplexExpr W4_3 = W1_3;
Type type = f.output_types()[0];
ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
vector<ComplexFuncRef> x = get_func_refs(f, 6);
vector<ComplexFuncRef> X = get_func_refs(F, 6);
vector<ComplexFuncRef> T = get_func_refs(F, 6, true);
T[0] = (x[0] + x[3]);
T[3] = (x[0] - x[3]);
T[1] = (x[1] + x[4]);
T[4] = (x[1] - x[4]);
T[2] = (x[2] + x[5]);
T[5] = (x[2] - x[5]);
X[0] = T[0] + T[2] + T[1];
X[4] = T[0] + T[2]*W1_3 + T[1]*W2_3;
X[2] = T[0] + T[2]*W2_3 + T[1]*W4_3;
X[3] = T[3] + T[5] - T[4];
X[1] = T[3] + T[5]*W1_3 - T[4]*W2_3;
X[5] = T[3] + T[5]*W2_3 - T[4]*W4_3;
return F;
}
ComplexFunc dft8(ComplexFunc f, int sign, const string& prefix) {
const float sqrt2_2 = 0.70710678f;
Type type = f.output_types()[0];
ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
vector<ComplexFuncRef> x = get_func_refs(f, 8);
vector<ComplexFuncRef> X = get_func_refs(F, 8);
vector<ComplexFuncRef> T = get_func_refs(F, 8, true);
X[0] = (x[0] + x[4]);
X[2] = (x[2] + x[6]);
T[0] = (X[0] + X[2]);
T[2] = (X[0] - X[2]);
X[1] = (x[0] - x[4]);
X[3] = (x[2] - x[6]) * j * sign;
T[1] = (X[1] + X[3]);
T[3] = (X[1] - X[3]);
X[4] = (x[1] + x[5]);
X[6] = (x[3] + x[7]);
T[4] = (X[4] + X[6]);
T[6] = (X[4] - X[6]) * j * sign;
X[5] = (x[1] - x[5]);
X[7] = (x[3] - x[7]) * j * sign;
T[5] = mul(X[5] + X[7], sqrt2_2, sign * sqrt2_2);
T[7] = mul(X[5] - X[7], -sqrt2_2, sign * sqrt2_2);
X[0] = (T[0] + T[4]);
X[1] = (T[1] + T[5]);
X[2] = (T[2] + T[6]);
X[3] = (T[3] + T[7]);
X[4] = (T[0] - T[4]);
X[5] = (T[1] - T[5]);
X[6] = (T[2] - T[6]);
X[7] = (T[3] - T[7]);
return F;
}
ComplexFunc dftN(ComplexFunc x, int N, int sign, const string& prefix) {
vector<Var> args(x.args());
args.erase(args.begin());
Var n("n");
ComplexFunc X(prefix + "XN");
if (N < 10) {
ComplexExpr dft = x(A({Expr(0)}, args));
for (int k = 1; k < N; k++) {
dft += expj((sign*2*kPi*k*n)/N) * x(A({Expr(k)}, args));
}
X(A({n}, args)) = dft;
} else {
RDom k(0, N);
X(A({n}, args)) = sum(expj((sign*2*kPi*k*n)/N) * x(A({k}, args)));
}
X.unroll(n);
return X;
}
ComplexFunc dft1d_c2c(ComplexFunc x, int N, int sign,
const string& prefix) {
switch (N) {
case 1: return x;
case 2: return dft2(x, prefix);
case 4: return dft4(x, sign, prefix);
case 6: return dft6(x, sign, prefix);
case 8: return dft8(x, sign, prefix);
default: return dftN(x, N, sign, prefix);
}
}
typedef std::map<int, ComplexFunc> TwiddleFactorSet;
ComplexFunc twiddle_factors(int N, Expr gain, int sign,
const string& prefix,
TwiddleFactorSet* cache) {
ComplexFunc W(prefix + "W");
if (is_one(gain)) {
W = (*cache)[N];
}
if (!W.defined()) {
Var n("n");
W(n) = expj((sign * 2 * kPi * n) / N) * gain;
W.compute_root();
}
return W;
}
ComplexFunc fft_dim1(ComplexFunc x,
const vector<int>& NR,
int sign,
int extent_0,
Expr gain,
bool parallel,
const string& prefix,
const Target& target,
TwiddleFactorSet* twiddle_cache) {
int N = product(NR);
vector<Var> args = x.args();
Var n0(args[0]), n1(args[1]);
args.erase(args.begin());
args.erase(args.begin());
vector<std::pair<Func, RDom>> stages;
RVar r_, s_;
int S = 1;
int vector_width = 1;
for (size_t i = 0; i < NR.size(); i++) {
int R = NR[i];
std::stringstream stage_id;
stage_id << prefix;
if (S == N / R) {
stage_id << "fft1";
} else {
stage_id << "x";
}
stage_id << "_S" << S << "_R" << R << "_" << n1.name();
ComplexFunc exchange(stage_id.str());
Var r("r"), s("s");
ComplexFunc v("v_" + stage_id.str());
ComplexExpr x_rs = x(A({n0, s + r * (N / R)}, args));
if (S > 1) {
x_rs = cast<float>(x_rs);
ComplexFunc W = twiddle_factors(R * S, gain, sign, prefix, twiddle_cache);
v(A({r, s, n0}, args)) = select(r > 0, likely(x_rs * W(r * (s % S))), x_rs * gain);
gain = 1.0f;
} else {
v(A({r, s, n0}, args)) = x_rs;
}
vector_width = lcm(vector_width, target.natural_vector_size(v.output_types()[0]));
ComplexFunc V = dft1d_c2c(v, R, sign, prefix);
exchange(A({n0, n1}, args)) = undef_z(V.output_types()[0]);
RDom rs(0, R, 0, N / R);
r_ = rs.x;
s_ = rs.y;
ComplexExpr V_rs = V(A({r_, s_, n0}, args));
if (S == N / R) {
V_rs = V_rs * gain;
gain = 1.0f;
}
exchange(A({n0, ((s_ / S) * R * S) + (s_ % S) + (r_ * S)}, args)) = V_rs;
exchange.bound(n1, 0, N);
if (S > 1) {
v.compute_at(exchange, s_).unroll(r);
v.reorder_storage(n0, r, s);
}
V.compute_at(exchange, s_);
V.reorder_storage(V.args()[2], V.args()[0], V.args()[1]);
if (S == N / R) {
v.vectorize(n0);
V.vectorize(V.args()[2]);
for (int i = 0; i < V.num_update_definitions(); i++) {
V.update(i).vectorize(V.args()[2]);
}
}
exchange.update().unroll(r_);
stages.push_back({ exchange, rs });
x = exchange;
S *= R;
}
vector_width = std::min(vector_width, extent_0);
x.update()
.split(n0, group, n0, vector_width)
.reorder(n0, r_, s_, group)
.vectorize(n0);
if (parallel) {
x.update().parallel(group);
}
for (size_t i = 0; i + 1 < stages.size(); i++) {
Func stage = stages[i].first;
stage.compute_at(x, group).update().vectorize(n0);
}
return x;
}
template <typename FuncType>
FuncType transpose(FuncType f) {
vector<Halide::Var> argsT(f.args());
std::swap(argsT[0], argsT[1]);
FuncType fT;
fT(argsT) = f(f.args());
return fT;
}
template <typename FuncType>
std::pair<FuncType, FuncType> tiled_transpose(FuncType f, int max_tile_size,
const Target& target,
const string& prefix,
bool always_tile = false) {
if (target.arch != Target::ARM && !always_tile) {
return { transpose(f), FuncType() };
}
const int tile_size =
std::min(max_tile_size, target.natural_vector_size(f.output_types()[0]));
vector<Var> args = f.args();
Var x(args[0]), y(args[1]);
args.erase(args.begin());
args.erase(args.begin());
Var xo(x.name() + "o");
Var yo(y.name() + "o");
FuncType f_tiled(prefix + "tiled");
f_tiled(A({x, y, xo, yo}, args)) = f(A({xo * tile_size + x, yo * tile_size + y}, args));
FuncType f_tiledT(prefix + "tiledT");
f_tiledT(A({y, x, xo, yo}, args)) = f_tiled(A({x, y, xo, yo}, args));
FuncType fT_tiled(prefix + "T_tiled");
fT_tiled(A({y, x, yo, xo}, args)) = f_tiledT(A({y, x, xo, yo}, args));
FuncType fT(prefix + "T");
fT(A({y, x}, args)) = fT_tiled(A({y % tile_size, x % tile_size, y / tile_size, x / tile_size}, args));
f_tiledT
.vectorize(x, tile_size)
.unroll(y, tile_size);
return { fT, f_tiledT };
}
}
ComplexFunc fft2d_c2c(ComplexFunc x,
vector<int> R0,
vector<int> R1,
int sign,
const Target& target,
const Fft2dDesc& desc) {
string prefix = desc.name.empty() ? "c2c_" : desc.name + "_";
int N0 = product(R0);
int N1 = product(R1);
Var outer = Var::outermost();
if (x.dimensions() > 2) {
outer = x.args()[2];
}
Var n0 = x.args()[0];
Var n1 = x.args()[1];
TwiddleFactorSet twiddle_cache;
ComplexFunc xT, x_tiled;
std::tie(xT, x_tiled) = tiled_transpose(x, N1, target, prefix);
ComplexFunc dft1T = fft_dim1(xT,
R0,
sign,
N1,
1.0f,
desc.parallel,
prefix,
target,
&twiddle_cache);
ComplexFunc dft1, dft1_tiled;
std::tie(dft1, dft1_tiled) = tiled_transpose(dft1T, N0, target, prefix);
ComplexFunc dft = fft_dim1(dft1,
R1,
sign,
N0,
desc.gain,
desc.parallel,
prefix,
target,
&twiddle_cache);
if (dft1_tiled.defined()) {
dft1_tiled.compute_at(dft, group);
} else {
xT.compute_at(dft, outer).vectorize(n0).unroll(n1);
}
if (x_tiled.defined()) {
x_tiled.compute_at(dft1T, group);
}
if (desc.schedule_input) {
x.compute_at(dft1T, group);
}
dft1T.compute_at(dft, outer);
dft.bound(dft.args()[0], 0, N0);
dft.bound(dft.args()[1], 0, N1);
return dft;
}
ComplexFunc fft2d_r2c(Func r,
const vector<int> &R0,
const vector<int> &R1,
const Target& target,
const Fft2dDesc& desc) {
string prefix = desc.name.empty() ? "r2c_" : desc.name + "_";
vector<Var> args(r.args());
Var n0(args[0]), n1(args[1]);
args.erase(args.begin());
args.erase(args.begin());
Var outer = Var::outermost();
if (!args.empty()) {
outer = args.front();
}
int N0 = product(R0);
int N1 = product(R1);
TwiddleFactorSet twiddle_cache;
Expr gain = desc.gain;
ComplexFunc zipped(prefix + "zipped");
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = target.natural_vector_size(r.output_types()[0]);
}
zip_width = gcd(zip_width, N0 / 2);
Expr zip_n0 = (n0 / zip_width) * zip_width * 2 + (n0 % zip_width);
zipped(A({n0, n1}, args)) =
ComplexExpr(r(A({zip_n0, n1}, args)),
r(A({zip_n0 + zip_width, n1}, args)));
ComplexFunc dft1 = fft_dim1(zipped,
R1,
-1,
std::min(zip_width, N0 / 2),
1.0f,
false,
prefix,
target,
&twiddle_cache);
ComplexFunc unzipped(prefix + "unzipped"); {
Expr unzip_n0 = (n0 / (zip_width * 2)) * zip_width + (n0 % zip_width);
ComplexExpr Z = dft1(A({unzip_n0, n1}, args));
ComplexExpr conjsymZ = conj(dft1(A({unzip_n0, (N1 - n1) % N1}, args)));
ComplexExpr X = Z + conjsymZ;
ComplexExpr Y = -j * (Z - conjsymZ);
gain /= 2;
unzipped(A({n0, n1}, args)) =
select(n0 % (zip_width * 2) < zip_width, X, Y);
}
ComplexFunc zipped_0(prefix + "zipped_0");
zipped_0(A({n0, n1}, args)) =
select(n1 > 0, likely(unzipped(A({n0, n1}, args))),
ComplexExpr(re(unzipped(A({n0, 0}, args))),
re(unzipped(A({n0, N1 / 2}, args)))));
int zipped_extent0 = std::min((N1 + 1) / 2, zip_width);
ComplexFunc unzippedT, unzippedT_tiled;
std::tie(unzippedT, unzippedT_tiled) = tiled_transpose(zipped_0, zipped_extent0, target, prefix);
ComplexFunc dftT = fft_dim1(unzippedT,
R0,
-1,
zipped_extent0,
gain,
desc.parallel,
prefix,
target,
&twiddle_cache);
ComplexFunc dft = transpose(dftT);
dft = ComplexFunc(constant_exterior((Func)dft, Tuple(undef_z()), Expr(), Expr(), Expr(0), Expr(N1 / 2)));
RDom n0z1(1, N0 / 2);
RDom n0z2(N0 / 2, N0 / 2);
dft(A({0, N1 / 2}, args)) = im(dft(A({0, 0}, args)));
dft(A({n0z1, N1 / 2}, args)) =
0.5f * -j * (dft(A({n0z1, 0}, args)) - conj(dft(A({N0 - n0z1, 0}, args))));
dft(A({n0z2, N1 / 2}, args)) = conj(dft(A({N0 - n0z2, N1 / 2}, args)));
dft(A({0, 0}, args)) = re(dft(A({0, 0}, args)));
dft(A({n0z1, 0}, args)) =
0.5f * (dft(A({n0z1, 0}, args)) + conj(dft(A({N0 - n0z1, 0}, args))));
dft(A({n0z2, 0}, args)) = conj(dft(A({N0 - n0z2, 0}, args)));
dftT.compute_at(dft, outer);
if (unzippedT_tiled.defined()) {
unzippedT_tiled.compute_at(dftT, group);
}
if (desc.schedule_input) {
r.compute_at(dft1, group);
}
Var n0o("n0o"), n0i("n0i");
unzipped.compute_at(dft, outer)
.split(n0, n0o, n0i, zip_width * 2)
.reorder(n0i, n1, n0o)
.vectorize(n0i, zip_width)
.unroll(n0i);
dft1.compute_at(unzipped, n0o);
if (desc.parallel) {
unzipped.parallel(n0o);
}
dft.vectorize(n0, target.natural_vector_size<float>())
.unroll(n0, gcd(N0 / target.natural_vector_size<float>(), 4));
dft.update(1).allow_race_conditions()
.vectorize(n0z1, target.natural_vector_size<float>());
dft.update(2).allow_race_conditions()
.vectorize(n0z2, target.natural_vector_size<float>());
dft.update(4).allow_race_conditions()
.vectorize(n0z1, target.natural_vector_size<float>());
dft.update(5).allow_race_conditions()
.vectorize(n0z2, target.natural_vector_size<float>());
dft.bound(n0, 0, N0);
dft.bound(n1, 0, (N1 + 1) / 2 + 1);
return dft;
}
Func fft2d_c2r(ComplexFunc c,
vector<int> R0,
vector<int> R1,
const Target& target,
const Fft2dDesc& desc) {
string prefix = desc.name.empty() ? "c2r_" : desc.name + "_";
vector<Var> args = c.args();
Var n0(args[0]), n1(args[1]);
args.erase(args.begin());
args.erase(args.begin());
Var outer = Var::outermost();
if (!args.empty()) {
outer = args.front();
}
int N0 = product(R0);
int N1 = product(R1);
TwiddleFactorSet twiddle_cache;
int zipped_extent0 = (N1 + 1) / 2;
ComplexFunc c_zipped(prefix + "c_zipped"); {
ComplexExpr X = c(A({n0, 0}, args));
ComplexExpr Y = c(A({n0, N1 / 2}, args));
c_zipped(A({n0, n1}, args)) = select(n1 > 0, likely(c(A({n0, n1}, args))), X + j * Y);
}
ComplexFunc cT, cT_tiled;
std::tie(cT, cT_tiled) =
tiled_transpose(c_zipped, zipped_extent0, target, prefix);
ComplexFunc dft0T = fft_dim1(cT,
R0,
1,
zipped_extent0,
1.0f,
desc.parallel,
prefix,
target,
&twiddle_cache);
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = target.natural_vector_size(dft0T.output_types()[0]);
}
ComplexFunc dft0, dft0_tiled;
std::tie(dft0, dft0_tiled) = tiled_transpose(dft0T, zip_width, target, prefix, true);
ComplexFunc dft0_unzipped("dft0_unzipped"); {
dft0_unzipped(A({n0, n1}, args)) =
select(n1 <= 0, re(dft0(A({n0, 0}, args))),
n1 >= N1 / 2, im(dft0(A({n0, 0}, args))),
likely(dft0(A({n0, min(n1, (N1 / 2) - 1)}, args))));
}
ComplexFunc dft0_bounded =
ComplexFunc(repeat_edge((Func)dft0_unzipped, Expr(0), Expr(N0), Expr(0), Expr((N1 + 1) / 2 + 1)));
zip_width = gcd(zip_width, N0 / 2);
ComplexFunc zipped(prefix + "zipped"); {
Expr n0_X = (n0 / zip_width) * zip_width * 2 + (n0 % zip_width);
Expr n1_sym = (N1 - n1) % N1;
ComplexExpr X = select(n1 < N1 / 2,
dft0_bounded(A({n0_X, n1}, args)),
conj(dft0_bounded(A({n0_X, n1_sym}, args))));
Expr n0_Y = n0_X + zip_width;
ComplexExpr Y = select(n1 < N1 / 2,
dft0_bounded(A({n0_Y, n1}, args)),
conj(dft0_bounded(A({n0_Y, n1_sym}, args))));
zipped(A({n0, n1}, args)) = X + j * Y;
}
ComplexFunc dft = fft_dim1(zipped,
R1,
1,
std::min(zip_width, N0 / 2),
desc.gain,
desc.parallel,
prefix,
target,
&twiddle_cache);
ComplexFunc dft_padded = ComplexFunc(repeat_edge((Func)dft, Expr(), Expr(), Expr(0), Expr(N1)));
Func unzipped(prefix + "unzipped"); {
Expr unzip_n0 = (n0 / (zip_width * 2)) * zip_width + (n0 % zip_width);
unzipped(A({n0, n1}, args)) =
select(n0 % (zip_width * 2) < zip_width,
re(dft_padded(A({unzip_n0, n1}, args))),
im(dft_padded(A({unzip_n0, n1}, args))));
}
if (cT_tiled.defined()) {
cT_tiled.compute_at(dft0T, group);
}
dft0_tiled.compute_at(dft, outer);
if (desc.schedule_input) {
c.compute_at(dft, outer);
}
dft0T.compute_at(dft, outer);
dft.compute_at(unzipped, outer);
unzipped.bound(n0, 0, N0);
unzipped.bound(n1, 0, N1);
unzipped
.vectorize(n0, zip_width)
.unroll(n0, gcd(N0 / zip_width, 4));
return unzipped;
}
namespace {
vector<int> radix_factor(int N) {
switch (N) {
case 16: return { 4, 4 };
case 32: return { 8, 4 };
case 64: return { 8, 8 };
case 128: return { 8, 4, 4 };
case 256: return { 8, 8, 4 };
}
static const int radices[] = { 8, 6, 4, 2 };
vector<int> R;
for (int r : radices) {
while (N % r == 0) {
R.push_back(r);
N /= r;
}
}
if (N != 1 || R.empty()) {
R.push_back(N);
}
return R;
}
}
ComplexFunc fft2d_c2c(ComplexFunc x,
int N0, int N1,
int sign,
const Target& target,
const Fft2dDesc& desc) {
return fft2d_c2c(x, radix_factor(N0), radix_factor(N1), sign, target, desc);
}
ComplexFunc fft2d_r2c(Func r,
int N0, int N1,
const Target& target,
const Fft2dDesc& desc) {
return fft2d_r2c(r, radix_factor(N0), radix_factor(N1), target, desc);
}
Func fft2d_c2r(ComplexFunc c,
int N0, int N1,
const Target& target,
const Fft2dDesc& desc) {
return fft2d_c2r(c, radix_factor(N0), radix_factor(N1), target, desc);
}