1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef DNNL_TEST_COMMON_HPP
18#define DNNL_TEST_COMMON_HPP
19
20#ifdef _WIN32
21#include <windows.h> // GetEnvironmentVariable
22#endif
23
24#include <cmath>
25#include <limits>
26#include <numeric>
27#include <sstream>
28#include <stdint.h>
29#include <vector>
30#include <type_traits>
31#include <unordered_map>
32
33#include "gtest/gtest.h"
34
35#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
36#define collapse(x)
37#endif
38
39#include "oneapi/dnnl/dnnl.hpp"
40#include "oneapi/dnnl/dnnl_debug.h"
41
42#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
43#include "oneapi/dnnl/dnnl_threadpool.hpp"
44#endif
45
46#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL || defined(DNNL_WITH_SYCL)
47#include "dnnl_test_common_ocl.hpp"
48#endif
49
50#ifdef DNNL_WITH_SYCL
51#include "oneapi/dnnl/dnnl_sycl.hpp"
52#endif
53
54// Don't move it higher than library public headers
55#include "dnnl_test_macros.hpp"
56
57#include "src/common/bfloat16.hpp"
58#include "src/common/float16.hpp"
59#include "src/common/memory_desc_wrapper.hpp"
60#include "src/common/nstl.hpp"
61#include "src/common/primitive_cache.hpp"
62#include "tests/gtests/test_malloc.hpp"
63#include "tests/test_thread.hpp"
64
65#include "src/cpu/platform.hpp"
66
67#define for_ for
68
69using dnnl::impl::bfloat16_t;
70using dnnl::impl::float16_t;
71
72#ifdef DNNL_ENABLE_MEM_DEBUG
73#define DNNL_CHECK(f) \
74 do { \
75 dnnl_status_t s = (f); \
76 dnnl::error::wrap_c_api(s, dnnl_status2str(s)); \
77 } while (0)
78#else
79#define DNNL_CHECK(f) \
80 do { \
81 dnnl_status_t s = (f); \
82 ASSERT_EQ(s, dnnl_success); \
83 } while (0)
84#endif
85
86// XXX: Using EXPECT_NE in 'if' statement raises a warning when GCC compiler is
87// used: suggest explicit braces to avoid ambiguous 'else'
88#define GTEST_EXPECT_NE(val1, val2) \
89 do { \
90 EXPECT_NE(val1, val2); \
91 } while (0)
92
93using memory = dnnl::memory;
94
95bool is_current_test_failed();
96
97#ifdef DNNL_TEST_WITH_ENGINE_PARAM
98dnnl::engine::kind get_test_engine_kind();
99dnnl::engine get_test_engine();
100#endif
101
102inline int get_vendor_id(const std::string &vendor) {
103 if (vendor == "nvidia") {
104 return 0x10DE;
105 } else if (vendor == "amd") {
106 return 0x1002;
107 } else if (vendor == "intel") {
108 return 0x8086;
109 } else {
110 return -1;
111 }
112}
113
114inline bool is_nvidia_gpu(const dnnl::engine &eng) {
115#ifdef DNNL_WITH_SYCL
116 if (eng.get_kind() != dnnl::engine::kind::gpu) return false;
117 const uint32_t nvidia_vendor_id = get_vendor_id("nvidia");
118 const auto device = dnnl::sycl_interop::get_device(eng);
119 const auto eng_vendor_id
120 = device.get_info<::sycl::info::device::vendor_id>();
121 return eng_vendor_id == nvidia_vendor_id;
122#endif
123 return false;
124}
125
126inline bool is_amd_gpu(const dnnl::engine &eng) {
127#ifdef DNNL_WITH_SYCL
128 if (eng.get_kind() != dnnl::engine::kind::gpu) return false;
129 const uint32_t amd_vendor_id = get_vendor_id("amd");
130 const auto device = dnnl::sycl_interop::get_device(eng);
131 const auto eng_vendor_id
132 = device.get_info<::sycl::info::device::vendor_id>();
133 return eng_vendor_id == amd_vendor_id;
134#endif
135 return false;
136}
137
138inline bool is_sycl_engine(dnnl::engine::kind eng_kind) {
139#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
140 if (eng_kind == dnnl::engine::kind::cpu) return true;
141#endif
142
143#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
144 if (eng_kind == dnnl::engine::kind::gpu) return true;
145#endif
146 return false;
147}
148
149inline bool unsupported_data_type(
150 memory::data_type dt, const dnnl::engine &eng) {
151 bool supported = true; // optimism
152
153#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
154 dnnl::engine::kind kind = eng.get_kind();
155 if (kind == dnnl::engine::kind::cpu)
156 supported = dnnl::impl::cpu::platform::has_data_type_support(
157 memory::convert_to_c(dt));
158#endif
159
160#if defined(DNNL_SYCL_CUDA) || defined(DNNL_SYCL_HIP)
161 if (is_nvidia_gpu(eng) || is_amd_gpu(eng)) {
162 switch (dt) {
163 case memory::data_type::f32: return false;
164 case memory::data_type::f16: return false;
165 case memory::data_type::s8: return false;
166 case memory::data_type::undef: return false;
167 default: return true;
168 }
169 }
170#endif
171 return !supported;
172}
173
174#ifdef DNNL_TEST_WITH_ENGINE_PARAM
175inline bool unsupported_data_type(memory::data_type dt) {
176 return unsupported_data_type(dt, get_test_engine());
177}
178
179template <typename... Rest>
180inline bool unsupported_data_type(
181 memory::data_type first_dt, Rest... rest_dts) {
182 bool rval = unsupported_data_type(first_dt, get_test_engine());
183 if (rval) return rval;
184 return unsupported_data_type(rest_dts...);
185}
186#endif
187
188template <typename data_t>
189struct data_traits {};
190template <>
191struct data_traits<float16_t> {
192 static const auto data_type = memory::data_type::f16;
193
194 using uint_type = uint16_t;
195};
196template <>
197struct data_traits<bfloat16_t> {
198 static const auto data_type = memory::data_type::bf16;
199
200 using uint_type = uint16_t;
201};
202template <>
203struct data_traits<float> {
204 static const auto data_type = memory::data_type::f32;
205
206 using uint_type = uint32_t;
207};
208template <>
209struct data_traits<uint8_t> {
210 static const auto data_type = memory::data_type::u8;
211
212 using uint_type = uint8_t;
213};
214template <>
215struct data_traits<int8_t> {
216 static const auto data_type = memory::data_type::s8;
217
218 using uint_type = uint8_t;
219};
220template <>
221struct data_traits<int32_t> {
222 static const auto data_type = memory::data_type::s32;
223
224 using uint_type = uint32_t;
225};
226
227template <typename T>
228inline void assert_eq(T a, T b);
229template <>
230inline void assert_eq<float>(float a, float b) {
231 ASSERT_FLOAT_EQ(a, b);
232}
233
234#if defined(__x86_64__) || defined(_M_X64)
235#include <immintrin.h>
236inline int mxcsr_cvt(float f) {
237 return _mm_cvtss_si32(_mm_load_ss(&f));
238}
239#else
240inline int mxcsr_cvt(float f) {
241 return (int)nearbyintf(f);
242}
243#endif
244
245template <typename data_t>
246data_t out_round(float x) {
247 return (data_t)mxcsr_cvt(x);
248}
249template <>
250inline float out_round<float>(float x) {
251 return x;
252}
253
254template <typename data_t, typename out_t>
255out_t saturate(const out_t &x) {
256 out_t v = x;
257 if (v <= std::numeric_limits<data_t>::min())
258 v = std::numeric_limits<data_t>::min();
259 if (v > std::numeric_limits<data_t>::max())
260 v = std::numeric_limits<data_t>::max();
261 return v;
262}
263
264inline memory::dim right_padding(memory::dim i, memory::dim o, memory::dim k,
265 memory::dim p, memory::dim s, memory::dim d = 0) {
266 return (o - 1) * s + (k - 1) * (d + 1) - (p + i - 1);
267}
268
269template <typename data_t>
270struct acc_t {
271 typedef data_t type;
272};
273template <>
274struct acc_t<int8_t> {
275 typedef int type;
276};
277template <>
278struct acc_t<uint8_t> {
279 typedef int type;
280};
281
282// Smart pointer for map/unmap operations with unique_ptr semantics
283template <typename T>
284struct mapped_ptr_t {
285 using nonconst_type = typename std::remove_cv<T>::type;
286
287 mapped_ptr_t(std::nullptr_t) : mem_(nullptr), ptr_(nullptr) {}
288 mapped_ptr_t(const memory *mem) : mem_(mem) {
289 ptr_ = mem->map_data<nonconst_type>();
290 }
291 mapped_ptr_t(mapped_ptr_t &&other) : mem_(other.mem_), ptr_(other.ptr_) {
292 other.mem_ = nullptr;
293 other.ptr_ = nullptr;
294 }
295
296 mapped_ptr_t(const mapped_ptr_t &) = delete;
297 mapped_ptr_t &operator=(const mapped_ptr_t &) = delete;
298
299 ~mapped_ptr_t() {
300 if (mem_ && ptr_) mem_->unmap_data(ptr_);
301 };
302
303 operator T *() { return ptr_; }
304 operator const T *() const { return ptr_; }
305 operator bool() const { return ptr_ != nullptr; }
306
307private:
308 const memory *mem_;
309 nonconst_type *ptr_;
310};
311
312template <typename T>
313mapped_ptr_t<T> map_memory(const memory &mem) {
314 return mapped_ptr_t<T>(&mem);
315}
316
317// check_zero_tail - check on zero or set to zero padded memory
318template <typename data_t>
319void check_zero_tail(int set_zero_flag, const memory &src) {
320
321 auto src_data = map_memory<data_t>(src);
322
323 const memory::desc src_d = src.get_desc();
324 const int ndims = src_d.get_ndims();
325 const auto dims = src_d.get_dims();
326 const auto pdims = src_d.get_padded_dims();
327 const dnnl::impl::memory_desc_wrapper mdw(src_d.get());
328
329 memory::dim idx[DNNL_MAX_NDIMS] = {}, str[DNNL_MAX_NDIMS] = {};
330 memory::dim nelems = 1;
331 int tail_flag = 0;
332 for (int i = 0; i < ndims; ++i) {
333 if (dims[ndims - i - 1] != pdims[ndims - i - 1]) tail_flag = 1;
334 nelems *= pdims[ndims - i - 1];
335 idx[i] = 0;
336 str[i] = (i == 0) ? 1 : str[i - 1] * pdims[ndims - i];
337 }
338 if (tail_flag == 0) return;
339
340 for (memory::dim i = 0; i < nelems; ++i) {
341 memory::dim off = 0;
342 bool flag = 0;
343 for (int j = 0; j < ndims; ++j) {
344 off += idx[j] * str[j];
345 if (idx[j] >= dims[ndims - j - 1]) flag = 1;
346 }
347 if (flag == 1) {
348 memory::dim blk_off = mdw.off_l(off, true);
349 if (set_zero_flag) {
350 src_data[blk_off] = 0.0;
351 } else {
352 ASSERT_EQ(src_data[blk_off], 0.0)
353 << " blk_off = " << blk_off << "off = " << off;
354 }
355 }
356 /*Update idx*/
357 for (int j = 0; j < ndims; ++j) {
358 idx[j]++;
359 if (idx[j] < pdims[ndims - j - 1]) break;
360 idx[j] = 0;
361 }
362 }
363}
364
365inline memory::desc create_md(memory::dims dims, memory::data_type data_type,
366 memory::format_tag fmt_tag) {
367 return memory::desc(dims, data_type, fmt_tag);
368}
369
370template <typename data_t>
371static inline data_t set_value(
372 memory::dim index, data_t mean, data_t deviation, double sparsity) {
373 if (data_traits<data_t>::data_type == memory::data_type::f16
374 || data_traits<data_t>::data_type == memory::data_type::bf16) {
375 return data_t(set_value<float>(index, mean, deviation, sparsity));
376 } else if (data_traits<data_t>::data_type == memory::data_type::f32) {
377 const memory::dim group_size = (memory::dim)(1. / sparsity);
378 const memory::dim group = index / group_size;
379 const memory::dim in_group = index % group_size;
380 const bool fill = in_group == ((group % 1637) % group_size);
381 return fill ? static_cast<data_t>(
382 mean + deviation * sinf(float(index % 37)))
383 : data_t {0};
384 } else if (data_traits<data_t>::data_type == memory::data_type::s32
385 || data_traits<data_t>::data_type == memory::data_type::s8) {
386 return data_t(index * 13 % 21 - 10);
387 } else if (data_traits<data_t>::data_type == memory::data_type::u8) {
388 return data_t(index * 13 % 17);
389 }
390 assert(!"not expected");
391 return data_t(0);
392}
393
394template <typename data_t>
395static void fill_data(const memory::dim nelems, data_t *data, data_t mean,
396 data_t deviation, double sparsity = 1.) {
397 dnnl::impl::parallel_nd(nelems, [&](memory::dim n) {
398 data[n] = set_value<data_t>(n, mean, deviation, sparsity);
399 });
400}
401
402template <typename data_t>
403static void fill_data(const memory::dim nelems, const memory &mem, data_t mean,
404 data_t deviation, double sparsity = 1.) {
405 auto data_ptr = map_memory<data_t>(mem);
406 fill_data<data_t>(nelems, data_ptr, mean, deviation, sparsity);
407}
408
409inline void fill_data(memory::data_type dt, const memory &mem, float mean,
410 float deviation, double sparsity = 1.) {
411 size_t nelems = mem.get_desc().get_size() / memory::data_type_size(dt);
412 switch (dt) {
413 case memory::data_type::f32:
414 fill_data<float>(nelems, mem, mean, deviation, sparsity);
415 break;
416 case memory::data_type::bf16:
417 fill_data<bfloat16_t>(nelems, mem, mean, deviation, sparsity);
418 break;
419 case memory::data_type::f16:
420 fill_data<float16_t>(nelems, mem, mean, deviation, sparsity);
421 break;
422 case memory::data_type::s32:
423 fill_data<int>(nelems, mem, mean, deviation, sparsity);
424 break;
425 case memory::data_type::s8:
426 fill_data<int8_t>(nelems, mem, mean, deviation, sparsity);
427 break;
428 case memory::data_type::u8:
429 fill_data<uint8_t>(nelems, mem, mean, deviation, sparsity);
430 break;
431 default: assert(!"unsupported data type"); break;
432 }
433}
434
435template <typename data_t>
436static void fill_data(const memory::dim nelems, data_t *data,
437 double sparsity = 1., bool init_negs = false) {
438 dnnl::impl::parallel_nd(nelems, [&](memory::dim n) {
439 data[n] = set_value<data_t>(n, data_t(1), data_t(2e-1), sparsity);
440
441 if (init_negs && n % 4 == 0)
442 data[n] = static_cast<data_t>(
443 -data[n]); // weird for unsigned types!
444 });
445}
446
447template <typename data_t>
448static void fill_data(const memory::dim nelems, const memory &mem,
449 double sparsity = 1., bool init_negs = false) {
450 auto data_ptr = map_memory<data_t>(mem);
451 fill_data<data_t>(nelems, data_ptr, sparsity, init_negs);
452}
453
454template <typename data_t>
455static void remove_zeroes(const memory &mem) {
456 size_t nelems = mem.get_desc().get_size() / sizeof(data_t);
457 auto data_ptr = map_memory<data_t>(mem);
458 dnnl::impl::parallel_nd(nelems, [&](memory::dim n) {
459 if (data_ptr[n] == data_t(0)) data_ptr[n] += data_t(1);
460 });
461}
462
463template <typename data_t>
464static void compare_data(
465 const memory &ref, const memory &dst, data_t threshold = (data_t)1e-4) {
466 using data_type = memory::data_type;
467
468 ASSERT_TRUE(data_traits<data_t>::data_type == data_type::f32
469 || data_traits<data_t>::data_type == data_type::f16
470 || data_traits<data_t>::data_type == data_type::bf16
471 || data_traits<data_t>::data_type == data_type::s32
472 || data_traits<data_t>::data_type == data_type::s8);
473
474 /* Note: size_t incompatible with MSVC++ */
475 auto ref_desc = ref.get_desc();
476 auto dst_desc = dst.get_desc();
477 const dnnl::impl::memory_desc_wrapper mdw_ref(ref_desc.get());
478 const dnnl::impl::memory_desc_wrapper mdw_dst(dst_desc.get());
479
480 ASSERT_TRUE(ref_desc.get_ndims() == dst_desc.get_ndims());
481
482 auto ndims = ref_desc.get_ndims();
483
484 for (auto d = 0; d < ndims; ++d) {
485 ASSERT_TRUE(ref_desc.get_dims()[d] == dst_desc.get_dims()[d]);
486 }
487
488 auto dims = ref_desc.get_dims();
489
490 memory::dim num = 1;
491 for (auto d = 0; d < ndims; ++d) {
492 num *= dims[d];
493 }
494
495 auto ref_data = map_memory<data_t>(ref);
496 auto dst_data = map_memory<data_t>(dst);
497
498 dnnl::impl::parallel_nd(num, [&](memory::dim i) {
499 if (is_current_test_failed()) return;
500
501 data_t ref = ref_data[mdw_ref.off_l(i, true)];
502 data_t got = dst_data[mdw_dst.off_l(i, true)];
503
504 if (data_traits<data_t>::data_type == data_type::f32
505 || data_traits<data_t>::data_type == data_type::f16
506 || data_traits<data_t>::data_type == data_type::bf16) {
507 const float threshold_f32 = static_cast<float>(threshold);
508 const float ref_f32 = static_cast<float>(ref);
509 const float got_f32 = static_cast<float>(got);
510 const float diff_f32
511 = (got_f32 == ref_f32) ? 0.0f : got_f32 - ref_f32;
512 const float e = (std::abs(ref_f32) > threshold_f32)
513 ? diff_f32 / ref_f32
514 : diff_f32;
515 ASSERT_NEAR(e, 0.0, threshold_f32)
516 << "Index: " << i << " Total: " << num;
517 } else {
518 ASSERT_EQ(ref, got) << "Index: " << i << " Total: " << num;
519 }
520 });
521}
522
523inline const char *query_impl_info(const_dnnl_primitive_desc_t pd) {
524 const char *str;
525 dnnl_primitive_desc_query(pd, dnnl_query_impl_info_str, 0, &str);
526 return str;
527};
528
529inline dnnl_status_t get_conv_impl_status(
530 const_dnnl_primitive_desc_t pd, const char *match_str) {
531 const char *conv_str = query_impl_info(pd);
532
533 if (strstr(conv_str, match_str) != NULL) return dnnl_status_t::dnnl_success;
534 return dnnl_status_t::dnnl_unimplemented;
535};
536
537struct test_convolution_sizes_t {
538 test_convolution_sizes_t(memory::dim mb, memory::dim ng, memory::dim ic,
539 memory::dim ih, memory::dim iw, memory::dim oc, memory::dim oh,
540 memory::dim ow, memory::dim kh, memory::dim kw, memory::dim padh,
541 memory::dim padw, memory::dim strh, memory::dim strw,
542 memory::dim dilh = 0, memory::dim dilw = 0)
543 : mb(mb)
544 , ng(ng)
545 , ic(ic)
546 , ih(ih)
547 , iw(iw)
548 , oc(oc)
549 , oh(oh)
550 , ow(ow)
551 , kh(kh)
552 , kw(kw)
553 , padh(padh)
554 , padw(padw)
555 , strh(strh)
556 , strw(strw)
557 , dilh(dilh)
558 , dilw(dilw) {}
559 memory::dim mb;
560 memory::dim ng;
561 memory::dim ic, ih, iw;
562 memory::dim oc, oh, ow;
563 memory::dim kh, kw;
564 memory::dim padh, padw;
565 memory::dim strh, strw;
566 memory::dim dilh, dilw;
567};
568
569struct test_convolution_attr_t {
570 struct scale_t {
571 enum policy_t { NONE = 0, COMMON };
572
573 bool is_def() const { return policy != NONE; }
574
575 scale_t(float s, policy_t p = NONE) : scale(s) { policy = p; }
576
577 policy_t policy;
578 float scale;
579 };
580
581 void dnnl_attr_recreate() {
582 dnnl_attr = dnnl::primitive_attr();
583 if (src_scale.is_def()) {
584 const int mask = 0;
585 dnnl_attr.set_scales_mask(DNNL_ARG_SRC, mask);
586 }
587 if (wei_scale.is_def()) {
588 const int mask = 0;
589 dnnl_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
590 }
591 if (dst_scale.is_def()) {
592 const int mask = 0;
593 dnnl_attr.set_scales_mask(DNNL_ARG_DST, mask);
594 }
595 }
596
597 test_convolution_attr_t(
598 float s, scale_t::policy_t p = scale_t::policy_t::NONE)
599 : src_scale(s, p), wei_scale(s, p), dst_scale(s, p), dnnl_attr() {}
600
601 test_convolution_attr_t() : test_convolution_attr_t(1.f) {}
602
603 scale_t src_scale;
604 scale_t wei_scale;
605 scale_t dst_scale;
606 dnnl::primitive_attr dnnl_attr;
607};
608
609struct test_convolution_formats_t {
610 memory::format_tag src_format;
611 memory::format_tag weights_format;
612 memory::format_tag bias_format;
613 memory::format_tag dst_format;
614};
615
616struct test_convolution_params_t {
617 dnnl::algorithm aalgorithm;
618 test_convolution_formats_t formats;
619 test_convolution_attr_t attr;
620 test_convolution_sizes_t sizes;
621 bool expect_to_fail;
622 dnnl_status_t expected_status;
623};
624
625struct test_convolution_eltwise_params_t {
626 const dnnl::algorithm alg;
627 dnnl::algorithm aalgorithm;
628 const float eltwise_alpha;
629 const float eltwise_beta;
630 test_convolution_formats_t formats;
631 test_convolution_attr_t attr;
632 test_convolution_sizes_t sizes;
633 bool expect_to_fail;
634 dnnl_status_t expected_status;
635};
636
637template <typename F>
638bool catch_expected_failures(const F &f, bool expect_to_fail,
639 dnnl_status_t expected_status, bool ignore_unimplemented = false) {
640 try {
641 f();
642 } catch (const dnnl::error &e) {
643 // Rethrow the exception if it is not expected or the error status did
644 // not match.
645 if (!(expect_to_fail) || e.status != (expected_status)) {
646 // Ignore unimplemented
647 if (ignore_unimplemented && (e.status == dnnl_unimplemented)) {
648 // Print unimplemented but do not treat as error
649 std::cout << "[ UNIMPL ] "
650 << "Implementation not found" << std::endl;
651 reset_failed_malloc_counter();
652 return true;
653 } else if (test_out_of_memory()
654 && (e.status == dnnl_out_of_memory
655 || e.status == dnnl_unimplemented)) {
656 // Restart if error thrown due to a malloc failed intentionally,
657 // and increment malloc counter.
658 // TODO: This should be valid only for `dnnl_out_of_memory`
659 // error. Currently a failed malloc inside
660 // gemm_pack_storage_shell_t ctor makes it unable to use the
661 // reference RNN impl, and the iterator produces an
662 // `dnnl_unimplemented` error.
663 increment_failed_malloc_counter();
664 return catch_expected_failures(f, expect_to_fail,
665 expected_status, ignore_unimplemented);
666 } else {
667 if (expect_to_fail && (e.status != expected_status))
668 std::cout << "expect failure status mismatch: expect("
669 << dnnl_status2str(expected_status) << ") get("
670 << dnnl_status2str(e.status)
671 << "). Re-throwing...\n";
672 throw e;
673 }
674 }
675 // Return normally if the failure is expected. Reset failed malloc
676 // counter to zero before performing a new test.
677 if (expect_to_fail) {
678 reset_failed_malloc_counter();
679 return true;
680 }
681 }
682
683 // Throw an exception if the failure is expected but did not happen
684 if (expect_to_fail) {
685 std::cout << "expect failure with status("
686 << dnnl_status2str(expected_status) << "), "
687 << "but operation succeed. Throwing an exception...\n";
688 throw std::exception();
689 }
690
691 // Reset failed malloc counter to zero before performing a new test.
692 reset_failed_malloc_counter();
693 return false;
694}
695
696namespace test {
697inline dnnl::memory make_memory(
698 const dnnl::memory::desc &md, const dnnl::engine &eng) {
699
700#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
701 if (eng.get_kind() == dnnl::engine::kind::cpu) {
702 return dnnl::memory(md, eng);
703 }
704#endif
705
706#if defined(TEST_DNNL_OCL_USM)
707 return dnnl::ocl_interop::make_memory(
708 md, eng, dnnl::ocl_interop::memory_kind::usm);
709#elif defined(TEST_DNNL_DPCPP_BUFFER)
710 return dnnl::sycl_interop::make_memory(
711 md, eng, dnnl::sycl_interop::memory_kind::buffer);
712#else
713 return dnnl::memory(md, eng);
714#endif
715}
716
717inline dnnl::memory make_memory(
718 const dnnl::memory::desc &md, const dnnl::engine &eng, void *handle) {
719#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
720 if (eng.get_kind() == dnnl::engine::kind::cpu) {
721 return dnnl::memory(md, eng, handle);
722 }
723#endif
724
725#if defined(TEST_DNNL_OCL_USM)
726 return dnnl::ocl_interop::make_memory(
727 md, eng, dnnl::ocl_interop::memory_kind::usm, handle);
728#elif defined(TEST_DNNL_DPCPP_BUFFER)
729 return dnnl::sycl_interop::make_memory(
730 md, eng, dnnl::sycl_interop::memory_kind::buffer, handle);
731#else
732 return dnnl::memory(md, eng, handle);
733#endif
734}
735} // namespace test
736
737#define TEST_MALLOC_OFFSET 8
738static char *test_malloc(size_t size) {
739 void *ptr;
740 const size_t align = 64;
741 const size_t padded_size = TEST_MALLOC_OFFSET + size;
742#ifdef _WIN32
743 ptr = _aligned_malloc(padded_size, align);
744 int rc = ((ptr) ? 0 : errno);
745#else
746 int rc = ::posix_memalign(&ptr, align, padded_size);
747#endif /* _WIN32 */
748 return rc == 0 ? (char *)ptr + TEST_MALLOC_OFFSET : 0;
749}
750
751static void test_free(char *ptr) {
752 char *base_ptr = ptr - TEST_MALLOC_OFFSET;
753#ifdef _WIN32
754 _aligned_free(base_ptr);
755#else
756 return ::free(base_ptr);
757#endif /* _WIN32 */
758}
759#undef TEST_MALLOC_OFFSET
760
761class test_memory {
762public:
763 test_memory(const memory::desc &d, const dnnl::engine &e) {
764 bool is_cpu_native = (e.get_kind() == dnnl::engine::kind::cpu)
765 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL;
766
767 size_ = d.get_size();
768 if (is_cpu_native) {
769 data_.reset(test_malloc(size_), test_free);
770 mem_ = test::make_memory(d, e, data_.get());
771 } else {
772 mem_ = test::make_memory(d, e);
773 }
774 // Fill with a magic number to catch possible uninitialized access
775 mapped_ptr_t<char> ptr(&mem_);
776 if (ptr) memset(ptr, 0xFF, size_);
777 }
778
779 size_t get_size() const { return size_; }
780 const memory &get() const { return mem_; }
781
782 operator bool() const { return mem_.get(true) != nullptr; }
783
784private:
785 memory mem_;
786 std::shared_ptr<char> data_;
787 size_t size_;
788};
789
790template <typename T>
791mapped_ptr_t<T> map_memory(const test_memory &mem) {
792 return mapped_ptr_t<T>(&mem.get());
793}
794
795inline std::string to_string(dnnl_engine_kind_t engine_kind) {
796 std::stringstream ss;
797 if (engine_kind == dnnl_cpu)
798 ss << "cpu";
799 else if (engine_kind == dnnl_gpu)
800 ss << "gpu";
801 else
802 ss << "unknown";
803
804 return ss.str();
805}
806
807inline std::string to_string(dnnl_stream_flags_t stream_flags) {
808 std::stringstream ss;
809 if (stream_flags & dnnl_stream_default_flags)
810 ss << "default";
811 else if (stream_flags & dnnl_stream_in_order)
812 ss << "in_order";
813 else if (stream_flags & dnnl_stream_out_of_order)
814 ss << "out_of_order";
815
816 return ss.str();
817}
818
819// testing all available C++ primitive descriptor constructors
820struct allows_attr_t {
821 bool po_sum;
822 bool po_eltwise;
823 bool po_binary;
824 bool zp;
825 bool scales;
826};
827
828using engine = dnnl::engine;
829// forward
830template <typename pd_t, typename... prim_params_t>
831void test_fwd_pd_attr(const engine &eng, const prim_params_t &... prim_params) {
832 dnnl::primitive_attr attr;
833 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr));
834}
835
836template <typename pd_t, typename... prim_params_t>
837void test_fwd_pd_attr_po_sum(const engine &eng, bool supports_po_sum,
838 const prim_params_t &... prim_params) {
839 dnnl::post_ops ops_sum;
840 ops_sum.append_sum(1.1f);
841 dnnl::primitive_attr attr_po_sum;
842 attr_po_sum.set_post_ops(ops_sum);
843 if (supports_po_sum)
844 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr_po_sum));
845 else
846 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., attr_po_sum));
847}
848
849template <typename pd_t, typename... prim_params_t>
850void test_fwd_pd_attr_po_eltwise(const engine &eng, bool supports_po_eltwise,
851 const prim_params_t &... prim_params) {
852 dnnl::post_ops ops_eltwise;
853 ops_eltwise.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f);
854 dnnl::primitive_attr attr_po_eltwise;
855 attr_po_eltwise.set_post_ops(ops_eltwise);
856 if (supports_po_eltwise)
857 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr_po_eltwise));
858 else
859 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., attr_po_eltwise));
860}
861
862template <typename pd_t, typename... prim_params_t>
863void test_fwd_pd_attr_po_binary(const engine &eng, bool supports_po_binary,
864 const prim_params_t &... prim_params) {
865 dnnl::post_ops ops_binary;
866 dnnl::memory::desc src1_desc(
867 {16}, memory::data_type::s8, memory::format_tag::x);
868 ops_binary.append_binary(dnnl::algorithm::binary_mul, src1_desc);
869 dnnl::primitive_attr attr_po_binary;
870 attr_po_binary.set_post_ops(ops_binary);
871 if (supports_po_binary)
872 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr_po_binary));
873 else
874 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., attr_po_binary));
875}
876
877template <typename pd_t, typename... prim_params_t>
878void test_fwd_pd_attr_zp(const engine &eng, bool supports_zero_point,
879 const prim_params_t &... prim_params) {
880 dnnl::primitive_attr attr_zp;
881 attr_zp.set_zero_points_mask(DNNL_ARG_SRC, 0);
882 if (supports_zero_point)
883 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr_zp));
884 else
885 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., attr_zp));
886}
887
888template <typename pd_t, typename... prim_params_t>
889void test_fwd_pd_attr_scales(const engine &eng, bool supports_scales,
890 const prim_params_t &... prim_params) {
891 dnnl::primitive_attr attr_scales;
892 attr_scales.set_scales_mask(DNNL_ARG_SRC, 0);
893
894 if (supports_scales) { // Currently only used with binary ops
895 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., attr_scales));
896 } else
897 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., attr_scales));
898}
899
900template <typename pd_t, typename... prim_params_t>
901void test_fwd_pd_allow_empty(
902 const pd_t &pd, const prim_params_t &... prim_params) {
903 bool allow_empty = true;
904 pd_t new_pd {};
905 dnnl::primitive_attr unsupported_attr;
906 // Assumption is that mask "10" is a legit mask for scales
907 // from API perspective.
908 unsupported_attr.set_scales_mask(DNNL_ARG_SRC, 10);
909 ASSERT_NO_THROW(new_pd = pd_t(pd.get_engine(), prim_params...,
910 unsupported_attr, allow_empty));
911 ASSERT_FALSE(new_pd);
912}
913
914// Note: requires a valid primitive descriptor!
915template <typename pd_t, typename... prim_params_t>
916void test_fwd_pd_constructors(const pd_t &pd, const allows_attr_t &aa,
917 const prim_params_t &... prim_params) {
918 auto test_pd = pd_t();
919 auto eng = pd.get_engine();
920 // ctor from C pd, should not throw
921 ASSERT_NO_THROW(test_pd = pd_t(pd.get()));
922 // ctor w/ empty attr, should not throw
923 test_fwd_pd_attr<pd_t>(eng, prim_params...);
924 // following ctors w/ attrs may throw based on pd support
925 test_fwd_pd_attr_po_sum<pd_t>(eng, aa.po_sum, prim_params...);
926 test_fwd_pd_attr_po_eltwise<pd_t>(eng, aa.po_eltwise, prim_params...);
927 test_fwd_pd_attr_po_binary<pd_t>(eng, aa.po_binary, prim_params...);
928 test_fwd_pd_attr_zp<pd_t>(eng, aa.zp, prim_params...);
929 test_fwd_pd_attr_scales<pd_t>(eng, aa.scales, prim_params...);
930 // check allow empty, should not throw
931 test_fwd_pd_allow_empty<pd_t>(test_pd, prim_params...);
932}
933
934// backward: has hint
935template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
936void test_bwd_pd_attr(const engine &eng, const hint_pd_t &hint,
937 const prim_params_t &... prim_params) {
938 dnnl::primitive_attr attr;
939 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., hint, attr));
940}
941
942template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
943void test_bwd_pd_attr_po_sum(const engine &eng, const hint_pd_t &hint,
944 bool supports_po_sum, const prim_params_t &... prim_params) {
945 dnnl::post_ops ops_sum;
946 ops_sum.append_sum(1.1f);
947 dnnl::primitive_attr attr_po_sum;
948 attr_po_sum.set_post_ops(ops_sum);
949 if (supports_po_sum)
950 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., hint, attr_po_sum));
951 else
952 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., hint, attr_po_sum));
953}
954
955template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
956void test_bwd_pd_attr_po_eltwise(const engine &eng, const hint_pd_t &hint,
957 bool supports_po_eltwise, const prim_params_t &... prim_params) {
958 dnnl::post_ops ops_eltwise;
959 ops_eltwise.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f);
960 dnnl::primitive_attr attr_po_eltwise;
961 attr_po_eltwise.set_post_ops(ops_eltwise);
962 if (supports_po_eltwise)
963 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., hint, attr_po_eltwise));
964 else
965 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., hint, attr_po_eltwise));
966}
967
968template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
969void test_bwd_pd_attr_po_binary(const engine &eng, const hint_pd_t &hint,
970 bool supports_po_binary, const prim_params_t &... prim_params) {
971 dnnl::post_ops ops_binary;
972 dnnl::memory::desc src1_desc(
973 {16}, memory::data_type::s8, memory::format_tag::x);
974 ops_binary.append_binary(dnnl::algorithm::binary_mul, src1_desc);
975 dnnl::primitive_attr attr_po_binary;
976 attr_po_binary.set_post_ops(ops_binary);
977 if (supports_po_binary)
978 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., hint, attr_po_binary));
979 else
980 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., hint, attr_po_binary));
981}
982
983template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
984void test_bwd_pd_attr_zp(const engine &eng, const hint_pd_t &hint,
985 bool supports_zero_point, const prim_params_t &... prim_params) {
986 dnnl::primitive_attr attr_zp;
987 attr_zp.set_zero_points_mask(DNNL_ARG_SRC, 0);
988 if (supports_zero_point)
989 EXPECT_NO_THROW(pd_t pd(eng, prim_params..., hint, attr_zp));
990 else
991 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., hint, attr_zp));
992}
993
994template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
995void test_bwd_pd_attr_scales(const engine &eng, const hint_pd_t &hint,
996 bool supports_scales, const prim_params_t &... prim_params) {
997 dnnl::primitive_attr attr_scales;
998 attr_scales.set_scales_mask(DNNL_ARG_SRC, 0);
999 EXPECT_ANY_THROW(pd_t pd(eng, prim_params..., hint, attr_scales));
1000}
1001
1002template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
1003void test_bwd_pd_allow_empty(const pd_t &pd, const hint_pd_t &hint,
1004 const prim_params_t &... prim_params) {
1005 bool allow_empty = true;
1006 pd_t new_pd {};
1007 dnnl::primitive_attr unsupported_attr;
1008 // Assumption is that mask "10" is a legit mask for scales
1009 // from API perspective.
1010 unsupported_attr.set_scales_mask(DNNL_ARG_SRC, 10);
1011 ASSERT_NO_THROW(new_pd = pd_t(pd.get_engine(), prim_params..., hint,
1012 unsupported_attr, allow_empty));
1013 ASSERT_FALSE(new_pd);
1014}
1015
1016// Note: requires a valid primitive descriptor!
1017template <typename pd_t, typename hint_pd_t, typename... prim_params_t>
1018void test_bwd_pd_constructors(const pd_t &pd, const hint_pd_t &hint,
1019 const allows_attr_t &aa, const prim_params_t &... prim_params) {
1020 auto test_pd = pd_t();
1021 auto hint_pd = hint;
1022 auto eng = pd.get_engine();
1023 // ctor from C pd, should not throw
1024 ASSERT_NO_THROW(test_pd = pd_t(pd.get()));
1025 // ctor w/ empty attr, should not throw
1026 test_bwd_pd_attr<pd_t>(eng, hint_pd, prim_params...);
1027 // following ctors w/ attrs may throw based on pd support
1028 test_bwd_pd_attr_po_sum<pd_t>(eng, hint_pd, aa.po_sum, prim_params...);
1029 test_bwd_pd_attr_po_eltwise<pd_t>(
1030 eng, hint_pd, aa.po_eltwise, prim_params...);
1031 test_bwd_pd_attr_po_binary<pd_t>(
1032 eng, hint_pd, aa.po_binary, prim_params...);
1033 test_bwd_pd_attr_zp<pd_t>(eng, hint_pd, aa.zp, prim_params...);
1034 test_bwd_pd_attr_scales<pd_t>(eng, hint_pd, aa.scales, prim_params...);
1035 // check allow empty, should not throw
1036 test_bwd_pd_allow_empty<pd_t>(test_pd, hint_pd, prim_params...);
1037}
1038
1039inline dnnl::stream make_stream(dnnl::engine engine,
1040 dnnl::stream::flags flags = dnnl::stream::flags::default_flags) {
1041#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
1042 if (engine.get_kind() == dnnl::engine::kind::cpu)
1043 return dnnl::threadpool_interop::make_stream(
1044 engine, dnnl::testing::get_threadpool());
1045#endif
1046 return dnnl::stream(engine, flags);
1047}
1048
1049inline int get_primitive_cache_size() {
1050 int result = 0;
1051 auto status = dnnl::impl::get_primitive_cache_size(&result);
1052 if (status != dnnl::impl::status::success) return -1;
1053 return result;
1054}
1055
1056// This is a local copy of dnnl::impl::getenv.
1057// Copying to avoid exposure of internal symbol from the library.
1058inline int gtest_getenv(const char *name, char *buffer, int buffer_size) {
1059 if (name == nullptr || buffer_size < 0
1060 || (buffer == nullptr && buffer_size > 0))
1061 return INT_MIN;
1062
1063 int result = 0;
1064 int term_zero_idx = 0;
1065 size_t value_length = 0;
1066
1067#ifdef _WIN32
1068 value_length = GetEnvironmentVariable(name, buffer, buffer_size);
1069#else
1070 const char *value = ::getenv(name);
1071 value_length = value == nullptr ? 0 : strlen(value);
1072#endif
1073
1074 if (value_length > INT_MAX)
1075 result = INT_MIN;
1076 else {
1077 int int_value_length = (int)value_length;
1078 if (int_value_length >= buffer_size) {
1079 result = -int_value_length;
1080 } else {
1081 term_zero_idx = int_value_length;
1082 result = int_value_length;
1083#ifndef _WIN32
1084 if (value) strncpy(buffer, value, buffer_size - 1);
1085#endif
1086 }
1087 }
1088
1089 if (buffer != nullptr) buffer[term_zero_idx] = '\0';
1090 return result;
1091}
1092
1093#endif
1094