#ifndef HALIDE_GENERATOR_H_
#define HALIDE_GENERATOR_H_
#include <algorithm>
#include <iterator>
#include <limits>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include "Func.h"
#include "ExternalCode.h"
#include "Introspection.h"
#include "ObjectInstanceRegistry.h"
#include "ScheduleParam.h"
#include "Target.h"
namespace Halide {
template<typename T> class Buffer;
namespace Internal {
EXPORT void generator_test();
class ValueTracker {
private:
std::map<std::string, std::vector<std::vector<Expr>>> values_history;
const size_t max_unique_values;
public:
explicit ValueTracker(size_t max_unique_values = 2) : max_unique_values(max_unique_values) {}
EXPORT void track_values(const std::string &name, const std::vector<Expr> &values);
};
EXPORT std::vector<Expr> parameter_constraints(const Parameter &p);
template <typename T>
NO_INLINE std::string enum_to_string(const std::map<std::string, T> &enum_map, const T& t) {
for (auto key_value : enum_map) {
if (t == key_value.second) {
return key_value.first;
}
}
user_error << "Enumeration value not found.\n";
return "";
}
template <typename T>
T enum_from_string(const std::map<std::string, T> &enum_map, const std::string& s) {
auto it = enum_map.find(s);
user_assert(it != enum_map.end()) << "Enumeration value not found: " << s << "\n";
return it->second;
}
EXPORT extern const std::map<std::string, Halide::Type> &get_halide_type_enum_map();
inline std::string halide_type_to_enum_string(const Type &t) {
return enum_to_string(get_halide_type_enum_map(), t);
}
EXPORT extern const std::map<std::string, Halide::LoopLevel> &get_halide_looplevel_enum_map();
inline std::string halide_looplevel_to_enum_string(const LoopLevel &loop_level){
return enum_to_string(get_halide_looplevel_enum_map(), loop_level);
}
EXPORT std::string halide_type_to_c_source(const Type &t);
EXPORT std::string halide_type_to_c_type(const Type &t);
EXPORT int generate_filter_main(int argc, char **argv, std::ostream &cerr);
template<bool B, typename T>
struct cond {
static constexpr bool value = B;
using type = T;
};
template <typename First, typename... Rest>
struct select_type : std::conditional<First::value, typename First::type, typename select_type<Rest...>::type> { };
template<typename First>
struct select_type<First> { using type = typename std::conditional<First::value, typename First::type, void>::type; };
class GeneratorBase;
class GeneratorParamBase {
public:
EXPORT explicit GeneratorParamBase(const std::string &name);
EXPORT virtual ~GeneratorParamBase();
const std::string name;
#define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
virtual void set(const TYPE &new_value) = 0;
HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
#undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
void set(const std::string &new_value) { set_from_string(new_value); }
void set(const char *new_value) { set_from_string(std::string(new_value)); }
protected:
friend class GeneratorBase;
friend class StubEmitter;
EXPORT void check_value_readable() const;
EXPORT void check_value_writable() const;
virtual void set_from_string(const std::string &value_string) = 0;
virtual std::string to_string() const = 0;
virtual std::string call_to_string(const std::string &v) const = 0;
virtual std::string get_c_type() const = 0;
virtual std::string get_type_decls() const {
return "";
}
virtual std::string get_default_value() const {
return to_string();
}
virtual std::string get_template_type() const {
return get_c_type();
}
virtual std::string get_template_value() const {
return get_default_value();
}
virtual bool is_synthetic_param() const {
return false;
}
EXPORT void fail_wrong_type(const char *type);
private:
explicit GeneratorParamBase(const GeneratorParamBase &) = delete;
void operator=(const GeneratorParamBase &) = delete;
GeneratorBase *generator{nullptr};
};
template<typename T>
class GeneratorParamImpl : public GeneratorParamBase {
public:
using type = T;
GeneratorParamImpl(const std::string &name, const T &value) : GeneratorParamBase(name), value_(value) {}
T value() const { check_value_readable(); return value_; }
operator T() const { return this->value(); }
operator Expr() const { return make_const(type_of<T>(), this->value()); }
#define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
void set(const TYPE &new_value) override { typed_setter_impl<TYPE>(new_value, #TYPE); }
HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
#undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
protected:
virtual void set_impl(const T &new_value) { check_value_writable(); value_ = new_value; }
private:
T value_;
template <typename T2, typename std::enable_if<std::is_convertible<T2, T>::value>::type * = nullptr>
HALIDE_ALWAYS_INLINE void typed_setter_impl(const T2 &value, const char * msg) {
if (!std::is_same<T, T2>::value &&
std::is_arithmetic<T>::value &&
std::is_arithmetic<T2>::value) {
const T t = Convert<T2, T>::value(value);
const T2 t2 = Convert<T, T2>::value(t);
if (t2 != value) {
fail_wrong_type(msg);
}
}
value_ = Convert<T2, T>::value(value);
}
template <typename T2, typename std::enable_if<!std::is_convertible<T2, T>::value>::type * = nullptr>
HALIDE_ALWAYS_INLINE void typed_setter_impl(const T2 &, const char *msg) {
fail_wrong_type(msg);
}
};
template<typename T>
class GeneratorParam_Target : public GeneratorParamImpl<T> {
public:
GeneratorParam_Target(const std::string &name, const T &value) : GeneratorParamImpl<T>(name, value) {}
void set_from_string(const std::string &new_value_string) override {
this->set(Target(new_value_string));
}
std::string to_string() const override {
return this->value().to_string();
}
std::string call_to_string(const std::string &v) const override {
std::ostringstream oss;
oss << v << ".to_string()";
return oss.str();
}
std::string get_c_type() const override {
return "Target";
}
};
template<typename T>
class GeneratorParam_Arithmetic : public GeneratorParamImpl<T> {
public:
GeneratorParam_Arithmetic(const std::string &name,
const T &value,
const T &min = std::numeric_limits<T>::lowest(),
const T &max = std::numeric_limits<T>::max())
: GeneratorParamImpl<T>(name, value), min(min), max(max) {
this->set(value);
}
void set_impl(const T &new_value) override {
user_assert(new_value >= min && new_value <= max) << "Value out of range: " << new_value;
GeneratorParamImpl<T>::set_impl(new_value);
}
void set_from_string(const std::string &new_value_string) override {
std::istringstream iss(new_value_string);
T t;
iss >> t;
user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << new_value_string;
this->set(t);
}
std::string to_string() const override {
std::ostringstream oss;
oss << this->value();
if (std::is_same<T, float>::value) {
if (oss.str().find(".") == std::string::npos) {
oss << ".";
}
oss << "f";
}
return oss.str();
}
std::string call_to_string(const std::string &v) const override {
std::ostringstream oss;
oss << "std::to_string(" << v << ")";
return oss.str();
}
std::string get_c_type() const override {
std::ostringstream oss;
if (std::is_same<T, float>::value) {
return "float";
} else if (std::is_same<T, double>::value) {
return "double";
} else if (std::is_integral<T>::value) {
if (std::is_unsigned<T>::value) oss << 'u';
oss << "int" << (sizeof(T) * 8) << "_t";
return oss.str();
} else {
user_error << "Unknown arithmetic type\n";
return "";
}
}
private:
const T min, max;
};
template<typename T>
class GeneratorParam_Bool : public GeneratorParam_Arithmetic<T> {
public:
GeneratorParam_Bool(const std::string &name, const T &value) : GeneratorParam_Arithmetic<T>(name, value) {}
void set_from_string(const std::string &new_value_string) override {
bool v = false;
if (new_value_string == "true") {
v = true;
} else if (new_value_string == "false") {
v = false;
} else {
user_assert(false) << "Unable to parse bool: " << new_value_string;
}
this->set(v);
}
std::string to_string() const override {
return this->value() ? "true" : "false";
}
std::string call_to_string(const std::string &v) const override {
std::ostringstream oss;
oss << "(" << v << ") ? \"true\" : \"false\"";
return oss.str();
}
std::string get_c_type() const override {
return "bool";
}
};
template<typename T>
class GeneratorParam_Enum : public GeneratorParamImpl<T> {
public:
GeneratorParam_Enum(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
: GeneratorParamImpl<T>(name, value), enum_map(enum_map) {}
void set_from_string(const std::string &new_value_string) override {
auto it = enum_map.find(new_value_string);
user_assert(it != enum_map.end()) << "Enumeration value not found: " << new_value_string;
this->set_impl(it->second);
}
std::string to_string() const override {
return enum_to_string(enum_map, this->value());
}
std::string call_to_string(const std::string &v) const override {
return "Enum_" + this->name + "_map().at(" + v + ")";
}
std::string get_c_type() const override {
return "Enum_" + this->name;
}
std::string get_default_value() const override {
return "Enum_" + this->name + "::" + enum_to_string(enum_map, this->value());
}
std::string get_type_decls() const override {
std::ostringstream oss;
oss << "enum class Enum_" << this->name << " {\n";
for (auto key_value : enum_map) {
oss << " " << key_value.first << ",\n";
}
oss << "};\n";
oss << "\n";
oss << "inline NO_INLINE const std::map<Enum_" << this->name << ", std::string>& Enum_" << this->name << "_map() {\n";
oss << " static const std::map<Enum_" << this->name << ", std::string> m = {\n";
for (auto key_value : enum_map) {
oss << " { Enum_" << this->name << "::" << key_value.first << ", \"" << key_value.first << "\"},\n";
}
oss << " };\n";
oss << " return m;\n";
oss << "};\n";
return oss.str();
}
private:
const std::map<std::string, T> enum_map;
};
template<typename T>
class GeneratorParam_Type : public GeneratorParam_Enum<T> {
public:
GeneratorParam_Type(const std::string &name, const T &value)
: GeneratorParam_Enum<T>(name, value, get_halide_type_enum_map()) {}
std::string call_to_string(const std::string &v) const override {
return "Halide::Internal::halide_type_to_enum_string(" + v + ")";
}
std::string get_c_type() const override {
return "Type";
}
std::string get_template_type() const override {
return "typename";
}
std::string get_template_value() const override {
return halide_type_to_c_type(this->value());
}
std::string get_default_value() const override {
return halide_type_to_c_source(this->value());
}
std::string get_type_decls() const override {
return "";
}
};
template<typename T>
using GeneratorParamImplBase =
typename select_type<
cond<std::is_same<T, Target>::value, GeneratorParam_Target<T>>,
cond<std::is_same<T, Type>::value, GeneratorParam_Type<T>>,
cond<std::is_same<T, bool>::value, GeneratorParam_Bool<T>>,
cond<std::is_arithmetic<T>::value, GeneratorParam_Arithmetic<T>>,
cond<std::is_enum<T>::value, GeneratorParam_Enum<T>>
>::type;
}
template <typename T>
class GeneratorParam : public Internal::GeneratorParamImplBase<T> {
public:
GeneratorParam(const std::string &name, const T &value)
: Internal::GeneratorParamImplBase<T>(name, value) {}
GeneratorParam(const std::string &name, const T &value, const T &min, const T &max)
: Internal::GeneratorParamImplBase<T>(name, value, min, max) {}
GeneratorParam(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
: Internal::GeneratorParamImplBase<T>(name, value, enum_map) {}
GeneratorParam(const std::string &name, const std::string &value)
: Internal::GeneratorParamImplBase<T>(name, value) {}
};
template <typename Other, typename T>
decltype((Other)0 + (T)0) operator+(const Other &a, const GeneratorParam<T> &b) { return a + (T)b; }
template <typename Other, typename T>
decltype((T)0 + (Other)0) operator+(const GeneratorParam<T> &a, const Other & b) { return (T)a + b; }
template <typename Other, typename T>
decltype((Other)0 - (T)0) operator-(const Other & a, const GeneratorParam<T> &b) { return a - (T)b; }
template <typename Other, typename T>
decltype((T)0 - (Other)0) operator-(const GeneratorParam<T> &a, const Other & b) { return (T)a - b; }
template <typename Other, typename T>
decltype((Other)0 * (T)0) operator*(const Other &a, const GeneratorParam<T> &b) { return a * (T)b; }
template <typename Other, typename T>
decltype((Other)0 * (T)0) operator*(const GeneratorParam<T> &a, const Other &b) { return (T)a * b; }
template <typename Other, typename T>
decltype((Other)0 / (T)1) operator/(const Other &a, const GeneratorParam<T> &b) { return a / (T)b; }
template <typename Other, typename T>
decltype((T)0 / (Other)1) operator/(const GeneratorParam<T> &a, const Other &b) { return (T)a / b; }
template <typename Other, typename T>
decltype((Other)0 % (T)1) operator%(const Other &a, const GeneratorParam<T> &b) { return a % (T)b; }
template <typename Other, typename T>
decltype((T)0 % (Other)1) operator%(const GeneratorParam<T> &a, const Other &b) { return (T)a % b; }
template <typename Other, typename T>
decltype((Other)0 > (T)1) operator>(const Other &a, const GeneratorParam<T> &b) { return a > (T)b; }
template <typename Other, typename T>
decltype((T)0 > (Other)1) operator>(const GeneratorParam<T> &a, const Other &b) { return (T)a > b; }
template <typename Other, typename T>
decltype((Other)0 < (T)1) operator<(const Other &a, const GeneratorParam<T> &b) { return a < (T)b; }
template <typename Other, typename T>
decltype((T)0 < (Other)1) operator<(const GeneratorParam<T> &a, const Other &b) { return (T)a < b; }
template <typename Other, typename T>
decltype((Other)0 >= (T)1) operator>=(const Other &a, const GeneratorParam<T> &b) { return a >= (T)b; }
template <typename Other, typename T>
decltype((T)0 >= (Other)1) operator>=(const GeneratorParam<T> &a, const Other &b) { return (T)a >= b; }
template <typename Other, typename T>
decltype((Other)0 <= (T)1) operator<=(const Other &a, const GeneratorParam<T> &b) { return a <= (T)b; }
template <typename Other, typename T>
decltype((T)0 <= (Other)1) operator<=(const GeneratorParam<T> &a, const Other &b) { return (T)a <= b; }
template <typename Other, typename T>
decltype((Other)0 == (T)1) operator==(const Other &a, const GeneratorParam<T> &b) { return a == (T)b; }
template <typename Other, typename T>
decltype((T)0 == (Other)1) operator==(const GeneratorParam<T> &a, const Other &b) { return (T)a == b; }
template <typename Other, typename T>
decltype((Other)0 != (T)1) operator!=(const Other &a, const GeneratorParam<T> &b) { return a != (T)b; }
template <typename Other, typename T>
decltype((T)0 != (Other)1) operator!=(const GeneratorParam<T> &a, const Other &b) { return (T)a != b; }
template <typename Other, typename T>
decltype((Other)0 && (T)1) operator&&(const Other &a, const GeneratorParam<T> &b) { return a && (T)b; }
template <typename Other, typename T>
decltype((T)0 && (Other)1) operator&&(const GeneratorParam<T> &a, const Other &b) { return (T)a && b; }
template <typename Other, typename T>
decltype((Other)0 || (T)1) operator||(const Other &a, const GeneratorParam<T> &b) { return a || (T)b; }
template <typename Other, typename T>
decltype((T)0 || (Other)1) operator||(const GeneratorParam<T> &a, const Other &b) { return (T)a || b; }
namespace Internal { namespace GeneratorMinMax {
using std::max;
using std::min;
template <typename Other, typename T>
decltype(min((Other)0, (T)1)) min_forward(const Other &a, const GeneratorParam<T> &b) { return min(a, (T)b); }
template <typename Other, typename T>
decltype(min((T)0, (Other)1)) min_forward(const GeneratorParam<T> &a, const Other &b) { return min((T)a, b); }
template <typename Other, typename T>
decltype(max((Other)0, (T)1)) max_forward(const Other &a, const GeneratorParam<T> &b) { return max(a, (T)b); }
template <typename Other, typename T>
decltype(max((T)0, (Other)1)) max_forward(const GeneratorParam<T> &a, const Other &b) { return max((T)a, b); }
}}
template <typename Other, typename T>
auto min(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
return Internal::GeneratorMinMax::min_forward(a, b);
}
template <typename Other, typename T>
auto min(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
return Internal::GeneratorMinMax::min_forward(a, b);
}
template <typename Other, typename T>
auto max(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
return Internal::GeneratorMinMax::max_forward(a, b);
}
template <typename Other, typename T>
auto max(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
return Internal::GeneratorMinMax::max_forward(a, b);
}
template <typename T>
decltype(!(T)0) operator!(const GeneratorParam<T> &a) { return !(T)a; }
namespace Internal {
template<typename T2> class GeneratorInput_Buffer;
enum class IOKind { Scalar, Function, Buffer };
template<typename T = void>
class StubInputBuffer {
friend class StubInput;
template<typename T2> friend class GeneratorInput_Buffer;
Parameter parameter_;
NO_INLINE explicit StubInputBuffer(const Parameter &p) : parameter_(p) {
Buffer<> other(p.type(), nullptr, std::vector<int>(p.dimensions(), 1));
internal_assert((Buffer<T>::can_convert_from(other)));
}
template<typename T2>
NO_INLINE static Parameter parameter_from_buffer(const Buffer<T2> &b) {
user_assert((Buffer<T>::can_convert_from(b)));
Parameter p(b.type(), true, b.dimensions());
p.set_buffer(b);
return p;
}
public:
StubInputBuffer() {}
template<typename T2>
StubInputBuffer(const Buffer<T2> &b) : parameter_(parameter_from_buffer(b)) {}
};
class StubOutputBufferBase {
protected:
Func f;
std::shared_ptr<GeneratorBase> generator;
EXPORT void check_scheduled(const char* m) const;
EXPORT Target get_target() const;
explicit StubOutputBufferBase(const Func &f, std::shared_ptr<GeneratorBase> generator) : f(f), generator(generator) {}
StubOutputBufferBase() {}
public:
Realization realize(std::vector<int32_t> sizes) {
check_scheduled("realize");
return f.realize(sizes, get_target());
}
template <typename... Args>
Realization realize(Args&&... args) {
check_scheduled("realize");
return f.realize(std::forward<Args>(args)..., get_target());
}
template<typename Dst>
void realize(Dst dst) {
check_scheduled("realize");
f.realize(dst, get_target());
}
};
template<typename T = void>
class StubOutputBuffer : public StubOutputBufferBase {
template<typename T2> friend class GeneratorOutput_Buffer;
friend class GeneratorStub;
explicit StubOutputBuffer(const Func &f, std::shared_ptr<GeneratorBase> generator) : StubOutputBufferBase(f, generator) {}
public:
StubOutputBuffer() {}
};
class StubInput {
const IOKind kind_;
const Parameter parameter_;
const Func func_;
const Expr expr_;
public:
template<typename T2>
StubInput(const StubInputBuffer<T2> &b) : kind_(IOKind::Buffer), parameter_(b.parameter_) {}
StubInput(const Func &f) : kind_(IOKind::Function), func_(f) {}
StubInput(const Expr &e) : kind_(IOKind::Scalar), expr_(e) {}
private:
friend class GeneratorInputBase;
IOKind kind() const {
return kind_;
}
Parameter parameter() const {
internal_assert(kind_ == IOKind::Buffer);
return parameter_;
}
Func func() const {
internal_assert(kind_ == IOKind::Function);
return func_;
}
Expr expr() const {
internal_assert(kind_ == IOKind::Scalar);
return expr_;
}
};
class Constrainable {
public:
virtual ~Constrainable() {}
virtual Parameter parameter() const = 0;
Dimension dim(int i) {
return Dimension(parameter(), i);
}
const Dimension dim(int i) const {
return Dimension(parameter(), i);
}
int host_alignment() const {
return parameter().host_alignment();
}
Constrainable &set_host_alignment(int alignment) {
parameter().set_host_alignment(alignment);
return *this;
}
};
class GIOBase {
public:
EXPORT bool array_size_defined() const;
EXPORT size_t array_size() const;
EXPORT virtual bool is_array() const;
EXPORT const std::string &name() const;
EXPORT IOKind kind() const;
EXPORT bool types_defined() const;
EXPORT const std::vector<Type> &types() const;
EXPORT Type type() const;
EXPORT bool dimensions_defined() const;
EXPORT int dimensions() const;
EXPORT const std::vector<Func> &funcs() const;
EXPORT const std::vector<Expr> &exprs() const;
protected:
EXPORT GIOBase(size_t array_size,
const std::string &name,
IOKind kind,
const std::vector<Type> &types,
int dimensions);
EXPORT virtual ~GIOBase();
friend class GeneratorBase;
int array_size_;
const std::string name_;
const IOKind kind_;
std::vector<Type> types_;
int dimensions_;
std::vector<Func> funcs_;
std::vector<Expr> exprs_;
GeneratorBase *generator{nullptr};
EXPORT std::string array_name(size_t i) const;
EXPORT virtual void verify_internals() const;
EXPORT void check_matching_array_size(size_t size);
EXPORT void check_matching_type_and_dim(const std::vector<Type> &t, int d);
template<typename ElemType>
const std::vector<ElemType> &get_values() const;
virtual bool allow_synthetic_generator_params() const {
return true;
}
virtual Parameter parameter() const {
internal_error << "Unimplemented";
return Parameter();
}
virtual void check_value_writable() const = 0;
private:
template<typename T> friend class GeneratorParam_Synthetic;
explicit GIOBase(const GIOBase &) = delete;
void operator=(const GIOBase &) = delete;
};
template<>
inline const std::vector<Expr> &GIOBase::get_values<Expr>() const {
return exprs();
}
template<>
inline const std::vector<Func> &GIOBase::get_values<Func>() const {
return funcs();
}
class GeneratorInputBase : public GIOBase {
protected:
EXPORT GeneratorInputBase(size_t array_size,
const std::string &name,
IOKind kind,
const std::vector<Type> &t,
int d);
EXPORT GeneratorInputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d);
EXPORT ~GeneratorInputBase() override;
friend class GeneratorBase;
std::vector<Parameter> parameters_;
EXPORT void init_internals();
EXPORT void set_inputs(const std::vector<StubInput> &inputs);
EXPORT virtual void set_def_min_max();
EXPORT void verify_internals() const override;
friend class StubEmitter;
virtual std::string get_c_type() const = 0;
EXPORT void check_value_writable() const override;
private:
EXPORT void init_parameters();
};
template<typename T, typename ValueType>
class GeneratorInputImpl : public GeneratorInputBase {
protected:
using TBase = typename std::remove_all_extents<T>::type;
bool is_array() const override {
return std::is_array<T>::value;
}
template <typename T2 = T, typename std::enable_if<
!std::is_array<T2>::value
>::type * = nullptr>
GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorInputBase(name, kind, t, d) {
}
template <typename T2 = T, typename std::enable_if<
std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)
>::type * = nullptr>
GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorInputBase(std::extent<T2, 0>::value, name, kind, t, d) {
}
template <typename T2 = T, typename std::enable_if<
std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0
>::type * = nullptr>
GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorInputBase(-1, name, kind, t, d) {
}
public:
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
size_t size() const {
return get_values<ValueType>().size();
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
const ValueType &operator[](size_t i) const {
return get_values<ValueType>()[i];
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
const ValueType &at(size_t i) const {
return get_values<ValueType>().at(i);
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
typename std::vector<ValueType>::const_iterator begin() const {
return get_values<ValueType>().begin();
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
typename std::vector<ValueType>::const_iterator end() const {
return get_values<ValueType>().end();
}
};
template<typename T>
class GeneratorInput_Buffer : public GeneratorInputImpl<T, Func>, public Constrainable {
private:
using Super = GeneratorInputImpl<T, Func>;
protected:
using TBase = typename Super::TBase;
friend class ::Halide::Func;
friend class ::Halide::Stage;
bool allow_synthetic_generator_params() const override {
return !T::has_static_halide_type();
}
std::string get_c_type() const override {
if (T::has_static_halide_type()) {
return "Halide::Internal::StubInputBuffer<" +
halide_type_to_c_type(T::static_halide_type()) +
">";
} else {
return "Halide::Internal::StubInputBuffer<>";
}
}
Parameter parameter() const override {
internal_assert(this->parameters_.size() == 1);
return this->parameters_.at(0);
}
public:
GeneratorInput_Buffer(const std::string &name)
: Super(name, IOKind::Buffer,
T::has_static_halide_type() ? std::vector<Type>{ T::static_halide_type() } : std::vector<Type>{},
-1) {
}
GeneratorInput_Buffer(const std::string &name, const Type &t, int d = -1)
: Super(name, IOKind::Buffer, {t}, d) {
static_assert(!T::has_static_halide_type(), "Cannot use pass a Type argument for a Buffer with a non-void static type");
}
GeneratorInput_Buffer(const std::string &name, int d)
: Super(name, IOKind::Buffer, T::has_static_halide_type() ? std::vector<Type>{ T::static_halide_type() } : std::vector<Type>{}, d) {
}
template <typename... Args>
Expr operator()(Args&&... args) const {
return this->funcs().at(0)(std::forward<Args>(args)...);
}
Expr operator()(std::vector<Expr> args) const {
return this->funcs().at(0)(args);
}
template<typename T2>
operator StubInputBuffer<T2>() const {
return StubInputBuffer<T2>(parameter());
}
operator Func() const {
return this->funcs().at(0);
}
};
template<typename T>
class GeneratorInput_Func : public GeneratorInputImpl<T, Func> {
private:
using Super = GeneratorInputImpl<T, Func>;
protected:
using TBase = typename Super::TBase;
std::string get_c_type() const override {
return "Func";
}
public:
GeneratorInput_Func(const std::string &name, const Type &t, int d)
: Super(name, IOKind::Function, {t}, d) {
}
GeneratorInput_Func(const std::string &name, int d)
: Super(name, IOKind::Function, {}, d) {
}
GeneratorInput_Func(const std::string &name, const Type &t)
: Super(name, IOKind::Function, {t}, -1) {
}
GeneratorInput_Func(const std::string &name)
: Super(name, IOKind::Function, {}, -1) {
}
GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t, int d)
: Super(array_size, name, IOKind::Function, {t}, d) {
}
GeneratorInput_Func(size_t array_size, const std::string &name, int d)
: Super(array_size, name, IOKind::Function, {}, d) {
}
GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t)
: Super(array_size, name, IOKind::Function, {t}, -1) {
}
GeneratorInput_Func(size_t array_size, const std::string &name)
: Super(array_size, name, IOKind::Function, {}, -1) {
}
template <typename... Args>
Expr operator()(Args&&... args) const {
return this->funcs().at(0)(std::forward<Args>(args)...);
}
Expr operator()(std::vector<Expr> args) const {
return this->funcs().at(0)(args);
}
operator Func() const {
return this->funcs().at(0);
}
};
template<typename T>
class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {
private:
using Super = GeneratorInputImpl<T, Expr>;
protected:
using TBase = typename Super::TBase;
const TBase def_{TBase()};
protected:
void set_def_min_max() override {
for (Parameter &p : this->parameters_) {
p.set_scalar<TBase>(def_);
}
}
std::string get_c_type() const override {
return "Expr";
}
public:
explicit GeneratorInput_Scalar(const std::string &name,
const TBase &def = static_cast<TBase>(0))
: Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def) {
}
GeneratorInput_Scalar(size_t array_size,
const std::string &name,
const TBase &def = static_cast<TBase>(0))
: Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def) {
}
operator Expr() const {
return this->exprs().at(0);
}
operator ExternFuncArgument() const {
return ExternFuncArgument(this->exprs().at(0));
}
};
template<typename T>
class GeneratorInput_Arithmetic : public GeneratorInput_Scalar<T> {
private:
using Super = GeneratorInput_Scalar<T>;
protected:
using TBase = typename Super::TBase;
const Expr min_, max_;
protected:
void set_def_min_max() override {
GeneratorInput_Scalar<T>::set_def_min_max();
if (!std::is_same<TBase, bool>::value) {
for (Parameter &p : this->parameters_) {
if (min_.defined()) p.set_min_value(min_);
if (max_.defined()) p.set_max_value(max_);
}
}
}
public:
explicit GeneratorInput_Arithmetic(const std::string &name,
const TBase &def = static_cast<TBase>(0))
: Super(name, def), min_(Expr()), max_(Expr()) {
}
GeneratorInput_Arithmetic(size_t array_size,
const std::string &name,
const TBase &def = static_cast<TBase>(0))
: Super(array_size, name, def), min_(Expr()), max_(Expr()) {
}
GeneratorInput_Arithmetic(const std::string &name,
const TBase &def,
const TBase &min,
const TBase &max)
: Super(name, def), min_(min), max_(max) {
}
GeneratorInput_Arithmetic(size_t array_size,
const std::string &name,
const TBase &def,
const TBase &min,
const TBase &max)
: Super(array_size, name, def), min_(min), max_(max) {
}
};
template<typename>
struct type_sink { typedef void type; };
template<typename T2, typename = void>
struct has_static_halide_type_method : std::false_type {};
template<typename T2>
struct has_static_halide_type_method<T2, typename type_sink<decltype(T2::static_halide_type())>::type> : std::true_type {};
template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
using GeneratorInputImplBase =
typename select_type<
cond<has_static_halide_type_method<TBase>::value, GeneratorInput_Buffer<T>>,
cond<std::is_same<TBase, Func>::value, GeneratorInput_Func<T>>,
cond<std::is_arithmetic<TBase>::value, GeneratorInput_Arithmetic<T>>,
cond<std::is_scalar<TBase>::value, GeneratorInput_Scalar<T>>
>::type;
}
template <typename T>
class GeneratorInput : public Internal::GeneratorInputImplBase<T> {
private:
using Super = Internal::GeneratorInputImplBase<T>;
protected:
using TBase = typename Super::TBase;
struct Unused;
using IntIfNonScalar =
typename Internal::select_type<
Internal::cond<Internal::has_static_halide_type_method<TBase>::value, int>,
Internal::cond<std::is_same<TBase, Func>::value, int>,
Internal::cond<true, Unused>
>::type;
public:
explicit GeneratorInput(const std::string &name)
: Super(name) {
}
GeneratorInput(const std::string &name, const TBase &def)
: Super(name, def) {
}
GeneratorInput(size_t array_size, const std::string &name, const TBase &def)
: Super(array_size, name, def) {
}
GeneratorInput(const std::string &name,
const TBase &def, const TBase &min, const TBase &max)
: Super(name, def, min, max) {
}
GeneratorInput(size_t array_size, const std::string &name,
const TBase &def, const TBase &min, const TBase &max)
: Super(array_size, name, def, min, max) {
}
GeneratorInput(const std::string &name, const Type &t, int d)
: Super(name, t, d) {
}
GeneratorInput(const std::string &name, const Type &t)
: Super(name, t) {
}
GeneratorInput(const std::string &name, IntIfNonScalar d)
: Super(name, d) {
}
GeneratorInput(size_t array_size, const std::string &name, const Type &t, int d)
: Super(array_size, name, t, d) {
}
GeneratorInput(size_t array_size, const std::string &name, const Type &t)
: Super(array_size, name, t) {
}
GeneratorInput(size_t array_size, const std::string &name, IntIfNonScalar d)
: Super(array_size, name, d) {
}
GeneratorInput(size_t array_size, const std::string &name)
: Super(array_size, name) {
}
};
namespace Internal {
class GeneratorOutputBase : public GIOBase {
public:
#define HALIDE_OUTPUT_FORWARD(method) \
template<typename ...Args> \
inline auto method(Args&&... args) -> \
decltype(std::declval<Func>().method(std::forward<Args>(args)...)) {\
return get_func_ref().method(std::forward<Args>(args)...); \
}
#define HALIDE_OUTPUT_FORWARD_CONST(method) \
template<typename ...Args> \
inline auto method(Args&&... args) const -> \
decltype(std::declval<Func>().method(std::forward<Args>(args)...)) {\
return get_func_ref().method(std::forward<Args>(args)...); \
}
HALIDE_OUTPUT_FORWARD(align_bounds)
HALIDE_OUTPUT_FORWARD(align_storage)
HALIDE_OUTPUT_FORWARD_CONST(args)
HALIDE_OUTPUT_FORWARD(bound)
HALIDE_OUTPUT_FORWARD(bound_extent)
HALIDE_OUTPUT_FORWARD(compute_at)
HALIDE_OUTPUT_FORWARD(compute_inline)
HALIDE_OUTPUT_FORWARD(compute_root)
HALIDE_OUTPUT_FORWARD_CONST(defined)
HALIDE_OUTPUT_FORWARD(fold_storage)
HALIDE_OUTPUT_FORWARD(fuse)
HALIDE_OUTPUT_FORWARD(glsl)
HALIDE_OUTPUT_FORWARD(gpu)
HALIDE_OUTPUT_FORWARD(gpu_blocks)
HALIDE_OUTPUT_FORWARD(gpu_single_thread)
HALIDE_OUTPUT_FORWARD(gpu_threads)
HALIDE_OUTPUT_FORWARD(gpu_tile)
HALIDE_OUTPUT_FORWARD_CONST(has_update_definition)
HALIDE_OUTPUT_FORWARD(hexagon)
HALIDE_OUTPUT_FORWARD(in)
HALIDE_OUTPUT_FORWARD(memoize)
HALIDE_OUTPUT_FORWARD_CONST(num_update_definitions)
HALIDE_OUTPUT_FORWARD_CONST(output_types)
HALIDE_OUTPUT_FORWARD_CONST(outputs)
HALIDE_OUTPUT_FORWARD(parallel)
HALIDE_OUTPUT_FORWARD(prefetch)
HALIDE_OUTPUT_FORWARD(rename)
HALIDE_OUTPUT_FORWARD(reorder)
HALIDE_OUTPUT_FORWARD(reorder_storage)
HALIDE_OUTPUT_FORWARD_CONST(rvars)
HALIDE_OUTPUT_FORWARD(serial)
HALIDE_OUTPUT_FORWARD(shader)
HALIDE_OUTPUT_FORWARD(specialize)
HALIDE_OUTPUT_FORWARD(specialize_fail)
HALIDE_OUTPUT_FORWARD(split)
HALIDE_OUTPUT_FORWARD(store_at)
HALIDE_OUTPUT_FORWARD(store_root)
HALIDE_OUTPUT_FORWARD(tile)
HALIDE_OUTPUT_FORWARD(unroll)
HALIDE_OUTPUT_FORWARD(update)
HALIDE_OUTPUT_FORWARD_CONST(update_args)
HALIDE_OUTPUT_FORWARD_CONST(update_value)
HALIDE_OUTPUT_FORWARD_CONST(update_values)
HALIDE_OUTPUT_FORWARD_CONST(value)
HALIDE_OUTPUT_FORWARD_CONST(values)
HALIDE_OUTPUT_FORWARD(vectorize)
#undef HALIDE_OUTPUT_FORWARD
protected:
EXPORT GeneratorOutputBase(size_t array_size,
const std::string &name,
IOKind kind,
const std::vector<Type> &t,
int d);
EXPORT GeneratorOutputBase(const std::string &name,
IOKind kind,
const std::vector<Type> &t,
int d);
EXPORT ~GeneratorOutputBase() override;
friend class GeneratorBase;
friend class StubEmitter;
EXPORT void init_internals();
EXPORT void resize(size_t size);
virtual std::string get_c_type() const {
return "Func";
}
EXPORT void check_value_writable() const override;
NO_INLINE Func &get_func_ref() {
internal_assert(kind() != IOKind::Scalar);
internal_assert(funcs_.size() == array_size() && exprs_.empty());
return funcs_[0];
}
NO_INLINE const Func &get_func_ref() const {
internal_assert(kind() != IOKind::Scalar);
internal_assert(funcs_.size() == array_size() && exprs_.empty());
return funcs_[0];
}
};
template<typename T>
class GeneratorOutputImpl : public GeneratorOutputBase {
protected:
using TBase = typename std::remove_all_extents<T>::type;
using ValueType = Func;
bool is_array() const override {
return std::is_array<T>::value;
}
template <typename T2 = T, typename std::enable_if<
!std::is_array<T2>::value
>::type * = nullptr>
GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorOutputBase(name, kind, t, d) {
}
template <typename T2 = T, typename std::enable_if<
std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)
>::type * = nullptr>
GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorOutputBase(std::extent<T2, 0>::value, name, kind, t, d) {
}
template <typename T2 = T, typename std::enable_if<
std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0
>::type * = nullptr>
GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
: GeneratorOutputBase(-1, name, kind, t, d) {
}
public:
template <typename... Args, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
FuncRef operator()(Args&&... args) const {
return get_values<ValueType>().at(0)(std::forward<Args>(args)...);
}
template <typename ExprOrVar, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
FuncRef operator()(std::vector<ExprOrVar> args) const {
return get_values<ValueType>().at(0)(args);
}
template <typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
operator Func() const {
return get_values<ValueType>().at(0);
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
size_t size() const {
return get_values<ValueType>().size();
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
const ValueType &operator[](size_t i) const {
return get_values<ValueType>()[i];
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
const ValueType &at(size_t i) const {
return get_values<ValueType>().at(i);
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
typename std::vector<ValueType>::const_iterator begin() const {
return get_values<ValueType>().begin();
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
typename std::vector<ValueType>::const_iterator end() const {
return get_values<ValueType>().end();
}
template <typename T2 = T, typename std::enable_if<
std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0
>::type * = nullptr>
void resize(size_t size) {
GeneratorOutputBase::resize(size);
}
};
template<typename T>
class GeneratorOutput_Buffer : public GeneratorOutputImpl<T>, public Constrainable {
private:
using Super = GeneratorOutputImpl<T>;
protected:
using TBase = typename Super::TBase;
protected:
GeneratorOutput_Buffer(const std::string &name)
: Super(name, IOKind::Buffer,
T::has_static_halide_type() ? std::vector<Type>{ T::static_halide_type() } : std::vector<Type>{},
-1) {
}
GeneratorOutput_Buffer(const std::string &name, const std::vector<Type> &t, int d = -1)
: Super(name, IOKind::Buffer,
T::has_static_halide_type() ? std::vector<Type>{ T::static_halide_type() } : t,
d) {
if (T::has_static_halide_type()) {
user_assert(t.empty()) << "Cannot use pass a Type argument for a Buffer with a non-void static type\n";
} else {
user_assert(t.size() <= 1) << "Output<Buffer<>>(" << name << ") requires at most one Type, but has " << t.size() << "\n";
}
}
GeneratorOutput_Buffer(const std::string &name, int d)
: Super(name, IOKind::Buffer, std::vector<Type>{ T::static_halide_type() }, d) {
static_assert(T::has_static_halide_type(), "Must pass a Type argument for a Buffer with a static type of void");
}
NO_INLINE std::string get_c_type() const override {
if (T::has_static_halide_type()) {
return "Halide::Internal::StubOutputBuffer<" +
halide_type_to_c_type(T::static_halide_type()) +
">";
} else {
return "Halide::Internal::StubOutputBuffer<>";
}
}
Parameter parameter() const override {
internal_assert(this->funcs().size() == 1);
return this->funcs().at(0).output_buffer().parameter();
}
public:
template<typename T2>
NO_INLINE GeneratorOutput_Buffer<T> &operator=(Buffer<T2> &buffer) {
this->check_value_writable();
user_assert(T::can_convert_from(buffer))
<< "Cannot assign to the Output \"" << this->name()
<< "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
if (this->types_defined()) {
user_assert(Type(buffer.type()) == this->type())
<< "Output should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n";
}
if (this->dimensions_defined()) {
user_assert(buffer.dimensions() == this->dimensions())
<< "Output should have dim=" << this->dimensions() << " but saw dim=" << buffer.dimensions() << "\n";
}
internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
user_assert(!this->funcs_.at(0).defined());
this->funcs_.at(0)(_) = buffer(_);
return *this;
}
template<typename T2>
NO_INLINE GeneratorOutput_Buffer<T> &operator=(const StubOutputBuffer<T2> &stub_output_buffer) {
this->check_value_writable();
const auto &f = stub_output_buffer.f;
internal_assert(f.defined());
const auto &output_types = f.output_types();
user_assert(output_types.size() == 1)
<< "Output should have size=1 but saw size=" << output_types.size() << "\n";
Buffer<> other(output_types.at(0), nullptr, std::vector<int>(f.dimensions(), 1));
user_assert(T::can_convert_from(other))
<< "Cannot assign to the Output \"" << this->name()
<< "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
if (this->types_defined()) {
user_assert(output_types.at(0) == this->type())
<< "Output should have type=" << this->type() << " but saw type=" << output_types.at(0) << "\n";
}
if (this->dimensions_defined()) {
user_assert(f.dimensions() == this->dimensions())
<< "Output should have dim=" << this->dimensions() << " but saw dim=" << f.dimensions() << "\n";
}
internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
user_assert(!this->funcs_.at(0).defined());
this->funcs_[0] = f;
return *this;
}
};
template<typename T>
class GeneratorOutput_Func : public GeneratorOutputImpl<T> {
private:
using Super = GeneratorOutputImpl<T>;
NO_INLINE Func &get_assignable_func_ref(size_t i) {
internal_assert(this->exprs_.empty() && this->funcs_.size() > i);
return this->funcs_.at(i);
}
protected:
using TBase = typename Super::TBase;
protected:
GeneratorOutput_Func(const std::string &name, const std::vector<Type> &t, int d)
: Super(name, IOKind::Function, t, d) {
}
GeneratorOutput_Func(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
: Super(array_size, name, IOKind::Function, t, d) {
}
public:
template <typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
GeneratorOutput_Func<T> &operator=(const Func &f) {
this->check_value_writable();
get_assignable_func_ref(0) = f;
return *this;
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
Func &operator[](size_t i) {
this->check_value_writable();
return get_assignable_func_ref(i);
}
template <typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
const Func &operator[](size_t i) const {
return Super::operator[](i);
}
};
template<typename T>
class GeneratorOutput_Arithmetic : public GeneratorOutputImpl<T> {
private:
using Super = GeneratorOutputImpl<T>;
protected:
using TBase = typename Super::TBase;
protected:
explicit GeneratorOutput_Arithmetic(const std::string &name)
: Super(name, IOKind::Function, {type_of<TBase>()}, 0) {
}
GeneratorOutput_Arithmetic(size_t array_size, const std::string &name)
: Super(array_size, name, IOKind::Function, {type_of<TBase>()}, 0) {
}
};
template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
using GeneratorOutputImplBase =
typename select_type<
cond<has_static_halide_type_method<TBase>::value, GeneratorOutput_Buffer<T>>,
cond<std::is_same<TBase, Func>::value, GeneratorOutput_Func<T>>,
cond<std::is_arithmetic<TBase>::value, GeneratorOutput_Arithmetic<T>>
>::type;
}
template <typename T>
class GeneratorOutput : public Internal::GeneratorOutputImplBase<T> {
private:
using Super = Internal::GeneratorOutputImplBase<T>;
protected:
using TBase = typename Super::TBase;
public:
explicit GeneratorOutput(const std::string &name)
: Super(name) {
}
explicit GeneratorOutput(const char *name)
: GeneratorOutput(std::string(name)) {
}
GeneratorOutput(size_t array_size, const std::string &name)
: Super(array_size, name) {
}
GeneratorOutput(const std::string &name, int d)
: Super(name, {}, d) {
}
GeneratorOutput(const std::string &name, const Type &t, int d)
: Super(name, {t}, d) {
}
GeneratorOutput(const std::string &name, const std::vector<Type> &t, int d)
: Super(name, t, d) {
}
GeneratorOutput(size_t array_size, const std::string &name, int d)
: Super(array_size, name, {}, d) {
}
GeneratorOutput(size_t array_size, const std::string &name, const Type &t, int d)
: Super(array_size, name, {t}, d) {
}
GeneratorOutput(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
: Super(array_size, name, t, d) {
}
template <typename T2>
GeneratorOutput<T> &operator=(Buffer<T2> &buffer) {
Super::operator=(buffer);
return *this;
}
template <typename T2>
GeneratorOutput<T> &operator=(const Internal::StubOutputBuffer<T2> &stub_output_buffer) {
Super::operator=(stub_output_buffer);
return *this;
}
GeneratorOutput<T> &operator=(const Func &f) {
Super::operator=(f);
return *this;
}
};
namespace Internal {
template<typename T>
T parse_scalar(const std::string &value) {
std::istringstream iss(value);
T t;
iss >> t;
user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << value;
return t;
}
EXPORT std::vector<Type> parse_halide_type_list(const std::string &types);
template<typename T>
class GeneratorParam_Synthetic : public GeneratorParamImpl<T> {
public:
void set_from_string(const std::string &new_value_string) override {
set_from_string_impl<T>(new_value_string);
}
std::string to_string() const override {
internal_error;
return std::string();
}
std::string call_to_string(const std::string &v) const override {
internal_error;
return std::string();
}
std::string get_c_type() const override {
internal_error;
return std::string();
}
bool is_synthetic_param() const override {
return true;
}
private:
friend class GeneratorBase;
enum Which { Type, Dim, ArraySize };
GeneratorParam_Synthetic(const std::string &name, GIOBase &gio, Which which) : GeneratorParamImpl<T>(name, T()), gio(gio), which(which) {}
template <typename T2 = T, typename std::enable_if<std::is_same<T2, ::Halide::Type>::value>::type * = nullptr>
void set_from_string_impl(const std::string &new_value_string) {
internal_assert(which == Type);
gio.types_ = parse_halide_type_list(new_value_string);
}
template <typename T2 = T, typename std::enable_if<std::is_integral<T2>::value>::type * = nullptr>
void set_from_string_impl(const std::string &new_value_string) {
if (which == Dim) {
gio.dimensions_ = parse_scalar<T2>(new_value_string);
} else if (which == ArraySize) {
gio.array_size_ = parse_scalar<T2>(new_value_string);
} else {
internal_error;
}
}
GIOBase &gio;
const Which which;
};
class GeneratorStub;
}
class GeneratorContext {
public:
virtual ~GeneratorContext() {};
virtual Target get_target() const = 0;
using ExternsMap = std::map<std::string, ExternalCode>;
virtual std::shared_ptr<ExternsMap> get_externs_map() const = 0;
protected:
friend class Internal::GeneratorBase;
virtual std::shared_ptr<Internal::ValueTracker> get_value_tracker() const = 0;
};
class JITGeneratorContext : public GeneratorContext {
public:
explicit JITGeneratorContext(const Target &t)
: target(t)
, externs_map(std::make_shared<ExternsMap>())
, value_tracker(std::make_shared<Internal::ValueTracker>()) {}
Target get_target() const override { return target; }
std::shared_ptr<ExternsMap> get_externs_map() const override { return externs_map; }
protected:
std::shared_ptr<Internal::ValueTracker> get_value_tracker() const override { return value_tracker; }
private:
const Target target;
const std::shared_ptr<ExternsMap> externs_map;
const std::shared_ptr<Internal::ValueTracker> value_tracker;
};
class NamesInterface {
protected:
using Expr = Halide::Expr;
using ExternFuncArgument = Halide::ExternFuncArgument;
using Func = Halide::Func;
using GeneratorContext = Halide::GeneratorContext;
using ImageParam = Halide::ImageParam;
using LoopLevel = Halide::LoopLevel;
using Pipeline = Halide::Pipeline;
using RDom = Halide::RDom;
using TailStrategy = Halide::TailStrategy;
using Target = Halide::Target;
using Tuple = Halide::Tuple;
using Type = Halide::Type;
using Var = Halide::Var;
using NameMangling = Halide::NameMangling;
template <typename T> static Expr cast(Expr e) { return Halide::cast<T>(e); }
static inline Expr cast(Halide::Type t, Expr e) { return Halide::cast(t, e); }
template <typename T> using GeneratorParam = Halide::GeneratorParam<T>;
template <typename T> using ScheduleParam = Halide::ScheduleParam<T>;
template <typename T = void> using Buffer = Halide::Buffer<T>;
template <typename T> using Param = Halide::Param<T>;
static inline Type Bool(int lanes = 1) { return Halide::Bool(lanes); }
static inline Type Float(int bits, int lanes = 1) { return Halide::Float(bits, lanes); }
static inline Type Int(int bits, int lanes = 1) { return Halide::Int(bits, lanes); }
static inline Type UInt(int bits, int lanes = 1) { return Halide::UInt(bits, lanes); }
};
namespace Internal {
template<typename ...Args>
struct NoRealizations : std::false_type {};
template<>
struct NoRealizations<> : std::true_type {};
template<typename T, typename ...Args>
struct NoRealizations<T, Args...> {
static const bool value = !std::is_convertible<T, Realization>::value && NoRealizations<Args...>::value;
};
class GeneratorStub;
class SimpleGeneratorFactory;
class GeneratorBase : public NamesInterface, public GeneratorContext {
public:
GeneratorParam<Target> target{ "target", Halide::get_host_target() };
struct EmitOptions {
bool emit_o, emit_h, emit_cpp, emit_assembly, emit_bitcode, emit_stmt, emit_stmt_html, emit_static_library, emit_cpp_stub;
std::map<std::string, std::string> substitutions;
EmitOptions()
: emit_o(false), emit_h(true), emit_cpp(false), emit_assembly(false),
emit_bitcode(false), emit_stmt(false), emit_stmt_html(false), emit_static_library(true), emit_cpp_stub(false) {}
};
EXPORT virtual ~GeneratorBase();
Target get_target() const override { return target; }
EXPORT void set_generator_param(const std::string &name, const std::string &value);
EXPORT void set_generator_and_schedule_param_values(const std::map<std::string, std::string> ¶ms);
template<typename T>
GeneratorBase &set_generator_param(const std::string &name, const T &value) {
find_generator_param_by_name(name).set(value);
return *this;
}
template<typename T>
GeneratorBase &set_schedule_param(const std::string &name, const T &value) {
find_schedule_param_by_name(name).set(value);
return *this;
}
int natural_vector_size(Halide::Type t) const {
return get_target().natural_vector_size(t);
}
template <typename data_t>
int natural_vector_size() const {
return get_target().natural_vector_size<data_t>();
}
EXPORT void emit_cpp_stub(const std::string &stub_file_path);
EXPORT Module build_module(const std::string &function_name = "",
const LoweredFunc::LinkageType linkage_type = LoweredFunc::ExternalPlusMetadata);
template <typename... Args>
void set_inputs(const Args &...args) {
ParamInfo &pi = param_info();
user_assert(sizeof...(args) == pi.filter_inputs.size())
<< "Expected exactly " << pi.filter_inputs.size()
<< " inputs but got " << sizeof...(args) << "\n";
set_inputs_vector(build_inputs(std::forward_as_tuple<const Args &...>(args...), make_index_sequence<sizeof...(Args)>{}));
}
Realization realize(std::vector<int32_t> sizes) {
check_scheduled("realize");
return get_pipeline().realize(sizes, get_target());
}
template <typename... Args, typename std::enable_if<NoRealizations<Args...>::value>::type * = nullptr>
Realization realize(Args&&... args) {
check_scheduled("realize");
return get_pipeline().realize(std::forward<Args>(args)..., get_target());
}
void realize(Realization r) {
check_scheduled("realize");
get_pipeline().realize(r, get_target());
}
EXPORT Pipeline get_pipeline();
EXPORT std::shared_ptr<ExternsMap> get_externs_map() const override;
protected:
EXPORT GeneratorBase(size_t size, const void *introspection_helper);
EXPORT void init_from_context(const Halide::GeneratorContext &context);
EXPORT virtual Pipeline build_pipeline() = 0;
EXPORT virtual void call_generate() = 0;
EXPORT virtual void call_schedule() = 0;
std::shared_ptr<ValueTracker> get_value_tracker() const override { return value_tracker; }
EXPORT void track_parameter_values(bool include_outputs);
EXPORT void pre_build();
EXPORT void post_build();
EXPORT void pre_generate();
EXPORT void post_generate();
EXPORT void pre_schedule();
EXPORT void post_schedule();
template<typename T>
using Input = GeneratorInput<T>;
template<typename T>
using Output = GeneratorOutput<T>;
template<typename T>
using ScheduleParam = ScheduleParam<T>;
enum Phase {
Created,
InputsSet,
GenerateCalled,
ScheduleCalled,
} phase{Created};
void check_exact_phase(Phase expected_phase) const;
void check_min_phase(Phase expected_phase) const;
void advance_phase(Phase new_phase);
private:
friend void ::Halide::Internal::generator_test();
friend class GeneratorParamBase;
friend class GeneratorInputBase;
friend class GeneratorOutputBase;
friend class GeneratorStub;
friend class SimpleGeneratorFactory;
friend class StubOutputBufferBase;
struct ParamInfo {
EXPORT ParamInfo(GeneratorBase *generator, const size_t size);
std::vector<Internal::GeneratorParamBase *> generator_params;
std::vector<Internal::ScheduleParamBase *> schedule_params;
std::vector<Internal::GeneratorInputBase *> filter_inputs;
std::vector<Internal::GeneratorOutputBase *> filter_outputs;
std::vector<Internal::Parameter *> filter_params;
std::map<std::string, Internal::GeneratorParamBase *> generator_params_by_name;
std::map<std::string, Internal::ScheduleParamBase *> schedule_params_by_name;
private:
std::vector<std::unique_ptr<Internal::GeneratorParamBase>> owned_synthetic_params;
};
const size_t size;
std::unique_ptr<ParamInfo> param_info_ptr;
std::shared_ptr<Internal::ValueTracker> value_tracker;
mutable std::shared_ptr<ExternsMap> externs_map;
bool inputs_set{false};
std::string generator_name;
Pipeline pipeline;
EXPORT ParamInfo ¶m_info();
EXPORT Internal::GeneratorParamBase &find_generator_param_by_name(const std::string &name);
EXPORT Internal::ScheduleParamBase &find_schedule_param_by_name(const std::string &name);
EXPORT void check_scheduled(const char* m) const;
EXPORT void build_params(bool force = false);
void get_host_target();
void get_jit_target_from_environment();
void get_target_from_environment();
EXPORT Func get_first_output();
EXPORT Func get_output(const std::string &n);
EXPORT std::vector<Func> get_output_vector(const std::string &n);
void set_generator_name(const std::string &n) {
internal_assert(generator_name.empty());
generator_name = n;
}
EXPORT void set_inputs_vector(const std::vector<std::vector<StubInput>> &inputs);
EXPORT static void check_input_is_singular(Internal::GeneratorInputBase *in);
EXPORT static void check_input_is_array(Internal::GeneratorInputBase *in);
EXPORT static void check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind);
template<typename T>
std::vector<StubInput> build_input(size_t i, const Buffer<T> &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_is_singular(in);
const auto k = in->kind();
if (k == Internal::IOKind::Buffer) {
Halide::Buffer<> b = arg;
StubInputBuffer<> sib(b);
StubInput si(sib);
return {si};
} else if (k == Internal::IOKind::Function) {
Halide::Func f(arg.name() + "_im");
f(Halide::_) = arg(Halide::_);
StubInput si(f);
return {si};
} else {
check_input_kind(in, Internal::IOKind::Buffer);
return {};
}
}
template<typename T>
std::vector<StubInput> build_input(size_t i, const GeneratorInput<Buffer<T>> &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_is_singular(in);
const auto k = in->kind();
if (k == Internal::IOKind::Buffer) {
StubInputBuffer<> sib = arg;
StubInput si(sib);
return {si};
} else if (k == Internal::IOKind::Function) {
Halide::Func f = arg.funcs().at(0);
StubInput si(f);
return {si};
} else {
check_input_kind(in, Internal::IOKind::Buffer);
return {};
}
}
std::vector<StubInput> build_input(size_t i, const Func &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Function);
check_input_is_singular(in);
Halide::Func f = arg;
StubInput si(f);
return {si};
}
std::vector<StubInput> build_input(size_t i, const std::vector<Func> &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Function);
check_input_is_array(in);
std::vector<StubInput> siv;
siv.reserve(arg.size());
for (const auto &f : arg) {
siv.emplace_back(f);
}
return siv;
}
std::vector<StubInput> build_input(size_t i, const Expr &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Scalar);
check_input_is_singular(in);
StubInput si(arg);
return {si};
}
std::vector<StubInput> build_input(size_t i, const std::vector<Expr> &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Scalar);
check_input_is_array(in);
std::vector<StubInput> siv;
siv.reserve(arg.size());
for (const auto &value : arg) {
siv.emplace_back(value);
}
return siv;
}
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
std::vector<StubInput> build_input(size_t i, const T &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Scalar);
check_input_is_singular(in);
Expr e(arg);
StubInput si(e);
return {si};
}
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
std::vector<StubInput> build_input(size_t i, const std::vector<T> &arg) {
auto *in = param_info().filter_inputs.at(i);
check_input_kind(in, Internal::IOKind::Scalar);
check_input_is_array(in);
std::vector<StubInput> siv;
siv.reserve(arg.size());
for (const auto &value : arg) {
Expr e(value);
siv.emplace_back(e);
}
return siv;
}
template<typename... Args, size_t... Indices>
std::vector<std::vector<StubInput>> build_inputs(const std::tuple<const Args &...>& t, index_sequence<Indices...>) {
return {build_input(Indices, std::get<Indices>(t))...};
}
GeneratorBase(const GeneratorBase &) = delete;
void operator=(const GeneratorBase &) = delete;
GeneratorBase(GeneratorBase&& that) = delete;
void operator=(GeneratorBase&& that) = delete;
};
class GeneratorFactory {
public:
virtual ~GeneratorFactory() {}
virtual std::unique_ptr<GeneratorBase> create(const GeneratorContext &context,
const std::map<std::string, std::string> ¶ms) const = 0;
};
using GeneratorCreateFunc = std::function<std::unique_ptr<Internal::GeneratorBase>(const GeneratorContext &context)>;
class SimpleGeneratorFactory : public GeneratorFactory {
public:
SimpleGeneratorFactory(GeneratorCreateFunc create_func, const std::string &generator_name)
: create_func(create_func), generator_name(generator_name) {
internal_assert(create_func != nullptr);
}
std::unique_ptr<Internal::GeneratorBase> create(const GeneratorContext &context,
const std::map<std::string, std::string> ¶ms) const override {
auto g = create_func(context);
internal_assert(g.get() != nullptr);
g->set_generator_name(generator_name);
g->set_generator_and_schedule_param_values(params);
return g;
}
private:
const GeneratorCreateFunc create_func;
const std::string generator_name;
};
class GeneratorRegistry {
public:
EXPORT static void register_factory(const std::string &name, std::unique_ptr<GeneratorFactory> factory);
EXPORT static void unregister_factory(const std::string &name);
EXPORT static std::vector<std::string> enumerate();
EXPORT static std::unique_ptr<GeneratorBase> create(const std::string &name,
const GeneratorContext &context,
const std::map<std::string, std::string> ¶ms);
private:
using GeneratorFactoryMap = std::map<const std::string, std::unique_ptr<GeneratorFactory>>;
GeneratorFactoryMap factories;
std::mutex mutex;
EXPORT static GeneratorRegistry &get_registry();
GeneratorRegistry() {}
GeneratorRegistry(const GeneratorRegistry &) = delete;
void operator=(const GeneratorRegistry &) = delete;
};
}
template <class T>
class Generator : public Internal::GeneratorBase {
protected:
Generator() :
Internal::GeneratorBase(sizeof(T),
Internal::Introspection::get_introspection_helper<T>()) {}
public:
static std::unique_ptr<Internal::GeneratorBase> create(const Halide::GeneratorContext &context) {
T *t = new T;
t->init_from_context(context);
return std::unique_ptr<Internal::GeneratorBase>(t);
}
private:
template<typename>
struct type_sink { typedef void type; };
template<typename T2, typename = void>
struct has_generate_method : std::false_type {};
template<typename T2>
struct has_generate_method<T2, typename type_sink<decltype(std::declval<T2>().generate())>::type> : std::true_type {};
template<typename T2, typename = void>
struct has_schedule_method : std::false_type {};
template<typename T2>
struct has_schedule_method<T2, typename type_sink<decltype(std::declval<T2>().schedule())>::type> : std::true_type {};
template <typename T2 = T,
typename std::enable_if<!has_generate_method<T2>::value>::type * = nullptr>
Pipeline build_pipeline_impl() {
static_assert(!has_schedule_method<T2>::value, "The schedule() method is ignored if you define a build() method; use generate() instead.");
pre_build();
Pipeline p = ((T *)this)->build();
post_build();
return p;
}
template <typename T2 = T,
typename std::enable_if<has_generate_method<T2>::value>::type * = nullptr>
Pipeline build_pipeline_impl() {
((T *)this)->call_generate_impl();
((T *)this)->call_schedule_impl();
return get_pipeline();
}
template <typename T2 = T,
typename std::enable_if<!has_generate_method<T2>::value>::type * = nullptr>
void call_generate_impl() {
user_error << "Unimplemented";
}
template <typename T2 = T,
typename std::enable_if<has_generate_method<T2>::value>::type * = nullptr>
void call_generate_impl() {
T *t = (T*)this;
static_assert(std::is_void<decltype(t->generate())>::value, "generate() must return void");
pre_generate();
t->generate();
post_generate();
}
template <typename T2 = T,
typename std::enable_if<!has_schedule_method<T2>::value>::type * = nullptr>
void call_schedule_impl() {
user_error << "Unimplemented";
}
template <typename T2 = T,
typename std::enable_if<has_schedule_method<T2>::value>::type * = nullptr>
void call_schedule_impl() {
T *t = (T*)this;
static_assert(std::is_void<decltype(t->schedule())>::value, "schedule() must return void");
pre_schedule();
t->schedule();
post_schedule();
}
protected:
Pipeline build_pipeline() override {
return this->build_pipeline_impl();
}
void call_generate() override {
this->call_generate_impl();
}
void call_schedule() override {
this->call_schedule_impl();
}
private:
friend void ::Halide::Internal::generator_test();
friend class Internal::SimpleGeneratorFactory;
friend void ::Halide::Internal::generator_test();
Generator(const Generator &) = delete;
void operator=(const Generator &) = delete;
Generator(Generator&& that) = delete;
void operator=(Generator&& that) = delete;
};
template <class GeneratorClass>
class RegisterGenerator {
public:
RegisterGenerator(const char* generator_name) {
std::unique_ptr<Internal::SimpleGeneratorFactory> f(new Internal::SimpleGeneratorFactory(GeneratorClass::create, generator_name));
Internal::GeneratorRegistry::register_factory(generator_name, std::move(f));
}
};
namespace Internal {
class GeneratorStub : public NamesInterface {
public:
GeneratorStub() = default;
GeneratorStub(GeneratorStub&& that) : generator(std::move(that.generator)) {}
GeneratorStub& operator=(GeneratorStub&& that) {
generator = std::move(that.generator);
return *this;
}
Target get_target() const { return generator->get_target(); }
template<typename T>
GeneratorStub &set_schedule_param(const std::string &name, const T &value) {
generator->set_schedule_param(name, value);
return *this;
}
GeneratorStub &schedule() {
generator->call_schedule();
return *this;
}
operator Func() const {
return get_first_output();
}
template <typename... Args>
FuncRef operator()(Args&&... args) const {
return get_first_output()(std::forward<Args>(args)...);
}
template <typename ExprOrVar>
FuncRef operator()(std::vector<ExprOrVar> args) const {
return get_first_output()(args);
}
Realization realize(std::vector<int32_t> sizes) {
return generator->realize(sizes);
}
template <typename... Args, typename std::enable_if<NoRealizations<Args...>::value>::type * = nullptr>
Realization realize(Args&&... args) {
return generator->realize(std::forward<Args>(args)...);
}
void realize(Realization r) {
generator->realize(r);
}
virtual ~GeneratorStub() {}
protected:
typedef std::function<std::unique_ptr<GeneratorBase>(const GeneratorContext&, const std::map<std::string, std::string>&)> GeneratorFactory;
EXPORT GeneratorStub(const GeneratorContext &context,
GeneratorFactory generator_factory,
const std::map<std::string, std::string> &generator_params,
const std::vector<std::vector<Internal::StubInput>> &inputs);
ScheduleParamBase &get_schedule_param(const std::string &n) const {
return generator->find_schedule_param_by_name(n);
}
Func get_output(const std::string &n) const {
return generator->get_output(n);
}
template<typename T2>
T2 get_output_buffer(const std::string &n) const {
return T2(get_output(n), generator);
}
std::vector<Func> get_output_vector(const std::string &n) const {
return generator->get_output_vector(n);
}
bool has_generator() const {
return generator != nullptr;
}
template<typename Ratio>
static double ratio_to_double() {
return (double)Ratio::num / (double)Ratio::den;
}
static std::vector<StubInput> to_stub_input_vector(const Expr &e) {
return { StubInput(e) };
}
static std::vector<StubInput> to_stub_input_vector(const Func &f) {
return { StubInput(f) };
}
template<typename T = void>
static std::vector<StubInput> to_stub_input_vector(const StubInputBuffer<T> &b) {
return { StubInput(b) };
}
template <typename T>
static std::vector<StubInput> to_stub_input_vector(const std::vector<T> &v) {
std::vector<StubInput> r;
std::copy(v.begin(), v.end(), std::back_inserter(r));
return r;
}
EXPORT void verify_same_funcs(const Func &a, const Func &b);
EXPORT void verify_same_funcs(const std::vector<Func>& a, const std::vector<Func>& b);
template<typename T2>
void verify_same_funcs(const StubOutputBuffer<T2> &a, const StubOutputBuffer<T2> &b) {
verify_same_funcs(a.f, b.f);
}
private:
std::shared_ptr<GeneratorBase> generator;
Func get_first_output() const {
return generator->get_first_output();
}
explicit GeneratorStub(const GeneratorStub &) = delete;
GeneratorStub &operator=(const GeneratorStub &) = delete;
explicit GeneratorStub(const GeneratorStub &&) = delete;
GeneratorStub &operator=(const GeneratorStub &&) = delete;
};
}
}
#define HALIDE_REGISTER_GENERATOR(GEN_CLASS_NAME, GEN_REGISTRY_NAME) \
namespace ns_reg_gen { static auto reg_##GEN_CLASS_NAME = Halide::RegisterGenerator<GEN_CLASS_NAME>(GEN_REGISTRY_NAME); }
#endif