1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cmath> |
18 | #include <float.h> |
19 | #include <random> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | |
23 | #include <sstream> |
24 | |
25 | #include "oneapi/dnnl/dnnl.h" |
26 | |
27 | #include "utils/parallel.hpp" |
28 | |
29 | #include "dnnl_common.hpp" |
30 | #include "dnnl_memory.hpp" |
31 | |
32 | #include "bnorm/bnorm.hpp" |
33 | #include "lnorm/lnorm.hpp" |
34 | |
35 | using namespace bnorm; |
36 | |
37 | namespace lnorm { |
38 | |
39 | static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &mean, |
40 | dnn_mem_t &var, dnn_mem_t &sc, dnn_mem_t &sh) { |
41 | /** Idea: choose src[] values so that both mean and variance are computed |
42 | * exactly (independently of the order of the computations). |
43 | * |
44 | * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean. |
45 | * |
46 | * The variation in src is allowed in the last flex_bits bits. |
47 | * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean |
48 | * value is set to 0 and src is partially filled with zeros (according to |
49 | * density so that at least want_flex_bits is reserved for src variation. |
50 | * Once src is set, variance is computed. |
51 | * |
52 | * ALG_0: mean is set to 0 |
53 | * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4} |
54 | * ALG_AUTO: choose between ALG_0 and ALG_1 automatically |
55 | * ALG_2: if fall back to ALG_0 gives only one non-zero element, use the |
56 | * filling which doesn't use strict approach. |
57 | */ |
58 | const int64_t exact_bits = digits_dt(prb->dt[0]); |
59 | const int64_t L = prb->c; |
60 | const int64_t logL = (int64_t)ceilf(log2f(L)); |
61 | |
62 | assert(logL <= 0 || (1LL << (logL - 1)) < L); |
63 | assert(L <= (1LL << logL)); |
64 | |
65 | const int64_t min_flex_bits = 3; |
66 | const int64_t want_flex_bits = MIN2(6, exact_bits / 2); |
67 | |
68 | check_alg_t alg = prb->check_alg; |
69 | if (alg == ALG_AUTO) /* choose appropriate checking algorithm */ |
70 | alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0; |
71 | |
72 | const int64_t flex_bits = alg == ALG_0 |
73 | ? want_flex_bits |
74 | : MIN2(exact_bits, (exact_bits - logL) / 2 - 1); |
75 | if (flex_bits < min_flex_bits) return FAIL; |
76 | |
77 | if (exact_bits / 2 == flex_bits) alg = ALG_2; |
78 | |
79 | if ((alg == ALG_0 || alg == ALG_1) && !is_integral_dt(prb->dt[0])) { |
80 | const int64_t flex_mask = (1 << flex_bits) - 1; |
81 | |
82 | /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */ |
83 | const float density = alg == ALG_0 |
84 | ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L |
85 | : 1.f; |
86 | assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits); |
87 | |
88 | BENCHDNN_PRINT(99, "check_alg: %s, density = %g, flex_bits = %ld\n" , |
89 | check_alg2str(alg), density, (long)flex_bits); |
90 | |
91 | benchdnn_parallel_nd(prb->n, [&](int64_t n) { |
92 | const float m = alg == ALG_0 ? 0.f : 0.25f * (1 << (n % 7)); |
93 | float v = 0; /* current variance */ |
94 | |
95 | float *s = (float *)src + n * prb->c; |
96 | for (int64_t c = 0; c < prb->c; ++c) { |
97 | const int64_t l = c + n * 239 * 2; // l[0] must be even |
98 | |
99 | if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) { |
100 | s[c] = 0; |
101 | continue; |
102 | } |
103 | |
104 | const int64_t gen = (l / 2 * 1637) & flex_mask; |
105 | const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */ |
106 | const float f = 1.f * sgn * gen / (1 << flex_bits); |
107 | |
108 | src.set_elem(n * prb->c + c, alg == ALG_0 ? f : m * (1.f + f)); |
109 | if (L % 2 && (c == L - 1)) { s[c] = m; } |
110 | v += (s[c] - m) * (s[c] - m); |
111 | } |
112 | mean.set_elem(n, m); |
113 | var.set_elem(n, v / prb->c); |
114 | }); |
115 | } else { |
116 | assert(alg == ALG_2); |
117 | |
118 | benchdnn_parallel_nd(prb->n, [&](int64_t n) { |
119 | // Note: we use a different seed for each chunk to avoid |
120 | // repeating patterns. We could use discard(idx_start) too but |
121 | // it has a complexity in O(idx_start). We also add 1 to avoid |
122 | // seeding with 0. |
123 | std::minstd_rand int_seed(n + 1); |
124 | int_seed.discard(1); |
125 | std::minstd_rand b_seed(n + 1); |
126 | b_seed.discard(2); |
127 | |
128 | const float val_coeff = is_integral_dt(prb->dt[0]) ? 4.f : 1.f; |
129 | const int distr_shift = prb->dt[0] == dnnl_u8 ? 2 : 0; |
130 | std::uniform_int_distribution<> int_dist(0 + distr_shift, 6); |
131 | std::bernoulli_distribution b_dist(0.5f); |
132 | const float m = val_coeff * 0.25f * (1 << int_dist(int_seed)); |
133 | float v = 0; /* current variance */ |
134 | |
135 | const int64_t c_shift = n * prb->c; |
136 | float *s = (float *)src + c_shift; |
137 | |
138 | bool bigger_val = false; |
139 | float val = 0.f; |
140 | |
141 | for (int64_t c = 0; c < prb->c; ++c) { |
142 | const int64_t idx = c_shift + c; |
143 | |
144 | if (c % 2 == 0) { |
145 | bigger_val = b_dist(b_seed); |
146 | val = bigger_val ? (m + val_coeff * 1.f) |
147 | : (m + val_coeff * 0.25f); |
148 | } else { |
149 | val = bigger_val ? (m - val_coeff * 1.f) |
150 | : (m - val_coeff * 0.25f); |
151 | } |
152 | src.set_elem(idx, val); |
153 | |
154 | v += (s[c] - m) * (s[c] - m); |
155 | } |
156 | // Update last element with s[c] = m. |
157 | if (prb->c % 2 == 1) { |
158 | v -= (s[prb->c - 1] - m) * (s[prb->c - 1] - m); |
159 | s[prb->c - 1] = m; |
160 | } |
161 | mean.set_elem(n, m); |
162 | var.set_elem(n, v / prb->c); |
163 | }); |
164 | } |
165 | |
166 | const bool use_sc = prb->use_sc(); |
167 | const bool use_sh = prb->use_sh(); |
168 | |
169 | benchdnn_parallel_nd(prb->c, [&](int64_t c) { |
170 | float sc_value = 1.f / 8 * (1 << (c % 7)); |
171 | float sh_value = (c % 3 + 1) * sc_value / 64; |
172 | ((float *)sc)[c] = use_sc ? sc_value : 1.0f; |
173 | ((float *)sh)[c] = use_sh ? sh_value : 0.0f; |
174 | }); |
175 | return OK; |
176 | } |
177 | |
178 | static int prepare_bwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &d_dst, |
179 | dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc) { |
180 | if (prb->c < 2) return FAIL; |
181 | |
182 | const bool use_sc = prb->use_sc(); |
183 | |
184 | // fill gamma |
185 | for (int64_t c = 0; c < prb->c; ++c) { |
186 | const float sc_value = 0.125f * (1 << (c % 7)); |
187 | ((float *)sc)[c] = use_sc ? sc_value : 1.0f; |
188 | } |
189 | |
190 | benchdnn_parallel_nd(prb->n, [&](int64_t n) { |
191 | // Note: we use a different seed for each chunk to avoid |
192 | // repeating patterns. We could use discard(idx_start) too but |
193 | // it has a complexity in O(idx_start). We also add 1 to avoid |
194 | // seeding with 0. |
195 | std::minstd_rand int_seed(n + 1); |
196 | int_seed.discard(1); |
197 | std::minstd_rand b_seed(n + 1); |
198 | b_seed.discard(2); |
199 | |
200 | // Idea behind the filling is to reduce a possibility of cancellation |
201 | // when subtracting a part accumulated over N. For that, we simplify |
202 | // src data to (m+1) and (m-1) points, d_dst data is more or less |
203 | // random but we keep all values as pow2 values to have almost exact |
204 | // summation result. |
205 | std::uniform_int_distribution<> stat_dist(0, 2); |
206 | std::uniform_int_distribution<> data_dist(0, 6); |
207 | std::bernoulli_distribution half_dist(0.5f); |
208 | |
209 | // mean = {-0.5f, 0.f, 0.5f} |
210 | const float m = 0.5f * (stat_dist(int_seed) - 1); |
211 | mean.set_elem(n, m); |
212 | |
213 | // final variance = {0.25f, 1.f, 4.f} |
214 | const float v = 0.25f * (1 << (stat_dist(int_seed) * 2)); |
215 | var.set_elem(n, v - prb->eps); |
216 | |
217 | const int64_t c_shift = n * prb->c; |
218 | |
219 | for (int64_t c = 0; c < prb->c; ++c) { |
220 | int sign = half_dist(b_seed) ? 1.f : -1.f; |
221 | // d_dst = powf(2, {-4, ... , 2}) |
222 | float dd = sign * 0.0625f * (1LL << data_dist(int_seed)); |
223 | d_dst.set_elem(c_shift + c, |
224 | round_to_nearest_representable(prb->dt[1], dd)); |
225 | |
226 | float s = c % 2 == 0 ? (m - 1.f) : (m + 1.f); |
227 | src.set_elem( |
228 | c_shift + c, round_to_nearest_representable(prb->dt[0], s)); |
229 | } |
230 | }); |
231 | |
232 | return OK; |
233 | } |
234 | |
235 | int fill_scales( |
236 | const attr_t &attr, int arg, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
237 | const auto nelems = mem_fp.nelems(); |
238 | if (nelems == 0) return OK; |
239 | |
240 | assert(mem_dt.nelems() == mem_fp.nelems()); |
241 | |
242 | const auto &scales = attr.scales.get(arg); |
243 | |
244 | /* Do fixed partitioning to have same filling for any number of threads */ |
245 | const int64_t n_chunks = 16; |
246 | const int64_t chunk_size = div_up(nelems, n_chunks); |
247 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
248 | int64_t idx_start = idx_chunk * chunk_size; |
249 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
250 | // Note: we use a different seed for each chunk to avoid |
251 | // repeating patterns. We could use discard(idx_start) too but |
252 | // it has a complexity in O(idx_start). We also add 1 to avoid |
253 | // seeding with 0. |
254 | std::minstd_rand int_seed(idx_start + 1); |
255 | int_seed.discard(1); |
256 | |
257 | std::uniform_int_distribution<> gen(-5, 5); |
258 | |
259 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
260 | int pow2 = gen(int_seed); |
261 | int pow2_shift = 1 << std::abs(pow2); |
262 | const float gen_val = pow2 < 0 ? (1.f / pow2_shift) : pow2_shift; |
263 | const float fixed_val = scales.scale; |
264 | const float val = nelems == 1 ? fixed_val : gen_val; |
265 | mem_fp.set_elem(idx, val); |
266 | } |
267 | }); |
268 | |
269 | SAFE(mem_dt.reorder(mem_fp), WARN); |
270 | |
271 | return OK; |
272 | } |
273 | |
274 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
275 | const prb_t *prb = init_pd_args.prb; |
276 | |
277 | auto src_d = dnn_mem_t::init_md( |
278 | prb->ndims, prb->dims.data(), prb->dt[0], prb->tag[0]); |
279 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> stat_d {}; |
280 | if (prb->stat_tag != tag::undef) { |
281 | stat_d = dnn_mem_t::init_md( |
282 | prb->ndims - 1, prb->dims.data(), dnnl_f32, prb->stat_tag); |
283 | } |
284 | |
285 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
286 | create_dnnl_attr(prb->attr, attr_args_t())); |
287 | |
288 | auto flags = (dnnl_normalization_flags_t)prb->flags; |
289 | if (prb->dir & FLAG_FWD) { |
290 | auto dst_d = dnn_mem_t::init_md( |
291 | prb->ndims, prb->dims.data(), prb->dt[1], prb->tag[1]); |
292 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
293 | : dnnl_forward_training; |
294 | DNN_SAFE_STATUS(dnnl_layer_normalization_forward_primitive_desc_create( |
295 | &init_pd_args.pd, init_pd_args.engine, prop, src_d, dst_d, |
296 | stat_d, prb->eps, flags, dnnl_attr)); |
297 | } else { |
298 | auto diff_src_d = dnn_mem_t::init_md( |
299 | prb->ndims, prb->dims.data(), prb->dt[0], prb->tag[0]); |
300 | auto diff_dst_d = dnn_mem_t::init_md( |
301 | prb->ndims, prb->dims.data(), prb->dt[1], prb->tag[1]); |
302 | auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data; |
303 | DNN_SAFE_STATUS(dnnl_layer_normalization_backward_primitive_desc_create( |
304 | &init_pd_args.pd, init_pd_args.engine, prop, diff_src_d, |
305 | diff_dst_d, src_d, stat_d, prb->eps, flags, init_pd_args.hint, |
306 | dnnl_attr)); |
307 | } |
308 | |
309 | return dnnl_success; |
310 | } |
311 | |
312 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
313 | skip_unimplemented_data_type({prb->dt[0], prb->dt[1]}, prb->dir, res); |
314 | skip_unimplemented_sum_po(prb->attr, res); |
315 | |
316 | if (is_gpu()) { |
317 | const bool dt_ok = prb->dt[0] == prb->dt[1] |
318 | && !is_integral_dt(prb->dt[0]) && !is_integral_dt(prb->dt[1]); |
319 | if (!dt_ok) { |
320 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
321 | return; |
322 | } |
323 | } |
324 | } |
325 | |
326 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
327 | // See `skip_invalid_inplace` for details. |
328 | if (prb->inplace) { |
329 | skip_invalid_inplace( |
330 | res, prb->dt[0], prb->dt[1], prb->tag[0], prb->tag[1]); |
331 | if (res->state == SKIPPED) return; |
332 | } |
333 | } |
334 | |
335 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
336 | const args_t &ref_args) { |
337 | const bool compare_with_norm = (prb->dir & FLAG_BWD); |
338 | cmp.set_norm_validation_mode(compare_with_norm); |
339 | |
340 | const auto dt = prb->dir & FLAG_FWD ? prb->dt[1] : prb->dt[0]; |
341 | const int f32_mant_digits = 24; |
342 | const float trh_coeff = (1 << (f32_mant_digits - digits_dt(dt))); |
343 | float trh = trh_coeff * ((kind == SRC || kind == DST) ? 5e-7 : 0); |
344 | if ((kind == SC || kind == SH) && prb->dir & FLAG_BWD) |
345 | trh = trh_coeff * 5e-6; |
346 | cmp.set_threshold(trh); |
347 | |
348 | // u8 turns half of output into zeros. |
349 | if (prb->dt[1] == dnnl_u8) cmp.set_zero_trust_percent(60.f); |
350 | |
351 | // When the error is larger than `trh`, it could be due to a catastrophic |
352 | // cancellation in final result which is computed as `Y = a * X + b`. |
353 | // When `a * X` is close to `b` and their signs are opposite, then large |
354 | // error in `a * X` could result in a final result (which has a cancellation |
355 | // i.e. `|Y| = |a*X - (-b)|`), which has no meaningful digits left in |
356 | // mantissa. |
357 | // |
358 | // Since lambda is called when stack is unavailable, need to capture `prb` |
359 | // and `kind` by value to avoid using dangling references. |
360 | const auto lnorm_add_check = |
361 | [&, kind, prb]( |
362 | const compare::compare_t::driver_check_func_args_t &args) { |
363 | if (!((prb->dir & FLAG_FWD) && kind == DST && prb->use_sh())) |
364 | return false; |
365 | |
366 | const auto &sh = ref_args.find(DNNL_ARG_SHIFT); |
367 | const auto &dst = ref_args.find(DNNL_ARG_DST); |
368 | const int64_t c = dst.get_scale_idx( |
369 | args.idx, 1 << (prb->ndims - 1) /* last_dim_mask */); |
370 | const float beta = sh.get_elem(c); |
371 | // Using an empirically derived threshold, check if |
372 | // cancellation error in `|Y| = |a*X - (-b)|` is huge. |
373 | const float abs_exp = fabsf(args.exp); |
374 | const float norm_denom = abs_exp > FLT_MIN ? abs_exp : 1.f; |
375 | const float abs_exp_delta = fabsf(args.exp - beta); |
376 | bool maybe_cancel_error = abs_exp_delta / norm_denom > 1.f; |
377 | if (!maybe_cancel_error) return false; |
378 | |
379 | // Check for error in `a * X` |
380 | float diff_aX = fabsf((args.exp - beta) - (args.got - beta)); |
381 | float rel_diff_aX = diff_aX |
382 | / (abs_exp_delta > FLT_MIN ? abs_exp_delta : 1.f); |
383 | return rel_diff_aX <= args.trh; |
384 | }; |
385 | cmp.set_driver_check_function(lnorm_add_check); |
386 | } |
387 | |
388 | int doit(const prb_t *prb, res_t *res) { |
389 | if (bench_mode == LIST) return res->state = LISTED, OK; |
390 | |
391 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
392 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
393 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
394 | |
395 | auto const_pd = query_pd(prim); |
396 | |
397 | const bool use_sc = prb->use_sc(); |
398 | const bool use_sh = prb->use_sh(); |
399 | |
400 | const auto &src_md = query_md(const_pd, DNNL_ARG_SRC); |
401 | const auto &mean_md = query_md(const_pd, DNNL_ARG_MEAN); |
402 | const auto &var_md = query_md(const_pd, DNNL_ARG_VARIANCE); |
403 | const auto &sc_md = query_md(const_pd, DNNL_ARG_SCALE); |
404 | const auto &sh_md = query_md(const_pd, DNNL_ARG_SHIFT); |
405 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
406 | |
407 | const auto &test_engine = get_test_engine(); |
408 | const auto &ref_engine = get_cpu_engine(); |
409 | |
410 | dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine); |
411 | dnn_mem_t src_dt(src_md, test_engine); |
412 | dnn_mem_t placeholder_dst_dt; |
413 | dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt; |
414 | |
415 | // On inference w/o global stats the layer norm doesn't require stat |
416 | // memories. Hence, we need to prepare the mean_fp and var_fp ourselves. |
417 | dnn_mem_t mean_fp( |
418 | prb->ndims - 1, src_fp.dims(), dnnl_f32, tag::abx, ref_engine); |
419 | dnn_mem_t mean_dt(mean_md, test_engine); |
420 | |
421 | dnn_mem_t var_fp( |
422 | prb->ndims - 1, src_fp.dims(), dnnl_f32, tag::abx, ref_engine); |
423 | dnn_mem_t var_dt(var_md, test_engine); |
424 | |
425 | dnn_mem_t sc_fp(sc_md, dnnl_f32, tag::abx, ref_engine); |
426 | dnn_mem_t sc_dt(sc_md, test_engine); |
427 | |
428 | dnn_mem_t sh_fp(sh_md, dnnl_f32, use_sh ? tag::x : tag::abx, ref_engine); |
429 | dnn_mem_t sh_dt(sh_md, test_engine); |
430 | |
431 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
432 | |
433 | dnn_mem_t d_dst_dt, placeholder_d_src_dt, d_sc_dt, d_sh_dt; |
434 | |
435 | const dnnl_dims_t scale_dims = {1}; |
436 | auto scales_md = dnn_mem_t::init_md(1, scale_dims, dnnl_f32, tag::abx); |
437 | dnn_mem_t src_scales_dt(scales_md, test_engine); |
438 | dnn_mem_t dst_scales_dt(scales_md, test_engine); |
439 | |
440 | args_t args, ref_args; |
441 | |
442 | if (prb->dir & FLAG_FWD) { |
443 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
444 | |
445 | dnn_mem_t &dst_fp = src_fp; // in-place reference |
446 | if (!prb->inplace) { |
447 | placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); |
448 | } |
449 | |
450 | if (prepare_fwd(prb, src_fp, mean_fp, var_fp, sc_fp, sh_fp) != OK) { |
451 | return res->state = MISTRUSTED, OK; |
452 | } |
453 | |
454 | SAFE(src_dt.reorder(src_fp), WARN); |
455 | if (prb->flags & GLOB_STATS) { |
456 | /* prepare mean & var if they are inputs */ |
457 | SAFE(mean_dt.reorder(mean_fp), WARN); |
458 | SAFE(var_dt.reorder(var_fp), WARN); |
459 | } |
460 | if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); } |
461 | if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); } |
462 | |
463 | dnn_mem_t src_scales_fp(scales_md, ref_engine); |
464 | dnn_mem_t dst_scales_fp(scales_md, ref_engine); |
465 | fill_scales(prb->attr, DNNL_ARG_SRC, src_scales_dt, src_scales_fp); |
466 | fill_scales(prb->attr, DNNL_ARG_DST, dst_scales_dt, dst_scales_fp); |
467 | |
468 | args.set(DNNL_ARG_SRC, src_dt); |
469 | args.set(DNNL_ARG_MEAN, mean_dt); |
470 | args.set(DNNL_ARG_VARIANCE, var_dt); |
471 | args.set(DNNL_ARG_SCALE, sc_dt); |
472 | args.set(DNNL_ARG_SHIFT, sh_dt); |
473 | args.set(DNNL_ARG_DST, dst_dt); |
474 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
475 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
476 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
477 | |
478 | SAFE(execute_and_wait(prim, args, res), WARN); |
479 | |
480 | if (is_bench_mode(CORR)) { |
481 | ref_args.set(DNNL_ARG_SRC, src_fp); |
482 | ref_args.set(DNNL_ARG_MEAN, mean_fp); |
483 | ref_args.set(DNNL_ARG_VARIANCE, var_fp); |
484 | ref_args.set(DNNL_ARG_SCALE, sc_fp); |
485 | ref_args.set(DNNL_ARG_SHIFT, sh_fp); |
486 | ref_args.set(DNNL_ARG_DST, dst_fp); |
487 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp); |
488 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp); |
489 | |
490 | std::vector<data_kind_t> kinds {DST}; |
491 | if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) { |
492 | kinds.push_back(MEAN); |
493 | kinds.push_back(VAR); |
494 | } |
495 | |
496 | check_correctness(prb, kinds, args, ref_args, setup_cmp, res); |
497 | } |
498 | } else { |
499 | const auto &d_src_md = query_md(const_pd, DNNL_ARG_DIFF_SRC); |
500 | const auto &d_dst_md = query_md(const_pd, DNNL_ARG_DIFF_DST); |
501 | |
502 | dnn_mem_t d_dst_fp(d_dst_md, dnnl_f32, tag::abx, ref_engine); |
503 | d_dst_dt = dnn_mem_t(d_dst_md, test_engine); |
504 | |
505 | dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code |
506 | if (!prb->inplace) { |
507 | placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine); |
508 | } |
509 | dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt; |
510 | |
511 | d_sc_dt = dnn_mem_t(sc_md, test_engine); |
512 | dnn_mem_t d_sc_fp(sc_md, dnnl_f32, tag::abx, ref_engine); |
513 | |
514 | d_sh_dt = dnn_mem_t(sh_md, test_engine); |
515 | dnn_mem_t d_sh_fp( |
516 | sh_md, dnnl_f32, use_sh ? tag::x : tag::abx, ref_engine); |
517 | |
518 | if (prepare_bwd(prb, src_fp, d_dst_fp, mean_fp, var_fp, sc_fp) != OK) { |
519 | return res->state = MISTRUSTED, OK; |
520 | } |
521 | |
522 | SAFE(src_dt.reorder(src_fp), WARN); |
523 | SAFE(d_dst_dt.reorder(d_dst_fp), WARN); |
524 | SAFE(mean_dt.reorder(mean_fp), WARN); |
525 | SAFE(var_dt.reorder(var_fp), WARN); |
526 | if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); } |
527 | if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); } |
528 | |
529 | args.set(DNNL_ARG_SRC, src_dt); |
530 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
531 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
532 | args.set(DNNL_ARG_MEAN, mean_dt); |
533 | args.set(DNNL_ARG_VARIANCE, var_dt); |
534 | args.set(DNNL_ARG_SCALE, sc_dt); |
535 | args.set(DNNL_ARG_DIFF_SCALE, d_sc_dt); |
536 | args.set(DNNL_ARG_SHIFT, sh_dt); |
537 | args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt); |
538 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
539 | |
540 | SAFE(execute_and_wait(prim, args, res), WARN); |
541 | |
542 | if (is_bench_mode(CORR)) { |
543 | ref_args.set(DNNL_ARG_SRC, src_fp); |
544 | ref_args.set(DNNL_ARG_MEAN, mean_fp); |
545 | ref_args.set(DNNL_ARG_VARIANCE, var_fp); |
546 | ref_args.set(DNNL_ARG_SCALE, sc_fp); |
547 | ref_args.set(DNNL_ARG_SHIFT, sh_fp); |
548 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
549 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
550 | ref_args.set(DNNL_ARG_DIFF_SCALE, d_sc_fp); |
551 | ref_args.set(DNNL_ARG_DIFF_SHIFT, d_sh_fp); |
552 | |
553 | std::vector<data_kind_t> kinds {SRC}; |
554 | if (use_sc && (prb->dir & FLAG_WEI)) kinds.push_back(SC); |
555 | if (use_sh && (prb->dir & FLAG_WEI)) kinds.push_back(SH); |
556 | |
557 | check_correctness(prb, kinds, args, ref_args, setup_cmp, res); |
558 | } |
559 | } |
560 | |
561 | return measure_perf(prb->ctx_exe, res, prim, args); |
562 | } |
563 | |
564 | } // namespace lnorm |
565 | |