This source file includes following definitions.
- visit
- visit
- visit
- visit
- visit
- visit
- has_likely_tag
- visit
- expr_uses_invalid_buffers
- visit
- new_simplification
- visit
- visit
- visit
- visit
- visit
- visit_let
- visit
- visit
- mutate
- visit
- contains_thread_barrier
- visit
- visit
- expr_contains_load
- visit
- visit
- visit
- is_trivial
- visit
- visit
- visit
- visit
- visit
- partition_loops
#include <algorithm>
#include <numeric>
#include "PartitionLoops.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "Solve.h"
#include "IREquality.h"
#include "ExprUsesVar.h"
#include "Substitute.h"
#include "CodeGen_GPU_Dev.h"
#include "Var.h"
#include "CSE.h"
namespace Halide {
namespace Internal {
using std::string;
using std::vector;
using std::pair;
using std::map;
namespace {
class MarkClampedRampsAsLikely : public IRMutator {
    using IRMutator::visit;
    void visit(const Min *op) {
        if (in_index && op->a.as<Ramp>()) {
            
            
            expr = min(likely(op->a), mutate(op->b));
        } else if (in_index && op->b.as<Ramp>()) {
            expr = min(mutate(op->a), likely(op->b));
        } else {
            IRMutator::visit(op);
        }
    }
    void visit(const Max *op) {
        if (in_index && op->a.as<Ramp>()) {
            expr = max(likely(op->a), mutate(op->b));
        } else if (in_index && op->b.as<Ramp>()) {
            expr = max(mutate(op->a), likely(op->b));
        } else {
            IRMutator::visit(op);
        }
    }
    void visit(const Load *op) {
        bool old_in_index = in_index;
        in_index = true;
        IRMutator::visit(op);
        in_index = old_in_index;
    }
    void visit(const Store *op) {
        bool old_in_index = in_index;
        in_index = true;
        Expr index = mutate(op->index);
        in_index = old_in_index;
        Expr value = mutate(op->value);
        Expr predicate = mutate(op->predicate);
        if (predicate.same_as(op->predicate) && index.same_as(op->index) && value.same_as(op->value)) {
            stmt = op;
        } else {
            stmt = Store::make(op->name, value, index, op->param, predicate);
        }
    }
    bool in_index = false;
};
class RemoveLikelyTags : public IRMutator {
    using IRMutator::visit;
    void visit(const Call *op) {
        if (op->is_intrinsic(Call::likely)) {
            internal_assert(op->args.size() == 1);
            expr = mutate(op->args[0]);
        } else {
            IRMutator::visit(op);
        }
    }
};
class HasLikelyTag : public IRVisitor {
    using IRVisitor::visit;
    void visit(const Call *op) {
        if (op->is_intrinsic(Call::likely)) {
            result = true;
        } else {
            IRVisitor::visit(op);
        }
    }
public:
    bool result = false;
};
bool has_likely_tag(Expr e) {
    HasLikelyTag h;
    e.accept(&h);
    return h.result;
}
struct Simplification {
    
    Expr condition;
    
    Expr old_expr;
    
    Expr likely_value;
    
    
    Expr unlikely_value;
    
    bool tight;
    
