This source file includes following definitions.
- make_load
- visit
- visit_let
- visit
- visit
- align_loads
#include <algorithm>
#include "AlignLoads.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Scope.h"
#include "Bounds.h"
#include "ModulusRemainder.h"
#include "Simplify.h"
using std::vector;
namespace Halide {
namespace Internal {
namespace {
class AlignLoads : public IRMutator {
public:
AlignLoads(int alignment) : required_alignment(alignment) {}
private:
int required_alignment;
Scope<ModulusRemainder> alignment_info;
using IRMutator::visit;
Expr make_load(const Load *load, Expr index) {
internal_assert(is_one(load->predicate)) << "Load should not be predicated.\n";
return mutate(Load::make(load->type.with_lanes(index.type().lanes()), load->name,
index, load->image, load->param, const_true(index.type().lanes())));
}
void visit(const Load *op) {
if (!is_one(op->predicate)) {
IRMutator::visit(op);
return;
}
if (!op->type.is_vector()) {
IRMutator::visit(op);
return;
}
if (op->image.defined()) {
IRMutator::visit(op);
return;
}
Expr index = mutate(op->index);
const Ramp *ramp = index.as<Ramp>();
const int64_t *const_stride = ramp ? as_const_int(ramp->stride) : nullptr;
if (!ramp || !const_stride) {
IRMutator::visit(op);
return;
}
int lanes = ramp->lanes;
int native_lanes = required_alignment / op->type.bytes();
if (!(*const_stride == 1 || *const_stride == 2 || *const_stride == 3)) {
IRMutator::visit(op);
return;
}
int aligned_offset = 0;
bool known_alignment = false;
int base_alignment =
op->param.defined() ? op->param.host_alignment() : required_alignment;
if (base_alignment % required_alignment == 0) {
known_alignment = reduce_expr_modulo(ramp->base, native_lanes, &aligned_offset,
alignment_info);
}
int stride = static_cast<int>(*const_stride);
if (stride != 1) {
internal_assert(stride >= 0);
int shift = known_alignment && aligned_offset < stride ? aligned_offset : 0;
Expr dense_base = simplify(ramp->base - shift);
Expr dense_index = Ramp::make(dense_base, 1, lanes*stride);
Expr dense = make_load(op, dense_index);
expr = Shuffle::make_slice(dense, shift, stride, lanes);
return;
}
internal_assert(stride == 1);
if (lanes < native_lanes) {
Expr native_load = make_load(op, Ramp::make(ramp->base, 1, native_lanes));
expr = Shuffle::make_slice(native_load, 0, 1, lanes);
return;
}
if (lanes > native_lanes) {
vector<Expr> slices;
for (int i = 0; i < lanes; i += native_lanes) {
int slice_lanes = std::min(native_lanes, lanes - i);
Expr slice_base = simplify(ramp->base + i);
slices.push_back(make_load(op, Ramp::make(slice_base, 1, slice_lanes)));
}
expr = Shuffle::make_concat(slices);
return;
}
if (known_alignment && aligned_offset != 0) {
Expr aligned_base = simplify(ramp->base - aligned_offset);
Expr aligned_load = make_load(op, Ramp::make(aligned_base, 1, lanes*2));
expr = Shuffle::make_slice(aligned_load, aligned_offset, 1, lanes);
return;
}
IRMutator::visit(op);
}
template<typename NodeType, typename LetType>
void visit_let(NodeType &result, const LetType *op) {
if (op->value.type() == Int(32)) {
alignment_info.push(op->name, modulus_remainder(op->value, alignment_info));
}
Expr value = mutate(op->value);
NodeType body = mutate(op->body);
if (op->value.type() == Int(32)) {
alignment_info.pop(op->name);
}
if (!value.same_as(op->value) || !body.same_as(op->body)) {
result = LetType::make(op->name, value, body);
} else {
result = op;
}
}
void visit(const Let *op) { visit_let(expr, op); }
void visit(const LetStmt *op) { visit_let(stmt, op); }
};
}
Stmt align_loads(Stmt s, int alignment) {
return AlignLoads(alignment).mutate(s);
}
}
}