1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_MEMORY_DESC_WRAPPER_HPP
18#define COMMON_MEMORY_DESC_WRAPPER_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "nstl.hpp"
24#include "utils.hpp"
25
26#include "type_helpers.hpp"
27
28namespace dnnl {
29namespace impl {
30
31/** thin wrapper class over \struct memory_desc_t which allows easy
32 * manipulations with underlying C structure, which is taken by reference */
33struct memory_desc_wrapper : public c_compatible {
34 const memory_desc_t *md_;
35
36 /** constructor which takes a reference to a constant underlying C memory
37 * descriptor \param md */
38 memory_desc_wrapper(const memory_desc_t *md)
39 : md_(md ? md : &glob_zero_md) {}
40 memory_desc_wrapper(const memory_desc_t &md) : memory_desc_wrapper(&md) {}
41
42 /* implementing attributes */
43 int ndims() const { return md_->ndims; }
44 const dims_t &dims() const { return md_->dims; }
45 data_type_t data_type() const { return md_->data_type; }
46
47 const dims_t &padded_dims() const { return md_->padded_dims; }
48 const dims_t &padded_offsets() const { return md_->padded_offsets; }
49 dim_t offset0() const { return md_->offset0; }
50
51 format_kind_t format_kind() const { return md_->format_kind; }
52
53 bool is_blocking_desc() const {
54 return format_kind() == format_kind::blocked;
55 }
56 bool is_wino_desc() const { return format_kind() == format_kind::wino; }
57 bool is_rnn_packed_desc() const {
58 return format_kind() == format_kind::rnn_packed;
59 }
60
61 const blocking_desc_t &blocking_desc() const {
62 assert(is_blocking_desc());
63 return md_->format_desc.blocking;
64 }
65 const wino_desc_t &wino_desc() const {
66 assert(is_wino_desc());
67 return md_->format_desc.wino_desc;
68 }
69 const rnn_packed_desc_t &rnn_packed_desc() const {
70 assert(is_rnn_packed_desc());
71 return md_->format_desc.rnn_packed_desc;
72 }
73
74 const memory_extra_desc_t &extra() const { return md_->extra; }
75
76 /* some useful function */
77
78 /** returns the number of elements including padding if \param with_padding
79 * is true, and the number of data elements otherwise */
80 dim_t nelems(bool with_padding = false) const {
81 if (is_zero()) return 0;
82 if (has_runtime_dims()) return DNNL_RUNTIME_DIM_VAL;
83 return utils::array_product(
84 with_padding ? padded_dims() : dims(), ndims());
85 }
86
87 /** returns true if memory descriptor is zero */
88 bool is_zero() const { return ndims() == 0; }
89
90 /** returns true if memory descriptor contains zero as one of its dim */
91 bool has_zero_dim() const {
92 for (int d = 0; d < ndims(); ++d)
93 if (dims()[d] == 0) return true;
94 return false;
95 }
96
97 /** return the size of data type (a shortcut) */
98 size_t data_type_size() const { return types::data_type_size(data_type()); }
99
100 /** return the size of data type of additional buffer */
101 size_t additional_buffer_data_size(uint64_t flag_select) const {
102 using namespace memory_extra_flags;
103 if (flag_select & compensation_conv_s8s8) return sizeof(int32_t);
104 if ((flag_select & rnn_u8s8_compensation)
105 && !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select))
106 return sizeof(float);
107 if (flag_select & compensation_conv_asymmetric_src)
108 return sizeof(int32_t);
109 return 0;
110 }
111
112 /** return true if memory format has additional buffer */
113 bool is_additional_buffer() const {
114 using namespace memory_extra_flags;
115 // Currently compensation is not required for rnn_s8s8_compensation,
116 // but it has common bit with rnn_u8s8_compensation constant so we have
117 // to exclude rnn_s8s8_compensation case explicitly
118 return ((extra().flags
119 & (compensation_conv_s8s8 | rnn_u8s8_compensation
120 | compensation_conv_asymmetric_src))
121 && !types::extra_flag_rnn_s8s8_compensation_is_set(
122 extra().flags));
123 }
124
125 /** returns the size required for a particular extra memory buffer */
126 size_t additional_buffer_size(memory_extra_flags_t flag) const {
127 using namespace memory_extra_flags;
128
129 auto calculate_size = [=](int cmask, size_t buff_data_size) {
130 assert(utils::one_of(cmask, 1, 2, 3, 5, 13, 27));
131 dim_t prod = 1;
132 for (int d = 0; d < ndims(); ++d)
133 if (cmask & (1 << d)) { prod *= padded_dims()[d]; }
134 return (size_t)prod * buff_data_size;
135 };
136
137 if (extra().flags & compensation_conv_s8s8) {
138 return calculate_size(extra().compensation_mask,
139 additional_buffer_data_size(flag));
140 }
141
142 if ((extra().flags & rnn_u8s8_compensation)
143 && !types::extra_flag_rnn_s8s8_compensation_is_set(
144 extra().flags)) {
145 return calculate_size(extra().compensation_mask,
146 additional_buffer_data_size(flag));
147 }
148 if (extra().flags & compensation_conv_asymmetric_src) {
149 return calculate_size(extra().asymm_compensation_mask,
150 additional_buffer_data_size(flag));
151 }
152
153 return 0;
154 }
155
156 /** returns the size of the appended buffer when the memory descriptor
157 * requires extra space to hold compensation data */
158 size_t additional_buffer_size() const {
159 using namespace memory_extra_flags;
160
161 size_t buff_size = 0;
162 buff_size += additional_buffer_size(compensation_conv_s8s8);
163 buff_size += additional_buffer_size(rnn_u8s8_compensation);
164 buff_size += additional_buffer_size(compensation_conv_asymmetric_src);
165 return buff_size;
166 }
167
168 /** returns the size required to store described memory
169 * note: if offset0 != 0 returns 0 (need to specify the behavior) */
170 size_t size() const {
171 if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)
172 || is_zero() || has_zero_dim())
173 return 0;
174
175 if (has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL;
176
177 if (format_kind() == format_kind::wino) {
178 return wino_desc().size;
179 } else if (format_kind() == format_kind::rnn_packed) {
180 return rnn_packed_desc().size;
181 } else {
182 if (offset0() != 0) return 0;
183
184 dims_t blocks = {0};
185 compute_blocks(blocks);
186
187 const auto &bd = blocking_desc();
188
189 size_t max_size = 0;
190 for (int d = 0; d < ndims(); ++d) {
191 dim_t strided_pdim = padded_dims()[d] / blocks[d];
192 dim_t effective_stride = strided_pdim == 1 ? 1 : bd.strides[d];
193 max_size = nstl::max<size_t>(
194 max_size, strided_pdim * effective_stride);
195 }
196
197 if (max_size == 1 && bd.inner_nblks != 0) {
198 max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
199 }
200
201 size_t data_size = max_size * data_type_size();
202 if (is_additional_buffer()) {
203 // The additional buffers, typically of data type int32_t, float
204 // are stored at the end of data. Pad the data, so that the
205 // buffers are properly aligned to their data type.
206 const size_t alignment_in_bytes = 4;
207 data_size = utils::rnd_up(data_size, alignment_in_bytes);
208 }
209 return data_size + additional_buffer_size();
210 }
211 }
212
213 /** returns the true if some dim is broadcasted (stride == 0) */
214 bool has_broadcast() const {
215 const auto &bd = blocking_desc();
216 for (int d = 0; d < ndims(); d++)
217 if (bd.strides[d] == 0) return true;
218 return false;
219 }
220
221 /** returns true if number of non unit dims is <= `n`. */
222 bool count_non_unit_dims(int n) const {
223 int non_unit_dims = 0;
224 for (int d = 0; d < ndims(); d++) {
225 if (dims()[d] != 1) non_unit_dims++;
226 }
227 return non_unit_dims <= n;
228 }
229
230 /** returns true if data is dense in memory */
231 bool is_dense(bool with_padding = false) const {
232 if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
233 return false;
234 if (has_runtime_dims_or_strides() || has_broadcast()) return false;
235 return nelems(with_padding) * data_type_size() == size();
236 }
237
238 /** returns true if format is set to `any` */
239 bool format_any() const { return format_kind() == format_kind::any; }
240
241 /** returns true if at least one dim is not known */
242 bool has_runtime_dims() const {
243 for (int d = 0; d < ndims(); ++d)
244 if (dims()[d] == DNNL_RUNTIME_DIM_VAL) return true;
245 return false;
246 }
247
248 /** returns true if at least one dim is not known */
249 bool has_runtime_strides() const {
250 if (!is_blocking_desc()) return false;
251 for (int d = 0; d < ndims(); ++d)
252 if (blocking_desc().strides[d] == DNNL_RUNTIME_DIM_VAL) return true;
253 return false;
254 }
255
256 /** returns true if memory format is runtime_dims_or_strides-dependent */
257 bool has_runtime_dims_or_strides() const {
258 return has_runtime_dims() || has_runtime_strides();
259 }
260
261 /** returns true if the only (potentially) padded dim is \param dim */
262 bool only_padded_dim(int dim) const {
263 if (has_runtime_dims()) return false;
264 for (int d = 0; d < ndims(); ++d)
265 if (d != dim && dims()[d] != padded_dims()[d]) return false;
266 return true;
267 }
268
269 /** returns true if memory desc has blocked layout and block dims are 1s */
270 bool is_plain() const {
271 if (!is_blocking_desc()) return false;
272 return blocking_desc().inner_nblks == 0;
273 }
274
275 /** returns overall block sizes */
276 void compute_blocks(dims_t blocks) const {
277 if (!is_blocking_desc()) {
278 utils::array_set(blocks, 0, ndims());
279 return;
280 }
281
282 utils::array_set(blocks, 1, ndims());
283
284 const auto &bd = blocking_desc();
285 for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
286 blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
287 }
288
289 // XXX: for backward compatibility with v0.x
290 // strides_compat[0]: stride between the first elements of adjacent blocks
291 // strides_compat[1]: strides between elements in the same block
292 //
293 // For 2+ level blocking all inner blocks are treated as a single block.
294 void compute_strides_compat(dims_t *strides_compat) const;
295
296 /* comparison section */
297
298 bool operator==(const memory_desc_wrapper &rhs) const {
299 return *this->md_ == *rhs.md_;
300 }
301 bool operator!=(const memory_desc_wrapper &rhs) const {
302 return !operator==(rhs);
303 }
304 bool operator==(const memory_desc_t &rhs) const {
305 return operator==(memory_desc_wrapper(rhs));
306 }
307 bool operator!=(const memory_desc_t &rhs) const { return !operator==(rhs); }
308
309 /** returns true if data (w/o padding if with_padding == false and w/
310 * padding otherwise) have the same physical structure, i.e. dimensions,
311 * strides, and blocked structure. Depending on with_data_type flag
312 * data_type is taken or not taken into account. dim_start allows to check
313 * similarity for the logical part of data [dim_start .. ndims()].
314 * CAUTION: format kind any and undef are not similar to whatever, hence the
315 * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
316 /* TODO: revise */
317 bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
318 bool with_data_type = true, int dim_start = 0) const;
319
320 /** returns true if one memory can be reordered to another */
321 bool consistent_with(const memory_desc_wrapper &rhs) const;
322
323 /** returns true if the memory desc corresponds to the given format tag and
324 * strides.
325 * @sa memory_desc_matches_tag */
326 bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
327 return memory_desc_matches_tag(*md_, tag, strides);
328 }
329
330 /** returns matching tag (or undef if match is not found)
331 * XXX: This is a workaround that eventually should go away! */
332 template <typename... Tags>
333 format_tag_t matches_one_of_tag(Tags... tags) const {
334 for (const auto tag : {tags...}) {
335 if (memory_desc_matches_tag(*md_, tag)) return tag;
336 }
337 return format_tag::undef;
338 }
339
340 /* offset section */
341
342 /** returns physical offset by logical one. logical offset is represented by
343 * an array \param pos. if \param is_pos_padded is true \param pos
344 * represents the position in already padded area */
345 dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
346 assert(is_blocking_desc());
347 const blocking_desc_t &blk = blocking_desc();
348
349 dims_t pos_copy = {0};
350 for (int d = 0; d < ndims(); ++d)
351 pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
352
353 dim_t phys_offset = offset0();
354
355 if (blk.inner_nblks > 0) {
356 dim_t blk_stride = 1;
357 for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
358 const int d = blk.inner_idxs[iblk];
359
360 dim_t p;
361 /* switch to faster 32-bit division when possible.
362 * inner blocks always fit 32-bit. */
363 if (pos_copy[d] <= INT32_MAX) {
364 p = (int32_t)pos_copy[d] % (int32_t)blk.inner_blks[iblk];
365 pos_copy[d] = (int32_t)pos_copy[d]
366 / (int32_t)blk.inner_blks[iblk];
367 } else {
368 p = pos_copy[d] % blk.inner_blks[iblk];
369 pos_copy[d] /= blk.inner_blks[iblk];
370 }
371
372 phys_offset += p * blk_stride;
373
374 blk_stride *= blk.inner_blks[iblk];
375 }
376 }
377
378 for (int d = 0; d < ndims(); ++d) {
379 const dim_t p = pos_copy[d];
380 phys_offset += p * blk.strides[d];
381 }
382
383 return phys_offset;
384 }
385
386 /** returns physical offset by logical one. logical offset is represented by
387 * a scalar \param l_offset. if \param is_pos_padded is true, \param
388 * l_offset represents logical offset in already padded area */
389 dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
390 dims_t dims_pos;
391 const auto &cur_dims = is_pos_padded ? padded_dims() : dims();
392 utils::l_dims_by_l_offset(dims_pos, l_offset, cur_dims, ndims());
393 return off_v(dims_pos, is_pos_padded);
394 }
395
396 /** returns physical offset by logical one. logical offset is represented by
397 * a tuple of indices (\param xn, ..., \param x1, \param x0) */
398 template <typename... Args>
399 dim_t off(Args... args) const {
400 assert(sizeof...(args) == ndims());
401 dims_t pos = {args...};
402 return off_v(pos, false);
403 }
404
405 /** returns physical offset by logical one. logical offset is represented by
406 * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
407 * padded area */
408 template <typename... Args>
409 dim_t off_padding(Args... args) const {
410 assert(sizeof...(args) == ndims());
411 dims_t pos = {args...};
412 return off_v(pos, true);
413 }
414
415 /** returns physical offset by logical one. Logical offset is represented by
416 * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
417 * user responsibility to adjust the result to get offset within blocks */
418 template <typename... Args>
419 dim_t blk_off(Args... args) const {
420 return _blk_off<sizeof...(args), Args...>(args...);
421 }
422
423 template <bool skip_first, typename T, typename... Args>
424 dim_t blk_off(T xn, Args... args) const {
425 return skip_first ? blk_off<Args...>(args...)
426 : blk_off<T, Args...>(xn, args...);
427 }
428
429 /* static functions section */
430 /* TODO: replace with non-static, once md_ becomes non-const ref */
431
432 static status_t compute_blocking(
433 memory_desc_t &memory_desc, format_tag_t tag);
434
435private:
436 /* TODO: put logical_offset in utils */
437 template <typename T>
438 dim_t logical_offset(T x0) const {
439 return x0;
440 }
441
442 template <typename T, typename... Args>
443 dim_t logical_offset(T xn, Args... args) const {
444 const size_t n_args = sizeof...(args);
445 return xn * utils::array_product<n_args>(&dims()[ndims() - n_args])
446 + logical_offset(args...);
447 }
448
449 template <int ORIG_LEN, typename... Void>
450 dim_t _blk_off() const {
451 return offset0();
452 }
453
454 template <int ORIG_LEN, typename T, typename... Args>
455 dim_t _blk_off(T xc, Args... args) const {
456 assert(is_blocking_desc());
457 constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
458 return xc * blocking_desc().strides[dc]
459 + _blk_off<ORIG_LEN, Args...>(args...);
460 }
461};
462
463inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
464 bool with_padding, bool with_data_type, int dim_start) const {
465 using namespace utils;
466
467 if (one_of(format_kind(), format_kind::undef, format_kind::any))
468 return false;
469 if (is_wino_desc() || is_rnn_packed_desc()) return false;
470
471 const int ds = dim_start;
472 const auto &blk = blocking_desc();
473 const auto &r_blk = rhs.blocking_desc();
474
475 return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
476 && format_kind() == rhs.format_kind()
477 && IMPLICATION(with_data_type, data_type() == rhs.data_type())
478 && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
479 && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
480 && blk.inner_nblks == r_blk.inner_nblks
481 && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
482 && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
483 && IMPLICATION(with_padding,
484 true
485 && array_cmp(padded_dims() + ds,
486 rhs.padded_dims() + ds, ndims() - ds)
487 && array_cmp(padded_offsets() + ds,
488 rhs.padded_offsets() + ds, ndims() - ds));
489}
490
491inline bool memory_desc_wrapper::consistent_with(
492 const memory_desc_wrapper &rhs) const {
493 if (ndims() == rhs.ndims()) {
494 for (int d = 0; d < ndims(); ++d) {
495 if (dims()[d] != rhs.dims()[d]) return false;
496 }
497 return true;
498 } else {
499 /* TODO: revise.
500 * is the following possible?
501 * [1, a, b] <--reorder--> [a, b]
502 * [a, 1, b] <--reorder--> [a, b]
503 * not, at least for now */
504 return false;
505 }
506}
507
508} // namespace impl
509} // namespace dnnl
510
511#endif
512
513// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
514