root/src/VaryingAttributes.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. make_block
  2. tag_linear_expression
  3. visit
  4. visit
  5. visit
  6. visit
  7. visit
  8. visit
  9. visit
  10. visit
  11. visit
  12. visit_binary_linear
  13. visit
  14. visit
  15. visit
  16. visit
  17. visit_binary
  18. visit
  19. visit
  20. visit
  21. visit
  22. visit
  23. visit
  24. visit
  25. visit
  26. visit
  27. visit
  28. visit
  29. visit
  30. visit
  31. visit
  32. max_expressions
  33. find_linear_expressions
  34. visit
  35. visit
  36. remove_varying_attributes
  37. visit
  38. replace_varying_attributes
  39. visit
  40. prune_varying_attributes
  41. visit
  42. visit
  43. float_type
  44. visit_binary_op
  45. visit
  46. visit
  47. visit
  48. visit
  49. visit
  50. visit
  51. visit
  52. visit
  53. visit
  54. visit
  55. visit
  56. visit
  57. visit
  58. visit
  59. visit
  60. visit
  61. visit
  62. visit
  63. visit
  64. mutate
  65. mutate
  66. mutate_operator
  67. mutate_operator
  68. mutate_operator
  69. visit
  70. visit
  71. visit
  72. visit
  73. visit
  74. visit
  75. visit
  76. visit
  77. visit
  78. visit
  79. visit
  80. visit
  81. visit
  82. visit
  83. visit
  84. visit
  85. visit
  86. visit
  87. visit
  88. visit
  89. visit
  90. visit
  91. visit
  92. visit
  93. visit
  94. visit
  95. visit
  96. visit
  97. visit
  98. visit
  99. visit
  100. visit
  101. visit
  102. visit
  103. visit
  104. visit
  105. visit
  106. visit
  107. visit
  108. visit
  109. visit
  110. visit
  111. visit
  112. dont_simplify
  113. used_in_codegen
  114. visit
  115. setup_gpu_vertex_buffer

#include "VaryingAttributes.h"

#include <algorithm>

#include "CodeGen_GPU_Dev.h"

#include "IRMutator.h"
#include "CSE.h"
#include "Simplify.h"

namespace Halide {
namespace Internal {

Stmt make_block(Stmt first, Stmt rest) {
    if (first.defined() && rest.defined()) {
        return Block::make(first, rest);
    } else if (first.defined()) {
        return first;
    } else {
        return rest;
    }
}

// Find expressions that we can evaluate with interpolation hardware in the GPU
//
// This visitor keeps track of the "order" of the expression in terms of the
// specified variables. The order value 0 means that the expression is contant;
// order value 1 means that it is linear in terms of only one variable, check
// the member found to determine which; order value 2 means non-linear, it
// could be disqualified due to being quadratic, bilinear or the result of an
// unknown function.
class FindLinearExpressions : public IRMutator {
protected:
    using IRMutator::visit;

    bool in_glsl_loops;

    Expr tag_linear_expression(Expr e, const std::string &name = unique_name('a')) {

        internal_assert(name.length() > 0);

        if (total_found >= max_expressions) {
            return e;
        }

        // Wrap the expression with an intrinsic to tag that it is a varying
        // attribute. These tagged variables will be pulled out of the fragment
        // shader during a subsequent pass
        Expr intrinsic = Call::make(e.type(), Call::glsl_varying,
                                    {name + ".varying", e},
                                    Call::Intrinsic);
        ++total_found;

        return intrinsic;
    }

    virtual void visit(const Call *op) {

        std::vector<Expr> new_args = op->args;

        // Check to see if this call is a load
        if (op->name == Call::glsl_texture_load) {
            // Check if the texture coordinate arguments are linear wrt the GPU
            // loop variables
            internal_assert(loop_vars.size() > 0) << "No GPU loop variables found at texture load\n";

            // Iterate over the texture coordinate arguments
            for (int i = 2; i != 4; ++i) {
                new_args[i] = mutate(op->args[i]);
                if (order == 1) {
                    new_args[i] = tag_linear_expression(new_args[i]);
                }
            }
        } else if (op->name == Call::glsl_texture_store) {
            // Check if the value expression is linear wrt the loop variables
            internal_assert(loop_vars.size() > 0) << "No GPU loop variables found at texture store\n";

            // The value is the 5th argument to the intrinsic
            new_args[5] = mutate(new_args[5]);
            if (order == 1) {
                new_args[5] = tag_linear_expression(new_args[5]);
            }
        }

        // The texture lookup itself is counted as a non-linear operation
        order = 2;
        expr = Call::make(op->type, op->name, new_args, op->call_type,
                          op->func, op->value_index, op->image, op->param);
    }

    virtual void visit(const Let *op) {

        Expr mutated_value = mutate(op->value);
        int value_order = order;

        scope.push(op->name, order);

        Expr mutated_body = mutate(op->body);

        if ((value_order == 1) && (total_found < max_expressions)) {
            // Wrap the let value with a varying tag
            mutated_value = Call::make(mutated_value.type(), Call::glsl_varying,
                                       {op->name + ".varying", mutated_value},
                                       Call::Intrinsic);
            ++total_found;
        }

        expr = Let::make(op->name, mutated_value, mutated_body);

        scope.pop(op->name);
    }

