root/src/runtime/cache.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. debug_print_buffer
  2. to_hex_char
  3. debug_print_key
  4. keys_equal
  5. buffer_has_shape
  6. get_pointer_to_header
  7. init
  8. destroy
  9. djb_hash
  10. validate_cache
  11. prune_cache
  12. halide_memoization_cache_set_size
  13. halide_memoization_cache_lookup
  14. halide_memoization_cache_store
  15. halide_memoization_cache_release
  16. halide_memoization_cache_cleanup
  17. halide_cache_cleanup

#include "HalideRuntime.h"
#include "device_buffer_utils.h"
#include "printer.h"
#include "scoped_mutex_lock.h"

// This is temporary code. In particular, the hash table is stupid and
// currently thread safety is accomplished via large granularity spin
// locks. It is mainly intended to prove the programming model and
// runtime interface for memoization. We'll improve the implementation
// later. In the meantime, on some platforms it can be replaced by a
// platform specific LRU cache such as libcache from Apple.

namespace Halide { namespace Runtime { namespace Internal {

#define CACHE_DEBUGGING 0

#if CACHE_DEBUGGING
WEAK void debug_print_buffer(void *user_context, const char *buf_name, const halide_buffer_t &buf) {
    debug(user_context) << buf_name << ": elem_size " << buf.type.bytes() << " dimensions " << buf.dimensions << ", ";
    for (int i = 0; i < buf.dimensions; i++) {
        debug(user_context) << "(" << buf.dim[i].min
                            << ", " << buf.dim[i].extent
                            << ", " << buf.dim[i].stride << ") ";
    }
    debug(user_context) << "\n";

}

WEAK char to_hex_char(int val) {
    if (val < 10) {
        return '0' + val;
    }
    return 'A' + (val - 10);
}

WEAK void debug_print_key(void *user_context, const char *msg, const uint8_t *cache_key, int32_t key_size) {
    debug(user_context) << "Key for " << msg << "\n";
    char buf[1024];
    bool append_ellipses = false;
    if ((size_t)key_size > (sizeof(buf) / 2) - 1) { // Each byte in key can take two bytes in output
        append_ellipses = true;
        key_size = (sizeof(buf) / 2) - 4; // room for NUL and "..."
    }
    char *buf_ptr = buf;
    for (int i = 0; i < key_size; i++) {
        if (cache_key[i] >= 32 && cache_key[i] <= '~') {
            *buf_ptr++ = cache_key[i];
        } else {
            *buf_ptr++ = to_hex_char((cache_key[i] >> 4));
            *buf_ptr++ = to_hex_char((cache_key[i] & 0xf));
        }
    }
    if (append_ellipses) {
        *buf_ptr++ = '.';
        *buf_ptr++ = '.';
        *buf_ptr++ = '.';
    }
    *buf_ptr++ = '\0';
    debug(user_context) << buf << "\n";
}
#endif

WEAK bool keys_equal(const uint8_t *key1, const uint8_t *key2, size_t key_size) {
    return memcmp(key1, key2, key_size) == 0;
}


WEAK bool buffer_has_shape(const halide_buffer_t *buf, const halide_dimension_t *shape) {
    for (int i = 0; i < buf->dimensions; i++) {
        if (buf->dim[i] != shape[i]) return false;
    }
    return true;
}

// Each host block has extra space to store a header just before the contents.
// 16 is chosen to keep that alignment.
// The header holds the cache key hash and pointer to the hash entry.
//
// This is an optimization the number of cycles it takes for the cache
// to operate.
const size_t extra_bytes_host_bytes = 16;

struct CacheEntry {
    CacheEntry *next;
    CacheEntry *more_recent;
    CacheEntry *less_recent;
    uint8_t *metadata_storage;
    size_t key_size;
    uint8_t *key;
    uint32_t hash;
    uint32_t in_use_count; // 0 if none returned from halide_cache_lookup
    uint32_t tuple_count;
    // The shape of the computed data. There may be more data allocated than this.
    int32_t dimensions;
    halide_dimension_t *computed_bounds;
    // The actual stored data.
    halide_buffer_t *buf;

