1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/compiler_workarounds.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/cpu_batch_normalization_utils.hpp"
26#include "cpu/platform.hpp"
27
28#include "cpu/ncsp_batch_normalization.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34using namespace memory_tracking::names;
35using namespace data_type;
36
37template <data_type_t d_type>
38status_t ncsp_batch_normalization_fwd_t<d_type>::execute_forward(
39 const exec_ctx_t &ctx) const {
40
41 const bool calculate_stats = !pd()->stats_is_src();
42 const bool save_stats = pd()->is_training();
43 const bool is_training = pd()->is_training();
44 const bool fuse_norm_relu = pd()->fuse_norm_relu();
45
46 const bool use_scale = pd()->use_scale();
47 const bool use_shift = pd()->use_shift();
48
49 const dim_t C = pd()->C();
50
51 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
52 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
53 auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT);
54
55 auto scratchpad = ctx.get_scratchpad_grantor();
56 auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction);
57
58 acc_data_t *mean, *variance;
59 if (!calculate_stats) {
60 mean = const_cast<acc_data_t *>(
61 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN));
62 variance = const_cast<acc_data_t *>(
63 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE));
64 } else {
65 if (save_stats) {
66 mean = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
67 variance = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
68 } else {
69 mean = scratchpad.template get<acc_data_t>(key_bnorm_tmp_mean);
70 variance = scratchpad.template get<acc_data_t>(key_bnorm_tmp_var);
71 }
72 }
73
74 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
75 auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
76 acc_data_t *src_cvt_wsp
77 = scratchpad.template get<acc_data_t>(key_bnorm_cvt);
78
79 const float eps = pd()->desc()->batch_norm_epsilon;
80 const bool with_relu = pd()->with_relu_post_op(is_training);
81 auto maybe_post_op = [&](acc_data_t res) {
82 if (with_relu) return math::relu_fwd(res, pd()->alpha());
83 return res;
84 };
85
86 const dim_t SP = pd()->H() * pd()->W() * pd()->D();
87 const dim_t simd_w = 16;
88 const dim_t SP_cl_align = utils::rnd_up(SP, simd_w);
89 const dim_t N = pd()->MB();
90
91 const int nthr = pd()->nthr_;
92 size_t l3_size_ = platform::get_per_core_cache_size(3) * nthr / 2;
93 size_t data_size = N * C * SP * sizeof(data_t);
94 bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
95
96 parallel(nthr, [&](const int ithr, const int nthr) {
97 int C_ithr = 0, C_nthr = 0;
98 int N_ithr = 0, N_nthr = 0;
99 int S_ithr = 0, S_nthr = 0;
100
101 dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
102 dim_t N_s = 0, N_e = 0;
103 dim_t S_s = 0, S_e = 0;
104
105 dim_t C_blks_per_iter = 1;
106 int64_t iters = 1;
107
108 if (do_blocking) {
109 size_t working_set_size = N * SP * sizeof(data_t);
110 bnorm_utils::cache_balance(
111 working_set_size, C, N, nthr, C_blks_per_iter, iters);
112 } else
113 C_blks_per_iter = C;
114 int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
115 bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
116 true, false, ithr, nthr, N, C_blks_per_iter, SP, C_ithr, C_nthr,
117 C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s,
118 S_e);
119 balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
120 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
121 int SP_N_nthr = N_nthr * S_nthr;
122 for (int64_t it = 0; it < iters; ++it) {
123 size_t C_off = it * C_blks_per_iter;
124 if (it == iters - 1 && iters > 1) {
125 // On the last iteration the access pattern to ws_reduce
126 // might change (due to re-balance on C). So sync the
127 // threads if they are not synced by the algorithm.
128 if (SP_N_nthr == 1 && dnnl_thr_syncable()) dnnl_thr_barrier();
129
130 S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0;
131 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
132 spatial_thr_allowed, false, ithr, nthr, N,
133 last_iter_blks, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
134 N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
135 C_blks_per_iter = last_iter_blks;
136 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
137 SP_N_ithr = N_ithr * S_nthr + S_ithr;
138 SP_N_nthr = N_nthr * S_nthr;
139 }
140 const auto S_chunk = nstl::max(dim_t(0), S_e - S_s);
141 // On the last iteration the access pattern to ws_reduce
142 // might change (due to re-balance on C). Since sync is not always
143 // possible (in case of TBB) use different parts of ws for each
144 // iteration if threads are not synced by the algorithm.
145 size_t ws_iter_off = (dnnl_thr_syncable() ? 0 : 1) * C_off;
146
147 if (calculate_stats) {
148 acc_data_t *mean_blk = mean + C_off;
149 acc_data_t *variance_blk = variance + C_off;
150 for (dim_t c = C_blk_s; c < C_blk_e; c++) {
151 size_t off = (c + C_off) * SP;
152 acc_data_t sum = 0;
153 for (dim_t n = N_s; n < N_e; ++n) {
154 const acc_data_t *scr_fp32;
155 size_t soff = off + n * C * SP;
156 if (utils::one_of(d_type, bf16, f16)) {
157 acc_data_t *tmp_src
158 = src_cvt_wsp + ithr * SP_cl_align;
159 /*TODO: remove this conversion if performance
160 doesn't degrade, since xfloat16_t supports +=
161 operator with implicit conversions from xf16 to
162 float */
163 types::cvt_to_float(
164 tmp_src + S_s, src + soff + S_s, S_chunk);
165 scr_fp32 = tmp_src;
166 } else {
167 scr_fp32 = reinterpret_cast<const acc_data_t *>(
168 src + soff);
169 }
170 PRAGMA_OMP_SIMD(reduction(+ : sum))
171 for (dim_t sp = S_s; sp < S_e; ++sp) {
172 sum += scr_fp32[sp];
173 }
174 }
175 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
176 = sum;
177 }
178
179 if (dnnl_thr_syncable()) dnnl_thr_barrier();
180
181 for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
182 mean_blk[c] = 0.;
183 for (dim_t n = 0; n < SP_N_nthr; n++)
184 mean_blk[c] += ws_reduce[ws_iter_off
185 + n * C_blks_per_iter + c];
186 mean_blk[c] /= (N * SP);
187 }
188
189 if (dnnl_thr_syncable()) dnnl_thr_barrier();
190
191 for (dim_t c = C_blk_s; c < C_blk_e; c++) {
192 size_t off = c + C_off;
193 acc_data_t sum = 0.;
194 for (dim_t n = N_s; n < N_e; ++n) {
195 const acc_data_t *_src;
196 size_t soff = off * SP + n * C * SP;
197 if (utils::one_of(d_type, bf16, f16)) {
198 acc_data_t *tmp_src
199 = src_cvt_wsp + ithr * SP_cl_align;
200 /*TODO: remove this conversion if performance
201 doesn't degrade, since xfloat16_t supports +=
202 operator with implicit conversions from xf16 to
203 float */
204 types::cvt_to_float(
205 tmp_src + S_s, src + soff + S_s, S_chunk);
206 _src = tmp_src;
207 } else {
208 _src = reinterpret_cast<const acc_data_t *>(
209 src + soff);
210 }
211 PRAGMA_OMP_SIMD(reduction(+ : sum))
212 for (dim_t sp = S_s; sp < S_e; ++sp) {
213 acc_data_t m = _src[sp] - mean[off];
214 sum += m * m;
215 }
216 }
217 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
218 = sum;
219 }
220
221 if (dnnl_thr_syncable()) dnnl_thr_barrier();
222
223 for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
224 variance_blk[c] = 0.;
225 for (dim_t n = 0; n < SP_N_nthr; n++)
226 variance_blk[c] += ws_reduce[ws_iter_off
227 + n * C_blks_per_iter + c];
228 variance_blk[c] /= (N * SP);
229 }
230
231 if (dnnl_thr_syncable()) dnnl_thr_barrier();
232 }
233
234 for (dim_t c = C_blk_s; c < C_blk_e; c++) {
235 size_t off = c + C_off;
236 acc_data_t sqrt_variance
237 = static_cast<acc_data_t>(sqrtf(variance[off] + eps));
238 acc_data_t sm = (use_scale ? (acc_data_t)scale[off]
239 : (acc_data_t)1.0f)
240 / sqrt_variance;
241 acc_data_t sv
242 = use_shift ? (acc_data_t)shift[off] : (acc_data_t)0;
243 for (dim_t n = N_s; n < N_e; ++n) {
244 acc_data_t *_dst;
245 const acc_data_t *_src;
246 size_t s_off = off * SP + n * C * SP;
247 if (utils::one_of(d_type, bf16, f16)) {
248 // store dst to f32 buffer
249 _dst = src_cvt_wsp + ithr * SP_cl_align;
250 // convert src from bf16 to f32
251 acc_data_t *tmp_src
252 = src_cvt_wsp + (nthr + ithr) * SP_cl_align;
253 /*TODO: remove this conversion if performance
254 doesn't degrade, since xfloat16_t supports +=
255 operator with implicit conversions from xf16 to
256 float */
257 types::cvt_to_float(
258 tmp_src + S_s, src + s_off + S_s, S_chunk);
259 _src = tmp_src;
260 } else {
261 _dst = reinterpret_cast<acc_data_t *>(dst + s_off);
262 _src = reinterpret_cast<const acc_data_t *>(
263 src + s_off);
264 }
265#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
266 PRAGMA_OMP_SIMD()
267#endif
268 for (dim_t sp = S_s; sp < S_e; ++sp) {
269 size_t d_off = s_off + sp;
270 acc_data_t bn_res = sm * (_src[sp] - mean[off]) + sv;
271 if (fuse_norm_relu) {
272 if (bn_res <= 0) {
273 bn_res = 0;
274 if (is_training) ws[d_off] = 0;
275 } else {
276 if (is_training) ws[d_off] = 1;
277 }
278 }
279 _dst[sp] = maybe_post_op(bn_res);
280 }
281 if (utils::one_of(d_type, bf16, f16)) {
282 // convert dst from f32 to xf16
283 types::cvt_from_float(
284 dst + s_off + S_s, _dst + S_s, S_chunk);
285 }
286 }
287 }
288 }
289 });
290
291 return status::success;
292}
293
294template struct ncsp_batch_normalization_fwd_t<f32>;
295template struct ncsp_batch_normalization_fwd_t<bf16>;
296template struct ncsp_batch_normalization_fwd_t<f16>;
297
298template <data_type_t d_type>
299status_t ncsp_batch_normalization_bwd_t<d_type>::execute_backward(
300 const exec_ctx_t &ctx) const {
301
302 const auto use_scale = pd()->use_scale();
303
304 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
305 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
306 auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
307 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
308 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
309 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
310
311 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
312 auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE);
313 auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT);
314
315 auto scratchpad = ctx.get_scratchpad_grantor();
316 auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction);
317 acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt);
318
319 const size_t scratch_diff_shift_off = diff_scale ? 0 : pd()->C();
320 if (diff_scale == nullptr)
321 diff_scale = scratchpad.template get<acc_data_t>(key_bnorm_tmp_diff_ss);
322
323 if (diff_shift == nullptr)
324 diff_shift = &scratchpad.template get<acc_data_t>(
325 key_bnorm_tmp_diff_ss)[scratch_diff_shift_off];
326
327 const dim_t SP = pd()->D() * pd()->H() * pd()->W();
328 const dim_t simd_w = 16; //??
329 const dim_t SP_cl_align = utils::rnd_up(SP, simd_w);
330 const dim_t C = pd()->C(), N = pd()->MB();
331 const float eps = pd()->desc()->batch_norm_epsilon;
332 const bool calculate_diff_stats = !pd()->use_global_stats();
333 const bool fuse_norm_relu = pd()->fuse_norm_relu();
334
335 const int nthr = pd()->nthr_;
336 size_t l3_size_ = platform::get_per_core_cache_size(3) * nthr / 2;
337 size_t data_size = N * C * SP * sizeof(data_t);
338 bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
339
340 parallel(nthr, [&](const int ithr, const int nthr) {
341 int C_ithr = 0, C_nthr = 0;
342 int N_ithr = 0, N_nthr = 0;
343 int S_ithr = 0, S_nthr = 0;
344
345 dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
346 dim_t N_s = 0, N_e = 0;
347 dim_t S_s = 0, S_e = 0;
348
349 dim_t C_blks_per_iter = 1;
350 int64_t iters = 1;
351
352 if (do_blocking) {
353 size_t working_set_size = 2 * N * SP * sizeof(data_t);
354 bnorm_utils::cache_balance(
355 working_set_size, C, N, nthr, C_blks_per_iter, iters);
356 } else
357 C_blks_per_iter = C;
358 int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
359 bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
360 true, false, ithr, nthr, N, C_blks_per_iter, SP, C_ithr, C_nthr,
361 C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s,
362 S_e);
363 balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
364 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
365 int SP_N_nthr = N_nthr * S_nthr;
366
367 for (int64_t it = 0; it < iters; ++it) {
368 size_t C_off = it * C_blks_per_iter;
369 if (it == iters - 1 && iters > 1) {
370 // On the last iteration the access pattern to ws_reduce
371 // might change (due to re-balance on C). So sync the
372 // threads if they are not synced by the algorithm.
373 if (SP_N_nthr == 1 && dnnl_thr_syncable()) dnnl_thr_barrier();
374
375 C_blk_s = C_blk_e = N_s = N_e = 0;
376 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
377 spatial_thr_allowed, false, ithr, nthr, N,
378 last_iter_blks, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
379 N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
380 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
381 C_blks_per_iter = last_iter_blks;
382 SP_N_ithr = N_ithr * S_nthr + S_ithr;
383 SP_N_nthr = N_nthr * S_nthr;
384 }
385 const auto S_chunk = nstl::max(dim_t(0), S_e - S_s);
386 // On the last iteration the access pattern to ws_reduce
387 // might change (due to re-balance on C). Since sync is not always
388 // possible (in case of TBB) use different parts of ws for each
389 // iteration if threads are not synced by the algorithm.
390 size_t ws_iter_off = (dnnl_thr_syncable() ? 0 : 1) * 2 * C_off;
391
392 acc_data_t *diff_gamma_blk = diff_scale + C_off;
393 acc_data_t *diff_beta_blk = diff_shift + C_off;
394 for (dim_t c = C_blk_s; c < C_blk_e; c++) {
395 size_t off = c + C_off;
396 acc_data_t diff_gamma = 0.0, diff_beta = 0.0;
397 acc_data_t v_mean = mean[off];
398 for (dim_t n = N_s; n < N_e; ++n) {
399 const acc_data_t *_diff_dst;
400 const acc_data_t *_src;
401 dim_t s_off = off * SP + n * C * SP;
402 if (utils::one_of(d_type, bf16, f16)) {
403 // convert diff_dst to f32
404 acc_data_t *tmp_diff_dst
405 = tmp_data_ + ithr * SP_cl_align;
406 types::cvt_to_float(tmp_diff_dst + S_s,
407 diff_dst + s_off + S_s, S_chunk);
408 _diff_dst = tmp_diff_dst;
409 // convert src to f32
410 acc_data_t *tmp_src
411 = tmp_data_ + (nthr + ithr) * SP_cl_align;
412 types::cvt_to_float(
413 tmp_src + S_s, src + s_off + S_s, S_chunk);
414 _src = tmp_src;
415 } else {
416 _diff_dst = reinterpret_cast<const acc_data_t *>(
417 diff_dst + s_off);
418 _src = reinterpret_cast<const acc_data_t *>(
419 src + s_off);
420 }
421#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
422 PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta))
423#endif
424 for (dim_t sp = S_s; sp < S_e; ++sp) {
425 const dim_t d_off = s_off + sp;
426 acc_data_t dd;
427 if (fuse_norm_relu && !ws[d_off])
428 dd = 0;
429 else
430 dd = _diff_dst[sp];
431 diff_gamma += (_src[sp] - v_mean) * dd;
432 diff_beta += dd;
433 }
434 }
435 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
436 = diff_gamma;
437 ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter
438 + SP_N_ithr * C_blks_per_iter + c]
439 = diff_beta;
440 }
441
442 if (dnnl_thr_syncable()) dnnl_thr_barrier();
443
444 for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
445 acc_data_t sqrt_variance = static_cast<acc_data_t>(
446 1.0f / sqrtf(variance[c + C_off] + eps));
447 diff_gamma_blk[c] = 0.;
448 diff_beta_blk[c] = 0.;
449 for (dim_t n = 0; n < SP_N_nthr; n++) {
450 diff_gamma_blk[c]
451 += ws_reduce[ws_iter_off + n * C_blks_per_iter + c];
452 diff_beta_blk[c] += ws_reduce[ws_iter_off
453 + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter
454 + c];
455 }
456 diff_gamma_blk[c] *= sqrt_variance;
457 }
458
459 if (dnnl_thr_syncable()) dnnl_thr_barrier();
460
461 for (dim_t c = C_blk_s; c < C_blk_e; c++) {
462 size_t off = c + C_off;
463 acc_data_t gamma = use_scale ? scale[off] : 1;
464 acc_data_t sqrt_variance = static_cast<acc_data_t>(
465 1.0f / sqrtf(variance[off] + eps));
466 acc_data_t v_mean = mean[off];
467 for (dim_t n = N_s; n < N_e; ++n) {
468 acc_data_t *_diff_src;
469 const acc_data_t *_diff_dst;
470 const acc_data_t *_src;
471 dim_t s_off = off * SP + n * C * SP;
472 if (utils::one_of(d_type, bf16, f16)) {
473 // store diff_src to f32 buffer
474 _diff_src = tmp_data_ + ithr * SP_cl_align;
475 acc_data_t *tmp_diff_dst
476 = tmp_data_ + ithr * SP_cl_align;
477 types::cvt_to_float(tmp_diff_dst + S_s,
478 diff_dst + s_off + S_s, S_chunk);
479 _diff_dst = tmp_diff_dst;
480 if (calculate_diff_stats) {
481 // convert src to f32
482 acc_data_t *tmp_src = tmp_data_
483 + (2 * nthr + ithr) * SP_cl_align;
484 types::cvt_to_float(
485 tmp_src + S_s, src + s_off + S_s, S_chunk);
486 _src = tmp_src;
487 } else
488 _src = nullptr; // to avoid compiler warning w/
489 // gcc483
490 } else {
491 _diff_src = reinterpret_cast<acc_data_t *>(
492 diff_src + s_off);
493 _diff_dst = reinterpret_cast<const acc_data_t *>(
494 diff_dst + s_off);
495 _src = reinterpret_cast<const acc_data_t *>(
496 src + s_off);
497 }
498#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
499 PRAGMA_OMP_SIMD()
500#endif
501 for (dim_t sp = S_s; sp < S_e; ++sp) {
502 const dim_t d_off = s_off + sp;
503 acc_data_t v_diff_src;
504 if (fuse_norm_relu && !ws[d_off])
505 v_diff_src = 0;
506 else
507 v_diff_src = _diff_dst[sp];
508 if (calculate_diff_stats) {
509 v_diff_src -= diff_beta_blk[c] / (SP * N)
510 + (_src[sp] - v_mean) * diff_gamma_blk[c]
511 * sqrt_variance / (SP * N);
512 }
513 v_diff_src *= gamma * sqrt_variance;
514 _diff_src[sp] = v_diff_src;
515 }
516 if (utils::one_of(d_type, bf16, f16)) {
517 // convert diff_src from f32
518 types::cvt_from_float(diff_src + s_off + S_s,
519 _diff_src + S_s, S_chunk);
520 }
521 }
522 }
523 }
524 });
525 return status::success;
526}
527
528template struct ncsp_batch_normalization_bwd_t<f32>;
529template struct ncsp_batch_normalization_bwd_t<bf16>;
530template struct ncsp_batch_normalization_bwd_t<f16>;
531} // namespace cpu
532} // namespace impl
533} // namespace dnnl
534
535// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
536