1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef DNNL_COMMON_HPP
18#define DNNL_COMMON_HPP
19
20#include <functional>
21#include <stddef.h>
22#include <stdint.h>
23
24#include <vector>
25
26#include "oneapi/dnnl/dnnl.h"
27#include "src/common/bfloat16.hpp"
28#include "src/common/float16.hpp"
29#include "src/common/nstl.hpp"
30
31int check_pd_cache(const_dnnl_primitive_desc_t pd);
32int check_primitive_cache(dnnl_primitive_t p);
33
34#include "common.hpp"
35#include "dnn_types.hpp"
36#include "dnnl_debug.hpp"
37#include "dnnl_memory.hpp"
38#include "utils/compare.hpp"
39#include "utils/dims.hpp"
40#include "utils/dnnl_query.hpp"
41
42#include "tests/test_thread.hpp"
43
44#define for_ for
45
46#define DNN_SAFE(f, s) \
47 do { \
48 dnnl_status_t status__ = f; \
49 if (status__ != dnnl_success) { \
50 if (s == CRIT || s == WARN) { \
51 BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \
52 __PRETTY_FUNCTION__, __LINE__, #f, \
53 status2str(status__), (int)status__); \
54 fflush(0); \
55 if (s == CRIT) exit(2); \
56 } \
57 return FAIL; \
58 } \
59 } while (0)
60
61#define DNN_SAFE_V(f) \
62 do { \
63 dnnl_status_t status__ = f; \
64 if (status__ != dnnl_success) { \
65 BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \
66 __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \
67 status2str(status__), (int)status__); \
68 fflush(0); \
69 exit(2); \
70 } \
71 } while (0)
72
73#define DNN_SAFE_STATUS(f) \
74 do { \
75 dnnl_status_t status__ = f; \
76 if (status__ != dnnl_success) { return status__; } \
77 } while (0)
78
79/* aux */
80using bfloat16_t = dnnl::impl::bfloat16_t;
81using float16_t = dnnl::impl::float16_t;
82template <dnnl_data_type_t>
83struct prec_traits;
84template <>
85struct prec_traits<dnnl_bf16> {
86 typedef bfloat16_t type;
87};
88template <>
89struct prec_traits<dnnl_f16> {
90 typedef float16_t type;
91};
92template <>
93struct prec_traits<dnnl_f32> {
94 typedef float type;
95};
96
97// XXX: benchdnn infra doesn't support double yet.
98// Use float's max/min/epsilon values to avoid following build warnings:
99// warning C4756: overflow in constant arithmetic.
100// This should be fixed once cpu reference in f64 is added.
101template <>
102struct prec_traits<dnnl_f64> {
103 typedef float type;
104};
105template <>
106struct prec_traits<dnnl_s32> {
107 typedef int32_t type;
108};
109template <>
110struct prec_traits<dnnl_s8> {
111 typedef int8_t type;
112};
113template <>
114struct prec_traits<dnnl_u8> {
115 typedef uint8_t type;
116};
117
118#define CASE_ALL(dt) \
119 switch (dt) { \
120 CASE(dnnl_bf16); \
121 CASE(dnnl_f16); \
122 CASE(dnnl_f32); \
123 CASE(dnnl_f64); \
124 CASE(dnnl_s32); \
125 CASE(dnnl_s8); \
126 CASE(dnnl_u8); \
127 default: assert(!"bad data_type"); \
128 }
129
130/* std::numeric_limits::digits functionality */
131inline int digits_dt(dnnl_data_type_t dt) {
132#define CASE(dt) \
133 case dt: \
134 return dnnl::impl::nstl::numeric_limits< \
135 typename prec_traits<dt>::type>::digits;
136
137 CASE_ALL(dt);
138
139#undef CASE
140 return 0;
141}
142
143inline float epsilon_dt(dnnl_data_type_t dt) {
144#define CASE(dt) \
145 case dt: \
146 return (float)dnnl::impl::nstl::numeric_limits< \
147 typename prec_traits<dt>::type>::epsilon();
148
149 CASE_ALL(dt);
150
151#undef CASE
152
153 return 0;
154}
155
156inline float lowest_dt(dnnl_data_type_t dt) {
157#define CASE(dt) \
158 case dt: \
159 return (float)dnnl::impl::nstl::numeric_limits< \
160 typename prec_traits<dt>::type>::lowest();
161
162 CASE_ALL(dt);
163
164#undef CASE
165
166 return 0;
167}
168
169inline float max_dt(dnnl_data_type_t dt) {
170#define CASE(dt) \
171 case dt: \
172 return (float)dnnl::impl::nstl::numeric_limits< \
173 typename prec_traits<dt>::type>::max();
174
175 CASE_ALL(dt);
176
177#undef CASE
178
179 return 0;
180}
181
182#undef CASE_ALL
183
184#define BENCHDNN_S32_TO_F32_SAT_CONST 2147483520.f
185
186template <dnnl_data_type_t dt>
187inline float saturate_and_round(float val) {
188 const float dt_max = max_dt(dt);
189 const float dt_min = (float)dnnl::impl::nstl::numeric_limits<
190 typename prec_traits<dt>::type>::lowest();
191 if (dt == dnnl_s32 && val >= max_dt(dnnl_s32)) return max_dt(dnnl_s32);
192 if (val > dt_max) val = dt_max;
193 if (val < dt_min) val = dt_min;
194 return mxcsr_cvt(val);
195}
196
197inline bool is_integral_dt(dnnl_data_type_t dt) {
198 return dt == dnnl_s32 || dt == dnnl_s8 || dt == dnnl_u8;
199}
200
201inline float maybe_saturate(dnnl_data_type_t dt, float value) {
202 if (!is_integral_dt(dt)) return value;
203
204 switch (dt) {
205#define CASE(dt) \
206 case dt: return saturate_and_round<dt>(value);
207 CASE(dnnl_s32);
208 CASE(dnnl_s8);
209 CASE(dnnl_u8);
210#undef CASE
211 default: assert(!"bad data_type");
212 }
213 return 0;
214}
215
216float round_to_nearest_representable(dnnl_data_type_t dt, float value);
217
218extern dnnl_engine_kind_t engine_tgt_kind;
219extern size_t engine_index;
220extern isa_hints_t hints;
221
222struct engine_t {
223 engine_t(dnnl_engine_kind_t engine_kind);
224 engine_t(dnnl_engine_t engine);
225 engine_t(const engine_t &other);
226 ~engine_t();
227 operator dnnl_engine_t() const { return engine_; }
228
229private:
230 engine_t &operator=(engine_t &other) = delete;
231 dnnl_engine_t engine_;
232 bool is_owner_;
233};
234
235struct stream_t {
236 stream_t(dnnl_engine_t engine, void *interop_obj = nullptr);
237 ~stream_t();
238 operator dnnl_stream_t() const { return stream_; }
239
240private:
241 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(stream_t);
242 dnnl_stream_t stream_;
243};
244
245// Engine used to run oneDNN primitives for testing.
246inline const engine_t &get_test_engine() {
247 if (is_bench_mode(PROF)) {
248 bool is_profiling_supported = false;
249#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL \
250 || DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
251 is_profiling_supported = (engine_tgt_kind == dnnl_gpu);
252#endif
253
254 if (!is_profiling_supported) {
255 fprintf(stderr,
256 "Profiling-based performance mode is supported for OpenCL "
257 "and DPC++ only.\n");
258 exit(2);
259 }
260 }
261 static const engine_t instance(engine_tgt_kind);
262 return instance;
263}
264
265// Engine used to run all reference native implementations and CPU
266// implementations used by `--fast-ref-gpu` option.
267inline const engine_t &get_cpu_engine() {
268#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE
269 // In case of lacking CPU engine, just re-use testing one.
270 return get_test_engine();
271#else
272 static const engine_t instance(dnnl_cpu);
273 return instance;
274#endif
275}
276
277bool is_cpu(const dnnl_engine_t &engine = get_test_engine());
278bool is_gpu(const dnnl_engine_t &engine = get_test_engine());
279bool is_sycl_engine(const dnnl_engine_t &engine = get_test_engine());
280bool is_opencl_engine(const dnnl_engine_t &engine = get_test_engine());
281bool is_nvidia_gpu(const dnnl_engine_t &engine = get_test_engine());
282bool is_f64_supported(const dnnl_engine_t &engine = get_test_engine());
283bool is_amd_gpu(const dnnl_engine_t &engine = get_test_engine());
284
285// Extended version of dnnl_sycl_interop_memory_kind_t enumeration.
286enum class memory_kind_ext_t {
287 usm, // Same as dnnl_sycl_interop_usm
288 buffer, // Same as dnnl_sycl_interop_buffer
289 usm_device, // USM allocated via malloc_device()
290 usm_shared, // USM allocated via malloc_shared()
291};
292
293const memory_kind_ext_t default_memory_kind = memory_kind_ext_t::usm;
294
295extern memory_kind_ext_t memory_kind;
296
297void init_isa_settings();
298
299struct args_t {
300 args_t &set(int arg, const dnn_mem_t &mem);
301 args_t &set(
302 const std::vector<int> &args, const std::vector<dnn_mem_t> &mems);
303 void clear() { args_.clear(); }
304
305 int size() const { return (int)args_.size(); }
306
307 const dnn_mem_t &find(int arg) const;
308
309 int arg(int index) const { return args_[index].first; }
310 const dnn_mem_t &dnn_mem(int index) const { return *args_[index].second; }
311
312private:
313 std::vector<std::pair<int, const dnn_mem_t *>> args_;
314};
315
316template <typename prb_t>
317struct init_pd_args_t {
318 init_pd_args_t(res_t *res, dnnl_engine_t engine, const prb_t *prb,
319 dir_t dir, const_dnnl_primitive_desc_t hint)
320 : pd(nullptr)
321 , is_iterator_supported(true)
322 , res(res)
323 , engine(engine)
324 , prb(prb)
325 , dir(dir)
326 , hint(hint) {}
327
328 // Output members
329 dnnl_primitive_desc_t pd;
330
331 bool is_iterator_supported;
332
333 // Input members
334 res_t *res;
335 dnnl_engine_t engine;
336 const prb_t *prb;
337 dir_t dir;
338 const_dnnl_primitive_desc_t hint;
339};
340
341bool is_fwd_prop_kind(dnnl_prop_kind_t prop_kind);
342int get_memory_footprint(const_dnnl_primitive_desc_t pd, res_t *res);
343int check_same_pd(const dnnl_primitive_desc_t &pd_no_attr, res_t *res);
344int test_persistent_cache_api(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim,
345 const_dnnl_primitive_desc_t pd, res_t *res);
346int check_mem_size(const_dnnl_memory_desc_t md, res_t *res);
347int check_mem_size(const_dnnl_primitive_desc_t const_pd, res_t *res);
348
349void skip_start(res_t *res);
350void skip_unimplemented_data_type(
351 const std::vector<dnnl_data_type_t> &v_dt, dir_t dir, res_t *res);
352void skip_unimplemented_sum_po(const attr_t &attr, res_t *res,
353 dnnl_data_type_t dst_dt = dnnl_data_type_undef);
354void skip_invalid_inplace(res_t *res, dnnl_data_type_t sdt,
355 dnnl_data_type_t ddt, const std::string &stag, const std::string &dtag);
356void skip_unimplemented_arg_scale(const attr_t &attr, res_t *res);
357
358// `check_dnnl_status` function is called to validate the result of primitive
359// descriptor creation. Based on the status, it produces additional checks:
360// * For `invalid_arguments` it just updates the `res` object with it.
361// * For `unimplemented` it checks whether the lack of support is expected or
362// not. It relies on `skip_unimplemented_prb` function declared and defined
363// at every driver and expects it to find in correspondent namespace from
364// where `prb_t` was picked up. If the case is unknown, `UNIMPLEMENTED` status
365// will be returned.
366template <typename prb_t>
367int check_dnnl_status(dnnl_status_t status, const prb_t *prb, res_t *res) {
368 if (!res || status == dnnl_success) return OK;
369
370 switch (status) {
371 case dnnl_invalid_arguments: res->state = INVALID_ARGUMENTS; break;
372 case dnnl_unimplemented: {
373 // Unconditionally set all Nvidia backend unimplemented cases as
374 // not supported.
375 if (is_nvidia_gpu() || is_amd_gpu()) {
376 res->state = SKIPPED;
377 res->reason = CASE_NOT_SUPPORTED;
378 return OK;
379 }
380
381 // Check driver specific cases of unimplemented functionality.
382 skip_unimplemented_prb(prb, res);
383 if (res->state == SKIPPED) return OK;
384
385 // If the case is not known to be skipped, it is unimplemented.
386 res->state = UNIMPLEMENTED;
387 } break;
388 default: assert(!"unexpected");
389 }
390 return FAIL;
391}
392
393// `fetch_impl` is responsible to provide a valid `pd` under certain conditions:
394// 1. Either valid `pd` or `pd_it` were provided.
395// 2a. It's a service primitive (fwd-for-bwd or cpu-for-gpu or
396// simple-prims-of-complex-prim).
397// 2b. It's a tested primitive and not all implementations hit skip-impl option
398// values.
399template <typename prb_t>
400int fetch_impl(benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> &pdw,
401 init_pd_args_t<prb_t> &init_pd_args, res_t *res, bool is_service_prim) {
402 if (!init_pd_args.pd) return FAIL;
403
404 // Wrapper is expected to come empty.
405 assert(!pdw);
406
407 pdw.reset(init_pd_args.pd);
408
409 // Service primitive is not supposed to utilize further logic.
410 if (is_service_prim) return OK;
411
412 while (true) {
413 const auto impl_name = query_impl_info(pdw);
414 // Skip-impl is not requested or hit. Latest pd already fetched.
415 if (!maybe_skip(impl_name)) return OK;
416
417 BENCHDNN_PRINT(6, "Implementation skipped: %s\n", impl_name.c_str());
418
419 // Iterator is not supported, further logic is not applicable.
420 if (!init_pd_args.is_iterator_supported) {
421 res->state = SKIPPED;
422 res->reason = SKIP_IMPL_HIT;
423 return OK;
424 }
425
426 auto status = dnnl_primitive_desc_next_impl(pdw);
427 if (status == dnnl_last_impl_reached) {
428 BENCHDNN_PRINT(2, "%s\n", "All implementations were skipped!");
429 res->state = SKIPPED;
430 res->reason = SKIP_IMPL_HIT;
431 pdw.reset(nullptr);
432 return OK;
433 } else if (status == dnnl_success) {
434 continue;
435 } else {
436 BENCHDNN_PRINT(0, "%s\n", "Unexpected status from pd iterator.");
437 return FAIL;
438 }
439 }
440
441 // Unreached fail status.
442 return FAIL;
443}
444
445// This is an internal to `init_prim` function that utilizes the logic of
446// creating a `pd` and `prim` and assign them to input wrappers. It allows to
447// remove code duplication and keep all the logic in a single place.
448template <typename func_t, typename prb_t>
449int create_primitive(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &primw,
450 dnnl_engine_t engine, const func_t &init_pd_func, const prb_t *prb,
451 res_t *res, dir_t dir, const_dnnl_primitive_desc_t hint,
452 bool is_service_prim) {
453 dnnl_status_t status = dnnl_success;
454 dnnl_primitive_t prim {};
455
456 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
457
458 init_pd_args_t<prb_t> init_pd_args(res, engine, prb, dir, hint);
459 status = init_pd_func(init_pd_args);
460
461 SAFE(check_dnnl_status(status, prb, res), WARN);
462 if (res->state == SKIPPED) return OK;
463
464 // Fetch also checks if user requested to skip certain implementations.
465 SAFE(fetch_impl(pdw, init_pd_args, res, is_service_prim), WARN);
466 if (res->state == SKIPPED) return OK;
467
468 DNN_SAFE(dnnl_primitive_create(&prim, pdw), WARN);
469 primw.reset(prim);
470
471 return OK;
472}
473
474template <typename func_t, typename prb_t>
475int check_pd_w_and_wo_attr(dnnl_engine_t engine, const func_t &init_pd_func,
476 const prb_t *prb, res_t *res, dir_t dir,
477 const_dnnl_primitive_desc_t hint) {
478
479 if (!attr_same_pd_check || prb->attr.is_def()) return OK;
480
481 if (prb->attr.post_ops.convolution_index() != -1) return OK;
482
483 // Check that adding attributes doesn't cause a fall back to another impl.
484 auto *prb_mutable = const_cast<prb_t *>(prb);
485 auto old_attr = prb_mutable->attr;
486 prb_mutable->attr = attr_t();
487 init_pd_args_t<prb_t> init_pd_args_without_attr(
488 res, engine, prb_mutable, dir, hint);
489 DNN_SAFE(init_pd_func(init_pd_args_without_attr), WARN);
490 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw(
491 init_pd_args_without_attr.pd);
492 prb_mutable->attr = old_attr;
493 SAFE(check_same_pd(pdw, res), WARN);
494 return OK;
495}
496
497template <typename func_t, typename prb_t>
498int init_prim(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim,
499 const func_t &init_pd_func, const prb_t *prb, res_t *res,
500 dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr,
501 bool is_service_prim = false) {
502 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> primw;
503
504 skip_start(res);
505 if (res->state == SKIPPED) return OK;
506 skip_invalid_prb(prb, res);
507 if (res->state == SKIPPED) return OK;
508#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
509
510 // The idea is to create the requested primitive twice using different
511 // engines but the same device and context in the case of OpenCL and DPCPP.
512 // Rationale: make sure that the primitive cache is robust in the case
513 // where CPU and GPU engines are re-created because this is a commonly
514 // used scenario in the frameworks.
515 engine_t engine(get_test_engine());
516
517 // The first primitive creation using a temporary engine.
518 SAFE(create_primitive(primw, engine, init_pd_func, prb, res, dir, hint,
519 is_service_prim),
520 WARN);
521 if (res->state == SKIPPED) return OK;
522
523#endif
524 // The second (if the cache is enabled) primitive creation using the global
525 // test engine. This primitive is expected to come from the cache.
526 SAFE(create_primitive(primw, get_test_engine(), init_pd_func, prb, res, dir,
527 hint, is_service_prim),
528 WARN);
529 if (res->state == SKIPPED) return OK;
530
531 auto pd = query_pd(primw);
532 SAFE(check_mem_size(pd, res), WARN);
533 if (res->state == SKIPPED) return OK;
534
535 // Further checks are only for tested primitives.
536 if (is_service_prim) {
537 user_prim.reset(primw.release());
538 return OK;
539 }
540
541 res->impl_name = query_impl_info(pd);
542 BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
543 // Check that adding attributes doesn't cause a fall back to another impl.
544 SAFE(check_pd_w_and_wo_attr(
545 get_test_engine(), init_pd_func, prb, res, dir, hint),
546 WARN);
547 // Check primitive descriptor is picked up from the cache, if applicable.
548 SAFE(check_pd_cache(pd), WARN);
549 // Check primitive is picked up from the cache, if applicable.
550 SAFE(check_primitive_cache(primw), WARN);
551 // Collect memory footprint for a given primitive descriptor.
552 SAFE(get_memory_footprint(pd, res), WARN);
553
554 SAFE(test_persistent_cache_api(primw, pd, res), WARN);
555
556 user_prim.reset(primw.release());
557 return OK;
558}
559
560template <typename func_t, typename prb_t>
561int init_prim(const thr_ctx_t &thr_ctx,
562 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim,
563 const func_t &init_pd_func, prb_t *prb, res_t *res,
564 dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr,
565 bool is_service_prim = false) {
566 int (*f)(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &, func_t &,
567 const prb_t *, res_t *, dir_t, const_dnnl_primitive_desc_t, bool)
568 = init_prim<func_t, prb_t>;
569 return create_in_thr_ctx(thr_ctx, f, user_prim, init_pd_func, prb, res, dir,
570 hint, is_service_prim);
571}
572
573// `check_correctness` function is designed to be called from every driver where
574// correctness validation is needed. It takes:
575// * A pointer to a `prb_t` problem.
576// * A vector of kinds to compare, to validate several outputs, if applicable.
577// * Backend arguments to compare the output.
578// * Driver's reference memory arguments to compute the reference path, then
579// setup a compare object, and, finally, compare the output.
580// * A reference to function that sets up the compare object, see description
581// below.
582// * A pointer to a `res_t` structure to update validation status.
583// * An optional pointer to CPU primitive for speeding up reference path
584// computation on GPU.
585//
586// The function doesn't return status since we rely on `res` to contain all
587// necessary information about validation results.
588//
589// The function performs several validation steps:
590// * Checks that padded area of all memories are properly zeroed.
591// * Checks that GPU backend haven't modified out-of-boundary memory regions.
592// * Executes driver's reference path, using the problem, driver reference
593// arguments, and CPU primitive for GPU backend, if available.
594// * For each kind to validate it:
595// - Creates and sets up the compare object. Setting is done with
596// `setup_cmp_func`.
597// - Finds correspondent memory arguments from backend and reference and
598// compares them.
599// - Result of comparison is saved into `res` object.
600//
601// `setup_cmp_func` is a function that supposed to be defined in every driver's
602// namespace. Its interface is:
603// `void (compare::compare_t &, const prb_t *, data_kind_t, const args_t &);`
604// It takes:
605// * A reference to a `compare_t` object which the function modifies based on
606// driver's needs.
607// * A pointer to a `prb_t` problem.
608// * `data_kind` value to help to setup threshold depending on output argument.
609// * Driver's reference memory arguments since some drivers can't validate
610// certain scenarios for sure without additional memory arguments.
611// Returns nothing since the object is modified by reference due to lifetime of
612// the compare object is controlled by `check_correctness`.
613//
614// Note: a dedicated non-templated type for `setup_cmp_func_t` could be used but
615// since it relies on a `prb_t` type which is individual for each driver,
616// it is'nt possible without a template.
617template <typename setup_cmp_func_t, typename prb_t>
618void check_correctness(const prb_t *prb, const std::vector<data_kind_t> &kinds,
619 const args_t &args, const args_t &ref_args,
620 const setup_cmp_func_t &setup_cmp_func, res_t *res,
621 dnnl_primitive_t prim_ref = nullptr) {
622
623 for (int i = 0; i < args.size(); ++i) {
624 check_zero_padding(args.dnn_mem(i), args.arg(i), res);
625 check_buffer_overwrite(args.dnn_mem(i), args.arg(i), res);
626 }
627
628 TIME_REF(compute_ref(prb, ref_args, prim_ref));
629
630 for (const auto &kind : kinds) {
631 compare::compare_t cmp;
632 cmp.set_data_kind(kind);
633 setup_cmp_func(cmp, prb, kind, ref_args);
634
635 int arg = 0;
636 switch (kind) {
637 case DST: arg = DNNL_ARG_DST; break;
638 case SRC: arg = DNNL_ARG_DIFF_SRC; break;
639 case SRC_1: arg = DNNL_ARG_DIFF_SRC_1; break;
640 case WEI: arg = DNNL_ARG_DIFF_WEIGHTS; break;
641 case BIA: arg = DNNL_ARG_DIFF_BIAS; break;
642 case MEAN: arg = DNNL_ARG_MEAN; break;
643 case VAR: arg = DNNL_ARG_VARIANCE; break;
644 case SC: arg = DNNL_ARG_DIFF_SCALE; break;
645 case SH: arg = DNNL_ARG_DIFF_SHIFT; break;
646 case DST_ITER: arg = DNNL_ARG_DST_ITER; break;
647 case DST_ITER_C: arg = DNNL_ARG_DST_ITER_C; break;
648 case AUGRU_ATTENTION: arg = DNNL_ARG_DIFF_AUGRU_ATTENTION; break;
649 case SRC_ITER: arg = DNNL_ARG_DIFF_SRC_ITER; break;
650 case SRC_ITER_C: arg = DNNL_ARG_DIFF_SRC_ITER_C; break;
651 case WEI_ITER: arg = DNNL_ARG_DIFF_WEIGHTS_ITER; break;
652 case WEI_PEEPHOLE: arg = DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE; break;
653 case WEI_PROJECTION: arg = DNNL_ARG_DIFF_WEIGHTS_PROJECTION; break;
654 default: assert(!"unsupported kind"); SAFE_V(FAIL);
655 }
656 const auto &mem_dt = args.find(arg);
657 const auto &mem_fp = ref_args.find(arg);
658
659 cmp.compare(mem_fp, mem_dt, prb->attr, res);
660 }
661}
662
663typedef std::function<dnnl_status_t(
664 const dnnl_stream_t &, const std::vector<dnnl_exec_arg_t> &)>
665 perf_function_t;
666
667int execute_and_wait(perf_function_t &exec_func, const dnnl_engine_t &engine,
668 const args_t &args, res_t *res = nullptr);
669int execute_and_wait(
670 dnnl_primitive_t prim, const args_t &args, res_t *res = nullptr);
671
672void reset_gpu_profiling();
673int measure_perf(const thr_ctx_t &ctx, res_t *res, perf_function_t &perf_func,
674 args_t &args);
675int measure_perf(
676 const thr_ctx_t &ctx, res_t *res, dnnl_primitive_t prim, args_t &args);
677
678void maybe_prepare_runtime_scales(dnn_mem_t &scales_m,
679 const attr_t::scale_t &scale, int64_t scale_cnt, const float *scales);
680
681void maybe_prepare_runtime_scales_v2(dnn_mem_t &scales_dt, dnn_mem_t &scales_fp,
682 const attr_t::scale_t &scale, int64_t scale_cnt, const float *scales);
683
684void maybe_prepare_runtime_zero_points(dnn_mem_t &zero_points_m,
685 const attr_t &attr, int arg, int64_t count, const int32_t *zero_points);
686
687void maybe_prepare_runtime_zero_points_v2(dnn_mem_t &zero_points_dt,
688 dnn_mem_t &zero_points_fp, const attr_t &attr, int arg, int64_t count,
689 const int32_t *zero_points);
690
691std::vector<float> prepare_po_vals(const dnn_mem_t &dst_m, const args_t &args,
692 const std::vector<std::pair<int, int>> &v_po_masks,
693 const size_t dst_off);
694
695bool check_md_consistency_with_tag(
696 const_dnnl_memory_desc_t md, const std::string &tag);
697
698memory_kind_ext_t str2memory_kind(const char *str);
699
700float reorder_rescale_factor();
701dims_t md2dims(const dnnl_memory_desc_t &md);
702
703// Function adjusts data type if fpmath mode is present or sum_dt is different
704// from destination_dt. It is used in `cfg` objects that regulate filling.
705dnnl_data_type_t deduce_cfg_data_type(
706 dnnl_data_type_t in_dt, const attr_t &attr, data_kind_t dk);
707
708#endif
709