This source file includes following definitions.
- visit
- visit
- visit
- visit
- visit
- visit
- has_likely_tag
- visit
- expr_uses_invalid_buffers
- visit
- new_simplification
- visit
- visit
- visit
- visit
- visit
- visit_let
- visit
- visit
- mutate
- visit
- contains_thread_barrier
- visit
- visit
- expr_contains_load
- visit
- visit
- visit
- is_trivial
- visit
- visit
- visit
- visit
- visit
- partition_loops
#include <algorithm>
#include <numeric>
#include "PartitionLoops.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "Solve.h"
#include "IREquality.h"
#include "ExprUsesVar.h"
#include "Substitute.h"
#include "CodeGen_GPU_Dev.h"
#include "Var.h"
#include "CSE.h"
namespace Halide {
namespace Internal {
using std::string;
using std::vector;
using std::pair;
using std::map;
namespace {
class MarkClampedRampsAsLikely : public IRMutator {
using IRMutator::visit;
void visit(const Min *op) {
if (in_index && op->a.as<Ramp>()) {
expr = min(likely(op->a), mutate(op->b));
} else if (in_index && op->b.as<Ramp>()) {
expr = min(mutate(op->a), likely(op->b));
} else {
IRMutator::visit(op);
}
}
void visit(const Max *op) {
if (in_index && op->a.as<Ramp>()) {
expr = max(likely(op->a), mutate(op->b));
} else if (in_index && op->b.as<Ramp>()) {
expr = max(mutate(op->a), likely(op->b));
} else {
IRMutator::visit(op);
}
}
void visit(const Load *op) {
bool old_in_index = in_index;
in_index = true;
IRMutator::visit(op);
in_index = old_in_index;
}
void visit(const Store *op) {
bool old_in_index = in_index;
in_index = true;
Expr index = mutate(op->index);
in_index = old_in_index;
Expr value = mutate(op->value);
Expr predicate = mutate(op->predicate);
if (predicate.same_as(op->predicate) && index.same_as(op->index) && value.same_as(op->value)) {
stmt = op;
} else {
stmt = Store::make(op->name, value, index, op->param, predicate);
}
}
bool in_index = false;
};
class RemoveLikelyTags : public IRMutator {
using IRMutator::visit;
void visit(const Call *op) {
if (op->is_intrinsic(Call::likely)) {
internal_assert(op->args.size() == 1);
expr = mutate(op->args[0]);
} else {
IRMutator::visit(op);
}
}
};
class HasLikelyTag : public IRVisitor {
using IRVisitor::visit;
void visit(const Call *op) {
if (op->is_intrinsic(Call::likely)) {
result = true;
} else {
IRVisitor::visit(op);
}
}
public:
bool result = false;
};
bool has_likely_tag(Expr e) {
HasLikelyTag h;
e.accept(&h);
return h.result;
}
struct Simplification {
Expr condition;
Expr old_expr;
Expr likely_value;
Expr unlikely_value;
bool tight;
Interval interval;
};
class ExprUsesInvalidBuffers : public IRVisitor {
using IRVisitor::visit;
const Scope<int> &invalid_buffers;
void visit(const Load *op) {
if (invalid_buffers.contains(op->name)) {
invalid = true;
} else {
IRVisitor::visit(op);
}
}
public:
ExprUsesInvalidBuffers(const Scope<int> &buffers) : invalid_buffers(buffers), invalid(false) {}
bool invalid;
};
bool expr_uses_invalid_buffers(Expr e, const Scope<int> &invalid_buffers) {
ExprUsesInvalidBuffers uses(invalid_buffers);
e.accept(&uses);
return uses.invalid;
}
class FindSimplifications : public IRVisitor {
using IRVisitor::visit;
Scope<int> depends_on_loop_var;
Scope<int> buffers;
void visit(const Allocate *op) {
buffers.push(op->name, 0);
IRVisitor::visit(op);
}
void new_simplification(Expr condition, Expr old, Expr likely_val, Expr unlikely_val) {
if (!expr_uses_vars(condition, depends_on_loop_var)) {
return;
}
if (expr_uses_invalid_buffers(condition, buffers)) {
return;
}
condition = RemoveLikelyTags().mutate(condition);
Simplification s = {condition, old, likely_val, unlikely_val, true};
if (s.condition.type().is_vector()) {
s.condition = simplify(s.condition);
if (const Broadcast *b = s.condition.as<Broadcast>()) {
s.condition = b->value;
} else {
s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
s.tight = false;
}
}
internal_assert(s.condition.type().is_scalar()) << s.condition << "\n";
simplifications.push_back(s);
}
void visit(const Min *op) {
IRVisitor::visit(op);
bool likely_a = has_likely_tag(op->a);
bool likely_b = has_likely_tag(op->b);
if (likely_b && !likely_a) {
new_simplification(op->b <= op->a, op, op->b, op->a);
} else if (likely_a && !likely_b) {
new_simplification(op->a <= op->b, op, op->a, op->b);
}
}
void visit(const Max *op) {
IRVisitor::visit(op);
bool likely_a = has_likely_tag(op->a);
bool likely_b = has_likely_tag(op->b);
if (likely_b && !likely_a) {
new_simplification(op->b >= op->a, op, op->b, op->a);
} else if (likely_a && !likely_b) {
new_simplification(op->a >= op->b, op, op->a, op->b);
}
}
void visit(const Select *op) {
IRVisitor::visit(op);
bool likely_t = has_likely_tag(op->true_value);
bool likely_f = has_likely_tag(op->false_value);
if (likely_t && !likely_f) {
new_simplification(op->condition, op, op->true_value, op->false_value);
} else if (likely_f && !likely_t) {
new_simplification(!op->condition, op, op->false_value, op->true_value);
}
}
void visit(const IfThenElse *op) {
IRVisitor::visit(op);
const Call *call = op->condition.as<Call>();
if (call && call->is_intrinsic(Call::likely)) {
new_simplification(op->condition, op->condition, const_true(), const_false());
}
}
void visit(const For *op) {
vector<Simplification> old;
old.swap(simplifications);
IRVisitor::visit(op);
for (Simplification &s : simplifications) {
if (expr_uses_var(s.condition, op->name)) {
Scope<Interval> varying;
varying.push(op->name, Interval(op->min, op->min + op->extent - 1));
Expr relaxed = and_condition_over_domain(s.condition, varying);
internal_assert(!expr_uses_var(relaxed, op->name))
<< "Should not have had used the loop var (" << op->name
<< ") any longer\n before: " << s.condition << "\n after: "
<< relaxed << "\n";
if (!equal(relaxed, s.condition)) {
s.tight = false;
}
s.condition = relaxed;
}
}
simplifications.insert(simplifications.end(), old.begin(), old.end());
}
template<typename LetOrLetStmt>
void visit_let(const LetOrLetStmt *op) {
bool varying = expr_uses_vars(op->value, depends_on_loop_var);
if (varying) {
depends_on_loop_var.push(op->name, 0);
}
vector<Simplification> old;
old.swap(simplifications);
IRVisitor::visit(op);
for (Simplification &s : simplifications) {
if (expr_uses_var(s.condition, op->name)) {
s.condition = Let::make(op->name, op->value, s.condition);
}
}
simplifications.insert(simplifications.end(), old.begin(), old.end());
if (varying) {
depends_on_loop_var.pop(op->name);
}
}
void visit(const LetStmt *op) {
visit_let(op);
}
void visit(const Let *op) {
visit_let(op);
}
public:
vector<Simplification> simplifications;
FindSimplifications(const std::string &v) {
depends_on_loop_var.push(v, 0);
}
};
class MakeSimplifications : public IRMutator {
using IRMutator::visit;
const vector<Simplification> &simplifications;
public:
MakeSimplifications(const vector<Simplification> &s) : simplifications(s) {}
using IRMutator::mutate;
Expr mutate(Expr e) {
for (auto const &s : simplifications) {
if (e.same_as(s.old_expr)) {
return mutate(s.likely_value);
}
}
return IRMutator::mutate(e);
}
};
class ContainsThreadBarrier : public IRVisitor {
public:
bool result = false;
protected:
using IRVisitor::visit;
void visit(const Call *op) {
if (op->name == "halide_gpu_thread_barrier") {
result = true;
}
IRVisitor::visit(op);
}
};
bool contains_thread_barrier(Stmt s) {
ContainsThreadBarrier c;
s.accept(&c);
return c.result;
}
class PartitionLoops : public IRMutator {
using IRMutator::visit;
bool in_gpu_loop = false;
void visit(const For *op) {
Stmt body = op->body;
bool old_in_gpu_loop = in_gpu_loop;
in_gpu_loop |= CodeGen_GPU_Dev::is_gpu_var(op->name);
if (in_gpu_loop && contains_thread_barrier(body)) {
IRMutator::visit(op);
in_gpu_loop = old_in_gpu_loop;
return;
}
if (op->device_api == DeviceAPI::GLSL) {
stmt = op;
in_gpu_loop = old_in_gpu_loop;
return;
}
FindSimplifications finder(op->name);
body.accept(&finder);
if (finder.simplifications.empty()) {
IRMutator::visit(op);
return;
}
debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
vector<Expr> min_vals, max_vals;
vector<Simplification> middle_simps, prologue_simps, epilogue_simps;
bool lower_bound_is_tight = true, upper_bound_is_tight = true;
for (auto &s : finder.simplifications) {
s.interval = solve_for_inner_interval(s.condition, op->name);
if (s.tight) {
Interval outer = solve_for_outer_interval(s.condition, op->name);
s.tight &= equal(outer.min, s.interval.min) && equal(outer.max, s.interval.max);
}
debug(3) << "\nSimplification: \n"
<< " condition: " << s.condition << "\n"
<< " old: " << s.old_expr << "\n"
<< " new: " << s.likely_value << "\n"
<< " min: " << s.interval.min << "\n"
<< " max: " << s.interval.max << "\n"
<< " tight: " << s.tight << "\n";
if (!s.interval.is_empty()) {
if (s.interval.has_lower_bound()) {
Expr m = s.interval.min;
if (!s.tight) {
lower_bound_is_tight = false;
}
if (min_vals.empty()) {
min_vals.push_back(m);
} else if (equal(m, min_vals.back())) {
} else {
min_vals.push_back(m);
lower_bound_is_tight = false;
}
}
if (s.interval.has_upper_bound()) {
Expr m = s.interval.max;
if (!s.tight) {
upper_bound_is_tight = false;
}
if (max_vals.empty()) {
max_vals.push_back(m);
} else if (equal(m, max_vals.back())) {
} else {
max_vals.push_back(m);
upper_bound_is_tight = false;
}
}
middle_simps.push_back(s);
}
}
bool can_simplify_prologue = true;
for (Expr min_val : min_vals) {
for (Expr max_val : max_vals) {
Expr test = simplify(common_subexpression_elimination(min_val - 1 < max_val + 1));
if (!is_one(test)) {
can_simplify_prologue = false;
}
}
}
for (const auto &s : middle_simps) {
if (can_simplify_prologue &&
!s.interval.has_lower_bound()) {
prologue_simps.push_back(s);
}
if (!s.interval.has_upper_bound()) {
epilogue_simps.push_back(s);
}
if (can_simplify_prologue &&
s.interval.has_lower_bound() &&
lower_bound_is_tight) {
internal_assert(s.tight);
Simplification s2 = s;
s2.condition = !s2.condition;
std::swap(s2.likely_value, s2.unlikely_value);
prologue_simps.push_back(s2);
}
if (s.interval.has_upper_bound() &&
upper_bound_is_tight) {
internal_assert(s.tight);
Simplification s2 = s;
s2.condition = !s2.condition;
std::swap(s2.likely_value, s2.unlikely_value);
epilogue_simps.push_back(s2);
}
}
Stmt simpler_body = MakeSimplifications(middle_simps).mutate(body);
Stmt prologue = MakeSimplifications(prologue_simps).mutate(body);
Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body);
bool make_prologue = !equal(prologue, simpler_body);
bool make_epilogue = !equal(epilogue, simpler_body);
simpler_body = mutate(simpler_body);
Expr min_steady = op->min, max_steady = op->extent + op->min;
Expr prologue_val, epilogue_val;
string prologue_name = unique_name(op->name + ".prologue");
string epilogue_name = unique_name(op->name + ".epilogue");
if (make_prologue) {
std::sort(min_vals.begin(), min_vals.end(), IRDeepCompare());
min_vals.push_back(op->min);
prologue_val = fold_left(min_vals, Max::make);
prologue_val = min(prologue_val, op->extent + op->min);
min_steady = Variable::make(Int(32), prologue_name);
internal_assert(!expr_uses_var(prologue_val, op->name));
}
if (make_epilogue) {
std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare());
max_vals.push_back(op->min + op->extent - 1);
epilogue_val = fold_left(max_vals, Min::make) + 1;
if (make_prologue) {
epilogue_val = max(epilogue_val, prologue_val);
} else {
epilogue_val = max(op->min, epilogue_val);
}
max_steady = Variable::make(Int(32), epilogue_name);
internal_assert(!expr_uses_var(epilogue_val, op->name));
}
if (op->for_type == ForType::Serial) {
stmt = For::make(op->name, min_steady, max_steady - min_steady,
op->for_type, op->device_api, simpler_body);
if (make_prologue) {
prologue = For::make(op->name, op->min, min_steady - op->min,
op->for_type, op->device_api, prologue);
stmt = Block::make(prologue, stmt);
}
if (make_epilogue) {
epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
op->for_type, op->device_api, epilogue);
stmt = Block::make(stmt, epilogue);
}
} else {
Expr loop_var = Variable::make(Int(32), op->name);
stmt = simpler_body;
if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
} else {
if (make_epilogue) {
stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
}
if (make_prologue) {
stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
}
}
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
}
if (make_epilogue) {
stmt = LetStmt::make(epilogue_name, epilogue_val, stmt);
} else {
epilogue_val = op->min + op->extent;
}
if (make_prologue) {
stmt = LetStmt::make(prologue_name, prologue_val, stmt);
} else {
prologue_val = op->min;
}
if (can_prove(epilogue_val <= prologue_val)) {
IRMutator::visit(op);
return;
}
in_gpu_loop = old_in_gpu_loop;
debug(3) << "Partition loop.\n"
<< "Old: " << Stmt(op) << "\n"
<< "New: " << stmt << "\n";
}
};
class ExprContainsLoad : public IRVisitor {
using IRVisitor::visit;
void visit(const Load *op) {
result = true;
}
public:
bool result = false;
};
bool expr_contains_load(Expr e) {
ExprContainsLoad l;
e.accept(&l);
return l.result;
}
class RenormalizeGPULoops : public IRMutator {
bool in_gpu_loop = false, in_thread_loop = false;
using IRMutator::visit;
Scope<int> gpu_vars;
vector<pair<string, Expr> > lifted_lets;
void visit(const For *op) {
if (op->device_api == DeviceAPI::GLSL) {
stmt = op;
return;
}
if (ends_with(op->name, "__thread_id_x")) {
in_thread_loop = true;
IRMutator::visit(op);
in_thread_loop = false;
return;
}
bool old_in_gpu_loop = in_gpu_loop;
if (in_gpu_loop || CodeGen_GPU_Dev::is_gpu_var(op->name)) {
gpu_vars.push(op->name, 0);
in_gpu_loop = true;
}
IRMutator::visit(op);
if (in_gpu_loop && !old_in_gpu_loop) {
while (lifted_lets.size()) {
stmt = LetStmt::make(lifted_lets.back().first,
lifted_lets.back().second,
stmt);
lifted_lets.pop_back();
}
}
in_gpu_loop = old_in_gpu_loop;
}
void visit(const LetStmt *op) {
if (!in_gpu_loop) {
IRMutator::visit(op);
return;
}
if (!expr_uses_vars(op->value, gpu_vars) && !expr_contains_load(op->value)) {
string new_name = unique_name('t');
Expr new_var = Variable::make(op->value.type(), new_name);
lifted_lets.push_back({ new_name, op->value });
stmt = mutate(substitute(op->name, new_var, op->body));
return;
}
gpu_vars.push(op->name, 0);
if (in_thread_loop) {
IRMutator::visit(op);
return;
}
Stmt body = mutate(op->body);
const For *f = body.as<For>();
const Allocate *a = body.as<Allocate>();
if (f && in_gpu_loop && !in_thread_loop) {
internal_assert(!expr_uses_var(f->min, op->name) &&
!expr_uses_var(f->extent, op->name));
Stmt inner = LetStmt::make(op->name, op->value, f->body);
inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
stmt = mutate(inner);
} else if (a && in_gpu_loop && !in_thread_loop) {
internal_assert(a->name == "__shared" && a->extents.size() == 1);
if (expr_uses_var(a->extents[0], op->name)) {
IRMutator::visit(op);
} else {
Stmt inner = LetStmt::make(op->name, op->value, a->body);
inner = Allocate::make(a->name, a->type, a->extents, a->condition, inner);
stmt = mutate(inner);
}
} else {
IRMutator::visit(op);
}
}
void visit(const IfThenElse *op) {
if (!in_gpu_loop || in_thread_loop) {
IRMutator::visit(op);
return;
}
internal_assert(op->else_case.defined())
<< "PartitionLoops should only introduce if statements with an else branch\n";
Stmt then_case = mutate(op->then_case);
Stmt else_case = mutate(op->else_case);
if (equal(then_case, else_case)) {
stmt = then_case;
return;
}
const Allocate *allocate_a = then_case.as<Allocate>();
const Allocate *allocate_b = else_case.as<Allocate>();
const For *for_a = then_case.as<For>();
const For *for_b = else_case.as<For>();
const LetStmt *let_a = then_case.as<LetStmt>();
const LetStmt *let_b = else_case.as<LetStmt>();
if (allocate_a && allocate_b &&
allocate_a->name == "__shared" &&
allocate_b->name == "__shared") {
Stmt inner = IfThenElse::make(op->condition, allocate_a->body, allocate_b->body);
inner = Allocate::make(allocate_a->name, allocate_a->type, allocate_a->extents, allocate_a->condition, inner);
stmt = mutate(inner);
} else if (let_a && let_b && let_a->name == let_b->name) {
string condition_name = unique_name('t');
Expr condition = Variable::make(op->condition.type(), condition_name);
Stmt inner = IfThenElse::make(condition, let_a->body, let_b->body);
inner = LetStmt::make(let_a->name, select(condition, let_a->value, let_b->value), inner);
inner = LetStmt::make(condition_name, op->condition, inner);
stmt = mutate(inner);
} else if (let_a) {
string new_name = unique_name(let_a->name);
Stmt inner = let_a->body;
inner = substitute(let_a->name, Variable::make(let_a->value.type(), new_name), inner);
inner = IfThenElse::make(op->condition, inner, else_case);
inner = LetStmt::make(new_name, let_a->value, inner);
stmt = mutate(inner);
} else if (let_b) {
string new_name = unique_name(let_b->name);
Stmt inner = let_b->body;
inner = substitute(let_b->name, Variable::make(let_b->value.type(), new_name), inner);
inner = IfThenElse::make(op->condition, then_case, inner);
inner = LetStmt::make(new_name, let_b->value, inner);
stmt = mutate(inner);
} else if (for_a && for_b &&
for_a->name == for_b->name &&
for_a->min.same_as(for_b->min) &&
for_a->extent.same_as(for_b->extent)) {
Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
stmt = mutate(inner);
} else {
internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
}
}
};
class ExpandSelects : public IRMutator {
using IRMutator::visit;
bool is_trivial(Expr e) {
return e.as<Variable>() || is_const(e);
}
void visit(const Select *op) {
Expr condition = mutate(op->condition);
Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
if (const Or *o = condition.as<Or>()) {
if (is_trivial(true_value)) {
expr = mutate(Select::make(o->a, true_value, Select::make(o->b, true_value, false_value)));
} else {
string var_name = unique_name('t');
Expr var = Variable::make(true_value.type(), var_name);
expr = mutate(Select::make(o->a, var, Select::make(o->b, var, false_value)));
expr = Let::make(var_name, true_value, expr);
}
} else if (const And *a = condition.as<And>()) {
if (is_trivial(false_value)) {
expr = mutate(Select::make(a->a, Select::make(a->b, true_value, false_value), false_value));
} else {
string var_name = unique_name('t');
Expr var = Variable::make(false_value.type(), var_name);
expr = mutate(Select::make(a->a, Select::make(a->b, true_value, var), var));
expr = Let::make(var_name, false_value, expr);
}
} else if (const Not *n = condition.as<Not>()) {
expr = mutate(Select::make(n->a, false_value, true_value));
} else if (condition.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
expr = op;
} else {
expr = Select::make(condition, true_value, false_value);
}
}
};
class CollapseSelects : public IRMutator {
using IRMutator::visit;
void visit(const Select *op) {
const Select *t = op->true_value.as<Select>();
const Select *f = op->false_value.as<Select>();
if (t && equal(t->false_value, op->false_value)) {
expr = mutate(select(op->condition && t->condition, t->true_value, op->false_value));
} else if (f && equal(op->true_value, f->true_value)) {
expr = mutate(select(op->condition || f->condition, op->true_value, f->false_value));
} else {
IRMutator::visit(op);
}
}
};
class ContainsLoop : public IRVisitor {
using IRVisitor::visit;
void visit(const For *op) {
result = true;
}
public:
bool result = false;
};
class LowerLikelyIfInnermost : public IRMutator {
using IRMutator::visit;
bool inside_innermost_loop = false;
void visit(const Call *op) {
if (op->is_intrinsic(Call::likely_if_innermost)) {
internal_assert(op->args.size() == 1);
if (inside_innermost_loop) {
expr = Call::make(op->type, Call::likely, {mutate(op->args[0])}, Call::PureIntrinsic);
} else {
expr = mutate(op->args[0]);
}
} else {
IRMutator::visit(op);
}
}
void visit(const For *op) {
ContainsLoop c;
op->body.accept(&c);
inside_innermost_loop = !c.result;
IRMutator::visit(op);
inside_innermost_loop = false;
}
};
}
Stmt partition_loops(Stmt s) {
s = LowerLikelyIfInnermost().mutate(s);
s = MarkClampedRampsAsLikely().mutate(s);
s = ExpandSelects().mutate(s);
s = PartitionLoops().mutate(s);
s = RenormalizeGPULoops().mutate(s);
s = RemoveLikelyTags().mutate(s);
s = CollapseSelects().mutate(s);
return s;
}
}
}