1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stddef.h>
20#include <stdio.h>
21#include <stdlib.h>
22
23#include <random>
24#include <sstream>
25
26#include "oneapi/dnnl/dnnl.h"
27
28#include "utils/parallel.hpp"
29
30#include "dnnl_common.hpp"
31#include "dnnl_memory.hpp"
32
33#include "bnorm/bnorm.hpp"
34
35namespace bnorm {
36
37static int prepare_fwd_with_stats(const prb_t *prb, dnn_mem_t &src,
38 dnn_mem_t &src_add, dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc,
39 dnn_mem_t &sh) {
40 const bool use_sc = prb->use_sc();
41 const bool use_sh = prb->use_sh();
42 const bool fill_src_add = prb->fuse_add_relu();
43
44 benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
45 mean.set_elem(c, 4 * ((c % 5) - 2));
46 var.set_elem(c, ((c % 7) << 1));
47
48 const float sc_value = 1 << (c % 7);
49 const float sh_value = ((c % 3) - 1) * sc_value;
50 sc.set_elem(c, use_sc ? sc_value : 1.0f);
51 sh.set_elem(c, use_sh ? sh_value : 0.0f);
52 });
53
54 benchdnn_parallel_nd(prb->ic, prb->mb, prb->id, prb->ih, prb->iw,
55 [&](int64_t c, int64_t mb, int64_t d, int64_t h, int64_t w) {
56 int64_t l_base = mb * prb->id * prb->ih * prb->iw + c * 239 * 2;
57 float *s = (float *)src + data_off(prb, mb, c, 0, 0, 0);
58
59 const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w;
60 const int64_t l = l_base + sp;
61 const int64_t value = (l % 65) - 32;
62 s[sp] = round_to_nearest_representable(prb->dt, value);
63 if (fill_src_add) {
64 float *s_add
65 = (float *)src_add + data_off(prb, mb, c, 0, 0, 0);
66 s_add[sp] = round_to_nearest_representable(
67 prb->dt, (l % 17) - 8);
68 }
69 });
70
71 return OK;
72}
73
74static int prepare_fwd_no_stats(const prb_t *prb, dnn_mem_t &src,
75 dnn_mem_t &src_add, dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc,
76 dnn_mem_t &sh) {
77 /** Idea: choose src[] values so that both mean and variance are computed
78 * exactly (independently of the order of the computations).
79 *
80 * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
81 *
82 * The variation in src is allowed in the last flex_bits bits.
83 * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
84 * value is set to 0 and src is partially filled with zeros (according to
85 * density so that at least want_flex_bits is reserved for src variation.
86 * Once src is set, variance is computed.
87 *
88 * ALG_0: mean is set to 0
89 * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4}
90 * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */
91 const int64_t exact_bits = digits_dt(prb->dt);
92 const int64_t L = prb->mb * prb->id * prb->ih * prb->iw;
93 const int64_t logL = (int64_t)ceilf(log2f(L));
94
95 assert(logL <= 0 || (1LL << (logL - 1)) < L);
96 assert(L <= (1LL << logL));
97
98 const int64_t min_flex_bits = 3;
99 const int64_t want_flex_bits = MIN2(6, exact_bits / 2);
100
101 check_alg_t alg = prb->check_alg;
102 if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
103 alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
104
105 const int64_t flex_bits = alg == ALG_0
106 ? want_flex_bits /* BFloat16 has only 7 bits of mantissa */
107 : MIN2(prb->dt == dnnl_bf16 ? 7 : exact_bits,
108 (exact_bits - logL) / 2 - 1);
109
110 if (flex_bits < min_flex_bits) return FAIL;
111
112 const int64_t flex_mask = (1 << flex_bits) - 1;
113
114 /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
115 const float density = alg == ALG_0
116 ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L
117 : 1.f;
118 assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
119
120 BENCHDNN_PRINT(6, "check_alg: %s, density = %g, flex_bits = " IFMT "\n",
121 check_alg2str(alg), density, flex_bits);
122
123 const bool use_sc = prb->use_sc();
124 const bool use_sh = prb->use_sh();
125 const bool fill_src_add = prb->fuse_add_relu();
126
127 benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
128 const float m = ((float *)mean)[c]
129 = alg == ALG_0 ? 0.f : 0.25f * (1 << (c % 7));
130 float v = 0; /* current variance */
131
132 for (int64_t mb = 0; mb < prb->mb; ++mb) {
133 int64_t l_base = mb * prb->id * prb->ih * prb->iw
134 + c * 239 * 2; // l[0] must be even
135 int64_t off = data_off(prb, mb, c, 0, 0, 0);
136 float *s = (float *)src + off;
137
138 for_(int64_t d = 0; d < prb->id; ++d)
139 for_(int64_t h = 0; h < prb->ih; ++h)
140 for (int64_t w = 0; w < prb->iw; ++w) {
141
142 const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w;
143 const int64_t l = l_base + sp;
144
145 if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) {
146 s[sp] = 0;
147 continue;
148 }
149
150 const int64_t gen = (l / 2 * 1637) & flex_mask;
151 const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
152 const float f = 1.f * sgn * gen / (1 << flex_bits);
153
154 s[sp] = alg == ALG_0 ? f : m * (1.f + f);
155 if (L % 2 && (mb * prb->id * prb->ih * prb->iw + sp == L - 1)) {
156 s[sp] = m;
157 }
158 v += (s[sp] - m) * (s[sp] - m);
159 if (fill_src_add) {
160 // The main purpose of such filling is to avoid catastrophic
161 // cancellation. To do that, the sign of `Add` tensor final
162 // values is kept the same as it would be after applying
163 // bnorm: what's below mean, that has negative sign, what's
164 // equal or higher - positive.
165 const int64_t mod2_base = (mb + c + d + h + w) % 5;
166 const float mod2_val = 1.f / (2LL << mod2_base);
167 const int64_t sign_val = s[sp] < m ? -1 : 1;
168 float *s_add = (float *)src_add + off;
169 s_add[sp] = round_to_nearest_representable(
170 prb->dt, mod2_val * sign_val);
171 }
172 }
173 }
174
175 ((float *)var)[c] = v / (prb->mb * prb->id * prb->ih * prb->iw);
176
177 const float sc_value = 1.f / 8 * (1 << (c % 7));
178 const float sh_value = ((c % 3) - 1) * sc_value / 64;
179 ((float *)sc)[c] = use_sc ? sc_value : 1.0f;
180 ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
181 });
182
183 return OK;
184}
185
186static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &src_add,
187 dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc, dnn_mem_t &sh) {
188 if (prb->flags & GLOB_STATS)
189 return prepare_fwd_with_stats(prb, src, src_add, mean, var, sc, sh);
190 else
191 return prepare_fwd_no_stats(prb, src, src_add, mean, var, sc, sh);
192}
193
194static int prepare_bwd(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
195 const auto nelems = mem_fp.nelems();
196 if (nelems == 0) return OK;
197
198 // Idea behind filling: integer diff_dst values decrease norms unlike fp32
199 // values in [-1.f, 1.f] range. To decrease norms more, make data pretty
200 // sparse as answers sum all diff_dst values.
201
202 /* Do fixed partitioning to have same filling for any number of threads */
203 const int64_t n_chunks = 16;
204 const int64_t chunk_size = div_up(nelems, n_chunks);
205
206 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
207 int64_t idx_start = idx_chunk * chunk_size;
208 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
209
210 // Note: we use a different seed for each chunk to avoid
211 // repeating patterns. We could use discard(idx_start) too but
212 // it has a complexity in O(idx_start). We also add 1 to avoid
213 // seeding with 0.
214 std::minstd_rand msr(idx_start + 1);
215 msr.discard(1);
216
217 std::uniform_int_distribution<> igen_val(-2, 2);
218 std::uniform_int_distribution<> igen_coin(0, 256 * 1024);
219
220 // at least 20 non-zero elems
221 float sparsity = MAX2(0.05f, MIN2(1.f, 20.f / nelems));
222
223 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
224 float value = flip_coin(igen_coin(msr), sparsity)
225 ? round_to_nearest_representable(prb->dt, igen_val(msr))
226 : 0;
227 mem_fp.set_elem(idx, value);
228 }
229 });
230
231 SAFE(mem_dt.reorder(mem_fp), WARN);
232
233 return OK;
234}
235
236int check_fwd_ws(const dnn_mem_t &dst_dt, const dnn_mem_t &ws_dt, res_t *res) {
237 if (ws_dt.ndims() == 0) return OK;
238
239 /* so far we know ws is just bit-mask of whether value was negative or
240 * positive */
241 const auto nelems = dst_dt.nelems(true);
242 const uint8_t *ws = (const uint8_t *)ws_dt;
243
244 /* some internal knowledge: flags in ws are either stored as bytes (e.g.
245 * for the ref implementation) or as bits (e.g. for the jitted one); in
246 * the latter case the ws memory has fewer elements than the data memory */
247 enum { ws_byte, ws_bit } ws_type;
248 ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
249
250 /* more internal knowledge: dst_dt and ws_dt are expected to have exactly
251 * the same data layout, and dst_dt padded regions are expected to be
252 * zero, and the respective ws_dt elements should be set accordingly */
253 for (int64_t i = 0; i < nelems; i += 8) {
254 for (int64_t j = 0; j < MIN2(8, nelems - i); ++j) {
255 const float data = dst_dt.get_elem(i + j);
256 const bool want = data > 0.f;
257 const bool bit_set = ws_type == ws_byte ? *ws : !!(*ws & (1 << j));
258
259 const bool ok = bit_set == want;
260 res->errors += !ok;
261
262 bool dump = false || (!ok && (res->errors < 10 || verbose >= 10))
263 || (verbose >= 50 && i < 30);
264 if (dump) {
265 BENCHDNN_PRINT(0, "[%4ld] ws exp:%d got:%d (data:%g:%a)\n",
266 (long)(i + j), want, bit_set, data, data);
267 }
268
269 if (ws_type == ws_byte) ++ws;
270 }
271 if (ws_type == ws_bit) ++ws;
272 }
273
274 if (res->errors) res->state = FAILED;
275 if (res->state == EXECUTED) res->state = PASSED;
276
277 return res->state == FAILED ? FAIL : OK;
278}
279
280dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
281 const prb_t *prb = init_pd_args.prb;
282 const dir_t dir = init_pd_args.dir;
283
284 auto src_d = dnn_mem_t::init_md(
285 prb->ndims, prb->data_dims().data(), prb->dt, prb->tag);
286
287 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
288 create_dnnl_attr(prb->attr, attr_args_t()));
289
290 auto flags = (dnnl_normalization_flags_t)prb->flags;
291 if (dir & FLAG_FWD) {
292 auto dst_d = dnn_mem_t::init_md(
293 prb->ndims, prb->data_dims().data(), prb->dt, tag::any);
294 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
295 : dnnl_forward_training;
296 DNN_SAFE_STATUS(dnnl_batch_normalization_forward_primitive_desc_create(
297 &init_pd_args.pd, init_pd_args.engine, prop, src_d, dst_d,
298 prb->eps, flags, dnnl_attr));
299 } else {
300 auto diff_src_d = dnn_mem_t::init_md(
301 prb->ndims, prb->data_dims().data(), prb->dt, tag::any);
302 auto diff_dst_d = dnn_mem_t::init_md(
303 prb->ndims, prb->data_dims().data(), prb->dt, tag::any);
304 auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data;
305 DNN_SAFE_STATUS(dnnl_batch_normalization_backward_primitive_desc_create(
306 &init_pd_args.pd, init_pd_args.engine, prop, diff_src_d,
307 diff_dst_d, src_d, prb->eps, flags, init_pd_args.hint,
308 dnnl_attr));
309 }
310
311 return dnnl_success;
312}
313
314void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
315 skip_unimplemented_data_type({prb->dt}, prb->dir, res);
316 skip_unimplemented_sum_po(prb->attr, res);
317
318 // Non-zero alpha is not supported on GPU and for training in general.
319 const auto &po = prb->attr.post_ops;
320 const auto relu_idx = po.find(attr_t::post_ops_t::kind_t::RELU);
321 if (relu_idx >= 0) {
322 const auto &e = po.entry[relu_idx];
323 float alpha = e.eltwise.alpha;
324 bool alpha_ok
325 = IMPLICATION(alpha != 0.f, (prb->dir & FLAG_INF) && is_cpu());
326 if (!alpha_ok) {
327 res->state = SKIPPED;
328 res->reason = CASE_NOT_SUPPORTED;
329 }
330 }
331 // BN+Add+ReLU fusion is not supported on CPU
332 if (is_cpu() && prb->fuse_add_relu()) {
333 res->state = SKIPPED;
334 res->reason = CASE_NOT_SUPPORTED;
335 }
336}
337
338void skip_invalid_prb(const prb_t *prb, res_t *res) {
339 // See `skip_invalid_inplace` for details.
340 if (prb->inplace) {
341 skip_invalid_inplace(res, prb->dt, prb->dt, prb->tag, prb->tag);
342 if (res->state == SKIPPED) return;
343 }
344}
345
346void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
347 const args_t &ref_args) {
348 // Since bwd testing is done using results from forward which are random
349 // fp32 values, diff_scale starts fluctuating, so we check norm for both
350 // data, SC, and SH.
351 const bool compare_with_norm = (prb->dir & FLAG_BWD);
352 cmp.set_norm_validation_mode(compare_with_norm);
353
354 const int f32_mant_digits = 24;
355 const float trh_coeff = (1 << (f32_mant_digits - digits_dt(prb->dt)));
356 float trh = trh_coeff
357 * ((kind == SRC || kind == DST || kind == SRC_1) ? 5e-7 : 0);
358 if ((kind == SC || kind == SH) && prb->dir & FLAG_BWD)
359 trh = trh_coeff * 5e-6;
360
361#ifdef DNNL_EXPERIMENTAL
362 const bool bnorm_single_pass
363 = dnnl::impl::experimental::use_bnorm_stats_one_pass();
364#else
365 const bool bnorm_single_pass = false;
366#endif
367
368 const bool use_relaxed_validation = is_nvidia_gpu() || bnorm_single_pass;
369 if (use_relaxed_validation) {
370 // Nvidia: cuDNN stores unbiased variance which requires rescaling by
371 // `(N - 1) / N`, where `N = MB * Spatial`. Hence, we cannot set the
372 // threshold to 0...
373 // Also mean could be computed using a single pass formula.
374 //
375 // On Intel GPUs mean and variance could be rounded incorrectly because
376 // they are calculated using fast but potentially unstable formula.
377 if (kind == MEAN) trh = 1e-7;
378 if (kind == VAR) trh = 4e-7;
379 }
380 cmp.set_threshold(trh);
381
382 // TODO: improve bf16 filling
383 if (prb->dt == dnnl_bf16) cmp.set_zero_trust_percent(99.f);
384
385 // When the error is larger than `trh`, it could be due to a catastrophic
386 // cancellation in final result which is computed as `Y = a * X + b`.
387 // When `a * X` is close to `b` and their signs are opposite, then large
388 // error in `a * X` could result in a final result (which has a cancellation
389 // i.e. `|Y| = |a*X - (-b)|`), which has no meaningful digits left in
390 // mantissa.
391 //
392 // Since lambda is called when stack is unavailable, need to capture `prb`
393 // and `kind` by value to avoid using dangling references.
394 const auto bnorm_add_check =
395 [&, kind, prb](
396 const compare::compare_t::driver_check_func_args_t &args) {
397 if (!((prb->dir & FLAG_FWD) && kind == DST && prb->use_sh()))
398 return false;
399
400 const auto &sh = ref_args.find(DNNL_ARG_SHIFT);
401 const auto &dst = ref_args.find(DNNL_ARG_DST);
402 const int64_t c = dst.get_scale_idx(
403 args.idx, 1 << 1 /* channel_mask */);
404 const float beta = sh.get_elem(c);
405 // Using an empirically derived threshold, check if
406 // cancellation error in `|Y| = |a*X - (-b)|` is huge.
407 const float abs_exp = fabsf(args.exp);
408 const float norm_denom = abs_exp > FLT_MIN ? abs_exp : 1.f;
409 const float abs_exp_delta = fabsf(args.exp - beta);
410 bool maybe_cancel_error = abs_exp_delta / norm_denom > 1.f;
411 if (!maybe_cancel_error) return false;
412
413 // Check for error in `a * X`
414 float diff_aX = fabsf((args.exp - beta) - (args.got - beta));
415 float rel_diff_aX = diff_aX
416 / (abs_exp_delta > FLT_MIN ? abs_exp_delta : 1.f);
417 return rel_diff_aX <= args.trh;
418 };
419 cmp.set_driver_check_function(bnorm_add_check);
420}
421
422int doit(const prb_t *prb, res_t *res) {
423 if (bench_mode == LIST) return res->state = LISTED, OK;
424
425 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
426 bool is_service_prim = prb->dir & FLAG_BWD;
427 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr,
428 is_service_prim),
429 WARN);
430
431 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
432
433 auto const_fpd = query_pd(prim);
434
435 const bool use_sc = prb->use_sc();
436 const bool use_sh = prb->use_sh();
437 const bool fuse_add_relu = prb->fuse_add_relu();
438
439 const auto &data_md = query_md(const_fpd, DNNL_ARG_SRC);
440 const auto &src_add_md = query_md(const_fpd, DNNL_ARG_SRC_1);
441 const auto &mean_md = query_md(const_fpd, DNNL_ARG_MEAN);
442 const auto &var_md = query_md(const_fpd, DNNL_ARG_VARIANCE);
443 const auto &sc_md = query_md(const_fpd, DNNL_ARG_SCALE);
444 const auto &sh_md = query_md(const_fpd, DNNL_ARG_SHIFT);
445 const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE);
446 const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD);
447
448 const auto fp = dnnl_f32;
449 const auto tag = tag::abx;
450
451 const auto &test_engine = get_test_engine();
452 const auto &ref_engine = get_cpu_engine();
453
454 dnn_mem_t src_fp(data_md, fp, tag, ref_engine);
455 dnn_mem_t src_dt(data_md, test_engine);
456 dnn_mem_t src_add_fp(src_add_md, fp, tag, ref_engine);
457 dnn_mem_t src_add_dt(src_add_md, test_engine);
458 // stash for bwd: src_hat[i] = (src[i] - mean) / sqrt(var + prb->eps)
459 dnn_mem_t src_hat_fp(data_md, fp, tag, ref_engine);
460
461 dnn_mem_t &dst_fp = src_fp; // in-place in ref code
462 dnn_mem_t placeholder_dst_dt;
463 const bool inplace_fwd = prb->inplace && (prb->dir & FLAG_FWD);
464 if (!inplace_fwd) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); }
465 dnn_mem_t &dst_dt = inplace_fwd ? src_dt : placeholder_dst_dt;
466
467 // On inference w/o global stats the batch norm doesn't require stat
468 // memories. Hence, we need to prepare the mean_fp and var_fp ourselves.
469 const dnnl_dims_t dims1d = {prb->ic};
470 dnn_mem_t mean_fp(1, dims1d, fp, tag::abx, ref_engine);
471 dnn_mem_t mean_dt(mean_md, test_engine);
472 dnn_mem_t var_fp(1, dims1d, fp, tag::abx, ref_engine);
473 dnn_mem_t var_dt(var_md, test_engine);
474
475 dnn_mem_t sc_fp(1, dims1d, fp, tag::abx, ref_engine);
476 dnn_mem_t sc_dt(sc_md, test_engine);
477 dnn_mem_t d_sc_fp(1, dims1d, fp, tag::abx, ref_engine);
478 dnn_mem_t d_sc_dt(sc_md, test_engine);
479
480 dnn_mem_t sh_fp(1, dims1d, fp, tag::abx, ref_engine);
481 dnn_mem_t sh_dt(sh_md, test_engine);
482 dnn_mem_t d_sh_fp(1, dims1d, fp, tag::abx, ref_engine);
483 dnn_mem_t d_sh_dt(sh_md, test_engine);
484
485 dnn_mem_t ws_fp(data_md, dnnl_u8, tag, ref_engine);
486 dnn_mem_t ws_dt(ws_md, test_engine);
487 if (prb->need_ws()) SAFE(ws_dt.ndims() != 0 ? OK : FAIL, WARN);
488 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
489
490 dnn_mem_t d_dst_dt, placeholder_d_src_dt;
491
492 if (prepare_fwd(prb, src_fp, src_add_fp, mean_fp, var_fp, sc_fp, sh_fp)
493 != OK) {
494 return res->state = MISTRUSTED, OK;
495 }
496
497 SAFE(src_dt.reorder(src_fp), WARN);
498 if (fuse_add_relu) SAFE(src_add_dt.reorder(src_add_fp), WARN);
499 if (prb->flags & GLOB_STATS) {
500 SAFE(mean_dt.reorder(mean_fp), WARN);
501 SAFE(var_dt.reorder(var_fp), WARN);
502 }
503 if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); }
504 if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
505
506 args_t args, ref_args;
507
508 args.set(DNNL_ARG_SRC, src_dt);
509 args.set(DNNL_ARG_SRC_1, src_add_dt);
510 args.set(DNNL_ARG_MEAN, mean_dt);
511 args.set(DNNL_ARG_VARIANCE, var_dt);
512 args.set(DNNL_ARG_SCALE, sc_dt);
513 args.set(DNNL_ARG_SHIFT, sh_dt);
514 args.set(DNNL_ARG_WORKSPACE, ws_dt);
515 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
516 args.set(DNNL_ARG_DST, dst_dt);
517
518 SAFE(execute_and_wait(prim, args, res), WARN);
519
520 // Running ref to collect src_hat (used instead of src + mean) and ws, if
521 // fuse_relu flag is requested.
522 if (is_bench_mode(CORR)) {
523 if (prb->dir & FLAG_FWD) {
524 ref_args.set(DNNL_ARG_SRC, src_fp);
525 ref_args.set(DNNL_ARG_SRC_1, src_add_fp);
526 ref_args.set(DNNL_ARG_MEAN, mean_fp);
527 ref_args.set(DNNL_ARG_VARIANCE, var_fp);
528 ref_args.set(DNNL_ARG_SCALE, sc_fp);
529 ref_args.set(DNNL_ARG_SHIFT, sh_fp);
530 ref_args.set(DNNL_ARG_WORKSPACE, ws_fp);
531 ref_args.set(DNNL_ARG_DST, dst_fp);
532 ref_args.set(DNNL_ARG_DST_1, src_hat_fp); // Reference aux arg.
533
534 std::vector<data_kind_t> kinds {DST};
535 if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) {
536 kinds.push_back(MEAN);
537 kinds.push_back(VAR);
538 }
539
540 check_correctness(prb, kinds, args, ref_args, setup_cmp, res);
541
542 if (prb->debug_check_ws) check_fwd_ws(dst_dt, ws_dt, res);
543 }
544 }
545
546 if (prb->dir & FLAG_BWD) {
547 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim;
548 SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN);
549 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
550 prim.reset(tmp_prim.release());
551
552 auto const_bpd = query_pd(prim);
553
554 const auto &d_data_md = query_md(const_bpd, DNNL_ARG_DIFF_DST);
555 const auto &d_src_add_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC_1);
556 const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD);
557
558 dnn_mem_t d_dst_fp(d_data_md, fp, tag, ref_engine);
559 d_dst_dt = dnn_mem_t(d_data_md, test_engine);
560
561 dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code
562 dnn_mem_t d_src_add_fp(d_src_add_md, fp, tag, ref_engine);
563 if (!prb->inplace) {
564 placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine);
565 }
566 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
567 dnn_mem_t d_src_add_dt = dnn_mem_t(d_src_add_md, test_engine);
568
569 scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine);
570
571 SAFE(prepare_bwd(prb, d_dst_dt, d_dst_fp), WARN);
572
573 args.clear();
574 args.set(DNNL_ARG_SRC, src_dt);
575 args.set(DNNL_ARG_SRC_1, src_add_dt);
576 args.set(DNNL_ARG_MEAN, mean_dt);
577 args.set(DNNL_ARG_VARIANCE, var_dt);
578 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
579 args.set(DNNL_ARG_SCALE, sc_dt);
580 args.set(DNNL_ARG_WORKSPACE, ws_dt);
581 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
582 args.set(DNNL_ARG_DIFF_SCALE, d_sc_dt);
583 args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt);
584 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
585 // Since DIFF_SRC_1 is the second output it can be in blocked format
586 // and unconditional including leads zero-paddiing failures
587 if (fuse_add_relu) args.set(DNNL_ARG_DIFF_SRC_1, d_src_add_dt);
588
589 SAFE(execute_and_wait(prim, args, res), WARN);
590
591 if (is_bench_mode(CORR)) {
592 ref_args.set(DNNL_ARG_SRC, src_fp);
593 ref_args.set(DNNL_ARG_SRC_1, src_add_fp);
594 ref_args.set(DNNL_ARG_MEAN, mean_fp);
595 ref_args.set(DNNL_ARG_VARIANCE, var_fp);
596 ref_args.set(DNNL_ARG_SCALE, sc_fp);
597 ref_args.set(DNNL_ARG_SHIFT, sh_fp);
598 ref_args.set(DNNL_ARG_WORKSPACE, ws_fp);
599 ref_args.set(DNNL_ARG_DST, dst_fp);
600 ref_args.set(DNNL_ARG_DST_1, src_hat_fp); // Reference aux arg.
601 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
602 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
603 ref_args.set(DNNL_ARG_DIFF_SRC_1, d_src_add_fp);
604 ref_args.set(DNNL_ARG_DIFF_SCALE, d_sc_fp);
605 ref_args.set(DNNL_ARG_DIFF_SHIFT, d_sh_fp);
606
607 std::vector<data_kind_t> kinds {SRC};
608 if (use_sc && (prb->dir & FLAG_WEI)) kinds.push_back(SC);
609 if (use_sh && (prb->dir & FLAG_WEI)) kinds.push_back(SH);
610 if (fuse_add_relu) kinds.push_back(SRC_1);
611
612 check_correctness(prb, kinds, args, ref_args, setup_cmp, res);
613 }
614 }
615
616 return measure_perf(prb->ctx_exe, res, prim, args);
617}
618
619} // namespace bnorm
620