    bool init(const uint8_t *cache_key, size_t cache_key_size,
              uint32_t key_hash,
              const halide_buffer_t *computed_bounds_buf,
              int32_t tuples, halide_buffer_t **tuple_buffers);
    void destroy();
    halide_buffer_t &buffer(int32_t i);

};

struct CacheBlockHeader {
    CacheEntry *entry;
    uint32_t hash;
};

WEAK CacheBlockHeader *get_pointer_to_header(uint8_t * host) {
    return (CacheBlockHeader *)(host - extra_bytes_host_bytes);
}

WEAK bool CacheEntry::init(const uint8_t *cache_key, size_t cache_key_size,
                           uint32_t key_hash, const halide_buffer_t *computed_bounds_buf,
                           int32_t tuples, halide_buffer_t **tuple_buffers) {
    next = NULL;
    more_recent = NULL;
    less_recent = NULL;
    key_size = cache_key_size;
    hash = key_hash;
    in_use_count = 0;
    tuple_count = tuples;
    dimensions = computed_bounds_buf->dimensions;

    // Allocate all the necessary space (or die)
    size_t storage_bytes = 0;

    // First storage for the tuple buffer_t's
    storage_bytes += sizeof(halide_buffer_t) * tuple_count;

    // Then storage for the computed shape, and the allocated shape for
    // each tuple buffer. These may all be distinct.
    size_t shape_offset = storage_bytes;
    storage_bytes += sizeof(halide_dimension_t) * dimensions * (tuple_count + 1);

    // Then storage for the key
    size_t key_offset = storage_bytes;
    storage_bytes += key_size;

    // Do the single malloc call
    metadata_storage = (uint8_t *)halide_malloc(NULL, storage_bytes);
    if (!metadata_storage) {
        return false;
    }

    // Set up the pointers into the allocated metadata space
    buf = (halide_buffer_t *)metadata_storage;
    computed_bounds = (halide_dimension_t *)(metadata_storage + shape_offset);
    key = metadata_storage + key_offset;

    // Copy over the key
    for (size_t i = 0; i < key_size; i++) {
        key[i] = cache_key[i];
    }

    // Copy over the shape of the computed region
    for (int i = 0; i < dimensions; i++) {
        computed_bounds[i] = computed_bounds_buf->dim[i];
    }

    // Copy over the tuple buffers and the shapes of the allocated regions
    for (uint32_t i = 0; i < tuple_count; i++) {
        buf[i] = *tuple_buffers[i];
        buf[i].dim = computed_bounds + (i+1)*dimensions;
        for (int j = 0; j < dimensions; j++) {
            buf[i].dim[j] = tuple_buffers[i]->dim[j];
        }
    }
    return true;
}

WEAK void CacheEntry::destroy() {
    for (uint32_t i = 0; i < tuple_count; i++) {
        halide_device_free(NULL, &buf[i]);
        halide_free(NULL, get_pointer_to_header(buf[i].host));
    }
    halide_free(NULL, metadata_storage);
}

WEAK uint32_t djb_hash(const uint8_t *key, size_t key_size)  {
    uint32_t h = 5381;
    for (size_t i = 0; i < key_size; i++) {
      h = (h << 5) + h + key[i];
    }
    return h;
}

WEAK halide_mutex memoization_lock;

const size_t kHashTableSize = 256;

WEAK CacheEntry *cache_entries[kHashTableSize];

WEAK CacheEntry *most_recently_used = NULL;
WEAK CacheEntry *least_recently_used = NULL;

const uint64_t kDefaultCacheSize = 1 << 20;
WEAK int64_t max_cache_size = kDefaultCacheSize;
WEAK int64_t current_cache_size = 0;

#if CACHE_DEBUGGING
WEAK void validate_cache() {
    print(NULL) << "validating cache, "
                << "current size " << current_cache_size
                << " of maximum " << max_cache_size << "\n";
    int entries_in_hash_table = 0;
    for (size_t i = 0; i < kHashTableSize; i++) {
        CacheEntry *entry = cache_entries[i];
        while (entry != NULL) {
            entries_in_hash_table++;
            if (entry->more_recent == NULL && entry != most_recently_used) {
                halide_print(NULL, "cache invalid case 1\n");
                __builtin_trap();
            }
            if (entry->less_recent == NULL && entry != least_recently_used) {
                halide_print(NULL, "cache invalid case 2\n");
                __builtin_trap();
            }
            entry = entry->next;
        }
    }
    int entries_from_mru = 0;
    CacheEntry *mru_chain = most_recently_used;
    while (mru_chain != NULL) {
        entries_from_mru++;
        mru_chain = mru_chain->less_recent;
    }
    int entries_from_lru = 0;
    CacheEntry *lru_chain = least_recently_used;
    while (lru_chain != NULL) {
        entries_from_lru++;
        lru_chain = lru_chain->more_recent;
    }
    print(NULL) << "hash entries " << entries_in_hash_table
                << ", mru entries " << entries_from_mru
                << ", lru entries " << entries_from_lru << "\n";
    if (entries_in_hash_table != entries_from_mru) {
        halide_print(NULL, "cache invalid case 3\n");
        __builtin_trap();
    }
    if (entries_in_hash_table != entries_from_lru) {
        halide_print(NULL, "cache invalid case 4\n");
        __builtin_trap();
    }
    if (current_cache_size < 0) {
        halide_print(NULL, "cache size is negative\n");
        __builtin_trap();
    }
}
#endif

WEAK void prune_cache() {
#if CACHE_DEBUGGING
    validate_cache();
#endif
    CacheEntry *prune_candidate = least_recently_used;
    while (current_cache_size > max_cache_size &&
           prune_candidate != NULL) {
        CacheEntry *more_recent = prune_candidate->more_recent;

        if (prune_candidate->in_use_count == 0) {
            uint32_t h = prune_candidate->hash;
            uint32_t index = h % kHashTableSize;

            // Remove from hash table
            CacheEntry *prev_hash_entry = cache_entries[index];
            if (prev_hash_entry == prune_candidate) {
                cache_entries[index] = prune_candidate->next;
            } else {
                while (prev_hash_entry != NULL && prev_hash_entry->next != prune_candidate) {
                    prev_hash_entry = prev_hash_entry->next;
                }
                halide_assert(NULL, prev_hash_entry != NULL);
                prev_hash_entry->next = prune_candidate->next;
            }

            // Remove from less recent chain.
            if (least_recently_used == prune_candidate) {
                least_recently_used = more_recent;
            }
            if (more_recent != NULL) {
                more_recent->less_recent = prune_candidate->less_recent;
            }

            // Remove from more recent chain.
            if (most_recently_used == prune_candidate) {
                most_recently_used = prune_candidate->less_recent;
            }
            if (prune_candidate->less_recent != NULL) {
                prune_candidate->less_recent = more_recent;
            }

            // Decrease cache used amount.
            for (uint32_t i = 0; i < prune_candidate->tuple_count; i++) {
                current_cache_size -= prune_candidate->buf[i].size_in_bytes();
            }

            // Deallocate the entry.
            prune_candidate->destroy();
            halide_free(NULL, prune_candidate);
        }

        prune_candidate = more_recent;
    }
#if CACHE_DEBUGGING
    validate_cache();
#endif
}

}}} // namespace Halide::Runtime::Internal

