1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include "xnnpack/cache.h"
7
8#include <assert.h> // For assert.
9#include <stddef.h> // For size_t.
10#include <stdint.h> // For uint32_t.
11
12#include "xnnpack.h"
13#include "xnnpack/allocator.h"
14#include "xnnpack/log.h"
15#include "xnnpack/math.h"
16#include "xnnpack/mutex.h"
17
18#define XNN_CACHE_HASH_SEED 7
19#define XNN_CACHE_INITIAL_BUCKETS 32
20#define XNN_CACHE_MAX_LOAD 0.75
21// Max load factor is 0.75 (3/4), i.e. num_entries / num_buckets > 3 / 4.
22#define XNN_CACHE_MAX_LOAD_ENTRIES_MULTIPLIER 4
23#define XNN_CACHE_MAX_LOAD_BUCKETS_MULTIPLIER 3
24#define XNN_CACHE_GROWTH_FACTOR 2
25
26// MurmurHash3 implementation, copied from smhasher, with minor modifications in
27// style and main loop.
28
29static inline uint32_t fmix32(uint32_t h)
30{
31 h ^= h >> 16;
32 h *= UINT32_C(0x85EBCA6B);
33 h ^= h >> 13;
34 h *= UINT32_C(0xC2B2AE35);
35 h ^= h >> 16;
36
37 return h;
38}
39
40static uint32_t murmur_hash3(const void* key, size_t len, uint32_t seed)
41{
42 const uint8_t* data = (const uint8_t*) key;
43
44 uint32_t h1 = seed;
45
46 const uint32_t c1 = UINT32_C(0xCC9E2D51);
47 const uint32_t c2 = UINT32_C(0x1B873593);
48
49 const uint32_t* blocks = (const uint32_t*) data;
50 for (; len >= sizeof(uint32_t); len -= sizeof(uint32_t)) {
51 uint32_t k1 = *blocks++;
52
53 k1 *= c1;
54 k1 = math_rotl_u32(k1, 15);
55 k1 *= c2;
56
57 h1 ^= k1;
58 h1 = math_rotl_u32(h1, 13);
59 h1 = h1 * 5 + UINT32_C(0xE6546B64);
60 }
61
62 const uint8_t* tail = (const uint8_t*) blocks;
63
64 uint32_t k1 = 0;
65
66 switch (len & 3) {
67 case 3:
68 k1 ^= tail[2] << 16;
69 case 2:
70 k1 ^= tail[1] << 8;
71 case 1:
72 k1 ^= tail[0];
73 k1 *= c1;
74 k1 = math_rotl_u32(k1, 15);
75 k1 *= c2;
76 h1 ^= k1;
77 };
78
79 h1 ^= len;
80
81 return fmix32(h1);
82}
83
84#ifndef NDEBUG
85// This function is only used by an assert, so do not include it in non-debug
86// builds.
87static inline size_t cache_size(struct xnn_cache* cache) {
88 switch (cache->type) {
89 case xnn_cache_type_code:
90 return cache->code.size;
91 case xnn_cache_type_weights:
92 return cache->weights.size;
93 default:
94 XNN_UNREACHABLE;
95 }
96 return SIZE_MAX;
97}
98#endif
99
100static inline void* cache_start(struct xnn_cache* cache) {
101 switch (cache->type) {
102 case xnn_cache_type_code:
103 return cache->code.start;
104 case xnn_cache_type_weights:
105 return cache->weights.start;
106 default:
107 XNN_UNREACHABLE;
108 }
109 return NULL;
110}
111
112enum xnn_status xnn_init_cache_with_size(struct xnn_cache* cache, size_t num_buckets, enum xnn_cache_type cache_type)
113{
114 memset(cache, 0, sizeof(struct xnn_cache));
115 cache->buckets = (struct xnn_cache_bucket*) xnn_allocate_zero_memory(num_buckets * sizeof(struct xnn_cache_bucket));
116 if (cache->buckets == NULL) {
117 xnn_log_error("fail to allocate memory for cache buckets");
118 return xnn_status_out_of_memory;
119 }
120
121 cache->type = cache_type;
122 cache->num_buckets = num_buckets;
123 return xnn_status_success;
124}
125
126enum xnn_status xnn_init_code_cache_with_size(struct xnn_code_cache* cache, size_t num_buckets)
127{
128 memset(cache, 0, sizeof(struct xnn_code_cache));
129 enum xnn_status status = xnn_status_success;
130 status = xnn_init_cache_with_size(&cache->cache, num_buckets, xnn_cache_type_code);
131 if (status != xnn_status_success) {
132 goto error;
133 }
134
135 status = xnn_allocate_code_memory(&cache->cache.code, XNN_DEFAULT_CODE_BUFFER_SIZE);
136 if (status != xnn_status_success) {
137 goto error;
138 }
139
140 return xnn_status_success;
141
142error:
143 xnn_release_code_cache(cache);
144 return status;
145}
146
147enum xnn_status xnn_init_code_cache(struct xnn_code_cache* cache)
148{
149 return xnn_init_code_cache_with_size(cache, XNN_CACHE_INITIAL_BUCKETS);
150}
151
152static bool cache_buckets_grow(struct xnn_cache* cache)
153{
154 const size_t new_num_buckets = cache->num_buckets * XNN_CACHE_GROWTH_FACTOR;
155 assert(is_po2(new_num_buckets));
156 struct xnn_cache tmp_cache;
157 xnn_init_cache_with_size(&tmp_cache, new_num_buckets, cache->type);
158
159 for (size_t i = 0; i < cache->num_buckets; i++) {
160 struct xnn_cache_bucket b = cache->buckets[i];
161 if (b.size == 0) {
162 continue;
163 }
164
165 // Find the first empty slot by linear probing to insert. No need to check
166 // hashes since we are not looking up anything, just moving things around
167 // into a bigger hash table.
168 const size_t mask = tmp_cache.num_buckets - 1;
169 size_t idx = b.hash & mask;
170 while (tmp_cache.buckets[idx].size != 0) {
171 idx = (idx + 1) & mask;
172 }
173 tmp_cache.buckets[idx].hash = b.hash;
174 tmp_cache.buckets[idx].size = b.size;
175 tmp_cache.buckets[idx].offset = b.offset;
176 }
177
178 xnn_release_memory(cache->buckets);
179
180 cache->buckets = tmp_cache.buckets;
181 cache->num_buckets = tmp_cache.num_buckets;
182 return true;
183}
184
185static inline bool bytes_equal(struct xnn_cache* cache, void* ptr, size_t size, size_t offset)
186{
187 return memcmp(ptr, (void*) ((uintptr_t) cache_start(cache) + offset), size) == 0;
188}
189
190static bool lookup(struct xnn_cache* cache, void* ptr, size_t size, uint32_t hash, size_t* index)
191{
192 assert(is_po2(cache->num_buckets));
193 const size_t mask = cache->num_buckets - 1;
194 size_t idx = hash & mask;
195 const struct xnn_cache_bucket* buckets = cache->buckets;
196
197 // Linear probing.
198 while (buckets[idx].size != 0 &&
199 !(buckets[idx].hash == hash &&
200 size == buckets[idx].size &&
201 bytes_equal(cache, ptr, buckets[idx].size, buckets[idx].offset))) {
202 idx = (idx + 1) & mask;
203 }
204 *index = idx;
205 if (buckets[idx].size == 0) {
206 return false;
207 } else {
208 return true;
209 }
210}
211
212static bool insert(struct xnn_cache* cache, void* ptr, size_t size)
213{
214 const uint32_t hash = murmur_hash3(ptr, size, /*seed=*/XNN_CACHE_HASH_SEED);
215 size_t idx;
216 const bool found = lookup(cache, ptr, size, hash, &idx);
217 if (found) {
218 return false;
219 }
220
221 // Ensure we have enough buckets to keep under our load limit.
222 if (cache->num_entries * XNN_CACHE_MAX_LOAD_ENTRIES_MULTIPLIER >
223 cache->num_buckets * XNN_CACHE_MAX_LOAD_BUCKETS_MULTIPLIER) {
224 if (!cache_buckets_grow(cache)) {
225 // Can't grow hash table anymore.
226 xnn_log_error("failed to grow cache buckets");
227 return false;
228 }
229 xnn_log_debug("successfully grew cache buckets");
230
231 // If the cache grew, idx is stale, since that is based on the old cache's num_buckets.
232 const bool found_in_grown_cache = lookup(cache, ptr, size, hash, &idx);
233 assert(!found_in_grown_cache);
234 (void) found_in_grown_cache; // Silence unused variable warnings.
235 }
236
237 // Check that ptr points into cache's buffer.
238 assert((uintptr_t) ptr >= (uintptr_t) cache_start(cache));
239 if (cache->type == xnn_cache_type_code) {
240 assert((uintptr_t) ptr < (uintptr_t) cache_start(cache) + cache_size(cache));
241 }
242
243 const size_t offset = (uintptr_t) ptr - (uintptr_t) cache_start(cache);
244
245 // Insert the entry.
246 cache->buckets[idx].size = size;
247 cache->buckets[idx].hash = hash;
248 cache->buckets[idx].offset = offset;
249 cache->num_entries++;
250 return true;
251}
252
253// Checks if a generated microkernel is already in the cache, returns the offset
254// if found, XNN_CACHE_NOT_FOUND otherwise.
255static size_t lookup_cache(struct xnn_cache* cache, void* ptr, size_t size)
256{
257 const uint32_t hash = murmur_hash3(ptr, size, /*seed=*/XNN_CACHE_HASH_SEED);
258 size_t bucket_idx;
259 if (lookup(cache, ptr, size, hash, &bucket_idx)) {
260 cache->hits++;
261 return cache->buckets[bucket_idx].offset;
262 } else {
263 cache->misses++;
264 return XNN_CACHE_NOT_FOUND;
265 }
266}
267
268size_t xnn_get_or_insert_cache(struct xnn_cache* cache, void* ptr, size_t size)
269{
270 const size_t found_offset = lookup_cache(cache, ptr, size);
271 if (found_offset != XNN_CACHE_NOT_FOUND) {
272 if (cache->type == xnn_cache_type_code) {
273 // Found in the cache, rewind the buffer because code generators update buffer size.
274 cache->code.size -= size;
275 }
276 return found_offset;
277 }
278
279 if (cache->type == xnn_cache_type_weights) {
280 // Cache miss, weights packing functions don't update buffer size, update it here.
281 cache->weights.size += size;
282 }
283
284 const size_t offset = (uintptr_t) ptr - (uintptr_t) cache_start(cache);
285 if (!insert(cache, ptr, size)) {
286 return XNN_CACHE_NOT_FOUND;
287 }
288 return offset;
289}
290
291size_t xnn_get_or_insert_code_cache(struct xnn_code_cache* cache, void* ptr, size_t size)
292{
293 return xnn_get_or_insert_cache(&cache->cache, ptr, size);
294}
295
296enum xnn_status xnn_release_code_cache(struct xnn_code_cache* cache)
297{
298 if XNN_LIKELY(cache != NULL) {
299 assert(cache->cache.type == xnn_cache_type_code);
300 xnn_release_code_memory(&cache->cache.code);
301 xnn_release_memory(cache->cache.buckets);
302 }
303 return xnn_status_success;
304}
305
306enum xnn_status xnn_internal_init_weights_cache(
307 struct xnn_weights_cache* cache,
308 size_t num_buckets,
309 size_t buffer_size)
310{
311 memset(cache, 0, sizeof(struct xnn_weights_cache));
312
313 enum xnn_status status = xnn_status_success;
314 status = xnn_init_cache_with_size(&cache->cache, num_buckets, xnn_cache_type_weights);
315 if (status != xnn_status_success) {
316 goto error;
317 }
318
319 status = xnn_allocate_weights_memory(&cache->cache.weights, buffer_size);
320 if (status != xnn_status_success) {
321 goto error;
322 }
323
324 status = xnn_mutex_init(&cache->mutex);
325 if (status != xnn_status_success) {
326 goto error;
327 }
328
329 return xnn_status_success;
330
331error:
332 xnn_release_weights_cache(cache);
333 return status;
334}
335
336enum xnn_status xnn_init_weights_cache_with_size(struct xnn_weights_cache* cache, size_t size)
337{
338 return xnn_internal_init_weights_cache(cache, XNN_CACHE_INITIAL_BUCKETS, size);
339}
340
341enum xnn_status xnn_init_weights_cache(struct xnn_weights_cache* cache)
342{
343 return xnn_init_weights_cache_with_size(cache, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE);
344}
345
346enum xnn_status xnn_finalize_weights_cache(
347 struct xnn_weights_cache* cache,
348 enum xnn_weights_cache_finalization_kind finalization_kind)
349{
350 switch (cache->finalization_state) {
351 case xnn_cache_state_hard_finalized:
352 case xnn_cache_state_soft_finalized:
353 xnn_log_error("failed to finalize an already final weights cache");
354 return xnn_status_invalid_state;
355 case xnn_cache_state_not_finalized: {
356 enum xnn_status status;
357 enum xnn_cache_state finalized_state;
358
359 if (finalization_kind == xnn_weights_cache_finalization_kind_hard) {
360 xnn_log_debug("hard finalizing weights cache");
361 status = xnn_finalize_weights_memory(&cache->cache.weights);
362 // Also release the memory used by hash table (but not the weights memory).
363 xnn_release_memory(cache->cache.buckets);
364 cache->cache.buckets = NULL;
365 finalized_state = xnn_cache_state_hard_finalized;
366 } else {
367 xnn_log_debug("soft finalizing weights cache");
368 assert(finalization_kind == xnn_weights_cache_finalization_kind_soft);
369 // Finalize weights cache by reserving sufficient space for the insertion of the largest cached weights. This
370 // ensures that we have space to write packed weights to check for cache hits without growing and moving the
371 // memory. This has some memory overhead, which can be as large as the size of the largest cached weights,
372 // rounded up to page size.
373 status = xnn_reserve_weights_memory(&cache->cache.weights, cache->max_weights_size);
374 finalized_state = xnn_cache_state_soft_finalized;
375 }
376 if (status != xnn_status_success) {
377 xnn_log_error("failed to finalize weights cache memory");
378 return xnn_status_invalid_state;
379 }
380
381 cache->finalization_state = finalized_state;
382 return xnn_status_success;
383 }
384 }
385}
386
387enum xnn_status xnn_release_weights_cache(struct xnn_weights_cache* cache)
388{
389 if XNN_LIKELY(cache != NULL) {
390 assert(cache->cache.type == xnn_cache_type_weights);
391 xnn_release_weights_memory(&cache->cache.weights);
392 if (cache->cache.buckets != NULL) {
393 xnn_release_memory(cache->cache.buckets);
394 }
395 const enum xnn_status status = xnn_mutex_destroy(&cache->mutex);
396 if (status != xnn_status_success) {
397 return status;
398 }
399 }
400 return xnn_status_success;
401}
402
403static inline bool cache_has_space(struct xnn_weights_cache* cache, size_t n)
404{
405 const struct xnn_weights_buffer buf = cache->cache.weights;
406 return buf.size + n <= buf.capacity;
407}
408
409void* xnn_reserve_space_in_weights_cache(struct xnn_weights_cache* cache, size_t n) {
410 switch (cache->finalization_state) {
411 case xnn_cache_state_hard_finalized:
412 xnn_log_error("cannot reserve additional space in a finalized compact weights cache");
413 return NULL;
414 case xnn_cache_state_soft_finalized:
415 if (!cache_has_space(cache, n)) {
416 xnn_log_error("cannot reserve additional space in a finalized weights cache");
417 return NULL;
418 }
419 // If the cache is finalized, and has space for `n` bytes, we still want to lock the mutex, because we can have
420 // multiple writers attempting to write to this space.
421 break;
422 default:
423 break;
424 }
425
426 enum xnn_status status = xnn_mutex_lock(&cache->mutex);
427 if (status != xnn_status_success) {
428 return NULL;
429 }
430
431 struct xnn_weights_buffer* buffer = &cache->cache.weights;
432 status = xnn_reserve_weights_memory(buffer, n);
433 if (status != xnn_status_success) {
434 xnn_mutex_unlock(&cache->mutex);
435 return NULL;
436 }
437
438 return (void*) ((uintptr_t) buffer->start + buffer->size);
439}
440
441size_t xnn_get_or_insert_weights_cache(struct xnn_weights_cache* cache, void* ptr, size_t size)
442{
443 size_t offset = XNN_CACHE_NOT_FOUND;
444
445 switch (cache->finalization_state) {
446 case xnn_cache_state_hard_finalized: {
447 xnn_log_error("cannot insert into a finalized compact weights cache");
448 return XNN_CACHE_NOT_FOUND;
449 }
450 case xnn_cache_state_soft_finalized: {
451 // Inserting into a finalized weights cache is okay as long as:
452 // 1. there is sufficient space in the memory (to write the incoming packed weights), or
453 // 2. incoming packed weights is already in cache
454 if (!cache_has_space(cache, size)) {
455 xnn_log_error("insufficient extra space in finalized weights cache buffer");
456 return XNN_CACHE_NOT_FOUND;
457 }
458
459 // We need to release the mutex from this point onwards, because xnn_reserve_space_in_weights would have returned
460 // non-NULL (which means that it locked the mutex).
461 const size_t found_offset = lookup_cache(&cache->cache, ptr, size);
462 if (found_offset == XNN_CACHE_NOT_FOUND) {
463 xnn_log_error("packed weights not found in finalized weights cache");
464 }
465
466 offset = found_offset;
467 break;
468 }
469 case xnn_cache_state_not_finalized: {
470 offset = xnn_get_or_insert_cache(&cache->cache, ptr, size);
471 if (offset != XNN_CACHE_NOT_FOUND) {
472 // Found or inserted packed weights, update the largest size seen so far, this will be used when finalizing the
473 // weights cache, to ensure there is an extra space at the end for future cache checks.
474 cache->max_weights_size = max(size, cache->max_weights_size);
475 }
476 break;
477 }
478 }
479
480 // Mutex is locked in xnn_reserve_space_in_weights_cache when it returns non-NULL, i.e. when cache is not finalized,
481 // or if it is xnn_cache_state_soft_finalized and has sufficient space.
482 const enum xnn_status status = xnn_mutex_unlock(&cache->mutex);
483 (void) status;
484 assert(status == xnn_status_success);
485 return offset;
486}
487
488bool xnn_weights_cache_is_finalized(struct xnn_weights_cache* cache) {
489 return cache->finalization_state != xnn_cache_state_not_finalized;
490}
491