    virtual void visit(const For *op) {
        bool old_in_glsl_loops = in_glsl_loops;
        bool kernel_loop = op->device_api == DeviceAPI::GLSL;
        bool within_kernel_loop = !kernel_loop && in_glsl_loops;
        // Check if the loop variable is a GPU variable thread variable and for GLSL
        if (kernel_loop) {
            loop_vars.push_back(op->name);
            in_glsl_loops = true;
        } else if (within_kernel_loop) {
            // The inner loop variable is non-linear w.r.t the glsl pixel coordinate.
            scope.push(op->name, 2);
        }

        Stmt mutated_body = mutate(op->body);

        if (kernel_loop) {
            loop_vars.pop_back();
        } else if (within_kernel_loop) {
            scope.pop(op->name);
        }

        in_glsl_loops = old_in_glsl_loops;

        if (mutated_body.same_as(op->body)) {
            stmt = op;
        } else {
            stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, mutated_body);
        }
    }

    virtual void visit(const Variable *op) {
        if (std::find(loop_vars.begin(), loop_vars.end(), op->name) != loop_vars.end()) {
            order = 1;
        } else if (scope.contains(op->name)) {
            order = scope.get(op->name);
        } else {
            // If the variable is not found in scope, then we assume it is
            // constant in terms of the independent variables.
            order = 0;
        }
        expr = op;
    }

    virtual void visit(const IntImm *op)    { order = 0; expr = op; }
    virtual void visit(const UIntImm *op)   { order = 0; expr = op; }
    virtual void visit(const FloatImm *op)  { order = 0; expr = op; }
    virtual void visit(const StringImm *op) { order = 0; expr = op; }

    virtual void visit(const Cast *op) {

        Expr mutated_value = mutate(op->value);
        int value_order = order;

        // We can only interpolate float values, disqualify the expression if
        // this is a cast to a different type
        if (order && (!op->type.is_float())) {
            order = 2;
        }

        if ((order > 1) && (value_order == 1)) {
            mutated_value = tag_linear_expression(mutated_value);
        }

        expr = Cast::make(op->type, mutated_value);
    }

    // Add and subtract do not make the expression non-linear, if it is already
    // linear or constant
    template<typename T>
    void visit_binary_linear(T *op) {
        Expr a = mutate(op->a);
        unsigned int order_a = order;
        Expr b = mutate(op->b);
        unsigned int order_b = order;

        order = std::max(order_a, order_b);

        // If the whole expression is greater than linear, check to see if
        // either argument is linear and if so, add it to a candidate list
        if ((order > 1) && (order_a == 1)) {
            a = tag_linear_expression(a);
        }
        if ((order > 1) && (order_b == 1)) {
            b = tag_linear_expression(b);
        }

        expr = T::make(a, b);
    }

    virtual void visit(const Add *op) { visit_binary_linear(op); }
    virtual void visit(const Sub *op) { visit_binary_linear(op); }

    // Multiplying increases the order of the expression, possibly making it
    // non-linear
    virtual void visit(const Mul *op) {
        Expr a = mutate(op->a);
        unsigned int order_a = order;
        Expr b = mutate(op->b);
        unsigned int order_b = order;

        order = order_a + order_b;

        // If the whole expression is greater than linear, check to see if
        // either argument is linear and if so, add it to a candidate list
        if ((order > 1) && (order_a == 1)) {
            a = tag_linear_expression(a);
        }
        if ((order > 1) && (order_b == 1)) {
            b = tag_linear_expression(b);
        }

        expr = Mul::make(a, b);
    }

    // Dividing is either multiplying by a constant, or makes the result
    // non-linear (i.e. order -1)
    virtual void visit(const Div *op) {
        Expr a = mutate(op->a);
        unsigned int order_a = order;
        Expr b = mutate(op->b);
        unsigned int order_b = order;

        if (order_a && !order_b) {
            // Case: x / c
            order = order_a;
        } else if (!order_a && order_b) {
            // Case: c / x
            order = 2;
        } else {
            order = order_a + order_b;
        }

        if ((order > 1) && (order_a == 1)) {
            a = tag_linear_expression(a);
        }
        if ((order > 1) && (order_b == 1)) {
            b = tag_linear_expression(b);
        }

        expr = Div::make(a, b);
    }

    // For other binary operators, if either argument is non-constant, then the
    // whole expression is non-linear
    template<typename T>
    void visit_binary(T *op) {

        Expr a = mutate(op->a);
        unsigned int order_a = order;
        Expr b = mutate(op->b);
        unsigned int order_b = order;

        if (order_a || order_b) {
            order = 2;
        }

        if ((order > 1) && (order_a == 1)) {
            a = tag_linear_expression(a);
        }
        if ((order > 1) && (order_b == 1)) {
            b = tag_linear_expression(b);
        }

        expr = T::make(a, b);
    }

    virtual void visit(const Mod *op) { visit_binary(op); }

    // Break the expression into a piecewise function, if the expressions are
    // linear, we treat the piecewise behavior specially during codegen

    // Once this is done, Min and Max should call visit_binary_linear and the code
    // in setup_mesh will handle piecewise linear behavior introduced by these
    // expressions
    virtual void visit(const Min *op) { visit_binary(op); }
    virtual void visit(const Max *op) { visit_binary(op); }

    virtual void visit(const EQ *op) { visit_binary(op); }
    virtual void visit(const NE *op) { visit_binary(op); }
    virtual void visit(const LT *op) { visit_binary(op); }
    virtual void visit(const LE *op) { visit_binary(op); }
    virtual void visit(const GT *op) { visit_binary(op); }
    virtual void visit(const GE *op) { visit_binary(op); }
    virtual void visit(const And *op) { visit_binary(op); }
    virtual void visit(const Or *op) { visit_binary(op); }

