1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_MEMORY_DESC_WRAPPER_HPP |
18 | #define COMMON_MEMORY_DESC_WRAPPER_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "nstl.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | #include "type_helpers.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | |
31 | /** thin wrapper class over \struct memory_desc_t which allows easy |
32 | * manipulations with underlying C structure, which is taken by reference */ |
33 | struct memory_desc_wrapper : public c_compatible { |
34 | const memory_desc_t *md_; |
35 | |
36 | /** constructor which takes a reference to a constant underlying C memory |
37 | * descriptor \param md */ |
38 | memory_desc_wrapper(const memory_desc_t *md) |
39 | : md_(md ? md : &glob_zero_md) {} |
40 | memory_desc_wrapper(const memory_desc_t &md) : memory_desc_wrapper(&md) {} |
41 | |
42 | /* implementing attributes */ |
43 | int ndims() const { return md_->ndims; } |
44 | const dims_t &dims() const { return md_->dims; } |
45 | data_type_t data_type() const { return md_->data_type; } |
46 | |
47 | const dims_t &padded_dims() const { return md_->padded_dims; } |
48 | const dims_t &padded_offsets() const { return md_->padded_offsets; } |
49 | dim_t offset0() const { return md_->offset0; } |
50 | |
51 | format_kind_t format_kind() const { return md_->format_kind; } |
52 | |
53 | bool is_blocking_desc() const { |
54 | return format_kind() == format_kind::blocked; |
55 | } |
56 | bool is_wino_desc() const { return format_kind() == format_kind::wino; } |
57 | bool is_rnn_packed_desc() const { |
58 | return format_kind() == format_kind::rnn_packed; |
59 | } |
60 | |
61 | const blocking_desc_t &blocking_desc() const { |
62 | assert(is_blocking_desc()); |
63 | return md_->format_desc.blocking; |
64 | } |
65 | const wino_desc_t &wino_desc() const { |
66 | assert(is_wino_desc()); |
67 | return md_->format_desc.wino_desc; |
68 | } |
69 | const rnn_packed_desc_t &rnn_packed_desc() const { |
70 | assert(is_rnn_packed_desc()); |
71 | return md_->format_desc.rnn_packed_desc; |
72 | } |
73 | |
74 | const memory_extra_desc_t &() const { return md_->extra; } |
75 | |
76 | /* some useful function */ |
77 | |
78 | /** returns the number of elements including padding if \param with_padding |
79 | * is true, and the number of data elements otherwise */ |
80 | dim_t nelems(bool with_padding = false) const { |
81 | if (is_zero()) return 0; |
82 | if (has_runtime_dims()) return DNNL_RUNTIME_DIM_VAL; |
83 | return utils::array_product( |
84 | with_padding ? padded_dims() : dims(), ndims()); |
85 | } |
86 | |
87 | /** returns true if memory descriptor is zero */ |
88 | bool is_zero() const { return ndims() == 0; } |
89 | |
90 | /** returns true if memory descriptor contains zero as one of its dim */ |
91 | bool has_zero_dim() const { |
92 | for (int d = 0; d < ndims(); ++d) |
93 | if (dims()[d] == 0) return true; |
94 | return false; |
95 | } |
96 | |
97 | /** return the size of data type (a shortcut) */ |
98 | size_t data_type_size() const { return types::data_type_size(data_type()); } |
99 | |
100 | /** return the size of data type of additional buffer */ |
101 | size_t additional_buffer_data_size(uint64_t flag_select) const { |
102 | using namespace memory_extra_flags; |
103 | if (flag_select & compensation_conv_s8s8) return sizeof(int32_t); |
104 | if ((flag_select & rnn_u8s8_compensation) |
105 | && !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select)) |
106 | return sizeof(float); |
107 | if (flag_select & compensation_conv_asymmetric_src) |
108 | return sizeof(int32_t); |
109 | return 0; |
110 | } |
111 | |
112 | /** return true if memory format has additional buffer */ |
113 | bool is_additional_buffer() const { |
114 | using namespace memory_extra_flags; |
115 | // Currently compensation is not required for rnn_s8s8_compensation, |
116 | // but it has common bit with rnn_u8s8_compensation constant so we have |
117 | // to exclude rnn_s8s8_compensation case explicitly |
118 | return ((extra().flags |
119 | & (compensation_conv_s8s8 | rnn_u8s8_compensation |
120 | | compensation_conv_asymmetric_src)) |
121 | && !types::extra_flag_rnn_s8s8_compensation_is_set( |
122 | extra().flags)); |
123 | } |
124 | |
125 | /** returns the size required for a particular extra memory buffer */ |
126 | size_t (memory_extra_flags_t flag) const { |
127 | using namespace memory_extra_flags; |
128 | |
129 | auto calculate_size = [=](int cmask, size_t buff_data_size) { |
130 | assert(utils::one_of(cmask, 1, 2, 3, 5, 13, 27)); |
131 | dim_t prod = 1; |
132 | for (int d = 0; d < ndims(); ++d) |
133 | if (cmask & (1 << d)) { prod *= padded_dims()[d]; } |
134 | return (size_t)prod * buff_data_size; |
135 | }; |
136 | |
137 | if (extra().flags & compensation_conv_s8s8) { |
138 | return calculate_size(extra().compensation_mask, |
139 | additional_buffer_data_size(flag)); |
140 | } |
141 | |
142 | if ((extra().flags & rnn_u8s8_compensation) |
143 | && !types::extra_flag_rnn_s8s8_compensation_is_set( |
144 | extra().flags)) { |
145 | return calculate_size(extra().compensation_mask, |
146 | additional_buffer_data_size(flag)); |
147 | } |
148 | if (extra().flags & compensation_conv_asymmetric_src) { |
149 | return calculate_size(extra().asymm_compensation_mask, |
150 | additional_buffer_data_size(flag)); |
151 | } |
152 | |
153 | return 0; |
154 | } |
155 | |
156 | /** returns the size of the appended buffer when the memory descriptor |
157 | * requires extra space to hold compensation data */ |
158 | size_t additional_buffer_size() const { |
159 | using namespace memory_extra_flags; |
160 | |
161 | size_t buff_size = 0; |
162 | buff_size += additional_buffer_size(compensation_conv_s8s8); |
163 | buff_size += additional_buffer_size(rnn_u8s8_compensation); |
164 | buff_size += additional_buffer_size(compensation_conv_asymmetric_src); |
165 | return buff_size; |
166 | } |
167 | |
168 | /** returns the size required to store described memory |
169 | * note: if offset0 != 0 returns 0 (need to specify the behavior) */ |
170 | size_t size() const { |
171 | if (utils::one_of(format_kind(), format_kind::undef, format_kind::any) |
172 | || is_zero() || has_zero_dim()) |
173 | return 0; |
174 | |
175 | if (has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL; |
176 | |
177 | if (format_kind() == format_kind::wino) { |
178 | return wino_desc().size; |
179 | } else if (format_kind() == format_kind::rnn_packed) { |
180 | return rnn_packed_desc().size; |
181 | } else { |
182 | if (offset0() != 0) return 0; |
183 | |
184 | dims_t blocks = {0}; |
185 | compute_blocks(blocks); |
186 | |
187 | const auto &bd = blocking_desc(); |
188 | |
189 | size_t max_size = 0; |
190 | for (int d = 0; d < ndims(); ++d) { |
191 | dim_t strided_pdim = padded_dims()[d] / blocks[d]; |
192 | dim_t effective_stride = strided_pdim == 1 ? 1 : bd.strides[d]; |
193 | max_size = nstl::max<size_t>( |
194 | max_size, strided_pdim * effective_stride); |
195 | } |
196 | |
197 | if (max_size == 1 && bd.inner_nblks != 0) { |
198 | max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); |
199 | } |
200 | |
201 | size_t data_size = max_size * data_type_size(); |
202 | if (is_additional_buffer()) { |
203 | // The additional buffers, typically of data type int32_t, float |
204 | // are stored at the end of data. Pad the data, so that the |
205 | // buffers are properly aligned to their data type. |
206 | const size_t alignment_in_bytes = 4; |
207 | data_size = utils::rnd_up(data_size, alignment_in_bytes); |
208 | } |
209 | return data_size + additional_buffer_size(); |
210 | } |
211 | } |
212 | |
213 | /** returns the true if some dim is broadcasted (stride == 0) */ |
214 | bool has_broadcast() const { |
215 | const auto &bd = blocking_desc(); |
216 | for (int d = 0; d < ndims(); d++) |
217 | if (bd.strides[d] == 0) return true; |
218 | return false; |
219 | } |
220 | |
221 | /** returns true if number of non unit dims is <= `n`. */ |
222 | bool count_non_unit_dims(int n) const { |
223 | int non_unit_dims = 0; |
224 | for (int d = 0; d < ndims(); d++) { |
225 | if (dims()[d] != 1) non_unit_dims++; |
226 | } |
227 | return non_unit_dims <= n; |
228 | } |
229 | |
230 | /** returns true if data is dense in memory */ |
231 | bool is_dense(bool with_padding = false) const { |
232 | if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) |
233 | return false; |
234 | if (has_runtime_dims_or_strides() || has_broadcast()) return false; |
235 | return nelems(with_padding) * data_type_size() == size(); |
236 | } |
237 | |
238 | /** returns true if format is set to `any` */ |
239 | bool format_any() const { return format_kind() == format_kind::any; } |
240 | |
241 | /** returns true if at least one dim is not known */ |
242 | bool has_runtime_dims() const { |
243 | for (int d = 0; d < ndims(); ++d) |
244 | if (dims()[d] == DNNL_RUNTIME_DIM_VAL) return true; |
245 | return false; |
246 | } |
247 | |
248 | /** returns true if at least one dim is not known */ |
249 | bool has_runtime_strides() const { |
250 | if (!is_blocking_desc()) return false; |
251 | for (int d = 0; d < ndims(); ++d) |
252 | if (blocking_desc().strides[d] == DNNL_RUNTIME_DIM_VAL) return true; |
253 | return false; |
254 | } |
255 | |
256 | /** returns true if memory format is runtime_dims_or_strides-dependent */ |
257 | bool has_runtime_dims_or_strides() const { |
258 | return has_runtime_dims() || has_runtime_strides(); |
259 | } |
260 | |
261 | /** returns true if the only (potentially) padded dim is \param dim */ |
262 | bool only_padded_dim(int dim) const { |
263 | if (has_runtime_dims()) return false; |
264 | for (int d = 0; d < ndims(); ++d) |
265 | if (d != dim && dims()[d] != padded_dims()[d]) return false; |
266 | return true; |
267 | } |
268 | |
269 | /** returns true if memory desc has blocked layout and block dims are 1s */ |
270 | bool is_plain() const { |
271 | if (!is_blocking_desc()) return false; |
272 | return blocking_desc().inner_nblks == 0; |
273 | } |
274 | |
275 | /** returns overall block sizes */ |
276 | void compute_blocks(dims_t blocks) const { |
277 | if (!is_blocking_desc()) { |
278 | utils::array_set(blocks, 0, ndims()); |
279 | return; |
280 | } |
281 | |
282 | utils::array_set(blocks, 1, ndims()); |
283 | |
284 | const auto &bd = blocking_desc(); |
285 | for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) |
286 | blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; |
287 | } |
288 | |
289 | // XXX: for backward compatibility with v0.x |
290 | // strides_compat[0]: stride between the first elements of adjacent blocks |
291 | // strides_compat[1]: strides between elements in the same block |
292 | // |
293 | // For 2+ level blocking all inner blocks are treated as a single block. |
294 | void compute_strides_compat(dims_t *strides_compat) const; |
295 | |
296 | /* comparison section */ |
297 | |
298 | bool operator==(const memory_desc_wrapper &rhs) const { |
299 | return *this->md_ == *rhs.md_; |
300 | } |
301 | bool operator!=(const memory_desc_wrapper &rhs) const { |
302 | return !operator==(rhs); |
303 | } |
304 | bool operator==(const memory_desc_t &rhs) const { |
305 | return operator==(memory_desc_wrapper(rhs)); |
306 | } |
307 | bool operator!=(const memory_desc_t &rhs) const { return !operator==(rhs); } |
308 | |
309 | /** returns true if data (w/o padding if with_padding == false and w/ |
310 | * padding otherwise) have the same physical structure, i.e. dimensions, |
311 | * strides, and blocked structure. Depending on with_data_type flag |
312 | * data_type is taken or not taken into account. dim_start allows to check |
313 | * similarity for the logical part of data [dim_start .. ndims()]. |
314 | * CAUTION: format kind any and undef are not similar to whatever, hence the |
315 | * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ |
316 | /* TODO: revise */ |
317 | bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true, |
318 | bool with_data_type = true, int dim_start = 0) const; |
319 | |
320 | /** returns true if one memory can be reordered to another */ |
321 | bool consistent_with(const memory_desc_wrapper &rhs) const; |
322 | |
323 | /** returns true if the memory desc corresponds to the given format tag and |
324 | * strides. |
325 | * @sa memory_desc_matches_tag */ |
326 | bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { |
327 | return memory_desc_matches_tag(*md_, tag, strides); |
328 | } |
329 | |
330 | /** returns matching tag (or undef if match is not found) |
331 | * XXX: This is a workaround that eventually should go away! */ |
332 | template <typename... Tags> |
333 | format_tag_t matches_one_of_tag(Tags... tags) const { |
334 | for (const auto tag : {tags...}) { |
335 | if (memory_desc_matches_tag(*md_, tag)) return tag; |
336 | } |
337 | return format_tag::undef; |
338 | } |
339 | |
340 | /* offset section */ |
341 | |
342 | /** returns physical offset by logical one. logical offset is represented by |
343 | * an array \param pos. if \param is_pos_padded is true \param pos |
344 | * represents the position in already padded area */ |
345 | dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { |
346 | assert(is_blocking_desc()); |
347 | const blocking_desc_t &blk = blocking_desc(); |
348 | |
349 | dims_t pos_copy = {0}; |
350 | for (int d = 0; d < ndims(); ++d) |
351 | pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); |
352 | |
353 | dim_t phys_offset = offset0(); |
354 | |
355 | if (blk.inner_nblks > 0) { |
356 | dim_t blk_stride = 1; |
357 | for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { |
358 | const int d = blk.inner_idxs[iblk]; |
359 | |
360 | dim_t p; |
361 | /* switch to faster 32-bit division when possible. |
362 | * inner blocks always fit 32-bit. */ |
363 | if (pos_copy[d] <= INT32_MAX) { |
364 | p = (int32_t)pos_copy[d] % (int32_t)blk.inner_blks[iblk]; |
365 | pos_copy[d] = (int32_t)pos_copy[d] |
366 | / (int32_t)blk.inner_blks[iblk]; |
367 | } else { |
368 | p = pos_copy[d] % blk.inner_blks[iblk]; |
369 | pos_copy[d] /= blk.inner_blks[iblk]; |
370 | } |
371 | |
372 | phys_offset += p * blk_stride; |
373 | |
374 | blk_stride *= blk.inner_blks[iblk]; |
375 | } |
376 | } |
377 | |
378 | for (int d = 0; d < ndims(); ++d) { |
379 | const dim_t p = pos_copy[d]; |
380 | phys_offset += p * blk.strides[d]; |
381 | } |
382 | |
383 | return phys_offset; |
384 | } |
385 | |
386 | /** returns physical offset by logical one. logical offset is represented by |
387 | * a scalar \param l_offset. if \param is_pos_padded is true, \param |
388 | * l_offset represents logical offset in already padded area */ |
389 | dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { |
390 | dims_t dims_pos; |
391 | const auto &cur_dims = is_pos_padded ? padded_dims() : dims(); |
392 | utils::l_dims_by_l_offset(dims_pos, l_offset, cur_dims, ndims()); |
393 | return off_v(dims_pos, is_pos_padded); |
394 | } |
395 | |
396 | /** returns physical offset by logical one. logical offset is represented by |
397 | * a tuple of indices (\param xn, ..., \param x1, \param x0) */ |
398 | template <typename... Args> |
399 | dim_t off(Args... args) const { |
400 | assert(sizeof...(args) == ndims()); |
401 | dims_t pos = {args...}; |
402 | return off_v(pos, false); |
403 | } |
404 | |
405 | /** returns physical offset by logical one. logical offset is represented by |
406 | * a tuple of indices (\param xn, ..., \param x1, \param x0) in already |
407 | * padded area */ |
408 | template <typename... Args> |
409 | dim_t off_padding(Args... args) const { |
410 | assert(sizeof...(args) == ndims()); |
411 | dims_t pos = {args...}; |
412 | return off_v(pos, true); |
413 | } |
414 | |
415 | /** returns physical offset by logical one. Logical offset is represented by |
416 | * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a |
417 | * user responsibility to adjust the result to get offset within blocks */ |
418 | template <typename... Args> |
419 | dim_t blk_off(Args... args) const { |
420 | return _blk_off<sizeof...(args), Args...>(args...); |
421 | } |
422 | |
423 | template <bool skip_first, typename T, typename... Args> |
424 | dim_t blk_off(T xn, Args... args) const { |
425 | return skip_first ? blk_off<Args...>(args...) |
426 | : blk_off<T, Args...>(xn, args...); |
427 | } |
428 | |
429 | /* static functions section */ |
430 | /* TODO: replace with non-static, once md_ becomes non-const ref */ |
431 | |
432 | static status_t compute_blocking( |
433 | memory_desc_t &memory_desc, format_tag_t tag); |
434 | |
435 | private: |
436 | /* TODO: put logical_offset in utils */ |
437 | template <typename T> |
438 | dim_t logical_offset(T x0) const { |
439 | return x0; |
440 | } |
441 | |
442 | template <typename T, typename... Args> |
443 | dim_t logical_offset(T xn, Args... args) const { |
444 | const size_t n_args = sizeof...(args); |
445 | return xn * utils::array_product<n_args>(&dims()[ndims() - n_args]) |
446 | + logical_offset(args...); |
447 | } |
448 | |
449 | template <int ORIG_LEN, typename... Void> |
450 | dim_t _blk_off() const { |
451 | return offset0(); |
452 | } |
453 | |
454 | template <int ORIG_LEN, typename T, typename... Args> |
455 | dim_t _blk_off(T xc, Args... args) const { |
456 | assert(is_blocking_desc()); |
457 | constexpr int dc = ORIG_LEN - sizeof...(args) - 1; |
458 | return xc * blocking_desc().strides[dc] |
459 | + _blk_off<ORIG_LEN, Args...>(args...); |
460 | } |
461 | }; |
462 | |
463 | inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, |
464 | bool with_padding, bool with_data_type, int dim_start) const { |
465 | using namespace utils; |
466 | |
467 | if (one_of(format_kind(), format_kind::undef, format_kind::any)) |
468 | return false; |
469 | if (is_wino_desc() || is_rnn_packed_desc()) return false; |
470 | |
471 | const int ds = dim_start; |
472 | const auto &blk = blocking_desc(); |
473 | const auto &r_blk = rhs.blocking_desc(); |
474 | |
475 | return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */ |
476 | && format_kind() == rhs.format_kind() |
477 | && IMPLICATION(with_data_type, data_type() == rhs.data_type()) |
478 | && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) |
479 | && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) |
480 | && blk.inner_nblks == r_blk.inner_nblks |
481 | && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) |
482 | && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) |
483 | && IMPLICATION(with_padding, |
484 | true |
485 | && array_cmp(padded_dims() + ds, |
486 | rhs.padded_dims() + ds, ndims() - ds) |
487 | && array_cmp(padded_offsets() + ds, |
488 | rhs.padded_offsets() + ds, ndims() - ds)); |
489 | } |
490 | |
491 | inline bool memory_desc_wrapper::consistent_with( |
492 | const memory_desc_wrapper &rhs) const { |
493 | if (ndims() == rhs.ndims()) { |
494 | for (int d = 0; d < ndims(); ++d) { |
495 | if (dims()[d] != rhs.dims()[d]) return false; |
496 | } |
497 | return true; |
498 | } else { |
499 | /* TODO: revise. |
500 | * is the following possible? |
501 | * [1, a, b] <--reorder--> [a, b] |
502 | * [a, 1, b] <--reorder--> [a, b] |
503 | * not, at least for now */ |
504 | return false; |
505 | } |
506 | } |
507 | |
508 | } // namespace impl |
509 | } // namespace dnnl |
510 | |
511 | #endif |
512 | |
513 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
514 | |