    Interval interval;
};
class ExprUsesInvalidBuffers : public IRVisitor {
    using IRVisitor::visit;
    const Scope<int> &invalid_buffers;
    void visit(const Load *op) {
        if (invalid_buffers.contains(op->name)) {
            invalid = true;
        } else {
            IRVisitor::visit(op);
        }
    }
public:
    ExprUsesInvalidBuffers(const Scope<int> &buffers) : invalid_buffers(buffers), invalid(false) {}
    bool invalid;
};
bool expr_uses_invalid_buffers(Expr e, const Scope<int> &invalid_buffers) {
    ExprUsesInvalidBuffers uses(invalid_buffers);
    e.accept(&uses);
    return uses.invalid;
}
class FindSimplifications : public IRVisitor {
    using IRVisitor::visit;
    Scope<int> depends_on_loop_var;
    Scope<int> buffers;
    void visit(const Allocate *op) {
        buffers.push(op->name, 0);
        IRVisitor::visit(op);
    }
    void new_simplification(Expr condition, Expr old, Expr likely_val, Expr unlikely_val) {
        if (!expr_uses_vars(condition, depends_on_loop_var)) {
            return;
        }
        if (expr_uses_invalid_buffers(condition, buffers)) {
            
            
            return;
        }
        condition = RemoveLikelyTags().mutate(condition);
        Simplification s = {condition, old, likely_val, unlikely_val, true};
        if (s.condition.type().is_vector()) {
            s.condition = simplify(s.condition);
            if (const Broadcast *b = s.condition.as<Broadcast>()) {
                s.condition = b->value;
            } else {
                
                s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
                s.tight = false;
            }
        }
        internal_assert(s.condition.type().is_scalar()) << s.condition << "\n";
        simplifications.push_back(s);
    }
    void visit(const Min *op) {
        IRVisitor::visit(op);
        bool likely_a = has_likely_tag(op->a);
        bool likely_b = has_likely_tag(op->b);
        if (likely_b && !likely_a) {
            new_simplification(op->b <= op->a, op, op->b, op->a);
        } else if (likely_a && !likely_b) {
            new_simplification(op->a <= op->b, op, op->a, op->b);
        }
    }
    void visit(const Max *op) {
        IRVisitor::visit(op);
        bool likely_a = has_likely_tag(op->a);
        bool likely_b = has_likely_tag(op->b);
        if (likely_b && !likely_a) {
            new_simplification(op->b >= op->a, op, op->b, op->a);
        } else if (likely_a && !likely_b) {
            new_simplification(op->a >= op->b, op, op->a, op->b);
        }
    }
    void visit(const Select *op) {
        IRVisitor::visit(op);
        bool likely_t = has_likely_tag(op->true_value);
        bool likely_f = has_likely_tag(op->false_value);
        if (likely_t && !likely_f) {
            new_simplification(op->condition, op, op->true_value, op->false_value);
        } else if (likely_f && !likely_t) {
            new_simplification(!op->condition, op, op->false_value, op->true_value);
        }
    }
    void visit(const IfThenElse *op) {
        
        
        
        
        
        IRVisitor::visit(op);
        const Call *call = op->condition.as<Call>();
        if (call && call->is_intrinsic(Call::likely)) {
            new_simplification(op->condition, op->condition, const_true(), const_false());
        }
    }
    void visit(const For *op) {
        vector<Simplification> old;
        old.swap(simplifications);
        IRVisitor::visit(op);
        
        for (Simplification &s : simplifications) {
            if (expr_uses_var(s.condition, op->name)) {
                Scope<Interval> varying;
                varying.push(op->name, Interval(op->min, op->min + op->extent - 1));
                Expr relaxed = and_condition_over_domain(s.condition, varying);
                internal_assert(!expr_uses_var(relaxed, op->name))
                    << "Should not have had used the loop var (" << op->name
                    << ") any longer\n  before: " << s.condition << "\n  after: "
                    << relaxed << "\n";
                if (!equal(relaxed, s.condition)) {
                    s.tight = false;
                }
                s.condition = relaxed;
            }
        }
        simplifications.insert(simplifications.end(), old.begin(), old.end());
    }
    template<typename LetOrLetStmt>
    void visit_let(const LetOrLetStmt *op) {
        bool varying = expr_uses_vars(op->value, depends_on_loop_var);
        if (varying) {
            depends_on_loop_var.push(op->name, 0);
        }
        vector<Simplification> old;
        old.swap(simplifications);
        IRVisitor::visit(op);
        for (Simplification &s : simplifications) {
            if (expr_uses_var(s.condition, op->name)) {
                s.condition = Let::make(op->name, op->value, s.condition);
            }
        }
        simplifications.insert(simplifications.end(), old.begin(), old.end());
        if (varying) {
            depends_on_loop_var.pop(op->name);
        }
    }
    void visit(const LetStmt *op) {
        visit_let(op);
    }
    void visit(const Let *op) {
        visit_let(op);
    }
public:
    vector<Simplification> simplifications;
    FindSimplifications(const std::string &v) {
        depends_on_loop_var.push(v, 0);
    }
};
class MakeSimplifications : public IRMutator {
    using IRMutator::visit;
    const vector<Simplification> &simplifications;
public:
    MakeSimplifications(const vector<Simplification> &s) : simplifications(s) {}
    using IRMutator::mutate;
    Expr mutate(Expr e) {
        for (auto const &s : simplifications) {
            if (e.same_as(s.old_expr)) {
                return mutate(s.likely_value);
            }
        }
        return IRMutator::mutate(e);
    }
};
class ContainsThreadBarrier : public IRVisitor {
public:
    bool result = false;
protected:
    using IRVisitor::visit;
    void visit(const Call *op) {
        if (op->name == "halide_gpu_thread_barrier") {
            result = true;
        }
        IRVisitor::visit(op);
    }
};
bool contains_thread_barrier(Stmt s) {
    ContainsThreadBarrier c;
    s.accept(&c);
    return c.result;
}
class PartitionLoops : public IRMutator {
    using IRMutator::visit;
    bool in_gpu_loop = false;
    void visit(const For *op) {
        Stmt body = op->body;
        bool old_in_gpu_loop = in_gpu_loop;
        in_gpu_loop |= CodeGen_GPU_Dev::is_gpu_var(op->name);
        
        
        if (in_gpu_loop && contains_thread_barrier(body)) {
            IRMutator::visit(op);
            in_gpu_loop = old_in_gpu_loop;
            return;
        }
        
        
        if (op->device_api == DeviceAPI::GLSL) {
            stmt = op;
            in_gpu_loop = old_in_gpu_loop;
            return;
        }
        
        FindSimplifications finder(op->name);
        body.accept(&finder);
        if (finder.simplifications.empty()) {
            IRMutator::visit(op);
            return;
        }
        debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
        vector<Expr> min_vals, max_vals;
        vector<Simplification> middle_simps, prologue_simps, epilogue_simps;
        bool lower_bound_is_tight = true, upper_bound_is_tight = true;
        for (auto &s : finder.simplifications) {
            
            s.interval = solve_for_inner_interval(s.condition, op->name);
            if (s.tight) {
                
                
                Interval outer = solve_for_outer_interval(s.condition, op->name);
                s.tight &= equal(outer.min, s.interval.min) && equal(outer.max, s.interval.max);
            }
            debug(3) << "\nSimplification: \n"
                     << "  condition: " << s.condition << "\n"
                     << "  old: " << s.old_expr << "\n"
                     << "  new: " << s.likely_value << "\n"
                     << "  min: " << s.interval.min << "\n"
                     << "  max: " << s.interval.max << "\n"
                     << "  tight: " << s.tight << "\n";
            
            if (!s.interval.is_empty()) {
                if (s.interval.has_lower_bound()) {
                    Expr m = s.interval.min;
                    if (!s.tight) {
                        lower_bound_is_tight = false;
                    }
                    if (min_vals.empty()) {
                        min_vals.push_back(m);
                    } else if (equal(m, min_vals.back())) {
                        
                    } else {
                        
                        min_vals.push_back(m);
                        lower_bound_is_tight = false;
                    }
                }
                if (s.interval.has_upper_bound()) {
                    Expr m = s.interval.max;
                    if (!s.tight) {
                        upper_bound_is_tight = false;
                    }
                    if (max_vals.empty()) {
                        max_vals.push_back(m);
                    } else if (equal(m, max_vals.back())) {
                        
                    } else {
                        
                        max_vals.push_back(m);
                        upper_bound_is_tight = false;
                    }
                }
                
                
                middle_simps.push_back(s);
            }
        }
        
        
        
        
        bool can_simplify_prologue = true;
        for (Expr min_val : min_vals) {
            for (Expr max_val : max_vals) {
                Expr test = simplify(common_subexpression_elimination(min_val - 1 < max_val + 1));
                if (!is_one(test)) {
                    can_simplify_prologue = false;
                }
            }
        }
        
        for (const auto &s : middle_simps) {
            
            
            if (can_simplify_prologue &&
                !s.interval.has_lower_bound()) {
                prologue_simps.push_back(s);
            }
            
            
            if (!s.interval.has_upper_bound()) {
                epilogue_simps.push_back(s);
            }
            
            
            
            if (can_simplify_prologue &&
                s.interval.has_lower_bound() &&
                lower_bound_is_tight) {
                internal_assert(s.tight);
                Simplification s2 = s;
                
                
                
                s2.condition = !s2.condition;
                std::swap(s2.likely_value, s2.unlikely_value);
                prologue_simps.push_back(s2);
            }
            if (s.interval.has_upper_bound() &&
                upper_bound_is_tight) {
                internal_assert(s.tight);
                Simplification s2 = s;
                s2.condition = !s2.condition;
                std::swap(s2.likely_value, s2.unlikely_value);
                epilogue_simps.push_back(s2);
            }
        }
        
        Stmt simpler_body = MakeSimplifications(middle_simps).mutate(body);
        Stmt prologue = MakeSimplifications(prologue_simps).mutate(body);
        Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body);
        bool make_prologue = !equal(prologue, simpler_body);
        bool make_epilogue = !equal(epilogue, simpler_body);
        
