This source file includes following definitions.
- generate
#include "Halide.h"
using namespace Halide;
using namespace Halide::ConciseCasts;
class MatMul : public Generator<MatMul> {
public:
Input<Buffer<uint8_t>> A{"A", 2};
Input<Buffer<uint8_t>> B{"B", 2};
Output<Buffer<uint32_t>> output{"output", 2};
std::function<void()> schedule;
void generate() {
Var x{"x"}, y{"y"};
const int k_unroll_factor = 2;
Expr k_extent = A.dim(0).extent();
k_extent = (k_extent/(k_unroll_factor*4))*(k_unroll_factor*4);
const int k_split_factor = 4;
RDom rk(0, k_extent/k_split_factor, "k");
Func B_swizzled("B_swizzled");
Var k("k");
B_swizzled(x, y, k) = B(x, 4*y + k);
Func AB("AB");
AB(x, y) = u32(0);
AB(x, y) +=
u32(u16(A(4*rk + 0, y))*u16(B_swizzled(x, rk, 0))) +
u32(u16(A(4*rk + 1, y))*u16(B_swizzled(x, rk, 1))) +
u32(u16(A(4*rk + 2, y))*u16(B_swizzled(x, rk, 2))) +
u32(u16(A(4*rk + 3, y))*u16(B_swizzled(x, rk, 3)));
output(x, y) = AB(x, y);
schedule = [=]() mutable {
const Target &target = get_target();
int vector_size_u8 = target.natural_vector_size<uint8_t>();
bool use_hexagon = false;
if (target.has_feature(Target::HVX_64)) {
vector_size_u8 = 64;
use_hexagon = true;
} else if (target.has_feature(Target::HVX_128)) {
vector_size_u8 = 128;
use_hexagon = true;
}
int vector_size_u32 = vector_size_u8 / 4;
int tile_rows = 4;
if (use_hexagon) {
Var xo("xo"), yo("yo");
output.compute_root()
.hexagon()
.tile(x, y, xo, yo, x, y, vector_size_u8, tile_rows, TailStrategy::RoundUp)
.reorder(x, y, yo, xo)
.prefetch(A, yo, 1)
.vectorize(x)
.unroll(y)
.parallel(xo);
AB.compute_at(output, yo)
.vectorize(x)
.unroll(y);
AB.update(0)
.reorder(x, y, rk)
.vectorize(x)
.unroll(y)
.unroll(rk, k_unroll_factor);
B_swizzled.compute_at(output, xo)
.reorder_storage(k, x, y)
.reorder(k, x, y)
.vectorize(x, vector_size_u8, TailStrategy::RoundUp)
.unroll(k);
} else {
Var xi("xi"), xii("xii"), yi("yi"), yii("yii");
RVar rki("rki");
constexpr int kBlockSize = 32;
const int kBlockSizeXi = 8;
output.compute_root()
.tile(x, y, x, y, xi, yi, vector_size_u8, tile_rows, TailStrategy::RoundUp)
.reorder(xi, yi, x, y)
.vectorize(xi)
.unroll(yi)
.parallel(y);
AB.compute_root()
.vectorize(x, vector_size_u32);
AB.update(0)
.split(x, x, xi, kBlockSize, TailStrategy::GuardWithIf)
.split(xi, xi, xii, kBlockSizeXi, TailStrategy::GuardWithIf)
.split(y, y, yi, kBlockSize, TailStrategy::GuardWithIf)
.split(yi, yi, yii, 4, TailStrategy::GuardWithIf)
.split(rk, rk, rki, kBlockSize/k_split_factor, TailStrategy::GuardWithIf)
.reorder(xii, yii, xi, rki, yi, rk, x, y)
.parallel(y)
.vectorize(xii)
.unroll(xi)
.unroll(yii);
}
A.dim(0)
.set_bounds(0, (k_extent/vector_size_u8)*vector_size_u8);
A.dim(1)
.set_bounds(0, (A.dim(1).extent()/tile_rows)*tile_rows)
.set_stride((A.dim(1).stride()/vector_size_u8)*vector_size_u8);
B.dim(0)
.set_bounds(0, (B.dim(0).extent()/vector_size_u8)*vector_size_u8);
B.dim(1)
.set_bounds(0, (k_extent/vector_size_u8)*vector_size_u8)
.set_stride((B.dim(1).stride()/vector_size_u8)*vector_size_u8);
output.dim(0)
.set_bounds(0, (output.dim(0).extent()/vector_size_u32)*vector_size_u32);
output.dim(1)
.set_bounds(0, (output.dim(1).extent()/tile_rows)*tile_rows)
.set_stride((output.dim(1).stride()/vector_size_u32)*vector_size_u32);
};
}
};
HALIDE_REGISTER_GENERATOR(MatMul, "matmul");