1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18#include <utility>
19
20#include <float.h>
21#include <math.h>
22#include <stdio.h>
23#include <stdlib.h>
24
25#include "utils/parallel.hpp"
26
27#include "dnnl_common.hpp"
28#include "dnnl_memory.hpp"
29
30#include "binary/binary.hpp"
31#include "deconv/deconv.hpp"
32#include "prelu/prelu.hpp"
33
34namespace deconv {
35
36int transpose_data_wei(
37 const prb_t *prb, const dnn_mem_t &wei, const dnn_mem_t &wei_tr) {
38 benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd,
39 prb->kh, prb->kw,
40 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
41 int64_t kw) {
42 int64_t ch_idx
43 = (g * prb->ic / prb->g + ic) * prb->oc / prb->g + oc;
44 int64_t idx = ((ch_idx * prb->kd + kd) * prb->kh + kh) * prb->kw
45 + kw;
46 ((float *)wei_tr)[idx]
47 = ((float *)wei)[wei_off_f(prb, g, oc, ic, kd, kh, kw)];
48 });
49
50 return OK;
51}
52
53double get_non_zero_trust_percent(const prb_t *prb, data_kind_t kind) {
54 auto negative_to_zero = [&]() {
55 using pk = attr_t::post_ops_t::kind_t;
56 const auto &po = prb->attr.post_ops;
57 int count = 0;
58
59 // Check for all post-ops that convert negative to zero
60 std::vector<pk> non_neg_po {pk::ABS};
61 std::vector<pk> non_neg_alpha_0_po {
62 pk::CLIP, pk::CLIP_V2, pk::ELU, pk::RELU};
63 for (int i = 0; i < po.len(); ++i) {
64 const auto &e = po.entry[i];
65 if (!e.is_eltwise_kind()) continue;
66
67 auto k = e.kind;
68 auto alpha = e.eltwise.alpha;
69
70 count += std::any_of(non_neg_po.cbegin(), non_neg_po.cend(),
71 [k](const pk alg) { return alg == k; });
72 count += std::any_of(non_neg_alpha_0_po.cbegin(),
73 non_neg_alpha_0_po.cend(), [k, alpha](const pk alg) {
74 return alg == k && alpha == 0;
75 });
76 }
77 // Check for u8 dst
78 count += prb->cfg[DST].dt == dnnl_u8;
79 // Check for physically padded area in the output
80 count += prb->od > prb->id || prb->oh > prb->ih || prb->ow > prb->iw;
81
82 return !!count;
83 };
84
85 double trust = 0.3; /* why? */
86 switch (kind) {
87 case SRC: trust /= prb->sd * prb->sh * prb->sw; break;
88 case WEI:
89 trust /= 1. * prb->kd * prb->kh * prb->kw
90 / MIN3(prb->kd * prb->kh * prb->kw,
91 prb->id * prb->ih * prb->iw,
92 prb->od * prb->oh * prb->ow);
93 break;
94 case BIA:
95 trust = 0.8 * prb->cfg[DST].f_sparsity; /* why? */
96 break;
97 case DST: trust /= (1.f + negative_to_zero()); break;
98 default: assert(!"unsupported data kind");
99 }
100
101 return trust;
102}
103
104bool need_src_init(const prb_t *prb) {
105 return !(prb->dir == BWD_D);
106}
107
108bool need_wei_init(const prb_t *prb) {
109 return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI);
110}
111
112bool need_bia_init(const prb_t *prb) {
113 return need_wei_init(prb);
114}
115
116bool need_dst_init(const prb_t *prb) {
117 return !(prb->dir & FLAG_FWD)
118 || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0);
119}
120
121int fill_src(
122 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
123 const bool check_reorder
124 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
125 dnn_mem_t extra_mem;
126 if (check_reorder) {
127 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
128 }
129 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
130
131 // Use dense filling for small problems.
132 int src_nelems_mask = powf(2.f, prb->ndims) - 1;
133 src_nelems_mask -= 1; // remove minibatch as independent dimension
134 auto src_nelems = prb->desc_nelems(DNNL_ARG_SRC, src_nelems_mask);
135 if (prb->has_groups) src_nelems /= prb->g; // groups are also independent
136
137 const auto &c = prb->cfg[SRC];
138 const int range = c.f_max - c.f_min + 1;
139 const float sparsity
140 = (!is_bench_mode(CORR) || src_nelems < 100) ? 1.f : c.f_sparsity;
141
142 benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw,
143 [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) {
144 const int64_t gen
145 = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic;
146 const bool non_base = flip_coin(gen, sparsity);
147 float value = non_base ? c.f_min + gen * c.f_step % range
148 : c.f_base;
149
150 maybe_zero_point(
151 prb->attr, value, prb->src_zp, ic, DNNL_ARG_SRC, true);
152
153 ((float *)mem_00)[src_off_f(prb, mb, 0, ic, id, ih, iw)]
154 = round_to_nearest_representable(mem_dt.dt(), value);
155 });
156
157 SAFE(mem_dt.reorder(mem_00), WARN);
158 if (check_reorder) {
159 SAFE(mem_fp.reorder(mem_dt), WARN);
160 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
161 if (rc != 0) {
162 res->state = FAILED;
163 SAFE(FAIL, WARN);
164 }
165 }
166
167 return OK;
168}
169
170int fill_wei(
171 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
172 const bool is_def_zp = prb->attr.zero_points.is_def(DNNL_ARG_SRC);
173 const bool diff_data_type = mem_dt.dt() != mem_fp.dt();
174
175 dnnl_data_type_t dt_check = dnnl_s8;
176#if DNNL_AARCH64
177 /* Note for x64:
178 Both data types of src and weight are s8, oneDNN addds 128 to one of the s8
179 input to make it of type u8 instead, as explained in
180 https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html or
181 doc/advanced/int8_computations.md
182 It is because `VPDPBUSD` instruction uses the combination of s8 and u8 as
183 input.
184
185 Note for AArch64:
186 Because dot product instructions of AArch64 "SDOT" receives s8 input
187 for both src and weight, the addition (and its counterpart of subtraction)
188 is not required for AArch64.
189 */
190 if (res->impl_name.find("jit", 0) == 0) dt_check = dnnl_u8;
191#endif
192
193 const bool wei_x8x8
194 = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[SRC].dt == dt_check;
195 const bool check_reorder
196 = (is_bench_mode(CORR)) && diff_data_type && !wei_x8x8 && is_def_zp;
197
198 dnn_mem_t extra_mem;
199 if (check_reorder) {
200 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
201 }
202 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
203
204 const auto &c = prb->cfg[WEI];
205 const int range = c.f_max - c.f_min + 1;
206 const float sparsity = !is_bench_mode(CORR) ? 1.f : c.f_sparsity;
207
208 benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd,
209 prb->kh, prb->kw,
210 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
211 int64_t kw) {
212 const int64_t gen = 113 * g + 127 * kd + 131 * kh + 137 * kw
213 + 139 * oc + 149 * ic + 151;
214 const bool non_base = flip_coin(gen, sparsity);
215 const float value = non_base ? c.f_min + gen * c.f_step % range
216 : c.f_base;
217 ((float *)mem_00)[wei_off_f(prb, g, oc, ic, kd, kh, kw)]
218 = value;
219 });
220
221 SAFE(mem_dt.reorder(mem_00), WARN);
222 if (check_reorder) {
223 SAFE(mem_fp.reorder(mem_dt), WARN);
224 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
225 if (rc != 0) {
226 res->state = FAILED;
227 SAFE(FAIL, WARN);
228 }
229 }
230 if ((wei_x8x8 || !is_def_zp) && is_cpu()) {
231 // Check that s8 -> s8_comp exists in the library since users may have
232 // already quantized data.
233 dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine());
234 dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine());
235 SAFE(mem_fp_s8.reorder(mem_fp), WARN);
236 SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN);
237 SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN);
238 int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size());
239 SAFE(rc == 0 ? OK : FAIL, WARN);
240 }
241 return OK;
242}
243
244int fill_bia(
245 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
246 const bool check_reorder
247 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
248 dnn_mem_t extra_mem;
249 if (check_reorder)
250 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::x, get_cpu_engine());
251 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
252
253 const size_t nelems = mem_00.nelems();
254 if (nelems == 0) return OK;
255
256 const auto &c = prb->cfg[BIA];
257 const int range = c.f_max - c.f_min + 1;
258
259 for (size_t i = 0; i < nelems; ++i) {
260 const int gen = (int)(151 * i);
261 const bool non_base = flip_coin(gen, c.f_sparsity);
262 const float value
263 = non_base ? c.f_min + gen * c.f_step % range : c.f_base;
264
265 ((float *)mem_00)[i] = value;
266 }
267
268 SAFE(mem_dt.reorder(mem_00), WARN);
269 if (check_reorder) {
270 SAFE(mem_fp.reorder(mem_dt), WARN);
271 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
272 if (rc != 0) {
273 res->state = FAILED;
274 SAFE(FAIL, WARN);
275 }
276 }
277 return OK;
278}
279
280int fill_dst_with_params(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
281 dnnl_data_type_t dt, double sparsity, int min, int max, int base,
282 int step, res_t *res) {
283 const bool check_reorder
284 = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt());
285 dnn_mem_t extra_mem;
286 if (check_reorder) {
287 extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine());
288 }
289
290 dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp;
291 const int range = max - min + 1;
292
293 benchdnn_parallel_nd(prb->mb, prb->oc, prb->od, prb->oh, prb->ow,
294 [&](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) {
295 const int64_t gen
296 = 157 * od + 163 * oh + 167 * ow + 173 * mb + 179 * oc;
297 const bool non_base = flip_coin(gen, sparsity);
298 const float value = non_base ? min + gen * step % range : base;
299
300 ((float *)mem_00)[dst_off_f(prb, mb, 0, oc, od, oh, ow)]
301 = value;
302 });
303
304 SAFE(mem_dt.reorder(mem_00), WARN);
305 if (check_reorder) {
306 SAFE(mem_fp.reorder(mem_dt), WARN);
307 int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size());
308 if (rc != 0) {
309 res->state = FAILED;
310 SAFE(FAIL, WARN);
311 }
312 }
313
314 return OK;
315}
316
317int fill_dst(
318 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
319 auto dst_dt = mem_dt.dt();
320 int sum_ind = prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM);
321 auto sum_dt = (sum_ind != -1) ? prb->attr.post_ops.entry[sum_ind].sum.dt
322 : dnnl_data_type_undef;
323 bool diff_sum_dst_types
324 = sum_dt != dnnl_data_type_undef && sum_dt != dst_dt;
325 bool sum_dt_is_int8 = sum_dt == dnnl_s8 || sum_dt == dnnl_u8;
326
327 const auto &c = prb->cfg[DST];
328 float f_min = c.f_min;
329 float f_max = c.f_max;
330 if (diff_sum_dst_types && sum_dt_is_int8) {
331 f_min = lowest_dt(sum_dt);
332 f_max = max_dt(sum_dt);
333 }
334
335 // Change mem dt to sum dt, so we can save sum data properly.
336 if (diff_sum_dst_types) { mem_dt.set_dt(sum_dt); }
337
338 fill_dst_with_params(prb, mem_dt, mem_fp, sum_dt, c.f_sparsity, f_min,
339 f_max, c.f_base, c.f_step, res);
340
341 // Return dst data type back.
342 if (diff_sum_dst_types) { mem_dt.set_dt(dst_dt); }
343 return OK;
344}
345
346dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
347 const prb_t *prb = init_pd_args.prb;
348
349 auto src_d = dnn_mem_t::init_md(
350 prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag);
351 auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups,
352 prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag);
353 auto bia_d = dnn_mem_t::init_md(
354 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any);
355 auto dst_d = dnn_mem_t::init_md(
356 prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag);
357
358 dnnl_alg_kind_t alg = dnnl_deconvolution_direct;
359 if (prb->alg == WINO) alg = dnnl_deconvolution_winograd;
360
361 attr_args_t attr_args;
362 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
363 if (wei_scale.policy == policy_t::PER_OC) {
364 auto wei_mask = prb->has_groups ? 2 : 1;
365 attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales,
366 prb->oc, wei_mask);
367 }
368 attr_args.prepare_post_ops_mds(
369 prb->attr, prb->ndims, prb->dst_dims().data());
370 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
371 create_dnnl_attr(prb->attr, attr_args));
372
373 switch (prb->dir) {
374 case FWD_D:
375 case FWD_B:
376 case FWD_I:
377 if (prb->dir != FWD_B) bia_d.reset(nullptr);
378 DNN_SAFE_STATUS(dnnl_deconvolution_forward_primitive_desc_create(
379 &init_pd_args.pd, init_pd_args.engine,
380 prb->dir == FWD_I ? dnnl_forward_inference
381 : dnnl_forward_training,
382 alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(),
383 prb->dilations().data(), prb->padding().data(),
384 prb->padding_r().data(), dnnl_attr));
385 break;
386 case BWD_D:
387 DNN_SAFE_STATUS(
388 dnnl_deconvolution_backward_data_primitive_desc_create(
389 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
390 wei_d, dst_d, prb->strides().data(),
391 prb->dilations().data(), prb->padding().data(),
392 prb->padding_r().data(), init_pd_args.hint,
393 dnnl_attr));
394 break;
395 case BWD_W:
396 case BWD_WB:
397 if (prb->dir == BWD_W) bia_d.reset(nullptr);
398 DNN_SAFE_STATUS(
399 dnnl_deconvolution_backward_weights_primitive_desc_create(
400 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
401 wei_d, bia_d, dst_d, prb->strides().data(),
402 prb->dilations().data(), prb->padding().data(),
403 prb->padding_r().data(), init_pd_args.hint,
404 dnnl_attr));
405 break;
406 default: DNN_SAFE_STATUS(dnnl_invalid_arguments);
407 }
408
409 // TODO: enabled back when query is added to pd??
410 //DNN_SAFE_STATUS(cd.accum_data_type == prb->cfg[ACC].dt
411 // ? dnnl_success
412 // : dnnl_unimplemented);
413
414 return dnnl_success;
415}
416
417int init_prim_ref(
418 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
419 if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
420
421 // Create a new copy of prb to avoid potentially corrupting the test by
422 // modifying prb in place.
423 // DIRECT algorithm is used to prevent fallback to the slow benchdnn
424 // reference implementation.
425 auto cpu_attr = prb->attr;
426 update_cpu_ref_attrs(cpu_attr);
427 prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx,
428 DIRECT, cpu_attr, prb->ctx_init, prb->ctx_exe, prb->mb};
429
430 init_pd_args_t<prb_t> init_pd_args(
431 /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
432 /* hint = */ nullptr);
433 init_pd(init_pd_args);
434
435 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
436 fetch_impl(pdw, init_pd_args, /* res = */ nullptr,
437 /* is_service_prim = */ true);
438
439 dnnl_primitive_t prim_ref_ {};
440 if (pdw) {
441 if (query_impl_info(pdw) == "ref:any") return OK;
442 DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN);
443 BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n",
444 query_impl_info(pdw).c_str());
445 }
446 prim_ref.reset(prim_ref_);
447 return OK;
448}
449
450void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
451 skip_unimplemented_data_type(
452 {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir,
453 res);
454 skip_unimplemented_sum_po(prb->attr, res);
455
456 // GPU supports only post ops and all but x8s8bf16 cfg
457 if (is_gpu()) {
458 const bool is_x8s8bf16_cfg
459 = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[DST].dt == dnnl_bf16;
460 const bool fwd_ok = !is_x8s8bf16_cfg;
461 if (!fwd_ok) {
462 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
463 return;
464 }
465 }
466}
467
468void skip_invalid_prb(const prb_t *prb, res_t *res) {}
469
470void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
471 const args_t &ref_args) {
472 const bool compare_with_norm = (prb->alg & WINO);
473 cmp.set_norm_validation_mode(compare_with_norm);
474
475 float trh = prb->cfg[kind].eps;
476 if ((prb->alg & WINO) && (prb->dir & FLAG_WEI)) {
477 // This is an empirical equation derived by observing growth error with
478 // increasing 'k' dimension in gemm of winograd
479 const float log_const = log10(0.125 * prb->mb * prb->oh * prb->ow);
480 trh = prb->cfg[kind].eps * (MAX2(1, pow(10, 0.4 * log_const)));
481 }
482 cmp.set_threshold(trh);
483
484 const float zpp = (1.f - get_non_zero_trust_percent(prb, kind)) * 100.f;
485 cmp.set_zero_trust_percent(zpp);
486}
487
488int doit(const prb_t *prb, res_t *res) {
489 if (bench_mode == LIST) return res->state = LISTED, OK;
490
491 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
492 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
493 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
494
495 auto const_pd = query_pd(prim);
496
497 const auto &src_md = prb->dir == BWD_D
498 ? query_md(const_pd, DNNL_ARG_DIFF_SRC)
499 : query_md(const_pd, DNNL_ARG_SRC);
500 const auto &wei_md = prb->dir & FLAG_WEI
501 ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS)
502 : query_md(const_pd, DNNL_ARG_WEIGHTS);
503 const auto &bia_md = prb->dir & FLAG_WEI
504 ? query_md(const_pd, DNNL_ARG_DIFF_BIAS)
505 : query_md(const_pd, DNNL_ARG_BIAS);
506 const auto &dst_md = prb->dir & FLAG_BWD
507 ? query_md(const_pd, DNNL_ARG_DIFF_DST)
508 : query_md(const_pd, DNNL_ARG_DST);
509 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
510
511 dnnl_dims_t wei_tr_dims {};
512 dnnl_dim_t wei_ndims = query_md_ndims(wei_md);
513 std::copy(query_md_dims(wei_md), query_md_dims(wei_md) + wei_ndims,
514 wei_tr_dims);
515
516 const bool with_groups = true;
517 std::swap(wei_tr_dims[with_groups + 0], wei_tr_dims[with_groups + 1]);
518
519 const auto fp = dnnl_f32;
520 const auto src_tag = tag::abx;
521 const auto wei_tag = tag::abx;
522
523 // Use CPU prim as the reference in GPU testing to reduce testing time.
524 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
525 SAFE(init_prim_ref(prim_ref, prb), WARN);
526
527 const auto &test_engine = get_test_engine();
528 const auto &ref_engine = get_cpu_engine();
529
530 dnn_mem_t src_dt(src_md, test_engine);
531 dnn_mem_t wei_dt(wei_md, test_engine);
532 dnn_mem_t dst_dt(dst_md, test_engine);
533 dnn_mem_t bia_dt(bia_md, test_engine);
534 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
535 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
536 std::vector<int> binary_po_args;
537 SAFE(binary::setup_binary_po(
538 const_pd, binary_po_args, binary_po_dt, binary_po_fp),
539 WARN);
540 std::vector<dnn_mem_t> prelu_po_fp, prelu_po_dt;
541 std::vector<int> prelu_po_args;
542 SAFE(prelu::setup_prelu_po(
543 const_pd, prelu_po_args, prelu_po_fp, prelu_po_dt),
544 WARN);
545
546 dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine);
547 dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine);
548 dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine);
549 dnn_mem_t wei_tr_fp(wei_ndims, wei_tr_dims, fp, wei_tag, ref_engine);
550 dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
551 dnn_mem_t scratchpad_fp;
552
553 if (prim_ref)
554 scratchpad_fp = dnn_mem_t(
555 query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine);
556
557 dnn_mem_t src_scales_dt, src_scales_fp;
558 dnn_mem_t wei_scales_dt, wei_scales_fp;
559 dnn_mem_t dst_scales_dt, dst_scales_fp;
560 dnn_mem_t src_zp_fp, src_zp_dt;
561 dnn_mem_t dst_zp_fp, dst_zp_dt;
562
563 /* fill memory + reorders <-> */
564 if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
565 if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
566 if (need_wei_init(prb)) {
567 SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN);
568 SAFE(transpose_data_wei(prb, wei_fp, wei_tr_fp), WARN);
569 }
570 if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN);
571
572 args_t args, ref_args;
573
574 if (prb->dir & FLAG_FWD) {
575 maybe_prepare_runtime_zero_points_v2(src_zp_dt, src_zp_fp, prb->attr,
576 DNNL_ARG_SRC, prb->ic, prb->src_zp);
577 maybe_prepare_runtime_zero_points_v2(dst_zp_dt, dst_zp_fp, prb->attr,
578 DNNL_ARG_DST, prb->oc, prb->dst_zp);
579 const int src_mask = attr_t::get_default_mask(
580 prb->attr.scales.get(DNNL_ARG_SRC).policy);
581 int wei_mask = attr_t::get_default_mask(
582 prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy,
583 DNNL_ARG_WEIGHTS);
584 if (prb->has_groups) wei_mask = (1 << wei_mask) + 1;
585 const int dst_mask = attr_t::get_default_mask(
586 prb->attr.scales.get(DNNL_ARG_DST).policy);
587 maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp,
588 prb->attr.scales.get(DNNL_ARG_SRC),
589 prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales);
590 maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp,
591 prb->attr.scales.get(DNNL_ARG_WEIGHTS),
592 prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales);
593 maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp,
594 prb->attr.scales.get(DNNL_ARG_DST),
595 prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales);
596
597 args.set(DNNL_ARG_SRC, src_dt);
598 args.set(DNNL_ARG_WEIGHTS, wei_dt);
599 args.set(DNNL_ARG_BIAS, bia_dt);
600 args.set(DNNL_ARG_DST, dst_dt);
601 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
602 args.set(binary_po_args, binary_po_dt);
603 args.set(prelu_po_args, prelu_po_dt);
604 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt);
605 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt);
606 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
607 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt);
608 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
609
610 SAFE(execute_and_wait(prim, args, res), WARN);
611
612 if (is_bench_mode(CORR)) {
613 ref_args.set(DNNL_ARG_SRC, src_fp);
614 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
615 ref_args.set(DNNL_ARG_BIAS, bia_fp);
616 ref_args.set(DNNL_ARG_DST, dst_fp);
617 ref_args.set(DNNL_ARG_WEIGHTS_1, wei_tr_fp); // Hack. See ref.
618 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
619 ref_args.set(binary_po_args, binary_po_fp);
620 ref_args.set(prelu_po_args, prelu_po_fp);
621 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp);
622 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp);
623 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp);
624 ref_args.set(
625 DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp);
626 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp);
627
628 check_correctness(
629 prb, {DST}, args, ref_args, setup_cmp, res, prim_ref);
630 }
631 } else if (prb->dir == BWD_D) {
632 args.set(DNNL_ARG_DIFF_DST, dst_dt);
633 args.set(DNNL_ARG_WEIGHTS, wei_dt);
634 args.set(DNNL_ARG_DIFF_SRC, src_dt);
635 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
636
637 SAFE(execute_and_wait(prim, args, res), WARN);
638
639 if (is_bench_mode(CORR)) {
640 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
641 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
642 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
643 ref_args.set(DNNL_ARG_WEIGHTS_1, wei_tr_fp); // Hack. See ref.
644 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
645
646 check_correctness(
647 prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref);
648 }
649 } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
650 args.set(DNNL_ARG_SRC, src_dt);
651 args.set(DNNL_ARG_DIFF_DST, dst_dt);
652 args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt);
653 args.set(DNNL_ARG_DIFF_BIAS, bia_dt);
654 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
655
656 SAFE(execute_and_wait(prim, args, res), WARN);
657
658 if (is_bench_mode(CORR)) {
659 ref_args.set(DNNL_ARG_SRC, src_fp);
660 ref_args.set(DNNL_ARG_DIFF_WEIGHTS_1, wei_tr_fp); // Hack. See ref.
661 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
662 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp);
663 ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp);
664 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
665
666 check_correctness(
667 prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref);
668 }
669 } else {
670 SAFE(FAIL, CRIT);
671 }
672
673 return measure_perf(prb->ctx_exe, res, prim, args);
674}
675
676} // namespace deconv
677