1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_TYPE_HELPERS_HPP |
18 | #define COMMON_TYPE_HELPERS_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <assert.h> |
22 | #include <math.h> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "bit_cast.hpp" |
27 | #include "c_types_map.hpp" |
28 | #include "dnnl_traits.hpp" |
29 | #include "math_utils.hpp" |
30 | #include "memory_desc.hpp" |
31 | #include "nstl.hpp" |
32 | #include "utils.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | |
37 | // Global zero memory descriptor. Mostly used for queries to return |
38 | extern memory_desc_t DNNL_API glob_zero_md; |
39 | |
40 | template <typename base_type, typename derived_type> |
41 | status_t safe_ptr_assign(base_type *&lhs, derived_type *rhs) { |
42 | if (rhs == nullptr) return status::out_of_memory; |
43 | lhs = rhs; |
44 | return status::success; |
45 | } |
46 | |
47 | template <typename base_type, typename derived_type> |
48 | status_t safe_ptr_assign(std::unique_ptr<base_type> &lhs, derived_type *rhs) { |
49 | if (rhs == nullptr) return status::out_of_memory; |
50 | lhs.reset(rhs); |
51 | return status::success; |
52 | } |
53 | |
54 | template <typename T, typename U> |
55 | struct is_subset { |
56 | static constexpr bool value = false; |
57 | }; |
58 | template <typename T> |
59 | struct is_subset<T, T> { |
60 | static constexpr bool value = true; |
61 | }; |
62 | template <typename T> |
63 | struct is_subset<T, |
64 | typename utils::enable_if<nstl::is_integral<T>::value, float>::type> { |
65 | static constexpr bool value = true; |
66 | }; |
67 | #define ISSPEC(t1, t2) \ |
68 | template <> \ |
69 | struct is_subset<t1, t2> { \ |
70 | static constexpr bool value = true; \ |
71 | } |
72 | ISSPEC(int16_t, int32_t); |
73 | ISSPEC(int8_t, int32_t); |
74 | ISSPEC(uint8_t, int32_t); |
75 | ISSPEC(int8_t, int16_t); |
76 | ISSPEC(uint8_t, int16_t); |
77 | #undef ISSPEC |
78 | |
79 | inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); |
80 | |
81 | namespace types { |
82 | |
83 | inline size_t data_type_size(data_type_t data_type) { |
84 | using namespace data_type; |
85 | switch ((int)data_type) { |
86 | case f16: return sizeof(prec_traits<f16>::type); |
87 | case bf16: return sizeof(prec_traits<bf16>::type); |
88 | case tf32: // the tf32 type is an f32 |
89 | case f32: return sizeof(prec_traits<f32>::type); |
90 | case f64: return sizeof(prec_traits<f64>::type); |
91 | case s32: return sizeof(prec_traits<s32>::type); |
92 | case s8: return sizeof(prec_traits<s8>::type); |
93 | case u8: return sizeof(prec_traits<u8>::type); |
94 | case data_type::undef: |
95 | default: assert(!"unknown data_type" ); |
96 | } |
97 | return (size_t)-1; /* not supposed to be reachable */ |
98 | } |
99 | |
100 | template <typename T> |
101 | inline T max_value(data_type_t data_type) { |
102 | using namespace data_type; |
103 | #define CASE(x) \ |
104 | case x: \ |
105 | return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::max()) |
106 | switch (data_type) { |
107 | CASE(f16); |
108 | CASE(bf16); |
109 | CASE(s32); |
110 | CASE(s8); |
111 | CASE(u8); |
112 | case data_type::undef: |
113 | default: assert(!"unknown data_type" ); |
114 | } |
115 | return static_cast<T>(0); /* not supposed to be reachable */ |
116 | #undef CASE |
117 | } |
118 | |
119 | // This is a hack to comply with a big comment below. |
120 | template <> |
121 | inline float max_value(data_type_t data_type) { |
122 | using namespace data_type; |
123 | #define CASE(x) \ |
124 | case x: \ |
125 | return static_cast<float>( \ |
126 | nstl::numeric_limits<prec_traits<x>::type>::max()) |
127 | switch (data_type) { |
128 | CASE(f16); |
129 | CASE(bf16); |
130 | CASE(s8); |
131 | CASE(u8); |
132 | // INT_MAX is not representable in float. The nearest float to it is |
133 | // INT_MAX + 1 = 2^31 (0x4f000000). Regular conversion instructions such |
134 | // as `cvtps2dq` or `cvtss2si` will convert this number to INT_MIN |
135 | // making the result negative. We on purpose choose the previous float |
136 | // number (0x4effffff) to return leaving the output close to INT_MAX but |
137 | // still positive. In addition, we adjust validation of this approach. |
138 | // The main concern against `real` saturation is performance, which |
139 | // likely to drop (but it was not proved). The only drawback of current |
140 | // approach is saturating on some integer values before it should happen |
141 | // in the reality. |
142 | case s32: return 2147483520.f; |
143 | case data_type::undef: |
144 | default: assert(!"unknown data_type" ); |
145 | } |
146 | return 0.f; /* not supposed to be reachable */ |
147 | #undef CASE |
148 | } |
149 | |
150 | inline format_kind_t format_tag_to_kind(format_tag_t tag) { |
151 | switch (tag) { |
152 | case format_tag::undef: return format_kind::undef; |
153 | case format_tag::any: return format_kind::any; |
154 | case format_tag::last: return format_kind::undef; |
155 | default: return format_kind::blocked; |
156 | } |
157 | |
158 | assert(!"unreachable" ); |
159 | return format_kind::undef; |
160 | } |
161 | |
162 | // Currently rnn_s8s8_compensation has common bits with rnn_u8s8_compensation |
163 | // and scale_adjust constants so we have to perform additional checks to |
164 | // separate these two cases |
165 | inline bool (uint64_t flags) { |
166 | return ((flags & memory_extra_flags::rnn_s8s8_compensation) |
167 | ^ memory_extra_flags::rnn_s8s8_compensation) |
168 | == 0; |
169 | } |
170 | |
171 | inline bool ( |
172 | const memory_extra_desc_t &lhs, const memory_extra_desc_t &rhs) { |
173 | using namespace memory_extra_flags; |
174 | return true && lhs.flags == rhs.flags |
175 | && IMPLICATION(lhs.flags & compensation_conv_s8s8, |
176 | lhs.compensation_mask == rhs.compensation_mask) |
177 | && IMPLICATION((lhs.flags & rnn_u8s8_compensation) |
178 | && !extra_flag_rnn_s8s8_compensation_is_set( |
179 | lhs.flags), |
180 | lhs.compensation_mask == rhs.compensation_mask) |
181 | && IMPLICATION((lhs.flags & scale_adjust) |
182 | && !extra_flag_rnn_s8s8_compensation_is_set( |
183 | lhs.flags), |
184 | lhs.scale_adjust == rhs.scale_adjust) |
185 | && IMPLICATION(lhs.flags & compensation_conv_asymmetric_src, |
186 | lhs.asymm_compensation_mask == rhs.asymm_compensation_mask); |
187 | } |
188 | |
189 | inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md, |
190 | const memory_desc_t &rhs_md, bool ignore_strides = false) { |
191 | using dnnl::impl::utils::array_cmp; |
192 | |
193 | assert(lhs_md.format_kind == format_kind::blocked); |
194 | assert(rhs_md.format_kind == format_kind::blocked); |
195 | |
196 | const auto &lhs = lhs_md.format_desc.blocking; |
197 | const auto &rhs = rhs_md.format_desc.blocking; |
198 | bool equal = lhs.inner_nblks == rhs.inner_nblks |
199 | && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) |
200 | && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); |
201 | if (ignore_strides) return equal; |
202 | |
203 | // Check the strides. |
204 | // Note: for dimensions of size `1` the stride doesn't really matter. |
205 | for (int d = 0; d < lhs_md.ndims; ++d) { |
206 | if (lhs_md.dims[d] == 1 && lhs_md.padded_dims[d] == 1) continue; |
207 | equal = equal && lhs.strides[d] == rhs.strides[d]; |
208 | } |
209 | |
210 | return equal; |
211 | } |
212 | |
213 | inline bool wino_desc_is_equal(const wino_desc_t &lhs, const wino_desc_t &rhs) { |
214 | return lhs.wino_format == rhs.wino_format && lhs.alpha == rhs.alpha |
215 | && lhs.ic == rhs.ic && lhs.oc == rhs.oc |
216 | && lhs.ic_block == rhs.ic_block && lhs.oc_block == rhs.oc_block |
217 | && lhs.ic2_block == rhs.ic2_block && lhs.oc2_block == rhs.oc2_block |
218 | && lhs.r == rhs.r; |
219 | } |
220 | |
221 | inline bool rnn_packed_desc_is_equal( |
222 | const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { |
223 | bool ok = true && lhs.format == rhs.format && lhs.ldb == rhs.ldb |
224 | && lhs.n_parts == rhs.n_parts |
225 | && lhs.offset_compensation == rhs.offset_compensation |
226 | && lhs.size == rhs.size && lhs.n == rhs.n; |
227 | if (!ok) return false; |
228 | |
229 | for (int i = 0; i < rhs.n_parts; i++) |
230 | ok = ok && lhs.parts[i] == rhs.parts[i]; |
231 | for (int i = 0; i < rhs.n_parts; i++) |
232 | ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; |
233 | return ok; |
234 | } |
235 | |
236 | inline memory_desc_t zero_md() { |
237 | auto zero = memory_desc_t(); |
238 | return zero; |
239 | } |
240 | |
241 | inline bool is_zero_md(const memory_desc_t *md) { |
242 | return md == nullptr || *md == zero_md(); |
243 | } |
244 | |
245 | inline data_type_t default_accum_data_type( |
246 | data_type_t src_dt, data_type_t dst_dt, bool strict = true) { |
247 | using namespace utils; |
248 | using namespace data_type; |
249 | |
250 | // we allow to use f32 accumulation type only when the |
251 | // accumulation chain is small. Otherwise, strict should be set to |
252 | // true |
253 | if (one_of(src_dt, s8, u8) && (dst_dt != f32 || strict)) return s32; |
254 | |
255 | if (one_of(f16, src_dt, dst_dt)) return f32; |
256 | if (one_of(bf16, src_dt, dst_dt)) return f32; |
257 | if (one_of(f32, src_dt, dst_dt)) return f32; |
258 | if (one_of(f64, src_dt, dst_dt)) return f64; |
259 | if (one_of(s32, src_dt, dst_dt)) return s32; |
260 | |
261 | if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; |
262 | |
263 | return data_type::undef; |
264 | } |
265 | |
266 | inline data_type_t default_accum_data_type(data_type_t src_dt, |
267 | data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { |
268 | using namespace utils; |
269 | using namespace data_type; |
270 | using namespace prop_kind; |
271 | |
272 | /* prop_kind doesn't matter */ |
273 | if (everyone_is(f32, src_dt, wei_dt)) return f32; |
274 | if (everyone_is(f64, src_dt, wei_dt)) return f64; |
275 | |
276 | if (one_of(prop_kind, forward_training, forward_inference)) { |
277 | if ((src_dt == u8 || src_dt == s8) && wei_dt == s8) return s32; |
278 | if (one_of(f16, src_dt, wei_dt)) return f32; |
279 | } else if (prop_kind == backward_data) { |
280 | if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 |
281 | && one_of(dst_dt, s8, u8, s32)) |
282 | return s32; |
283 | if (one_of(f16, dst_dt, wei_dt)) return f32; |
284 | if (everyone_is(f32, dst_dt, wei_dt) && one_of(src_dt, s8, u8)) |
285 | return f32; |
286 | } |
287 | |
288 | if (one_of(bf16, src_dt, wei_dt, dst_dt)) return f32; |
289 | if (one_of(f16, src_dt, wei_dt, dst_dt)) return f32; |
290 | |
291 | return data_type::undef; |
292 | } |
293 | |
294 | inline bool is_integral_dt(data_type_t dt) { |
295 | using namespace data_type; |
296 | return utils::one_of(dt, s32, s8, u8); |
297 | } |
298 | |
299 | template <typename data_t> |
300 | inline void cvt_from_float(data_t *out, const float *inp, size_t nelems) { |
301 | assert(!"unimplemented" ); |
302 | } |
303 | |
304 | template <typename data_t> |
305 | inline void cvt_to_float(float *out, const data_t *inp, size_t nelems) { |
306 | assert(!"unimplemented" ); |
307 | } |
308 | |
309 | template <> |
310 | inline void cvt_from_float<bfloat16_t>( |
311 | bfloat16_t *out, const float *inp, size_t nelems) { |
312 | cvt_float_to_bfloat16(out, inp, nelems); |
313 | } |
314 | |
315 | template <> |
316 | inline void cvt_to_float<bfloat16_t>( |
317 | float *out, const bfloat16_t *inp, size_t nelems) { |
318 | cvt_bfloat16_to_float(out, inp, nelems); |
319 | } |
320 | |
321 | template <> |
322 | inline void cvt_from_float<float16_t>( |
323 | float16_t *out, const float *inp, size_t nelems) { |
324 | cvt_float_to_float16(out, inp, nelems); |
325 | } |
326 | |
327 | template <> |
328 | inline void cvt_to_float<float16_t>( |
329 | float *out, const float16_t *inp, size_t nelems) { |
330 | cvt_float16_to_float(out, inp, nelems); |
331 | } |
332 | |
333 | inline void cvt_from_float( |
334 | data_type_t dt, void *out, const float *inp, size_t nelems) { |
335 | switch (dt) { |
336 | case data_type::bf16: |
337 | cvt_from_float((bfloat16_t *)out, inp, nelems); |
338 | break; |
339 | case data_type::f16: |
340 | cvt_from_float((float16_t *)out, inp, nelems); |
341 | break; |
342 | default: assert(!"unimplemented" ); |
343 | } |
344 | } |
345 | |
346 | } // namespace types |
347 | |
348 | inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { |
349 | using namespace dnnl::impl::utils; |
350 | // quick path for zero_mds |
351 | if (utils::everyone_is(0, lhs.ndims, rhs.ndims)) return true; |
352 | |
353 | bool base_equal = true && lhs.ndims == rhs.ndims |
354 | && array_cmp(lhs.dims, rhs.dims, lhs.ndims) |
355 | && lhs.data_type == rhs.data_type |
356 | && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) |
357 | && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) |
358 | && lhs.offset0 == rhs.offset0 && lhs.format_kind == rhs.format_kind; |
359 | if (!base_equal) return false; |
360 | if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; |
361 | if (lhs.format_kind == format_kind::blocked) |
362 | return types::blocking_desc_is_equal(lhs, rhs); |
363 | else if (lhs.format_kind == format_kind::wino) |
364 | return types::wino_desc_is_equal( |
365 | lhs.format_desc.wino_desc, rhs.format_desc.wino_desc); |
366 | else if (lhs.format_kind == format_kind::rnn_packed) |
367 | return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, |
368 | rhs.format_desc.rnn_packed_desc); |
369 | return true; |
370 | } |
371 | |
372 | inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { |
373 | return !operator==(lhs, rhs); |
374 | } |
375 | |
376 | // Comparison operators for descriptors |
377 | #define COMPARE_DESC_MEMBERS(m) lhs.m == rhs.m |
378 | #define COMPARE_DESC_ARRAY_MEMBERS(m, s) utils::array_cmp(lhs.m, rhs.m, s) |
379 | #define DEREF_AND_COMPARE_DESC_MEMBERS(m) *lhs.m == *rhs.m |
380 | #define COMPARE_FLOAT_DESC_MEMBERS(m) utils::equal_with_nan(lhs.m, rhs.m) |
381 | #define COMPARE_FLOAT_DESC_ARRAY_MEMBERS(m, s) \ |
382 | !std::memcmp(lhs.m, rhs.m, sizeof(float) * s) |
383 | |
384 | // clang-format off |
385 | inline bool operator==(const batch_normalization_desc_t &lhs, |
386 | const batch_normalization_desc_t &rhs) { |
387 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
388 | && COMPARE_DESC_MEMBERS(prop_kind) |
389 | && COMPARE_DESC_MEMBERS(src_desc) |
390 | && COMPARE_DESC_MEMBERS(dst_desc) |
391 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
392 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
393 | && COMPARE_DESC_MEMBERS(scaleshift_desc) |
394 | && COMPARE_DESC_MEMBERS(diff_scaleshift_desc) |
395 | && COMPARE_DESC_MEMBERS(stat_desc) |
396 | && COMPARE_FLOAT_DESC_MEMBERS(batch_norm_epsilon) |
397 | && COMPARE_DESC_MEMBERS(flags); |
398 | return ret; |
399 | } |
400 | |
401 | inline bool operator==(const binary_desc_t &lhs, const binary_desc_t &rhs) { |
402 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
403 | && COMPARE_DESC_MEMBERS(alg_kind) |
404 | && COMPARE_DESC_MEMBERS(src_desc[0]) |
405 | && COMPARE_DESC_MEMBERS(src_desc[1]) |
406 | && COMPARE_DESC_MEMBERS(dst_desc); |
407 | return ret; |
408 | } |
409 | |
410 | inline bool operator==(const concat_desc_t &lhs, const concat_desc_t &rhs) { |
411 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
412 | && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md) |
413 | && COMPARE_DESC_MEMBERS(n) |
414 | && COMPARE_DESC_MEMBERS(concat_dimension); |
415 | |
416 | if (!ret) return ret; |
417 | |
418 | for (int i = 0; i < lhs.n; i++) { |
419 | ret = *lhs.src_mds[i] == *rhs.src_mds[i]; |
420 | if (!ret) break; |
421 | } |
422 | return ret; |
423 | } |
424 | |
425 | inline bool operator==( |
426 | const convolution_desc_t &lhs, const convolution_desc_t &rhs) { |
427 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
428 | && COMPARE_DESC_MEMBERS(prop_kind) |
429 | && COMPARE_DESC_MEMBERS(alg_kind) |
430 | && COMPARE_DESC_MEMBERS(src_desc) |
431 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
432 | && COMPARE_DESC_MEMBERS(weights_desc) |
433 | && COMPARE_DESC_MEMBERS(diff_weights_desc) |
434 | && COMPARE_DESC_MEMBERS(bias_desc) |
435 | && COMPARE_DESC_MEMBERS(diff_bias_desc) |
436 | && COMPARE_DESC_MEMBERS(dst_desc) |
437 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
438 | && COMPARE_DESC_ARRAY_MEMBERS(strides, DNNL_MAX_NDIMS) |
439 | && COMPARE_DESC_ARRAY_MEMBERS(dilates, DNNL_MAX_NDIMS) |
440 | && COMPARE_DESC_ARRAY_MEMBERS(padding[0], DNNL_MAX_NDIMS) |
441 | && COMPARE_DESC_ARRAY_MEMBERS(padding[1], DNNL_MAX_NDIMS) |
442 | && COMPARE_DESC_MEMBERS(accum_data_type); |
443 | return ret; |
444 | } |
445 | |
446 | inline bool operator==(const eltwise_desc_t &lhs, const eltwise_desc_t &rhs) { |
447 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
448 | && COMPARE_DESC_MEMBERS(prop_kind) |
449 | && COMPARE_DESC_MEMBERS(alg_kind) |
450 | && COMPARE_DESC_MEMBERS(src_desc) |
451 | && COMPARE_DESC_MEMBERS(dst_desc) |
452 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
453 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
454 | && COMPARE_FLOAT_DESC_MEMBERS(alpha) |
455 | && COMPARE_FLOAT_DESC_MEMBERS(beta); |
456 | return ret; |
457 | } |
458 | |
459 | inline bool operator==(const gemm_desc_t &lhs, const gemm_desc_t &rhs) { |
460 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
461 | && COMPARE_DESC_MEMBERS(a_desc) |
462 | && COMPARE_DESC_MEMBERS(b_desc) |
463 | && COMPARE_DESC_MEMBERS(c_desc) |
464 | && COMPARE_DESC_MEMBERS(bias_desc) |
465 | && COMPARE_DESC_MEMBERS(acc_type) |
466 | && COMPARE_DESC_MEMBERS(sum_ab) |
467 | && COMPARE_DESC_MEMBERS(sum_ab_type); |
468 | return ret; |
469 | } |
470 | |
471 | inline bool operator==( |
472 | const inner_product_desc_t &lhs, const inner_product_desc_t &rhs) { |
473 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
474 | && COMPARE_DESC_MEMBERS(prop_kind) |
475 | && COMPARE_DESC_MEMBERS(src_desc) |
476 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
477 | && COMPARE_DESC_MEMBERS(weights_desc) |
478 | && COMPARE_DESC_MEMBERS(diff_weights_desc) |
479 | && COMPARE_DESC_MEMBERS(bias_desc) |
480 | && COMPARE_DESC_MEMBERS(diff_bias_desc) |
481 | && COMPARE_DESC_MEMBERS(dst_desc) |
482 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
483 | && COMPARE_DESC_MEMBERS(accum_data_type); |
484 | return ret; |
485 | } |
486 | |
487 | inline bool operator==( |
488 | const layer_normalization_desc_t &lhs, const layer_normalization_desc_t &rhs) { |
489 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
490 | && COMPARE_DESC_MEMBERS(prop_kind) |
491 | && COMPARE_DESC_MEMBERS(src_desc) |
492 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
493 | && COMPARE_DESC_MEMBERS(data_scaleshift_desc) |
494 | && COMPARE_DESC_MEMBERS(diff_data_scaleshift_desc) |
495 | && COMPARE_DESC_MEMBERS(dst_desc) |
496 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
497 | && COMPARE_DESC_MEMBERS(stat_desc) |
498 | && COMPARE_FLOAT_DESC_MEMBERS(layer_norm_epsilon) |
499 | && COMPARE_DESC_MEMBERS(flags); |
500 | return ret; |
501 | } |
502 | |
503 | inline bool operator==(const lrn_desc_t &lhs, const lrn_desc_t &rhs) { |
504 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
505 | && COMPARE_DESC_MEMBERS(prop_kind) |
506 | && COMPARE_DESC_MEMBERS(alg_kind) |
507 | && COMPARE_DESC_MEMBERS(src_desc) |
508 | && COMPARE_DESC_MEMBERS(dst_desc) |
509 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
510 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
511 | && COMPARE_DESC_MEMBERS(local_size) |
512 | && COMPARE_FLOAT_DESC_MEMBERS(lrn_alpha) |
513 | && COMPARE_FLOAT_DESC_MEMBERS(lrn_beta) |
514 | && COMPARE_FLOAT_DESC_MEMBERS(lrn_k); |
515 | return ret; |
516 | } |
517 | |
518 | inline bool operator==(const matmul_desc_t &lhs, const matmul_desc_t &rhs) { |
519 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
520 | && COMPARE_DESC_MEMBERS(src_desc) |
521 | && COMPARE_DESC_MEMBERS(weights_desc) |
522 | && COMPARE_DESC_MEMBERS(bias_desc) |
523 | && COMPARE_DESC_MEMBERS(dst_desc) |
524 | && COMPARE_DESC_MEMBERS(accum_data_type); |
525 | return ret; |
526 | } |
527 | |
528 | inline bool operator==( |
529 | const pooling_desc_t &lhs, const pooling_desc_t &rhs) { |
530 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
531 | && COMPARE_DESC_MEMBERS(prop_kind) |
532 | && COMPARE_DESC_MEMBERS(alg_kind) |
533 | && COMPARE_DESC_MEMBERS(src_desc) |
534 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
535 | && COMPARE_DESC_MEMBERS(dst_desc) |
536 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
537 | && COMPARE_DESC_ARRAY_MEMBERS(strides, DNNL_MAX_NDIMS) |
538 | && COMPARE_DESC_ARRAY_MEMBERS(kernel, DNNL_MAX_NDIMS) |
539 | && COMPARE_DESC_ARRAY_MEMBERS(padding[0], DNNL_MAX_NDIMS) |
540 | && COMPARE_DESC_ARRAY_MEMBERS(padding[1], DNNL_MAX_NDIMS) |
541 | && COMPARE_DESC_ARRAY_MEMBERS(dilation, DNNL_MAX_NDIMS) |
542 | && COMPARE_DESC_MEMBERS(accum_data_type); |
543 | return ret; |
544 | } |
545 | |
546 | inline bool operator==(const prelu_desc_t &lhs, const prelu_desc_t &rhs) { |
547 | const bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
548 | && COMPARE_DESC_MEMBERS(prop_kind) |
549 | && COMPARE_DESC_MEMBERS(src_desc) |
550 | && COMPARE_DESC_MEMBERS(weights_desc) |
551 | && COMPARE_DESC_MEMBERS(dst_desc) |
552 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
553 | && COMPARE_DESC_MEMBERS(diff_weights_desc) |
554 | && COMPARE_DESC_MEMBERS(diff_dst_desc); |
555 | return ret; |
556 | } |
557 | |
558 | inline bool operator==( |
559 | const reduction_desc_t &lhs, const reduction_desc_t &rhs) { |
560 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
561 | && COMPARE_DESC_MEMBERS(alg_kind) |
562 | && COMPARE_DESC_MEMBERS(src_desc) |
563 | && COMPARE_DESC_MEMBERS(dst_desc) |
564 | && COMPARE_FLOAT_DESC_MEMBERS(p) |
565 | && COMPARE_FLOAT_DESC_MEMBERS(eps); |
566 | return ret; |
567 | } |
568 | |
569 | inline bool operator==(const reorder_desc_t &lhs, const reorder_desc_t &rhs) { |
570 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
571 | && DEREF_AND_COMPARE_DESC_MEMBERS(src_md) |
572 | && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md) |
573 | && COMPARE_DESC_MEMBERS(src_engine_kind) |
574 | && COMPARE_DESC_MEMBERS(dst_engine_kind) |
575 | && COMPARE_DESC_MEMBERS(is_cross_engine); |
576 | return ret; |
577 | } |
578 | |
579 | inline bool operator==( |
580 | const resampling_desc_t &lhs, const resampling_desc_t &rhs) { |
581 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
582 | && COMPARE_DESC_MEMBERS(alg_kind) |
583 | && COMPARE_DESC_MEMBERS(src_desc) |
584 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
585 | && COMPARE_DESC_MEMBERS(dst_desc) |
586 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
587 | && COMPARE_FLOAT_DESC_ARRAY_MEMBERS(factors, DNNL_MAX_NDIMS); |
588 | return ret; |
589 | } |
590 | |
591 | inline bool operator==(const rnn_desc_t &lhs, const rnn_desc_t &rhs) { |
592 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
593 | && COMPARE_DESC_MEMBERS(prop_kind) |
594 | && COMPARE_DESC_MEMBERS(cell_kind) |
595 | && COMPARE_DESC_MEMBERS(direction) |
596 | && COMPARE_DESC_MEMBERS(src_layer_desc) |
597 | && COMPARE_DESC_MEMBERS(src_iter_desc) |
598 | && COMPARE_DESC_MEMBERS(src_iter_c_desc) |
599 | && COMPARE_DESC_MEMBERS(weights_layer_desc) |
600 | && COMPARE_DESC_MEMBERS(weights_iter_desc) |
601 | && COMPARE_DESC_MEMBERS(bias_desc) |
602 | && COMPARE_DESC_MEMBERS(dst_layer_desc) |
603 | && COMPARE_DESC_MEMBERS(dst_iter_desc) |
604 | && COMPARE_DESC_MEMBERS(dst_iter_c_desc) |
605 | && COMPARE_DESC_MEMBERS(weights_peephole_desc) |
606 | && COMPARE_DESC_MEMBERS(weights_projection_desc) |
607 | && COMPARE_DESC_MEMBERS(diff_src_layer_desc) |
608 | && COMPARE_DESC_MEMBERS(diff_src_iter_desc) |
609 | && COMPARE_DESC_MEMBERS(diff_src_iter_c_desc) |
610 | && COMPARE_DESC_MEMBERS(diff_weights_layer_desc) |
611 | && COMPARE_DESC_MEMBERS(diff_weights_iter_desc) |
612 | && COMPARE_DESC_MEMBERS(diff_bias_desc) |
613 | && COMPARE_DESC_MEMBERS(diff_dst_layer_desc) |
614 | && COMPARE_DESC_MEMBERS(diff_dst_iter_desc) |
615 | && COMPARE_DESC_MEMBERS(diff_dst_iter_c_desc) |
616 | && COMPARE_DESC_MEMBERS(diff_weights_peephole_desc) |
617 | && COMPARE_DESC_MEMBERS(diff_weights_projection_desc) |
618 | && COMPARE_DESC_MEMBERS(flags) |
619 | && COMPARE_DESC_MEMBERS(activation_kind) |
620 | && COMPARE_FLOAT_DESC_MEMBERS(alpha) |
621 | && COMPARE_FLOAT_DESC_MEMBERS(beta); |
622 | return ret; |
623 | } |
624 | |
625 | inline bool operator==(const shuffle_desc_t &lhs, const shuffle_desc_t &rhs) { |
626 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
627 | && COMPARE_DESC_MEMBERS(prop_kind) |
628 | && COMPARE_DESC_MEMBERS(src_desc) |
629 | && COMPARE_DESC_MEMBERS(dst_desc) |
630 | && COMPARE_DESC_MEMBERS(axis) |
631 | && COMPARE_DESC_MEMBERS(group_size); |
632 | return ret; |
633 | } |
634 | |
635 | inline bool operator==( |
636 | const softmax_desc_t &lhs, const softmax_desc_t &rhs) { |
637 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
638 | && COMPARE_DESC_MEMBERS(prop_kind) |
639 | && COMPARE_DESC_MEMBERS(alg_kind) |
640 | && COMPARE_DESC_MEMBERS(src_desc) |
641 | && COMPARE_DESC_MEMBERS(diff_src_desc) |
642 | && COMPARE_DESC_MEMBERS(dst_desc) |
643 | && COMPARE_DESC_MEMBERS(diff_dst_desc) |
644 | && COMPARE_DESC_MEMBERS(softmax_axis); |
645 | return ret; |
646 | } |
647 | |
648 | inline bool operator==(const sum_desc_t &lhs, const sum_desc_t &rhs) { |
649 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind) |
650 | && DEREF_AND_COMPARE_DESC_MEMBERS(dst_md) |
651 | && COMPARE_DESC_MEMBERS(n); |
652 | if (!ret) return ret; |
653 | |
654 | for (int i = 0; i < lhs.n; i++) { |
655 | ret = *lhs.src_mds[i] == *rhs.src_mds[i]; |
656 | if (!ret) break; |
657 | } |
658 | if (!ret) return ret; |
659 | |
660 | for (int i = 0; i < lhs.n; i++) { |
661 | ret = ret && COMPARE_FLOAT_DESC_MEMBERS(scales[i]); |
662 | if (!ret) break; |
663 | } |
664 | |
665 | return ret; |
666 | } |
667 | |
668 | inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) { |
669 | bool ret = COMPARE_DESC_MEMBERS(primitive_kind); |
670 | return ret; |
671 | } |
672 | // clang-format on |
673 | |
674 | #undef COMPARE_DESC_MEMBERS |
675 | #undef COMPARE_DESC_ARRAY_MEMBERS |
676 | #undef DEREF_AND_COMPARE_DESC_MEMBERS |
677 | #undef COMPARE_FLOAT_DESC_MEMBERS |
678 | #undef COMPARE_FLOAT_DESC_ARRAY_MEMBERS |
679 | |
680 | /** returns true if strides are compatible with memory_desc_t */ |
681 | inline bool memory_desc_strides_check( |
682 | const memory_desc_t &md, const dims_t strides) { |
683 | if (strides == nullptr || md.ndims == 0 |
684 | || md.format_kind != format_kind::blocked) |
685 | return true; |
686 | |
687 | dims_t blocks = {0}; |
688 | int perm[DNNL_MAX_NDIMS] = {0}; |
689 | for (int d = 0; d < md.ndims; ++d) { |
690 | // no strides check needed for empty tensor |
691 | if (md.padded_dims[d] == 0) return true; |
692 | |
693 | // no strides verification for runtime dims |
694 | const bool has_runtime_dim = utils::one_of( |
695 | DNNL_RUNTIME_DIM_VAL, strides[d], md.padded_dims[d]); |
696 | if (has_runtime_dim) return true; |
697 | |
698 | perm[d] = d; |
699 | blocks[d] = 1; |
700 | } |
701 | |
702 | dim_t block_size = 1; |
703 | const auto &blk = md.format_desc.blocking; |
704 | for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { |
705 | blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; |
706 | block_size *= blk.inner_blks[iblk]; |
707 | } |
708 | |
709 | // A custom comparator to yield linear order on perm |
710 | auto idx_sorter = [&](const int a, const int b) -> bool { |
711 | if (strides[a] == strides[b] && md.padded_dims[a] == md.padded_dims[b]) |
712 | return a < b; |
713 | else if (strides[a] == strides[b]) |
714 | return md.padded_dims[a] < md.padded_dims[b]; |
715 | else |
716 | return strides[a] < strides[b]; |
717 | }; |
718 | std::sort(perm, perm + md.ndims, idx_sorter); |
719 | |
720 | dim_t min_stride = block_size; |
721 | for (int idx = 0; idx < md.ndims; ++idx) { |
722 | const int d = perm[idx]; |
723 | |
724 | // Make an exception for strides[d] == 0 as it has broadcast semantics |
725 | // Note: owing to being sorted, these are the initial strides |
726 | |
727 | // FIXME: make an exception for dims[d] == 1 with the |
728 | // assumption that no code applies that stride when the only |
729 | // index accessed for that dimenstion is 0. This is because PT |
730 | // can use "dummy" padding in those situations |
731 | if ((strides[d] == 0) || (md.padded_dims[d] == 1)) |
732 | continue; |
733 | else if (strides[d] < min_stride) |
734 | return false; |
735 | |
736 | // update min_stride for next iteration |
737 | const auto padded_dim = md.padded_dims[d]; |
738 | min_stride = block_size * strides[d] * (padded_dim / blocks[d]); |
739 | } |
740 | return true; |
741 | } |
742 | |
743 | inline status_t memory_desc_init_by_strides( |
744 | memory_desc_t &md, const dims_t strides) { |
745 | return memory_desc_init_by_strides( |
746 | md, md.ndims, md.dims, md.data_type, strides); |
747 | } |
748 | |
749 | inline status_t memory_desc_init_by_tag( |
750 | memory_desc_t &md, format_tag_t tag, const dims_t strides = nullptr) { |
751 | status_t status |
752 | = memory_desc_init_by_tag(md, md.ndims, md.dims, md.data_type, tag); |
753 | if (status != status::success || strides == nullptr) return status; |
754 | |
755 | if (!memory_desc_strides_check(md, strides)) |
756 | return status::invalid_arguments; |
757 | |
758 | for (int d = 0; d < md.ndims; ++d) |
759 | md.format_desc.blocking.strides[d] = strides[d]; |
760 | |
761 | return status::success; |
762 | } |
763 | |
764 | /** inits memory descriptor based on logical dimensions kept in @p md, and the |
765 | * blocking structure @p blk. |
766 | * |
767 | * @note blk.strides represent the order only (from smaller to bigger) |
768 | * |
769 | * TODO: move md related functions to one single place |
770 | */ |
771 | inline status_t memory_desc_init_by_blocking_desc( |
772 | memory_desc_t &md, const blocking_desc_t &blk) { |
773 | dims_t blocks = {0}; |
774 | utils::array_set(blocks, 1, md.ndims); |
775 | dim_t block_size = 1; |
776 | for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { |
777 | blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; |
778 | block_size *= blk.inner_blks[iblk]; |
779 | } |
780 | |
781 | for (int d = 0; d < md.ndims; ++d) { |
782 | md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); |
783 | md.padded_offsets[d] = 0; |
784 | } |
785 | md.offset0 = 0; |
786 | |
787 | md.format_kind = format_kind::blocked; |
788 | auto &mblk = md.format_desc.blocking; |
789 | mblk = blk; |
790 | |
791 | const int ndims = nstl::min(DNNL_MAX_NDIMS, md.ndims); // make GCC 5 happy |
792 | utils::array_copy(mblk.strides, blk.strides, ndims); |
793 | |
794 | dims_t ou_blocks = {0}; |
795 | utils::array_copy(ou_blocks, md.padded_dims, ndims); |
796 | |
797 | int perm[DNNL_MAX_NDIMS]; |
798 | for (int d = 0; d < ndims; ++d) { |
799 | perm[d] = d; |
800 | ou_blocks[d] /= blocks[d]; |
801 | } |
802 | |
803 | utils::simultaneous_sort( |
804 | mblk.strides, ou_blocks, perm, ndims, [](stride_t a, stride_t b) { |
805 | if (utils::one_of(DNNL_RUNTIME_DIM_VAL, a, b)) |
806 | return DNNL_RUNTIME_DIM_VAL; |
807 | return b - a; |
808 | }); |
809 | |
810 | dim_t stride = block_size; |
811 | for (int _d = ndims - 1; _d >= 0; --_d) { |
812 | const int d = perm[_d]; |
813 | md.format_desc.blocking.strides[d] = stride; |
814 | if (md.padded_dims[d] != 0) { // Keep same stride for zero dim |
815 | stride *= md.padded_dims[d] / blocks[d]; |
816 | } |
817 | } |
818 | |
819 | md.extra = utils::zero<memory_extra_desc_t>(); |
820 | |
821 | return status::success; |
822 | } |
823 | |
824 | /** inits memory descriptor @p md based on another one memory descriptor |
825 | * @p md_base and given @p data_type. |
826 | * Essentially: { md = md_base; md.dt = data_type; } */ |
827 | inline status_t memory_desc_init_by_md_and_dt(memory_desc_t &md, |
828 | const memory_desc_t &md_base, data_type_t data_type) { |
829 | if (&md != &md_base) md = md_base; |
830 | md.data_type = data_type; |
831 | return status::success; |
832 | } |
833 | |
834 | /** returns true if memory desc @p md corresponds to the given format tag and |
835 | * strides. |
836 | * If strides are not passed (or passed as nullptr) the dense structure is |
837 | * assumed (i.e. the one that memory_desc_init_by_tag() returns). |
838 | * Strides might contain `0` value, indicating the stride must match the one |
839 | * that memory_desc_init_by_tag() returns. |
840 | * Strides might contain `-1` values, that would be ignored during the |
841 | * comparison. For instance, this can be used if a stride along minibatch |
842 | * doesn't matter. */ |
843 | inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, |
844 | const dims_t strides = nullptr) { |
845 | if (md.format_kind != types::format_tag_to_kind(tag)) return false; |
846 | |
847 | memory_desc_t md_gold; |
848 | status_t status = memory_desc_init_by_tag( |
849 | md_gold, md.ndims, md.dims, md.data_type, tag); |
850 | if (status != status::success) return false; |
851 | |
852 | if (md.format_kind != format_kind::blocked) |
853 | return false; // unimplemented yet |
854 | |
855 | const auto &blk = md.format_desc.blocking; |
856 | const auto &blk_gold = md_gold.format_desc.blocking; |
857 | |
858 | using utils::array_cmp; |
859 | bool same_blocks = true && blk.inner_nblks == blk_gold.inner_nblks |
860 | && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) |
861 | && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); |
862 | |
863 | if (!same_blocks) return false; |
864 | |
865 | if (strides == nullptr) |
866 | return array_cmp(blk.strides, blk_gold.strides, md.ndims); |
867 | |
868 | for (int d = 0; d < md.ndims; ++d) { |
869 | dim_t stride = strides[d]; |
870 | if (stride == -1) continue; |
871 | if (stride == 0) stride = blk_gold.strides[d]; |
872 | if (blk.strides[d] != stride) return false; |
873 | } |
874 | |
875 | return true; |
876 | } |
877 | |
878 | /** returns matching tag (or undef if match is not found) |
879 | * XXX: This is a workaround that eventually should go away! */ |
880 | template <typename... Tags> |
881 | format_tag_t memory_desc_matches_one_of_tag( |
882 | const memory_desc_t &md, Tags... tags) { |
883 | for (const auto tag : {tags...}) { |
884 | if (memory_desc_matches_tag(md, tag)) return tag; |
885 | } |
886 | return format_tag::undef; |
887 | } |
888 | |
889 | /** returns true if fp32 value denotes DNNL_RUNTIME_F32_VAL */ |
890 | inline bool is_runtime_value(float val) { |
891 | return utils::bit_cast<unsigned>(val) == DNNL_RUNTIME_F32_VAL_REP.u; |
892 | } |
893 | |
894 | /** returns true if s32 value denotes DNNL_RUNTIME_S32_VAL */ |
895 | inline bool is_runtime_value(int val) { |
896 | return val == DNNL_RUNTIME_S32_VAL; |
897 | } |
898 | |
899 | /** returns true if dim_t value denotes DNNL_RUNTIME_DIM_VAL */ |
900 | inline bool is_runtime_value(dim_t val) { |
901 | return val == DNNL_RUNTIME_DIM_VAL; |
902 | } |
903 | |
904 | inline bool memory_desc_sanity_check(int ndims, const dims_t dims, |
905 | data_type_t data_type, format_kind_t format_kind) { |
906 | using namespace data_type; |
907 | |
908 | if (ndims == 0) return true; |
909 | |
910 | bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS |
911 | && utils::one_of(data_type, f16, bf16, f32, f64, s32, s8, u8); |
912 | if (!ok) return false; |
913 | |
914 | bool has_runtime_dims = false; |
915 | for (int d = 0; d < ndims; ++d) { |
916 | if (dims[d] != DNNL_RUNTIME_DIM_VAL && dims[d] < 0) return false; |
917 | if (dims[d] == DNNL_RUNTIME_DIM_VAL) has_runtime_dims = true; |
918 | } |
919 | |
920 | if (has_runtime_dims) { |
921 | // format `any` is currently not supported for run-time dims |
922 | if (format_kind == format_kind::any) return false; |
923 | } |
924 | |
925 | return true; |
926 | } |
927 | |
928 | inline bool memory_desc_sanity_check(const memory_desc_t &md) { |
929 | return memory_desc_sanity_check( |
930 | md.ndims, md.dims, md.data_type, format_kind::undef); |
931 | } |
932 | |
933 | inline void copy_c_op_desc(op_desc_t *dst, const op_desc_t *src) { |
934 | #define CASE_OP_DESC(pkind) \ |
935 | case primitive_kind::pkind: dst->pkind = src->pkind; break; |
936 | |
937 | switch ((int)src->kind) { |
938 | CASE_OP_DESC(batch_normalization); |
939 | CASE_OP_DESC(binary); |
940 | CASE_OP_DESC(convolution); |
941 | CASE_OP_DESC(deconvolution); |
942 | CASE_OP_DESC(eltwise); |
943 | CASE_OP_DESC(gemm); |
944 | CASE_OP_DESC(inner_product); |
945 | CASE_OP_DESC(layer_normalization); |
946 | CASE_OP_DESC(lrn); |
947 | CASE_OP_DESC(matmul); |
948 | CASE_OP_DESC(pooling); |
949 | CASE_OP_DESC(prelu); |
950 | CASE_OP_DESC(reduction); |
951 | CASE_OP_DESC(resampling); |
952 | CASE_OP_DESC(rnn); |
953 | CASE_OP_DESC(shuffle); |
954 | CASE_OP_DESC(softmax); |
955 | |
956 | // Internal descs |
957 | CASE_OP_DESC(zero_pad); |
958 | default: assert(!"unknown C primitive kind" ); |
959 | } |
960 | #undef CASE_OP_DESC |
961 | } |
962 | |
963 | } // namespace impl |
964 | } // namespace dnnl |
965 | |
966 | #include "memory_desc_wrapper.hpp" |
967 | |
968 | #endif |
969 | |
970 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
971 | |