#ifndef HALIDE_IR_OPERATOR_H
#define HALIDE_IR_OPERATOR_H
#include <atomic>
#include "IR.h"
#include "Util.h"
namespace Halide {
namespace Internal {
EXPORT bool is_const(const Expr &e);
EXPORT bool is_const(const Expr &e, int64_t v);
EXPORT const int64_t *as_const_int(const Expr &e);
EXPORT const uint64_t *as_const_uint(const Expr &e);
EXPORT const double *as_const_float(const Expr &e);
EXPORT bool is_const_power_of_two_integer(const Expr &e, int *bits);
EXPORT bool is_positive_const(const Expr &e);
EXPORT bool is_negative_const(const Expr &e);
EXPORT bool is_negative_negatable_const(const Expr &e);
EXPORT bool is_undef(const Expr &e);
EXPORT bool is_zero(const Expr &e);
EXPORT bool is_one(const Expr &e);
EXPORT bool is_two(const Expr &e);
EXPORT bool is_no_op(const Stmt &s);
EXPORT Expr make_const(Type t, int64_t val);
EXPORT Expr make_const(Type t, uint64_t val);
EXPORT Expr make_const(Type t, double val);
inline Expr make_const(Type t, int32_t val) {return make_const(t, (int64_t)val);}
inline Expr make_const(Type t, uint32_t val) {return make_const(t, (uint64_t)val);}
inline Expr make_const(Type t, int16_t val) {return make_const(t, (int64_t)val);}
inline Expr make_const(Type t, uint16_t val) {return make_const(t, (uint64_t)val);}
inline Expr make_const(Type t, int8_t val) {return make_const(t, (int64_t)val);}
inline Expr make_const(Type t, uint8_t val) {return make_const(t, (uint64_t)val);}
inline Expr make_const(Type t, bool val) {return make_const(t, (uint64_t)val);}
inline Expr make_const(Type t, float val) {return make_const(t, (double)val);}
inline Expr make_const(Type t, float16_t val) {return make_const(t, (double)val);}
EXPORT void check_representable(Type t, int64_t val);
EXPORT Expr make_bool(bool val, int lanes = 1);
EXPORT Expr make_zero(Type t);
EXPORT Expr make_one(Type t);
EXPORT Expr make_two(Type t);
EXPORT Expr const_true(int lanes = 1);
EXPORT Expr const_false(int lanes = 1);
EXPORT Expr lossless_cast(Type t, const Expr &e);
EXPORT void match_types(Expr &a, Expr &b);
EXPORT Expr halide_log(const Expr &a);
EXPORT Expr halide_exp(const Expr &a);
EXPORT Expr halide_erf(const Expr &a);
EXPORT Expr raise_to_integer_power(const Expr &a, int64_t b);
EXPORT void split_into_ands(const Expr &cond, std::vector<Expr> &result);
struct BufferBuilder {
Expr buffer_memory, shape_memory;
Expr host, device, device_interface;
Type type;
int dimensions = 0;
std::vector<Expr> mins, extents, strides;
Expr host_dirty, device_dirty;
EXPORT Expr build() const;
};
}
template<typename T>
inline Expr cast(const Expr &a) {
return cast(type_of<T>(), a);
}
inline Expr cast(Type t, const Expr &a) {
user_assert(a.defined()) << "cast of undefined Expr\n";
if (a.type() == t) return a;
if (t.is_handle() && !a.type().is_handle()) {
user_error << "Can't cast \"" << a << "\" to a handle. "
<< "The only legal cast from scalar types to a handle is: "
<< "reinterpret(Handle(), cast<uint64_t>(" << a << "));\n";
} else if (a.type().is_handle() && !t.is_handle()) {
user_error << "Can't cast handle \"" << a << "\" to type " << t << ". "
<< "The only legal cast from handles to scalar types is: "
<< "reinterpret(UInt(64), " << a << ");\n";
}
if (const int64_t *i = as_const_int(a)) {
return Internal::make_const(t, *i);
}
if (const uint64_t *u = as_const_uint(a)) {
return Internal::make_const(t, *u);
}
if (const double *f = as_const_float(a)) {
return Internal::make_const(t, *f);
}
if (t.is_vector()) {
if (a.type().is_scalar()) {
return Internal::Broadcast::make(cast(t.element_of(), a), t.lanes());
} else if (const Internal::Broadcast *b = a.as<Internal::Broadcast>()) {
internal_assert(b->lanes == t.lanes());
return Internal::Broadcast::make(cast(t.element_of(), b->value), t.lanes());
}
}
return Internal::Cast::make(t, a);
}
inline Expr operator+(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator+ of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Add::make(a, b);
}
inline Expr operator+(const Expr &a, int b) {
user_assert(a.defined()) << "operator+ of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Add::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator+(int a, const Expr &b) {
user_assert(b.defined()) << "operator+ of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Add::make(Internal::make_const(b.type(), a), b);
}
inline Expr &operator+=(Expr &a, const Expr &b) {
user_assert(a.defined() && b.defined()) << "operator+= of undefined Expr\n";
a = Internal::Add::make(a, cast(a.type(), b));
return a;
}
inline Expr operator-(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator- of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Sub::make(a, b);
}
inline Expr operator-(const Expr &a, int b) {
user_assert(a.defined()) << "operator- of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Sub::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator-(int a, const Expr &b) {
user_assert(b.defined()) << "operator- of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Sub::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator-(const Expr &a) {
user_assert(a.defined()) << "operator- of undefined Expr\n";
return Internal::Sub::make(Internal::make_zero(a.type()), a);
}
inline Expr &operator-=(Expr &a, const Expr &b) {
user_assert(a.defined() && b.defined()) << "operator-= of undefined Expr\n";
a = Internal::Sub::make(a, cast(a.type(), b));
return a;
}
inline Expr operator*(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator* of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Mul::make(a, b);
}
inline Expr operator*(const Expr &a, int b) {
user_assert(a.defined()) << "operator* of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Mul::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator*(int a, const Expr &b) {
user_assert(b.defined()) << "operator* of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Mul::make(Internal::make_const(b.type(), a), b);
}
inline Expr &operator*=(Expr &a, const Expr &b) {
user_assert(a.defined() && b.defined()) << "operator*= of undefined Expr\n";
a = Internal::Mul::make(a, cast(a.type(), b));
return a;
}
inline Expr operator/(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator/ of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Div::make(a, b);
}
inline Expr &operator/=(Expr &a, const Expr &b) {
user_assert(a.defined() && b.defined()) << "operator/= of undefined Expr\n";
a = Internal::Div::make(a, cast(a.type(), b));
return a;
}
inline Expr operator/(const Expr &a, int b) {
user_assert(a.defined()) << "operator/ of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Div::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator/(int a, const Expr &b) {
user_assert(b.defined()) << "operator- of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Div::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator%(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator% of undefined Expr\n";
user_assert(!Internal::is_zero(b)) << "operator% with constant 0 modulus\n";
Internal::match_types(a, b);
return Internal::Mod::make(a, b);
}
inline Expr operator%(const Expr &a, int b) {
user_assert(a.defined()) << "operator% of undefined Expr\n";
user_assert(b != 0) << "operator% with constant 0 modulus\n";
Internal::check_representable(a.type(), b);
return Internal::Mod::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator%(int a, const Expr &b) {
user_assert(b.defined()) << "operator% of undefined Expr\n";
user_assert(!Internal::is_zero(b)) << "operator% with constant 0 modulus\n";
Internal::check_representable(b.type(), a);
return Internal::Mod::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator>(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator> of undefined Expr\n";
Internal::match_types(a, b);
return Internal::GT::make(a, b);
}
inline Expr operator>(const Expr &a, int b) {
user_assert(a.defined()) << "operator> of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::GT::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator>(int a, const Expr &b) {
user_assert(b.defined()) << "operator> of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::GT::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator<(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator< of undefined Expr\n";
Internal::match_types(a, b);
return Internal::LT::make(a, b);
}
inline Expr operator<(const Expr &a, int b) {
user_assert(a.defined()) << "operator< of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::LT::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator<(int a, const Expr &b) {
user_assert(b.defined()) << "operator< of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::LT::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator<=(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator<= of undefined Expr\n";
Internal::match_types(a, b);
return Internal::LE::make(a, b);
}
inline Expr operator<=(const Expr &a, int b) {
user_assert(a.defined()) << "operator<= of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::LE::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator<=(int a, const Expr &b) {
user_assert(b.defined()) << "operator<= of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::LE::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator>=(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator>= of undefined Expr\n";
Internal::match_types(a, b);
return Internal::GE::make(a, b);
}
inline Expr operator>=(const Expr &a, int b) {
user_assert(a.defined()) << "operator>= of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::GE::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator>=(int a, const Expr &b) {
user_assert(b.defined()) << "operator>= of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::GE::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator==(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator== of undefined Expr\n";
Internal::match_types(a, b);
return Internal::EQ::make(a, b);
}
inline Expr operator==(const Expr &a, int b) {
user_assert(a.defined()) << "operator== of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::EQ::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator==(int a, const Expr &b) {
user_assert(b.defined()) << "operator== of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::EQ::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator!=(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "operator!= of undefined Expr\n";
Internal::match_types(a, b);
return Internal::NE::make(a, b);
}
inline Expr operator!=(const Expr &a, int b) {
user_assert(a.defined()) << "operator!= of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::NE::make(a, Internal::make_const(a.type(), b));
}
inline Expr operator!=(int a, const Expr &b) {
user_assert(b.defined()) << "operator!= of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::NE::make(Internal::make_const(b.type(), a), b);
}
inline Expr operator&&(Expr a, Expr b) {
Internal::match_types(a, b);
return Internal::And::make(a, b);
}
inline Expr operator&&(const Expr &a, bool b) {
internal_assert(a.defined()) << "operator&& of undefined Expr\n";
internal_assert(a.type().is_bool()) << "operator&& of Expr of type " << a.type() << "\n";
if (b) {
return a;
} else {
return Internal::make_zero(a.type());
}
}
inline Expr operator&&(bool a, const Expr &b) {
return b && a;
}
inline Expr operator||(Expr a, Expr b) {
Internal::match_types(a, b);
return Internal::Or::make(a, b);
}
inline Expr operator||(const Expr &a, bool b) {
internal_assert(a.defined()) << "operator|| of undefined Expr\n";
internal_assert(a.type().is_bool()) << "operator|| of Expr of type " << a.type() << "\n";
if (b) {
return Internal::make_one(a.type());
} else {
return a;
}
}
inline Expr operator||(bool a, const Expr &b) {
return b || a;
}
inline Expr operator!(const Expr &a) {
return Internal::Not::make(a);
}
inline Expr max(Expr a, Expr b) {
user_assert(a.defined() && b.defined())
<< "max of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Max::make(a, b);
}
inline Expr max(const Expr &a, int b) {
user_assert(a.defined()) << "max of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Max::make(a, Internal::make_const(a.type(), b));
}
inline Expr max(int a, const Expr &b) {
user_assert(b.defined()) << "max of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Max::make(Internal::make_const(b.type(), a), b);
}
inline Expr max(float a, const Expr &b) {return max(Expr(a), b);}
inline Expr max(const Expr &a, float b) {return max(a, Expr(b));}
template<typename A, typename B, typename C, typename... Rest,
typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Rest...>::value>::type* = nullptr>
inline Expr max(const A &a, const B &b, const C &c, Rest&&... rest) {
return max(a, max(b, c, std::forward<Rest>(rest)...));
}
inline Expr min(Expr a, Expr b) {
user_assert(a.defined() && b.defined())
<< "min of undefined Expr\n";
Internal::match_types(a, b);
return Internal::Min::make(a, b);
}
inline Expr min(const Expr &a, int b) {
user_assert(a.defined()) << "max of undefined Expr\n";
Internal::check_representable(a.type(), b);
return Internal::Min::make(a, Internal::make_const(a.type(), b));
}
inline Expr min(int a, const Expr &b) {
user_assert(b.defined()) << "max of undefined Expr\n";
Internal::check_representable(b.type(), a);
return Internal::Min::make(Internal::make_const(b.type(), a), b);
}
inline Expr min(float a, const Expr &b) {return min(Expr(a), b);}
inline Expr min(const Expr &a, float b) {return min(a, Expr(b));}
template<typename A, typename B, typename C, typename... Rest,
typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Rest...>::value>::type* = nullptr>
inline Expr min(const A &a, const B &b, const C &c, Rest&&... rest) {
return min(a, min(b, c, std::forward<Rest>(rest)...));
}
inline Expr operator+(const Expr &a, float b) {return a + Expr(b);}
inline Expr operator+(float a, const Expr &b) {return Expr(a) + b;}
inline Expr operator-(const Expr &a, float b) {return a - Expr(b);}
inline Expr operator-(float a, const Expr &b) {return Expr(a) - b;}
inline Expr operator*(const Expr &a, float b) {return a * Expr(b);}
inline Expr operator*(float a, const Expr &b) {return Expr(a) * b;}
inline Expr operator/(const Expr &a, float b) {return a / Expr(b);}
inline Expr operator/(float a, const Expr &b) {return Expr(a) / b;}
inline Expr operator%(const Expr &a, float b) {return a % Expr(b);}
inline Expr operator%(float a, const Expr &b) {return Expr(a) % b;}
inline Expr operator>(const Expr &a, float b) {return a > Expr(b);}
inline Expr operator>(float a, const Expr &b) {return Expr(a) > b;}
inline Expr operator<(const Expr &a, float b) {return a < Expr(b);}
inline Expr operator<(float a, const Expr &b) {return Expr(a) < b;}
inline Expr operator>=(const Expr &a, float b) {return a >= Expr(b);}
inline Expr operator>=(float a, const Expr &b) {return Expr(a) >= b;}
inline Expr operator<=(const Expr &a, float b) {return a <= Expr(b);}
inline Expr operator<=(float a, const Expr &b) {return Expr(a) <= b;}
inline Expr operator==(const Expr &a, float b) {return a == Expr(b);}
inline Expr operator==(float a, const Expr &b) {return Expr(a) == b;}
inline Expr operator!=(const Expr &a, float b) {return a != Expr(b);}
inline Expr operator!=(float a, const Expr &b) {return Expr(a) != b;}
inline Expr clamp(const Expr &a, const Expr &min_val, const Expr &max_val) {
user_assert(a.defined() && min_val.defined() && max_val.defined())
<< "clamp of undefined Expr\n";
Expr n_min_val = lossless_cast(a.type(), min_val);
user_assert(n_min_val.defined())
<< "clamp with possibly out of range minimum bound: " << min_val << "\n";
Expr n_max_val = lossless_cast(a.type(), max_val);
user_assert(n_max_val.defined())
<< "clamp with possibly out of range maximum bound: " << max_val << "\n";
return Internal::Max::make(Internal::Min::make(a, n_max_val), n_min_val);
}
inline Expr abs(const Expr &a) {
user_assert(a.defined())
<< "abs of undefined Expr\n";
Type t = a.type();
if (t.is_uint()) {
user_warning << "Warning: abs of an unsigned type is a no-op\n";
return a;
}
return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
Internal::Call::abs, {a}, Internal::Call::PureIntrinsic);
}
inline Expr absd(Expr a, Expr b) {
user_assert(a.defined() && b.defined()) << "absd of undefined Expr\n";
Internal::match_types(a, b);
Type t = a.type();
if (t.is_float()) {
return abs(a - b);
}
return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
Internal::Call::absd, {a, b},
Internal::Call::PureIntrinsic);
}
inline Expr select(Expr condition, Expr true_value, Expr false_value) {
if (as_const_int(condition)) {
condition = cast(Bool(), condition);
}
if (as_const_int(true_value)) {
true_value = cast(false_value.type(), true_value);
}
if (as_const_int(false_value)) {
false_value = cast(true_value.type(), false_value);
}
user_assert(condition.type().is_bool())
<< "The first argument to a select must be a boolean:\n"
<< " " << condition << " has type " << condition.type() << "\n";
user_assert(true_value.type() == false_value.type())
<< "The second and third arguments to a select do not have a matching type:\n"
<< " " << true_value << " has type " << true_value.type() << "\n"
<< " " << false_value << " has type " << false_value.type() << "\n";
return Internal::Select::make(condition, true_value, false_value);
}
template<typename... Args,
typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Args...>::value>::type* = nullptr>
inline Expr select(const Expr &c0, const Expr &v0, const Expr &c1, const Expr &v1, Args&&... args) {
return select(c0, v0, select(c1, v1, std::forward<Args>(args)...));
}
inline Expr sin(const Expr &x) {
user_assert(x.defined()) << "sin of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "sin_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "sin_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "sin_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr asin(const Expr &x) {
user_assert(x.defined()) << "asin of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "asin_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "asin_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "asin_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr cos(const Expr &x) {
user_assert(x.defined()) << "cos of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "cos_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "cos_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "cos_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr acos(const Expr &x) {
user_assert(x.defined()) << "acos of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "acos_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "acos_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "acos_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr tan(const Expr &x) {
user_assert(x.defined()) << "tan of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "tan_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "tan_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "tan_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr atan(const Expr &x) {
user_assert(x.defined()) << "atan of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "atan_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "atan_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "atan_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr atan2(Expr y, Expr x) {
user_assert(x.defined() && y.defined()) << "atan2 of undefined Expr\n";
if (y.type() == Float(64)) {
x = cast<double>(x);
return Internal::Call::make(Float(64), "atan2_f64", {y, x}, Internal::Call::PureExtern);
}
else if (y.type() == Float(16)) {
x = cast<float16_t>(x);
return Internal::Call::make(Float(16), "atan2_f16", {y, x}, Internal::Call::PureExtern);
}
else {
y = cast<float>(y);
x = cast<float>(x);
return Internal::Call::make(Float(32), "atan2_f32", {y, x}, Internal::Call::PureExtern);
}
}
inline Expr sinh(const Expr &x) {
user_assert(x.defined()) << "sinh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "sinh_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "sinh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "sinh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr asinh(const Expr &x) {
user_assert(x.defined()) << "asinh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "asinh_f64", {x}, Internal::Call::PureExtern);
}
else if(x.type() == Float(16)) {
return Internal::Call::make(Float(16), "asinh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "asinh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr cosh(const Expr &x) {
user_assert(x.defined()) << "cosh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "cosh_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "cosh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "cosh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr acosh(const Expr &x) {
user_assert(x.defined()) << "acosh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "acosh_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "acosh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "acosh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr tanh(const Expr &x) {
user_assert(x.defined()) << "tanh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "tanh_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "tanh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "tanh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr atanh(const Expr &x) {
user_assert(x.defined()) << "atanh of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "atanh_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "atanh_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "atanh_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr sqrt(const Expr &x) {
user_assert(x.defined()) << "sqrt of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "sqrt_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "sqrt_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "sqrt_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr hypot(const Expr &x, const Expr &y) {
return sqrt(x*x + y*y);
}
inline Expr exp(const Expr &x) {
user_assert(x.defined()) << "exp of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "exp_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "exp_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "exp_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr log(const Expr &x) {
user_assert(x.defined()) << "log of undefined Expr\n";
if (x.type() == Float(64)) {
return Internal::Call::make(Float(64), "log_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
return Internal::Call::make(Float(16), "log_f16", {x}, Internal::Call::PureExtern);
}
else {
return Internal::Call::make(Float(32), "log_f32", {cast<float>(x)}, Internal::Call::PureExtern);
}
}
inline Expr pow(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "pow of undefined Expr\n";
if (const int64_t *i = as_const_int(y)) {
return raise_to_integer_power(x, *i);
}
if (x.type() == Float(64)) {
y = cast<double>(y);
return Internal::Call::make(Float(64), "pow_f64", {x, y}, Internal::Call::PureExtern);
}
else if (x.type() == Float(16)) {
y = cast<float16_t>(y);
return Internal::Call::make(Float(16), "pow_f16", {x, y}, Internal::Call::PureExtern);
}
else {
x = cast<float>(x);
y = cast<float>(y);
return Internal::Call::make(Float(32), "pow_f32", {x, y}, Internal::Call::PureExtern);
}
}
inline Expr erf(const Expr &x) {
user_assert(x.defined()) << "erf of undefined Expr\n";
user_assert(x.type() == Float(32)) << "erf only takes float arguments\n";
return Internal::halide_erf(x);
}
EXPORT Expr fast_log(const Expr &x);
EXPORT Expr fast_exp(const Expr &x);
inline Expr fast_pow(Expr x, Expr y) {
if (const int64_t *i = as_const_int(y)) {
return raise_to_integer_power(x, *i);
}
x = cast<float>(x);
y = cast<float>(y);
return select(x == 0.0f, 0.0f, fast_exp(fast_log(x) * y));
}
inline Expr fast_inverse(const Expr &x) {
user_assert(x.type() == Float(32)) << "fast_inverse only takes float arguments\n";
return Internal::Call::make(x.type(), "fast_inverse_f32", {x}, Internal::Call::PureExtern);
}
inline Expr fast_inverse_sqrt(const Expr &x) {
user_assert(x.type() == Float(32)) << "fast_inverse_sqrt only takes float arguments\n";
return Internal::Call::make(x.type(), "fast_inverse_sqrt_f32", {x}, Internal::Call::PureExtern);
}
inline Expr floor(const Expr &x) {
user_assert(x.defined()) << "floor of undefined Expr\n";
if (x.type().element_of() == Float(64)) {
return Internal::Call::make(x.type(), "floor_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type().element_of() == Float(16)) {
return Internal::Call::make(Float(16), "floor_f16", {x}, Internal::Call::PureExtern);
}
else {
Type t = Float(32, x.type().lanes());
return Internal::Call::make(t, "floor_f32", {cast(t, x)}, Internal::Call::PureExtern);
}
}
inline Expr ceil(const Expr &x) {
user_assert(x.defined()) << "ceil of undefined Expr\n";
if (x.type().element_of() == Float(64)) {
return Internal::Call::make(x.type(), "ceil_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type().element_of() == Float(16)) {
return Internal::Call::make(Float(16), "ceil_f16", {x}, Internal::Call::PureExtern);
}
else {
Type t = Float(32, x.type().lanes());
return Internal::Call::make(t, "ceil_f32", {cast(t, x)}, Internal::Call::PureExtern);
}
}
inline Expr round(const Expr &x) {
user_assert(x.defined()) << "round of undefined Expr\n";
if (x.type().element_of() == Float(64)) {
return Internal::Call::make(Float(64), "round_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type().element_of() == Float(16)) {
return Internal::Call::make(Float(16), "round_f16", {x}, Internal::Call::PureExtern);
}
else {
Type t = Float(32, x.type().lanes());
return Internal::Call::make(t, "round_f32", {cast(t, x)}, Internal::Call::PureExtern);
}
}
inline Expr trunc(const Expr &x) {
user_assert(x.defined()) << "trunc of undefined Expr\n";
if (x.type().element_of() == Float(64)) {
return Internal::Call::make(Float(64), "trunc_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type().element_of() == Float(16)) {
return Internal::Call::make(Float(16), "trunc_f16", {x}, Internal::Call::PureExtern);
}
else {
Type t = Float(32, x.type().lanes());
return Internal::Call::make(t, "trunc_f32", {cast(t, x)}, Internal::Call::PureExtern);
}
}
inline Expr is_nan(const Expr &x) {
user_assert(x.defined()) << "is_nan of undefined Expr\n";
user_assert(x.type().is_float()) << "is_nan only works for float";
Type t = Bool(x.type().lanes());
if (x.type().element_of() == Float(64)) {
return Internal::Call::make(t, "is_nan_f64", {x}, Internal::Call::PureExtern);
}
else if (x.type().element_of() == Float(64)) {
return Internal::Call::make(t, "is_nan_f16", {x}, Internal::Call::PureExtern);
}
else {
Type ft = Float(32, x.type().lanes());
return Internal::Call::make(t, "is_nan_f32", {cast(ft, x)}, Internal::Call::PureExtern);
}
}
inline Expr fract(const Expr &x) {
user_assert(x.defined()) << "fract of undefined Expr\n";
return x - trunc(x);
}
inline Expr reinterpret(Type t, const Expr &e) {
user_assert(e.defined()) << "reinterpret of undefined Expr\n";
int from_bits = e.type().bits() * e.type().lanes();
int to_bits = t.bits() * t.lanes();
user_assert(from_bits == to_bits)
<< "Reinterpret cast from type " << e.type()
<< " which has " << from_bits
<< " bits, to type " << t
<< " which has " << to_bits << " bits\n";
return Internal::Call::make(t, Internal::Call::reinterpret, {e}, Internal::Call::PureIntrinsic);
}
template<typename T>
inline Expr reinterpret(const Expr &e) {
return reinterpret(type_of<T>(), e);
}
inline Expr operator&(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "bitwise and of undefined Expr\n";
user_assert(x.type().is_int() || x.type().is_uint())
<< "The first argument to bitwise and must be an integer or unsigned integer";
user_assert(y.type().is_int() || y.type().is_uint())
<< "The second argument to bitwise and must be an integer or unsigned integer";
if (y.type().bits() != x.type().bits()) {
y = cast(y.type().with_bits(x.type().bits()), y);
}
if (y.type() != x.type()) {
y = reinterpret(x.type(), y);
}
return Internal::Call::make(x.type(), Internal::Call::bitwise_and, {x, y}, Internal::Call::PureIntrinsic);
}
inline Expr operator|(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "bitwise or of undefined Expr\n";
user_assert(x.type().is_int() || x.type().is_uint())
<< "The first argument to bitwise or must be an integer or unsigned integer";
user_assert(y.type().is_int() || y.type().is_uint())
<< "The second argument to bitwise or must be an integer or unsigned integer";
if (y.type().bits() != x.type().bits()) {
y = cast(y.type().with_bits(x.type().bits()), y);
}
if (y.type() != x.type()) {
y = reinterpret(x.type(), y);
}
return Internal::Call::make(x.type(), Internal::Call::bitwise_or, {x, y}, Internal::Call::PureIntrinsic);
}
inline Expr operator^(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "bitwise xor of undefined Expr\n";
user_assert(x.type().is_int() || x.type().is_uint())
<< "The first argument to bitwise xor must be an integer or unsigned integer";
user_assert(y.type().is_int() || y.type().is_uint())
<< "The second argument to bitwise xor must be an integer or unsigned integer";
if (y.type().bits() != x.type().bits()) {
y = cast(y.type().with_bits(x.type().bits()), y);
}
if (y.type() != x.type()) {
y = reinterpret(x.type(), y);
}
return Internal::Call::make(x.type(), Internal::Call::bitwise_xor, {x, y}, Internal::Call::PureIntrinsic);
}
inline Expr operator~(const Expr &x) {
user_assert(x.defined()) << "bitwise not of undefined Expr\n";
user_assert(x.type().is_int() || x.type().is_uint())
<< "Argument to bitwise not must be an integer or unsigned integer";
return Internal::Call::make(x.type(), Internal::Call::bitwise_not, {x}, Internal::Call::PureIntrinsic);
}
inline Expr operator<<(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "shift left of undefined Expr\n";
user_assert(!x.type().is_float()) << "First argument to shift left is a float: " << x << "\n";
user_assert(!y.type().is_float()) << "Second argument to shift left is a float: " << y << "\n";
Internal::match_types(x, y);
return Internal::Call::make(x.type(), Internal::Call::shift_left, {x, y}, Internal::Call::PureIntrinsic);
}
inline Expr operator<<(const Expr &x, int y) {
Internal::check_representable(x.type(), y);
return x << Internal::make_const(x.type(), y);
}
inline Expr operator<<(int x, const Expr &y) {
Internal::check_representable(y.type(), x);
return Internal::make_const(y.type(), x) << y;
}
inline Expr operator>>(Expr x, Expr y) {
user_assert(x.defined() && y.defined()) << "shift right of undefined Expr\n";
user_assert(!x.type().is_float()) << "First argument to shift right is a float: " << x << "\n";
user_assert(!y.type().is_float()) << "Second argument to shift right is a float: " << y << "\n";
Internal::match_types(x, y);
return Internal::Call::make(x.type(), Internal::Call::shift_right, {x, y}, Internal::Call::PureIntrinsic);
}
inline Expr operator>>(const Expr &x, int y) {
Internal::check_representable(x.type(), y);
return x >> Internal::make_const(x.type(), y);
}
inline Expr operator>>(int x, const Expr &y) {
Internal::check_representable(y.type(), x);
return Internal::make_const(y.type(), x) >> y;
}
inline Expr lerp(Expr zero_val, Expr one_val, Expr weight) {
user_assert(zero_val.defined()) << "lerp with undefined zero value";
user_assert(one_val.defined()) << "lerp with undefined one value";
user_assert(weight.defined()) << "lerp with undefined weight";
if (as_const_int(zero_val)) {
zero_val = cast(one_val.type(), zero_val);
}
if (as_const_int(one_val)) {
one_val = cast(zero_val.type(), one_val);
}
user_assert(zero_val.type() == one_val.type())
<< "Can't lerp between " << zero_val << " of type " << zero_val.type()
<< " and " << one_val << " of different type " << one_val.type() << "\n";
user_assert((weight.type().is_uint() || weight.type().is_float()))
<< "A lerp weight must be an unsigned integer or a float, but "
<< "lerp weight " << weight << " has type " << weight.type() << ".\n";
user_assert((zero_val.type().is_float() || zero_val.type().lanes() <= 32))
<< "Lerping between 64-bit integers is not supported\n";
if (!zero_val.type().is_float()) {
const double *const_weight = as_const_float(weight);
if (const_weight) {
user_assert(*const_weight >= 0.0 && *const_weight <= 1.0)
<< "Floating-point weight for lerp with integer arguments is "
<< *const_weight << ", which is not in the range [0.0, 1.0].\n";
}
}
return Internal::Call::make(zero_val.type(), Internal::Call::lerp,
{zero_val, one_val, weight},
Internal::Call::PureIntrinsic);
}
inline Expr popcount(const Expr &x) {
user_assert(x.defined()) << "popcount of undefined Expr\n";
return Internal::Call::make(x.type(), Internal::Call::popcount,
{x}, Internal::Call::PureIntrinsic);
}
inline Expr count_leading_zeros(const Expr &x) {
user_assert(x.defined()) << "count leading zeros of undefined Expr\n";
return Internal::Call::make(x.type(), Internal::Call::count_leading_zeros,
{x}, Internal::Call::PureIntrinsic);
}
inline Expr count_trailing_zeros(const Expr &x) {
user_assert(x.defined()) << "count trailing zeros of undefined Expr\n";
return Internal::Call::make(x.type(), Internal::Call::count_trailing_zeros,
{x}, Internal::Call::PureIntrinsic);
}
inline Expr div_round_to_zero(Expr x, Expr y) {
user_assert(x.defined()) << "div_round_to_zero of undefined dividend\n";
user_assert(y.defined()) << "div_round_to_zero of undefined divisor\n";
Internal::match_types(x, y);
if (x.type().is_uint()) {
return x / y;
}
user_assert(x.type().is_int()) << "First argument to div_round_to_zero is not an integer: " << x << "\n";
user_assert(y.type().is_int()) << "Second argument to div_round_to_zero is not an integer: " << y << "\n";
return Internal::Call::make(x.type(), Internal::Call::div_round_to_zero,
{x, y},
Internal::Call::PureIntrinsic);
}
inline Expr mod_round_to_zero(Expr x, Expr y) {
user_assert(x.defined()) << "mod_round_to_zero of undefined dividend\n";
user_assert(y.defined()) << "mod_round_to_zero of undefined divisor\n";
Internal::match_types(x, y);
if (x.type().is_uint()) {
return x % y;
}
user_assert(x.type().is_int()) << "First argument to mod_round_to_zero is not an integer: " << x << "\n";
user_assert(y.type().is_int()) << "Second argument to mod_round_to_zero is not an integer: " << y << "\n";
return Internal::Call::make(x.type(), Internal::Call::mod_round_to_zero,
{x, y},
Internal::Call::PureIntrinsic);
}
inline Expr random_float(const Expr &seed = Expr()) {
static std::atomic<int> counter;
int id = (counter++)*2;
std::vector<Expr> args;
if (seed.defined()) {
user_assert(seed.type() == Int(32))
<< "The seed passed to random_float must have type Int(32), but instead is "
<< seed << " of type " << seed.type() << "\n";
args.push_back(seed);
}
args.push_back(id);
return Internal::Call::make(Float(32), Internal::Call::random,
args, Internal::Call::PureIntrinsic);
}
inline Expr random_uint(const Expr &seed = Expr()) {
static std::atomic<int> counter;
int id = (counter++)*2 + 1;
std::vector<Expr> args;
if (seed.defined()) {
user_assert(seed.type() == Int(32) || seed.type() == UInt(32))
<< "The seed passed to random_int must have type Int(32) or UInt(32), but instead is "
<< seed << " of type " << seed.type() << "\n";
args.push_back(seed);
}
args.push_back(id);
return Internal::Call::make(UInt(32), Internal::Call::random,
args, Internal::Call::PureIntrinsic);
}
inline Expr random_int(const Expr &seed = Expr()) {
return cast<int32_t>(random_uint(seed));
}
namespace Internal {
inline NO_INLINE void collect_print_args(std::vector<Expr> &args) {
}
template<typename ...Args>
inline NO_INLINE void collect_print_args(std::vector<Expr> &args, const char *arg, Args&&... more_args) {
args.push_back(Expr(std::string(arg)));
collect_print_args(args, std::forward<Args>(more_args)...);
}
template<typename ...Args>
inline NO_INLINE void collect_print_args(std::vector<Expr> &args, const Expr &arg, Args&&... more_args) {
args.push_back(arg);
collect_print_args(args, std::forward<Args>(more_args)...);
}
}
EXPORT Expr print(const std::vector<Expr> &values);
template <typename... Args>
inline NO_INLINE Expr print(const Expr &a, Args&&... args) {
std::vector<Expr> collected_args = {a};
Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
return print(collected_args);
}
EXPORT Expr print_when(const Expr &condition, const std::vector<Expr> &values);
template<typename ...Args>
inline NO_INLINE Expr print_when(const Expr &condition, const Expr &a, Args&&... args) {
std::vector<Expr> collected_args = {a};
Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
return print_when(condition, collected_args);
}
EXPORT Expr require(const Expr &condition, const std::vector<Expr> &values);
template<typename ...Args>
inline NO_INLINE Expr require(const Expr &condition, const Expr &value, Args&&... args) {
std::vector<Expr> collected_args = {value};
Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
return require(condition, collected_args);
}
inline Expr undef(Type t) {
return Internal::Call::make(t, Internal::Call::undef,
std::vector<Expr>(),
Internal::Call::PureIntrinsic);
}
template<typename T>
inline Expr undef() {
return undef(type_of<T>());
}
namespace Internal {
EXPORT Expr memoize_tag_helper(const Expr &result, const std::vector<Expr> &cache_key_values);
}
template<typename ...Args>
inline NO_INLINE Expr memoize_tag(const Expr &result, Args&&... args) {
std::vector<Expr> collected_args{std::forward<Args>(args)...};
return Internal::memoize_tag_helper(result, collected_args);
}
inline Expr likely(const Expr &e) {
return Internal::Call::make(e.type(), Internal::Call::likely,
{e}, Internal::Call::PureIntrinsic);
}
inline Expr likely_if_innermost(const Expr &e) {
return Internal::Call::make(e.type(), Internal::Call::likely_if_innermost,
{e}, Internal::Call::PureIntrinsic);
}
template <typename T>
Expr saturating_cast(const Expr &e) {
return saturating_cast(type_of<T>(), e);
}
EXPORT Expr saturating_cast(Type t, Expr e);
}
#endif