extern "C" {

WEAK void halide_memoization_cache_set_size(int64_t size) {
    if (size == 0) {
        size = kDefaultCacheSize;
    }

    ScopedMutexLock lock(&memoization_lock);

    max_cache_size = size;
    prune_cache();
}

WEAK int halide_memoization_cache_lookup(void *user_context, const uint8_t *cache_key, int32_t size,
                                         halide_buffer_t *computed_bounds, int32_t tuple_count, halide_buffer_t **tuple_buffers) {
    uint32_t h = djb_hash(cache_key, size);
    uint32_t index = h % kHashTableSize;

    ScopedMutexLock lock(&memoization_lock);

#if CACHE_DEBUGGING
    debug_print_key(user_context, "halide_memoization_cache_lookup", cache_key, size);

    debug_print_buffer(user_context, "computed_bounds", *computed_bounds);

    {
        for (int32_t i = 0; i < tuple_count; i++) {
            halide_buffer_t *buf = tuple_buffers[i];
            debug_print_buffer(user_context, "Allocation bounds", *buf);
        }
    }
#endif

    CacheEntry *entry = cache_entries[index];
    while (entry != NULL) {
        if (entry->hash == h && entry->key_size == (size_t)size &&
            keys_equal(entry->key, cache_key, size) &&
            buffer_has_shape(computed_bounds, entry->computed_bounds) &&
            entry->tuple_count == (uint32_t)tuple_count) {

            // Check all the tuple buffers have the same bounds (they should).
            bool all_bounds_equal = true;
            for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
                all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
            }

            if (all_bounds_equal) {
                if (entry != most_recently_used) {
                    halide_assert(user_context, entry->more_recent != NULL);
                    if (entry->less_recent != NULL) {
                        entry->less_recent->more_recent = entry->more_recent;
                    } else {
                        halide_assert(user_context, least_recently_used == entry);
                        least_recently_used = entry->more_recent;
                    }
                    halide_assert(user_context, entry->more_recent != NULL);
                    entry->more_recent->less_recent = entry->less_recent;

                    entry->more_recent = NULL;
                    entry->less_recent = most_recently_used;
                    if (most_recently_used != NULL) {
                        most_recently_used->more_recent = entry;
                    }
                    most_recently_used = entry;
                }

                for (int32_t i = 0; i < tuple_count; i++) {
                    halide_buffer_t *buf = tuple_buffers[i];
                    *buf = entry->buf[i];
                }

                entry->in_use_count += tuple_count;

                return 0;
            }
        }
        entry = entry->next;
    }

    for (int32_t i = 0; i < tuple_count; i++) {
        halide_buffer_t *buf = tuple_buffers[i];

        // See documentation on extra_bytes_host_bytes
        buf->host = ((uint8_t *)halide_malloc(user_context, buf->size_in_bytes() + extra_bytes_host_bytes));
        if (buf->host == NULL) {
            for (int32_t j = i; j > 0; j--) {
                halide_free(user_context, get_pointer_to_header(tuple_buffers[j - 1]->host));
                tuple_buffers[j - 1]->host = NULL;
            }
            return -1;
        }
        buf->host += extra_bytes_host_bytes;
        CacheBlockHeader *header = get_pointer_to_header(buf->host);
        header->hash = h;
        header->entry = NULL;
    }

#if CACHE_DEBUGGING
    validate_cache();
#endif

    return 1;
}