    virtual void visit(const Not *op) {
        Expr a = mutate(op->a);
        unsigned int order_a = order;

        if (order_a) {
            order = 2;
        }

        expr = Not::make(a);
    }

    virtual void visit(const Broadcast *op) {
        Expr a = mutate(op->value);

        if (order == 1) {
            a = tag_linear_expression(a);
        }

        if (order) {
            order = 2;
        }

        expr = Broadcast::make(a, op->lanes);
    }

    virtual void visit(const Select *op) {

        // If either the true expression or the false expression is non-linear
        // in terms of the loop variables, then the select expression might
        // evaluate to a non-linear expression and is disqualified.

        // If both are either linear or constant, and the condition expression
        // is constant with respect to the loop variables, then either the true
        // or false expression will be evaluated across the whole loop domain,
        // and the select expression is linear. Otherwise, the expression is
        // disqualified.

        // The condition expression must be constant (order == 0) with respect
        // to the loop variables.
        Expr mutated_condition = mutate(op->condition);
        int condition_order = (order != 0) ? 2 : 0;

        Expr mutated_true_value = mutate(op->true_value);
        int true_value_order = order;

        Expr mutated_false_value = mutate(op->false_value);
        int false_value_order = order;

        order = std::max(std::max(condition_order, true_value_order), false_value_order);

        if ((order > 1) && (condition_order == 1)) {
            mutated_condition = tag_linear_expression(mutated_condition);
        }
        if ((order > 1) && (true_value_order == 1)) {
            mutated_true_value = tag_linear_expression(mutated_true_value);
        }
        if ((order > 1) && (false_value_order == 1)) {
            mutated_false_value = tag_linear_expression(mutated_false_value);
        }

        expr = Select::make(mutated_condition, mutated_true_value, mutated_false_value);
    }

public:
    std::vector<std::string> loop_vars;

    Scope<int> scope;

    unsigned int order;
    bool found;

    unsigned int total_found;

    // This parameter controls the maximum number of linearly varying
    // expressions halide will pull out of the fragment shader and evaluate per
    // vertex, and allow the GPU to linearly interpolate across the domain. For
    // OpenGL ES 2.0 we can pass 16 vec4 varying attributes, or 64 scalars. Two
    // scalar slots are used by boilerplate code to pass pixel coordinates.
    const unsigned int max_expressions;

    FindLinearExpressions() : in_glsl_loops(false), total_found(0), max_expressions(62) {}
};

Stmt find_linear_expressions(Stmt s) {

    return FindLinearExpressions().mutate(s);
}

// This visitor produces a map containing name and expression pairs from varying
// tagged intrinsics
class FindVaryingAttributeTags : public IRVisitor
{
public:
    FindVaryingAttributeTags(std::map<std::string, Expr>& varyings_) : varyings(varyings_) { }

    using IRVisitor::visit;

    virtual void visit(const Call *op) {
        if (op->name == Call::glsl_varying) {
            std::string name = op->args[0].as<StringImm>()->value;
            varyings[name] = op->args[1];
        }
        IRVisitor::visit(op);
    }

    std::map<std::string, Expr>& varyings;
};

// This visitor removes glsl_varying intrinsics.
class RemoveVaryingAttributeTags : public IRMutator {
public:
    using IRMutator::visit;

    virtual void visit(const Call *op) {
        if (op->name == Call::glsl_varying) {
            // Replace the call expression with its wrapped argument expression
            expr = op->args[1];
        } else {
            IRMutator::visit(op);
        }
    }
};

Stmt remove_varying_attributes(Stmt s)
{
    return RemoveVaryingAttributeTags().mutate(s);
}

// This visitor removes glsl_varying intrinsics and replaces them with
// variables. After this visitor is called, the varying attribute expressions
// will no longer appear in the IR tree, only variables with the .varying tag
// will remain.
class ReplaceVaryingAttributeTags : public IRMutator {
public:
    using IRMutator::visit;

    virtual void visit(const Call *op) {
        if (op->name == Call::glsl_varying) {
            // Replace the intrinsic tag wrapper with a variable the variable
            // name ends with the tag ".varying"
            std::string name = op->args[0].as<StringImm>()->value;

            internal_assert(ends_with(name, ".varying"));

            expr = Variable::make(op->type, name);
        } else {
            IRMutator::visit(op);
        }
    }
};

Stmt replace_varying_attributes(Stmt s)
{
    return ReplaceVaryingAttributeTags().mutate(s);
}


// This visitor produces a set of variable names that are tagged with
// ".varying".
class FindVaryingAttributeVars : public IRVisitor {
public:
    using IRVisitor::visit;

    virtual void visit(const Variable *op) {
        if (ends_with(op->name, ".varying")) {
            variables.insert(op->name);
        }
    }

