This source file includes following definitions.
- native_interleave
- native_deinterleave
- is_native_interleave_op
- is_native_interleave
- is_native_deinterleave
- bc
- with_lanes
- visit
- visit
- visit
- with_lanes
- flags
- apply_patterns
- lossless_negate
- apply_commutative_patterns
- visit
- halide_hexagon_add_2mpy
- halide_hexagon_add_2mpy
- halide_hexagon_add_4mpy
- unbroadcast_lossless_cast
- find_mpy_ops
- visit
- visit
- visit
- visit
- visit
- yields_removable_interleave
- yields_interleave
- yields_removable_interleave
- remove_interleave
- visit_binary
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- visit
- uses_var
- uses_var
- visit_let
- visit
- visit
- visit
- is_interleavable
- visit_bool_to_mask
- visit
- visit
- visit
- visit
- visit
- span_of_bounds
- visit_let
- visit
- visit
- visit
- optimize_hexagon_shuffles
- optimize_hexagon_instructions
#include "HexagonOptimize.h"
#include "ConciseCasts.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRMatch.h"
#include "IREquality.h"
#include "ExprUsesVar.h"
#include "CSE.h"
#include "Simplify.h"
#include "Substitute.h"
#include "Scope.h"
#include "Bounds.h"
#include "Lerp.h"
namespace Halide {
namespace Internal {
using std::set;
using std::vector;
using std::string;
using std::pair;
using namespace Halide::ConciseCasts;
Expr native_interleave(Expr x) {
string fn;
switch (x.type().bits()) {
case 8: fn = "halide.hexagon.interleave.vb"; break;
case 16: fn = "halide.hexagon.interleave.vh"; break;
case 32: fn = "halide.hexagon.interleave.vw"; break;
default: internal_error << "Cannot interleave native vectors of type " << x.type() << "\n";
}
return Call::make(x.type(), fn, {x}, Call::PureExtern);
}
Expr native_deinterleave(Expr x) {
string fn;
switch (x.type().bits()) {
case 8: fn = "halide.hexagon.deinterleave.vb"; break;
case 16: fn = "halide.hexagon.deinterleave.vh"; break;
case 32: fn = "halide.hexagon.deinterleave.vw"; break;
default: internal_error << "Cannot deinterleave native vectors of type " << x.type() << "\n";
}
return Call::make(x.type(), fn, {x}, Call::PureExtern);
}
bool is_native_interleave_op(Expr x, const char *name) {
const Call *c = x.as<Call>();
if (!c || c->args.size() != 1) return false;
return starts_with(c->name, name);
}
bool is_native_interleave(Expr x) {
return is_native_interleave_op(x, "halide.hexagon.interleave");
}
bool is_native_deinterleave(Expr x) {
return is_native_interleave_op(x, "halide.hexagon.deinterleave");
}
namespace {
Expr bc(Expr x) { return Broadcast::make(x, 0); }
class WithLanes : public IRMutator {
using IRMutator::visit;
int lanes;
Type with_lanes(Type t) { return t.with_lanes(lanes); }
void visit(const Cast *op) {
if (op->type.lanes() != lanes) {
expr = Cast::make(with_lanes(op->type), mutate(op->value));
} else {
IRMutator::visit(op);
}
}
void visit(const Variable *op) {
if (op->type.lanes() != lanes) {
expr = Variable::make(with_lanes(op->type), op->name);
} else {
expr = op;
}
}
void visit(const Broadcast *op) {
if (op->type.lanes() != lanes) {
expr = Broadcast::make(op->value, lanes);
} else {
IRMutator::visit(op);
}
}
public:
WithLanes(int lanes) : lanes(lanes) {}
};
Expr with_lanes(Expr x, int lanes) {
return WithLanes(lanes).mutate(x);
}
struct Pattern {
enum Flags {
InterleaveResult = 1 << 0,
SwapOps01 = 1 << 1,
SwapOps12 = 1 << 2,
ExactLog2Op1 = 1 << 3,
ExactLog2Op2 = 1 << 4,
BeginExactLog2Op = 1,
EndExactLog2Op = 3,
DeinterleaveOp0 = 1 << 5,
DeinterleaveOp1 = 1 << 6,
DeinterleaveOp2 = 1 << 7,
DeinterleaveOps = DeinterleaveOp0 | DeinterleaveOp1 | DeinterleaveOp2,
BeginDeinterleaveOp = 0,
EndDeinterleaveOp = 3,
ReinterleaveOp0 = InterleaveResult | DeinterleaveOp0,
NarrowOp0 = 1 << 10,
NarrowOp1 = 1 << 11,
NarrowOp2 = 1 << 12,
NarrowOp3 = 1 << 13,
NarrowOps = NarrowOp0 | NarrowOp1 | NarrowOp2 | NarrowOp3,
NarrowUnsignedOp0 = 1 << 15,
NarrowUnsignedOp1 = 1 << 16,
NarrowUnsignedOp2 = 1 << 17,
NarrowUnsignedOps = NarrowUnsignedOp0 | NarrowUnsignedOp1 | NarrowUnsignedOp2,
v62 = 1 << 20,
};
string intrin;
Expr pattern;
int flags;
Pattern() {}
Pattern(const string &intrin, Expr p, int flags = 0)
: intrin(intrin), pattern(p), flags(flags) {}
};
Expr wild_u8 = Variable::make(UInt(8), "*");
Expr wild_u16 = Variable::make(UInt(16), "*");
Expr wild_u32 = Variable::make(UInt(32), "*");
Expr wild_u64 = Variable::make(UInt(64), "*");
Expr wild_i8 = Variable::make(Int(8), "*");
Expr wild_i16 = Variable::make(Int(16), "*");
Expr wild_i32 = Variable::make(Int(32), "*");
Expr wild_i64 = Variable::make(Int(64), "*");
Expr wild_u8x = Variable::make(Type(Type::UInt, 8, 0), "*");
Expr wild_u16x = Variable::make(Type(Type::UInt, 16, 0), "*");
Expr wild_u32x = Variable::make(Type(Type::UInt, 32, 0), "*");
Expr wild_u64x = Variable::make(Type(Type::UInt, 64, 0), "*");
Expr wild_i8x = Variable::make(Type(Type::Int, 8, 0), "*");
Expr wild_i16x = Variable::make(Type(Type::Int, 16, 0), "*");
Expr wild_i32x = Variable::make(Type(Type::Int, 32, 0), "*");
Expr wild_i64x = Variable::make(Type(Type::Int, 64, 0), "*");
Expr apply_patterns(Expr x, const vector<Pattern> &patterns, const Target &target, IRMutator *op_mutator) {
debug(3) << "apply_patterns " << x << "\n";
vector<Expr> matches;
for (const Pattern &p : patterns) {
if ((p.flags & (Pattern::v62)) && !target.has_feature(Target::HVX_v62))
continue;
if (expr_match(p.pattern, x, matches)) {
debug(3) << "matched " << p.pattern << "\n";
debug(3) << "matches:\n";
for (Expr i : matches) {
debug(3) << i << "\n";
}
bool is_match = true;
for (size_t i = 0; i < matches.size() && is_match; i++) {
Type t = matches[i].type();
Type target_t = t.with_bits(t.bits()/2);
if (p.flags & (Pattern::NarrowOp0 << i)) {
matches[i] = lossless_cast(target_t, matches[i]);
} else if (p.flags & (Pattern::NarrowUnsignedOp0 << i)) {
matches[i] = lossless_cast(target_t.with_code(Type::UInt), matches[i]);
}
if (!matches[i].defined()) is_match = false;
}
if (!is_match) continue;
for (size_t i = Pattern::BeginExactLog2Op; i < Pattern::EndExactLog2Op && is_match; i++) {
if (p.flags & (Pattern::ExactLog2Op1 << (i - Pattern::BeginExactLog2Op))) {
int pow;
if (is_const_power_of_two_integer(matches[i], &pow)) {
matches[i] = cast(matches[i].type().with_lanes(1), pow);
} else {
is_match = false;
}
}
}
if (!is_match) continue;
for (size_t i = Pattern::BeginDeinterleaveOp; i < Pattern::EndDeinterleaveOp; i++) {
if (p.flags &
(Pattern::DeinterleaveOp0 << (i - Pattern::BeginDeinterleaveOp))) {
internal_assert(matches[i].type().is_vector());
matches[i] = native_deinterleave(matches[i]);
}
}
if (p.flags & Pattern::SwapOps01) {
internal_assert(matches.size() >= 2);
std::swap(matches[0], matches[1]);
}
if (p.flags & Pattern::SwapOps12) {
internal_assert(matches.size() >= 3);
std::swap(matches[1], matches[2]);
}
for (Expr &op : matches) {
op = op_mutator->mutate(op);
}
x = Call::make(x.type(), p.intrin, matches, Call::PureExtern);
if (p.flags & Pattern::InterleaveResult) {
x = native_interleave(x);
}
debug(3) << "rewrote to: " << x << "\n";
return x;
}
}
return x;
}
Expr lossless_negate(Expr x) {
const Mul *m = x.as<Mul>();
if (m) {
Expr a = lossless_negate(m->a);
if (a.defined()) {
return Mul::make(a, m->b);
}
Expr b = lossless_negate(m->b);
if (b.defined()) {
return Mul::make(m->a, b);
}
}
if (is_negative_negatable_const(x) || is_positive_const(x)) {
return simplify(-x);
}
return Expr();
}
template <typename T>
Expr apply_commutative_patterns(const T *op, const vector<Pattern> &patterns, const Target &target, IRMutator *mutator) {
Expr ret = apply_patterns(op, patterns, target, mutator);
if (!ret.same_as(op)) return ret;
Expr commuted = T::make(op->b, op->a);
ret = apply_patterns(commuted, patterns, target, mutator);
if (!ret.same_as(commuted)) return ret;
return op;
}
class OptimizePatterns : public IRMutator {
private:
using IRMutator::visit;
Target target;
void visit(const Mul *op) {
static const vector<Pattern> scalar_muls = {
{ "halide.hexagon.mpy.vub.ub", wild_u16x*bc(wild_u16), Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vub.b", wild_i16x*bc(wild_i16), Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1 },
{ "halide.hexagon.mpy.vuh.uh", wild_u32x*bc(wild_u32), Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vh.h", wild_i32x*bc(wild_i32), Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.shl.vub.ub", wild_u8x*bc(wild_u8), Pattern::ExactLog2Op1 },
{ "halide.hexagon.shl.vuh.uh", wild_u16x*bc(wild_u16), Pattern::ExactLog2Op1 },
{ "halide.hexagon.shl.vuw.uw", wild_u32x*bc(wild_u32), Pattern::ExactLog2Op1 },
{ "halide.hexagon.shl.vb.b", wild_i8x*bc(wild_i8), Pattern::ExactLog2Op1 },
{ "halide.hexagon.shl.vh.h", wild_i16x*bc(wild_i16), Pattern::ExactLog2Op1 },
{ "halide.hexagon.shl.vw.w", wild_i32x*bc(wild_i32), Pattern::ExactLog2Op1 },
{ "halide.hexagon.mul.vh.b", wild_i16x*bc(wild_i16), Pattern::NarrowOp1 },
{ "halide.hexagon.mul.vw.h", wild_i32x*bc(wild_i32), Pattern::NarrowOp1 },
};
static const vector<Pattern> muls = {
{ "halide.hexagon.mpy.vub.vub", wild_u16x*wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vuh.vuh", wild_u32x*wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vb.vb", wild_i16x*wild_i16x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vh.vh", wild_i32x*wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.mpy.vub.vb", wild_i16x*wild_i16x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1 },
{ "halide.hexagon.mpy.vh.vuh", wild_i32x*wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOp0 | Pattern::NarrowUnsignedOp1 },
{ "halide.hexagon.mpy.vub.vb", wild_i16x*wild_i16x, Pattern::InterleaveResult | Pattern::NarrowOp0 | Pattern::NarrowUnsignedOp1 | Pattern::SwapOps01 },
{ "halide.hexagon.mpy.vh.vuh", wild_i32x*wild_i32x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1 | Pattern::SwapOps01 },
{ "halide.hexagon.mul.vw.vh", wild_i32x*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 },
{ "halide.hexagon.mul.vw.vuh", wild_i32x*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 },
{ "halide.hexagon.mul.vuw.vuh", wild_u32x*wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 },
};
if (op->type.is_vector()) {
Expr new_expr = apply_commutative_patterns(op, scalar_muls, target, this);
if (!new_expr.same_as(op)) {
expr = new_expr;
return;
}
new_expr = apply_commutative_patterns(op, muls, target, this);
if (!new_expr.same_as(op)) {
expr = new_expr;
return;
}
}
IRMutator::visit(op);
}
static Expr halide_hexagon_add_2mpy(Type result_type, string suffix, Expr v0, Expr v1, Expr c0, Expr c1) {
Expr call = Call::make(result_type, "halide.hexagon.add_2mpy" + suffix, {v0, v1, c0, c1}, Call::PureExtern);
return native_interleave(call);
}
static Expr halide_hexagon_add_2mpy(Type result_type, string suffix, Expr v01, Expr c01) {
return Call::make(result_type, "halide.hexagon.add_2mpy" + suffix, {v01, c01}, Call::PureExtern);
}
static Expr halide_hexagon_add_4mpy(Type result_type, string suffix, Expr v01, Expr c01) {
return Call::make(result_type, "halide.hexagon.add_4mpy" + suffix, {v01, c01}, Call::PureExtern);
}
typedef pair<Expr, Expr> MulExpr;
static Expr unbroadcast_lossless_cast(Type ty, Expr x) {
if (ty.lanes() == 1 && x.type().lanes() > 1) {
if (const Broadcast *bc = x.as<Broadcast>()) {
x = bc->value;
}
}
if (ty.lanes() != x.type().lanes()) {
return Expr();
}
return lossless_cast(ty, x);
}
static int find_mpy_ops(Expr op, Type a_ty, Type b_ty, int max_mpy_count,
vector<MulExpr> &mpys, Expr &rest) {
if ((int)mpys.size() >= max_mpy_count) {
rest = rest.defined() ? Add::make(rest, op) : op;
return 0;
}
int mpy_bits = std::max(a_ty.bits(), b_ty.bits())*2;
Expr maybe_mul = op;
if (op.type().bits() == mpy_bits*2) {
if (const Cast *cast = op.as<Cast>()) {
if (cast->value.type().bits() == mpy_bits) {
maybe_mul = cast->value;
}
}
}
if (const Mul *mul = maybe_mul.as<Mul>()) {
Expr a = unbroadcast_lossless_cast(a_ty, mul->a);
Expr b = unbroadcast_lossless_cast(b_ty, mul->b);
if (a.defined() && b.defined()) {
mpys.emplace_back(a, b);
return 1;
} else {
a = unbroadcast_lossless_cast(a_ty, mul->b);
b = unbroadcast_lossless_cast(b_ty, mul->a);
if (a.defined() && b.defined()) {
mpys.emplace_back(a, b);
return 1;
}
}
} else if (const Add *add = op.as<Add>()) {
int mpy_count = 0;
mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
} else if (const Sub *sub = op.as<Sub>()) {
if (const Mul *mul_b = sub->b.as<Mul>()) {
if (is_positive_const(mul_b->a) || is_negative_negatable_const(mul_b->a)) {
Expr add_b = Mul::make(simplify(-mul_b->a), mul_b->b);
int mpy_count = 0;
mpy_count += find_mpy_ops(sub->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add_b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
} else if (is_positive_const(mul_b->b) || is_negative_negatable_const(mul_b->b)) {
Expr add_b = Mul::make(mul_b->a, simplify(-mul_b->b));
int mpy_count = 0;
mpy_count += find_mpy_ops(sub->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add_b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
}
}
}
Expr as_a = unbroadcast_lossless_cast(a_ty, op);
Expr as_b = unbroadcast_lossless_cast(b_ty, op);
if (as_a.defined()) {
mpys.emplace_back(as_a, make_one(b_ty));
} else if (as_b.defined()) {
mpys.emplace_back(make_one(a_ty), as_b);
} else {
rest = rest.defined() ? Add::make(rest, op) : op;
}
return 0;
}
void visit(const Add *op) {
if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string suffix;
int mpy_count = 0;
if (op->type.is_uint()) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8), 4, mpys, rest);
suffix = ".vub.ub";
} else {
mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 4, mpys, rest);
suffix = ".vub.b";
}
if (mpy_count > 0 && mpys.size() == 4) {
Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
a0123 = simplify(a0123);
if (op->type.bits() == 32 || !a0123.as<Shuffle>()) {
Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
b0123 = simplify(b0123);
b0123 = reinterpret(Type(b0123.type().code(), 32, 1), b0123);
Expr new_expr = halide_hexagon_add_4mpy(op->type, suffix, a0123, b0123);
if (op->type.bits() == 16) {
new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
expr = mutate(new_expr);
return;
}
}
mpys.clear();
rest = Expr();
if (op->type.is_uint()) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest);
suffix = ".vub.vub";
} else {
mpy_count = find_mpy_ops(op, Int(8, lanes), Int(8, lanes), 4, mpys, rest);
suffix = ".vb.vb";
}
if (mpy_count > 0 && mpys.size() == 4) {
Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
a0123 = simplify(a0123);
b0123 = simplify(b0123);
if (op->type.bits() == 32 || (!a0123.as<Shuffle>() && !b0123.as<Shuffle>())) {
Expr new_expr = halide_hexagon_add_4mpy(op->type, suffix, a0123, b0123);
if (op->type.bits() == 16) {
new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
expr = mutate(new_expr);
return;
}
}
}
if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string vmpa_suffix;
string vdmpy_suffix;
int mpy_count = 0;
if (op->type.bits() == 16) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 2, mpys, rest);
vmpa_suffix = ".vub.vub.b.b";
vdmpy_suffix = ".vub.b";
} else if (op->type.bits() == 32) {
mpy_count = find_mpy_ops(op, Int(16, lanes), Int(8), 2, mpys, rest);
vmpa_suffix = ".vh.vh.b.b";
vdmpy_suffix = ".vh.b";
}
if (mpy_count > 0 && mpys.size() == 2) {
Expr a01 = Shuffle::make_interleave({mpys[0].first, mpys[1].first});
a01 = simplify(a01);
Expr new_expr;
if (!a01.as<Shuffle>() || vmpa_suffix.empty()) {
Expr b01 = Shuffle::make_interleave({mpys[0].second, mpys[1].second});
b01 = simplify(b01);
b01 = reinterpret(Type(b01.type().code(), 16, 1), b01);
new_expr = halide_hexagon_add_2mpy(op->type, vdmpy_suffix, a01, b01);
} else {
new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
expr = mutate(new_expr);
return;
}
}
static const vector<Pattern> adds = {
{ "halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.vub.b.b", wild_u8x, wild_u8x, wild_i8, wild_i8), Pattern::ReinterleaveOp0 },
{ "halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.vh.b.b", wild_i16x, wild_i16x, wild_i8, wild_i8), Pattern::ReinterleaveOp0 },
{ "halide.hexagon.acc_add_2mpy.vh.vub.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i16) },
{ "halide.hexagon.acc_add_2mpy.vw.vh.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.b", wild_i16x, wild_i16) },
{ "halide.hexagon.acc_add_4mpy.vw.vub.b", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.b", wild_u8x, wild_i32) },
{ "halide.hexagon.acc_add_4mpy.vuw.vub.ub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.ub", wild_u8x, wild_u32) },
{ "halide.hexagon.acc_add_4mpy.vuw.vub.vub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.vub", wild_u8x, wild_u8x) },
{ "halide.hexagon.acc_add_4mpy.vw.vub.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.vb", wild_u8x, wild_i8x) },
{ "halide.hexagon.acc_add_4mpy.vw.vb.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vb.vb", wild_i8x, wild_i8x) },
{ "halide.hexagon.add_vuh.vub.vub", wild_u16x + wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.add_vuw.vuh.vuh", wild_u32x + wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.add_vw.vh.vh", wild_i32x + wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.add_mpy.vuh.vub.ub", wild_u16x + wild_u16x*bc(wild_u16), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vh.vub.b", wild_i16x + wild_i16x*bc(wild_i16), Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vuw.vuh.uh", wild_u32x + wild_u32x*bc(wild_u32), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vuh.vub.ub", wild_u16x + bc(wild_u16)*wild_u16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mpy.vh.vub.b", wild_i16x + bc(wild_i16)*wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mpy.vuw.vuh.uh", wild_u32x + bc(wild_u32)*wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.satw_add_mpy.vw.vh.h", wild_i32x + wild_i32x*bc(wild_i32), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.satw_add_mpy.vw.vh.h", wild_i32x + bc(wild_i32)*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mpy.vuh.vub.vub", wild_u16x + wild_u16x*wild_u16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vuw.vuh.vuh", wild_u32x + wild_u32x*wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vh.vb.vb", wild_i16x + wild_i16x*wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vw.vh.vh", wild_i32x + wild_i32x*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + wild_i16x*wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2 },
{ "halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + wild_i32x*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2 },
{ "halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + wild_i16x*wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + wild_i32x*wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12 },
{ "halide.hexagon.add_shr.vw.vw.w", wild_i32x + (wild_i32x >> bc(wild_i32)) },
{ "halide.hexagon.add_shl.vw.vw.w", wild_i32x + (wild_i32x << bc(wild_i32)) },
{ "halide.hexagon.add_shl.vw.vw.w", wild_u32x + (wild_u32x << bc(wild_u32)) },
{ "halide.hexagon.add_shr.vw.vw.w", wild_i32x + (wild_i32x/bc(wild_i32)), Pattern::ExactLog2Op2 },
{ "halide.hexagon.add_shl.vw.vw.w", wild_i32x + (wild_i32x*bc(wild_i32)), Pattern::ExactLog2Op2 },
{ "halide.hexagon.add_shl.vw.vw.w", wild_u32x + (wild_u32x*bc(wild_u32)), Pattern::ExactLog2Op2 },
{ "halide.hexagon.add_shl.vw.vw.w", wild_i32x + (bc(wild_i32)*wild_i32x), Pattern::ExactLog2Op1 | Pattern::SwapOps12 },
{ "halide.hexagon.add_shl.vw.vw.w", wild_u32x + (bc(wild_u32)*wild_u32x), Pattern::ExactLog2Op1 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mul.vh.vh.b", wild_i16x + wild_i16x*bc(wild_i16), Pattern::NarrowOp2 },
{ "halide.hexagon.add_mul.vw.vw.h", wild_i32x + wild_i32x*bc(wild_i32), Pattern::NarrowOp2 },
{ "halide.hexagon.add_mul.vh.vh.b", wild_i16x + bc(wild_i16)*wild_i16x, Pattern::NarrowOp1 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mul.vw.vw.h", wild_i32x + bc(wild_i32)*wild_i32x, Pattern::NarrowOp1 | Pattern::SwapOps12 },
{ "halide.hexagon.add_mul.vh.vh.vh", wild_i16x + wild_i16x*wild_i16x },
};
if (op->type.is_vector()) {
Expr new_expr = apply_commutative_patterns(op, adds, target, this);
if (!new_expr.same_as(op)) {
expr = new_expr;
return;
}
}
IRMutator::visit(op);
}
void visit(const Sub *op) {
if (op->type.is_vector()) {
Expr neg_b = lossless_negate(op->b);
if (neg_b.defined()) {
expr = mutate(op->a + neg_b);
return;
} else {
static const vector<Pattern> subs = {
{ "halide.hexagon.sub_vuh.vub.vub", wild_u16x - wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.sub_vuw.vuh.vuh", wild_u32x - wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps },
{ "halide.hexagon.sub_vw.vh.vh", wild_i32x - wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps },
};
Expr new_expr = apply_patterns(op, subs, target, this);
if (!new_expr.same_as(op)) {
expr = new_expr;
return;
}
}
}
IRMutator::visit(op);
}
void visit(const Max *op) {
IRMutator::visit(op);
if (op->type.is_vector()) {
static const pair<string, Expr> cl[] = {
{ "halide.hexagon.cls.vh", max(count_leading_zeros(wild_i16x), count_leading_zeros(~wild_i16x)) },
{ "halide.hexagon.cls.vw", max(count_leading_zeros(wild_i32x), count_leading_zeros(~wild_i32x)) },
};
vector<Expr> matches;
for (const auto &i : cl) {
if (expr_match(i.second, expr, matches) && equal(matches[0], matches[1])) {
expr = Call::make(op->type, i.first, {matches[0]}, Call::PureExtern) + 1;
return;
}
}
}
}
void visit(const Cast *op) {
static const vector<Pattern> casts = {
{ "halide.hexagon.avg.vub.vub", u8((wild_u16x + wild_u16x)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg.vuh.vuh", u16((wild_u32x + wild_u32x)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg.vh.vh", i16((wild_i32x + wild_i32x)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg.vw.vw", i32((wild_i64x + wild_i64x)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg_rnd.vub.vub", u8((wild_u16x + wild_u16x + 1)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg_rnd.vuh.vuh", u16((wild_u32x + wild_u32x + 1)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg_rnd.vh.vh", i16((wild_i32x + wild_i32x + 1)/2), Pattern::NarrowOps },
{ "halide.hexagon.avg_rnd.vw.vw", i32((wild_i64x + wild_i64x + 1)/2), Pattern::NarrowOps },
{ "halide.hexagon.navg.vub.vub", i8_sat((wild_i16x - wild_i16x)/2), Pattern::NarrowUnsignedOps },
{ "halide.hexagon.navg.vh.vh", i16_sat((wild_i32x - wild_i32x)/2), Pattern::NarrowOps },
{ "halide.hexagon.navg.vw.vw", i32_sat((wild_i64x - wild_i64x)/2), Pattern::NarrowOps },
{ "halide.hexagon.satub_add.vub.vub", u8_sat(wild_u16x + wild_u16x), Pattern::NarrowOps },
{ "halide.hexagon.satuh_add.vuh.vuh", u16_sat(wild_u32x + wild_u32x), Pattern::NarrowOps },
{ "halide.hexagon.satuw_add.vuw.vuw", u32_sat(wild_u64x + wild_u64x), Pattern::NarrowOps | Pattern::v62 },
{ "halide.hexagon.sath_add.vh.vh", i16_sat(wild_i32x + wild_i32x), Pattern::NarrowOps },
{ "halide.hexagon.satw_add.vw.vw", i32_sat(wild_i64x + wild_i64x), Pattern::NarrowOps },
{ "halide.hexagon.satub_sub.vub.vub", u8_sat(wild_i16x - wild_i16x), Pattern::NarrowUnsignedOps },
{ "halide.hexagon.satuh_sub.vuh.vuh", u16_sat(wild_i32x - wild_i32x), Pattern::NarrowUnsignedOps },
{ "halide.hexagon.sath_sub.vh.vh", i16_sat(wild_i32x - wild_i32x), Pattern::NarrowOps },
{ "halide.hexagon.satw_sub.vw.vw", i32_sat(wild_i64x - wild_i64x), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satub_rnd.vh", u8_sat((wild_i32x + 128)/256), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0 },
{ "halide.hexagon.trunc_satb_rnd.vh", i8_sat((wild_i32x + 128)/256), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0 },
{ "halide.hexagon.trunc_satuh_rnd.vw", u16_sat((wild_i64x + 32768)/65536), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0 },
{ "halide.hexagon.trunc_sath_rnd.vw", i16_sat((wild_i64x + 32768)/65536), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0 },
{ "halide.hexagon.trunc_mpy.vw.vw", i32((wild_i64x*wild_i64x)/Expr(static_cast<int64_t>(1) << 32)), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satw_mpy2.vh.h", i16_sat((wild_i32x*bc(wild_i32))/32768), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satw_mpy2.vh.h", i16_sat((bc(wild_i32)*wild_i32x)/32768), Pattern::NarrowOps | Pattern::SwapOps01 },
{ "halide.hexagon.trunc_satw_mpy2_rnd.vh.h", i16_sat((wild_i32x*bc(wild_i32) + 16384)/32768), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satw_mpy2_rnd.vh.h", i16_sat((bc(wild_i32)*wild_i32x + 16384)/32768), Pattern::NarrowOps | Pattern::SwapOps01 },
{ "halide.hexagon.trunc_satw_mpy2_rnd.vh.vh", i16_sat((wild_i32x*wild_i32x + 16384)/32768), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satdw_mpy2.vw.vw", i32_sat((wild_i64x*wild_i64x)/Expr(static_cast<int64_t>(1) << 31)), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satdw_mpy2_rnd.vw.vw", i32_sat((wild_i64x*wild_i64x + (1 << 30))/Expr(static_cast<int64_t>(1) << 31)), Pattern::NarrowOps },
{ "halide.hexagon.trunc_satub_shr.vh.h", u8_sat(wild_i16x >> wild_i16), Pattern::DeinterleaveOp0 },
{ "halide.hexagon.trunc_satuh_shr.vw.w", u16_sat(wild_i32x >> wild_i32), Pattern::DeinterleaveOp0 },
{ "halide.hexagon.trunc_sath_shr.vw.w", i16_sat(wild_i32x >> wild_i32), Pattern::DeinterleaveOp0 },
{ "halide.hexagon.trunc_satub_shr.vh.h", u8_sat(wild_i16x/wild_i16), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1 },
{ "halide.hexagon.trunc_satuh_shr.vw.w", u16_sat(wild_i32x/wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1 },
{ "halide.hexagon.trunc_sath_shr.vw.w", i16_sat(wild_i32x/wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1 },
{ "halide.hexagon.pack_satub.vh", u8_sat(wild_i16x) },
{ "halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x) },
{ "halide.hexagon.pack_satb.vh", i8_sat(wild_i16x) },
{ "halide.hexagon.pack_sath.vw", i16_sat(wild_i32x) },
{ "halide.hexagon.trunc_satuh.vuw", u16_sat(wild_u32x), Pattern::DeinterleaveOp0 | Pattern::v62 },
{ "halide.hexagon.packhi.vh", u8(wild_u16x/256) },
{ "halide.hexagon.packhi.vh", u8(wild_i16x/256) },
{ "halide.hexagon.packhi.vh", i8(wild_u16x/256) },
{ "halide.hexagon.packhi.vh", i8(wild_i16x/256) },
{ "halide.hexagon.packhi.vw", u16(wild_u32x/65536) },
{ "halide.hexagon.packhi.vw", u16(wild_i32x/65536) },
{ "halide.hexagon.packhi.vw", i16(wild_u32x/65536) },
{ "halide.hexagon.packhi.vw", i16(wild_i32x/65536) },
{ "halide.hexagon.trunc_shr.vw.w", i16(wild_i32x >> wild_i32), Pattern::DeinterleaveOp0 },
{ "halide.hexagon.trunc_shr.vw.w", i16(wild_i32x/wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1 },
{ "halide.hexagon.pack.vh", u8(wild_u16x) },
{ "halide.hexagon.pack.vh", u8(wild_i16x) },
{ "halide.hexagon.pack.vh", i8(wild_u16x) },
{ "halide.hexagon.pack.vh", i8(wild_i16x) },
{ "halide.hexagon.pack.vw", u16(wild_u32x) },
{ "halide.hexagon.pack.vw", u16(wild_i32x) },
{ "halide.hexagon.pack.vw", i16(wild_u32x) },
{ "halide.hexagon.pack.vw", i16(wild_i32x) },
{ "halide.hexagon.zxt.vub", u16(wild_u8x), Pattern::InterleaveResult },
{ "halide.hexagon.zxt.vub", i16(wild_u8x), Pattern::InterleaveResult },
{ "halide.hexagon.zxt.vuh", u32(wild_u16x), Pattern::InterleaveResult },
{ "halide.hexagon.zxt.vuh", i32(wild_u16x), Pattern::InterleaveResult },
{ "halide.hexagon.sxt.vb", u16(wild_i8x), Pattern::InterleaveResult },
{ "halide.hexagon.sxt.vb", i16(wild_i8x), Pattern::InterleaveResult },
{ "halide.hexagon.sxt.vh", u32(wild_i16x), Pattern::InterleaveResult },
{ "halide.hexagon.sxt.vh", i32(wild_i16x), Pattern::InterleaveResult },
};
static const vector<pair<Expr, Expr>> cast_rewrites = {
{ u8_sat(wild_u32x), u8_sat(u16_sat(wild_u32x)) },
{ u8_sat(wild_i32x), u8_sat(i16_sat(wild_i32x)) },
{ i8_sat(wild_u32x), i8_sat(u16_sat(wild_u32x)) },
{ i8_sat(wild_i32x), i8_sat(i16_sat(wild_i32x)) },
{ u8(wild_u32x), u8(u16(wild_u32x)) },
{ u8(wild_i32x), u8(i16(wild_i32x)) },
{ i8(wild_u32x), i8(u16(wild_u32x)) },
{ i8(wild_i32x), i8(i16(wild_i32x)) },
{ u32(wild_u8x), u32(u16(wild_u8x)) },
{ u32(wild_i8x), u32(i16(wild_i8x)) },
{ i32(wild_u8x), i32(u16(wild_u8x)) },
{ i32(wild_i8x), i32(i16(wild_i8x)) },
};
if (op->type.is_vector()) {
Expr cast = op;
Expr new_expr = apply_patterns(cast, casts, target, this);
if (!new_expr.same_as(cast)) {
expr = new_expr;
return;
}
vector<Expr> matches;
for (auto i : cast_rewrites) {
if (expr_match(i.first, cast, matches)) {
debug(3) << "rewriting cast to: " << i.first << " from " << cast << "\n";
Expr replacement = with_lanes(i.second, op->type.lanes());
expr = substitute("*", matches[0], replacement);
expr = mutate(expr);
return;
}
}
}
IRMutator::visit(op);
}
void visit(const Call *op) {
if (op->is_intrinsic(Call::lerp)) {
internal_assert(op->args.size() == 3);
expr = mutate(lower_lerp(op->args[0], op->args[1], op->args[2]));
} else if (op->is_intrinsic(Call::cast_mask)) {
internal_assert(op->args.size() == 1);
Type src_type = op->args[0].type();
Type dst_type = op->type;
if (dst_type.bits() < src_type.bits()) {
expr = mutate(Cast::make(dst_type, op->args[0]));
} else {
Expr e = op->args[0];
while (src_type.bits() < dst_type.bits()) {
e = Shuffle::make_interleave({e, e});
src_type = src_type.with_bits(src_type.bits()*2);
e = reinterpret(src_type, e);
}
expr = mutate(e);
}
} else {
IRMutator::visit(op);
}
}
public:
OptimizePatterns(Target t) {
target = t;
}
};
class EliminateInterleaves : public IRMutator {
Scope<bool> vars;
int native_vector_bits;
bool in_bool_to_mask = false;
bool interleave_mask = false;
bool yields_removable_interleave(Expr x) {
if (is_native_interleave(x)) {
return true;
}
if (const Let *let = x.as<Let>()) {
return yields_removable_interleave(let->body);
}
const Variable *var = x.as<Variable>();
if (var && vars.contains(var->name + ".deinterleaved")) {
return true;
}
return false;
}
bool yields_interleave(Expr x) {
if (yields_removable_interleave(x)) {
return true;
}
if (x.type().is_scalar() || x.as<Broadcast>()) {
return true;
}
if (const Let *let = x.as<Let>()) {
return yields_interleave(let->body);
}
const Variable *var = x.as<Variable>();
if (var && vars.contains(var->name + ".weak_deinterleaved")) {
return true;
}
return false;
}
bool yields_removable_interleave(const vector<Expr> &exprs) {
bool any_is_interleave = false;
for (const Expr &i : exprs) {
if (yields_removable_interleave(i)) {
any_is_interleave = true;
} else if (!yields_interleave(i)) {
return false;
}
}
return any_is_interleave;
}
Expr remove_interleave(Expr x) {
if (is_native_interleave(x)) {
return x.as<Call>()->args[0];
} else if (x.type().is_scalar() || x.as<Broadcast>()) {
return x;
}
if (const Variable *var = x.as<Variable>()) {
if (vars.contains(var->name + ".deinterleaved")) {
return Variable::make(var->type, var->name + ".deinterleaved");
} else if (vars.contains(var->name + ".weak_deinterleaved")) {
return Variable::make(var->type, var->name + ".weak_deinterleaved");
}
}
if (const Let *let = x.as<Let>()) {
Expr body = remove_interleave(let->body);
if (!body.same_as(let->body)) {
return Let::make(let->name, let->value, remove_interleave(let->body));
} else {
return x;
}
}
internal_error << "Expression '" << x << "' does not yield an interleave.\n";
return x;
}
template <typename T>
void visit_binary(const T* op) {
Expr a = mutate(op->a);
Expr b = mutate(op->b);
if (yields_removable_interleave({a, b})) {
a = remove_interleave(a);
b = remove_interleave(b);
expr = T::make(a, b);
if (expr.type().bits() == 1) {
internal_assert(!interleave_mask);
interleave_mask = true;
} else {
expr = native_interleave(expr);
}
} else if (!a.same_as(op->a) || !b.same_as(op->b)) {
expr = T::make(a, b);
} else {
expr = op;
}
}
void visit(const Add *op) { visit_binary(op); }
void visit(const Sub *op) { visit_binary(op); }
void visit(const Mul *op) { visit_binary(op); }
void visit(const Div *op) { visit_binary(op); }
void visit(const Mod *op) { visit_binary(op); }
void visit(const Min *op) { visit_binary(op); }
void visit(const Max *op) { visit_binary(op); }
void visit(const EQ *op) { visit_binary(op); }
void visit(const NE *op) { visit_binary(op); }
void visit(const LT *op) { visit_binary(op); }
void visit(const LE *op) { visit_binary(op); }
void visit(const GT *op) { visit_binary(op); }
void visit(const GE *op) { visit_binary(op); }
void visit(const And *op) {
internal_assert(op->type.is_scalar());
IRMutator::visit(op);
}
void visit(const Or *op) {
internal_assert(op->type.is_scalar());
IRMutator::visit(op);
}
void visit(const Not *op) {
internal_assert(op->type.is_scalar());
IRMutator::visit(op);
}
void visit(const Select *op) {
Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
internal_assert(op->condition.type().is_scalar());
Expr cond = mutate(op->condition);
if (yields_removable_interleave({true_value, false_value})) {
true_value = remove_interleave(true_value);
false_value = remove_interleave(false_value);
expr = native_interleave(Select::make(cond, true_value, false_value));
} else if (!cond.same_as(op->condition) ||
!true_value.same_as(op->true_value) ||
!false_value.same_as(op->false_value)) {
expr = Select::make(cond, true_value, false_value);
} else {
expr = op;
}
}
static bool uses_var(Stmt s, const string &var) {
return stmt_uses_var(s, var);
}
static bool uses_var(Expr e, const string &var) {
return expr_uses_var(e, var);
}
template <typename NodeType, typename LetType>
void visit_let(NodeType &result, const LetType *op) {
Expr value = mutate(op->value);
string deinterleaved_name;
NodeType body;
if (yields_removable_interleave(value)) {
deinterleaved_name = op->name + ".deinterleaved";
vars.push(deinterleaved_name, true);
body = mutate(op->body);
vars.pop(deinterleaved_name);
} else if (yields_interleave(value)) {
deinterleaved_name = op->name + ".weak_deinterleaved";
vars.push(deinterleaved_name, true);
body = mutate(op->body);
vars.pop(deinterleaved_name);
} else {
body = mutate(op->body);
}
if (value.same_as(op->value) && body.same_as(op->body)) {
result = op;
} else if (body.same_as(op->body)) {
result = LetType::make(op->name, value, body);
} else {
result = body;
bool deinterleaved_used = uses_var(result, deinterleaved_name);
bool interleaved_used = uses_var(result, op->name);
if (deinterleaved_used && interleaved_used) {
Expr deinterleaved = remove_interleave(value);
Expr interleaved = Variable::make(deinterleaved.type(), deinterleaved_name);
if (!deinterleaved.same_as(value)) {
interleaved = native_interleave(interleaved);
}
result = LetType::make(op->name, interleaved, result);
result = LetType::make(deinterleaved_name, deinterleaved, result);
} else if (deinterleaved_used) {
result = LetType::make(deinterleaved_name, remove_interleave(value), result);
} else if (interleaved_used) {
result = LetType::make(op->name, value, result);
} else {
internal_assert(!uses_var(op->body, op->name)) << "EliminateInterleaves eliminated a non-dead let.\n";
}
}
}
void visit(const Let *op) {
visit_let(expr, op);
const Let *let = expr.as<Let>();
if (yields_removable_interleave(let->body)) {
expr = native_interleave(Let::make(let->name, let->value, remove_interleave(let->body)));
}
}
void visit(const LetStmt *op) { visit_let(stmt, op); }
void visit(const Cast *op) {
if (op->type.bits() == op->value.type().bits()) {
Expr value = mutate(op->value);
if (yields_removable_interleave(value)) {
value = remove_interleave(value);
expr = native_interleave(Cast::make(op->type, value));
} else if (!value.same_as(op->value)) {
expr = Cast::make(op->type, value);
} else {
expr = op;
}
} else {
IRMutator::visit(op);
}
}
static bool is_interleavable(const Call *op) {
static const set<string> interleavable = {
Call::bitwise_and,
Call::bitwise_not,
Call::bitwise_xor,
Call::bitwise_or,
Call::shift_left,
Call::shift_right,
Call::abs,
Call::absd,
Call::select_mask
};
if (interleavable.count(op->name) != 0) return true;
static const set<string> not_interleavable = {
"halide.hexagon.interleave.vb",
"halide.hexagon.interleave.vh",
"halide.hexagon.interleave.vw",
"halide.hexagon.deinterleave.vb",
"halide.hexagon.deinterleave.vh",
"halide.hexagon.deinterleave.vw",
};
if (not_interleavable.count(op->name) != 0) return false;
if (starts_with(op->name, "halide.hexagon.")) {
for (Expr i : op->args) {
if (i.type().is_scalar()) continue;
if (i.type().bits() != op->type.bits() || i.type().lanes() != op->type.lanes()) {
return false;
}
}
}
return true;
}
void visit_bool_to_mask(const Call *op) {
bool old_in_bool_to_mask = in_bool_to_mask;
in_bool_to_mask = true;
Expr arg = mutate(op->args[0]);
if (!arg.same_as(op->args[0]) || interleave_mask) {
expr = Call::make(op->type, Call::bool_to_mask, {arg}, Call::PureIntrinsic);
if (interleave_mask) {
expr = native_interleave(expr);
interleave_mask = false;
}
} else {
expr = op;
}
in_bool_to_mask = old_in_bool_to_mask;
}
void visit(const Call *op) {
if (op->is_intrinsic(Call::bool_to_mask)) {
visit_bool_to_mask(op);
return;
}
vector<Expr> args(op->args);
bool changed = false;
for (Expr &i : args) {
Expr new_i = mutate(i);
changed = changed || !new_i.same_as(i);
i = new_i;
}
static std::map<string, string> deinterleaving_alts = {
{ "halide.hexagon.pack.vh", "halide.hexagon.trunc.vh" },
{ "halide.hexagon.pack.vw", "halide.hexagon.trunc.vw" },
{ "halide.hexagon.packhi.vh", "halide.hexagon.trunclo.vh" },
{ "halide.hexagon.packhi.vw", "halide.hexagon.trunclo.vw" },
{ "halide.hexagon.pack_satub.vh", "halide.hexagon.trunc_satub.vh" },
{ "halide.hexagon.pack_sath.vw", "halide.hexagon.trunc_sath.vw" },
{ "halide.hexagon.pack_satuh.vw", "halide.hexagon.trunc_satuh.vw" },
};
static std::map<string, string> interleaving_alts = {
{ "halide.hexagon.trunc.vh", "halide.hexagon.pack.vh" },
{ "halide.hexagon.trunc.vw", "halide.hexagon.pack.vw" },
{ "halide.hexagon.trunclo.vh", "halide.hexagon.packhi.vh" },
{ "halide.hexagon.trunclo.vw", "halide.hexagon.packhi.vw" },
{ "halide.hexagon.trunc_satub.vh", "halide.hexagon.pack_satub.vh" },
{ "halide.hexagon.trunc_sath.vw", "halide.hexagon.pack_sath.vw" },
{ "halide.hexagon.trunc_satuh.vw", "halide.hexagon.pack_satuh.vw" },
};
if (is_native_deinterleave(op) && yields_interleave(args[0])) {
expr = remove_interleave(args[0]);
} else if (is_interleavable(op) && yields_removable_interleave(args)) {
for (Expr &i : args) {
i = remove_interleave(i);
}
expr = Call::make(op->type, op->name, args, op->call_type,
op->func, op->value_index, op->image, op->param);
expr = native_interleave(expr);
} else if (deinterleaving_alts.find(op->name) != deinterleaving_alts.end() &&
yields_removable_interleave(args)) {
for (Expr &i : args) {
i = remove_interleave(i);
}
expr = Call::make(op->type, deinterleaving_alts[op->name], args, op->call_type);
} else if (interleaving_alts.count(op->name) && is_native_deinterleave(args[0])) {
Expr arg = args[0].as<Call>()->args[0];
expr = Call::make(op->type, interleaving_alts[op->name], { arg }, op->call_type,
op->func, op->value_index, op->image, op->param);
} else if (changed) {
expr = Call::make(op->type, op->name, args, op->call_type,
op->func, op->value_index, op->image, op->param);
} else {
expr = op;
}
}
enum class BufferState {
Unknown,
Interleaved,
NotInterleaved,
};
Scope<BufferState> buffers;
Scope<bool> deinterleave_buffers;
void visit(const Allocate *op) {
Expr condition = mutate(op->condition);
buffers.push(op->name, BufferState::Unknown);
Stmt body = mutate(op->body);
bool deinterleave = buffers.get(op->name) == BufferState::Interleaved;
buffers.pop(op->name);
if (deinterleave) {
deinterleave_buffers.push(op->name, true);
body = mutate(op->body);
deinterleave_buffers.pop(op->name);
}
if (!body.same_as(op->body) || !condition.same_as(op->condition)) {
stmt = Allocate::make(op->name, op->type, op->extents, condition, body,
op->new_expr, op->free_function);
} else {
stmt = op;
}
}
void visit(const Store *op) {
Expr predicate = mutate(op->predicate);
Expr value = mutate(op->value);
Expr index = mutate(op->index);
if (buffers.contains(op->name)) {
BufferState &state = buffers.ref(op->name);
if (!is_one(predicate)) {
state = BufferState::NotInterleaved;
} else if (yields_removable_interleave(value)) {
if (state == BufferState::Unknown) {
state = BufferState::Interleaved;
}
} else if (!yields_interleave(value)) {
state = BufferState::NotInterleaved;
} else {
}
}
if (deinterleave_buffers.contains(op->name)) {
internal_assert(is_one(predicate)) << "The store shouldn't have been predicated.\n";
value = remove_interleave(value);
}
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
stmt = op;
} else {
stmt = Store::make(op->name, value, index, op->param, predicate);
}
}
void visit(const Load *op) {
if (buffers.contains(op->name)) {
if ((op->type.lanes()*op->type.bits()) % (native_vector_bits*2) == 0) {
} else {
BufferState &state = buffers.ref(op->name);
state = BufferState::NotInterleaved;
}
}
IRMutator::visit(op);
if (deinterleave_buffers.contains(op->name)) {
expr = native_interleave(expr);
}
}
using IRMutator::visit;
public:
EliminateInterleaves(int native_vector_bits) : native_vector_bits(native_vector_bits) {}
};
class FuseInterleaves : public IRMutator {
void visit(const Call *op) {
static const std::vector<std::pair<string, string>> non_deinterleaving_alts = {
{ "halide.hexagon.zxt.vub", "halide.hexagon.unpack.vub" },
{ "halide.hexagon.sxt.vb", "halide.hexagon.unpack.vb" },
{ "halide.hexagon.zxt.vuh", "halide.hexagon.unpack.vuh" },
{ "halide.hexagon.sxt.vh", "halide.hexagon.unpack.vh" },
};
if (is_native_interleave(op)) {
if (const Call *arg = op->args[0].as<Call>()) {
for (const auto &i : non_deinterleaving_alts) {
if (arg->name == i.first) {
std::vector<Expr> args = arg->args;
for (Expr &j : args) {
j = mutate(j);
}
expr = Call::make(op->type, i.second, args, Call::PureExtern);
return;
}
}
}
}
IRMutator::visit(op);
}
using IRMutator::visit;
};
Expr span_of_bounds(Interval bounds) {
internal_assert(bounds.is_bounded());
const Min *min_min = bounds.min.as<Min>();
const Max *min_max = bounds.min.as<Max>();
const Min *max_min = bounds.max.as<Min>();
const Max *max_max = bounds.max.as<Max>();
const Add *min_add = bounds.min.as<Add>();
const Add *max_add = bounds.max.as<Add>();
const Sub *min_sub = bounds.min.as<Sub>();
const Sub *max_sub = bounds.max.as<Sub>();
if (min_min && max_min && equal(min_min->b, max_min->b)) {
return span_of_bounds({min_min->a, max_min->a});
} else if (min_max && max_max && equal(min_max->b, max_max->b)) {
return span_of_bounds({min_max->a, max_max->a});
} else if (min_add && max_add && equal(min_add->b, max_add->b)) {
return span_of_bounds({min_add->a, max_add->a});
} else if (min_sub && max_sub && equal(min_sub->b, max_sub->b)) {
return span_of_bounds({min_sub->a, max_sub->a});
} else {
return bounds.max - bounds.min;
}
}
class OptimizeShuffles : public IRMutator {
int lut_alignment;
Scope<Interval> bounds;
std::vector<std::pair<string, Expr>> lets;
using IRMutator::visit;
template <typename T>
void visit_let(const T *op) {
if (op->value.type().is_vector()) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
}
IRMutator::visit(op);
if (op->value.type().is_vector()) {
bounds.pop(op->name);
}
}
void visit(const Let *op) {
lets.push_back({op->name, op->value});
visit_let(op);
lets.pop_back();
}
void visit(const LetStmt *op) { visit_let(op); }
void visit(const Load *op) {
if (!is_one(op->predicate)) {
IRMutator::visit(op);
return;
}
if (!op->type.is_vector() || op->index.as<Ramp>()) {
IRMutator::visit(op);
return;
}
Expr index = mutate(op->index);
Interval unaligned_index_bounds = bounds_of_expr_in_scope(index, bounds);
if (unaligned_index_bounds.is_bounded()) {
int align = lut_alignment / op->type.bytes();
Interval aligned_index_bounds = {
(unaligned_index_bounds.min / align) * align,
((unaligned_index_bounds.max + align) / align) * align - 1
};
for (Interval index_bounds : {aligned_index_bounds, unaligned_index_bounds}) {
Expr index_span = span_of_bounds(index_bounds);
index_span = common_subexpression_elimination(index_span);
index_span = simplify(index_span);
if (can_prove(index_span < 256)) {
int const_extent = as_const_int(index_span) ? *as_const_int(index_span) + 1 : 256;
Expr base = simplify(index_bounds.min);
Expr lut = Load::make(op->type.with_lanes(const_extent), op->name,
Ramp::make(base, 1, const_extent),
op->image, op->param, const_true(const_extent));
index = simplify(cast(UInt(8).with_lanes(op->type.lanes()), index - base));
expr = Call::make(op->type, "dynamic_shuffle", {lut, index, 0, const_extent - 1}, Call::PureIntrinsic);
return;
}
}
}
if (!index.same_as(op->index)) {
expr = Load::make(op->type, op->name, index, op->image, op->param, op->predicate);
} else {
expr = op;
}
}
public:
OptimizeShuffles(int lut_alignment) : lut_alignment(lut_alignment) {}
};
}
Stmt optimize_hexagon_shuffles(Stmt s, int lut_alignment) {
return OptimizeShuffles(lut_alignment).mutate(s);
}
Stmt optimize_hexagon_instructions(Stmt s, Target t) {
s = OptimizePatterns(t).mutate(s);
s = EliminateInterleaves(t.natural_vector_size(Int(8))*8).mutate(s);
s = FuseInterleaves().mutate(s);
return s;
}
}
}