This source file includes following definitions.
- substitute_value_in_var
- mutate
- simplify_using_fact
- propagate_specialization_in_definition
- simplify_specializations
#include "SimplifySpecializations.h"
#include "IROperator.h"
#include "IRMutator.h"
#include "Simplify.h"
#include "Substitute.h"
#include "Definition.h"
#include "IREquality.h"
#include <set>
namespace Halide{
namespace Internal {
using std::map;
using std::set;
using std::string;
using std::vector;
namespace {
void substitute_value_in_var(const string &var, Expr value, vector<Definition> &definitions) {
for (Definition &def : definitions) {
for (auto &def_arg : def.args()) {
def_arg = simplify(substitute(var, value, def_arg));
}
for (auto &def_val : def.values()) {
def_val = simplify(substitute(var, value, def_val));
}
}
}
class SimplifyUsingFact : public IRMutator {
public:
using IRMutator::mutate;
Expr mutate(Expr e) {
if (e.type().is_bool()) {
if (equal(fact, e) ||
can_prove(!fact || e)) {
return const_true();
}
if (equal(fact, !e) ||
equal(!fact, e) ||
can_prove(!fact || !e)) {
return const_false();
}
}
return IRMutator::mutate(e);
}
Expr fact;
SimplifyUsingFact(Expr f) : fact(f) {}
};
void simplify_using_fact(Expr fact, vector<Definition> &definitions) {
for (Definition &def : definitions) {
for (auto &def_arg : def.args()) {
def_arg = simplify(SimplifyUsingFact(fact).mutate(def_arg));
}
for (auto &def_val : def.values()) {
def_val = simplify(SimplifyUsingFact(fact).mutate(def_val));
}
}
}
vector<Definition> propagate_specialization_in_definition(Definition &def, const string &name) {
vector<Definition> result;
result.push_back(def);
vector<Specialization> &specializations = def.specializations();
bool seen_const_true = false;
for (auto it = specializations.begin(); it != specializations.end(); ) {
Expr old_c = it->condition;
Expr c = simplify(it->condition);
it->condition = c;
if (is_zero(c) || seen_const_true) {
debug(1) << "Erasing unreachable specialization ("
<< old_c << ") -> (" << c << ") for function \"" << name << "\"\n";
it = specializations.erase(it);
} else {
it++;
}
seen_const_true |= is_one(c);
}
if (!specializations.empty() && is_one(specializations.back().condition) && specializations.back().failure_message.empty()) {
debug(1) << "Replacing default Schedule with const-true specialization for function \"" << name << "\"\n";
const Definition s_def = specializations.back().definition;
specializations.pop_back();
def.values() = s_def.values();
def.args() = s_def.args();
def.schedule().splits() = s_def.schedule().splits();
def.schedule().dims() = s_def.schedule().dims();
def.schedule().prefetches() = s_def.schedule().prefetches();
def.schedule().touched() = s_def.schedule().touched();
def.schedule().allow_race_conditions() = s_def.schedule().allow_race_conditions();
specializations.insert(specializations.end(), s_def.specializations().begin(), s_def.specializations().end());
}
for (size_t i = specializations.size(); i > 0; i--) {
Expr c = specializations[i-1].condition;
Definition &s_def = specializations[i-1].definition;
const EQ *eq = c.as<EQ>();
const Variable *var = eq ? eq->a.as<Variable>() : c.as<Variable>();
vector<Definition> s_result = propagate_specialization_in_definition(s_def, name);
if (var && eq) {
substitute_value_in_var(var->name, eq->b, s_result);
if (eq->b.type().is_bool()) {
substitute_value_in_var(var->name, !eq->b, result);
}
} else if (var) {
substitute_value_in_var(var->name, const_true(), s_result);
substitute_value_in_var(var->name, const_false(), result);
} else {
simplify_using_fact(c, s_result);
simplify_using_fact(!c, result);
}
result.insert(result.end(), s_result.begin(), s_result.end());
}
return result;
}
}
void simplify_specializations(map<string, Function> &env) {
for (auto &iter : env) {
Function &func = iter.second;
propagate_specialization_in_definition(func.definition(), func.name());
}
}
}
}