#ifndef HALIDE_EXPR_H
#define HALIDE_EXPR_H
#include <string>
#include <vector>
#include "Debug.h"
#include "Error.h"
#include "Float16.h"
#include "Type.h"
#include "IntrusivePtr.h"
#include "Util.h"
namespace Halide {
namespace Internal {
class IRVisitor;
enum class IRNodeType {
IntImm,
UIntImm,
FloatImm,
StringImm,
Cast,
Variable,
Add,
Sub,
Mul,
Div,
Mod,
Min,
Max,
EQ,
NE,
LT,
LE,
GT,
GE,
And,
Or,
Not,
Select,
Load,
Ramp,
Broadcast,
Call,
Let,
LetStmt,
AssertStmt,
ProducerConsumer,
For,
Store,
Provide,
Allocate,
Free,
Realize,
Block,
IfThenElse,
Evaluate,
Shuffle,
Prefetch,
};
struct IRNode {
virtual void accept(IRVisitor *v) const = 0;
IRNode() {}
virtual ~IRNode() {}
mutable RefCount ref_count;
virtual IRNodeType type_info() const = 0;
};
template<>
EXPORT inline RefCount &ref_count<IRNode>(const IRNode *n) {return n->ref_count;}
template<>
EXPORT inline void destroy<IRNode>(const IRNode *n) {delete n;}
struct BaseStmtNode : public IRNode {
};
struct BaseExprNode : public IRNode {
Type type;
};
template<typename T>
struct ExprNode : public BaseExprNode {
EXPORT void accept(IRVisitor *v) const;
virtual IRNodeType type_info() const {return T::_type_info;}
virtual ~ExprNode() {}
};
template<typename T>
struct StmtNode : public BaseStmtNode {
EXPORT void accept(IRVisitor *v) const;
virtual IRNodeType type_info() const {return T::_type_info;}
virtual ~StmtNode() {}
};
struct IRHandle : public IntrusivePtr<const IRNode> {
IRHandle() : IntrusivePtr<const IRNode>() {}
IRHandle(const IRNode *p) : IntrusivePtr<const IRNode>(p) {}
void accept(IRVisitor *v) const {
ptr->accept(v);
}
template<typename T> const T *as() const {
if (ptr && ptr->type_info() == T::_type_info) {
return (const T *)ptr;
}
return nullptr;
}
};
struct IntImm : public ExprNode<IntImm> {
int64_t value;
static const IntImm *make(Type t, int64_t value) {
internal_assert(t.is_int() && t.is_scalar())
<< "IntImm must be a scalar Int\n";
internal_assert(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
<< "IntImm must be 8, 16, 32, or 64-bit\n";
value <<= (64 - t.bits());
value >>= (64 - t.bits());
IntImm *node = new IntImm;
node->type = t;
node->value = value;
return node;
}
static const IRNodeType _type_info = IRNodeType::IntImm;
};
struct UIntImm : public ExprNode<UIntImm> {
uint64_t value;
static const UIntImm *make(Type t, uint64_t value) {
internal_assert(t.is_uint() && t.is_scalar())
<< "UIntImm must be a scalar UInt\n";
internal_assert(t.bits() == 1 || t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
<< "UIntImm must be 1, 8, 16, 32, or 64-bit\n";
value <<= (64 - t.bits());
value >>= (64 - t.bits());
UIntImm *node = new UIntImm;
node->type = t;
node->value = value;
return node;
}
static const IRNodeType _type_info = IRNodeType::UIntImm;
};
struct FloatImm : public ExprNode<FloatImm> {
double value;
static const FloatImm *make(Type t, double value) {
internal_assert(t.is_float() && t.is_scalar())
<< "FloatImm must be a scalar Float\n";
FloatImm *node = new FloatImm;
node->type = t;
switch (t.bits()) {
case 16:
node->value = (double)((float16_t)value);
break;
case 32:
node->value = (float)value;
break;
case 64:
node->value = value;
break;
default:
internal_error << "FloatImm must be 16, 32, or 64-bit\n";
}
return node;
}
static const IRNodeType _type_info = IRNodeType::FloatImm;
};
struct StringImm : public ExprNode<StringImm> {
std::string value;
static const StringImm *make(const std::string &val) {
StringImm *node = new StringImm;
node->type = type_of<const char *>();
node->value = val;
return node;
}
static const IRNodeType _type_info = IRNodeType::StringImm;
};
}
struct Expr : public Internal::IRHandle {
Expr() : Internal::IRHandle() {}
Expr(const Internal::BaseExprNode *n) : IRHandle(n) {}
EXPORT explicit Expr(int8_t x) : IRHandle(Internal::IntImm::make(Int(8), x)) {}
EXPORT explicit Expr(int16_t x) : IRHandle(Internal::IntImm::make(Int(16), x)) {}
EXPORT Expr(int32_t x) : IRHandle(Internal::IntImm::make(Int(32), x)) {}
EXPORT explicit Expr(int64_t x) : IRHandle(Internal::IntImm::make(Int(64), x)) {}
EXPORT explicit Expr(uint8_t x) : IRHandle(Internal::UIntImm::make(UInt(8), x)) {}
EXPORT explicit Expr(uint16_t x) : IRHandle(Internal::UIntImm::make(UInt(16), x)) {}
EXPORT explicit Expr(uint32_t x) : IRHandle(Internal::UIntImm::make(UInt(32), x)) {}
EXPORT explicit Expr(uint64_t x) : IRHandle(Internal::UIntImm::make(UInt(64), x)) {}
EXPORT Expr(float16_t x) : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {}
EXPORT Expr(float x) : IRHandle(Internal::FloatImm::make(Float(32), x)) {}
EXPORT explicit Expr(double x) : IRHandle(Internal::FloatImm::make(Float(64), x)) {}
EXPORT Expr(const std::string &s) : IRHandle(Internal::StringImm::make(s)) {}
Type type() const {
return ((const Internal::BaseExprNode *)ptr)->type;
}
};
struct ExprCompare {
bool operator()(const Expr &a, const Expr &b) const {
return a.get() < b.get();
}
};
enum class DeviceAPI {
None,
Host,
Default_GPU,
CUDA,
OpenCL,
GLSL,
OpenGLCompute,
Metal,
Hexagon
};
const DeviceAPI all_device_apis[] = {DeviceAPI::None,
DeviceAPI::Host,
DeviceAPI::Default_GPU,
DeviceAPI::CUDA,
DeviceAPI::OpenCL,
DeviceAPI::GLSL,
DeviceAPI::OpenGLCompute,
DeviceAPI::Metal,
DeviceAPI::Hexagon};
namespace Internal {
enum class ForType {
Serial,
Parallel,
Vectorized,
Unrolled,
GPUBlock,
GPUThread
};
struct Stmt : public IRHandle {
Stmt() : IRHandle() {}
Stmt(const BaseStmtNode *n) : IRHandle(n) {}
struct Compare {
bool operator()(const Stmt &a, const Stmt &b) const {
return a.ptr < b.ptr;
}
};
};
}
}
#endif