1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cmath>
18#include <float.h>
19#include <random>
20#include <stdio.h>
21#include <stdlib.h>
22
23#include <sstream>
24
25#include "oneapi/dnnl/dnnl.h"
26
27#include "utils/parallel.hpp"
28
29#include "dnnl_common.hpp"
30#include "dnnl_memory.hpp"
31
32#include "bnorm/bnorm.hpp"
33#include "lnorm/lnorm.hpp"
34
35using namespace bnorm;
36
37namespace lnorm {
38
39static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &mean,
40 dnn_mem_t &var, dnn_mem_t &sc, dnn_mem_t &sh) {
41 /** Idea: choose src[] values so that both mean and variance are computed
42 * exactly (independently of the order of the computations).
43 *
44 * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
45 *
46 * The variation in src is allowed in the last flex_bits bits.
47 * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
48 * value is set to 0 and src is partially filled with zeros (according to
49 * density so that at least want_flex_bits is reserved for src variation.
50 * Once src is set, variance is computed.
51 *
52 * ALG_0: mean is set to 0
53 * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4}
54 * ALG_AUTO: choose between ALG_0 and ALG_1 automatically
55 * ALG_2: if fall back to ALG_0 gives only one non-zero element, use the
56 * filling which doesn't use strict approach.
57 */
58 const int64_t exact_bits = digits_dt(prb->dt[0]);
59 const int64_t L = prb->c;
60 const int64_t logL = (int64_t)ceilf(log2f(L));
61
62 assert(logL <= 0 || (1LL << (logL - 1)) < L);
63 assert(L <= (1LL << logL));
64
65 const int64_t min_flex_bits = 3;
66 const int64_t want_flex_bits = MIN2(6, exact_bits / 2);
67
68 check_alg_t alg = prb->check_alg;
69 if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
70 alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
71
72 const int64_t flex_bits = alg == ALG_0
73 ? want_flex_bits
74 : MIN2(exact_bits, (exact_bits - logL) / 2 - 1);
75 if (flex_bits < min_flex_bits) return FAIL;
76
77 if (exact_bits / 2 == flex_bits) alg = ALG_2;
78
79 if ((alg == ALG_0 || alg == ALG_1) && !is_integral_dt(prb->dt[0])) {
80 const int64_t flex_mask = (1 << flex_bits) - 1;
81
82 /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
83 const float density = alg == ALG_0
84 ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L
85 : 1.f;
86 assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
87
88 BENCHDNN_PRINT(99, "check_alg: %s, density = %g, flex_bits = %ld\n",
89 check_alg2str(alg), density, (long)flex_bits);
90
91 benchdnn_parallel_nd(prb->n, [&](int64_t n) {
92 const float m = alg == ALG_0 ? 0.f : 0.25f * (1 << (n % 7));
93 float v = 0; /* current variance */
94
95 float *s = (float *)src + n * prb->c;
96 for (int64_t c = 0; c < prb->c; ++c) {
97 const int64_t l = c + n * 239 * 2; // l[0] must be even
98
99 if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) {
100 s[c] = 0;
101 continue;
102 }
103
104 const int64_t gen = (l / 2 * 1637) & flex_mask;
105 const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
106 const float f = 1.f * sgn * gen / (1 << flex_bits);
107
108 src.set_elem(n * prb->c + c, alg == ALG_0 ? f : m * (1.f + f));
109 if (L % 2 && (c == L - 1)) { s[c] = m; }
110 v += (s[c] - m) * (s[c] - m);
111 }
112 mean.set_elem(n, m);
113 var.set_elem(n, v / prb->c);
114 });
115 } else {
116 assert(alg == ALG_2);
117
118 benchdnn_parallel_nd(prb->n, [&](int64_t n) {
119 // Note: we use a different seed for each chunk to avoid
120 // repeating patterns. We could use discard(idx_start) too but
121 // it has a complexity in O(idx_start). We also add 1 to avoid
122 // seeding with 0.
123 std::minstd_rand int_seed(n + 1);
124 int_seed.discard(1);
125 std::minstd_rand b_seed(n + 1);
126 b_seed.discard(2);
127
128 const float val_coeff = is_integral_dt(prb->dt[0]) ? 4.f : 1.f;
129 const int distr_shift = prb->dt[0] == dnnl_u8 ? 2 : 0;
130 std::uniform_int_distribution<> int_dist(0 + distr_shift, 6);
131 std::bernoulli_distribution b_dist(0.5f);
132 const float m = val_coeff * 0.25f * (1 << int_dist(int_seed));
133 float v = 0; /* current variance */
134
135 const int64_t c_shift = n * prb->c;
136 float *s = (float *)src + c_shift;
137
138 bool bigger_val = false;
139 float val = 0.f;
140
141 for (int64_t c = 0; c < prb->c; ++c) {
142 const int64_t idx = c_shift + c;
143
144 if (c % 2 == 0) {
145 bigger_val = b_dist(b_seed);
146 val = bigger_val ? (m + val_coeff * 1.f)
147 : (m + val_coeff * 0.25f);
148 } else {
149 val = bigger_val ? (m - val_coeff * 1.f)
150 : (m - val_coeff * 0.25f);
151 }
152 src.set_elem(idx, val);
153
154 v += (s[c] - m) * (s[c] - m);
155 }
156 // Update last element with s[c] = m.
157 if (prb->c % 2 == 1) {
158 v -= (s[prb->c - 1] - m) * (s[prb->c - 1] - m);
159 s[prb->c - 1] = m;
160 }
161 mean.set_elem(n, m);
162 var.set_elem(n, v / prb->c);
163 });
164 }
165
166 const bool use_sc = prb->use_sc();
167 const bool use_sh = prb->use_sh();
168
169 benchdnn_parallel_nd(prb->c, [&](int64_t c) {
170 float sc_value = 1.f / 8 * (1 << (c % 7));
171 float sh_value = (c % 3 + 1) * sc_value / 64;
172 ((float *)sc)[c] = use_sc ? sc_value : 1.0f;
173 ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
174 });
175 return OK;
176}
177
178static int prepare_bwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &d_dst,
179 dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc) {
180 if (prb->c < 2) return FAIL;
181
182 const bool use_sc = prb->use_sc();
183
184 // fill gamma
185 for (int64_t c = 0; c < prb->c; ++c) {
186 const float sc_value = 0.125f * (1 << (c % 7));
187 ((float *)sc)[c] = use_sc ? sc_value : 1.0f;
188 }
189
190 benchdnn_parallel_nd(prb->n, [&](int64_t n) {
191 // Note: we use a different seed for each chunk to avoid
192 // repeating patterns. We could use discard(idx_start) too but
193 // it has a complexity in O(idx_start). We also add 1 to avoid
194 // seeding with 0.
195 std::minstd_rand int_seed(n + 1);
196 int_seed.discard(1);
197 std::minstd_rand b_seed(n + 1);
198 b_seed.discard(2);
199
200 // Idea behind the filling is to reduce a possibility of cancellation
201 // when subtracting a part accumulated over N. For that, we simplify
202 // src data to (m+1) and (m-1) points, d_dst data is more or less
203 // random but we keep all values as pow2 values to have almost exact
204 // summation result.
205 std::uniform_int_distribution<> stat_dist(0, 2);
206 std::uniform_int_distribution<> data_dist(0, 6);
207 std::bernoulli_distribution half_dist(0.5f);
208
209 // mean = {-0.5f, 0.f, 0.5f}
210 const float m = 0.5f * (stat_dist(int_seed) - 1);
211 mean.set_elem(n, m);
212
213 // final variance = {0.25f, 1.f, 4.f}
214 const float v = 0.25f * (1 << (stat_dist(int_seed) * 2));
215 var.set_elem(n, v - prb->eps);
216
217 const int64_t c_shift = n * prb->c;
218
219 for (int64_t c = 0; c < prb->c; ++c) {
220 int sign = half_dist(b_seed) ? 1.f : -1.f;
221 // d_dst = powf(2, {-4, ... , 2})
222 float dd = sign * 0.0625f * (1LL << data_dist(int_seed));
223 d_dst.set_elem(c_shift + c,
224 round_to_nearest_representable(prb->dt[1], dd));
225
226 float s = c % 2 == 0 ? (m - 1.f) : (m + 1.f);
227 src.set_elem(
228 c_shift + c, round_to_nearest_representable(prb->dt[0], s));
229 }
230 });
231
232 return OK;
233}
234
235int fill_scales(
236 const attr_t &attr, int arg, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
237 const auto nelems = mem_fp.nelems();
238 if (nelems == 0) return OK;
239
240 assert(mem_dt.nelems() == mem_fp.nelems());
241
242 const auto &scales = attr.scales.get(arg);
243
244 /* Do fixed partitioning to have same filling for any number of threads */
245 const int64_t n_chunks = 16;
246 const int64_t chunk_size = div_up(nelems, n_chunks);
247 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
248 int64_t idx_start = idx_chunk * chunk_size;
249 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
250 // Note: we use a different seed for each chunk to avoid
251 // repeating patterns. We could use discard(idx_start) too but
252 // it has a complexity in O(idx_start). We also add 1 to avoid
253 // seeding with 0.
254 std::minstd_rand int_seed(idx_start + 1);
255 int_seed.discard(1);
256
257 std::uniform_int_distribution<> gen(-5, 5);
258
259 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
260 int pow2 = gen(int_seed);
261 int pow2_shift = 1 << std::abs(pow2);
262 const float gen_val = pow2 < 0 ? (1.f / pow2_shift) : pow2_shift;
263 const float fixed_val = scales.scale;
264 const float val = nelems == 1 ? fixed_val : gen_val;
265 mem_fp.set_elem(idx, val);
266 }
267 });
268
269 SAFE(mem_dt.reorder(mem_fp), WARN);
270
271 return OK;
272}
273
274dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
275 const prb_t *prb = init_pd_args.prb;
276
277 auto src_d = dnn_mem_t::init_md(
278 prb->ndims, prb->dims.data(), prb->dt[0], prb->tag[0]);
279 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> stat_d {};
280 if (prb->stat_tag != tag::undef) {
281 stat_d = dnn_mem_t::init_md(
282 prb->ndims - 1, prb->dims.data(), dnnl_f32, prb->stat_tag);
283 }
284
285 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
286 create_dnnl_attr(prb->attr, attr_args_t()));
287
288 auto flags = (dnnl_normalization_flags_t)prb->flags;
289 if (prb->dir & FLAG_FWD) {
290 auto dst_d = dnn_mem_t::init_md(
291 prb->ndims, prb->dims.data(), prb->dt[1], prb->tag[1]);
292 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
293 : dnnl_forward_training;
294 DNN_SAFE_STATUS(dnnl_layer_normalization_forward_primitive_desc_create(
295 &init_pd_args.pd, init_pd_args.engine, prop, src_d, dst_d,
296 stat_d, prb->eps, flags, dnnl_attr));
297 } else {
298 auto diff_src_d = dnn_mem_t::init_md(
299 prb->ndims, prb->dims.data(), prb->dt[0], prb->tag[0]);
300 auto diff_dst_d = dnn_mem_t::init_md(
301 prb->ndims, prb->dims.data(), prb->dt[1], prb->tag[1]);
302 auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data;
303 DNN_SAFE_STATUS(dnnl_layer_normalization_backward_primitive_desc_create(
304 &init_pd_args.pd, init_pd_args.engine, prop, diff_src_d,
305 diff_dst_d, src_d, stat_d, prb->eps, flags, init_pd_args.hint,
306 dnnl_attr));
307 }
308
309 return dnnl_success;
310}
311
312void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
313 skip_unimplemented_data_type({prb->dt[0], prb->dt[1]}, prb->dir, res);
314 skip_unimplemented_sum_po(prb->attr, res);
315
316 if (is_gpu()) {
317 const bool dt_ok = prb->dt[0] == prb->dt[1]
318 && !is_integral_dt(prb->dt[0]) && !is_integral_dt(prb->dt[1]);
319 if (!dt_ok) {
320 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
321 return;
322 }
323 }
324}
325
326void skip_invalid_prb(const prb_t *prb, res_t *res) {
327 // See `skip_invalid_inplace` for details.
328 if (prb->inplace) {
329 skip_invalid_inplace(
330 res, prb->dt[0], prb->dt[1], prb->tag[0], prb->tag[1]);
331 if (res->state == SKIPPED) return;
332 }
333}
334
335void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
336 const args_t &ref_args) {
337 const bool compare_with_norm = (prb->dir & FLAG_BWD);
338 cmp.set_norm_validation_mode(compare_with_norm);
339
340 const auto dt = prb->dir & FLAG_FWD ? prb->dt[1] : prb->dt[0];
341 const int f32_mant_digits = 24;
342 const float trh_coeff = (1 << (f32_mant_digits - digits_dt(dt)));
343 float trh = trh_coeff * ((kind == SRC || kind == DST) ? 5e-7 : 0);
344 if ((kind == SC || kind == SH) && prb->dir & FLAG_BWD)
345 trh = trh_coeff * 5e-6;
346 cmp.set_threshold(trh);
347
348 // u8 turns half of output into zeros.
349 if (prb->dt[1] == dnnl_u8) cmp.set_zero_trust_percent(60.f);
350
351 // When the error is larger than `trh`, it could be due to a catastrophic
352 // cancellation in final result which is computed as `Y = a * X + b`.
353 // When `a * X` is close to `b` and their signs are opposite, then large
354 // error in `a * X` could result in a final result (which has a cancellation
355 // i.e. `|Y| = |a*X - (-b)|`), which has no meaningful digits left in
356 // mantissa.
357 //
358 // Since lambda is called when stack is unavailable, need to capture `prb`
359 // and `kind` by value to avoid using dangling references.
360 const auto lnorm_add_check =
361 [&, kind, prb](
362 const compare::compare_t::driver_check_func_args_t &args) {
363 if (!((prb->dir & FLAG_FWD) && kind == DST && prb->use_sh()))
364 return false;
365
366 const auto &sh = ref_args.find(DNNL_ARG_SHIFT);
367 const auto &dst = ref_args.find(DNNL_ARG_DST);
368 const int64_t c = dst.get_scale_idx(
369 args.idx, 1 << (prb->ndims - 1) /* last_dim_mask */);
370 const float beta = sh.get_elem(c);
371 // Using an empirically derived threshold, check if
372 // cancellation error in `|Y| = |a*X - (-b)|` is huge.
373 const float abs_exp = fabsf(args.exp);
374 const float norm_denom = abs_exp > FLT_MIN ? abs_exp : 1.f;
375 const float abs_exp_delta = fabsf(args.exp - beta);
376 bool maybe_cancel_error = abs_exp_delta / norm_denom > 1.f;
377 if (!maybe_cancel_error) return false;
378
379 // Check for error in `a * X`
380 float diff_aX = fabsf((args.exp - beta) - (args.got - beta));
381 float rel_diff_aX = diff_aX
382 / (abs_exp_delta > FLT_MIN ? abs_exp_delta : 1.f);
383 return rel_diff_aX <= args.trh;
384 };
385 cmp.set_driver_check_function(lnorm_add_check);
386}
387
388int doit(const prb_t *prb, res_t *res) {
389 if (bench_mode == LIST) return res->state = LISTED, OK;
390
391 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
392 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
393 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
394
395 auto const_pd = query_pd(prim);
396
397 const bool use_sc = prb->use_sc();
398 const bool use_sh = prb->use_sh();
399
400 const auto &src_md = query_md(const_pd, DNNL_ARG_SRC);
401 const auto &mean_md = query_md(const_pd, DNNL_ARG_MEAN);
402 const auto &var_md = query_md(const_pd, DNNL_ARG_VARIANCE);
403 const auto &sc_md = query_md(const_pd, DNNL_ARG_SCALE);
404 const auto &sh_md = query_md(const_pd, DNNL_ARG_SHIFT);
405 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
406
407 const auto &test_engine = get_test_engine();
408 const auto &ref_engine = get_cpu_engine();
409
410 dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine);
411 dnn_mem_t src_dt(src_md, test_engine);
412 dnn_mem_t placeholder_dst_dt;
413 dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
414
415 // On inference w/o global stats the layer norm doesn't require stat
416 // memories. Hence, we need to prepare the mean_fp and var_fp ourselves.
417 dnn_mem_t mean_fp(
418 prb->ndims - 1, src_fp.dims(), dnnl_f32, tag::abx, ref_engine);
419 dnn_mem_t mean_dt(mean_md, test_engine);
420
421 dnn_mem_t var_fp(
422 prb->ndims - 1, src_fp.dims(), dnnl_f32, tag::abx, ref_engine);
423 dnn_mem_t var_dt(var_md, test_engine);
424
425 dnn_mem_t sc_fp(sc_md, dnnl_f32, tag::abx, ref_engine);
426 dnn_mem_t sc_dt(sc_md, test_engine);
427
428 dnn_mem_t sh_fp(sh_md, dnnl_f32, use_sh ? tag::x : tag::abx, ref_engine);
429 dnn_mem_t sh_dt(sh_md, test_engine);
430
431 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
432
433 dnn_mem_t d_dst_dt, placeholder_d_src_dt, d_sc_dt, d_sh_dt;
434
435 const dnnl_dims_t scale_dims = {1};
436 auto scales_md = dnn_mem_t::init_md(1, scale_dims, dnnl_f32, tag::abx);
437 dnn_mem_t src_scales_dt(scales_md, test_engine);
438 dnn_mem_t dst_scales_dt(scales_md, test_engine);
439
440 args_t args, ref_args;
441
442 if (prb->dir & FLAG_FWD) {
443 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
444
445 dnn_mem_t &dst_fp = src_fp; // in-place reference
446 if (!prb->inplace) {
447 placeholder_dst_dt = dnn_mem_t(dst_md, test_engine);
448 }
449
450 if (prepare_fwd(prb, src_fp, mean_fp, var_fp, sc_fp, sh_fp) != OK) {
451 return res->state = MISTRUSTED, OK;
452 }
453
454 SAFE(src_dt.reorder(src_fp), WARN);
455 if (prb->flags & GLOB_STATS) {
456 /* prepare mean & var if they are inputs */
457 SAFE(mean_dt.reorder(mean_fp), WARN);
458 SAFE(var_dt.reorder(var_fp), WARN);
459 }
460 if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); }
461 if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
462
463 dnn_mem_t src_scales_fp(scales_md, ref_engine);
464 dnn_mem_t dst_scales_fp(scales_md, ref_engine);
465 fill_scales(prb->attr, DNNL_ARG_SRC, src_scales_dt, src_scales_fp);
466 fill_scales(prb->attr, DNNL_ARG_DST, dst_scales_dt, dst_scales_fp);
467
468 args.set(DNNL_ARG_SRC, src_dt);
469 args.set(DNNL_ARG_MEAN, mean_dt);
470 args.set(DNNL_ARG_VARIANCE, var_dt);
471 args.set(DNNL_ARG_SCALE, sc_dt);
472 args.set(DNNL_ARG_SHIFT, sh_dt);
473 args.set(DNNL_ARG_DST, dst_dt);
474 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
475 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
476 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
477
478 SAFE(execute_and_wait(prim, args, res), WARN);
479
480 if (is_bench_mode(CORR)) {
481 ref_args.set(DNNL_ARG_SRC, src_fp);
482 ref_args.set(DNNL_ARG_MEAN, mean_fp);
483 ref_args.set(DNNL_ARG_VARIANCE, var_fp);
484 ref_args.set(DNNL_ARG_SCALE, sc_fp);
485 ref_args.set(DNNL_ARG_SHIFT, sh_fp);
486 ref_args.set(DNNL_ARG_DST, dst_fp);
487 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp);
488 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp);
489
490 std::vector<data_kind_t> kinds {DST};
491 if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) {
492 kinds.push_back(MEAN);
493 kinds.push_back(VAR);
494 }
495
496 check_correctness(prb, kinds, args, ref_args, setup_cmp, res);
497 }
498 } else {
499 const auto &d_src_md = query_md(const_pd, DNNL_ARG_DIFF_SRC);
500 const auto &d_dst_md = query_md(const_pd, DNNL_ARG_DIFF_DST);
501
502 dnn_mem_t d_dst_fp(d_dst_md, dnnl_f32, tag::abx, ref_engine);
503 d_dst_dt = dnn_mem_t(d_dst_md, test_engine);
504
505 dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code
506 if (!prb->inplace) {
507 placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine);
508 }
509 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
510
511 d_sc_dt = dnn_mem_t(sc_md, test_engine);
512 dnn_mem_t d_sc_fp(sc_md, dnnl_f32, tag::abx, ref_engine);
513
514 d_sh_dt = dnn_mem_t(sh_md, test_engine);
515 dnn_mem_t d_sh_fp(
516 sh_md, dnnl_f32, use_sh ? tag::x : tag::abx, ref_engine);
517
518 if (prepare_bwd(prb, src_fp, d_dst_fp, mean_fp, var_fp, sc_fp) != OK) {
519 return res->state = MISTRUSTED, OK;
520 }
521
522 SAFE(src_dt.reorder(src_fp), WARN);
523 SAFE(d_dst_dt.reorder(d_dst_fp), WARN);
524 SAFE(mean_dt.reorder(mean_fp), WARN);
525 SAFE(var_dt.reorder(var_fp), WARN);
526 if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); }
527 if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
528
529 args.set(DNNL_ARG_SRC, src_dt);
530 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
531 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
532 args.set(DNNL_ARG_MEAN, mean_dt);
533 args.set(DNNL_ARG_VARIANCE, var_dt);
534 args.set(DNNL_ARG_SCALE, sc_dt);
535 args.set(DNNL_ARG_DIFF_SCALE, d_sc_dt);
536 args.set(DNNL_ARG_SHIFT, sh_dt);
537 args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt);
538 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
539
540 SAFE(execute_and_wait(prim, args, res), WARN);
541
542 if (is_bench_mode(CORR)) {
543 ref_args.set(DNNL_ARG_SRC, src_fp);
544 ref_args.set(DNNL_ARG_MEAN, mean_fp);
545 ref_args.set(DNNL_ARG_VARIANCE, var_fp);
546 ref_args.set(DNNL_ARG_SCALE, sc_fp);
547 ref_args.set(DNNL_ARG_SHIFT, sh_fp);
548 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
549 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
550 ref_args.set(DNNL_ARG_DIFF_SCALE, d_sc_fp);
551 ref_args.set(DNNL_ARG_DIFF_SHIFT, d_sh_fp);
552
553 std::vector<data_kind_t> kinds {SRC};
554 if (use_sc && (prb->dir & FLAG_WEI)) kinds.push_back(SC);
555 if (use_sh && (prb->dir & FLAG_WEI)) kinds.push_back(SH);
556
557 check_correctness(prb, kinds, args, ref_args, setup_cmp, res);
558 }
559 }
560
561 return measure_perf(prb->ctx_exe, res, prim, args);
562}
563
564} // namespace lnorm
565