This source file includes following definitions.
- debug_print_buffer
- to_hex_char
- debug_print_key
- keys_equal
- buffer_has_shape
- get_pointer_to_header
- init
- destroy
- djb_hash
- validate_cache
- prune_cache
- halide_memoization_cache_set_size
- halide_memoization_cache_lookup
- halide_memoization_cache_store
- halide_memoization_cache_release
- halide_memoization_cache_cleanup
- halide_cache_cleanup
#include "HalideRuntime.h"
#include "device_buffer_utils.h"
#include "printer.h"
#include "scoped_mutex_lock.h"
namespace Halide { namespace Runtime { namespace Internal {
#define CACHE_DEBUGGING 0
#if CACHE_DEBUGGING
WEAK void debug_print_buffer(void *user_context, const char *buf_name, const halide_buffer_t &buf) {
debug(user_context) << buf_name << ": elem_size " << buf.type.bytes() << " dimensions " << buf.dimensions << ", ";
for (int i = 0; i < buf.dimensions; i++) {
debug(user_context) << "(" << buf.dim[i].min
<< ", " << buf.dim[i].extent
<< ", " << buf.dim[i].stride << ") ";
}
debug(user_context) << "\n";
}
WEAK char to_hex_char(int val) {
if (val < 10) {
return '0' + val;
}
return 'A' + (val - 10);
}
WEAK void debug_print_key(void *user_context, const char *msg, const uint8_t *cache_key, int32_t key_size) {
debug(user_context) << "Key for " << msg << "\n";
char buf[1024];
bool append_ellipses = false;
if ((size_t)key_size > (sizeof(buf) / 2) - 1) {
append_ellipses = true;
key_size = (sizeof(buf) / 2) - 4;
}
char *buf_ptr = buf;
for (int i = 0; i < key_size; i++) {
if (cache_key[i] >= 32 && cache_key[i] <= '~') {
*buf_ptr++ = cache_key[i];
} else {
*buf_ptr++ = to_hex_char((cache_key[i] >> 4));
*buf_ptr++ = to_hex_char((cache_key[i] & 0xf));
}
}
if (append_ellipses) {
*buf_ptr++ = '.';
*buf_ptr++ = '.';
*buf_ptr++ = '.';
}
*buf_ptr++ = '\0';
debug(user_context) << buf << "\n";
}
#endif
WEAK bool keys_equal(const uint8_t *key1, const uint8_t *key2, size_t key_size) {
return memcmp(key1, key2, key_size) == 0;
}
WEAK bool buffer_has_shape(const halide_buffer_t *buf, const halide_dimension_t *shape) {
for (int i = 0; i < buf->dimensions; i++) {
if (buf->dim[i] != shape[i]) return false;
}
return true;
}
const size_t extra_bytes_host_bytes = 16;
struct CacheEntry {
CacheEntry *next;
CacheEntry *more_recent;
CacheEntry *less_recent;
uint8_t *metadata_storage;
size_t key_size;
uint8_t *key;
uint32_t hash;
uint32_t in_use_count;
uint32_t tuple_count;
int32_t dimensions;
halide_dimension_t *computed_bounds;
halide_buffer_t *buf;
bool init(const uint8_t *cache_key, size_t cache_key_size,
uint32_t key_hash,
const halide_buffer_t *computed_bounds_buf,
int32_t tuples, halide_buffer_t **tuple_buffers);
void destroy();
halide_buffer_t &buffer(int32_t i);
};
struct CacheBlockHeader {
CacheEntry *entry;
uint32_t hash;
};
WEAK CacheBlockHeader *get_pointer_to_header(uint8_t * host) {
return (CacheBlockHeader *)(host - extra_bytes_host_bytes);
}
WEAK bool CacheEntry::init(const uint8_t *cache_key, size_t cache_key_size,
uint32_t key_hash, const halide_buffer_t *computed_bounds_buf,
int32_t tuples, halide_buffer_t **tuple_buffers) {
next = NULL;
more_recent = NULL;
less_recent = NULL;
key_size = cache_key_size;
hash = key_hash;
in_use_count = 0;
tuple_count = tuples;
dimensions = computed_bounds_buf->dimensions;
size_t storage_bytes = 0;
storage_bytes += sizeof(halide_buffer_t) * tuple_count;
size_t shape_offset = storage_bytes;
storage_bytes += sizeof(halide_dimension_t) * dimensions * (tuple_count + 1);
size_t key_offset = storage_bytes;
storage_bytes += key_size;
metadata_storage = (uint8_t *)halide_malloc(NULL, storage_bytes);
if (!metadata_storage) {
return false;
}
buf = (halide_buffer_t *)metadata_storage;
computed_bounds = (halide_dimension_t *)(metadata_storage + shape_offset);
key = metadata_storage + key_offset;
for (size_t i = 0; i < key_size; i++) {
key[i] = cache_key[i];
}
for (int i = 0; i < dimensions; i++) {
computed_bounds[i] = computed_bounds_buf->dim[i];
}
for (uint32_t i = 0; i < tuple_count; i++) {
buf[i] = *tuple_buffers[i];
buf[i].dim = computed_bounds + (i+1)*dimensions;
for (int j = 0; j < dimensions; j++) {
buf[i].dim[j] = tuple_buffers[i]->dim[j];
}
}
return true;
}
WEAK void CacheEntry::destroy() {
for (uint32_t i = 0; i < tuple_count; i++) {
halide_device_free(NULL, &buf[i]);
halide_free(NULL, get_pointer_to_header(buf[i].host));
}
halide_free(NULL, metadata_storage);
}
WEAK uint32_t djb_hash(const uint8_t *key, size_t key_size) {
uint32_t h = 5381;
for (size_t i = 0; i < key_size; i++) {
h = (h << 5) + h + key[i];
}
return h;
}
WEAK halide_mutex memoization_lock;
const size_t kHashTableSize = 256;
WEAK CacheEntry *cache_entries[kHashTableSize];
WEAK CacheEntry *most_recently_used = NULL;
WEAK CacheEntry *least_recently_used = NULL;
const uint64_t kDefaultCacheSize = 1 << 20;
WEAK int64_t max_cache_size = kDefaultCacheSize;
WEAK int64_t current_cache_size = 0;
#if CACHE_DEBUGGING
WEAK void validate_cache() {
print(NULL) << "validating cache, "
<< "current size " << current_cache_size
<< " of maximum " << max_cache_size << "\n";
int entries_in_hash_table = 0;
for (size_t i = 0; i < kHashTableSize; i++) {
CacheEntry *entry = cache_entries[i];
while (entry != NULL) {
entries_in_hash_table++;
if (entry->more_recent == NULL && entry != most_recently_used) {
halide_print(NULL, "cache invalid case 1\n");
__builtin_trap();
}
if (entry->less_recent == NULL && entry != least_recently_used) {
halide_print(NULL, "cache invalid case 2\n");
__builtin_trap();
}
entry = entry->next;
}
}
int entries_from_mru = 0;
CacheEntry *mru_chain = most_recently_used;
while (mru_chain != NULL) {
entries_from_mru++;
mru_chain = mru_chain->less_recent;
}
int entries_from_lru = 0;
CacheEntry *lru_chain = least_recently_used;
while (lru_chain != NULL) {
entries_from_lru++;
lru_chain = lru_chain->more_recent;
}
print(NULL) << "hash entries " << entries_in_hash_table
<< ", mru entries " << entries_from_mru
<< ", lru entries " << entries_from_lru << "\n";
if (entries_in_hash_table != entries_from_mru) {
halide_print(NULL, "cache invalid case 3\n");
__builtin_trap();
}
if (entries_in_hash_table != entries_from_lru) {
halide_print(NULL, "cache invalid case 4\n");
__builtin_trap();
}
if (current_cache_size < 0) {
halide_print(NULL, "cache size is negative\n");
__builtin_trap();
}
}
#endif
WEAK void prune_cache() {
#if CACHE_DEBUGGING
validate_cache();
#endif
CacheEntry *prune_candidate = least_recently_used;
while (current_cache_size > max_cache_size &&
prune_candidate != NULL) {
CacheEntry *more_recent = prune_candidate->more_recent;
if (prune_candidate->in_use_count == 0) {
uint32_t h = prune_candidate->hash;
uint32_t index = h % kHashTableSize;
CacheEntry *prev_hash_entry = cache_entries[index];
if (prev_hash_entry == prune_candidate) {
cache_entries[index] = prune_candidate->next;
} else {
while (prev_hash_entry != NULL && prev_hash_entry->next != prune_candidate) {
prev_hash_entry = prev_hash_entry->next;
}
halide_assert(NULL, prev_hash_entry != NULL);
prev_hash_entry->next = prune_candidate->next;
}
if (least_recently_used == prune_candidate) {
least_recently_used = more_recent;
}
if (more_recent != NULL) {
more_recent->less_recent = prune_candidate->less_recent;
}
if (most_recently_used == prune_candidate) {
most_recently_used = prune_candidate->less_recent;
}
if (prune_candidate->less_recent != NULL) {
prune_candidate->less_recent = more_recent;
}
for (uint32_t i = 0; i < prune_candidate->tuple_count; i++) {
current_cache_size -= prune_candidate->buf[i].size_in_bytes();
}
prune_candidate->destroy();
halide_free(NULL, prune_candidate);
}
prune_candidate = more_recent;
}
#if CACHE_DEBUGGING
validate_cache();
#endif
}
}}}
extern "C" {
WEAK void halide_memoization_cache_set_size(int64_t size) {
if (size == 0) {
size = kDefaultCacheSize;
}
ScopedMutexLock lock(&memoization_lock);
max_cache_size = size;
prune_cache();
}
WEAK int halide_memoization_cache_lookup(void *user_context, const uint8_t *cache_key, int32_t size,
halide_buffer_t *computed_bounds, int32_t tuple_count, halide_buffer_t **tuple_buffers) {
uint32_t h = djb_hash(cache_key, size);
uint32_t index = h % kHashTableSize;
ScopedMutexLock lock(&memoization_lock);
#if CACHE_DEBUGGING
debug_print_key(user_context, "halide_memoization_cache_lookup", cache_key, size);
debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
{
for (int32_t i = 0; i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
debug_print_buffer(user_context, "Allocation bounds", *buf);
}
}
#endif
CacheEntry *entry = cache_entries[index];
while (entry != NULL) {
if (entry->hash == h && entry->key_size == (size_t)size &&
keys_equal(entry->key, cache_key, size) &&
buffer_has_shape(computed_bounds, entry->computed_bounds) &&
entry->tuple_count == (uint32_t)tuple_count) {
bool all_bounds_equal = true;
for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
}
if (all_bounds_equal) {
if (entry != most_recently_used) {
halide_assert(user_context, entry->more_recent != NULL);
if (entry->less_recent != NULL) {
entry->less_recent->more_recent = entry->more_recent;
} else {
halide_assert(user_context, least_recently_used == entry);
least_recently_used = entry->more_recent;
}
halide_assert(user_context, entry->more_recent != NULL);
entry->more_recent->less_recent = entry->less_recent;
entry->more_recent = NULL;
entry->less_recent = most_recently_used;
if (most_recently_used != NULL) {
most_recently_used->more_recent = entry;
}
most_recently_used = entry;
}
for (int32_t i = 0; i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
*buf = entry->buf[i];
}
entry->in_use_count += tuple_count;
return 0;
}
}
entry = entry->next;
}
for (int32_t i = 0; i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
buf->host = ((uint8_t *)halide_malloc(user_context, buf->size_in_bytes() + extra_bytes_host_bytes));
if (buf->host == NULL) {
for (int32_t j = i; j > 0; j--) {
halide_free(user_context, get_pointer_to_header(tuple_buffers[j - 1]->host));
tuple_buffers[j - 1]->host = NULL;
}
return -1;
}
buf->host += extra_bytes_host_bytes;
CacheBlockHeader *header = get_pointer_to_header(buf->host);
header->hash = h;
header->entry = NULL;
}
#if CACHE_DEBUGGING
validate_cache();
#endif
return 1;
}
WEAK int halide_memoization_cache_store(void *user_context, const uint8_t *cache_key, int32_t size,
halide_buffer_t *computed_bounds,
int32_t tuple_count, halide_buffer_t **tuple_buffers) {
debug(user_context) << "halide_memoization_cache_store\n";
uint32_t h = get_pointer_to_header(tuple_buffers[0]->host)->hash;
uint32_t index = h % kHashTableSize;
ScopedMutexLock lock(&memoization_lock);
#if CACHE_DEBUGGING
debug_print_key(user_context, "halide_memoization_cache_store", cache_key, size);
debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
{
for (int32_t i = 0; i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
debug_print_buffer(user_context, "Allocation bounds", *buf);
}
}
#endif
CacheEntry *entry = cache_entries[index];
while (entry != NULL) {
if (entry->hash == h && entry->key_size == (size_t)size &&
keys_equal(entry->key, cache_key, size) &&
buffer_has_shape(computed_bounds, entry->computed_bounds) &&
entry->tuple_count == (uint32_t)tuple_count) {
bool all_bounds_equal = true;
bool no_host_pointers_equal = true;
{
for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
if (entry->buf[i].host == buf->host) {
no_host_pointers_equal = false;
}
}
}
if (all_bounds_equal) {
halide_assert(user_context, no_host_pointers_equal);
for (int32_t i = 0; i < tuple_count; i++) {
get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
}
return 0;
}
}
entry = entry->next;
}
uint64_t added_size = 0;
{
for (int32_t i = 0; i < tuple_count; i++) {
halide_buffer_t *buf = tuple_buffers[i];
added_size += buf->size_in_bytes();
}
}
current_cache_size += added_size;
prune_cache();
CacheEntry *new_entry = (CacheEntry *)halide_malloc(NULL, sizeof(CacheEntry));
bool inited = false;
if (new_entry) {
inited = new_entry->init(cache_key, size, h, computed_bounds, tuple_count, tuple_buffers);
}
if (!inited) {
current_cache_size -= added_size;
for (int32_t i = 0; i < tuple_count; i++) {
get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
}
if (new_entry) {
halide_free(user_context, new_entry);
}
return 0;
}
new_entry->next = cache_entries[index];
new_entry->less_recent = most_recently_used;
if (most_recently_used != NULL) {
most_recently_used->more_recent = new_entry;
}
most_recently_used = new_entry;
if (least_recently_used == NULL) {
least_recently_used = new_entry;
}
cache_entries[index] = new_entry;
new_entry->in_use_count = tuple_count;
for (int32_t i = 0; i < tuple_count; i++) {
get_pointer_to_header(tuple_buffers[i]->host)->entry = new_entry;
}
#if CACHE_DEBUGGING
validate_cache();
#endif
debug(user_context) << "Exiting halide_memoization_cache_store\n";
return 0;
}
WEAK void halide_memoization_cache_release(void *user_context, void *host) {
CacheBlockHeader *header = get_pointer_to_header((uint8_t *)host);
debug(user_context) << "halide_memoization_cache_release\n";
CacheEntry *entry = header->entry;
if (entry == NULL) {
halide_free(user_context, header);
} else {
ScopedMutexLock lock(&memoization_lock);
halide_assert(user_context, entry->in_use_count > 0);
entry->in_use_count--;
#if CACHE_DEBUGGING
validate_cache();
#endif
}
debug(user_context) << "Exited halide_memoization_cache_release.\n";
}
WEAK void halide_memoization_cache_cleanup() {
debug(NULL) << "halide_memoization_cache_cleanup\n";
for (size_t i = 0; i < kHashTableSize; i++) {
CacheEntry *entry = cache_entries[i];
cache_entries[i] = NULL;
while (entry != NULL) {
CacheEntry *next = entry->next;
entry->destroy();
halide_free(NULL, entry);
entry = next;
}
}
current_cache_size = 0;
most_recently_used = NULL;
least_recently_used = NULL;
halide_mutex_destroy(&memoization_lock);
}
namespace {
__attribute__((destructor))
WEAK void halide_cache_cleanup() {
halide_memoization_cache_cleanup();
}
}
}