1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/compiler_workarounds.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/cpu_batch_normalization_utils.hpp" |
26 | #include "cpu/platform.hpp" |
27 | |
28 | #include "cpu/ncsp_batch_normalization.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | using namespace memory_tracking::names; |
35 | using namespace data_type; |
36 | |
37 | template <data_type_t d_type> |
38 | status_t ncsp_batch_normalization_fwd_t<d_type>::execute_forward( |
39 | const exec_ctx_t &ctx) const { |
40 | |
41 | const bool calculate_stats = !pd()->stats_is_src(); |
42 | const bool save_stats = pd()->is_training(); |
43 | const bool is_training = pd()->is_training(); |
44 | const bool fuse_norm_relu = pd()->fuse_norm_relu(); |
45 | |
46 | const bool use_scale = pd()->use_scale(); |
47 | const bool use_shift = pd()->use_shift(); |
48 | |
49 | const dim_t C = pd()->C(); |
50 | |
51 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
52 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
53 | auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT); |
54 | |
55 | auto scratchpad = ctx.get_scratchpad_grantor(); |
56 | auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction); |
57 | |
58 | acc_data_t *mean, *variance; |
59 | if (!calculate_stats) { |
60 | mean = const_cast<acc_data_t *>( |
61 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN)); |
62 | variance = const_cast<acc_data_t *>( |
63 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE)); |
64 | } else { |
65 | if (save_stats) { |
66 | mean = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN); |
67 | variance = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE); |
68 | } else { |
69 | mean = scratchpad.template get<acc_data_t>(key_bnorm_tmp_mean); |
70 | variance = scratchpad.template get<acc_data_t>(key_bnorm_tmp_var); |
71 | } |
72 | } |
73 | |
74 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
75 | auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE); |
76 | acc_data_t *src_cvt_wsp |
77 | = scratchpad.template get<acc_data_t>(key_bnorm_cvt); |
78 | |
79 | const float eps = pd()->desc()->batch_norm_epsilon; |
80 | const bool with_relu = pd()->with_relu_post_op(is_training); |
81 | auto maybe_post_op = [&](acc_data_t res) { |
82 | if (with_relu) return math::relu_fwd(res, pd()->alpha()); |
83 | return res; |
84 | }; |
85 | |
86 | const dim_t SP = pd()->H() * pd()->W() * pd()->D(); |
87 | const dim_t simd_w = 16; |
88 | const dim_t SP_cl_align = utils::rnd_up(SP, simd_w); |
89 | const dim_t N = pd()->MB(); |
90 | |
91 | const int nthr = pd()->nthr_; |
92 | size_t l3_size_ = platform::get_per_core_cache_size(3) * nthr / 2; |
93 | size_t data_size = N * C * SP * sizeof(data_t); |
94 | bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); |
95 | |
96 | parallel(nthr, [&](const int ithr, const int nthr) { |
97 | int C_ithr = 0, C_nthr = 0; |
98 | int N_ithr = 0, N_nthr = 0; |
99 | int S_ithr = 0, S_nthr = 0; |
100 | |
101 | dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; |
102 | dim_t N_s = 0, N_e = 0; |
103 | dim_t S_s = 0, S_e = 0; |
104 | |
105 | dim_t C_blks_per_iter = 1; |
106 | int64_t iters = 1; |
107 | |
108 | if (do_blocking) { |
109 | size_t working_set_size = N * SP * sizeof(data_t); |
110 | bnorm_utils::cache_balance( |
111 | working_set_size, C, N, nthr, C_blks_per_iter, iters); |
112 | } else |
113 | C_blks_per_iter = C; |
114 | int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; |
115 | bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, |
116 | true, false, ithr, nthr, N, C_blks_per_iter, SP, C_ithr, C_nthr, |
117 | C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, |
118 | S_e); |
119 | balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); |
120 | int SP_N_ithr = N_ithr * S_nthr + S_ithr; |
121 | int SP_N_nthr = N_nthr * S_nthr; |
122 | for (int64_t it = 0; it < iters; ++it) { |
123 | size_t C_off = it * C_blks_per_iter; |
124 | if (it == iters - 1 && iters > 1) { |
125 | // On the last iteration the access pattern to ws_reduce |
126 | // might change (due to re-balance on C). So sync the |
127 | // threads if they are not synced by the algorithm. |
128 | if (SP_N_nthr == 1 && dnnl_thr_syncable()) dnnl_thr_barrier(); |
129 | |
130 | S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; |
131 | spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, |
132 | spatial_thr_allowed, false, ithr, nthr, N, |
133 | last_iter_blks, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, |
134 | N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); |
135 | C_blks_per_iter = last_iter_blks; |
136 | balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); |
137 | SP_N_ithr = N_ithr * S_nthr + S_ithr; |
138 | SP_N_nthr = N_nthr * S_nthr; |
139 | } |
140 | const auto S_chunk = nstl::max(dim_t(0), S_e - S_s); |
141 | // On the last iteration the access pattern to ws_reduce |
142 | // might change (due to re-balance on C). Since sync is not always |
143 | // possible (in case of TBB) use different parts of ws for each |
144 | // iteration if threads are not synced by the algorithm. |
145 | size_t ws_iter_off = (dnnl_thr_syncable() ? 0 : 1) * C_off; |
146 | |
147 | if (calculate_stats) { |
148 | acc_data_t *mean_blk = mean + C_off; |
149 | acc_data_t *variance_blk = variance + C_off; |
150 | for (dim_t c = C_blk_s; c < C_blk_e; c++) { |
151 | size_t off = (c + C_off) * SP; |
152 | acc_data_t sum = 0; |
153 | for (dim_t n = N_s; n < N_e; ++n) { |
154 | const acc_data_t *scr_fp32; |
155 | size_t soff = off + n * C * SP; |
156 | if (utils::one_of(d_type, bf16, f16)) { |
157 | acc_data_t *tmp_src |
158 | = src_cvt_wsp + ithr * SP_cl_align; |
159 | /*TODO: remove this conversion if performance |
160 | doesn't degrade, since xfloat16_t supports += |
161 | operator with implicit conversions from xf16 to |
162 | float */ |
163 | types::cvt_to_float( |
164 | tmp_src + S_s, src + soff + S_s, S_chunk); |
165 | scr_fp32 = tmp_src; |
166 | } else { |
167 | scr_fp32 = reinterpret_cast<const acc_data_t *>( |
168 | src + soff); |
169 | } |
170 | PRAGMA_OMP_SIMD(reduction(+ : sum)) |
171 | for (dim_t sp = S_s; sp < S_e; ++sp) { |
172 | sum += scr_fp32[sp]; |
173 | } |
174 | } |
175 | ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] |
176 | = sum; |
177 | } |
178 | |
179 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
180 | |
181 | for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { |
182 | mean_blk[c] = 0.; |
183 | for (dim_t n = 0; n < SP_N_nthr; n++) |
184 | mean_blk[c] += ws_reduce[ws_iter_off |
185 | + n * C_blks_per_iter + c]; |
186 | mean_blk[c] /= (N * SP); |
187 | } |
188 | |
189 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
190 | |
191 | for (dim_t c = C_blk_s; c < C_blk_e; c++) { |
192 | size_t off = c + C_off; |
193 | acc_data_t sum = 0.; |
194 | for (dim_t n = N_s; n < N_e; ++n) { |
195 | const acc_data_t *_src; |
196 | size_t soff = off * SP + n * C * SP; |
197 | if (utils::one_of(d_type, bf16, f16)) { |
198 | acc_data_t *tmp_src |
199 | = src_cvt_wsp + ithr * SP_cl_align; |
200 | /*TODO: remove this conversion if performance |
201 | doesn't degrade, since xfloat16_t supports += |
202 | operator with implicit conversions from xf16 to |
203 | float */ |
204 | types::cvt_to_float( |
205 | tmp_src + S_s, src + soff + S_s, S_chunk); |
206 | _src = tmp_src; |
207 | } else { |
208 | _src = reinterpret_cast<const acc_data_t *>( |
209 | src + soff); |
210 | } |
211 | PRAGMA_OMP_SIMD(reduction(+ : sum)) |
212 | for (dim_t sp = S_s; sp < S_e; ++sp) { |
213 | acc_data_t m = _src[sp] - mean[off]; |
214 | sum += m * m; |
215 | } |
216 | } |
217 | ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] |
218 | = sum; |
219 | } |
220 | |
221 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
222 | |
223 | for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { |
224 | variance_blk[c] = 0.; |
225 | for (dim_t n = 0; n < SP_N_nthr; n++) |
226 | variance_blk[c] += ws_reduce[ws_iter_off |
227 | + n * C_blks_per_iter + c]; |
228 | variance_blk[c] /= (N * SP); |
229 | } |
230 | |
231 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
232 | } |
233 | |
234 | for (dim_t c = C_blk_s; c < C_blk_e; c++) { |
235 | size_t off = c + C_off; |
236 | acc_data_t sqrt_variance |
237 | = static_cast<acc_data_t>(sqrtf(variance[off] + eps)); |
238 | acc_data_t sm = (use_scale ? (acc_data_t)scale[off] |
239 | : (acc_data_t)1.0f) |
240 | / sqrt_variance; |
241 | acc_data_t sv |
242 | = use_shift ? (acc_data_t)shift[off] : (acc_data_t)0; |
243 | for (dim_t n = N_s; n < N_e; ++n) { |
244 | acc_data_t *_dst; |
245 | const acc_data_t *_src; |
246 | size_t s_off = off * SP + n * C * SP; |
247 | if (utils::one_of(d_type, bf16, f16)) { |
248 | // store dst to f32 buffer |
249 | _dst = src_cvt_wsp + ithr * SP_cl_align; |
250 | // convert src from bf16 to f32 |
251 | acc_data_t *tmp_src |
252 | = src_cvt_wsp + (nthr + ithr) * SP_cl_align; |
253 | /*TODO: remove this conversion if performance |
254 | doesn't degrade, since xfloat16_t supports += |
255 | operator with implicit conversions from xf16 to |
256 | float */ |
257 | types::cvt_to_float( |
258 | tmp_src + S_s, src + s_off + S_s, S_chunk); |
259 | _src = tmp_src; |
260 | } else { |
261 | _dst = reinterpret_cast<acc_data_t *>(dst + s_off); |
262 | _src = reinterpret_cast<const acc_data_t *>( |
263 | src + s_off); |
264 | } |
265 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
266 | PRAGMA_OMP_SIMD() |
267 | #endif |
268 | for (dim_t sp = S_s; sp < S_e; ++sp) { |
269 | size_t d_off = s_off + sp; |
270 | acc_data_t bn_res = sm * (_src[sp] - mean[off]) + sv; |
271 | if (fuse_norm_relu) { |
272 | if (bn_res <= 0) { |
273 | bn_res = 0; |
274 | if (is_training) ws[d_off] = 0; |
275 | } else { |
276 | if (is_training) ws[d_off] = 1; |
277 | } |
278 | } |
279 | _dst[sp] = maybe_post_op(bn_res); |
280 | } |
281 | if (utils::one_of(d_type, bf16, f16)) { |
282 | // convert dst from f32 to xf16 |
283 | types::cvt_from_float( |
284 | dst + s_off + S_s, _dst + S_s, S_chunk); |
285 | } |
286 | } |
287 | } |
288 | } |
289 | }); |
290 | |
291 | return status::success; |
292 | } |
293 | |
294 | template struct ncsp_batch_normalization_fwd_t<f32>; |
295 | template struct ncsp_batch_normalization_fwd_t<bf16>; |
296 | template struct ncsp_batch_normalization_fwd_t<f16>; |
297 | |
298 | template <data_type_t d_type> |
299 | status_t ncsp_batch_normalization_bwd_t<d_type>::execute_backward( |
300 | const exec_ctx_t &ctx) const { |
301 | |
302 | const auto use_scale = pd()->use_scale(); |
303 | |
304 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
305 | auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN); |
306 | auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE); |
307 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
308 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
309 | auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE); |
310 | |
311 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
312 | auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE); |
313 | auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT); |
314 | |
315 | auto scratchpad = ctx.get_scratchpad_grantor(); |
316 | auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction); |
317 | acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt); |
318 | |
319 | const size_t scratch_diff_shift_off = diff_scale ? 0 : pd()->C(); |
320 | if (diff_scale == nullptr) |
321 | diff_scale = scratchpad.template get<acc_data_t>(key_bnorm_tmp_diff_ss); |
322 | |
323 | if (diff_shift == nullptr) |
324 | diff_shift = &scratchpad.template get<acc_data_t>( |
325 | key_bnorm_tmp_diff_ss)[scratch_diff_shift_off]; |
326 | |
327 | const dim_t SP = pd()->D() * pd()->H() * pd()->W(); |
328 | const dim_t simd_w = 16; //?? |
329 | const dim_t SP_cl_align = utils::rnd_up(SP, simd_w); |
330 | const dim_t C = pd()->C(), N = pd()->MB(); |
331 | const float eps = pd()->desc()->batch_norm_epsilon; |
332 | const bool calculate_diff_stats = !pd()->use_global_stats(); |
333 | const bool fuse_norm_relu = pd()->fuse_norm_relu(); |
334 | |
335 | const int nthr = pd()->nthr_; |
336 | size_t l3_size_ = platform::get_per_core_cache_size(3) * nthr / 2; |
337 | size_t data_size = N * C * SP * sizeof(data_t); |
338 | bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); |
339 | |
340 | parallel(nthr, [&](const int ithr, const int nthr) { |
341 | int C_ithr = 0, C_nthr = 0; |
342 | int N_ithr = 0, N_nthr = 0; |
343 | int S_ithr = 0, S_nthr = 0; |
344 | |
345 | dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; |
346 | dim_t N_s = 0, N_e = 0; |
347 | dim_t S_s = 0, S_e = 0; |
348 | |
349 | dim_t C_blks_per_iter = 1; |
350 | int64_t iters = 1; |
351 | |
352 | if (do_blocking) { |
353 | size_t working_set_size = 2 * N * SP * sizeof(data_t); |
354 | bnorm_utils::cache_balance( |
355 | working_set_size, C, N, nthr, C_blks_per_iter, iters); |
356 | } else |
357 | C_blks_per_iter = C; |
358 | int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; |
359 | bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, |
360 | true, false, ithr, nthr, N, C_blks_per_iter, SP, C_ithr, C_nthr, |
361 | C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, |
362 | S_e); |
363 | balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); |
364 | int SP_N_ithr = N_ithr * S_nthr + S_ithr; |
365 | int SP_N_nthr = N_nthr * S_nthr; |
366 | |
367 | for (int64_t it = 0; it < iters; ++it) { |
368 | size_t C_off = it * C_blks_per_iter; |
369 | if (it == iters - 1 && iters > 1) { |
370 | // On the last iteration the access pattern to ws_reduce |
371 | // might change (due to re-balance on C). So sync the |
372 | // threads if they are not synced by the algorithm. |
373 | if (SP_N_nthr == 1 && dnnl_thr_syncable()) dnnl_thr_barrier(); |
374 | |
375 | C_blk_s = C_blk_e = N_s = N_e = 0; |
376 | spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, |
377 | spatial_thr_allowed, false, ithr, nthr, N, |
378 | last_iter_blks, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, |
379 | N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); |
380 | balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); |
381 | C_blks_per_iter = last_iter_blks; |
382 | SP_N_ithr = N_ithr * S_nthr + S_ithr; |
383 | SP_N_nthr = N_nthr * S_nthr; |
384 | } |
385 | const auto S_chunk = nstl::max(dim_t(0), S_e - S_s); |
386 | // On the last iteration the access pattern to ws_reduce |
387 | // might change (due to re-balance on C). Since sync is not always |
388 | // possible (in case of TBB) use different parts of ws for each |
389 | // iteration if threads are not synced by the algorithm. |
390 | size_t ws_iter_off = (dnnl_thr_syncable() ? 0 : 1) * 2 * C_off; |
391 | |
392 | acc_data_t *diff_gamma_blk = diff_scale + C_off; |
393 | acc_data_t *diff_beta_blk = diff_shift + C_off; |
394 | for (dim_t c = C_blk_s; c < C_blk_e; c++) { |
395 | size_t off = c + C_off; |
396 | acc_data_t diff_gamma = 0.0, diff_beta = 0.0; |
397 | acc_data_t v_mean = mean[off]; |
398 | for (dim_t n = N_s; n < N_e; ++n) { |
399 | const acc_data_t *_diff_dst; |
400 | const acc_data_t *_src; |
401 | dim_t s_off = off * SP + n * C * SP; |
402 | if (utils::one_of(d_type, bf16, f16)) { |
403 | // convert diff_dst to f32 |
404 | acc_data_t *tmp_diff_dst |
405 | = tmp_data_ + ithr * SP_cl_align; |
406 | types::cvt_to_float(tmp_diff_dst + S_s, |
407 | diff_dst + s_off + S_s, S_chunk); |
408 | _diff_dst = tmp_diff_dst; |
409 | // convert src to f32 |
410 | acc_data_t *tmp_src |
411 | = tmp_data_ + (nthr + ithr) * SP_cl_align; |
412 | types::cvt_to_float( |
413 | tmp_src + S_s, src + s_off + S_s, S_chunk); |
414 | _src = tmp_src; |
415 | } else { |
416 | _diff_dst = reinterpret_cast<const acc_data_t *>( |
417 | diff_dst + s_off); |
418 | _src = reinterpret_cast<const acc_data_t *>( |
419 | src + s_off); |
420 | } |
421 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
422 | PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) |
423 | #endif |
424 | for (dim_t sp = S_s; sp < S_e; ++sp) { |
425 | const dim_t d_off = s_off + sp; |
426 | acc_data_t dd; |
427 | if (fuse_norm_relu && !ws[d_off]) |
428 | dd = 0; |
429 | else |
430 | dd = _diff_dst[sp]; |
431 | diff_gamma += (_src[sp] - v_mean) * dd; |
432 | diff_beta += dd; |
433 | } |
434 | } |
435 | ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] |
436 | = diff_gamma; |
437 | ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter |
438 | + SP_N_ithr * C_blks_per_iter + c] |
439 | = diff_beta; |
440 | } |
441 | |
442 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
443 | |
444 | for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { |
445 | acc_data_t sqrt_variance = static_cast<acc_data_t>( |
446 | 1.0f / sqrtf(variance[c + C_off] + eps)); |
447 | diff_gamma_blk[c] = 0.; |
448 | diff_beta_blk[c] = 0.; |
449 | for (dim_t n = 0; n < SP_N_nthr; n++) { |
450 | diff_gamma_blk[c] |
451 | += ws_reduce[ws_iter_off + n * C_blks_per_iter + c]; |
452 | diff_beta_blk[c] += ws_reduce[ws_iter_off |
453 | + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter |
454 | + c]; |
455 | } |
456 | diff_gamma_blk[c] *= sqrt_variance; |
457 | } |
458 | |
459 | if (dnnl_thr_syncable()) dnnl_thr_barrier(); |
460 | |
461 | for (dim_t c = C_blk_s; c < C_blk_e; c++) { |
462 | size_t off = c + C_off; |
463 | acc_data_t gamma = use_scale ? scale[off] : 1; |
464 | acc_data_t sqrt_variance = static_cast<acc_data_t>( |
465 | 1.0f / sqrtf(variance[off] + eps)); |
466 | acc_data_t v_mean = mean[off]; |
467 | for (dim_t n = N_s; n < N_e; ++n) { |
468 | acc_data_t *_diff_src; |
469 | const acc_data_t *_diff_dst; |
470 | const acc_data_t *_src; |
471 | dim_t s_off = off * SP + n * C * SP; |
472 | if (utils::one_of(d_type, bf16, f16)) { |
473 | // store diff_src to f32 buffer |
474 | _diff_src = tmp_data_ + ithr * SP_cl_align; |
475 | acc_data_t *tmp_diff_dst |
476 | = tmp_data_ + ithr * SP_cl_align; |
477 | types::cvt_to_float(tmp_diff_dst + S_s, |
478 | diff_dst + s_off + S_s, S_chunk); |
479 | _diff_dst = tmp_diff_dst; |
480 | if (calculate_diff_stats) { |
481 | // convert src to f32 |
482 | acc_data_t *tmp_src = tmp_data_ |
483 | + (2 * nthr + ithr) * SP_cl_align; |
484 | types::cvt_to_float( |
485 | tmp_src + S_s, src + s_off + S_s, S_chunk); |
486 | _src = tmp_src; |
487 | } else |
488 | _src = nullptr; // to avoid compiler warning w/ |
489 | // gcc483 |
490 | } else { |
491 | _diff_src = reinterpret_cast<acc_data_t *>( |
492 | diff_src + s_off); |
493 | _diff_dst = reinterpret_cast<const acc_data_t *>( |
494 | diff_dst + s_off); |
495 | _src = reinterpret_cast<const acc_data_t *>( |
496 | src + s_off); |
497 | } |
498 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
499 | PRAGMA_OMP_SIMD() |
500 | #endif |
501 | for (dim_t sp = S_s; sp < S_e; ++sp) { |
502 | const dim_t d_off = s_off + sp; |
503 | acc_data_t v_diff_src; |
504 | if (fuse_norm_relu && !ws[d_off]) |
505 | v_diff_src = 0; |
506 | else |
507 | v_diff_src = _diff_dst[sp]; |
508 | if (calculate_diff_stats) { |
509 | v_diff_src -= diff_beta_blk[c] / (SP * N) |
510 | + (_src[sp] - v_mean) * diff_gamma_blk[c] |
511 | * sqrt_variance / (SP * N); |
512 | } |
513 | v_diff_src *= gamma * sqrt_variance; |
514 | _diff_src[sp] = v_diff_src; |
515 | } |
516 | if (utils::one_of(d_type, bf16, f16)) { |
517 | // convert diff_src from f32 |
518 | types::cvt_from_float(diff_src + s_off + S_s, |
519 | _diff_src + S_s, S_chunk); |
520 | } |
521 | } |
522 | } |
523 | } |
524 | }); |
525 | return status::success; |
526 | } |
527 | |
528 | template struct ncsp_batch_normalization_bwd_t<f32>; |
529 | template struct ncsp_batch_normalization_bwd_t<bf16>; |
530 | template struct ncsp_batch_normalization_bwd_t<f16>; |
531 | } // namespace cpu |
532 | } // namespace impl |
533 | } // namespace dnnl |
534 | |
535 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
536 | |