    std::set<std::string> variables;
};

// Remove varying attributes from the varying's map if they do not appear in the
// loop_stmt because they were simplified away.
void prune_varying_attributes(Stmt loop_stmt, std::map<std::string, Expr>& varying)
{
    FindVaryingAttributeVars find;
    loop_stmt.accept(&find);

    std::vector<std::string> remove_list;

    for (const std::pair<std::string, Expr> &i : varying) {
        const std::string &name = i.first;
        if (find.variables.find(name) == find.variables.end()) {
            debug(2) << "Removed varying attribute " << name << "\n";
            remove_list.push_back(name);
        }
    }

    for (const std::string &i : remove_list) {
        varying.erase(i);
    }
}

// This visitor changes the type of variables tagged with .varying to float,
// since GLSL will only interpolate floats. In the case that the type of the
// varying attribute was integer, the interpolated float value is snapped to the
// integer grid and cast to the integer type. This case occurs with coordinate
// expressions where the integer loop variables are manipulated without being
// converted to floating point. In other cases, like an affine transformation of
// image coordinates, the loop variables are cast to floating point within the
// interpolated expression.
class CastVaryingVariables : public IRMutator {
protected:
    using IRMutator::visit;

    virtual void visit(const Variable *op) {
        if ((ends_with(op->name, ".varying")) && (op->type != Float(32))) {
            // The incoming variable will be float type because GLSL only
            // interpolates floats
            Expr v = Variable::make(Float(32), op->name);

            // If the varying attribute expression that this variable replaced
            // was integer type, snap the interpolated floating point variable
            // back to the integer grid.
            expr = Cast::make(op->type, floor(v + 0.5f));
        } else {
            // Otherwise, the variable keeps its float type.
            expr = op;
        }
    }
};

// This visitor casts the named variables to float, and then propagates the
// float type through the expression. The variable is offset by 0.5f
class CastVariablesToFloatAndOffset : public IRMutator {
protected:
    using IRMutator::visit;

    virtual void visit(const Variable *op) {

        // Check to see if the variable matches a loop variable name
        if (std::find(names.begin(), names.end(), op->name) != names.end()) {
            // This case is used by integer type loop variables. They are cast
            // to float and offset.
            expr = Expr(op) - 0.5f;

        } else if (scope.contains(op->name) && (op->type != scope.get(op->name).type())) {
            // Otherwise, check to see if it is defined by a modified let
            // expression and if so, change the type of the variable to match
            // the modified expression
            expr = Variable::make(scope.get(op->name).type(), op->name);
        } else {
            expr = op;
        }
    }

    Type float_type(Expr e) {
        return Float(e.type().bits(), e.type().lanes());
    }

    template<typename T>
    void visit_binary_op(const T *op) {
        Expr mutated_a = mutate(op->a);
        Expr mutated_b = mutate(op->b);

        bool a_float = mutated_a.type().is_float();
        bool b_float = mutated_b.type().is_float();

        // If either argument is a float, then make sure both are float
        if (a_float || b_float) {
            if (!a_float) {
                mutated_a = Cast::make(float_type(op->b), mutated_a);
            }
            if (!b_float) {
                mutated_b = Cast::make(float_type(op->a), mutated_b);
            }
        }

        expr = T::make(mutated_a, mutated_b);
    }

    virtual void visit(const Add *op) { visit_binary_op(op); }
    virtual void visit(const Sub *op) { visit_binary_op(op); }
    virtual void visit(const Mul *op) { visit_binary_op(op); }
    virtual void visit(const Div *op) { visit_binary_op(op); }
    virtual void visit(const Mod *op) { visit_binary_op(op); }
    virtual void visit(const Min *op) { visit_binary_op(op); }
    virtual void visit(const Max *op) { visit_binary_op(op); }
    virtual void visit(const EQ *op) { visit_binary_op(op); }
    virtual void visit(const NE *op) { visit_binary_op(op); }
    virtual void visit(const LT *op) { visit_binary_op(op); }
    virtual void visit(const LE *op) { visit_binary_op(op); }
    virtual void visit(const GT *op) { visit_binary_op(op); }
    virtual void visit(const GE *op) { visit_binary_op(op); }
    virtual void visit(const And *op) { visit_binary_op(op); }
    virtual void visit(const Or *op) { visit_binary_op(op); }

    virtual void visit(const Select *op)  {
        Expr mutated_condition = mutate(op->condition);
        Expr mutated_true_value = mutate(op->true_value);
        Expr mutated_false_value = mutate(op->false_value);

        bool t_float = mutated_true_value.type().is_float();
        bool f_float = mutated_false_value.type().is_float();

        // If either argument is a float, then make sure both are float
        if (t_float || f_float) {
            if (!t_float) {
                mutated_true_value = Cast::make(float_type(op->true_value), mutated_true_value);
            }
            if (!f_float) {
                mutated_false_value = Cast::make(float_type(op->false_value), mutated_false_value);
            }
        }

        expr = Select::make(mutated_condition, mutated_true_value, mutated_false_value);
    }

    virtual void visit(const Ramp *op) {
        Expr mutated_base = mutate(op->base);
        Expr mutated_stride = mutate(op->stride);

        // If either base or stride is a float, then make sure both are float
        bool base_float = mutated_base.type().is_float();
        bool stride_float = mutated_stride.type().is_float();
        if (!base_float && stride_float) {
            mutated_base = Cast::make(float_type(op->base), mutated_base);
        }
        else if (base_float && !stride_float) {
            mutated_stride = Cast::make(float_type(op->stride), mutated_stride);
        }

        if (mutated_base.same_as(op->base) && mutated_stride.same_as(op->stride)) {
            expr = op;
        }
        else {
            expr = Ramp::make(mutated_base, mutated_stride, op->lanes);
        }
    }