WEAK int halide_memoization_cache_store(void *user_context, const uint8_t *cache_key, int32_t size,
                                        halide_buffer_t *computed_bounds,
                                        int32_t tuple_count, halide_buffer_t **tuple_buffers) {
    debug(user_context) << "halide_memoization_cache_store\n";

    uint32_t h = get_pointer_to_header(tuple_buffers[0]->host)->hash;

    uint32_t index = h % kHashTableSize;

    ScopedMutexLock lock(&memoization_lock);

#if CACHE_DEBUGGING
    debug_print_key(user_context, "halide_memoization_cache_store", cache_key, size);

    debug_print_buffer(user_context, "computed_bounds", *computed_bounds);

    {
        for (int32_t i = 0; i < tuple_count; i++) {
            halide_buffer_t *buf = tuple_buffers[i];
            debug_print_buffer(user_context, "Allocation bounds", *buf);
        }
    }
#endif

    CacheEntry *entry = cache_entries[index];
    while (entry != NULL) {
        if (entry->hash == h && entry->key_size == (size_t)size &&
            keys_equal(entry->key, cache_key, size) &&
            buffer_has_shape(computed_bounds, entry->computed_bounds) &&
            entry->tuple_count == (uint32_t)tuple_count) {

            bool all_bounds_equal = true;
            bool no_host_pointers_equal = true;
            {
                for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
                    halide_buffer_t *buf = tuple_buffers[i];
                    all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
                    if (entry->buf[i].host == buf->host) {
                        no_host_pointers_equal = false;
                    }
                }
            }
            if (all_bounds_equal) {
                halide_assert(user_context, no_host_pointers_equal);
                // This entry is still in use by the caller. Mark it as having no cache entry
                // so halide_memoization_cache_release can free the buffer.
                for (int32_t i = 0; i < tuple_count; i++) {
                    get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;

                }
                return 0;
            }
        }
        entry = entry->next;
    }

    uint64_t added_size = 0;
    {
        for (int32_t i = 0; i < tuple_count; i++) {
            halide_buffer_t *buf = tuple_buffers[i];
            added_size += buf->size_in_bytes();
        }
    }
    current_cache_size += added_size;
    prune_cache();

    CacheEntry *new_entry = (CacheEntry *)halide_malloc(NULL, sizeof(CacheEntry));
    bool inited = false;
    if (new_entry) {
        inited = new_entry->init(cache_key, size, h, computed_bounds, tuple_count, tuple_buffers);
    }
    if (!inited) {
        current_cache_size -= added_size;

        // This entry is still in use by the caller. Mark it as having no cache entry
        // so halide_memoization_cache_release can free the buffer.
        for (int32_t i = 0; i < tuple_count; i++) {
            get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
        }

        if (new_entry) {
            halide_free(user_context, new_entry);
        }
        return 0;
    }

    new_entry->next = cache_entries[index];
    new_entry->less_recent = most_recently_used;
    if (most_recently_used != NULL) {
        most_recently_used->more_recent = new_entry;
    }
    most_recently_used = new_entry;
    if (least_recently_used == NULL) {
        least_recently_used = new_entry;
    }
    cache_entries[index] = new_entry;

    new_entry->in_use_count = tuple_count;

    for (int32_t i = 0; i < tuple_count; i++) {
        get_pointer_to_header(tuple_buffers[i]->host)->entry = new_entry;
    }

