1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_TYPE_HELPERS_HPP
18#define COMMON_TYPE_HELPERS_HPP
19
20#include <algorithm>
21#include <assert.h>
22#include <math.h>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "bit_cast.hpp"
27#include "c_types_map.hpp"
28#include "dnnl_traits.hpp"
29#include "math_utils.hpp"
30#include "memory_desc.hpp"
31#include "nstl.hpp"
32#include "utils.hpp"
33
34namespace dnnl {
35namespace impl {
36
37// Global zero memory descriptor. Mostly used for queries to return
38extern memory_desc_t DNNL_API glob_zero_md;
39
40template <typename base_type, typename derived_type>
41status_t safe_ptr_assign(base_type *&lhs, derived_type *rhs) {
42 if (rhs == nullptr) return status::out_of_memory;
43 lhs = rhs;
44 return status::success;
45}
46
47template <typename base_type, typename derived_type>
48status_t safe_ptr_assign(std::unique_ptr<base_type> &lhs, derived_type *rhs) {
49 if (rhs == nullptr) return status::out_of_memory;
50 lhs.reset(rhs);
51 return status::success;
52}
53
54template <typename T, typename U>
55struct is_subset {
56 static constexpr bool value = false;
57};
58template <typename T>
59struct is_subset<T, T> {
60 static constexpr bool value = true;
61};
62template <typename T>
63struct is_subset<T,
64 typename utils::enable_if<nstl::is_integral<T>::value, float>::type> {
65 static constexpr bool value = true;
66};
67#define ISSPEC(t1, t2) \
68 template <> \
69 struct is_subset<t1, t2> { \
70 static constexpr bool value = true; \
71 }
72ISSPEC(int16_t, int32_t);
73ISSPEC(int8_t, int32_t);
74ISSPEC(uint8_t, int32_t);
75ISSPEC(int8_t, int16_t);
76ISSPEC(uint8_t, int16_t);
77#undef ISSPEC
78
79inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs);
80
81namespace types {
82
83inline size_t data_type_size(data_type_t data_type) {
84 using namespace data_type;
85 switch ((int)data_type) {
86 case f16: return sizeof(prec_traits<f16>::type);
87 case bf16: return sizeof(prec_traits<bf16>::type);
88 case tf32: // the tf32 type is an f32
89 case f32: return sizeof(prec_traits<f32>::type);
90 case f64: return sizeof(prec_traits<f64>::type);
91 case s32: return sizeof(prec_traits<s32>::type);
92 case s8: return sizeof(prec_traits<s8>::type);
93 case u8: return sizeof(prec_traits<u8>::type);
94 case data_type::undef:
95 default: assert(!"unknown data_type");
96 }
97 return (size_t)-1; /* not supposed to be reachable */
98}
99
100template <typename T>
101inline T max_value(data_type_t data_type) {
102 using namespace data_type;
103#define CASE(x) \
104 case x: \
105 return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::max())
106 switch (data_type) {
107 CASE(f16);
108 CASE(bf16);
109 CASE(s32);
110 CASE(s8);
111 CASE(u8);
112 case data_type::undef:
113 default: assert(!"unknown data_type");
114 }
115 return static_cast<T>(0); /* not supposed to be reachable */
116#undef CASE
117}
118
119// This is a hack to comply with a big comment below.
120template <>
121inline float max_value(data_type_t data_type) {
122 using namespace data_type;
123#define CASE(x) \
124 case x: \
125 return static_cast<float>( \
126 nstl::numeric_limits<prec_traits<x>::type>::max())
127 switch (data_type) {
128 CASE(f16);
129 CASE(bf16);
130 CASE(s8);
131 CASE(u8);
132 // INT_MAX is not representable in float. The nearest float to it is
133 // INT_MAX + 1 = 2^31 (0x4f000000). Regular conversion instructions such
134 // as `cvtps2dq` or `cvtss2si` will convert this number to INT_MIN
135 // making the result negative. We on purpose choose the previous float
136 // number (0x4effffff) to return leaving the output close to INT_MAX but
137 // still positive. In addition, we adjust validation of this approach.
138 // The main concern against `real` saturation is performance, which
139 // likely to drop (but it was not proved). The only drawback of current
140 // approach is saturating on some integer values before it should happen
141 // in the reality.
142 case s32: return 2147483520.f;
143 case data_type::undef:
144 default: assert(!"unknown data_type");
145 }
146 return 0.f; /* not supposed to be reachable */
147#undef CASE
148}
149
150inline format_kind_t format_tag_to_kind(format_tag_t tag) {
151 switch (tag) {
152 case format_tag::undef: return format_kind::undef;
153 case format_tag::any: return format_kind::any;
154 case format_tag::last: return format_kind::undef;
155 default: return format_kind::blocked;
156 }
157
158 assert(!"unreachable");
159 return format_kind::undef;
160}
161
162// Currently rnn_s8s8_compensation has common bits with rnn_u8s8_compensation
163// and scale_adjust constants so we have to perform additional checks to
164// separate these two cases
165inline bool extra_flag_rnn_s8s8_compensation_is_set(uint64_t flags) {
166 return ((flags & memory_extra_flags::rnn_s8s8_compensation)
167 ^ memory_extra_flags::rnn_s8s8_compensation)
168 == 0;
169}
170
171inline bool memory_extra_desc_is_equal(
172 const memory_extra_desc_t &lhs, const memory_extra_desc_t &rhs) {
173 using namespace memory_extra_flags;
174 return true && lhs.flags == rhs.flags
175 && IMPLICATION(lhs.flags & compensation_conv_s8s8,
176 lhs.compensation_mask == rhs.compensation_mask)
177 && IMPLICATION((lhs.flags & rnn_u8s8_compensation)
178 && !extra_flag_rnn_s8s8_compensation_is_set(
179 lhs.flags),
180 lhs.compensation_mask == rhs.compensation_mask)
181 && IMPLICATION((lhs.flags & scale_adjust)
182 && !extra_flag_rnn_s8s8_compensation_is_set(
183 lhs.flags),
184 lhs.scale_adjust == rhs.scale_adjust)
185 && IMPLICATION(lhs.flags & compensation_conv_asymmetric_src,
186 lhs.asymm_compensation_mask == rhs.asymm_compensation_mask);
187}
188
189inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md,
190 const memory_desc_t &rhs_md, bool ignore_strides = false) {
191 using dnnl::impl::utils::array_cmp;
192
193 assert(lhs_md.format_kind == format_kind::blocked);
194 assert(rhs_md.format_kind == format_kind::blocked);
195
196 const auto &lhs = lhs_md.format_desc.blocking;
197 const auto &rhs = rhs_md.format_desc.blocking;
198 bool equal = lhs.inner_nblks == rhs.inner_nblks
199 && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks)
200 && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks);
201 if (ignore_strides) return equal;
202
203 // Check the strides.
204 // Note: for dimensions of size `1` the stride doesn't really matter.
205 for (int d = 0; d < lhs_md.ndims; ++d) {
206 if (lhs_md.dims[d] == 1 && lhs_md.padded_dims[d] == 1) continue;
207 equal = equal && lhs.strides[d] == rhs.strides[d];
208 }
209
210 return equal;
211}
212
213inline bool wino_desc_is_equal(const wino_desc_t &lhs, const wino_desc_t &rhs) {
214 return lhs.wino_format == rhs.wino_format && lhs.alpha == rhs.alpha
215 && lhs.ic == rhs.ic && lhs.oc == rhs.oc
216 && lhs.ic_block == rhs.ic_block && lhs.oc_block == rhs.oc_block
217 && lhs.ic2_block == rhs.ic2_block && lhs.oc2_block == rhs.oc2_block
218 && lhs.r == rhs.r;
219}
220
221inline bool rnn_packed_desc_is_equal(
222 const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
223 bool ok = true && lhs.format == rhs.format && lhs.ldb == rhs.ldb
224 && lhs.n_parts == rhs.n_parts
225 && lhs.offset_compensation == rhs.offset_compensation
226 && lhs.size == rhs.size && lhs.n == rhs.n;
227 if (!ok) return false;
228
229 for (int i = 0; i < rhs.n_parts; i++)
230 ok = ok && lhs.parts[i] == rhs.parts[i];
231 for (int i = 0; i < rhs.n_parts; i++)
232 ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
233 return ok;
234}
235
236inline memory_desc_t zero_md() {
237 auto zero = memory_desc_t();
238 return zero;
239}
240
241inline bool is_zero_md(const memory_desc_t *md) {
242 return md == nullptr || *md == zero_md();
243}
244
245inline data_type_t default_accum_data_type(
246 data_type_t src_dt, data_type_t dst_dt, bool strict = true) {
247 using namespace utils;
248 using namespace data_type;
249
250 // we allow to use f32 accumulation type only when the
251 // accumulation chain is small. Otherwise, strict should be set to
252 // true
253 if (one_of(src_dt, s8, u8) && (dst_dt != f32 || strict)) return s32;
254
255 if (one_of(f16, src_dt, dst_dt)) return f32;
256 if (one_of(bf16, src_dt, dst_dt)) return f32;
257 if (one_of(f32, src_dt, dst_dt)) return f32;
258 if (one_of(f64, src_dt, dst_dt)) return f64;
259 if (one_of(s32, src_dt, dst_dt)) return s32;
260
261 if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
262
263 return data_type::undef;
264}
265
266inline data_type_t default_accum_data_type(data_type_t src_dt,
267 data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
268 using namespace utils;
269 using namespace data_type;
270 using namespace prop_kind;
271
272 /* prop_kind doesn't matter */
273 if (everyone_is(f32, src_dt, wei_dt)) return f32;
274 if (everyone_is(f64, src_dt, wei_dt)) return f64;
275
276 if (one_of(prop_kind, forward_training, forward_inference)) {
277 if ((src_dt == u8 || src_dt == s8) && wei_dt == s8) return s32;
278 if (one_of(f16, src_dt, wei_dt)) return f32;
279 } else if (prop_kind == backward_data) {
280 if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8
281 && one_of(dst_dt, s8, u8, s32))
282 return s32;
283 if (one_of(f16, dst_dt, wei_dt)) return f32;
284 if (everyone_is(f32, dst_dt, wei_dt) && one_of(src_dt, s8, u8))
285 return f32;
286 }
287
288 if (one_of(bf16, src_dt, wei_dt, dst_dt)) return f32;
289 if (one_of(f16, src_dt, wei_dt, dst_dt)) return f32;
290
291 return data_type::undef;
292}
293
294inline bool is_integral_dt(data_type_t dt) {
295 using namespace data_type;
296 return utils::one_of(dt, s32, s8, u8);
297}
298
299template <typename data_t>
300inline void cvt_from_float(data_t *out, const float *inp, size_t nelems) {
301 assert(!"unimplemented");
302}
303
304template <typename data_t>
305inline void cvt_to_float(float *out, const data_t *inp, size_t nelems) {
306 assert(!"unimplemented");
307}
308
309template <>
310inline void cvt_from_float<bfloat16_t>(
311 bfloat16_t *out, const float *inp, size_t nelems) {
312 cvt_float_to_bfloat16(out, inp, nelems);
313}
314
315template <>
316inline void cvt_to_float<bfloat16_t>(
317 float *out, const bfloat16_t *inp, size_t nelems) {
318 cvt_bfloat16_to_float(out, inp, nelems);
319}
320
321template <>
322inline void cvt_from_float<float16_t>(
323 float16_t *out, const float *inp, size_t nelems) {
324 cvt_float_to_float16(out, inp, nelems);
325}
326
327template <>
328inline void cvt_to_float<float16_t>(
329 float *out, const float16_t *inp, size_t nelems) {
330 cvt_float16_to_float(out, inp, nelems);
331}
332
333inline void cvt_from_float(
334 data_type_t dt, void *out, const float *inp, size_t nelems) {
335 switch (dt) {
336 case data_type::bf16:
337 cvt_from_float((bfloat16_t *)out, inp, nelems);
338 break;
339 case data_type::f16:
340 cvt_from_float((float16_t *)out, inp, nelems);
341 break;
342 default: assert(!"unimplemented");
343 }
344}
345
346} // namespace types
347
348inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
349 using namespace dnnl::impl::utils;
350 // quick path for zero_mds
351 if (utils::everyone_is(0, lhs.ndims, rhs.ndims)) return true;
352
353 bool base_equal = true && lhs.ndims == rhs.ndims
354 && array_cmp(lhs.dims, rhs.dims, lhs.ndims)
355 && lhs.data_type == rhs.data_type
356 && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims)
357 && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims)
358 && lhs.offset0 == rhs.offset0 && lhs.format_kind == rhs.format_kind;
359 if (!base_equal) return false;
360 if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false;
361 if (lhs.format_kind == format_kind::blocked)
362 return types::blocking_desc_is_equal(lhs, rhs);
363 else if (lhs.format_kind == format_kind::wino)
364 return types::wino_desc_is_equal(
365 lhs.format_desc.wino_desc, rhs.format_desc.wino_desc);
366 else if (lhs.format_kind == format_kind::rnn_packed)
367 return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc,
368 rhs.format_desc.rnn_packed_desc);
369 return true;
370}
371
372inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
373 return !operator==(lhs, rhs);
374}
375
376// Comparison operators for descriptors
377#define COMPARE_DESC_MEMBERS(m) lhs.m == rhs.m
378#define COMPARE_DESC_ARRAY_MEMBERS(m, s) utils::array_cmp(lhs.m, rhs.m, s)
379#define DEREF_AND_COMPARE_DESC_MEMBERS(m) *lhs.m == *rhs.m
380#define COMPARE_FLOAT_DESC_MEMBERS(m) utils::equal_with_nan(lhs.m, rhs.m)
381#define COMPARE_FLOAT_DESC_ARRAY_MEMBERS(m, s) \
382 !std::memcmp(lhs.m, rhs.m, sizeof(float) * s)
383
384// clang-format off
385inline bool operator==(const batch_normalization_desc_t &lhs,
386 const batch_normalization_desc_t &rhs) {
387 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
388 && COMPARE_DESC_MEMBERS(prop_kind)
389 && COMPARE_DESC_MEMBERS(src_desc)
390 && COMPARE_DESC_MEMBERS(dst_desc)
391 && COMPARE_DESC_MEMBERS(diff_src_desc)
392 && COMPARE_DESC_MEMBERS(diff_dst_desc)
393 && COMPARE_DESC_MEMBERS(scaleshift_desc)
394 && COMPARE_DESC_MEMBERS(diff_scaleshift_desc)
395 && COMPARE_DESC_MEMBERS(stat_desc)
396 && COMPARE_FLOAT_DESC_MEMBERS(batch_norm_epsilon)
397 && COMPARE_DESC_MEMBERS(flags);
398 return ret;
399}
400
401inline bool operator==(const binary_desc_t &lhs, const binary_desc_t &rhs) {
402 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
403 && COMPARE_DESC_MEMBERS(alg_kind)
404 && COMPARE_DESC_MEMBERS(src_desc[0])
405 && COMPARE_DESC_MEMBERS(src_desc[1])
406 && COMPARE_DESC_MEMBERS(dst_desc);
407 return ret;
408}
409
410inline bool operator==(const concat_desc_t &lhs, const concat_desc_t &rhs) {
411 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
412 && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md)
413 && COMPARE_DESC_MEMBERS(n)
414 && COMPARE_DESC_MEMBERS(concat_dimension);
415
416 if (!ret) return ret;
417
418 for (int i = 0; i < lhs.n; i++) {
419 ret = *lhs.src_mds[i] == *rhs.src_mds[i];
420 if (!ret) break;
421 }
422 return ret;
423}
424
425inline bool operator==(
426 const convolution_desc_t &lhs, const convolution_desc_t &rhs) {
427 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
428 && COMPARE_DESC_MEMBERS(prop_kind)
429 && COMPARE_DESC_MEMBERS(alg_kind)
430 && COMPARE_DESC_MEMBERS(src_desc)
431 && COMPARE_DESC_MEMBERS(diff_src_desc)
432 && COMPARE_DESC_MEMBERS(weights_desc)
433 && COMPARE_DESC_MEMBERS(diff_weights_desc)
434 && COMPARE_DESC_MEMBERS(bias_desc)
435 && COMPARE_DESC_MEMBERS(diff_bias_desc)
436 && COMPARE_DESC_MEMBERS(dst_desc)
437 && COMPARE_DESC_MEMBERS(diff_dst_desc)
438 && COMPARE_DESC_ARRAY_MEMBERS(strides, DNNL_MAX_NDIMS)
439 && COMPARE_DESC_ARRAY_MEMBERS(dilates, DNNL_MAX_NDIMS)
440 && COMPARE_DESC_ARRAY_MEMBERS(padding[0], DNNL_MAX_NDIMS)
441 && COMPARE_DESC_ARRAY_MEMBERS(padding[1], DNNL_MAX_NDIMS)
442 && COMPARE_DESC_MEMBERS(accum_data_type);
443 return ret;
444}
445
446inline bool operator==(const eltwise_desc_t &lhs, const eltwise_desc_t &rhs) {
447 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
448 && COMPARE_DESC_MEMBERS(prop_kind)
449 && COMPARE_DESC_MEMBERS(alg_kind)
450 && COMPARE_DESC_MEMBERS(src_desc)
451 && COMPARE_DESC_MEMBERS(dst_desc)
452 && COMPARE_DESC_MEMBERS(diff_src_desc)
453 && COMPARE_DESC_MEMBERS(diff_dst_desc)
454 && COMPARE_FLOAT_DESC_MEMBERS(alpha)
455 && COMPARE_FLOAT_DESC_MEMBERS(beta);
456 return ret;
457}
458
459inline bool operator==(const gemm_desc_t &lhs, const gemm_desc_t &rhs) {
460 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
461 && COMPARE_DESC_MEMBERS(a_desc)
462 && COMPARE_DESC_MEMBERS(b_desc)
463 && COMPARE_DESC_MEMBERS(c_desc)
464 && COMPARE_DESC_MEMBERS(bias_desc)
465 && COMPARE_DESC_MEMBERS(acc_type)
466 && COMPARE_DESC_MEMBERS(sum_ab)
467 && COMPARE_DESC_MEMBERS(sum_ab_type);
468 return ret;
469}
470
471inline bool operator==(
472 const inner_product_desc_t &lhs, const inner_product_desc_t &rhs) {
473 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
474 && COMPARE_DESC_MEMBERS(prop_kind)
475 && COMPARE_DESC_MEMBERS(src_desc)
476 && COMPARE_DESC_MEMBERS(diff_src_desc)
477 && COMPARE_DESC_MEMBERS(weights_desc)
478 && COMPARE_DESC_MEMBERS(diff_weights_desc)
479 && COMPARE_DESC_MEMBERS(bias_desc)
480 && COMPARE_DESC_MEMBERS(diff_bias_desc)
481 && COMPARE_DESC_MEMBERS(dst_desc)
482 && COMPARE_DESC_MEMBERS(diff_dst_desc)
483 && COMPARE_DESC_MEMBERS(accum_data_type);
484 return ret;
485}
486
487inline bool operator==(
488 const layer_normalization_desc_t &lhs, const layer_normalization_desc_t &rhs) {
489 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
490 && COMPARE_DESC_MEMBERS(prop_kind)
491 && COMPARE_DESC_MEMBERS(src_desc)
492 && COMPARE_DESC_MEMBERS(diff_src_desc)
493 && COMPARE_DESC_MEMBERS(data_scaleshift_desc)
494 && COMPARE_DESC_MEMBERS(diff_data_scaleshift_desc)
495 && COMPARE_DESC_MEMBERS(dst_desc)
496 && COMPARE_DESC_MEMBERS(diff_dst_desc)
497 && COMPARE_DESC_MEMBERS(stat_desc)
498 && COMPARE_FLOAT_DESC_MEMBERS(layer_norm_epsilon)
499 && COMPARE_DESC_MEMBERS(flags);
500 return ret;
501}
502
503inline bool operator==(const lrn_desc_t &lhs, const lrn_desc_t &rhs) {
504 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
505 && COMPARE_DESC_MEMBERS(prop_kind)
506 && COMPARE_DESC_MEMBERS(alg_kind)
507 && COMPARE_DESC_MEMBERS(src_desc)
508 && COMPARE_DESC_MEMBERS(dst_desc)
509 && COMPARE_DESC_MEMBERS(diff_src_desc)
510 && COMPARE_DESC_MEMBERS(diff_dst_desc)
511 && COMPARE_DESC_MEMBERS(local_size)
512 && COMPARE_FLOAT_DESC_MEMBERS(lrn_alpha)
513 && COMPARE_FLOAT_DESC_MEMBERS(lrn_beta)
514 && COMPARE_FLOAT_DESC_MEMBERS(lrn_k);
515 return ret;
516}
517
518inline bool operator==(const matmul_desc_t &lhs, const matmul_desc_t &rhs) {
519 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
520 && COMPARE_DESC_MEMBERS(src_desc)
521 && COMPARE_DESC_MEMBERS(weights_desc)
522 && COMPARE_DESC_MEMBERS(bias_desc)
523 && COMPARE_DESC_MEMBERS(dst_desc)
524 && COMPARE_DESC_MEMBERS(accum_data_type);
525 return ret;
526}
527
528inline bool operator==(
529 const pooling_desc_t &lhs, const pooling_desc_t &rhs) {
530 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
531 && COMPARE_DESC_MEMBERS(prop_kind)
532 && COMPARE_DESC_MEMBERS(alg_kind)
533 && COMPARE_DESC_MEMBERS(src_desc)
534 && COMPARE_DESC_MEMBERS(diff_src_desc)
535 && COMPARE_DESC_MEMBERS(dst_desc)
536 && COMPARE_DESC_MEMBERS(diff_dst_desc)
537 && COMPARE_DESC_ARRAY_MEMBERS(strides, DNNL_MAX_NDIMS)
538 && COMPARE_DESC_ARRAY_MEMBERS(kernel, DNNL_MAX_NDIMS)
539 && COMPARE_DESC_ARRAY_MEMBERS(padding[0], DNNL_MAX_NDIMS)
540 && COMPARE_DESC_ARRAY_MEMBERS(padding[1], DNNL_MAX_NDIMS)
541 && COMPARE_DESC_ARRAY_MEMBERS(dilation, DNNL_MAX_NDIMS)
542 && COMPARE_DESC_MEMBERS(accum_data_type);
543 return ret;
544}
545
546inline bool operator==(const prelu_desc_t &lhs, const prelu_desc_t &rhs) {
547 const bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
548 && COMPARE_DESC_MEMBERS(prop_kind)
549 && COMPARE_DESC_MEMBERS(src_desc)
550 && COMPARE_DESC_MEMBERS(weights_desc)
551 && COMPARE_DESC_MEMBERS(dst_desc)
552 && COMPARE_DESC_MEMBERS(diff_src_desc)
553 && COMPARE_DESC_MEMBERS(diff_weights_desc)
554 && COMPARE_DESC_MEMBERS(diff_dst_desc);
555 return ret;
556}
557
558inline bool operator==(
559 const reduction_desc_t &lhs, const reduction_desc_t &rhs) {
560 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
561 && COMPARE_DESC_MEMBERS(alg_kind)
562 && COMPARE_DESC_MEMBERS(src_desc)
563 && COMPARE_DESC_MEMBERS(dst_desc)
564 && COMPARE_FLOAT_DESC_MEMBERS(p)
565 && COMPARE_FLOAT_DESC_MEMBERS(eps);
566 return ret;
567}
568
569inline bool operator==(const reorder_desc_t &lhs, const reorder_desc_t &rhs) {
570 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
571 && DEREF_AND_COMPARE_DESC_MEMBERS(src_md)
572 && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md)
573 && COMPARE_DESC_MEMBERS(src_engine_kind)
574 && COMPARE_DESC_MEMBERS(dst_engine_kind)
575 && COMPARE_DESC_MEMBERS(is_cross_engine);
576 return ret;
577}
578
579inline bool operator==(
580 const resampling_desc_t &lhs, const resampling_desc_t &rhs) {
581 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
582 && COMPARE_DESC_MEMBERS(alg_kind)
583 && COMPARE_DESC_MEMBERS(src_desc)
584 && COMPARE_DESC_MEMBERS(diff_src_desc)
585 && COMPARE_DESC_MEMBERS(dst_desc)
586 && COMPARE_DESC_MEMBERS(diff_dst_desc)
587 && COMPARE_FLOAT_DESC_ARRAY_MEMBERS(factors, DNNL_MAX_NDIMS);
588 return ret;
589}
590
591inline bool operator==(const rnn_desc_t &lhs, const rnn_desc_t &rhs) {
592 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
593 && COMPARE_DESC_MEMBERS(prop_kind)
594 && COMPARE_DESC_MEMBERS(cell_kind)
595 && COMPARE_DESC_MEMBERS(direction)
596 && COMPARE_DESC_MEMBERS(src_layer_desc)
597 && COMPARE_DESC_MEMBERS(src_iter_desc)
598 && COMPARE_DESC_MEMBERS(src_iter_c_desc)
599 && COMPARE_DESC_MEMBERS(weights_layer_desc)
600 && COMPARE_DESC_MEMBERS(weights_iter_desc)
601 && COMPARE_DESC_MEMBERS(bias_desc)
602 && COMPARE_DESC_MEMBERS(dst_layer_desc)
603 && COMPARE_DESC_MEMBERS(dst_iter_desc)
604 && COMPARE_DESC_MEMBERS(dst_iter_c_desc)
605 && COMPARE_DESC_MEMBERS(weights_peephole_desc)
606 && COMPARE_DESC_MEMBERS(weights_projection_desc)
607 && COMPARE_DESC_MEMBERS(diff_src_layer_desc)
608 && COMPARE_DESC_MEMBERS(diff_src_iter_desc)
609 && COMPARE_DESC_MEMBERS(diff_src_iter_c_desc)
610 && COMPARE_DESC_MEMBERS(diff_weights_layer_desc)
611 && COMPARE_DESC_MEMBERS(diff_weights_iter_desc)
612 && COMPARE_DESC_MEMBERS(diff_bias_desc)
613 && COMPARE_DESC_MEMBERS(diff_dst_layer_desc)
614 && COMPARE_DESC_MEMBERS(diff_dst_iter_desc)
615 && COMPARE_DESC_MEMBERS(diff_dst_iter_c_desc)
616 && COMPARE_DESC_MEMBERS(diff_weights_peephole_desc)
617 && COMPARE_DESC_MEMBERS(diff_weights_projection_desc)
618 && COMPARE_DESC_MEMBERS(flags)
619 && COMPARE_DESC_MEMBERS(activation_kind)
620 && COMPARE_FLOAT_DESC_MEMBERS(alpha)
621 && COMPARE_FLOAT_DESC_MEMBERS(beta);
622 return ret;
623}
624
625inline bool operator==(const shuffle_desc_t &lhs, const shuffle_desc_t &rhs) {
626 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
627 && COMPARE_DESC_MEMBERS(prop_kind)
628 && COMPARE_DESC_MEMBERS(src_desc)
629 && COMPARE_DESC_MEMBERS(dst_desc)
630 && COMPARE_DESC_MEMBERS(axis)
631 && COMPARE_DESC_MEMBERS(group_size);
632 return ret;
633}
634
635inline bool operator==(
636 const softmax_desc_t &lhs, const softmax_desc_t &rhs) {
637 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
638 && COMPARE_DESC_MEMBERS(prop_kind)
639 && COMPARE_DESC_MEMBERS(alg_kind)
640 && COMPARE_DESC_MEMBERS(src_desc)
641 && COMPARE_DESC_MEMBERS(diff_src_desc)
642 && COMPARE_DESC_MEMBERS(dst_desc)
643 && COMPARE_DESC_MEMBERS(diff_dst_desc)
644 && COMPARE_DESC_MEMBERS(softmax_axis);
645 return ret;
646}
647
648inline bool operator==(const sum_desc_t &lhs, const sum_desc_t &rhs) {
649 bool ret = COMPARE_DESC_MEMBERS(primitive_kind)
650 && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md)
651 && COMPARE_DESC_MEMBERS(n);
652 if (!ret) return ret;
653
654 for (int i = 0; i < lhs.n; i++) {
655 ret = *lhs.src_mds[i] == *rhs.src_mds[i];
656 if (!ret) break;
657 }
658 if (!ret) return ret;
659
660 for (int i = 0; i < lhs.n; i++) {
661 ret = ret && COMPARE_FLOAT_DESC_MEMBERS(scales[i]);
662 if (!ret) break;
663 }
664
665 return ret;
666}
667
668inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) {
669 bool ret = COMPARE_DESC_MEMBERS(primitive_kind);
670 return ret;
671}
672// clang-format on
673
674#undef COMPARE_DESC_MEMBERS
675#undef COMPARE_DESC_ARRAY_MEMBERS
676#undef DEREF_AND_COMPARE_DESC_MEMBERS
677#undef COMPARE_FLOAT_DESC_MEMBERS
678#undef COMPARE_FLOAT_DESC_ARRAY_MEMBERS
679
680/** returns true if strides are compatible with memory_desc_t */
681inline bool memory_desc_strides_check(
682 const memory_desc_t &md, const dims_t strides) {
683 if (strides == nullptr || md.ndims == 0
684 || md.format_kind != format_kind::blocked)
685 return true;
686
687 dims_t blocks = {0};
688 int perm[DNNL_MAX_NDIMS] = {0};
689 for (int d = 0; d < md.ndims; ++d) {
690 // no strides check needed for empty tensor
691 if (md.padded_dims[d] == 0) return true;
692
693 // no strides verification for runtime dims
694 const bool has_runtime_dim = utils::one_of(
695 DNNL_RUNTIME_DIM_VAL, strides[d], md.padded_dims[d]);
696 if (has_runtime_dim) return true;
697
698 perm[d] = d;
699 blocks[d] = 1;
700 }
701
702 dim_t block_size = 1;
703 const auto &blk = md.format_desc.blocking;
704 for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
705 blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk];
706 block_size *= blk.inner_blks[iblk];
707 }
708
709 // A custom comparator to yield linear order on perm
710 auto idx_sorter = [&](const int a, const int b) -> bool {
711 if (strides[a] == strides[b] && md.padded_dims[a] == md.padded_dims[b])
712 return a < b;
713 else if (strides[a] == strides[b])
714 return md.padded_dims[a] < md.padded_dims[b];
715 else
716 return strides[a] < strides[b];
717 };
718 std::sort(perm, perm + md.ndims, idx_sorter);
719
720 dim_t min_stride = block_size;
721 for (int idx = 0; idx < md.ndims; ++idx) {
722 const int d = perm[idx];
723
724 // Make an exception for strides[d] == 0 as it has broadcast semantics
725 // Note: owing to being sorted, these are the initial strides
726
727 // FIXME: make an exception for dims[d] == 1 with the
728 // assumption that no code applies that stride when the only
729 // index accessed for that dimenstion is 0. This is because PT
730 // can use "dummy" padding in those situations
731 if ((strides[d] == 0) || (md.padded_dims[d] == 1))
732 continue;
733 else if (strides[d] < min_stride)
734 return false;
735
736 // update min_stride for next iteration
737 const auto padded_dim = md.padded_dims[d];
738 min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
739 }
740 return true;
741}
742
743inline status_t memory_desc_init_by_strides(
744 memory_desc_t &md, const dims_t strides) {
745 return memory_desc_init_by_strides(
746 md, md.ndims, md.dims, md.data_type, strides);
747}
748
749inline status_t memory_desc_init_by_tag(
750 memory_desc_t &md, format_tag_t tag, const dims_t strides = nullptr) {
751 status_t status
752 = memory_desc_init_by_tag(md, md.ndims, md.dims, md.data_type, tag);
753 if (status != status::success || strides == nullptr) return status;
754
755 if (!memory_desc_strides_check(md, strides))
756 return status::invalid_arguments;
757
758 for (int d = 0; d < md.ndims; ++d)
759 md.format_desc.blocking.strides[d] = strides[d];
760
761 return status::success;
762}
763
764/** inits memory descriptor based on logical dimensions kept in @p md, and the
765 * blocking structure @p blk.
766 *
767 * @note blk.strides represent the order only (from smaller to bigger)
768 *
769 * TODO: move md related functions to one single place
770 */
771inline status_t memory_desc_init_by_blocking_desc(
772 memory_desc_t &md, const blocking_desc_t &blk) {
773 dims_t blocks = {0};
774 utils::array_set(blocks, 1, md.ndims);
775 dim_t block_size = 1;
776 for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
777 blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk];
778 block_size *= blk.inner_blks[iblk];
779 }
780
781 for (int d = 0; d < md.ndims; ++d) {
782 md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
783 md.padded_offsets[d] = 0;
784 }
785 md.offset0 = 0;
786
787 md.format_kind = format_kind::blocked;
788 auto &mblk = md.format_desc.blocking;
789 mblk = blk;
790
791 const int ndims = nstl::min(DNNL_MAX_NDIMS, md.ndims); // make GCC 5 happy
792 utils::array_copy(mblk.strides, blk.strides, ndims);
793
794 dims_t ou_blocks = {0};
795 utils::array_copy(ou_blocks, md.padded_dims, ndims);
796
797 int perm[DNNL_MAX_NDIMS];
798 for (int d = 0; d < ndims; ++d) {
799 perm[d] = d;
800 ou_blocks[d] /= blocks[d];
801 }
802
803 utils::simultaneous_sort(
804 mblk.strides, ou_blocks, perm, ndims, [](stride_t a, stride_t b) {
805 if (utils::one_of(DNNL_RUNTIME_DIM_VAL, a, b))
806 return DNNL_RUNTIME_DIM_VAL;
807 return b - a;
808 });
809
810 dim_t stride = block_size;
811 for (int _d = ndims - 1; _d >= 0; --_d) {
812 const int d = perm[_d];
813 md.format_desc.blocking.strides[d] = stride;
814 if (md.padded_dims[d] != 0) { // Keep same stride for zero dim
815 stride *= md.padded_dims[d] / blocks[d];
816 }
817 }
818
819 md.extra = utils::zero<memory_extra_desc_t>();
820
821 return status::success;
822}
823
824/** inits memory descriptor @p md based on another one memory descriptor
825 * @p md_base and given @p data_type.
826 * Essentially: { md = md_base; md.dt = data_type; } */
827inline status_t memory_desc_init_by_md_and_dt(memory_desc_t &md,
828 const memory_desc_t &md_base, data_type_t data_type) {
829 if (&md != &md_base) md = md_base;
830 md.data_type = data_type;
831 return status::success;
832}
833
834/** returns true if memory desc @p md corresponds to the given format tag and
835 * strides.
836 * If strides are not passed (or passed as nullptr) the dense structure is
837 * assumed (i.e. the one that memory_desc_init_by_tag() returns).
838 * Strides might contain `0` value, indicating the stride must match the one
839 * that memory_desc_init_by_tag() returns.
840 * Strides might contain `-1` values, that would be ignored during the
841 * comparison. For instance, this can be used if a stride along minibatch
842 * doesn't matter. */
843inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag,
844 const dims_t strides = nullptr) {
845 if (md.format_kind != types::format_tag_to_kind(tag)) return false;
846
847 memory_desc_t md_gold;
848 status_t status = memory_desc_init_by_tag(
849 md_gold, md.ndims, md.dims, md.data_type, tag);
850 if (status != status::success) return false;
851
852 if (md.format_kind != format_kind::blocked)
853 return false; // unimplemented yet
854
855 const auto &blk = md.format_desc.blocking;
856 const auto &blk_gold = md_gold.format_desc.blocking;
857
858 using utils::array_cmp;
859 bool same_blocks = true && blk.inner_nblks == blk_gold.inner_nblks
860 && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks)
861 && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks);
862
863 if (!same_blocks) return false;
864
865 if (strides == nullptr)
866 return array_cmp(blk.strides, blk_gold.strides, md.ndims);
867
868 for (int d = 0; d < md.ndims; ++d) {
869 dim_t stride = strides[d];
870 if (stride == -1) continue;
871 if (stride == 0) stride = blk_gold.strides[d];
872 if (blk.strides[d] != stride) return false;
873 }
874
875 return true;
876}
877
878/** returns matching tag (or undef if match is not found)
879 * XXX: This is a workaround that eventually should go away! */
880template <typename... Tags>
881format_tag_t memory_desc_matches_one_of_tag(
882 const memory_desc_t &md, Tags... tags) {
883 for (const auto tag : {tags...}) {
884 if (memory_desc_matches_tag(md, tag)) return tag;
885 }
886 return format_tag::undef;
887}
888
889/** returns true if fp32 value denotes DNNL_RUNTIME_F32_VAL */
890inline bool is_runtime_value(float val) {
891 return utils::bit_cast<unsigned>(val) == DNNL_RUNTIME_F32_VAL_REP.u;
892}
893
894/** returns true if s32 value denotes DNNL_RUNTIME_S32_VAL */
895inline bool is_runtime_value(int val) {
896 return val == DNNL_RUNTIME_S32_VAL;
897}
898
899/** returns true if dim_t value denotes DNNL_RUNTIME_DIM_VAL */
900inline bool is_runtime_value(dim_t val) {
901 return val == DNNL_RUNTIME_DIM_VAL;
902}
903
904inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
905 data_type_t data_type, format_kind_t format_kind) {
906 using namespace data_type;
907
908 if (ndims == 0) return true;
909
910 bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
911 && utils::one_of(data_type, f16, bf16, f32, f64, s32, s8, u8);
912 if (!ok) return false;
913
914 bool has_runtime_dims = false;
915 for (int d = 0; d < ndims; ++d) {
916 if (dims[d] != DNNL_RUNTIME_DIM_VAL && dims[d] < 0) return false;
917 if (dims[d] == DNNL_RUNTIME_DIM_VAL) has_runtime_dims = true;
918 }
919
920 if (has_runtime_dims) {
921 // format `any` is currently not supported for run-time dims
922 if (format_kind == format_kind::any) return false;
923 }
924
925 return true;
926}
927
928inline bool memory_desc_sanity_check(const memory_desc_t &md) {
929 return memory_desc_sanity_check(
930 md.ndims, md.dims, md.data_type, format_kind::undef);
931}
932
933inline void copy_c_op_desc(op_desc_t *dst, const op_desc_t *src) {
934#define CASE_OP_DESC(pkind) \
935 case primitive_kind::pkind: dst->pkind = src->pkind; break;
936
937 switch ((int)src->kind) {
938 CASE_OP_DESC(batch_normalization);
939 CASE_OP_DESC(binary);
940 CASE_OP_DESC(convolution);
941 CASE_OP_DESC(deconvolution);
942 CASE_OP_DESC(eltwise);
943 CASE_OP_DESC(gemm);
944 CASE_OP_DESC(inner_product);
945 CASE_OP_DESC(layer_normalization);
946 CASE_OP_DESC(lrn);
947 CASE_OP_DESC(matmul);
948 CASE_OP_DESC(pooling);
949 CASE_OP_DESC(prelu);
950 CASE_OP_DESC(reduction);
951 CASE_OP_DESC(resampling);
952 CASE_OP_DESC(rnn);
953 CASE_OP_DESC(shuffle);
954 CASE_OP_DESC(softmax);
955
956 // Internal descs
957 CASE_OP_DESC(zero_pad);
958 default: assert(!"unknown C primitive kind");
959 }
960#undef CASE_OP_DESC
961}
962
963} // namespace impl
964} // namespace dnnl
965
966#include "memory_desc_wrapper.hpp"
967
968#endif
969
970// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
971