    virtual void visit(const Let *op) {
        Expr mutated_value = mutate(op->value);

        bool changed = op->value.type().is_float() != mutated_value.type().is_float();
        if (changed) {
            scope.push(op->name, mutated_value);
        }

        Expr mutated_body = mutate(op->body);

        if (changed) {
            scope.pop(op->name);
        }

        expr = Let::make(op->name, mutated_value, mutated_body);
    }
    virtual void visit(const LetStmt *op) {

        Expr mutated_value = mutate(op->value);

        bool changed = op->value.type().is_float() != mutated_value.type().is_float();
        if (changed) {
            scope.push(op->name, mutated_value);
        }

        Stmt mutated_body = mutate(op->body);

        if (changed) {
            scope.pop(op->name);
        }

        stmt = LetStmt::make(op->name, mutated_value, mutated_body);
    }
public:
    CastVariablesToFloatAndOffset(const std::vector<std::string>& names_) : names(names_) { }

    const std::vector<std::string>& names;
    Scope<Expr> scope;
};

// This is the base class for a special mutator that, by default, turns an IR
// tree into a tree of Stmts. Derived classes overload visit methods to filter
// out specific expressions which are placed in Evaluate nodes within the new
// tree.  This functionality is used by GLSL varying attributes to transform
// tagged linear expressions into Store nodes for the vertex buffer. The
// IRFilter allows these expressions to be filtered out while maintaining the
// existing structure of Let variable scopes around them.
class IRFilter : public IRVisitor {
public:
    virtual Stmt mutate(Expr expr);
    virtual Stmt mutate(Stmt stmt);

protected:
    using IRVisitor::visit;

    Stmt stmt;

