#ifndef HALIDE_RUNTIME_BUFFER_H
#define HALIDE_RUNTIME_BUFFER_H
#include <memory>
#include <vector>
#include <cassert>
#include <atomic>
#include <algorithm>
#include <stdint.h>
#include <string.h>
#include "HalideRuntime.h"
#ifndef EXPORT
#if defined(_WIN32) && defined(Halide_SHARED)
#ifdef Halide_EXPORTS
#define EXPORT __declspec(dllexport)
#else
#define EXPORT __declspec(dllimport)
#endif
#else
#define EXPORT
#endif
#endif
#ifdef _MSC_VER
#define HALIDE_ALLOCA _alloca
#else
#define HALIDE_ALLOCA __builtin_alloca
#endif
#if __GNUC__ == 5 && __GNUC_MINOR__ == 1
#pragma GCC diagnostic ignored "-Warray-bounds"
#endif
namespace Halide {
namespace Runtime {
template<typename T, int D> class Buffer;
template<typename ...Args>
struct AllInts : std::false_type {};
template<>
struct AllInts<> : std::true_type {};
template<typename T, typename ...Args>
struct AllInts<T, Args...> {
static const bool value = std::is_convertible<T, int>::value && AllInts<Args...>::value;
};
template<typename ...Args>
struct AllInts<float, Args...> : std::false_type {};
template<typename ...Args>
struct AllInts<double, Args...> : std::false_type {};
struct AllocationHeader {
void (*deallocate_fn)(void *);
std::atomic<int> ref_count {0};
};
template<typename T = void, int D = 4>
class Buffer {
halide_buffer_t buf = {0};
halide_dimension_t shape[D];
AllocationHeader *alloc = nullptr;
mutable std::atomic<int> *dev_ref_count = nullptr;
static const bool T_is_void = std::is_same<typename std::remove_const<T>::type, void>::value;
template<typename T2>
using add_const_if_T_is_const = typename std::conditional<std::is_const<T>::value, const T2, T2>::type;
using not_void_T = typename std::conditional<T_is_void,
add_const_if_T_is_const<uint8_t>,
T>::type;
using storage_T = typename std::conditional<std::is_pointer<T>::value, uint64_t, not_void_T>::type;
public:
static constexpr bool has_static_halide_type() {
return !T_is_void;
}
static halide_type_t static_halide_type() {
return halide_type_of<typename std::remove_cv<not_void_T>::type>();
}
bool manages_memory() const {
return alloc != nullptr;
}
private:
void incref() const {
if (!manages_memory()) return;
alloc->ref_count++;
if (buf.device) {
if (!dev_ref_count) {
dev_ref_count = new std::atomic<int>(1);
}
(*dev_ref_count)++;
}
}
void decref() {
if (!manages_memory()) return;
int new_count = --(alloc->ref_count);
if (new_count == 0) {
void (*fn)(void *) = alloc->deallocate_fn;
fn(alloc);
}
buf.host = nullptr;
alloc = nullptr;
decref_dev();
}
void decref_dev() {
int new_count = 0;
if (dev_ref_count) {
new_count = --(*dev_ref_count);
}
if (new_count == 0) {
if (buf.device) {
halide_device_free_t fn = halide_get_device_free_fn();
assert(fn && "Buffer has a device allocation but no Halide Runtime linked");
assert(!(alloc && device_dirty()) &&
"Implicitly freeing a dirty device allocation while a host allocation still lives. "
"Call device_free explicitly if you want to drop dirty device-side data. "
"Call copy_to_host explicitly if you want the data copied to the host allocation "
"before the device allocation is freed.");
(*fn)(nullptr, &buf);
}
if (dev_ref_count) {
delete dev_ref_count;
}
}
buf.device = 0;
dev_ref_count = nullptr;
}
void free_shape_storage() {
if (buf.dim != shape) {
delete[] buf.dim;
buf.dim = nullptr;
}
}
void make_shape_storage() {
if (buf.dimensions <= D) {
buf.dim = shape;
} else {
buf.dim = new halide_dimension_t[buf.dimensions];
}
}
void copy_shape_from(const halide_buffer_t &other) {
make_shape_storage();
for (int i = 0; i < buf.dimensions; i++) {
buf.dim[i] = other.dim[i];
}
}
template<typename T2, int D2>
void move_shape_from(Buffer<T2, D2> &&other) {
if (other.shape == other.buf.dim) {
copy_shape_from(other.buf);
} else {
buf.dim = other.buf.dim;
other.buf.dim = nullptr;
}
}
void initialize_from_buffer(const halide_buffer_t &b) {
memcpy(&buf, &b, sizeof(halide_buffer_t));
copy_shape_from(b);
}
template<typename ...Args>
void initialize_shape(int next, int first, Args... rest) {
buf.dim[next].min = 0;
buf.dim[next].extent = first;
if (next == 0) {
buf.dim[next].stride = 1;
} else {
buf.dim[next].stride = buf.dim[next-1].stride * buf.dim[next-1].extent;
}
initialize_shape(next + 1, rest...);
}
void initialize_shape(int) {
}
void initialize_shape(const std::vector<int> &sizes) {
for (size_t i = 0; i < sizes.size(); i++) {
buf.dim[i].min = 0;
buf.dim[i].extent = sizes[i];
if (i == 0) {
buf.dim[i].stride = 1;
} else {
buf.dim[i].stride = buf.dim[i-1].stride * buf.dim[i-1].extent;
}
}
}
template<typename Array, size_t N>
void initialize_shape_from_array_shape(int next, Array (&vals)[N]) {
buf.dim[next].min = 0;
buf.dim[next].extent = (int)N;
if (next == 0) {
buf.dim[next].stride = 1;
} else {
initialize_shape_from_array_shape(next - 1, vals[0]);
buf.dim[next].stride = buf.dim[next - 1].stride * buf.dim[next - 1].extent;
}
}
template<typename T2>
void initialize_shape_from_array_shape(int, const T2 &) {
}
template<typename Array, size_t N>
static int dimensionality_of_array(Array (&vals)[N]) {
return dimensionality_of_array(vals[0]) + 1;
}
template<typename T2>
static int dimensionality_of_array(const T2 &) {
return 0;
}
template<typename Array, size_t N>
static halide_type_t scalar_type_of_array(Array (&vals)[N]) {
return scalar_type_of_array(vals[0]);
}
template<typename T2>
static halide_type_t scalar_type_of_array(const T2 &) {
return halide_type_of<typename std::remove_cv<T2>::type>();
}
template<typename ...Args>
static bool any_zero(int first, Args... rest) {
if (first == 0) return true;
return any_zero(rest...);
}
static bool any_zero() {
return false;
}
static bool any_zero(const std::vector<int> &v) {
for (int i : v) {
if (i == 0) return true;
}
return false;
}
public:
typedef T ElemType;
class Dimension {
const halide_dimension_t &d;
public:
HALIDE_ALWAYS_INLINE int min() const {
return d.min;
}
HALIDE_ALWAYS_INLINE int stride() const {
return d.stride;
}
HALIDE_ALWAYS_INLINE int extent() const {
return d.extent;
}
HALIDE_ALWAYS_INLINE int max() const {
return min() + extent() - 1;
}
struct iterator {
int val;
int operator*() const {return val;}
bool operator!=(const iterator &other) const {return val != other.val;}
iterator &operator++() {val++; return *this;}
};
HALIDE_ALWAYS_INLINE iterator begin() const {
return {min()};
}
HALIDE_ALWAYS_INLINE iterator end() const {
return {min() + extent()};
}
Dimension(const halide_dimension_t &dim) : d(dim) {};
};
HALIDE_ALWAYS_INLINE Dimension dim(int i) const {
return Dimension(buf.dim[i]);
}
int min(int i) const { return dim(i).min(); }
int extent(int i) const { return dim(i).extent(); }
int stride(int i) const { return dim(i).stride(); }
size_t number_of_elements() const {
size_t s = 1;
for (int i = 0; i < dimensions(); i++) {
s *= dim(i).extent();
}
return s;
}
int dimensions() const {
return buf.dimensions;
}
halide_type_t type() const {
return buf.type;
}
T *begin() const {
ptrdiff_t index = 0;
for (int i = 0; i < dimensions(); i++) {
if (dim(i).stride() < 0) {
index += dim(i).stride() * (dim(i).extent() - 1);
}
}
return (T *)(buf.host + index * type().bytes());
}
T *end() const {
ptrdiff_t index = 0;
for (int i = 0; i < dimensions(); i++) {
if (dim(i).stride() > 0) {
index += dim(i).stride() * (dim(i).extent() - 1);
}
}
index += 1;
return (T *)(buf.host + index * type().bytes());
}
size_t size_in_bytes() const {
return (size_t)((const uint8_t *)end() - (const uint8_t *)begin());
}
Buffer() {
memset(&buf, 0, sizeof(halide_buffer_t));
buf.type = static_halide_type();
make_shape_storage();
}
Buffer(const halide_buffer_t &buf) {
assert(T_is_void || buf.type == static_halide_type());
initialize_from_buffer(buf);
}
Buffer(const buffer_t &old_buf) {
assert(!T_is_void && old_buf.elem_size == static_halide_type().bytes());
buf.host = old_buf.host;
buf.type = static_halide_type();
int d;
for (d = 0; d < 4 && old_buf.extent[d]; d++);
buf.dimensions = d;
make_shape_storage();
for (int i = 0; i < d; i++) {
buf.dim[i].min = old_buf.min[i];
buf.dim[i].extent = old_buf.extent[i];
buf.dim[i].stride = old_buf.stride[i];
}
buf.set_host_dirty(old_buf.host_dirty);
assert(old_buf.dev == 0 && "Cannot construct a Halide::Runtime::Buffer from a legacy buffer_t with a device allocation. Use halide_upgrade_buffer_t to upgrade it to a halide_buffer_t first.");
}
buffer_t make_legacy_buffer_t() const {
buffer_t old_buf = {0};
assert(!has_device_allocation() && "Cannot construct a legacy buffer_t from a Halide::Runtime::Buffer with a device allocation. Use halide_downgrade_buffer_t instead.");
old_buf.host = buf.host;
old_buf.elem_size = buf.type.bytes();
assert(dimensions() <= 4 && "Cannot construct a legacy buffer_t from a Halide::Runtime::Buffer with more than four dimensions.");
for (int i = 0; i < dimensions(); i++) {
old_buf.min[i] = dim(i).min();
old_buf.extent[i] = dim(i).extent();
old_buf.stride[i] = dim(i).stride();
}
return old_buf;
}
template<typename T2, int D2> friend class Buffer;
template<typename T2, int D2>
static bool can_convert_from(const Buffer<T2, D2> &other) {
static_assert((!std::is_const<T2>::value || std::is_const<T>::value),
"Can't convert from a Buffer<const T> to a Buffer<T>");
static_assert(std::is_same<typename std::remove_const<T>::type,
typename std::remove_const<T2>::type>::value ||
T_is_void || Buffer<T2, D2>::T_is_void,
"type mismatch constructing Buffer");
if (Buffer<T2, D2>::T_is_void && !T_is_void) {
return other.type() == static_halide_type();
}
return true;
}
template<typename T2, int D2>
static void assert_can_convert_from(const Buffer<T2, D2> &other) {
assert(can_convert_from(other));
}
Buffer(const Buffer<T, D> &other) : buf(other.buf),
alloc(other.alloc) {
other.incref();
dev_ref_count = other.dev_ref_count;
copy_shape_from(other.buf);
}
template<typename T2, int D2>
Buffer(const Buffer<T2, D2> &other) : buf(other.buf),
alloc(other.alloc) {
assert_can_convert_from(other);
other.incref();
dev_ref_count = other.dev_ref_count;
copy_shape_from(other.buf);
}
Buffer(Buffer<T, D> &&other) : buf(other.buf),
alloc(other.alloc),
dev_ref_count(other.dev_ref_count) {
other.dev_ref_count = nullptr;
other.alloc = nullptr;
move_shape_from(std::forward<Buffer<T, D>>(other));
}
template<typename T2, int D2>
Buffer(Buffer<T2, D2> &&other) : buf(other.buf),
alloc(other.alloc),
dev_ref_count(other.dev_ref_count) {
other.dev_ref_count = nullptr;
other.alloc = nullptr;
move_shape_from(std::forward<Buffer<T2, D2>>(other));
}
template<typename T2, int D2>
Buffer<T, D> &operator=(const Buffer<T2, D2> &other) {
if ((const void *)this == (const void *)&other) {
return *this;
}
assert_can_convert_from(other);
other.incref();
decref();
dev_ref_count = other.dev_ref_count;
alloc = other.alloc;
free_shape_storage();
buf = other.buf;
copy_shape_from(other.buf);
return *this;
}
Buffer<T, D> &operator=(const Buffer<T, D> &other) {
if (this == &other) {
return *this;
}
other.incref();
decref();
dev_ref_count = other.dev_ref_count;
alloc = other.alloc;
free_shape_storage();
buf = other.buf;
copy_shape_from(other.buf);
return *this;
}
template<typename T2, int D2>
Buffer<T, D> &operator=(Buffer<T2, D2> &&other) {
assert_can_convert_from(other);
decref();
alloc = other.alloc;
other.alloc = nullptr;
dev_ref_count = other.dev_ref_count;
other.dev_ref_count = nullptr;
free_shape_storage();
buf = other.buf;
move_shape_from(std::forward<Buffer<T2, D2>>(other));
return *this;
}
Buffer<T, D> &operator=(Buffer<T, D> &&other) {
decref();
alloc = other.alloc;
other.alloc = nullptr;
dev_ref_count = other.dev_ref_count;
other.dev_ref_count = nullptr;
free_shape_storage();
buf = other.buf;
move_shape_from(std::forward<Buffer<T, D>>(other));
return *this;
}
void check_overflow() {
size_t size = type().bytes();
for (int i = 0; i < dimensions(); i++) {
size *= dim(i).extent();
}
size = (size << 1) >> 1;
for (int i = 0; i < dimensions(); i++) {
size /= dim(i).extent();
}
assert(size == (size_t)type().bytes() && "Error: Overflow computing total size of buffer.");
}
void allocate(void *(*allocate_fn)(size_t) = nullptr,
void (*deallocate_fn)(void *) = nullptr) {
if (!allocate_fn) {
allocate_fn = malloc;
}
if (!deallocate_fn) {
deallocate_fn = free;
}
deallocate();
size_t size = size_in_bytes();
const size_t alignment = 128;
size = (size + alignment - 1) & ~(alignment - 1);
alloc = (AllocationHeader *)allocate_fn(size + sizeof(AllocationHeader) + alignment - 1);
alloc->deallocate_fn = deallocate_fn;
alloc->ref_count = 1;
uint8_t *unaligned_ptr = ((uint8_t *)alloc) + sizeof(AllocationHeader);
buf.host = (uint8_t *)((uintptr_t)(unaligned_ptr + alignment - 1) & ~(alignment - 1));
}
void deallocate() {
decref();
}
void device_deallocate() {
if (manages_memory()) {
decref_dev();
}
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
Buffer(halide_type_t t, int first, Args... rest) {
if (!T_is_void) {
assert(static_halide_type() == t);
}
buf.type = t;
buf.dimensions = 1 + (int)(sizeof...(rest));
make_shape_storage();
initialize_shape(0, first, rest...);
if (!any_zero(first, rest...)) {
check_overflow();
allocate();
}
}
explicit Buffer(int first) {
static_assert(!T_is_void,
"To construct an Buffer<void>, pass a halide_type_t as the first argument to the constructor");
buf.type = static_halide_type();
buf.dimensions = 1;
make_shape_storage();
initialize_shape(0, first);
if (first != 0) {
check_overflow();
allocate();
}
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
Buffer(int first, int second, Args... rest) {
static_assert(!T_is_void,
"To construct an Buffer<void>, pass a halide_type_t as the first argument to the constructor");
buf.type = static_halide_type();
buf.dimensions = 2 + (int)(sizeof...(rest));
make_shape_storage();
initialize_shape(0, first, second, rest...);
if (!any_zero(first, second, rest...)) {
check_overflow();
allocate();
}
}
Buffer(halide_type_t t, const std::vector<int> &sizes) {
if (!T_is_void) {
assert(static_halide_type() == t);
}
buf.type = t;
buf.dimensions = (int)sizes.size();
make_shape_storage();
initialize_shape(sizes);
if (!any_zero(sizes)) {
check_overflow();
allocate();
}
}
Buffer(const std::vector<int> &sizes) {
buf.type = static_halide_type();
buf.dimensions = (int)sizes.size();
make_shape_storage();
initialize_shape(sizes);
if (!any_zero(sizes)) {
check_overflow();
allocate();
}
}
template<typename Array, size_t N>
explicit Buffer(Array (&vals)[N]) {
buf.dimensions = dimensionality_of_array(vals);
buf.type = scalar_type_of_array(vals);
buf.host = (uint8_t *)vals;
make_shape_storage();
initialize_shape_from_array_shape(buf.dimensions - 1, vals);
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, int first, Args&&... rest) {
if (!T_is_void) {
assert(static_halide_type() == t);
}
buf.type = t;
buf.dimensions = 1 + (int)(sizeof...(rest));
buf.host = (uint8_t *)data;
make_shape_storage();
initialize_shape(0, first, int(rest)...);
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
explicit Buffer(T *data, int first, Args&&... rest) {
buf.type = static_halide_type();
buf.dimensions = 1 + (int)(sizeof...(rest));
buf.host = (uint8_t *)data;
make_shape_storage();
initialize_shape(0, first, int(rest)...);
}
explicit Buffer(T *data, const std::vector<int> &sizes) {
buf.type = static_halide_type();
buf.dimensions = (int)sizes.size();
buf.host = (uint8_t *)data;
make_shape_storage();
initialize_shape(sizes);
}
explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, const std::vector<int> &sizes) {
if (!T_is_void) {
assert(static_halide_type() == t);
}
buf.type = t;
buf.dimensions = (int)sizes.size();
buf.host = (uint8_t *)data;
make_shape_storage();
initialize_shape(sizes);
}
explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, int d, const halide_dimension_t *shape) {
if (!T_is_void) {
assert(static_halide_type() == t);
}
buf.type = t;
buf.dimensions = d;
buf.host = (uint8_t *)data;
make_shape_storage();
for (int i = 0; i < d; i++) {
buf.dim[i] = shape[i];
}
}
explicit Buffer(T *data, int d, const halide_dimension_t *shape) {
buf.type = halide_type_of<typename std::remove_cv<T>::type>();
buf.dimensions = d;
buf.host = (uint8_t *)data;
make_shape_storage();
for (int i = 0; i < d; i++) {
buf.dim[i] = shape[i];
}
}
~Buffer() {
free_shape_storage();
decref();
}
halide_buffer_t *raw_buffer() {
return &buf;
}
const halide_buffer_t *raw_buffer() const {
return &buf;
}
operator halide_buffer_t *() {
return &buf;
}
template<typename T2, int D2 = D,
typename = typename std::enable_if<(D2 <= D)>::type>
Buffer<T2, D2> &as() & {
Buffer<T2, D>::assert_can_convert_from(*this);
return *((Buffer<T2, D2> *)this);
}
template<typename T2, int D2 = D,
typename = typename std::enable_if<(D2 <= D)>::type>
const Buffer<T2, D2> &as() const & {
Buffer<T2, D>::assert_can_convert_from(*this);
return *((const Buffer<T2, D2> *)this);
}
template<typename T2, int D2 = D>
Buffer<T2, D2> as() && {
Buffer<T2, D2>::assert_can_convert_from(*this);
return *((Buffer<T2, D2> *)this);
}
int width() const {
return (dimensions() > 0) ? dim(0).extent() : 1;
}
int height() const {
return (dimensions() > 1) ? dim(1).extent() : 1;
}
int channels() const {
return (dimensions() > 2) ? dim(2).extent() : 1;
}
int left() const {
return dim(0).min();
}
int right() const {
return dim(0).max();
}
int top() const {
return dim(1).min();
}
int bottom() const {
return dim(1).max();
}
Buffer<T, D> copy(void *(*allocate_fn)(size_t) = nullptr,
void (*deallocate_fn)(void *) = nullptr) const {
Buffer<T, D> dst = make_with_shape_of(*this);
dst.copy_from(*this);
return dst;
}
template<typename T2, int D2>
void copy_from(const Buffer<T2, D2> &other) {
Buffer<const T, D> src(other);
Buffer<T, D> dst(*this);
assert(src.dimensions() == dst.dimensions());
for (int i = 0; i < dimensions(); i++) {
int min_coord = std::max(dst.dim(i).min(), src.dim(i).min());
int max_coord = std::min(dst.dim(i).max(), src.dim(i).max());
if (max_coord < min_coord) {
return;
}
dst.crop(i, min_coord, max_coord - min_coord + 1);
src.crop(i, min_coord, max_coord - min_coord + 1);
}
if (type().bytes() == 1) {
using MemType = uint8_t;
auto &typed_dst = (Buffer<MemType, D> &)dst;
auto &typed_src = (Buffer<const MemType, D> &)src;
typed_dst.for_each_value([&](MemType &dst, MemType src) {dst = src;}, typed_src);
} else if (type().bytes() == 2) {
using MemType = uint16_t;
auto &typed_dst = (Buffer<MemType, D> &)dst;
auto &typed_src = (Buffer<const MemType, D> &)src;
typed_dst.for_each_value([&](MemType &dst, MemType src) {dst = src;}, typed_src);
} else if (type().bytes() == 4) {
using MemType = uint32_t;
auto &typed_dst = (Buffer<MemType, D> &)dst;
auto &typed_src = (Buffer<const MemType, D> &)src;
typed_dst.for_each_value([&](MemType &dst, MemType src) {dst = src;}, typed_src);
} else if (type().bytes() == 8) {
using MemType = uint64_t;
auto &typed_dst = (Buffer<MemType, D> &)dst;
auto &typed_src = (Buffer<const MemType, D> &)src;
typed_dst.for_each_value([&](MemType &dst, MemType src) {dst = src;}, typed_src);
} else {
assert(false && "type().bytes() must be 1, 2, 4, or 8");
}
set_host_dirty();
}
Buffer<T, D> cropped(int d, int min, int extent) const {
Buffer<T, D> im = *this;
im.crop(d, min, extent);
return im;
}
void crop(int d, int min, int extent) {
int shift = min - dim(d).min();
if (shift) {
device_deallocate();
}
buf.host += shift * dim(d).stride() * type().bytes();
buf.dim[d].min = min;
buf.dim[d].extent = extent;
}
Buffer<T, D> cropped(const std::vector<std::pair<int, int>> &rect) const {
Buffer<T, D> im = *this;
im.crop(rect);
return im;
}
void crop(const std::vector<std::pair<int, int>> &rect) {
for (int i = 0; i < rect.size(); i++) {
crop(i, rect[i].first, rect[i].second);
}
}
Buffer<T, D> translated(int d, int dx) const {
Buffer<T, D> im = *this;
im.translate(d, dx);
return im;
}
void translate(int d, int delta) {
device_deallocate();
buf.dim[d].min += delta;
}
Buffer<T, D> translated(const std::vector<int> &delta) {
Buffer<T, D> im = *this;
im.translate(delta);
return im;
}
void translate(const std::vector<int> &delta) {
device_deallocate();
for (size_t i = 0; i < delta.size(); i++) {
translate(i, delta[i]);
}
}
template<typename ...Args>
void set_min(Args... args) {
assert(sizeof...(args) <= (size_t)dimensions());
device_deallocate();
const int x[] = {args...};
for (size_t i = 0; i < sizeof...(args); i++) {
buf.dim[i].min = x[i];
}
}
template<typename ...Args>
bool contains(Args... args) {
assert(sizeof...(args) <= (size_t)dimensions());
const int x[] = {args...};
for (size_t i = 0; i < sizeof...(args); i++) {
if (x[i] < dim(i).min() || x[i] > dim(i).max()) {
return false;
}
}
return true;
}
Buffer<T, D> transposed(int d1, int d2) const {
Buffer<T, D> im = *this;
im.transpose(d1, d2);
return im;
}
void transpose(int d1, int d2) {
std::swap(buf.dim[d1], buf.dim[d2]);
}
Buffer<T, D> sliced(int d, int pos) const {
Buffer<T, D> im = *this;
im.slice(d, pos);
return im;
}
void slice(int d, int pos) {
device_deallocate();
buf.dimensions--;
int shift = pos - dim(d).min();
assert(buf.device == 0 || shift == 0);
buf.host += shift * dim(d).stride() * type().bytes();
for (int i = d; i < dimensions(); i++) {
buf.dim[i] = buf.dim[i+1];
}
buf.dim[buf.dimensions] = {0, 0, 0};
}
Buffer<T, D> embedded(int d, int pos) const {
assert(d >= 0 && d <= dimensions());
Buffer<T, D> im(*this);
im.add_dimension();
im.translate(im.dimensions() - 1, pos);
for (int i = im.dimensions(); i > d; i--) {
im.transpose();
}
return im;
}
void embed(int d, int pos) {
assert(d >= 0 && d <= dimensions());
add_dimension();
translate(dimensions() - 1, pos);
for (int i = dimensions() - 1; i > d; i--) {
transpose(i, i-1);
}
}
void add_dimension() {
const int dims = buf.dimensions;
buf.dimensions++;
if (buf.dim != shape) {
halide_dimension_t *new_shape = new halide_dimension_t[buf.dimensions];
for (int i = 0; i < dims; i++) {
new_shape[i] = buf.dim[i];
}
delete[] buf.dim;
buf.dim = new_shape;
} else if (dims == D) {
make_shape_storage();
for (int i = 0; i < dims; i++) {
buf.dim[i] = shape[i];
}
} else {
}
buf.dim[dims] = {0, 1, 0};
if (dims == 0) {
buf.dim[dims].stride = 1;
} else {
buf.dim[dims].stride = buf.dim[dims-1].extent * buf.dim[dims-1].stride;
}
}
void add_dimension_with_stride(int s) {
add_dimension();
buf.dim[buf.dimensions-1].stride = s;
}
void set_host_dirty(bool v = true) {
buf.set_host_dirty(v);
}
bool device_dirty() const {
return buf.device_dirty();
}
bool host_dirty() const {
return buf.host_dirty();
}
void set_device_dirty(bool v = true) {
buf.set_device_dirty(v);
}
int copy_to_host(void *ctx = nullptr) {
if (device_dirty()) {
return halide_copy_to_host(ctx, &buf);
}
return 0;
}
int copy_to_device(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
if (host_dirty()) {
return halide_copy_to_device(ctx, &buf, device_interface);
}
return 0;
}
int device_malloc(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
return halide_device_malloc(ctx, &buf, device_interface);
}
int device_free(void *ctx = nullptr) {
if (dev_ref_count) {
assert(*dev_ref_count == 1 &&
"Multiple Halide::Runtime::Buffer objects share this device "
"allocation. Freeing it would create dangling references. "
"Don't call device_free on Halide buffers that you have copied or "
"passed by value.");
}
int ret = halide_device_free(ctx, &buf);
if (dev_ref_count) {
delete dev_ref_count;
dev_ref_count = nullptr;
}
return ret;
}
int device_sync(void *ctx = nullptr) {
return halide_device_sync(ctx, &buf);
}
bool has_device_allocation() const {
return buf.device != 0;
}
static Buffer<void, D> make_interleaved(halide_type_t t, int width, int height, int channels) {
Buffer<void, D> im(t, channels, width, height);
im.transpose(0, 1);
im.transpose(1, 2);
return im;
}
static Buffer<T, D> make_interleaved(int width, int height, int channels) {
Buffer<T, D> im(channels, width, height);
im.transpose(0, 1);
im.transpose(1, 2);
return im;
}
static Buffer<add_const_if_T_is_const<void>, D>
make_interleaved(halide_type_t t, T *data, int width, int height, int channels) {
Buffer<add_const_if_T_is_const<void>, D> im(t, data, channels, width, height);
im.transpose(0, 1);
im.transpose(1, 2);
return im;
}
static Buffer<T, D> make_interleaved(T *data, int width, int height, int channels) {
Buffer<T, D> im(data, channels, width, height);
im.transpose(0, 1);
im.transpose(1, 2);
return im;
}
static Buffer<add_const_if_T_is_const<void>, D> make_scalar(halide_type_t t) {
Buffer<add_const_if_T_is_const<void>, 1> buf(t, 1);
buf.slice(0, 0);
return buf;
}
static Buffer<T, D> make_scalar() {
Buffer<T, 1> buf(1);
buf.slice(0, 0);
return buf;
}
template<typename T2, int D2>
static Buffer<T, D> make_with_shape_of(Buffer<T2, D2> src,
void *(*allocate_fn)(size_t) = nullptr,
void (*deallocate_fn)(void *) = nullptr) {
std::vector<int> swaps;
for (int i = src.dimensions()-1; i > 0; i--) {
for (int j = i; j > 0; j--) {
if (src.dim(j-1).stride() > src.dim(j).stride()) {
src.transpose(j-1, j);
swaps.push_back(j);
}
}
}
halide_dimension_t *shape = src.buf.dim;
for (int i = 0; i < src.dimensions(); i++) {
if (i == 0) {
shape[i].stride = 1;
} else {
shape[i].stride = shape[i-1].extent * shape[i-1].stride;
}
}
while (!swaps.empty()) {
int j = swaps.back();
std::swap(shape[j-1], shape[j]);
swaps.pop_back();
}
Buffer<T, D> dst(nullptr, src.dimensions(), shape);
dst.allocate();
return dst;
}
private:
template<typename ...Args>
HALIDE_ALWAYS_INLINE
ptrdiff_t offset_of(int d, int first, Args... rest) const {
return offset_of(d+1, rest...) + this->buf.dim[d].stride * (first - this->buf.dim[d].min);
}
HALIDE_ALWAYS_INLINE
ptrdiff_t offset_of(int d) const {
return 0;
}
template<typename ...Args>
HALIDE_ALWAYS_INLINE
storage_T *address_of(Args... args) const {
if (T_is_void) {
return (storage_T *)(this->buf.host) + offset_of(0, args...) * type().bytes();
} else {
return (storage_T *)(this->buf.host) + offset_of(0, args...);
}
}
HALIDE_ALWAYS_INLINE
ptrdiff_t offset_of(const int *pos) const {
ptrdiff_t offset = 0;
for (int i = this->dimensions() - 1; i >= 0; i--) {
offset += this->buf.dim[i].stride * (pos[i] - this->buf.dim[i].min);
}
return offset;
}
HALIDE_ALWAYS_INLINE
storage_T *address_of(const int *pos) const {
if (T_is_void) {
return (storage_T *)this->buf.host + offset_of(pos) * type().bytes();
} else {
return (storage_T *)this->buf.host + offset_of(pos);
}
}
public:
T *data() {
return (T *)(this->buf.host);
}
const T *data() const {
return (const T *)(this->buf.host);
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
HALIDE_ALWAYS_INLINE
const not_void_T &operator()(int first, Args... rest) const {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
return *((const not_void_T *)(address_of(first, rest...)));
}
HALIDE_ALWAYS_INLINE
const not_void_T &
operator()() const {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
return *((const not_void_T *)(data()));
}
HALIDE_ALWAYS_INLINE
const not_void_T &
operator()(const int *pos) const {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
return *((const not_void_T *)(address_of(pos)));
}
template<typename ...Args,
typename = typename std::enable_if<AllInts<Args...>::value>::type>
HALIDE_ALWAYS_INLINE
not_void_T &operator()(int first, Args... rest) {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
set_host_dirty();
return *((not_void_T *)(address_of(first, rest...)));
}
HALIDE_ALWAYS_INLINE
not_void_T &
operator()() {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
set_host_dirty();
return *((not_void_T *)(data()));
}
HALIDE_ALWAYS_INLINE
not_void_T &
operator()(const int *pos) {
static_assert(!T_is_void,
"Cannot use operator() on Buffer<void> types");
set_host_dirty();
return *((not_void_T *)(address_of(pos)));
}
void fill(not_void_T val) {
for_each_value([=](T &v) {v = val;});
set_host_dirty();
}
private:
template<int N>
struct for_each_value_task_dim {
int extent;
int stride[N];
};
template<typename Ptr, typename ...Ptrs>
static void advance_ptrs(const int *stride, Ptr *ptr, Ptrs... ptrs) {
(*ptr) += *stride;
advance_ptrs(stride + 1, ptrs...);
}
static void advance_ptrs(const int *) {}
template<typename Ptr, typename ...Ptrs>
static void increment_ptrs(Ptr *ptr, Ptrs... ptrs) {
(*ptr)++;
increment_ptrs(ptrs...);
}
static void increment_ptrs() {}
template<typename T2, int D2, typename ...Args>
void extract_strides(int d, int *strides, const Buffer<T2, D2> *first, Args... rest) {
assert(first->dimensions() == dimensions());
assert(first->dim(d).min() == dim(d).min() &&
first->dim(d).max() == dim(d).max());
*strides++ = first->dim(d).stride();
extract_strides(d, strides, rest...);
}
void extract_strides(int d, int *strides) {}
template<int d, bool innermost_strides_are_one, typename Fn, typename... Ptrs>
static void for_each_value_helper(Fn &&f, const for_each_value_task_dim<sizeof...(Ptrs)> *t, Ptrs... ptrs) {
if (d == -1) {
f((*ptrs)...);
} else {
for (int i = t[d].extent; i != 0; i--) {
for_each_value_helper<(d >= 0 ? d - 1 : -1), innermost_strides_are_one>(f, t, ptrs...);
if (d == 0 && innermost_strides_are_one) {
increment_ptrs((&ptrs)...);
} else {
advance_ptrs(t[d].stride, (&ptrs)...);
}
}
}
}
template<bool innermost_strides_are_one, typename Fn, typename... Ptrs>
static void for_each_value_helper(Fn &&f, int d, const for_each_value_task_dim<sizeof...(Ptrs)> *t, Ptrs... ptrs) {
if (d == -1) {
for_each_value_helper<-1, innermost_strides_are_one>(f, t, ptrs...);
} else if (d == 0) {
for_each_value_helper<0, innermost_strides_are_one>(f, t, ptrs...);
} else if (d == 1) {
for_each_value_helper<1, innermost_strides_are_one>(f, t, ptrs...);
} else if (d == 2) {
for_each_value_helper<2, innermost_strides_are_one>(f, t, ptrs...);
} else {
for (int i = t[d].extent; i != 0; i--) {
for_each_value_helper<innermost_strides_are_one>(f, d-1, t, ptrs...);
advance_ptrs(t[d].stride, (&ptrs)...);
}
}
}
public:
template<typename Fn, typename ...Args, int N = sizeof...(Args) + 1>
void for_each_value(Fn &&f, Args... other_buffers) {
for_each_value_task_dim<N> *t =
(for_each_value_task_dim<N> *)HALIDE_ALLOCA((dimensions()+1) * sizeof(for_each_value_task_dim<N>));
for (int i = 0; i <= dimensions(); i++) {
for (int j = 0; j < N; j++) {
t[i].stride[j] = 0;
}
t[i].extent = 1;
}
for (int i = 0; i < dimensions(); i++) {
extract_strides(i, t[i].stride, this, &other_buffers...);
t[i].extent = dim(i).extent();
for (int j = i; j > 0 && t[j].stride[0] < t[j-1].stride[0]; j--) {
std::swap(t[j], t[j-1]);
}
}
int d = dimensions();
for (int i = 1; i < d; i++) {
bool flat = true;
for (int j = 0; j < N; j++) {
flat = flat && t[i-1].stride[j] * t[i-1].extent == t[i].stride[j];
}
if (flat) {
t[i-1].extent *= t[i].extent;
for (int j = i; j < dimensions(); j++) {
t[j] = t[j+1];
}
i--;
d--;
}
}
bool innermost_strides_are_one = false;
if (dimensions() > 0) {
innermost_strides_are_one = true;
for (int j = 0; j < N; j++) {
innermost_strides_are_one &= t[0].stride[j] == 1;
}
}
if (innermost_strides_are_one) {
for_each_value_helper<true>(f, dimensions() - 1, t, begin(), (other_buffers.begin())...);
} else {
for_each_value_helper<false>(f, dimensions() - 1, t, begin(), (other_buffers.begin())...);
}
}
private:
struct for_each_element_task_dim {
int min, max;
};
template<typename Fn,
typename ...Args,
typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
HALIDE_ALWAYS_INLINE
static void for_each_element_variadic(int, int, const for_each_element_task_dim *, Fn &&f, Args... args) {
f(args...);
}
template<typename Fn,
typename ...Args>
HALIDE_ALWAYS_INLINE
static void for_each_element_variadic(double, int d, const for_each_element_task_dim *t, Fn &&f, Args... args) {
for (int i = t[d].min; i <= t[d].max; i++) {
for_each_element_variadic(0, d - 1, t, std::forward<Fn>(f), i, args...);
}
}
template<typename Fn,
typename ...Args,
typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
HALIDE_ALWAYS_INLINE
static int num_args(int, Fn &&, Args...) {
return (int)(sizeof...(Args));
}
template<typename Fn,
typename ...Args>
HALIDE_ALWAYS_INLINE
static int num_args(double, Fn &&f, Args... args) {
static_assert(sizeof...(args) <= 256,
"Callable passed to for_each_element must accept either a const int *,"
" or up to 256 ints. No such operator found. Expect infinite template recursion.");
return num_args(0, std::forward<Fn>(f), 0, args...);
}
template<int d,
typename Fn,
typename = typename std::enable_if<(d >= 0)>::type>
HALIDE_ALWAYS_INLINE
static void for_each_element_array_helper(int, const for_each_element_task_dim *t, Fn &&f, int *pos) {
for (pos[d] = t[d].min; pos[d] <= t[d].max; pos[d]++) {
for_each_element_array_helper<d - 1>(0, t, std::forward<Fn>(f), pos);
}
}
template<int d,
typename Fn,
typename = typename std::enable_if<(d < 0)>::type>
HALIDE_ALWAYS_INLINE
static void for_each_element_array_helper(double, const for_each_element_task_dim *t, Fn &&f, int *pos) {
f(pos);
}
template<typename Fn>
static void for_each_element_array(int d, const for_each_element_task_dim *t, Fn &&f, int *pos) {
if (d == -1) {
f(pos);
} else if (d == 0) {
for_each_element_array_helper<0, Fn>(0, t, std::forward<Fn>(f), pos);
} else if (d == 1) {
for_each_element_array_helper<1, Fn>(0, t, std::forward<Fn>(f), pos);
} else if (d == 2) {
for_each_element_array_helper<2, Fn>(0, t, std::forward<Fn>(f), pos);
} else if (d == 3) {
for_each_element_array_helper<3, Fn>(0, t, std::forward<Fn>(f), pos);
} else {
for (pos[d] = t[d].min; pos[d] <= t[d].max; pos[d]++) {
for_each_element_array(d - 1, t, std::forward<Fn>(f), pos);
}
}
}
template<typename Fn,
typename = decltype(std::declval<Fn>()((const int *)nullptr))>
static void for_each_element(int, int dims, const for_each_element_task_dim *t, Fn &&f, int check = 0) {
int *pos = (int *)HALIDE_ALLOCA(dims * sizeof(int));
for_each_element_array(dims - 1, t, std::forward<Fn>(f), pos);
}
template<typename Fn>
HALIDE_ALWAYS_INLINE
static void for_each_element(double, int dims, const for_each_element_task_dim *t, Fn &&f) {
int args = num_args(0, std::forward<Fn>(f));
assert(dims >= args);
for_each_element_variadic(0, args - 1, t, std::forward<Fn>(f));
}
public:
template<typename Fn>
void for_each_element(Fn &&f) const {
for_each_element_task_dim *t =
(for_each_element_task_dim *)HALIDE_ALLOCA(dimensions() * sizeof(for_each_element_task_dim));
for (int i = 0; i < dimensions(); i++) {
t[i].min = dim(i).min();
t[i].max = dim(i).max();
}
for_each_element(0, dimensions(), t, std::forward<Fn>(f));
}
private:
template<typename Fn>
struct FillHelper {
Fn f;
Buffer<T, D> *buf;
template<typename... Args,
typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
void operator()(Args... args) {
(*buf)(args...) = f(args...);
}
FillHelper(Fn &&f, Buffer<T, D> *buf) : f(std::forward<Fn>(f)), buf(buf) {}
};
public:
template<typename Fn,
typename = typename std::enable_if<!std::is_arithmetic<typename std::decay<Fn>::type>::value>::type>
void fill(Fn &&f) {
FillHelper<Fn> wrapper(std::forward<Fn>(f), this);
for_each_element(wrapper);
}
};
}
}
#undef HALIDE_ALLOCA
#endif