1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3* Copyright 2021 FUJITSU LIMITED
4* Copyright 2021 Arm Ltd. and affiliates
5*
6* Licensed under the Apache License, Version 2.0 (the "License");
7* you may not use this file except in compliance with the License.
8* You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing, software
13* distributed under the License is distributed on an "AS IS" BASIS,
14* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15* See the License for the specific language governing permissions and
16* limitations under the License.
17*******************************************************************************/
18
19#include <algorithm>
20#include <cstring>
21#include <vector>
22
23#include <float.h>
24#include <math.h>
25#include <stdio.h>
26#include <stdlib.h>
27
28#include "oneapi/dnnl/dnnl.h"
29
30#include "utils/parallel.hpp"
31
32#include "dnnl_common.hpp"
33#include "dnnl_memory.hpp"
34
35#include "binary/binary.hpp"
36#include "conv/conv.hpp"
37#include "prelu/prelu.hpp"
38
39namespace conv {
40
41double get_non_zero_trust_percent(const prb_t *prb, data_kind_t kind) {
42 auto negative_to_zero = [&]() {
43 using pk = attr_t::post_ops_t::kind_t;
44 const auto &po = prb->attr.post_ops;
45 int count = 0;
46
47 // Check for all post-ops that convert negative to zero
48 std::vector<pk> non_neg_po {pk::ABS};
49 std::vector<pk> non_neg_alpha_0_po {
50 pk::CLIP, pk::CLIP_V2, pk::ELU, pk::RELU};
51 for (int i = 0; i < po.len(); ++i) {
52 const auto &e = po.entry[i];
53 if (!e.is_eltwise_kind()) continue;
54
55 auto k = e.kind;
56 auto alpha = e.eltwise.alpha;
57
58 count += std::any_of(non_neg_po.cbegin(), non_neg_po.cend(),
59 [k](const pk alg) { return alg == k; });
60 count += std::any_of(non_neg_alpha_0_po.cbegin(),
61 non_neg_alpha_0_po.cend(), [k, alpha](const pk alg) {
62 return alg == k && alpha == 0;
63 });
64 }
65 // Check for u8 dst
66 count += prb->cfg[DST].dt == dnnl_u8;
67 // Check for physically padded area in the output
68 count += prb->od > prb->id || prb->oh > prb->ih || prb->ow > prb->iw;
69
70 return !!count;
71 };
72
73 double trust = 0.3; /* why? */
74 switch (kind) {
75 case SRC: trust /= prb->sd * prb->sh * prb->sw; break;
76 case WEI:
77 trust /= 1. * prb->kd * prb->kh * prb->kw
78 / MIN3(prb->kd * prb->kh * prb->kw,
79 prb->id * prb->ih * prb->iw,
80 prb->od * prb->oh * prb->ow);
81 break;
82 case BIA:
83 trust = 0.8 * prb->cfg[DST].f_sparsity; /* why? */
84 break;
85 case DST: trust /= (1.f + negative_to_zero()); break;
86 default: assert(!"unsupported data kind");
87 }
88
89 return trust;
90}
91
92bool need_src_init(const prb_t *prb) {
93 return !(prb->dir == BWD_D);
94}
95
96bool need_wei_init(const prb_t *prb) {
97 return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI);
98}
99
100bool need_bia_init(const prb_t *prb) {
101 return need_wei_init(prb);
102}
103
104bool need_dst_init(const prb_t *prb) {
105 return !(prb->dir & FLAG_FWD)
106 || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0);
107}
108
109int fill_src(
110 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
111 const bool check_reorder
112 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
113 dnn_mem_t extra_mem;
114 if (check_reorder) {
115 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
116 }
117 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
118
119 // Use dense filling for small problems.
120 int src_nelems_mask = powf(2.f, prb->ndims) - 1;
121 src_nelems_mask -= 1; // remove minibatch as independent dimension
122 auto src_nelems = prb->desc_nelems(DNNL_ARG_SRC, src_nelems_mask);
123 if (prb->has_groups) src_nelems /= prb->g; // groups are also independent
124
125 const auto &c = prb->get_dt_conf(SRC);
126 const int range = c.f_max - c.f_min + 1;
127 const float sparsity
128 = (!is_bench_mode(CORR) || src_nelems < 100) ? 1.f : c.f_sparsity;
129
130 benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw,
131 [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) {
132 const int64_t gen
133 = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic;
134 const bool non_base = flip_coin(gen, sparsity);
135 float value = non_base ? c.f_min + gen * c.f_step % range
136 : c.f_base;
137
138 maybe_zero_point(
139 prb->attr, value, prb->src_zp, ic, DNNL_ARG_SRC, true);
140
141 ((float *)mem_00)[src_off_f(prb, mb, 0, ic, id, ih, iw)]
142 = round_to_nearest_representable(mem_dt.dt(), value);
143 });
144
145 SAFE(mem_dt.reorder(mem_00), WARN);
146 if (check_reorder) {
147 SAFE(mem_fp.reorder(mem_dt), WARN);
148 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
149 if (rc != 0) {
150 res->state = FAILED;
151 SAFE(FAIL, WARN);
152 }
153 }
154
155 return OK;
156}
157
158int fill_wei(
159 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
160 const bool is_def_zp = prb->attr.zero_points.is_def(DNNL_ARG_SRC);
161 const bool diff_data_type = mem_dt.dt() != mem_fp.dt();
162
163 dnnl_data_type_t dt_check = dnnl_s8;
164#if defined(DNNL_AARCH64) && (DNNL_AARCH64 == 1)
165 /* Note for x64:
166 Both data types of src and weight are s8, oneDNN addds 128 to one of the s8
167 input to make it of type u8 instead, as explained in
168 https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html or
169 doc/advanced/int8_computations.md
170 It is because `VPDPBUSD` instruction uses the combination of s8 and u8 as
171 input.
172
173 Note for AArch64:
174 Because dot product instructions of AArch64 "SDOT" receives s8 input
175 for both src and weight, the addition (and its counterpart of subtraction)
176 is not required for AArch64.
177 */
178 if (res->impl_name.find("jit", 0) == 0) dt_check = dnnl_u8;
179#endif
180
181 const bool wei_x8x8 = prb->get_dt_conf(WEI).dt == dnnl_s8
182 && prb->get_dt_conf(SRC).dt == dt_check;
183 const bool check_reorder
184 = (is_bench_mode(CORR)) && diff_data_type && !wei_x8x8 && is_def_zp;
185
186 dnn_mem_t extra_mem;
187 if (check_reorder) {
188 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
189 }
190 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
191
192 const auto &c = prb->get_dt_conf(WEI);
193 const int range = c.f_max - c.f_min + 1;
194 const float sparsity = !is_bench_mode(CORR) ? 1.f : c.f_sparsity;
195
196 benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd,
197 prb->kh, prb->kw,
198 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
199 int64_t kw) {
200 const int64_t gen = 113 * g + 127 * kd + 131 * kh + 137 * kw
201 + 139 * oc + 149 * ic + 151;
202 const bool non_base = flip_coin(gen, sparsity);
203 const float value = non_base ? c.f_min + gen * c.f_step % range
204 : c.f_base;
205 ((float *)mem_00)[wei_off_f(prb, g, oc, ic, kd, kh, kw)]
206 = value;
207 });
208
209 SAFE(mem_dt.reorder(mem_00), WARN);
210 if (check_reorder) {
211 SAFE(mem_fp.reorder(mem_dt), WARN);
212 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
213 if (rc != 0) {
214 res->state = FAILED;
215 SAFE(FAIL, WARN);
216 }
217 }
218 if ((wei_x8x8 || !is_def_zp) && is_cpu()) {
219 // Check that s8 -> s8_comp exists in the library since users may have
220 // already quantized data.
221 dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine());
222 dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine());
223 SAFE(mem_fp_s8.reorder(mem_fp), WARN);
224 SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN);
225 SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN);
226 int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size());
227 SAFE(rc == 0 ? OK : FAIL, WARN);
228 }
229 return OK;
230}
231
232int fill_bia(
233 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
234 const bool check_reorder
235 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
236 dnn_mem_t extra_mem;
237 if (check_reorder)
238 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::x, get_cpu_engine());
239 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
240
241 const size_t nelems = mem_00.nelems();
242 if (nelems == 0) return OK;
243
244 const auto &c = prb->get_dt_conf(BIA);
245 const int range = c.f_max - c.f_min + 1;
246
247 for (size_t i = 0; i < nelems; ++i) {
248 const int gen = (int)(151 * i);
249 const bool non_base = flip_coin(gen, c.f_sparsity);
250 const float value
251 = non_base ? c.f_min + gen * c.f_step % range : c.f_base;
252
253 ((float *)mem_00)[i] = value;
254 }
255
256 SAFE(mem_dt.reorder(mem_00), WARN);
257 if (check_reorder) {
258 SAFE(mem_fp.reorder(mem_dt), WARN);
259 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
260 if (rc != 0) {
261 res->state = FAILED;
262 SAFE(FAIL, WARN);
263 }
264 }
265 return OK;
266}
267
268int fill_dst_with_params(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
269 dnnl_data_type_t dt, double sparsity, int min, int max, int base,
270 int step, res_t *res) {
271 const bool check_reorder
272 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
273 dnn_mem_t extra_mem;
274 if (check_reorder) {
275 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
276 }
277
278 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
279 const int range = max - min + 1;
280
281 benchdnn_parallel_nd(prb->mb, prb->oc, prb->od, prb->oh, prb->ow,
282 [&](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) {
283 const int64_t gen
284 = 157 * od + 163 * oh + 167 * ow + 173 * mb + 179 * oc;
285 const bool non_base = flip_coin(gen, sparsity);
286 const float value = non_base ? min + gen * step % range : base;
287
288 ((float *)mem_00)[dst_off_f(prb, mb, 0, oc, od, oh, ow)]
289 = value;
290 });
291
292 SAFE(mem_dt.reorder(mem_00), WARN);
293 if (check_reorder) {
294 SAFE(mem_fp.reorder(mem_dt), WARN);
295 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
296 if (rc != 0) {
297 res->state = FAILED;
298 SAFE(FAIL, WARN);
299 }
300 }
301
302 return OK;
303}
304
305int fill_dst(
306 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
307 auto dst_dt = mem_dt.dt();
308 int sum_ind = prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM);
309 auto sum_dt = (sum_ind != -1) ? prb->attr.post_ops.entry[sum_ind].sum.dt
310 : dnnl_data_type_undef;
311 bool diff_sum_dst_types
312 = sum_dt != dnnl_data_type_undef && sum_dt != dst_dt;
313 bool sum_dt_is_int8 = sum_dt == dnnl_s8 || sum_dt == dnnl_u8;
314
315 const auto &c = prb->get_dt_conf(DST);
316 float f_min = c.f_min;
317 float f_max = c.f_max;
318 if (diff_sum_dst_types && sum_dt_is_int8) {
319 f_min = lowest_dt(sum_dt);
320 f_max = max_dt(sum_dt);
321 }
322
323 // Change mem dt to sum dt, so we can save sum data properly.
324 if (diff_sum_dst_types) { mem_dt.set_dt(sum_dt); }
325
326 fill_dst_with_params(prb, mem_dt, mem_fp, sum_dt, c.f_sparsity, f_min,
327 f_max, c.f_base, c.f_step, res);
328
329 // Return dst data type back.
330 if (diff_sum_dst_types) { mem_dt.set_dt(dst_dt); }
331 return OK;
332}
333
334dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
335 const prb_t *prb = init_pd_args.prb;
336
337 auto src_d = dnn_mem_t::init_md(prb->ndims, prb->src_dims().data(),
338 prb->get_dt_conf(SRC).dt, normalize_tag(prb->stag, prb->ndims));
339 auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups,
340 prb->wei_dims().data(), prb->get_dt_conf(WEI).dt,
341 normalize_tag(prb->wtag, prb->ndims + prb->has_groups));
342 auto bia_d = dnn_mem_t::init_md(
343 1, prb->bia_dims().data(), prb->get_dt_conf(BIA).dt, tag::any);
344 auto dst_d = dnn_mem_t::init_md(prb->ndims, prb->dst_dims().data(),
345 prb->get_dt_conf(DST).dt, normalize_tag(prb->dtag, prb->ndims));
346
347 dnnl_alg_kind_t alg = dnnl_convolution_direct;
348 if (prb->alg == WINO) alg = dnnl_convolution_winograd;
349 if (prb->alg == AUTO) alg = dnnl_convolution_auto;
350
351 attr_args_t attr_args;
352 attr_args.prepare_post_ops_mds(
353 prb->attr, prb->ndims, prb->dst_dims().data());
354 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
355 if (wei_scale.policy == policy_t::PER_OC) {
356 auto wei_mask = prb->has_groups ? 2 : 1;
357 attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales,
358 prb->oc, wei_mask);
359 }
360 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
361 create_dnnl_attr(prb->attr, attr_args));
362
363 switch (prb->dir) {
364 case FWD_D:
365 case FWD_B:
366 case FWD_I:
367 if (prb->dir != FWD_B) bia_d.reset(nullptr);
368 DNN_SAFE_STATUS(dnnl_convolution_forward_primitive_desc_create(
369 &init_pd_args.pd, init_pd_args.engine,
370 prb->dir == FWD_I ? dnnl_forward_inference
371 : dnnl_forward_training,
372 alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(),
373 prb->dilations().data(), prb->padding().data(),
374 prb->padding_r().data(), dnnl_attr));
375 break;
376 case BWD_D:
377 DNN_SAFE_STATUS(
378 dnnl_convolution_backward_data_primitive_desc_create(
379 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
380 wei_d, dst_d, prb->strides().data(),
381 prb->dilations().data(), prb->padding().data(),
382 prb->padding_r().data(), init_pd_args.hint,
383 dnnl_attr));
384 break;
385 case BWD_W:
386 case BWD_WB:
387 if (prb->dir == BWD_W) bia_d.reset(nullptr);
388 DNN_SAFE_STATUS(
389 dnnl_convolution_backward_weights_primitive_desc_create(
390 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
391 wei_d, bia_d, dst_d, prb->strides().data(),
392 prb->dilations().data(), prb->padding().data(),
393 prb->padding_r().data(), init_pd_args.hint,
394 dnnl_attr));
395 break;
396 default: DNN_SAFE_STATUS(dnnl_invalid_arguments);
397 }
398
399 // TODO: add query for acc type in pd.
400 //DNN_SAFE_STATUS(cd.accum_data_type == prb->get_dt_conf(ACC).dt
401 // ? dnnl_success
402 // : dnnl_unimplemented);
403 return dnnl_success;
404}
405
406int init_prim_ref(
407 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
408 if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
409
410 // Create a new copy of prb to avoid potentially corrupting the test by
411 // modifying prb in place.
412 // DIRECT algorithm is used to prevent fallback to the slow benchdnn
413 // reference implementation.
414 auto cpu_attr = prb->attr;
415 update_cpu_ref_attrs(cpu_attr);
416 prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx,
417 DIRECT, cpu_attr, prb->ctx_init, prb->ctx_exe, prb->mb};
418
419 init_pd_args_t<prb_t> init_pd_args(
420 /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
421 /* hint = */ nullptr);
422 init_pd(init_pd_args);
423
424 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
425 fetch_impl(pdw, init_pd_args, /* res = */ nullptr,
426 /* is_service_prim = */ true);
427
428 dnnl_primitive_t prim_ref_ {};
429 if (pdw) {
430 if (query_impl_info(pdw) == "ref:any") return OK;
431 DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN);
432 BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n",
433 query_impl_info(pdw).c_str());
434 }
435 prim_ref.reset(prim_ref_);
436 return OK;
437}
438
439void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
440 skip_unimplemented_data_type(
441 {prb->get_dt_conf(SRC).dt, prb->get_dt_conf(WEI).dt,
442 prb->get_dt_conf(DST).dt},
443 prb->dir, res);
444 skip_unimplemented_sum_po(prb->attr, res, prb->get_dt_conf(DST).dt);
445
446 if (is_cpu()) {
447 // Specific configurations are not supported.
448 const bool is_f32_src = prb->get_dt_conf(SRC).dt == dnnl_f32;
449 const bool is_f32_wei = prb->get_dt_conf(WEI).dt == dnnl_f32;
450 const bool is_f16 = prb->get_dt_conf(WEI).dt == dnnl_f16;
451 const bool is_bf16_src = prb->get_dt_conf(SRC).dt == dnnl_bf16;
452 const bool is_bf16_wei = prb->get_dt_conf(WEI).dt == dnnl_bf16;
453 const bool is_int8_dst = prb->get_dt_conf(DST).dt == dnnl_s8
454 || prb->get_dt_conf(DST).dt == dnnl_u8;
455 const bool is_f32f32x8 = is_f32_src && is_f32_wei && is_int8_dst;
456 const bool is_bf16bf16x8 = is_bf16_src && is_bf16_wei && is_int8_dst;
457 const bool is_valid_f16 = is_f16
458 && (prb->get_dt_conf(DST).dt == dnnl_f32
459 || prb->get_dt_conf(DST).dt == dnnl_f16);
460
461 if (is_f32f32x8 || is_bf16bf16x8 || !is_valid_f16) {
462 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
463 return;
464 }
465 }
466
467 // Winograd implementation has very limited scope and support. It doesn't
468 // make sense to list all of them, just convert all unimplemented Winograd
469 // problems into not supported.
470 if (prb->alg == WINO) {
471 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
472 return;
473 }
474}
475
476void skip_invalid_prb(const prb_t *prb, res_t *res) {}
477
478void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
479 const args_t &ref_args) {
480 const bool compare_with_norm = (prb->alg & WINO);
481 cmp.set_norm_validation_mode(compare_with_norm);
482
483 float trh = prb->get_dt_conf(kind).eps;
484 if ((prb->alg & WINO) && (prb->dir & FLAG_WEI)) {
485 // This is an empirical equation derived by observing growth error with
486 // increasing 'k' dimension in gemm of winograd
487 const float log_const = log10(0.125 * prb->mb * prb->oh * prb->ow);
488 trh = prb->get_dt_conf(kind).eps * (MAX2(1, pow(10, 0.4 * log_const)));
489 }
490 cmp.set_threshold(trh);
491
492 const float zpp = (1.f - get_non_zero_trust_percent(prb, kind)) * 100.f;
493 cmp.set_zero_trust_percent(zpp);
494}
495
496int doit(const prb_t *prb, res_t *res) {
497 if (bench_mode == LIST) return res->state = LISTED, OK;
498
499 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
500 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
501 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
502
503 auto const_pd = query_pd(prim);
504
505 if (prb->alg == AUTO) prb->alg = alg_kind2alg(query_alg_kind(const_pd));
506 prb->cfg = auto_cfg(prb->alg, prb->cfg);
507
508 const auto &src_md = prb->dir == BWD_D
509 ? query_md(const_pd, DNNL_ARG_DIFF_SRC)
510 : query_md(const_pd, DNNL_ARG_SRC);
511 const auto &wei_md = prb->dir & FLAG_WEI
512 ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS)
513 : query_md(const_pd, DNNL_ARG_WEIGHTS);
514 const auto &bia_md = prb->dir & FLAG_WEI
515 ? query_md(const_pd, DNNL_ARG_DIFF_BIAS)
516 : query_md(const_pd, DNNL_ARG_BIAS);
517 const auto &dst_md = prb->dir & FLAG_BWD
518 ? query_md(const_pd, DNNL_ARG_DIFF_DST)
519 : query_md(const_pd, DNNL_ARG_DST);
520 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
521
522 const auto fp = dnnl_f32;
523 const auto src_tag = tag::abx;
524 const auto wei_tag = tag::abx;
525
526 // Use CPU prim as the reference in GPU testing to reduce testing time.
527 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
528 SAFE(init_prim_ref(prim_ref, prb), WARN);
529
530 const auto &test_engine = get_test_engine();
531 const auto &ref_engine = get_cpu_engine();
532
533 dnn_mem_t src_dt(src_md, test_engine);
534 dnn_mem_t wei_dt(wei_md, test_engine);
535 dnn_mem_t dst_dt(dst_md, test_engine);
536 dnn_mem_t bia_dt(bia_md, test_engine);
537 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
538 dnn_mem_t src_scales_dt, src_scales_fp;
539 dnn_mem_t wei_scales_dt, wei_scales_fp;
540 dnn_mem_t dst_scales_dt, dst_scales_fp;
541 dnn_mem_t src_zp_dt, src_zp_fp;
542 dnn_mem_t dst_zp_dt, dst_zp_fp;
543 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
544 std::vector<int> binary_po_args;
545 SAFE(binary::setup_binary_po(
546 const_pd, binary_po_args, binary_po_dt, binary_po_fp),
547 WARN);
548 std::vector<dnn_mem_t> prelu_po_fp, prelu_po_dt;
549 std::vector<int> prelu_po_args;
550 SAFE(prelu::setup_prelu_po(
551 const_pd, prelu_po_args, prelu_po_fp, prelu_po_dt),
552 WARN);
553
554 dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine);
555 dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine);
556 dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine);
557 dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
558 dnn_mem_t scratchpad_fp;
559 if (prim_ref)
560 scratchpad_fp = dnn_mem_t(
561 query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine);
562
563 if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
564 if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
565 if (need_wei_init(prb)) SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN);
566 if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN);
567
568 const int src_mask = attr_t::get_default_mask(
569 prb->attr.scales.get(DNNL_ARG_SRC).policy);
570 int wei_mask = attr_t::get_default_mask(
571 prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS);
572 if (prb->has_groups) wei_mask = (1 << wei_mask) + 1;
573 const int dst_mask = attr_t::get_default_mask(
574 prb->attr.scales.get(DNNL_ARG_DST).policy);
575 maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp,
576 prb->attr.scales.get(DNNL_ARG_SRC),
577 prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales);
578 maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp,
579 prb->attr.scales.get(DNNL_ARG_WEIGHTS),
580 prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales);
581 maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp,
582 prb->attr.scales.get(DNNL_ARG_DST),
583 prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales);
584
585 maybe_prepare_runtime_zero_points_v2(src_zp_dt, src_zp_fp, prb->attr,
586 DNNL_ARG_SRC, prb->ic, prb->src_zp);
587 maybe_prepare_runtime_zero_points_v2(dst_zp_dt, dst_zp_fp, prb->attr,
588 DNNL_ARG_DST, prb->oc, prb->dst_zp);
589
590 args_t args, ref_args;
591
592 if (prb->dir & FLAG_FWD) {
593 args.set(DNNL_ARG_SRC, src_dt);
594 args.set(DNNL_ARG_WEIGHTS, wei_dt);
595 args.set(DNNL_ARG_BIAS, bia_dt);
596 args.set(DNNL_ARG_DST, dst_dt);
597 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
598 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
599 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt);
600 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
601 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt);
602 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt);
603 args.set(binary_po_args, binary_po_dt);
604 args.set(prelu_po_args, prelu_po_dt);
605
606 SAFE(execute_and_wait(prim, args, res), WARN);
607
608 if (is_bench_mode(CORR)) {
609 ref_args.set(DNNL_ARG_SRC, src_fp);
610 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
611 ref_args.set(DNNL_ARG_BIAS, bia_fp);
612 ref_args.set(DNNL_ARG_DST, dst_fp);
613 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
614 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp);
615 ref_args.set(
616 DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp);
617 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp);
618 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp);
619 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp);
620 ref_args.set(binary_po_args, binary_po_fp);
621 ref_args.set(prelu_po_args, prelu_po_fp);
622
623 check_correctness(
624 prb, {DST}, args, ref_args, setup_cmp, res, prim_ref);
625 }
626 } else if (prb->dir == BWD_D) {
627 args.set(DNNL_ARG_DIFF_SRC, src_dt);
628 args.set(DNNL_ARG_WEIGHTS, wei_dt);
629 args.set(DNNL_ARG_DIFF_DST, dst_dt);
630 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
631
632 SAFE(execute_and_wait(prim, args, res), WARN);
633
634 if (is_bench_mode(CORR)) {
635 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
636 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
637 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
638 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
639
640 check_correctness(
641 prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref);
642 }
643 } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
644 args.set(DNNL_ARG_SRC, src_dt);
645 args.set(DNNL_ARG_DIFF_DST, dst_dt);
646 args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt);
647 args.set(DNNL_ARG_DIFF_BIAS, bia_dt);
648 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
649
650 SAFE(execute_and_wait(prim, args, res), WARN);
651
652 if (is_bench_mode(CORR)) {
653 ref_args.set(DNNL_ARG_SRC, src_fp);
654 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
655 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp);
656 ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp);
657 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
658
659 check_correctness(
660 prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref);
661 }
662 } else {
663 SAFE(FAIL, CRIT);
664 }
665
666 return measure_perf(prb->ctx_exe, res, prim, args);
667}
668
669} // namespace conv
670