        simpler_body = mutate(simpler_body);
        
        Expr min_steady = op->min, max_steady = op->extent + op->min;
        Expr prologue_val, epilogue_val;
        string prologue_name = unique_name(op->name + ".prologue");
        string epilogue_name = unique_name(op->name + ".epilogue");
        if (make_prologue) {
            
            
            
            
            std::sort(min_vals.begin(), min_vals.end(), IRDeepCompare());
            min_vals.push_back(op->min);
            prologue_val = fold_left(min_vals, Max::make);
            
            prologue_val = min(prologue_val, op->extent + op->min);
            
            min_steady = Variable::make(Int(32), prologue_name);
            internal_assert(!expr_uses_var(prologue_val, op->name));
        }
        if (make_epilogue) {
            std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare());
            max_vals.push_back(op->min + op->extent - 1);
            epilogue_val = fold_left(max_vals, Min::make) + 1;
            
            if (make_prologue) {
                epilogue_val = max(epilogue_val, prologue_val);
            } else {
                epilogue_val = max(op->min, epilogue_val);
            }
            
            max_steady = Variable::make(Int(32), epilogue_name);
            internal_assert(!expr_uses_var(epilogue_val, op->name));
        }
        
        if (op->for_type == ForType::Serial) {
            stmt = For::make(op->name, min_steady, max_steady - min_steady,
                             op->for_type, op->device_api, simpler_body);
            if (make_prologue) {
                prologue = For::make(op->name, op->min, min_steady - op->min,
                                     op->for_type, op->device_api, prologue);
                stmt = Block::make(prologue, stmt);
            }
            if (make_epilogue) {
                epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
                                     op->for_type, op->device_api, epilogue);
                stmt = Block::make(stmt, epilogue);
            }
        } else {
            
            
            
            Expr loop_var = Variable::make(Int(32), op->name);
            stmt = simpler_body;
            if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
                stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
            } else {
                if (make_epilogue) {
                    stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
                }
                if (make_prologue) {
                    stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
                }
            }
            stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
        }
        if (make_epilogue) {
            
            
            stmt = LetStmt::make(epilogue_name, epilogue_val, stmt);
        } else {
            epilogue_val = op->min + op->extent;
        }
        if (make_prologue) {
            
            
            stmt = LetStmt::make(prologue_name, prologue_val, stmt);
        } else {
            prologue_val = op->min;
        }
        if (can_prove(epilogue_val <= prologue_val)) {
            
            
            IRMutator::visit(op);
            return;
        }
        in_gpu_loop = old_in_gpu_loop;
        debug(3) << "Partition loop.\n"
                 << "Old: " << Stmt(op) << "\n"
                 << "New: " << stmt << "\n";
    }
};
class ExprContainsLoad : public IRVisitor {
    using IRVisitor::visit;
    void visit(const Load *op) {
        result = true;
    }
public:
    bool result = false;
};
bool expr_contains_load(Expr e) {
    ExprContainsLoad l;
    e.accept(&l);
    return l.result;
}
class RenormalizeGPULoops : public IRMutator {
    bool in_gpu_loop = false, in_thread_loop = false;
    using IRMutator::visit;
    