    virtual void visit(const IntImm *);
    virtual void visit(const FloatImm *);
    virtual void visit(const StringImm *);
    virtual void visit(const Cast *);
    virtual void visit(const Variable *);
    virtual void visit(const Add *);
    virtual void visit(const Sub *);
    virtual void visit(const Mul *);
    virtual void visit(const Div *);
    virtual void visit(const Mod *);
    virtual void visit(const Min *);
    virtual void visit(const Max *);
    virtual void visit(const EQ *);
    virtual void visit(const NE *);
    virtual void visit(const LT *);
    virtual void visit(const LE *);
    virtual void visit(const GT *);
    virtual void visit(const GE *);
    virtual void visit(const And *);
    virtual void visit(const Or *);
    virtual void visit(const Not *);
    virtual void visit(const Select *);
    virtual void visit(const Load *);
    virtual void visit(const Ramp *);
    virtual void visit(const Broadcast *);
    virtual void visit(const Call *);
    virtual void visit(const Let *);
    virtual void visit(const LetStmt *);
    virtual void visit(const AssertStmt *);
    virtual void visit(const ProducerConsumer *);
    virtual void visit(const For *);
    virtual void visit(const Store *);
    virtual void visit(const Provide *);
    virtual void visit(const Allocate *);
    virtual void visit(const Free *);
    virtual void visit(const Realize *);
    virtual void visit(const Block *);
    virtual void visit(const IfThenElse *);
    virtual void visit(const Evaluate *);
};

Stmt IRFilter::mutate(Expr e) {
    if (e.defined()) {
        e.accept(this);
    }
    else {
        stmt = Stmt();
    }
    return stmt;
}

Stmt IRFilter::mutate(Stmt s) {
    if (s.defined()) {
        s.accept(this);
    } else {
        stmt = Stmt();
    }
    return stmt;
}

namespace {
    template<typename T, typename A>
    void mutate_operator(IRFilter *mutator, const T *op, const A op_a, Stmt *stmt) {
        Stmt a = mutator->mutate(op_a);
        *stmt = a;
    }
    template<typename T, typename A, typename B>
    void mutate_operator(IRFilter *mutator, const T *op, const A op_a, const B op_b, Stmt *stmt) {
        Stmt a = mutator->mutate(op_a);
        Stmt b = mutator->mutate(op_b);
        *stmt = make_block(a, b);
    }
    template<typename T, typename A, typename B, typename C>
    void mutate_operator(IRFilter *mutator, const T *op, const A op_a, const B op_b, const C op_c, Stmt *stmt) {
        Stmt a = mutator->mutate(op_a);
        Stmt b = mutator->mutate(op_b);
        Stmt c = mutator->mutate(op_c);
        *stmt = make_block(make_block(a, b), c);
    }
}

void IRFilter::visit(const IntImm *op)   {stmt = Stmt();}
void IRFilter::visit(const FloatImm *op) {stmt = Stmt();}
void IRFilter::visit(const StringImm *op) {stmt = Stmt();}
void IRFilter::visit(const Variable *op) {stmt = Stmt();}

void IRFilter::visit(const Cast *op) {
    mutate_operator(this, op, op->value, &stmt);
}

void IRFilter::visit(const Add *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Sub *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Mul *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Div *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Mod *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Min *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Max *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const EQ *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const NE *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const LT *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const LE *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const GT *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const GE *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const And *op)     {mutate_operator(this, op, op->a, op->b, &stmt);}
void IRFilter::visit(const Or *op)      {mutate_operator(this, op, op->a, op->b, &stmt);}

void IRFilter::visit(const Not *op) {
    mutate_operator(this, op, op->a, &stmt);
}

void IRFilter::visit(const Select *op)  {
    mutate_operator(this, op, op->condition, op->true_value, op->false_value, &stmt);
}

void IRFilter::visit(const Load *op) {
    mutate_operator(this, op, op->predicate, op->index, &stmt);
}

void IRFilter::visit(const Ramp *op) {
    mutate_operator(this, op, op->base, op->stride, &stmt);
}

void IRFilter::visit(const Broadcast *op) {
    mutate_operator(this, op, op->value, &stmt);
}

void IRFilter::visit(const Call *op) {
    std::vector<Stmt> new_args(op->args.size());

    // Mutate the args
    for (size_t i = 0; i < op->args.size(); i++) {
        Expr old_arg = op->args[i];
        Stmt new_arg = mutate(old_arg);
        new_args[i] = new_arg;
    }

    stmt = Stmt();
    for (size_t i = 0; i < new_args.size(); ++i) {
        if (new_args[i].defined()) {
            stmt = make_block(new_args[i], stmt);
        }
    }
}

void IRFilter::visit(const Let *op) {
    mutate_operator(this, op, op->value, op->body, &stmt);
}

void IRFilter::visit(const LetStmt *op) {
    mutate_operator(this, op, op->value, op->body, &stmt);
}

void IRFilter::visit(const AssertStmt *op) {
    mutate_operator(this, op, op->condition, op->message, &stmt);
}

void IRFilter::visit(const ProducerConsumer *op) {
    mutate_operator(this, op, op->body, &stmt);
}

void IRFilter::visit(const For *op) {
    mutate_operator(this, op, op->min, op->extent, op->body, &stmt);
}

void IRFilter::visit(const Store *op) {
    mutate_operator(this, op, op->predicate, op->value, op->index, &stmt);
}

void IRFilter::visit(const Provide *op) {
    stmt = Stmt();
    for (size_t i = 0; i < op->args.size(); i++) {
        Stmt new_arg = mutate(op->args[i]);
        if (new_arg.defined()) {
            stmt = make_block(new_arg, stmt);
        }
        Stmt new_value = mutate(op->values[i]);
        if (new_value.defined()) {
            stmt = make_block(new_value, stmt);
        }
    }
}

void IRFilter::visit(const Allocate *op) {
    stmt = Stmt();
    for (size_t i = 0; i < op->extents.size(); i++) {
        Stmt new_extent = mutate(op->extents[i]);
        if (new_extent.defined())
            stmt = make_block(new_extent, stmt);
    }

    Stmt body = mutate(op->body);
    if (body.defined())
        stmt = make_block(body, stmt);

    Stmt condition = mutate(op->condition);
    if (condition.defined())
        stmt = make_block(condition, stmt);
}

void IRFilter::visit(const Free *op) {
}

void IRFilter::visit(const Realize *op) {

    stmt = Stmt();

    // Mutate the bounds
    for (size_t i = 0; i < op->bounds.size(); i++) {
        Expr old_min    = op->bounds[i].min;
        Expr old_extent = op->bounds[i].extent;
        Stmt new_min    = mutate(old_min);
        Stmt new_extent = mutate(old_extent);

        if (new_min.defined())
            stmt = make_block(new_min, stmt);
        if (new_extent.defined())
            stmt = make_block(new_extent, stmt);
    }

    Stmt body = mutate(op->body);
    if (body.defined())
        stmt = make_block(body, stmt);

    Stmt condition = mutate(op->condition);
    if (condition.defined())
        stmt = make_block(condition, stmt);
}

void IRFilter::visit(const Block *op) {
    mutate_operator(this, op, op->first, op->rest, &stmt);
}

void IRFilter::visit(const IfThenElse *op) {
    mutate_operator(this, op, op->condition, op->then_case, op->else_case, &stmt);
}

void IRFilter::visit(const Evaluate *op) {
    mutate_operator(this, op, op->value, &stmt);
}


// This visitor takes a IR tree containing a set of .glsl scheduled for-loops
// and creates a matching set of serial for-loops to setup a vertex buffer on
// the  host. The visitor  filters out glsl_varying intrinsics and transforms
// them into Store nodes to evaluate the linear expressions they tag within the
// scope of all of the Let definitions they fall within.
// The statement returned by this operation should be executed on the host
// before the call to halide_dev_run.
class CreateVertexBufferOnHost : public IRFilter {
public:
    using IRFilter::visit;

    virtual void visit(const Call *op) {

        // Transform glsl_varying intrinsics into store operations to output the
        // vertex coordinate values.
        if (op->name == Call::glsl_varying) {

            // Construct an expression for the offset of the coordinate value in
            // terms of the current integer loop variables and the varying
            // attribute channel number
            std::string attribute_name = op->args[0].as<StringImm>()->value;

            Expr offset_expression = Variable::make(Int(32), "gpu.vertex_offset") +
                                     attribute_order[attribute_name];

            stmt = Store::make(vertex_buffer_name, op->args[1], offset_expression,
                               Parameter(), const_true(op->args[1].type().lanes()));
        } else {
            IRFilter::visit(op);
        }

    }
    virtual void visit(const Let *op) {
        stmt = nullptr;

        Stmt mutated_value = mutate(op->value);
        Stmt mutated_body = mutate(op->body);

        // If an operation was filtered out of the body, also filter out the
        // whole let expression so that the body may be evaluated completely. In
        // the case that the let variable is not used in the mutated body, it
        // will be removed by simplification.
        if (mutated_body.defined()) {
            stmt = LetStmt::make(op->name, op->value, mutated_body);
        }

        // If an operation with a side effect was filtered out of the value, the
        // stmt'ified value is placed in a Block, so that the side effect will
        // be included in filtered IR tree.
        if (mutated_value.defined()) {
            stmt = make_block(mutated_value, stmt);
        }
    }

