#ifndef HALIDE_TUPLE_H
#define HALIDE_TUPLE_H
#include "IR.h"
#include "IROperator.h"
#include "Util.h"
namespace Halide {
class FuncRef;
class Tuple {
private:
std::vector<Expr> exprs;
public:
size_t size() const { return exprs.size(); }
Expr &operator[](size_t x) {
user_assert(x < exprs.size()) << "Tuple access out of bounds\n";
return exprs[x];
}
Expr operator[](size_t x) const {
user_assert(x < exprs.size()) << "Tuple access out of bounds\n";
return exprs[x];
}
explicit Tuple(Expr e) {
exprs.push_back(e);
}
template<typename ...Args>
Tuple(Expr a, Expr b, Args&&... args) {
exprs = std::vector<Expr>{a, b, std::forward<Args>(args)...};
}
explicit NO_INLINE Tuple(const std::vector<Expr> &e) : exprs(e) {
user_assert(e.size() > 0) << "Tuples must have at least one element\n";
}
EXPORT Tuple(const FuncRef &);
const std::vector<Expr> &as_vector() const {
return exprs;
}
};
class Realization {
private:
std::vector<Buffer<>> images;
public:
size_t size() const { return images.size(); }
const Buffer<> &operator[](size_t x) const {
user_assert(x < images.size()) << "Realization access out of bounds\n";
return images[x];
}
Buffer<> &operator[](size_t x) {
user_assert(x < images.size()) << "Realization access out of bounds\n";
return images[x];
}
template<typename T>
operator Buffer<T>() const {
return images[0];
}
template<typename T,
typename ...Args,
typename = typename std::enable_if<Internal::all_are_convertible<Buffer<>, Args...>::value>::type>
Realization(Buffer<T> &a, Args&&... args) {
images = std::vector<Buffer<>>({a, args...});
}
explicit Realization(std::vector<Buffer<>> &e) : images(e) {
user_assert(e.size() > 0) << "Realizations must have at least one element\n";
}
};
inline Tuple tuple_select(Tuple condition, const Tuple &true_value, const Tuple &false_value) {
Tuple result(std::vector<Expr>(condition.size()));
for (size_t i = 0; i < result.size(); i++) {
result[i] = select(condition[i], true_value[i], false_value[i]);
}
return result;
}
inline Tuple tuple_select(Expr condition, const Tuple &true_value, const Tuple &false_value) {
Tuple result(std::vector<Expr>(true_value.size()));
for (size_t i = 0; i < result.size(); i++) {
result[i] = select(condition, true_value[i], false_value[i]);
}
return result;
}
}
#endif