1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_MEMORY_TRACKING_HPP |
18 | #define COMMON_MEMORY_TRACKING_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <unordered_map> |
22 | |
23 | #include "memory_debug.hpp" |
24 | #include "memory_storage.hpp" |
25 | #include "nstl.hpp" |
26 | #include "utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | |
31 | struct exec_ctx_t; |
32 | |
33 | namespace memory_tracking { |
34 | |
35 | /* Memory tracking capabilities |
36 | * |
37 | * The main purpose of this header file is to provide uniform way to register |
38 | * required memory for a scratchpad at a primitive descriptor creation time |
39 | * and then easily access it having only the base address of the scratchpad. |
40 | * |
41 | * Primitives might contain multiple disjoint parts that require temporary |
42 | * buffers (known as scratchpad) during their execution. A primitive descriptor |
43 | * should summarize all the needs into one single number -- the buffer size |
44 | * that would be requested from a user. At execution time, the corresponding |
45 | * primitive will receive a base pointer to a scratchpad. It then needs to |
46 | * provide each part of algorithm the corresponding piece of memory. Three main |
47 | * challenges here are: |
48 | * 1. Track correct offset (from the base scratchpad address) for each piece |
49 | * 2. Algorithm might require that different memory pieces to be aligned, so |
50 | * the scratchpad size is no more just a sum of size of the corresponding |
51 | * subparts. |
52 | * 3. While a primitive is responsible for its scratchpad, the implementation |
53 | * might use some other basic blocks (e.g. cpu_reducer) that also require |
54 | * scratchpad memory. So there should be a simple way of passing the |
55 | * information back and force between the main algorithm (a primitive) and |
56 | * auxiliary stuff that lives completely separately from it (e.g. reducer). |
57 | * |
58 | * To address these challenges this header file provides 3 structures: |
59 | * 1. registry_t -- the class the stores the information about requested |
60 | * memory. The information includes required size and desired |
61 | * alignment for each piece. This class is also responsible |
62 | * for computing the right offset to a given piece using the |
63 | * base pointer. |
64 | * This class is basically a ledger with all entries. |
65 | * Lives in primitive descriptors. |
66 | * |
67 | * 2. registrar_t -- the interface to a registry_t to book memory. Used at |
68 | * primitive descriptor creation time only. Contains a |
69 | * reference to the corresponding *mutable* registry. |
70 | * Always modifiable. |
71 | * Allows chaining (using prefixes). |
72 | * |
73 | * 3. grantor_t -- the interface to a registry_t to access memory. Used at |
74 | * primitive execution time only. Contains a reference to |
75 | * the corresponding *constant* registry and base pointer. |
76 | * Always constant. |
77 | * Allows chaining (using prefixes). |
78 | * |
79 | * Both registrar_t and grantor_t allow chaining with extra prefix provided. |
80 | * The feature is useful when a primitive offload a part of computations to |
81 | * some other primitives which require their own scratchpad space |
82 | * (e.g. reducer). Prefixes are used to avoid key collision in cases when |
83 | * multiple sub-primitive (e.g. multiple reducers) are used. |
84 | * |
85 | * A short example below demonstrates how to use aforementioned classes. In it |
86 | * the main primitive is convolution that uses scratchpad for keeping padded |
87 | * bias. It also needs a reducer, that needs its own space as well. |
88 | * |
89 | * ``` c++ |
90 | * struct reducer_t { |
91 | * static void init(registrar_t &scratchpad) { |
92 | * // reserve space for 980*1024 floats (one page aligned) |
93 | * scratchpad.book<float>(key_space, 980 * 1024, 4096); |
94 | * } |
95 | * |
96 | * void exec(const grantor_t &scratchpad) { |
97 | * // get the pointer to preserved space. scratchpad came from |
98 | * // upper primitive (convolution in this example) |
99 | * auto space = scratchpad.get<float>(key_reducer_space); |
100 | * |
101 | * space[:] += ...; |
102 | * } |
103 | * }; |
104 | * |
105 | * struct conv_t { |
106 | * struct pd_t { |
107 | * void init() { |
108 | * registrar_t scratchpad(scratchpad_registry_); |
109 | * |
110 | * // reserve space for 128 elements which are two bytes long that |
111 | * // require 4 byte alignment, but preferably have 64 byte |
112 | * // alignment for performance reasons |
113 | * // two alignment parameters are included for implementation |
114 | * // flexibility targeted at memory debugging purposes |
115 | * scratchpad.book(key_conv_padded_bias, 128, 2, 4, 64); |
116 | * |
117 | * // create a proxy registrar for the reducer All entries made |
118 | * // by reducer would live in convolution's registry, but would |
119 | * // have their own `prefix`, so no interference with conv's |
120 | * // buffers. |
121 | * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); |
122 | * |
123 | * reducer_t::init(reducer_scratchpad); |
124 | * } |
125 | * |
126 | * registry_t scratchpad_registry_; |
127 | * } |
128 | * |
129 | * void exec() { |
130 | * // get the base pointer to a scratchpad memory from a user |
131 | * void *scratchpad_ptr = this->input(DNNL_MEM_SCRATCHPAD); |
132 | * |
133 | * // create a grantor to the scratchpad (and provide the base |
134 | * // pointer). |
135 | * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); |
136 | * |
137 | * // access the padded_bias (need only key name and the grantor) |
138 | * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); |
139 | * |
140 | * // to give the `right` grantor to reducer we need to add the |
141 | * // corresponding prefix, so that reducer would be able to access |
142 | * // its keys. The call is very similar to the one in pd_t::init |
143 | * // with only difference in types: grantor_t vs registrar_t. |
144 | * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); |
145 | * reducer->exec(reducer_scratchpad); |
146 | * } |
147 | * }; |
148 | * ``` |
149 | */ |
150 | |
151 | /* namespace with common keys and prefixes */ |
152 | namespace names { |
153 | enum { |
154 | key_none = 0, |
155 | key_barrier, |
156 | key_bnorm_cvt, |
157 | key_bnorm_tmp_mean, |
158 | key_bnorm_tmp_var, |
159 | key_bnorm_tmp_diff_ss, |
160 | key_bnorm_tmp_stats, |
161 | key_bnorm_reduction, |
162 | key_brgemm_primitive_batch, |
163 | key_brgemm_primitive_buffer, |
164 | key_brgemm_primitive_buffer_a, |
165 | key_brgemm_primitive_buffer_b, |
166 | key_brgemm_primitive_buffer_comp, |
167 | key_brgemm_primitive_zp_comp_a, |
168 | key_brgemm_primitive_zp_comp_b, |
169 | key_concat_iptrs, |
170 | key_concat_istrides, |
171 | key_concat_nelems, |
172 | key_concat_optrs, |
173 | key_concat_tent_dst, |
174 | key_conv_adjusted_scales, |
175 | key_conv_amx_inp_buffer, |
176 | key_conv_amx_tilecfg, |
177 | key_conv_amx_tile_buffer, |
178 | key_conv_amx_wei_buffer, |
179 | key_conv_amx_wsp_buffer, |
180 | key_conv_bia_reduction, |
181 | key_conv_bias_bf16_convert_wsp, |
182 | key_conv_cudnn, |
183 | key_conv_cudnn_algo, |
184 | key_conv_cudnn_filter, |
185 | key_conv_cudnn_temp, |
186 | key_conv_dst_bf16_convert_wsp, |
187 | key_conv_brgemm_addr_a, |
188 | key_conv_brgemm_addr_b, |
189 | key_conv_brgemm_batch, |
190 | key_conv_brgemm_buffer, |
191 | key_conv_brgemm_inp_buffer, |
192 | key_conv_brgemm_inp_buffer_mask, |
193 | key_conv_bwd_w_1st_bia_reorder, |
194 | key_conv_bwd_w_1st_wei_reorder, |
195 | key_conv_gemm_acc, |
196 | key_conv_gemm_col, |
197 | key_conv_gemm_imtr, |
198 | key_conv_gemm_zp_src_comp, |
199 | key_conv_int_dat_in_acc_dt, |
200 | key_conv_padded_bias, |
201 | key_conv_rtus_space, |
202 | key_conv_store_wsp, |
203 | key_conv_tails, |
204 | key_conv_tr_diff_dst, |
205 | key_conv_tr_diff_dst_bctx, |
206 | key_conv_tr_src, |
207 | key_conv_tr_src_bctx, |
208 | key_conv_wei_reduction, |
209 | key_conv_wei_bia_reduction, |
210 | key_conv_wei_bia_reduction_bctx, |
211 | key_conv_zero_point_flag, |
212 | key_conv_zero_point_pad, |
213 | key_deconv_bias, |
214 | key_deconv_sum, |
215 | key_deconv_zp, |
216 | key_eltwise_diff_dst, |
217 | key_eltwise_src, |
218 | key_fusion_forward_scratchpad, |
219 | key_fusion_inout_buffer, |
220 | key_gemm_int_c_in_acc_dt, |
221 | key_gemm_tmp_buffer, |
222 | key_gemm_flag, |
223 | key_iprod_bias_bf16_convert_wsp, |
224 | key_iprod_dst_bf16_convert_wsp, |
225 | key_iprod_dst_reorder, |
226 | key_iprod_int_dat_in_acc_dt, |
227 | key_lnorm_inv_sqrtvar, |
228 | key_lnorm_tmp_mean, |
229 | key_lnorm_tmp_var, |
230 | key_lnorm_tmp_diff_ss, |
231 | key_lnorm_reduction, |
232 | key_matmul_dst_in_acc_dt, |
233 | key_pool_dst_bf16cvt, |
234 | key_pool_dst_plain2blocked_cvt, |
235 | key_pool_ind_plain2blocked_cvt, |
236 | key_pool_src_bf16cvt, |
237 | key_pool_src_plain2blocked_cvt, |
238 | key_precomputed_scales, |
239 | key_prelu_reduction, |
240 | key_reducer_space, |
241 | key_reducer_space_bctx, |
242 | key_reduction, |
243 | key_reduction_1, |
244 | key_reorder_cross_space, |
245 | key_reorder_space, |
246 | key_reorder_src_scales, |
247 | key_reorder_dst_scales, |
248 | key_reorder_wino_plain, |
249 | key_reorder_wino_transform_space, |
250 | key_reorder_precomputed_dst_scales, |
251 | key_reorder_rnn_space, |
252 | key_reorder_rnn_weights_bf16_cvt, |
253 | key_reorder_rnn_weights_quantization, |
254 | key_reorder_rnn_weights_reduction, |
255 | key_reorder_rnn_weights_transposition, |
256 | key_rnn_space, |
257 | key_rnn_bf32_attention_trans, |
258 | key_rnn_bf32_wei_layer_trans, |
259 | key_rnn_bf32_wei_iter_trans, |
260 | key_rnn_cell, |
261 | key_rnn_diff_states, |
262 | key_rnn_gates, |
263 | key_rnn_gates_blocked, |
264 | key_rnn_src_layer_trans, |
265 | key_rnn_src_iter_trans, |
266 | key_rnn_ht, |
267 | key_rnn_diff_ht, |
268 | key_rnn_ptrs_bia, |
269 | key_rnn_ptrs_wei_layer, |
270 | key_rnn_ptrs_wei_iter, |
271 | key_rnn_ptrs_wei_projection, |
272 | key_softmax_reduction, |
273 | key_softmax_interim_store, |
274 | key_sum_reduction, |
275 | key_sum_srcs_cvt, |
276 | key_wino_U, |
277 | key_wino_V, |
278 | key_wino_M, |
279 | // These two keys should always be the last ones, |
280 | // even though they are not in alphabetical order |
281 | key_nested, |
282 | key_nested_multiple, |
283 | }; |
284 | |
285 | enum { |
286 | prefix_none = 0, |
287 | prefix_fusion, |
288 | prefix_reducer_bia, |
289 | prefix_reducer_wei, |
290 | }; |
291 | } // namespace names |
292 | |
293 | // level 0: 00 00 00 xxx |
294 | // level 1: 00 00 aa xxx |
295 | // level 2: 00 aa bb xxx |
296 | // level 3: aa bb cc xxx |
297 | // max # of levels: 3 + 1 (base_level) |
298 | // here: |
299 | // xxx : [1 .. MAX_KEY) : key |
300 | // aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 |
301 | |
302 | using key_t = uint32_t; |
303 | enum { |
304 | MAX_KEY = (1u << 10), |
305 | MAX_PREFIX = (1u << 7), |
306 | }; |
307 | |
308 | /// generates global key based on a prefix and a local key |
309 | inline key_t make_key(key_t prefix, key_t key) { |
310 | return prefix + key; |
311 | } |
312 | |
313 | /// generates global prefix based on the global parent and the local ones |
314 | inline key_t make_prefix(key_t parent_prefix, key_t prefix) { |
315 | return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; |
316 | } |
317 | |
318 | struct registrar_t; |
319 | struct grantor_t; |
320 | |
321 | enum { default_alignment = 128 }; |
322 | inline size_t get_alignment(size_t alignment) { |
323 | size_t minimal_alignment |
324 | = memory_debug::is_mem_debug() ? getpagesize() : default_alignment; |
325 | return nstl::max<size_t>(alignment, minimal_alignment); |
326 | } |
327 | |
328 | inline size_t buffer_protect_size() { |
329 | return memory_debug::is_mem_debug() |
330 | ? memory_debug::protect_size() + getpagesize() |
331 | : 0; |
332 | } |
333 | |
334 | struct registry_t { |
335 | struct entry_t { |
336 | size_t offset, size, capacity, alignment; |
337 | |
338 | // apply offset and alignment + check memory_debug (host/cpu only) |
339 | const void *compute_ptr(const void *base_ptr) const; |
340 | }; |
341 | |
342 | // perf_align is the desired alignment for performance. |
343 | // data_align is the minimum data alignment required for functionality, |
344 | // this parameter is included for memory debugging purposes. |
345 | void book(const key_t &key, size_t size, size_t data_align, |
346 | size_t perf_align = default_alignment) { |
347 | if (size == 0) return; |
348 | assert(offset_map_.count(key) == 0); |
349 | size_t alignment = memory_debug::is_mem_debug() |
350 | ? data_align |
351 | : nstl::max(data_align, perf_align); |
352 | |
353 | if (memory_debug::is_mem_debug() && size_ == 0) |
354 | size_ += get_alignment(alignment) + buffer_protect_size(); |
355 | |
356 | assert(alignment > 0 && (alignment & (alignment - 1)) == 0); |
357 | size_t capacity |
358 | = size + get_alignment(alignment) + buffer_protect_size(); |
359 | assert(capacity < (SIZE_MAX + INT_MIN)); |
360 | offset_map_[key] = entry_t {size_, size, capacity, alignment}; |
361 | |
362 | size_ += capacity; |
363 | } |
364 | |
365 | entry_t get(const key_t &key) const { |
366 | if (size() == 0 || offset_map_.count(key) != 1) |
367 | return entry_t {0, 0, 0, 0}; |
368 | return offset_map_.at(key); |
369 | } |
370 | |
371 | size_t size() const { return size_; } |
372 | |
373 | registrar_t registrar(); |
374 | grantor_t grantor(const memory_storage_t *mem_storage, |
375 | const exec_ctx_t &exec_ctx) const; |
376 | |
377 | template <typename return_type> |
378 | class common_iterator_t { |
379 | private: |
380 | const void *base_ptr; |
381 | std::unordered_map<key_t, entry_t>::const_iterator iter; |
382 | |
383 | public: |
384 | common_iterator_t(const void *base_ptr_, |
385 | const std::unordered_map<key_t, entry_t> &map, |
386 | bool is_begin = true) { |
387 | base_ptr = base_ptr_; |
388 | if (is_begin) { |
389 | iter = map.cbegin(); |
390 | } else { |
391 | iter = map.cend(); |
392 | } |
393 | } |
394 | common_iterator_t &operator++(int) { |
395 | iter++; |
396 | return *this; |
397 | } |
398 | bool operator==(const common_iterator_t &rhs) const { |
399 | return iter == rhs.iter; |
400 | } |
401 | bool operator!=(const common_iterator_t &rhs) const { |
402 | return iter != rhs.iter; |
403 | } |
404 | std::pair<return_type, size_t> operator*() const { |
405 | const entry_t &entry = iter->second; |
406 | const void *ptr_start = entry.compute_ptr(base_ptr); |
407 | return std::pair<return_type, size_t> { |
408 | (return_type)ptr_start, entry.size}; |
409 | } |
410 | }; |
411 | typedef common_iterator_t<void *> iterator; |
412 | typedef common_iterator_t<const void *> const_iterator; |
413 | iterator begin(void *base_ptr_) const { |
414 | return iterator(base_ptr_, offset_map_); |
415 | } |
416 | iterator end(void *base_ptr_) const { |
417 | return iterator(base_ptr_, offset_map_, false); |
418 | } |
419 | const_iterator cbegin(const void *base_ptr_) const { |
420 | return const_iterator(base_ptr_, offset_map_); |
421 | } |
422 | const_iterator cend(const void *base_ptr_) const { |
423 | return const_iterator(base_ptr_, offset_map_, false); |
424 | } |
425 | |
426 | protected: |
427 | std::unordered_map<key_t, entry_t> offset_map_; |
428 | size_t size_ = 0; |
429 | }; |
430 | |
431 | struct registrar_t { |
432 | registrar_t(registry_t ®istry) : registry_(registry), prefix_(0) {} |
433 | registrar_t(registrar_t &parent, const key_t &prefix) |
434 | : registry_(parent.registry_) |
435 | , prefix_(make_prefix(parent.prefix_, prefix)) {} |
436 | |
437 | void book(const key_t &key, size_t nelems, size_t data_size, |
438 | size_t data_align = 0, size_t perf_align = default_alignment) { |
439 | assert(nelems < (SIZE_MAX + INT_MIN)); |
440 | if (data_align == 0) data_align = data_size; |
441 | registry_.book(make_key(prefix_, key), nelems * data_size, data_align, |
442 | perf_align); |
443 | } |
444 | template <typename T> |
445 | void book(const key_t &key, size_t nelems, |
446 | size_t perf_align = default_alignment) { |
447 | registry_.book(make_key(prefix_, key), nelems * sizeof(T), alignof(T), |
448 | perf_align); |
449 | } |
450 | |
451 | void book(const key_t &key, const registry_t ®istry, |
452 | size_t perf_align = default_alignment) { |
453 | registry_.book(make_key(prefix_, key), registry.size(), 1, perf_align); |
454 | } |
455 | |
456 | size_t size() const { return registry_.size(); } |
457 | |
458 | protected: |
459 | registry_t ®istry_; |
460 | const key_t prefix_; |
461 | }; |
462 | |
463 | struct grantor_t { |
464 | grantor_t(const registry_t ®istry, |
465 | const memory_storage_t *base_mem_storage, |
466 | const exec_ctx_t &exec_ctx) |
467 | : registry_(registry) |
468 | , prefix_(0) |
469 | , base_mem_storage_(base_mem_storage) |
470 | , exec_ctx_(&exec_ctx) {} |
471 | grantor_t(const grantor_t &parent, const key_t &prefix) |
472 | : registry_(parent.registry_) |
473 | , prefix_(make_prefix(parent.prefix_, prefix)) |
474 | , base_mem_storage_(parent.base_mem_storage_) |
475 | , exec_ctx_(parent.exec_ctx_) {} |
476 | |
477 | template <typename T = void> |
478 | T *get(const key_t &key, size_t *size = nullptr) const { |
479 | if (!base_mem_storage_) { |
480 | assert(registry_.size() == 0); |
481 | return nullptr; |
482 | } |
483 | auto e = registry_.get(make_key(prefix_, key)); |
484 | |
485 | if (size) *size = e.size; |
486 | if (e.size == 0) return nullptr; |
487 | |
488 | char *host_storage_ptr = get_host_storage_ptr(base_mem_storage_); |
489 | char *base_ptr = host_storage_ptr + base_mem_storage_->base_offset(); |
490 | return (T *)e.compute_ptr(base_ptr); |
491 | } |
492 | |
493 | std::unique_ptr<memory_storage_t> get_memory_storage( |
494 | const key_t &key) const { |
495 | if (!base_mem_storage_) { |
496 | assert(registry_.size() == 0); |
497 | return nullptr; |
498 | } |
499 | auto e = registry_.get(make_key(prefix_, key)); |
500 | if (e.size == 0) return nullptr; |
501 | |
502 | if (is_cpu_engine(base_mem_storage_)) { |
503 | char *host_storage_ptr = get_host_storage_ptr(base_mem_storage_); |
504 | char *base_ptr |
505 | = host_storage_ptr + base_mem_storage_->base_offset(); |
506 | char *aligned_ptr = (char *)e.compute_ptr(base_ptr); |
507 | size_t aligned_offset = size_t(aligned_ptr - host_storage_ptr); |
508 | return base_mem_storage_->get_sub_storage(aligned_offset, e.size); |
509 | } |
510 | |
511 | const size_t aligned_offset |
512 | = reinterpret_cast<size_t>(utils::align_ptr<char>( |
513 | reinterpret_cast<char *>(e.offset), e.alignment)); |
514 | assert(aligned_offset + e.size <= registry_.size()); |
515 | return base_mem_storage_->get_sub_storage(aligned_offset, e.size); |
516 | } |
517 | |
518 | const memory_storage_t *get_base_storage() const { |
519 | return base_mem_storage_; |
520 | } |
521 | const registry_t &get_registry() const { return registry_; } |
522 | |
523 | protected: |
524 | const registry_t ®istry_; |
525 | const key_t prefix_; |
526 | const memory_storage_t *base_mem_storage_; |
527 | const exec_ctx_t *exec_ctx_; |
528 | |
529 | private: |
530 | char *get_host_storage_ptr(const memory_storage_t *storage) const; |
531 | bool is_cpu_engine(const memory_storage_t *mem_storage) const; |
532 | }; |
533 | |
534 | inline registrar_t registry_t::registrar() { |
535 | return registrar_t(*this); |
536 | } |
537 | inline grantor_t registry_t::grantor( |
538 | const memory_storage_t *mem_storage, const exec_ctx_t &exec_ctx) const { |
539 | return grantor_t(*this, mem_storage, exec_ctx); |
540 | } |
541 | |
542 | } // namespace memory_tracking |
543 | } // namespace impl |
544 | } // namespace dnnl |
545 | |
546 | #endif |
547 | |