    virtual void visit(const LetStmt *op) {
        stmt = Stmt();

        Stmt mutated_value = mutate(op->value);
        Stmt mutated_body = mutate(op->body);

        if (mutated_body.defined()) {
            stmt = LetStmt::make(op->name, op->value, mutated_body);
        }

        if (mutated_value.defined()) {
            stmt = make_block(mutated_value, stmt);
        }
    }

    virtual void visit(const For *op) {
        if (CodeGen_GPU_Dev::is_gpu_var(op->name) && op->device_api == DeviceAPI::GLSL) {
            // Create a for-loop of integers iterating over the coordinates in
            // this dimension

            std::string name = op->name + ".idx";
            const std::vector<Expr>& dim = dims[op->name];

            internal_assert(for_loops.size() <= 1);
            for_loops.push_back(op);

            Expr loop_variable = Variable::make(Int(32),name);
            loop_variables.push_back(loop_variable);

            // TODO: When support for piecewise linear expressions is added this
            // expression must support more than two coordinates in each
            // dimension.
            Expr coord_expr = select(loop_variable == 0, dim[0], dim[1]);

            // Visit the body of the for-loop
            Stmt mutated_body = mutate(op->body);

            // If this was the inner most for-loop of the .glsl scheduled pair,
            // add a let definition for the vertex index and Store the spatial
            // coordinates
            const For *nested_for = op->body.as<For>();
            if (!(nested_for && CodeGen_GPU_Dev::is_gpu_var(nested_for->name))) {

                // Create a variable to store the offset in floats of this
                // vertex
                Expr gpu_varying_offset = Variable::make(Int(32), "gpu.vertex_offset");

                // Add expressions for the x and y vertex coordinates.
                Expr coord1 = cast<float>(Variable::make(Int(32),for_loops[0]->name));
                Expr coord0 = cast<float>(Variable::make(Int(32),for_loops[1]->name));

                // Transform the vertex coordinates to GPU device coordinates on
                // [-1,1]
                coord1 = (coord1 / for_loops[0]->extent) * 2.0f - 1.0f;
                coord0 = (coord0 / for_loops[1]->extent) * 2.0f - 1.0f;

                // Remove varying attribute intrinsics from the vertex setup IR
                // tree.
                mutated_body = remove_varying_attributes(mutated_body);

                // The GPU will take texture coordinates at pixel centers during
                // interpolation, we offset the Halide integer grid by 0.5 so that
                // these coordinates line up on integer coordinate values.
                std::vector<std::string> names = {for_loops[0]->name, for_loops[1]->name};
                CastVariablesToFloatAndOffset cast_and_offset(names);
                mutated_body = cast_and_offset.mutate(mutated_body);

                // Store the coordinates into the vertex buffer in interleaved
                // order
                mutated_body = make_block(Store::make(vertex_buffer_name,
                                                      coord1,
                                                      gpu_varying_offset + 1,
                                                      Parameter(), const_true()),
                                           mutated_body);

                mutated_body = make_block(Store::make(vertex_buffer_name,
                                                       coord0,
                                                       gpu_varying_offset + 0,
                                                       Parameter(), const_true()),
                                           mutated_body);

                // TODO: The value 2 in this expression must be changed to reflect
                // addition coordinate values in the fastest changing dimension when
                // support for piecewise linear functions is added
                Expr offset_expression = (loop_variables[0] * num_padded_attributes * 2) +
                (loop_variables[1] * num_padded_attributes);
                mutated_body = LetStmt::make("gpu.vertex_offset",
                                             offset_expression, mutated_body);

            }


            // Add a let statement for the for-loop name variable
            Stmt loop_var = LetStmt::make(op->name, coord_expr, mutated_body);

            stmt = For::make(name, 0, (int)dim.size(), ForType::Serial, DeviceAPI::None, loop_var);

        } else {
            IRFilter::visit(op);
        }
    }

    // The name of the previously allocated vertex buffer to store values
    std::string vertex_buffer_name;

    // Expressions for the spatial values of each coordinate in the GPU scheduled
    // loop dimensions.
    typedef std::map<std::string, std::vector<Expr>> DimsType;
    DimsType dims;

    // The channel of each varying attribute in the interleaved vertex buffer
    std::map<std::string, int> attribute_order;

    // The number of attributes padded up to the next multiple of four. This is
    // the stride from one vertex to the next in the buffer
    int num_padded_attributes;

    // Independent variable names in the linear expressions
    std::vector<const For*> for_loops;

    // Loop variables iterated across per GPU scheduled loop dimension to
    // construct the vertex buffer
    std::vector<Expr> loop_variables;
};

// These two methods provide a workaround to maintain unused let statements in
// the IR tree util calls are added that used them in codegen.

// TODO: We want to define a set of variables during lowering, and then use
// them during GLSL host codegen to pass values to the
// halide_dev_run function. It turns out that these variables will
// be simplified away since the call to the function does not appear
// in the IR. To avoid this we wrap the declaration in a
// return_second intrinsic as well as add a return_second intrinsic
// to consume the value.
// This prevents simplification passes that occur before codegen
// from removing the variables or substituting in their constant
// values.

Expr dont_simplify(Expr v_) {
    return Internal::Call::make(v_.type(),
                                Internal::Call::return_second,
                                {0, v_},
                                Internal::Call::Intrinsic);
}

Stmt used_in_codegen(Type type_, const std::string &v_) {
    return Evaluate::make(Internal::Call::make(Int(32),
                                               Internal::Call::return_second,
                                               {Variable::make(type_, v_), 0},
                                               Internal::Call::Intrinsic));
}


// This mutator inserts a set of serial for-loops to create the vertex buffer
// on the host using CreateVertexBufferOnHost above.
class CreateVertexBufferHostLoops : public IRMutator {
public:
    using IRMutator::visit;