#if CACHE_DEBUGGING
    validate_cache();
#endif
    debug(user_context) << "Exiting halide_memoization_cache_store\n";

    return 0;
}

WEAK void halide_memoization_cache_release(void *user_context, void *host) {
    CacheBlockHeader *header = get_pointer_to_header((uint8_t *)host);
    debug(user_context) << "halide_memoization_cache_release\n";
    CacheEntry *entry = header->entry;

    if (entry == NULL) {
        halide_free(user_context, header);
    } else {
        ScopedMutexLock lock(&memoization_lock);

        halide_assert(user_context, entry->in_use_count > 0);
        entry->in_use_count--;
#if CACHE_DEBUGGING
        validate_cache();
#endif
    }

    debug(user_context) << "Exited halide_memoization_cache_release.\n";
}

WEAK void halide_memoization_cache_cleanup() {
    debug(NULL) << "halide_memoization_cache_cleanup\n";
    for (size_t i = 0; i < kHashTableSize; i++) {
        CacheEntry *entry = cache_entries[i];
        cache_entries[i] = NULL;
        while (entry != NULL) {
            CacheEntry *next = entry->next;
            entry->destroy();
            halide_free(NULL, entry);
            entry = next;
        }
    }
    current_cache_size = 0;
    most_recently_used = NULL;
    least_recently_used = NULL;
    halide_mutex_destroy(&memoization_lock);
}

namespace {

__attribute__((destructor))
WEAK void halide_cache_cleanup() {
    halide_memoization_cache_cleanup();
}

}

}

/* [<][>][^][v][top][bottom][index][help] */