1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef DNNL_COMMON_HPP |
18 | #define DNNL_COMMON_HPP |
19 | |
20 | #include <functional> |
21 | #include <stddef.h> |
22 | #include <stdint.h> |
23 | |
24 | #include <vector> |
25 | |
26 | #include "oneapi/dnnl/dnnl.h" |
27 | #include "src/common/bfloat16.hpp" |
28 | #include "src/common/float16.hpp" |
29 | #include "src/common/nstl.hpp" |
30 | |
31 | int check_pd_cache(const_dnnl_primitive_desc_t pd); |
32 | int check_primitive_cache(dnnl_primitive_t p); |
33 | |
34 | #include "common.hpp" |
35 | #include "dnn_types.hpp" |
36 | #include "dnnl_debug.hpp" |
37 | #include "dnnl_memory.hpp" |
38 | #include "utils/compare.hpp" |
39 | #include "utils/dims.hpp" |
40 | #include "utils/dnnl_query.hpp" |
41 | |
42 | #include "tests/test_thread.hpp" |
43 | |
44 | #define for_ for |
45 | |
46 | #define DNN_SAFE(f, s) \ |
47 | do { \ |
48 | dnnl_status_t status__ = f; \ |
49 | if (status__ != dnnl_success) { \ |
50 | if (s == CRIT || s == WARN) { \ |
51 | BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \ |
52 | __PRETTY_FUNCTION__, __LINE__, #f, \ |
53 | status2str(status__), (int)status__); \ |
54 | fflush(0); \ |
55 | if (s == CRIT) exit(2); \ |
56 | } \ |
57 | return FAIL; \ |
58 | } \ |
59 | } while (0) |
60 | |
61 | #define DNN_SAFE_V(f) \ |
62 | do { \ |
63 | dnnl_status_t status__ = f; \ |
64 | if (status__ != dnnl_success) { \ |
65 | BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \ |
66 | __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \ |
67 | status2str(status__), (int)status__); \ |
68 | fflush(0); \ |
69 | exit(2); \ |
70 | } \ |
71 | } while (0) |
72 | |
73 | #define DNN_SAFE_STATUS(f) \ |
74 | do { \ |
75 | dnnl_status_t status__ = f; \ |
76 | if (status__ != dnnl_success) { return status__; } \ |
77 | } while (0) |
78 | |
79 | /* aux */ |
80 | using bfloat16_t = dnnl::impl::bfloat16_t; |
81 | using float16_t = dnnl::impl::float16_t; |
82 | template <dnnl_data_type_t> |
83 | struct prec_traits; |
84 | template <> |
85 | struct prec_traits<dnnl_bf16> { |
86 | typedef bfloat16_t type; |
87 | }; |
88 | template <> |
89 | struct prec_traits<dnnl_f16> { |
90 | typedef float16_t type; |
91 | }; |
92 | template <> |
93 | struct prec_traits<dnnl_f32> { |
94 | typedef float type; |
95 | }; |
96 | |
97 | // XXX: benchdnn infra doesn't support double yet. |
98 | // Use float's max/min/epsilon values to avoid following build warnings: |
99 | // warning C4756: overflow in constant arithmetic. |
100 | // This should be fixed once cpu reference in f64 is added. |
101 | template <> |
102 | struct prec_traits<dnnl_f64> { |
103 | typedef float type; |
104 | }; |
105 | template <> |
106 | struct prec_traits<dnnl_s32> { |
107 | typedef int32_t type; |
108 | }; |
109 | template <> |
110 | struct prec_traits<dnnl_s8> { |
111 | typedef int8_t type; |
112 | }; |
113 | template <> |
114 | struct prec_traits<dnnl_u8> { |
115 | typedef uint8_t type; |
116 | }; |
117 | |
118 | #define CASE_ALL(dt) \ |
119 | switch (dt) { \ |
120 | CASE(dnnl_bf16); \ |
121 | CASE(dnnl_f16); \ |
122 | CASE(dnnl_f32); \ |
123 | CASE(dnnl_f64); \ |
124 | CASE(dnnl_s32); \ |
125 | CASE(dnnl_s8); \ |
126 | CASE(dnnl_u8); \ |
127 | default: assert(!"bad data_type"); \ |
128 | } |
129 | |
130 | /* std::numeric_limits::digits functionality */ |
131 | inline int digits_dt(dnnl_data_type_t dt) { |
132 | #define CASE(dt) \ |
133 | case dt: \ |
134 | return dnnl::impl::nstl::numeric_limits< \ |
135 | typename prec_traits<dt>::type>::digits; |
136 | |
137 | CASE_ALL(dt); |
138 | |
139 | #undef CASE |
140 | return 0; |
141 | } |
142 | |
143 | inline float epsilon_dt(dnnl_data_type_t dt) { |
144 | #define CASE(dt) \ |
145 | case dt: \ |
146 | return (float)dnnl::impl::nstl::numeric_limits< \ |
147 | typename prec_traits<dt>::type>::epsilon(); |
148 | |
149 | CASE_ALL(dt); |
150 | |
151 | #undef CASE |
152 | |
153 | return 0; |
154 | } |
155 | |
156 | inline float lowest_dt(dnnl_data_type_t dt) { |
157 | #define CASE(dt) \ |
158 | case dt: \ |
159 | return (float)dnnl::impl::nstl::numeric_limits< \ |
160 | typename prec_traits<dt>::type>::lowest(); |
161 | |
162 | CASE_ALL(dt); |
163 | |
164 | #undef CASE |
165 | |
166 | return 0; |
167 | } |
168 | |
169 | inline float max_dt(dnnl_data_type_t dt) { |
170 | #define CASE(dt) \ |
171 | case dt: \ |
172 | return (float)dnnl::impl::nstl::numeric_limits< \ |
173 | typename prec_traits<dt>::type>::max(); |
174 | |
175 | CASE_ALL(dt); |
176 | |
177 | #undef CASE |
178 | |
179 | return 0; |
180 | } |
181 | |
182 | #undef CASE_ALL |
183 | |
184 | #define BENCHDNN_S32_TO_F32_SAT_CONST 2147483520.f |
185 | |
186 | template <dnnl_data_type_t dt> |
187 | inline float saturate_and_round(float val) { |
188 | const float dt_max = max_dt(dt); |
189 | const float dt_min = (float)dnnl::impl::nstl::numeric_limits< |
190 | typename prec_traits<dt>::type>::lowest(); |
191 | if (dt == dnnl_s32 && val >= max_dt(dnnl_s32)) return max_dt(dnnl_s32); |
192 | if (val > dt_max) val = dt_max; |
193 | if (val < dt_min) val = dt_min; |
194 | return mxcsr_cvt(val); |
195 | } |
196 | |
197 | inline bool is_integral_dt(dnnl_data_type_t dt) { |
198 | return dt == dnnl_s32 || dt == dnnl_s8 || dt == dnnl_u8; |
199 | } |
200 | |
201 | inline float maybe_saturate(dnnl_data_type_t dt, float value) { |
202 | if (!is_integral_dt(dt)) return value; |
203 | |
204 | switch (dt) { |
205 | #define CASE(dt) \ |
206 | case dt: return saturate_and_round<dt>(value); |
207 | CASE(dnnl_s32); |
208 | CASE(dnnl_s8); |
209 | CASE(dnnl_u8); |
210 | #undef CASE |
211 | default: assert(!"bad data_type" ); |
212 | } |
213 | return 0; |
214 | } |
215 | |
216 | float round_to_nearest_representable(dnnl_data_type_t dt, float value); |
217 | |
218 | extern dnnl_engine_kind_t engine_tgt_kind; |
219 | extern size_t engine_index; |
220 | extern isa_hints_t hints; |
221 | |
222 | struct engine_t { |
223 | engine_t(dnnl_engine_kind_t engine_kind); |
224 | engine_t(dnnl_engine_t engine); |
225 | engine_t(const engine_t &other); |
226 | ~engine_t(); |
227 | operator dnnl_engine_t() const { return engine_; } |
228 | |
229 | private: |
230 | engine_t &operator=(engine_t &other) = delete; |
231 | dnnl_engine_t engine_; |
232 | bool is_owner_; |
233 | }; |
234 | |
235 | struct stream_t { |
236 | stream_t(dnnl_engine_t engine, void *interop_obj = nullptr); |
237 | ~stream_t(); |
238 | operator dnnl_stream_t() const { return stream_; } |
239 | |
240 | private: |
241 | BENCHDNN_DISALLOW_COPY_AND_ASSIGN(stream_t); |
242 | dnnl_stream_t stream_; |
243 | }; |
244 | |
245 | // Engine used to run oneDNN primitives for testing. |
246 | inline const engine_t &get_test_engine() { |
247 | if (is_bench_mode(PROF)) { |
248 | bool is_profiling_supported = false; |
249 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL \ |
250 | || DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL |
251 | is_profiling_supported = (engine_tgt_kind == dnnl_gpu); |
252 | #endif |
253 | |
254 | if (!is_profiling_supported) { |
255 | fprintf(stderr, |
256 | "Profiling-based performance mode is supported for OpenCL " |
257 | "and DPC++ only.\n" ); |
258 | exit(2); |
259 | } |
260 | } |
261 | static const engine_t instance(engine_tgt_kind); |
262 | return instance; |
263 | } |
264 | |
265 | // Engine used to run all reference native implementations and CPU |
266 | // implementations used by `--fast-ref-gpu` option. |
267 | inline const engine_t &get_cpu_engine() { |
268 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE |
269 | // In case of lacking CPU engine, just re-use testing one. |
270 | return get_test_engine(); |
271 | #else |
272 | static const engine_t instance(dnnl_cpu); |
273 | return instance; |
274 | #endif |
275 | } |
276 | |
277 | bool is_cpu(const dnnl_engine_t &engine = get_test_engine()); |
278 | bool is_gpu(const dnnl_engine_t &engine = get_test_engine()); |
279 | bool is_sycl_engine(const dnnl_engine_t &engine = get_test_engine()); |
280 | bool is_opencl_engine(const dnnl_engine_t &engine = get_test_engine()); |
281 | bool is_nvidia_gpu(const dnnl_engine_t &engine = get_test_engine()); |
282 | bool is_f64_supported(const dnnl_engine_t &engine = get_test_engine()); |
283 | bool is_amd_gpu(const dnnl_engine_t &engine = get_test_engine()); |
284 | |
285 | // Extended version of dnnl_sycl_interop_memory_kind_t enumeration. |
286 | enum class memory_kind_ext_t { |
287 | usm, // Same as dnnl_sycl_interop_usm |
288 | buffer, // Same as dnnl_sycl_interop_buffer |
289 | usm_device, // USM allocated via malloc_device() |
290 | usm_shared, // USM allocated via malloc_shared() |
291 | }; |
292 | |
293 | const memory_kind_ext_t default_memory_kind = memory_kind_ext_t::usm; |
294 | |
295 | extern memory_kind_ext_t memory_kind; |
296 | |
297 | void init_isa_settings(); |
298 | |
299 | struct args_t { |
300 | args_t &set(int arg, const dnn_mem_t &mem); |
301 | args_t &set( |
302 | const std::vector<int> &args, const std::vector<dnn_mem_t> &mems); |
303 | void clear() { args_.clear(); } |
304 | |
305 | int size() const { return (int)args_.size(); } |
306 | |
307 | const dnn_mem_t &find(int arg) const; |
308 | |
309 | int arg(int index) const { return args_[index].first; } |
310 | const dnn_mem_t &dnn_mem(int index) const { return *args_[index].second; } |
311 | |
312 | private: |
313 | std::vector<std::pair<int, const dnn_mem_t *>> args_; |
314 | }; |
315 | |
316 | template <typename prb_t> |
317 | struct init_pd_args_t { |
318 | init_pd_args_t(res_t *res, dnnl_engine_t engine, const prb_t *prb, |
319 | dir_t dir, const_dnnl_primitive_desc_t hint) |
320 | : pd(nullptr) |
321 | , is_iterator_supported(true) |
322 | , res(res) |
323 | , engine(engine) |
324 | , prb(prb) |
325 | , dir(dir) |
326 | , hint(hint) {} |
327 | |
328 | // Output members |
329 | dnnl_primitive_desc_t pd; |
330 | |
331 | bool is_iterator_supported; |
332 | |
333 | // Input members |
334 | res_t *res; |
335 | dnnl_engine_t engine; |
336 | const prb_t *prb; |
337 | dir_t dir; |
338 | const_dnnl_primitive_desc_t hint; |
339 | }; |
340 | |
341 | bool is_fwd_prop_kind(dnnl_prop_kind_t prop_kind); |
342 | int (const_dnnl_primitive_desc_t pd, res_t *res); |
343 | int check_same_pd(const dnnl_primitive_desc_t &pd_no_attr, res_t *res); |
344 | int test_persistent_cache_api(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim, |
345 | const_dnnl_primitive_desc_t pd, res_t *res); |
346 | int check_mem_size(const_dnnl_memory_desc_t md, res_t *res); |
347 | int check_mem_size(const_dnnl_primitive_desc_t const_pd, res_t *res); |
348 | |
349 | void skip_start(res_t *res); |
350 | void skip_unimplemented_data_type( |
351 | const std::vector<dnnl_data_type_t> &v_dt, dir_t dir, res_t *res); |
352 | void skip_unimplemented_sum_po(const attr_t &attr, res_t *res, |
353 | dnnl_data_type_t dst_dt = dnnl_data_type_undef); |
354 | void skip_invalid_inplace(res_t *res, dnnl_data_type_t sdt, |
355 | dnnl_data_type_t ddt, const std::string &stag, const std::string &dtag); |
356 | void skip_unimplemented_arg_scale(const attr_t &attr, res_t *res); |
357 | |
358 | // `check_dnnl_status` function is called to validate the result of primitive |
359 | // descriptor creation. Based on the status, it produces additional checks: |
360 | // * For `invalid_arguments` it just updates the `res` object with it. |
361 | // * For `unimplemented` it checks whether the lack of support is expected or |
362 | // not. It relies on `skip_unimplemented_prb` function declared and defined |
363 | // at every driver and expects it to find in correspondent namespace from |
364 | // where `prb_t` was picked up. If the case is unknown, `UNIMPLEMENTED` status |
365 | // will be returned. |
366 | template <typename prb_t> |
367 | int check_dnnl_status(dnnl_status_t status, const prb_t *prb, res_t *res) { |
368 | if (!res || status == dnnl_success) return OK; |
369 | |
370 | switch (status) { |
371 | case dnnl_invalid_arguments: res->state = INVALID_ARGUMENTS; break; |
372 | case dnnl_unimplemented: { |
373 | // Unconditionally set all Nvidia backend unimplemented cases as |
374 | // not supported. |
375 | if (is_nvidia_gpu() || is_amd_gpu()) { |
376 | res->state = SKIPPED; |
377 | res->reason = CASE_NOT_SUPPORTED; |
378 | return OK; |
379 | } |
380 | |
381 | // Check driver specific cases of unimplemented functionality. |
382 | skip_unimplemented_prb(prb, res); |
383 | if (res->state == SKIPPED) return OK; |
384 | |
385 | // If the case is not known to be skipped, it is unimplemented. |
386 | res->state = UNIMPLEMENTED; |
387 | } break; |
388 | default: assert(!"unexpected" ); |
389 | } |
390 | return FAIL; |
391 | } |
392 | |
393 | // `fetch_impl` is responsible to provide a valid `pd` under certain conditions: |
394 | // 1. Either valid `pd` or `pd_it` were provided. |
395 | // 2a. It's a service primitive (fwd-for-bwd or cpu-for-gpu or |
396 | // simple-prims-of-complex-prim). |
397 | // 2b. It's a tested primitive and not all implementations hit skip-impl option |
398 | // values. |
399 | template <typename prb_t> |
400 | int fetch_impl(benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> &pdw, |
401 | init_pd_args_t<prb_t> &init_pd_args, res_t *res, bool is_service_prim) { |
402 | if (!init_pd_args.pd) return FAIL; |
403 | |
404 | // Wrapper is expected to come empty. |
405 | assert(!pdw); |
406 | |
407 | pdw.reset(init_pd_args.pd); |
408 | |
409 | // Service primitive is not supposed to utilize further logic. |
410 | if (is_service_prim) return OK; |
411 | |
412 | while (true) { |
413 | const auto impl_name = query_impl_info(pdw); |
414 | // Skip-impl is not requested or hit. Latest pd already fetched. |
415 | if (!maybe_skip(impl_name)) return OK; |
416 | |
417 | BENCHDNN_PRINT(6, "Implementation skipped: %s\n" , impl_name.c_str()); |
418 | |
419 | // Iterator is not supported, further logic is not applicable. |
420 | if (!init_pd_args.is_iterator_supported) { |
421 | res->state = SKIPPED; |
422 | res->reason = SKIP_IMPL_HIT; |
423 | return OK; |
424 | } |
425 | |
426 | auto status = dnnl_primitive_desc_next_impl(pdw); |
427 | if (status == dnnl_last_impl_reached) { |
428 | BENCHDNN_PRINT(2, "%s\n" , "All implementations were skipped!" ); |
429 | res->state = SKIPPED; |
430 | res->reason = SKIP_IMPL_HIT; |
431 | pdw.reset(nullptr); |
432 | return OK; |
433 | } else if (status == dnnl_success) { |
434 | continue; |
435 | } else { |
436 | BENCHDNN_PRINT(0, "%s\n" , "Unexpected status from pd iterator." ); |
437 | return FAIL; |
438 | } |
439 | } |
440 | |
441 | // Unreached fail status. |
442 | return FAIL; |
443 | } |
444 | |
445 | // This is an internal to `init_prim` function that utilizes the logic of |
446 | // creating a `pd` and `prim` and assign them to input wrappers. It allows to |
447 | // remove code duplication and keep all the logic in a single place. |
448 | template <typename func_t, typename prb_t> |
449 | int create_primitive(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &primw, |
450 | dnnl_engine_t engine, const func_t &init_pd_func, const prb_t *prb, |
451 | res_t *res, dir_t dir, const_dnnl_primitive_desc_t hint, |
452 | bool is_service_prim) { |
453 | dnnl_status_t status = dnnl_success; |
454 | dnnl_primitive_t prim {}; |
455 | |
456 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw; |
457 | |
458 | init_pd_args_t<prb_t> init_pd_args(res, engine, prb, dir, hint); |
459 | status = init_pd_func(init_pd_args); |
460 | |
461 | SAFE(check_dnnl_status(status, prb, res), WARN); |
462 | if (res->state == SKIPPED) return OK; |
463 | |
464 | // Fetch also checks if user requested to skip certain implementations. |
465 | SAFE(fetch_impl(pdw, init_pd_args, res, is_service_prim), WARN); |
466 | if (res->state == SKIPPED) return OK; |
467 | |
468 | DNN_SAFE(dnnl_primitive_create(&prim, pdw), WARN); |
469 | primw.reset(prim); |
470 | |
471 | return OK; |
472 | } |
473 | |
474 | template <typename func_t, typename prb_t> |
475 | int check_pd_w_and_wo_attr(dnnl_engine_t engine, const func_t &init_pd_func, |
476 | const prb_t *prb, res_t *res, dir_t dir, |
477 | const_dnnl_primitive_desc_t hint) { |
478 | |
479 | if (!attr_same_pd_check || prb->attr.is_def()) return OK; |
480 | |
481 | if (prb->attr.post_ops.convolution_index() != -1) return OK; |
482 | |
483 | // Check that adding attributes doesn't cause a fall back to another impl. |
484 | auto *prb_mutable = const_cast<prb_t *>(prb); |
485 | auto old_attr = prb_mutable->attr; |
486 | prb_mutable->attr = attr_t(); |
487 | init_pd_args_t<prb_t> init_pd_args_without_attr( |
488 | res, engine, prb_mutable, dir, hint); |
489 | DNN_SAFE(init_pd_func(init_pd_args_without_attr), WARN); |
490 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw( |
491 | init_pd_args_without_attr.pd); |
492 | prb_mutable->attr = old_attr; |
493 | SAFE(check_same_pd(pdw, res), WARN); |
494 | return OK; |
495 | } |
496 | |
497 | template <typename func_t, typename prb_t> |
498 | int init_prim(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim, |
499 | const func_t &init_pd_func, const prb_t *prb, res_t *res, |
500 | dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr, |
501 | bool is_service_prim = false) { |
502 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> primw; |
503 | |
504 | skip_start(res); |
505 | if (res->state == SKIPPED) return OK; |
506 | skip_invalid_prb(prb, res); |
507 | if (res->state == SKIPPED) return OK; |
508 | #ifndef DNNL_DISABLE_PRIMITIVE_CACHE |
509 | |
510 | // The idea is to create the requested primitive twice using different |
511 | // engines but the same device and context in the case of OpenCL and DPCPP. |
512 | // Rationale: make sure that the primitive cache is robust in the case |
513 | // where CPU and GPU engines are re-created because this is a commonly |
514 | // used scenario in the frameworks. |
515 | engine_t engine(get_test_engine()); |
516 | |
517 | // The first primitive creation using a temporary engine. |
518 | SAFE(create_primitive(primw, engine, init_pd_func, prb, res, dir, hint, |
519 | is_service_prim), |
520 | WARN); |
521 | if (res->state == SKIPPED) return OK; |
522 | |
523 | #endif |
524 | // The second (if the cache is enabled) primitive creation using the global |
525 | // test engine. This primitive is expected to come from the cache. |
526 | SAFE(create_primitive(primw, get_test_engine(), init_pd_func, prb, res, dir, |
527 | hint, is_service_prim), |
528 | WARN); |
529 | if (res->state == SKIPPED) return OK; |
530 | |
531 | auto pd = query_pd(primw); |
532 | SAFE(check_mem_size(pd, res), WARN); |
533 | if (res->state == SKIPPED) return OK; |
534 | |
535 | // Further checks are only for tested primitives. |
536 | if (is_service_prim) { |
537 | user_prim.reset(primw.release()); |
538 | return OK; |
539 | } |
540 | |
541 | res->impl_name = query_impl_info(pd); |
542 | BENCHDNN_PRINT(5, "oneDNN implementation: %s\n" , res->impl_name.c_str()); |
543 | // Check that adding attributes doesn't cause a fall back to another impl. |
544 | SAFE(check_pd_w_and_wo_attr( |
545 | get_test_engine(), init_pd_func, prb, res, dir, hint), |
546 | WARN); |
547 | // Check primitive descriptor is picked up from the cache, if applicable. |
548 | SAFE(check_pd_cache(pd), WARN); |
549 | // Check primitive is picked up from the cache, if applicable. |
550 | SAFE(check_primitive_cache(primw), WARN); |
551 | // Collect memory footprint for a given primitive descriptor. |
552 | SAFE(get_memory_footprint(pd, res), WARN); |
553 | |
554 | SAFE(test_persistent_cache_api(primw, pd, res), WARN); |
555 | |
556 | user_prim.reset(primw.release()); |
557 | return OK; |
558 | } |
559 | |
560 | template <typename func_t, typename prb_t> |
561 | int init_prim(const thr_ctx_t &thr_ctx, |
562 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim, |
563 | const func_t &init_pd_func, prb_t *prb, res_t *res, |
564 | dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr, |
565 | bool is_service_prim = false) { |
566 | int (*f)(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &, func_t &, |
567 | const prb_t *, res_t *, dir_t, const_dnnl_primitive_desc_t, bool) |
568 | = init_prim<func_t, prb_t>; |
569 | return create_in_thr_ctx(thr_ctx, f, user_prim, init_pd_func, prb, res, dir, |
570 | hint, is_service_prim); |
571 | } |
572 | |
573 | // `check_correctness` function is designed to be called from every driver where |
574 | // correctness validation is needed. It takes: |
575 | // * A pointer to a `prb_t` problem. |
576 | // * A vector of kinds to compare, to validate several outputs, if applicable. |
577 | // * Backend arguments to compare the output. |
578 | // * Driver's reference memory arguments to compute the reference path, then |
579 | // setup a compare object, and, finally, compare the output. |
580 | // * A reference to function that sets up the compare object, see description |
581 | // below. |
582 | // * A pointer to a `res_t` structure to update validation status. |
583 | // * An optional pointer to CPU primitive for speeding up reference path |
584 | // computation on GPU. |
585 | // |
586 | // The function doesn't return status since we rely on `res` to contain all |
587 | // necessary information about validation results. |
588 | // |
589 | // The function performs several validation steps: |
590 | // * Checks that padded area of all memories are properly zeroed. |
591 | // * Checks that GPU backend haven't modified out-of-boundary memory regions. |
592 | // * Executes driver's reference path, using the problem, driver reference |
593 | // arguments, and CPU primitive for GPU backend, if available. |
594 | // * For each kind to validate it: |
595 | // - Creates and sets up the compare object. Setting is done with |
596 | // `setup_cmp_func`. |
597 | // - Finds correspondent memory arguments from backend and reference and |
598 | // compares them. |
599 | // - Result of comparison is saved into `res` object. |
600 | // |
601 | // `setup_cmp_func` is a function that supposed to be defined in every driver's |
602 | // namespace. Its interface is: |
603 | // `void (compare::compare_t &, const prb_t *, data_kind_t, const args_t &);` |
604 | // It takes: |
605 | // * A reference to a `compare_t` object which the function modifies based on |
606 | // driver's needs. |
607 | // * A pointer to a `prb_t` problem. |
608 | // * `data_kind` value to help to setup threshold depending on output argument. |
609 | // * Driver's reference memory arguments since some drivers can't validate |
610 | // certain scenarios for sure without additional memory arguments. |
611 | // Returns nothing since the object is modified by reference due to lifetime of |
612 | // the compare object is controlled by `check_correctness`. |
613 | // |
614 | // Note: a dedicated non-templated type for `setup_cmp_func_t` could be used but |
615 | // since it relies on a `prb_t` type which is individual for each driver, |
616 | // it is'nt possible without a template. |
617 | template <typename setup_cmp_func_t, typename prb_t> |
618 | void check_correctness(const prb_t *prb, const std::vector<data_kind_t> &kinds, |
619 | const args_t &args, const args_t &ref_args, |
620 | const setup_cmp_func_t &setup_cmp_func, res_t *res, |
621 | dnnl_primitive_t prim_ref = nullptr) { |
622 | |
623 | for (int i = 0; i < args.size(); ++i) { |
624 | check_zero_padding(args.dnn_mem(i), args.arg(i), res); |
625 | check_buffer_overwrite(args.dnn_mem(i), args.arg(i), res); |
626 | } |
627 | |
628 | TIME_REF(compute_ref(prb, ref_args, prim_ref)); |
629 | |
630 | for (const auto &kind : kinds) { |
631 | compare::compare_t cmp; |
632 | cmp.set_data_kind(kind); |
633 | setup_cmp_func(cmp, prb, kind, ref_args); |
634 | |
635 | int arg = 0; |
636 | switch (kind) { |
637 | case DST: arg = DNNL_ARG_DST; break; |
638 | case SRC: arg = DNNL_ARG_DIFF_SRC; break; |
639 | case SRC_1: arg = DNNL_ARG_DIFF_SRC_1; break; |
640 | case WEI: arg = DNNL_ARG_DIFF_WEIGHTS; break; |
641 | case BIA: arg = DNNL_ARG_DIFF_BIAS; break; |
642 | case MEAN: arg = DNNL_ARG_MEAN; break; |
643 | case VAR: arg = DNNL_ARG_VARIANCE; break; |
644 | case SC: arg = DNNL_ARG_DIFF_SCALE; break; |
645 | case SH: arg = DNNL_ARG_DIFF_SHIFT; break; |
646 | case DST_ITER: arg = DNNL_ARG_DST_ITER; break; |
647 | case DST_ITER_C: arg = DNNL_ARG_DST_ITER_C; break; |
648 | case AUGRU_ATTENTION: arg = DNNL_ARG_DIFF_AUGRU_ATTENTION; break; |
649 | case SRC_ITER: arg = DNNL_ARG_DIFF_SRC_ITER; break; |
650 | case SRC_ITER_C: arg = DNNL_ARG_DIFF_SRC_ITER_C; break; |
651 | case WEI_ITER: arg = DNNL_ARG_DIFF_WEIGHTS_ITER; break; |
652 | case WEI_PEEPHOLE: arg = DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE; break; |
653 | case WEI_PROJECTION: arg = DNNL_ARG_DIFF_WEIGHTS_PROJECTION; break; |
654 | default: assert(!"unsupported kind" ); SAFE_V(FAIL); |
655 | } |
656 | const auto &mem_dt = args.find(arg); |
657 | const auto &mem_fp = ref_args.find(arg); |
658 | |
659 | cmp.compare(mem_fp, mem_dt, prb->attr, res); |
660 | } |
661 | } |
662 | |
663 | typedef std::function<dnnl_status_t( |
664 | const dnnl_stream_t &, const std::vector<dnnl_exec_arg_t> &)> |
665 | perf_function_t; |
666 | |
667 | int execute_and_wait(perf_function_t &exec_func, const dnnl_engine_t &engine, |
668 | const args_t &args, res_t *res = nullptr); |
669 | int execute_and_wait( |
670 | dnnl_primitive_t prim, const args_t &args, res_t *res = nullptr); |
671 | |
672 | void reset_gpu_profiling(); |
673 | int measure_perf(const thr_ctx_t &ctx, res_t *res, perf_function_t &perf_func, |
674 | args_t &args); |
675 | int measure_perf( |
676 | const thr_ctx_t &ctx, res_t *res, dnnl_primitive_t prim, args_t &args); |
677 | |
678 | void maybe_prepare_runtime_scales(dnn_mem_t &scales_m, |
679 | const attr_t::scale_t &scale, int64_t scale_cnt, const float *scales); |
680 | |
681 | void maybe_prepare_runtime_scales_v2(dnn_mem_t &scales_dt, dnn_mem_t &scales_fp, |
682 | const attr_t::scale_t &scale, int64_t scale_cnt, const float *scales); |
683 | |
684 | void maybe_prepare_runtime_zero_points(dnn_mem_t &zero_points_m, |
685 | const attr_t &attr, int arg, int64_t count, const int32_t *zero_points); |
686 | |
687 | void maybe_prepare_runtime_zero_points_v2(dnn_mem_t &zero_points_dt, |
688 | dnn_mem_t &zero_points_fp, const attr_t &attr, int arg, int64_t count, |
689 | const int32_t *zero_points); |
690 | |
691 | std::vector<float> prepare_po_vals(const dnn_mem_t &dst_m, const args_t &args, |
692 | const std::vector<std::pair<int, int>> &v_po_masks, |
693 | const size_t dst_off); |
694 | |
695 | bool check_md_consistency_with_tag( |
696 | const_dnnl_memory_desc_t md, const std::string &tag); |
697 | |
698 | memory_kind_ext_t str2memory_kind(const char *str); |
699 | |
700 | float reorder_rescale_factor(); |
701 | dims_t md2dims(const dnnl_memory_desc_t &md); |
702 | |
703 | // Function adjusts data type if fpmath mode is present or sum_dt is different |
704 | // from destination_dt. It is used in `cfg` objects that regulate filling. |
705 | dnnl_data_type_t deduce_cfg_data_type( |
706 | dnnl_data_type_t in_dt, const attr_t &attr, data_kind_t dk); |
707 | |
708 | #endif |
709 | |