    virtual void visit(const For *op) {
        if (CodeGen_GPU_Dev::is_gpu_var(op->name) && op->device_api == DeviceAPI::GLSL) {

            const For *loop1 = op;
            const For *loop0 = loop1->body.as<For>();

            internal_assert(loop1->body.as<For>()) << "Did not find pair of nested For loops";

            // Construct a mesh of expressions to instantiate during runtime
            std::map<std::string, Expr> varyings;

            FindVaryingAttributeTags tag_finder(varyings);
            op->accept(&tag_finder);

            // Establish and order for the attributes in each vertex
            std::map<std::string, int> attribute_order;

            // Add the attribute names to the mesh in the order that they appear in
            // each vertex
            attribute_order["__vertex_x"] = 0;
            attribute_order["__vertex_y"] = 1;

            int idx = 2;
            for (const std::pair<std::string, Expr> &v : varyings) {
                attribute_order[v.first] = idx++;
            }

            // Construct a list of expressions giving to coordinate locations along
            // each dimension, starting with the minimum and maximum coordinates

            attribute_order[loop0->name] = 0;
            attribute_order[loop1->name] = 1;

            Expr loop0_max = Add::make(loop0->min, loop0->extent);
            Expr loop1_max = Add::make(loop1->min, loop1->extent);

            std::vector<std::vector<Expr>> coords(2);

            coords[0].push_back(loop0->min);
            coords[0].push_back(loop0_max);

            coords[1].push_back(loop1->min);
            coords[1].push_back(loop1_max);

            // Count the two spatial x and y coordinates plus the number of
            // varying attribute expressions found
            int num_attributes = varyings.size() + 2;

            // Pad the number of attributes up to a multiple of four
            int num_padded_attributes = (num_attributes + 0x3) & ~0x3;
            int vertex_buffer_size = num_padded_attributes*coords[0].size()*coords[1].size();

            // Filter out varying attribute expressions from the glsl scheduled
            // loops. The expressions are filtered out in situ, among the
            // variables in scope
            CreateVertexBufferOnHost vs;
            vs.vertex_buffer_name = "glsl.vertex_buffer";
            vs.num_padded_attributes = num_padded_attributes;
            vs.dims[loop0->name] = coords[0];
            vs.dims[loop1->name] = coords[1];
            vs.attribute_order = attribute_order;

            Stmt vertex_setup = vs.mutate(loop1);

            // Remove varying attribute intrinsics from the vertex setup IR
            // tree. These may occur if an expression such as a Let-value was
            // filtered out without being mutated.
            vertex_setup = remove_varying_attributes(vertex_setup);

            // Simplify the new host code.  Workaround for #588
            vertex_setup = simplify(vertex_setup);
            vertex_setup = simplify(vertex_setup);
            vertex_setup = simplify(vertex_setup);
            vertex_setup = simplify(vertex_setup);

            // Replace varying attribute intriniscs in the gpu scheduled loops
            // with variables with ".varying" tagged names
            Stmt loop_stmt = replace_varying_attributes(op);

            // Simplify
            loop_stmt = simplify(loop_stmt, true);

            // It is possible that linear expressions we tagged in higher-level
            // intrinsics were removed by simplification if they were only used in
            // subsequent tagged linear expressions. Run a pass to check for
            // these and remove them from the varying attribute list
            prune_varying_attributes(loop_stmt, varyings);

            // At this point the varying attribute expressions have been removed from
            // loop_stmt- it only contains variables tagged with .varying

            // The GPU will only interpolate floating point values so the varying
            // attribute variables must be converted to floating point. If the
            // original varying expression was integer, casts are inserts to
            // snap the value back to the integer grid.
            loop_stmt = CastVaryingVariables().mutate(loop_stmt);

            // Insert two new for-loops for vertex buffer generation on the host
            // before the two GPU scheduled for-loops
            stmt = LetStmt::make("glsl.num_coords_dim0", dont_simplify((int)(coords[0].size())),
                   LetStmt::make("glsl.num_coords_dim1", dont_simplify((int)(coords[1].size())),
                   LetStmt::make("glsl.num_padded_attributes", dont_simplify(num_padded_attributes),
                   Allocate::make(vs.vertex_buffer_name, Float(32), {vertex_buffer_size}, const_true(),
                   Block::make(vertex_setup,
                   Block::make(loop_stmt,
                   Block::make(used_in_codegen(Int(32), "glsl.num_coords_dim0"),
                   Block::make(used_in_codegen(Int(32), "glsl.num_coords_dim1"),
                   Block::make(used_in_codegen(Int(32), "glsl.num_padded_attributes"),
                   Free::make(vs.vertex_buffer_name))))))))));
        } else {
            IRMutator::visit(op);
        }
    }
};

Stmt setup_gpu_vertex_buffer(Stmt s) {

    CreateVertexBufferHostLoops vb;
    return vb.mutate(s);
}

}
}

/* [<][>][^][v][top][bottom][index][help] */