1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stddef.h> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | |
23 | #include <random> |
24 | #include <sstream> |
25 | |
26 | #include "oneapi/dnnl/dnnl.h" |
27 | |
28 | #include "utils/parallel.hpp" |
29 | |
30 | #include "dnnl_common.hpp" |
31 | #include "dnnl_memory.hpp" |
32 | |
33 | #include "bnorm/bnorm.hpp" |
34 | |
35 | namespace bnorm { |
36 | |
37 | static int prepare_fwd_with_stats(const prb_t *prb, dnn_mem_t &src, |
38 | dnn_mem_t &src_add, dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc, |
39 | dnn_mem_t &sh) { |
40 | const bool use_sc = prb->use_sc(); |
41 | const bool use_sh = prb->use_sh(); |
42 | const bool fill_src_add = prb->fuse_add_relu(); |
43 | |
44 | benchdnn_parallel_nd(prb->ic, [&](int64_t c) { |
45 | mean.set_elem(c, 4 * ((c % 5) - 2)); |
46 | var.set_elem(c, ((c % 7) << 1)); |
47 | |
48 | const float sc_value = 1 << (c % 7); |
49 | const float sh_value = ((c % 3) - 1) * sc_value; |
50 | sc.set_elem(c, use_sc ? sc_value : 1.0f); |
51 | sh.set_elem(c, use_sh ? sh_value : 0.0f); |
52 | }); |
53 | |
54 | benchdnn_parallel_nd(prb->ic, prb->mb, prb->id, prb->ih, prb->iw, |
55 | [&](int64_t c, int64_t mb, int64_t d, int64_t h, int64_t w) { |
56 | int64_t l_base = mb * prb->id * prb->ih * prb->iw + c * 239 * 2; |
57 | float *s = (float *)src + data_off(prb, mb, c, 0, 0, 0); |
58 | |
59 | const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w; |
60 | const int64_t l = l_base + sp; |
61 | const int64_t value = (l % 65) - 32; |
62 | s[sp] = round_to_nearest_representable(prb->dt, value); |
63 | if (fill_src_add) { |
64 | float *s_add |
65 | = (float *)src_add + data_off(prb, mb, c, 0, 0, 0); |
66 | s_add[sp] = round_to_nearest_representable( |
67 | prb->dt, (l % 17) - 8); |
68 | } |
69 | }); |
70 | |
71 | return OK; |
72 | } |
73 | |
74 | static int prepare_fwd_no_stats(const prb_t *prb, dnn_mem_t &src, |
75 | dnn_mem_t &src_add, dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc, |
76 | dnn_mem_t &sh) { |
77 | /** Idea: choose src[] values so that both mean and variance are computed |
78 | * exactly (independently of the order of the computations). |
79 | * |
80 | * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean. |
81 | * |
82 | * The variation in src is allowed in the last flex_bits bits. |
83 | * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean |
84 | * value is set to 0 and src is partially filled with zeros (according to |
85 | * density so that at least want_flex_bits is reserved for src variation. |
86 | * Once src is set, variance is computed. |
87 | * |
88 | * ALG_0: mean is set to 0 |
89 | * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4} |
90 | * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */ |
91 | const int64_t exact_bits = digits_dt(prb->dt); |
92 | const int64_t L = prb->mb * prb->id * prb->ih * prb->iw; |
93 | const int64_t logL = (int64_t)ceilf(log2f(L)); |
94 | |
95 | assert(logL <= 0 || (1LL << (logL - 1)) < L); |
96 | assert(L <= (1LL << logL)); |
97 | |
98 | const int64_t min_flex_bits = 3; |
99 | const int64_t want_flex_bits = MIN2(6, exact_bits / 2); |
100 | |
101 | check_alg_t alg = prb->check_alg; |
102 | if (alg == ALG_AUTO) /* choose appropriate checking algorithm */ |
103 | alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0; |
104 | |
105 | const int64_t flex_bits = alg == ALG_0 |
106 | ? want_flex_bits /* BFloat16 has only 7 bits of mantissa */ |
107 | : MIN2(prb->dt == dnnl_bf16 ? 7 : exact_bits, |
108 | (exact_bits - logL) / 2 - 1); |
109 | |
110 | if (flex_bits < min_flex_bits) return FAIL; |
111 | |
112 | const int64_t flex_mask = (1 << flex_bits) - 1; |
113 | |
114 | /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */ |
115 | const float density = alg == ALG_0 |
116 | ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L |
117 | : 1.f; |
118 | assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits); |
119 | |
120 | BENCHDNN_PRINT(6, "check_alg: %s, density = %g, flex_bits = " IFMT "\n" , |
121 | check_alg2str(alg), density, flex_bits); |
122 | |
123 | const bool use_sc = prb->use_sc(); |
124 | const bool use_sh = prb->use_sh(); |
125 | const bool fill_src_add = prb->fuse_add_relu(); |
126 | |
127 | benchdnn_parallel_nd(prb->ic, [&](int64_t c) { |
128 | const float m = ((float *)mean)[c] |
129 | = alg == ALG_0 ? 0.f : 0.25f * (1 << (c % 7)); |
130 | float v = 0; /* current variance */ |
131 | |
132 | for (int64_t mb = 0; mb < prb->mb; ++mb) { |
133 | int64_t l_base = mb * prb->id * prb->ih * prb->iw |
134 | + c * 239 * 2; // l[0] must be even |
135 | int64_t off = data_off(prb, mb, c, 0, 0, 0); |
136 | float *s = (float *)src + off; |
137 | |
138 | for_(int64_t d = 0; d < prb->id; ++d) |
139 | for_(int64_t h = 0; h < prb->ih; ++h) |
140 | for (int64_t w = 0; w < prb->iw; ++w) { |
141 | |
142 | const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w; |
143 | const int64_t l = l_base + sp; |
144 | |
145 | if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) { |
146 | s[sp] = 0; |
147 | continue; |
148 | } |
149 | |
150 | const int64_t gen = (l / 2 * 1637) & flex_mask; |
151 | const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */ |
152 | const float f = 1.f * sgn * gen / (1 << flex_bits); |
153 | |
154 | s[sp] = alg == ALG_0 ? f : m * (1.f + f); |
155 | if (L % 2 && (mb * prb->id * prb->ih * prb->iw + sp == L - 1)) { |
156 | s[sp] = m; |
157 | } |
158 | v += (s[sp] - m) * (s[sp] - m); |
159 | if (fill_src_add) { |
160 | // The main purpose of such filling is to avoid catastrophic |
161 | // cancellation. To do that, the sign of `Add` tensor final |
162 | // values is kept the same as it would be after applying |
163 | // bnorm: what's below mean, that has negative sign, what's |
164 | // equal or higher - positive. |
165 | const int64_t mod2_base = (mb + c + d + h + w) % 5; |
166 | const float mod2_val = 1.f / (2LL << mod2_base); |
167 | const int64_t sign_val = s[sp] < m ? -1 : 1; |
168 | float *s_add = (float *)src_add + off; |
169 | s_add[sp] = round_to_nearest_representable( |
170 | prb->dt, mod2_val * sign_val); |
171 | } |
172 | } |
173 | } |
174 | |
175 | ((float *)var)[c] = v / (prb->mb * prb->id * prb->ih * prb->iw); |
176 | |
177 | const float sc_value = 1.f / 8 * (1 << (c % 7)); |
178 | const float sh_value = ((c % 3) - 1) * sc_value / 64; |
179 | ((float *)sc)[c] = use_sc ? sc_value : 1.0f; |
180 | ((float *)sh)[c] = use_sh ? sh_value : 0.0f; |
181 | }); |
182 | |
183 | return OK; |
184 | } |
185 | |
186 | static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &src_add, |
187 | dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &sc, dnn_mem_t &sh) { |
188 | if (prb->flags & GLOB_STATS) |
189 | return prepare_fwd_with_stats(prb, src, src_add, mean, var, sc, sh); |
190 | else |
191 | return prepare_fwd_no_stats(prb, src, src_add, mean, var, sc, sh); |
192 | } |
193 | |
194 | static int prepare_bwd(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
195 | const auto nelems = mem_fp.nelems(); |
196 | if (nelems == 0) return OK; |
197 | |
198 | // Idea behind filling: integer diff_dst values decrease norms unlike fp32 |
199 | // values in [-1.f, 1.f] range. To decrease norms more, make data pretty |
200 | // sparse as answers sum all diff_dst values. |
201 | |
202 | /* Do fixed partitioning to have same filling for any number of threads */ |
203 | const int64_t n_chunks = 16; |
204 | const int64_t chunk_size = div_up(nelems, n_chunks); |
205 | |
206 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
207 | int64_t idx_start = idx_chunk * chunk_size; |
208 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
209 | |
210 | // Note: we use a different seed for each chunk to avoid |
211 | // repeating patterns. We could use discard(idx_start) too but |
212 | // it has a complexity in O(idx_start). We also add 1 to avoid |
213 | // seeding with 0. |
214 | std::minstd_rand msr(idx_start + 1); |
215 | msr.discard(1); |
216 | |
217 | std::uniform_int_distribution<> igen_val(-2, 2); |
218 | std::uniform_int_distribution<> igen_coin(0, 256 * 1024); |
219 | |
220 | // at least 20 non-zero elems |
221 | float sparsity = MAX2(0.05f, MIN2(1.f, 20.f / nelems)); |
222 | |
223 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
224 | float value = flip_coin(igen_coin(msr), sparsity) |
225 | ? round_to_nearest_representable(prb->dt, igen_val(msr)) |
226 | : 0; |
227 | mem_fp.set_elem(idx, value); |
228 | } |
229 | }); |
230 | |
231 | SAFE(mem_dt.reorder(mem_fp), WARN); |
232 | |
233 | return OK; |
234 | } |
235 | |
236 | int check_fwd_ws(const dnn_mem_t &dst_dt, const dnn_mem_t &ws_dt, res_t *res) { |
237 | if (ws_dt.ndims() == 0) return OK; |
238 | |
239 | /* so far we know ws is just bit-mask of whether value was negative or |
240 | * positive */ |
241 | const auto nelems = dst_dt.nelems(true); |
242 | const uint8_t *ws = (const uint8_t *)ws_dt; |
243 | |
244 | /* some internal knowledge: flags in ws are either stored as bytes (e.g. |
245 | * for the ref implementation) or as bits (e.g. for the jitted one); in |
246 | * the latter case the ws memory has fewer elements than the data memory */ |
247 | enum { ws_byte, ws_bit } ws_type; |
248 | ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte; |
249 | |
250 | /* more internal knowledge: dst_dt and ws_dt are expected to have exactly |
251 | * the same data layout, and dst_dt padded regions are expected to be |
252 | * zero, and the respective ws_dt elements should be set accordingly */ |
253 | for (int64_t i = 0; i < nelems; i += 8) { |
254 | for (int64_t j = 0; j < MIN2(8, nelems - i); ++j) { |
255 | const float data = dst_dt.get_elem(i + j); |
256 | const bool want = data > 0.f; |
257 | const bool bit_set = ws_type == ws_byte ? *ws : !!(*ws & (1 << j)); |
258 | |
259 | const bool ok = bit_set == want; |
260 | res->errors += !ok; |
261 | |
262 | bool dump = false || (!ok && (res->errors < 10 || verbose >= 10)) |
263 | || (verbose >= 50 && i < 30); |
264 | if (dump) { |
265 | BENCHDNN_PRINT(0, "[%4ld] ws exp:%d got:%d (data:%g:%a)\n" , |
266 | (long)(i + j), want, bit_set, data, data); |
267 | } |
268 | |
269 | if (ws_type == ws_byte) ++ws; |
270 | } |
271 | if (ws_type == ws_bit) ++ws; |
272 | } |
273 | |
274 | if (res->errors) res->state = FAILED; |
275 | if (res->state == EXECUTED) res->state = PASSED; |
276 | |
277 | return res->state == FAILED ? FAIL : OK; |
278 | } |
279 | |
280 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
281 | const prb_t *prb = init_pd_args.prb; |
282 | const dir_t dir = init_pd_args.dir; |
283 | |
284 | auto src_d = dnn_mem_t::init_md( |
285 | prb->ndims, prb->data_dims().data(), prb->dt, prb->tag); |
286 | |
287 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
288 | create_dnnl_attr(prb->attr, attr_args_t())); |
289 | |
290 | auto flags = (dnnl_normalization_flags_t)prb->flags; |
291 | if (dir & FLAG_FWD) { |
292 | auto dst_d = dnn_mem_t::init_md( |
293 | prb->ndims, prb->data_dims().data(), prb->dt, tag::any); |
294 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
295 | : dnnl_forward_training; |
296 | DNN_SAFE_STATUS(dnnl_batch_normalization_forward_primitive_desc_create( |
297 | &init_pd_args.pd, init_pd_args.engine, prop, src_d, dst_d, |
298 | prb->eps, flags, dnnl_attr)); |
299 | } else { |
300 | auto diff_src_d = dnn_mem_t::init_md( |
301 | prb->ndims, prb->data_dims().data(), prb->dt, tag::any); |
302 | auto diff_dst_d = dnn_mem_t::init_md( |
303 | prb->ndims, prb->data_dims().data(), prb->dt, tag::any); |
304 | auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data; |
305 | DNN_SAFE_STATUS(dnnl_batch_normalization_backward_primitive_desc_create( |
306 | &init_pd_args.pd, init_pd_args.engine, prop, diff_src_d, |
307 | diff_dst_d, src_d, prb->eps, flags, init_pd_args.hint, |
308 | dnnl_attr)); |
309 | } |
310 | |
311 | return dnnl_success; |
312 | } |
313 | |
314 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
315 | skip_unimplemented_data_type({prb->dt}, prb->dir, res); |
316 | skip_unimplemented_sum_po(prb->attr, res); |
317 | |
318 | // Non-zero alpha is not supported on GPU and for training in general. |
319 | const auto &po = prb->attr.post_ops; |
320 | const auto relu_idx = po.find(attr_t::post_ops_t::kind_t::RELU); |
321 | if (relu_idx >= 0) { |
322 | const auto &e = po.entry[relu_idx]; |
323 | float alpha = e.eltwise.alpha; |
324 | bool alpha_ok |
325 | = IMPLICATION(alpha != 0.f, (prb->dir & FLAG_INF) && is_cpu()); |
326 | if (!alpha_ok) { |
327 | res->state = SKIPPED; |
328 | res->reason = CASE_NOT_SUPPORTED; |
329 | } |
330 | } |
331 | // BN+Add+ReLU fusion is not supported on CPU |
332 | if (is_cpu() && prb->fuse_add_relu()) { |
333 | res->state = SKIPPED; |
334 | res->reason = CASE_NOT_SUPPORTED; |
335 | } |
336 | } |
337 | |
338 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
339 | // See `skip_invalid_inplace` for details. |
340 | if (prb->inplace) { |
341 | skip_invalid_inplace(res, prb->dt, prb->dt, prb->tag, prb->tag); |
342 | if (res->state == SKIPPED) return; |
343 | } |
344 | } |
345 | |
346 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
347 | const args_t &ref_args) { |
348 | // Since bwd testing is done using results from forward which are random |
349 | // fp32 values, diff_scale starts fluctuating, so we check norm for both |
350 | // data, SC, and SH. |
351 | const bool compare_with_norm = (prb->dir & FLAG_BWD); |
352 | cmp.set_norm_validation_mode(compare_with_norm); |
353 | |
354 | const int f32_mant_digits = 24; |
355 | const float trh_coeff = (1 << (f32_mant_digits - digits_dt(prb->dt))); |
356 | float trh = trh_coeff |
357 | * ((kind == SRC || kind == DST || kind == SRC_1) ? 5e-7 : 0); |
358 | if ((kind == SC || kind == SH) && prb->dir & FLAG_BWD) |
359 | trh = trh_coeff * 5e-6; |
360 | |
361 | #ifdef DNNL_EXPERIMENTAL |
362 | const bool bnorm_single_pass |
363 | = dnnl::impl::experimental::use_bnorm_stats_one_pass(); |
364 | #else |
365 | const bool bnorm_single_pass = false; |
366 | #endif |
367 | |
368 | const bool use_relaxed_validation = is_nvidia_gpu() || bnorm_single_pass; |
369 | if (use_relaxed_validation) { |
370 | // Nvidia: cuDNN stores unbiased variance which requires rescaling by |
371 | // `(N - 1) / N`, where `N = MB * Spatial`. Hence, we cannot set the |
372 | // threshold to 0... |
373 | // Also mean could be computed using a single pass formula. |
374 | // |
375 | // On Intel GPUs mean and variance could be rounded incorrectly because |
376 | // they are calculated using fast but potentially unstable formula. |
377 | if (kind == MEAN) trh = 1e-7; |
378 | if (kind == VAR) trh = 4e-7; |
379 | } |
380 | cmp.set_threshold(trh); |
381 | |
382 | // TODO: improve bf16 filling |
383 | if (prb->dt == dnnl_bf16) cmp.set_zero_trust_percent(99.f); |
384 | |
385 | // When the error is larger than `trh`, it could be due to a catastrophic |
386 | // cancellation in final result which is computed as `Y = a * X + b`. |
387 | // When `a * X` is close to `b` and their signs are opposite, then large |
388 | // error in `a * X` could result in a final result (which has a cancellation |
389 | // i.e. `|Y| = |a*X - (-b)|`), which has no meaningful digits left in |
390 | // mantissa. |
391 | // |
392 | // Since lambda is called when stack is unavailable, need to capture `prb` |
393 | // and `kind` by value to avoid using dangling references. |
394 | const auto bnorm_add_check = |
395 | [&, kind, prb]( |
396 | const compare::compare_t::driver_check_func_args_t &args) { |
397 | if (!((prb->dir & FLAG_FWD) && kind == DST && prb->use_sh())) |
398 | return false; |
399 | |
400 | const auto &sh = ref_args.find(DNNL_ARG_SHIFT); |
401 | const auto &dst = ref_args.find(DNNL_ARG_DST); |
402 | const int64_t c = dst.get_scale_idx( |
403 | args.idx, 1 << 1 /* channel_mask */); |
404 | const float beta = sh.get_elem(c); |
405 | // Using an empirically derived threshold, check if |
406 | // cancellation error in `|Y| = |a*X - (-b)|` is huge. |
407 | const float abs_exp = fabsf(args.exp); |
408 | const float norm_denom = abs_exp > FLT_MIN ? abs_exp : 1.f; |
409 | const float abs_exp_delta = fabsf(args.exp - beta); |
410 | bool maybe_cancel_error = abs_exp_delta / norm_denom > 1.f; |
411 | if (!maybe_cancel_error) return false; |
412 | |
413 | // Check for error in `a * X` |
414 | float diff_aX = fabsf((args.exp - beta) - (args.got - beta)); |
415 | float rel_diff_aX = diff_aX |
416 | / (abs_exp_delta > FLT_MIN ? abs_exp_delta : 1.f); |
417 | return rel_diff_aX <= args.trh; |
418 | }; |
419 | cmp.set_driver_check_function(bnorm_add_check); |
420 | } |
421 | |
422 | int doit(const prb_t *prb, res_t *res) { |
423 | if (bench_mode == LIST) return res->state = LISTED, OK; |
424 | |
425 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
426 | bool is_service_prim = prb->dir & FLAG_BWD; |
427 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr, |
428 | is_service_prim), |
429 | WARN); |
430 | |
431 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
432 | |
433 | auto const_fpd = query_pd(prim); |
434 | |
435 | const bool use_sc = prb->use_sc(); |
436 | const bool use_sh = prb->use_sh(); |
437 | const bool fuse_add_relu = prb->fuse_add_relu(); |
438 | |
439 | const auto &data_md = query_md(const_fpd, DNNL_ARG_SRC); |
440 | const auto &src_add_md = query_md(const_fpd, DNNL_ARG_SRC_1); |
441 | const auto &mean_md = query_md(const_fpd, DNNL_ARG_MEAN); |
442 | const auto &var_md = query_md(const_fpd, DNNL_ARG_VARIANCE); |
443 | const auto &sc_md = query_md(const_fpd, DNNL_ARG_SCALE); |
444 | const auto &sh_md = query_md(const_fpd, DNNL_ARG_SHIFT); |
445 | const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE); |
446 | const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD); |
447 | |
448 | const auto fp = dnnl_f32; |
449 | const auto tag = tag::abx; |
450 | |
451 | const auto &test_engine = get_test_engine(); |
452 | const auto &ref_engine = get_cpu_engine(); |
453 | |
454 | dnn_mem_t src_fp(data_md, fp, tag, ref_engine); |
455 | dnn_mem_t src_dt(data_md, test_engine); |
456 | dnn_mem_t src_add_fp(src_add_md, fp, tag, ref_engine); |
457 | dnn_mem_t src_add_dt(src_add_md, test_engine); |
458 | // stash for bwd: src_hat[i] = (src[i] - mean) / sqrt(var + prb->eps) |
459 | dnn_mem_t src_hat_fp(data_md, fp, tag, ref_engine); |
460 | |
461 | dnn_mem_t &dst_fp = src_fp; // in-place in ref code |
462 | dnn_mem_t placeholder_dst_dt; |
463 | const bool inplace_fwd = prb->inplace && (prb->dir & FLAG_FWD); |
464 | if (!inplace_fwd) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); } |
465 | dnn_mem_t &dst_dt = inplace_fwd ? src_dt : placeholder_dst_dt; |
466 | |
467 | // On inference w/o global stats the batch norm doesn't require stat |
468 | // memories. Hence, we need to prepare the mean_fp and var_fp ourselves. |
469 | const dnnl_dims_t dims1d = {prb->ic}; |
470 | dnn_mem_t mean_fp(1, dims1d, fp, tag::abx, ref_engine); |
471 | dnn_mem_t mean_dt(mean_md, test_engine); |
472 | dnn_mem_t var_fp(1, dims1d, fp, tag::abx, ref_engine); |
473 | dnn_mem_t var_dt(var_md, test_engine); |
474 | |
475 | dnn_mem_t sc_fp(1, dims1d, fp, tag::abx, ref_engine); |
476 | dnn_mem_t sc_dt(sc_md, test_engine); |
477 | dnn_mem_t d_sc_fp(1, dims1d, fp, tag::abx, ref_engine); |
478 | dnn_mem_t d_sc_dt(sc_md, test_engine); |
479 | |
480 | dnn_mem_t sh_fp(1, dims1d, fp, tag::abx, ref_engine); |
481 | dnn_mem_t sh_dt(sh_md, test_engine); |
482 | dnn_mem_t d_sh_fp(1, dims1d, fp, tag::abx, ref_engine); |
483 | dnn_mem_t d_sh_dt(sh_md, test_engine); |
484 | |
485 | dnn_mem_t ws_fp(data_md, dnnl_u8, tag, ref_engine); |
486 | dnn_mem_t ws_dt(ws_md, test_engine); |
487 | if (prb->need_ws()) SAFE(ws_dt.ndims() != 0 ? OK : FAIL, WARN); |
488 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
489 | |
490 | dnn_mem_t d_dst_dt, placeholder_d_src_dt; |
491 | |
492 | if (prepare_fwd(prb, src_fp, src_add_fp, mean_fp, var_fp, sc_fp, sh_fp) |
493 | != OK) { |
494 | return res->state = MISTRUSTED, OK; |
495 | } |
496 | |
497 | SAFE(src_dt.reorder(src_fp), WARN); |
498 | if (fuse_add_relu) SAFE(src_add_dt.reorder(src_add_fp), WARN); |
499 | if (prb->flags & GLOB_STATS) { |
500 | SAFE(mean_dt.reorder(mean_fp), WARN); |
501 | SAFE(var_dt.reorder(var_fp), WARN); |
502 | } |
503 | if (use_sc) { SAFE(sc_dt.reorder(sc_fp), WARN); } |
504 | if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); } |
505 | |
506 | args_t args, ref_args; |
507 | |
508 | args.set(DNNL_ARG_SRC, src_dt); |
509 | args.set(DNNL_ARG_SRC_1, src_add_dt); |
510 | args.set(DNNL_ARG_MEAN, mean_dt); |
511 | args.set(DNNL_ARG_VARIANCE, var_dt); |
512 | args.set(DNNL_ARG_SCALE, sc_dt); |
513 | args.set(DNNL_ARG_SHIFT, sh_dt); |
514 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
515 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
516 | args.set(DNNL_ARG_DST, dst_dt); |
517 | |
518 | SAFE(execute_and_wait(prim, args, res), WARN); |
519 | |
520 | // Running ref to collect src_hat (used instead of src + mean) and ws, if |
521 | // fuse_relu flag is requested. |
522 | if (is_bench_mode(CORR)) { |
523 | if (prb->dir & FLAG_FWD) { |
524 | ref_args.set(DNNL_ARG_SRC, src_fp); |
525 | ref_args.set(DNNL_ARG_SRC_1, src_add_fp); |
526 | ref_args.set(DNNL_ARG_MEAN, mean_fp); |
527 | ref_args.set(DNNL_ARG_VARIANCE, var_fp); |
528 | ref_args.set(DNNL_ARG_SCALE, sc_fp); |
529 | ref_args.set(DNNL_ARG_SHIFT, sh_fp); |
530 | ref_args.set(DNNL_ARG_WORKSPACE, ws_fp); |
531 | ref_args.set(DNNL_ARG_DST, dst_fp); |
532 | ref_args.set(DNNL_ARG_DST_1, src_hat_fp); // Reference aux arg. |
533 | |
534 | std::vector<data_kind_t> kinds {DST}; |
535 | if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) { |
536 | kinds.push_back(MEAN); |
537 | kinds.push_back(VAR); |
538 | } |
539 | |
540 | check_correctness(prb, kinds, args, ref_args, setup_cmp, res); |
541 | |
542 | if (prb->debug_check_ws) check_fwd_ws(dst_dt, ws_dt, res); |
543 | } |
544 | } |
545 | |
546 | if (prb->dir & FLAG_BWD) { |
547 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim; |
548 | SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN); |
549 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
550 | prim.reset(tmp_prim.release()); |
551 | |
552 | auto const_bpd = query_pd(prim); |
553 | |
554 | const auto &d_data_md = query_md(const_bpd, DNNL_ARG_DIFF_DST); |
555 | const auto &d_src_add_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC_1); |
556 | const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD); |
557 | |
558 | dnn_mem_t d_dst_fp(d_data_md, fp, tag, ref_engine); |
559 | d_dst_dt = dnn_mem_t(d_data_md, test_engine); |
560 | |
561 | dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code |
562 | dnn_mem_t d_src_add_fp(d_src_add_md, fp, tag, ref_engine); |
563 | if (!prb->inplace) { |
564 | placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine); |
565 | } |
566 | dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt; |
567 | dnn_mem_t d_src_add_dt = dnn_mem_t(d_src_add_md, test_engine); |
568 | |
569 | scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine); |
570 | |
571 | SAFE(prepare_bwd(prb, d_dst_dt, d_dst_fp), WARN); |
572 | |
573 | args.clear(); |
574 | args.set(DNNL_ARG_SRC, src_dt); |
575 | args.set(DNNL_ARG_SRC_1, src_add_dt); |
576 | args.set(DNNL_ARG_MEAN, mean_dt); |
577 | args.set(DNNL_ARG_VARIANCE, var_dt); |
578 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
579 | args.set(DNNL_ARG_SCALE, sc_dt); |
580 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
581 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
582 | args.set(DNNL_ARG_DIFF_SCALE, d_sc_dt); |
583 | args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt); |
584 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
585 | // Since DIFF_SRC_1 is the second output it can be in blocked format |
586 | // and unconditional including leads zero-paddiing failures |
587 | if (fuse_add_relu) args.set(DNNL_ARG_DIFF_SRC_1, d_src_add_dt); |
588 | |
589 | SAFE(execute_and_wait(prim, args, res), WARN); |
590 | |
591 | if (is_bench_mode(CORR)) { |
592 | ref_args.set(DNNL_ARG_SRC, src_fp); |
593 | ref_args.set(DNNL_ARG_SRC_1, src_add_fp); |
594 | ref_args.set(DNNL_ARG_MEAN, mean_fp); |
595 | ref_args.set(DNNL_ARG_VARIANCE, var_fp); |
596 | ref_args.set(DNNL_ARG_SCALE, sc_fp); |
597 | ref_args.set(DNNL_ARG_SHIFT, sh_fp); |
598 | ref_args.set(DNNL_ARG_WORKSPACE, ws_fp); |
599 | ref_args.set(DNNL_ARG_DST, dst_fp); |
600 | ref_args.set(DNNL_ARG_DST_1, src_hat_fp); // Reference aux arg. |
601 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
602 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
603 | ref_args.set(DNNL_ARG_DIFF_SRC_1, d_src_add_fp); |
604 | ref_args.set(DNNL_ARG_DIFF_SCALE, d_sc_fp); |
605 | ref_args.set(DNNL_ARG_DIFF_SHIFT, d_sh_fp); |
606 | |
607 | std::vector<data_kind_t> kinds {SRC}; |
608 | if (use_sc && (prb->dir & FLAG_WEI)) kinds.push_back(SC); |
609 | if (use_sh && (prb->dir & FLAG_WEI)) kinds.push_back(SH); |
610 | if (fuse_add_relu) kinds.push_back(SRC_1); |
611 | |
612 | check_correctness(prb, kinds, args, ref_args, setup_cmp, res); |
613 | } |
614 | } |
615 | |
616 | return measure_perf(prb->ctx_exe, res, prim, args); |
617 | } |
618 | |
619 | } // namespace bnorm |
620 | |