    Scope<int> gpu_vars;
    vector<pair<string, Expr> > lifted_lets;
    void visit(const For *op) {
        if (op->device_api == DeviceAPI::GLSL) {
            
            stmt = op;
            return;
        }
        if (ends_with(op->name, "__thread_id_x")) {
            in_thread_loop = true;
            IRMutator::visit(op);
            in_thread_loop = false;
            return;
        }
        bool old_in_gpu_loop = in_gpu_loop;
        if (in_gpu_loop || CodeGen_GPU_Dev::is_gpu_var(op->name)) {
            gpu_vars.push(op->name, 0);
            in_gpu_loop = true;
        }
        IRMutator::visit(op);
        if (in_gpu_loop && !old_in_gpu_loop) {
            
            while (lifted_lets.size()) {
                stmt = LetStmt::make(lifted_lets.back().first,
                                     lifted_lets.back().second,
                                     stmt);
                lifted_lets.pop_back();
            }
        }
        in_gpu_loop = old_in_gpu_loop;
    }
    void visit(const LetStmt *op) {
        if (!in_gpu_loop) {
            IRMutator::visit(op);
            return;
        }
        if (!expr_uses_vars(op->value, gpu_vars) && !expr_contains_load(op->value)) {
            
            
            
            
            string new_name = unique_name('t');
            Expr new_var = Variable::make(op->value.type(), new_name);
            lifted_lets.push_back({ new_name, op->value });
            stmt = mutate(substitute(op->name, new_var, op->body));
            return;
        }
        gpu_vars.push(op->name, 0);
        if (in_thread_loop) {
            IRMutator::visit(op);
            return;
        }
        Stmt body = mutate(op->body);
        const For *f = body.as<For>();
        const Allocate *a = body.as<Allocate>();
        
        if (f && in_gpu_loop && !in_thread_loop) {
            internal_assert(!expr_uses_var(f->min, op->name) &&
                            !expr_uses_var(f->extent, op->name));
            Stmt inner = LetStmt::make(op->name, op->value, f->body);
            inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
            stmt = mutate(inner);
        } else if (a && in_gpu_loop && !in_thread_loop) {
            internal_assert(a->name == "__shared" && a->extents.size() == 1);
            if (expr_uses_var(a->extents[0], op->name)) {
                
                
                
                
                
                IRMutator::visit(op);
            } else {
                Stmt inner = LetStmt::make(op->name, op->value, a->body);
                inner = Allocate::make(a->name, a->type, a->extents, a->condition, inner);
                stmt = mutate(inner);
            }
        } else {
            IRMutator::visit(op);
        }
    }
    void visit(const IfThenElse *op) {
        if (!in_gpu_loop || in_thread_loop) {
            IRMutator::visit(op);
            return;
        }
        internal_assert(op->else_case.defined())
            << "PartitionLoops should only introduce if statements with an else branch\n";
        Stmt then_case = mutate(op->then_case);
        Stmt else_case = mutate(op->else_case);
        if (equal(then_case, else_case)) {
            
            
            stmt = then_case;
            return;
        }
        const Allocate *allocate_a = then_case.as<Allocate>();
        const Allocate *allocate_b = else_case.as<Allocate>();
        const For *for_a = then_case.as<For>();
        const For *for_b = else_case.as<For>();
        const LetStmt *let_a = then_case.as<LetStmt>();
        const LetStmt *let_b = else_case.as<LetStmt>();
        if (allocate_a && allocate_b &&
            allocate_a->name == "__shared" &&
            allocate_b->name == "__shared") {
            Stmt inner = IfThenElse::make(op->condition, allocate_a->body, allocate_b->body);
            inner = Allocate::make(allocate_a->name, allocate_a->type, allocate_a->extents, allocate_a->condition, inner);
            stmt = mutate(inner);
        } else if (let_a && let_b && let_a->name == let_b->name) {
            string condition_name = unique_name('t');
            Expr condition = Variable::make(op->condition.type(), condition_name);
            Stmt inner = IfThenElse::make(condition, let_a->body, let_b->body);
            inner = LetStmt::make(let_a->name, select(condition, let_a->value, let_b->value), inner);
            inner = LetStmt::make(condition_name, op->condition, inner);
            stmt = mutate(inner);
        } else if (let_a) {
            string new_name = unique_name(let_a->name);
            Stmt inner = let_a->body;
            inner = substitute(let_a->name, Variable::make(let_a->value.type(), new_name), inner);
            inner = IfThenElse::make(op->condition, inner, else_case);
            inner = LetStmt::make(new_name, let_a->value, inner);
            stmt = mutate(inner);
        } else if (let_b) {
            string new_name = unique_name(let_b->name);
            Stmt inner = let_b->body;
            inner = substitute(let_b->name, Variable::make(let_b->value.type(), new_name), inner);
            inner = IfThenElse::make(op->condition, then_case, inner);
            inner = LetStmt::make(new_name, let_b->value, inner);
            stmt = mutate(inner);
        } else if (for_a && for_b &&
                   for_a->name == for_b->name &&
                   for_a->min.same_as(for_b->min) &&
                   for_a->extent.same_as(for_b->extent)) {
            Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
            inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
            stmt = mutate(inner);
        } else {
            internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
        }
    }
};
class ExpandSelects : public IRMutator {
    using IRMutator::visit;
    bool is_trivial(Expr e) {
        return e.as<Variable>() || is_const(e);
    }
    void visit(const Select *op) {
        Expr condition   = mutate(op->condition);
        Expr true_value  = mutate(op->true_value);
        Expr false_value = mutate(op->false_value);
        if (const Or *o = condition.as<Or>()) {
            if (is_trivial(true_value)) {
                expr = mutate(Select::make(o->a, true_value, Select::make(o->b, true_value, false_value)));
            } else {
                string var_name = unique_name('t');
                Expr var = Variable::make(true_value.type(), var_name);
                expr = mutate(Select::make(o->a, var, Select::make(o->b, var, false_value)));
                expr = Let::make(var_name, true_value, expr);
            }
        } else if (const And *a = condition.as<And>()) {
            if (is_trivial(false_value)) {
                expr = mutate(Select::make(a->a, Select::make(a->b, true_value, false_value), false_value));
            } else {
                string var_name = unique_name('t');
                Expr var = Variable::make(false_value.type(), var_name);
                expr = mutate(Select::make(a->a, Select::make(a->b, true_value, var), var));
                expr = Let::make(var_name, false_value, expr);
            }
        } else if (const Not *n = condition.as<Not>()) {
            expr = mutate(Select::make(n->a, false_value, true_value));
        } else if (condition.same_as(op->condition) &&
                   true_value.same_as(op->true_value) &&
                   false_value.same_as(op->false_value)) {
            expr = op;
        } else {
            expr = Select::make(condition, true_value, false_value);
        }
    }
};
class CollapseSelects : public IRMutator {
    using IRMutator::visit;
    void visit(const Select *op) {
        const Select *t = op->true_value.as<Select>();
        const Select *f = op->false_value.as<Select>();
        if (t && equal(t->false_value, op->false_value)) {
            
            expr = mutate(select(op->condition && t->condition, t->true_value, op->false_value));
        } else if (f && equal(op->true_value, f->true_value)) {
            
            expr = mutate(select(op->condition || f->condition, op->true_value, f->false_value));
        } else {
            IRMutator::visit(op);
        }
    }
};
class ContainsLoop : public IRVisitor {
    using IRVisitor::visit;
    void visit(const For *op) {
        result = true;
    }
public:
    bool result = false;
};
class LowerLikelyIfInnermost : public IRMutator {
    using IRMutator::visit;
    bool inside_innermost_loop = false;
    void visit(const Call *op) {
        if (op->is_intrinsic(Call::likely_if_innermost)) {
            internal_assert(op->args.size() == 1);
            if (inside_innermost_loop) {
                expr = Call::make(op->type, Call::likely, {mutate(op->args[0])}, Call::PureIntrinsic);
            } else {
                expr = mutate(op->args[0]);
            }
        } else {
            IRMutator::visit(op);
        }
    }
    void visit(const For *op) {
        ContainsLoop c;
        op->body.accept(&c);
        inside_innermost_loop = !c.result;
        IRMutator::visit(op);
        inside_innermost_loop = false;
    }
};
}
Stmt partition_loops(Stmt s) {
    s = LowerLikelyIfInnermost().mutate(s);
    s = MarkClampedRampsAsLikely().mutate(s);
    s = ExpandSelects().mutate(s);
    s = PartitionLoops().mutate(s);
    s = RenormalizeGPULoops().mutate(s);
    s = RemoveLikelyTags().mutate(s);
    s = CollapseSelects().mutate(s);
    return s;
}
}
}