1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include <algorithm> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/compiler_workarounds.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | |
27 | #include "cpu/platform.hpp" |
28 | |
29 | #include "cpu/cpu_batch_normalization_utils.hpp" |
30 | |
31 | #include "cpu/nspc_batch_normalization.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | |
37 | using namespace memory_tracking::names; |
38 | using namespace data_type; |
39 | |
40 | template <data_type_t d_type> |
41 | status_t nspc_batch_normalization_fwd_t<d_type>::execute_forward( |
42 | const exec_ctx_t &ctx) const { |
43 | const bool save_stats = pd()->is_training(); |
44 | const bool is_training = pd()->is_training(); |
45 | const bool fuse_norm_relu = pd()->fuse_norm_relu(); |
46 | const bool calculate_stats = !pd()->stats_is_src(); |
47 | const bool with_relu = pd()->with_relu_post_op(is_training); |
48 | |
49 | const auto use_scale = pd()->use_scale(); |
50 | const auto use_shift = pd()->use_shift(); |
51 | |
52 | auto scratchpad = ctx.get_scratchpad_grantor(); |
53 | auto tmp_mean = scratchpad.template get<acc_data_t>(key_bnorm_tmp_mean); |
54 | auto tmp_var = scratchpad.template get<acc_data_t>(key_bnorm_tmp_var); |
55 | auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction); |
56 | |
57 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
58 | auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE); |
59 | auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT); |
60 | |
61 | acc_data_t *mean, *variance; |
62 | if (!calculate_stats) { |
63 | mean = const_cast<acc_data_t *>( |
64 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN)); |
65 | variance = const_cast<acc_data_t *>( |
66 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE)); |
67 | } else { |
68 | if (save_stats) { |
69 | mean = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN); |
70 | variance = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE); |
71 | } else { |
72 | mean = tmp_mean; |
73 | variance = tmp_var; |
74 | } |
75 | } |
76 | |
77 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
78 | auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE); |
79 | acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt); |
80 | |
81 | const dim_t N = pd()->MB(); |
82 | const dim_t C = pd()->C(); |
83 | const int simd_w = 16; |
84 | const dim_t C_align = utils::rnd_up(C, simd_w); |
85 | const dim_t SP = pd()->H() * pd()->W() * pd()->D(); |
86 | |
87 | const float eps = pd()->desc()->batch_norm_epsilon; |
88 | auto maybe_post_op = [&](acc_data_t res) { |
89 | if (with_relu) return math::relu_fwd(res, pd()->alpha()); |
90 | return res; |
91 | }; |
92 | const int nthr = pd()->nthr_; |
93 | |
94 | if (calculate_stats) { |
95 | parallel(nthr, [&](const int ithr, const int nthr) { |
96 | dim_t N_s = 0, N_e = 0; |
97 | balance211(N, nthr, ithr, N_s, N_e); |
98 | |
99 | for (dim_t c = 0; c < C; c++) |
100 | ws_reduce[C * ithr + c] = 0.; |
101 | |
102 | for (dim_t n = N_s; n < N_e; n++) { |
103 | for (dim_t sp = 0; sp < SP; sp++) { |
104 | const acc_data_t *_src; |
105 | const size_t s_off = (size_t)n * SP * C + sp * C; |
106 | if (utils::one_of(d_type, bf16, f16)) { |
107 | // convert src from xf16 to f32 |
108 | acc_data_t *tmp_src = tmp_data_ + ithr * C_align; |
109 | types::cvt_to_float(tmp_src, src + s_off, C); |
110 | _src = tmp_src; |
111 | } else { |
112 | _src = reinterpret_cast<const acc_data_t *>( |
113 | src + s_off); |
114 | } |
115 | PRAGMA_OMP_SIMD() |
116 | for (int c = 0; c < C; c++) { |
117 | ws_reduce[C * ithr + c] += _src[c]; |
118 | } |
119 | } |
120 | } |
121 | }); |
122 | parallel_nd(C, [&](dim_t c) { |
123 | mean[c] = 0; |
124 | for (dim_t n = 0; n < nthr; n++) |
125 | mean[c] += ws_reduce[C * n + c]; |
126 | mean[c] /= SP * N; |
127 | }); |
128 | parallel(nthr, [&](const int ithr, const int nthr) { |
129 | dim_t N_s = 0, N_e = 0; |
130 | balance211(N, nthr, ithr, N_s, N_e); |
131 | |
132 | acc_data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; |
133 | |
134 | if (ithr > 0 || save_stats) { |
135 | for (dim_t c = 0; c < C; c++) |
136 | mean_loc[c] = mean[c]; |
137 | } |
138 | |
139 | for (dim_t c = 0; c < C; c++) |
140 | ws_reduce[C * ithr + c] = 0.; |
141 | |
142 | for (dim_t n = N_s; n < N_e; n++) { |
143 | for (dim_t sp = 0; sp < SP; sp++) { |
144 | const acc_data_t *_src; |
145 | const size_t s_off = (size_t)n * SP * C + sp * C; |
146 | if (utils::one_of(d_type, bf16, f16)) { |
147 | // convert src from xf16 to f32 |
148 | acc_data_t *tmp_src = tmp_data_ + ithr * C_align; |
149 | types::cvt_to_float(tmp_src, src + s_off, C); |
150 | _src = tmp_src; |
151 | } else { |
152 | _src = reinterpret_cast<const acc_data_t *>( |
153 | src + s_off); |
154 | } |
155 | PRAGMA_OMP_SIMD() |
156 | for (int c = 0; c < C; c++) { |
157 | acc_data_t m = _src[c] - mean_loc[c]; |
158 | ws_reduce[C * ithr + c] += m * m; |
159 | } |
160 | } |
161 | } |
162 | }); |
163 | parallel_nd(C, [&](dim_t c) { |
164 | variance[c] = 0; |
165 | for (dim_t n = 0; n < nthr; n++) |
166 | variance[c] += ws_reduce[C * n + c]; |
167 | variance[c] /= SP * N; |
168 | }); |
169 | parallel(nthr, [&](const int ithr, const int nthr) { |
170 | acc_data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; |
171 | if (ithr > 0 || save_stats) { |
172 | for (dim_t c = 0; c < C; c++) |
173 | variance_loc[c] = variance[c]; |
174 | } |
175 | }); |
176 | } |
177 | |
178 | parallel(nthr, [&](const int ithr, const int nthr) { |
179 | dim_t N_s = 0, N_e = 0; |
180 | balance211(N, nthr, ithr, N_s, N_e); |
181 | |
182 | acc_data_t *mean_loc, *variance_loc; |
183 | if (calculate_stats) { |
184 | mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; |
185 | variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; |
186 | } else { |
187 | mean_loc = mean; |
188 | variance_loc = variance; |
189 | } |
190 | |
191 | for (dim_t n = N_s; n < N_e; n++) { |
192 | for (dim_t sp = 0; sp < SP; sp++) { |
193 | acc_data_t *_dst; |
194 | const acc_data_t *_src; |
195 | const size_t s_off = (size_t)n * SP * C + sp * C; |
196 | if (utils::one_of(d_type, bf16, f16)) { |
197 | // store dst to f32 buffer |
198 | _dst = tmp_data_ + ithr * C_align; |
199 | // convert src from xf16 to f32 |
200 | acc_data_t *tmp_src = tmp_data_ + (nthr + ithr) * C_align; |
201 | types::cvt_to_float(tmp_src, src + s_off, C); |
202 | _src = tmp_src; |
203 | } else { |
204 | _dst = reinterpret_cast<acc_data_t *>(dst + s_off); |
205 | _src = reinterpret_cast<const acc_data_t *>(src + s_off); |
206 | } |
207 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
208 | PRAGMA_OMP_SIMD() |
209 | #endif |
210 | for (int c = 0; c < C; c++) { |
211 | const size_t c_off = s_off + c; |
212 | acc_data_t sqrt_variance = static_cast<acc_data_t>( |
213 | sqrtf(variance_loc[c] + eps)); |
214 | acc_data_t sm = (use_scale ? (acc_data_t)scale[c] |
215 | : (acc_data_t)1.0f) |
216 | / sqrt_variance; |
217 | acc_data_t sv |
218 | = use_shift ? (acc_data_t)shift[c] : (acc_data_t)0; |
219 | acc_data_t bn_res = sm * (_src[c] - mean_loc[c]) + sv; |
220 | if (fuse_norm_relu) { |
221 | if (bn_res <= 0) { |
222 | bn_res = 0; |
223 | if (is_training) ws[c_off] = 0; |
224 | } else { |
225 | if (is_training) ws[c_off] = 1; |
226 | } |
227 | } |
228 | _dst[c] = maybe_post_op(bn_res); |
229 | } |
230 | if (utils::one_of(d_type, bf16, f16)) { |
231 | // convert dst from f32 to xf16 |
232 | types::cvt_from_float(dst + s_off, _dst, C); |
233 | } |
234 | } |
235 | } |
236 | }); |
237 | return status::success; |
238 | } |
239 | |
240 | template struct nspc_batch_normalization_fwd_t<f32>; |
241 | template struct nspc_batch_normalization_fwd_t<bf16>; |
242 | template struct nspc_batch_normalization_fwd_t<f16>; |
243 | |
244 | template <data_type_t d_type> |
245 | status_t nspc_batch_normalization_bwd_t<d_type>::execute_backward( |
246 | const exec_ctx_t &ctx) const { |
247 | |
248 | const auto use_scale = pd()->use_scale(); |
249 | |
250 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
251 | auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN); |
252 | auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE); |
253 | auto scale = CTX_IN_MEM(acc_data_t *, DNNL_ARG_SCALE); |
254 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
255 | auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE); |
256 | |
257 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
258 | auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE); |
259 | auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT); |
260 | |
261 | auto scratchpad = ctx.get_scratchpad_grantor(); |
262 | auto tmp_diff_ss |
263 | = scratchpad.template get<acc_data_t>(key_bnorm_tmp_diff_ss); |
264 | |
265 | const size_t scratch_diff_shift_off = diff_scale ? 0 : pd()->C(); |
266 | if (diff_scale == nullptr) diff_scale = tmp_diff_ss; |
267 | if (diff_shift == nullptr) |
268 | diff_shift = &tmp_diff_ss[scratch_diff_shift_off]; |
269 | |
270 | const dim_t N = pd()->MB(); |
271 | const dim_t C = pd()->C(); |
272 | const int simd_w = 16; |
273 | const dim_t C_align = utils::rnd_up(C, simd_w); |
274 | const dim_t SP = pd()->D() * pd()->H() * pd()->W(); |
275 | acc_data_t *diff_gamma = diff_scale, *diff_beta = diff_shift; |
276 | acc_data_t *ws_reduce |
277 | = scratchpad.template get<acc_data_t>(key_bnorm_reduction); |
278 | acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt); |
279 | |
280 | const float eps = pd()->desc()->batch_norm_epsilon; |
281 | const bool calculate_diff_stats = !pd()->use_global_stats(); |
282 | const bool fuse_norm_relu = pd()->fuse_norm_relu(); |
283 | |
284 | /* Note: potential seg-fault from incorrectly compiled vectorized-loop. |
285 | * Explicit tail-processing fixes this issue. */ |
286 | const dim_t c_blk = std::max( |
287 | platform::get_vector_register_size() / (int)sizeof(float), 8); |
288 | const dim_t tail = C % c_blk; |
289 | const dim_t nb_c_blk = (size_t)C / c_blk; |
290 | const int nthr = pd()->nthr_; |
291 | |
292 | parallel(nthr, [&](const int ithr, const int nthr) { |
293 | dim_t N_s = 0, N_e = 0; |
294 | balance211(N, nthr, ithr, N_s, N_e); |
295 | |
296 | for (dim_t c = 0; c < C; c++) { |
297 | ws_reduce[C * ithr + c] = 0.; |
298 | ws_reduce[C * nthr + C * ithr + c] = 0.; |
299 | } |
300 | |
301 | for (dim_t n = N_s; n < N_e; n++) { |
302 | for (dim_t sp = 0; sp < SP; sp++) { |
303 | const acc_data_t *_diff_dst; |
304 | const acc_data_t *_src; |
305 | const size_t s_off = (size_t)n * SP * C + sp * C; |
306 | if (utils::one_of(d_type, bf16, f16)) { |
307 | // convert diff_dst to f32 |
308 | acc_data_t *tmp_diff_dst = tmp_data_ + ithr * C_align; |
309 | types::cvt_to_float(tmp_diff_dst, diff_dst + s_off, C); |
310 | _diff_dst = tmp_diff_dst; |
311 | // convert src to f32 |
312 | acc_data_t *tmp_src = tmp_data_ + (nthr + ithr) * C_align; |
313 | types::cvt_to_float(tmp_src, src + s_off, C); |
314 | _src = tmp_src; |
315 | } else { |
316 | _diff_dst = reinterpret_cast<const acc_data_t *>( |
317 | diff_dst + s_off); |
318 | _src = reinterpret_cast<const acc_data_t *>(src + s_off); |
319 | } |
320 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
321 | PRAGMA_OMP_SIMD() |
322 | #endif |
323 | for (dim_t c = 0; c < C; c++) { |
324 | const size_t c_off = s_off + c; |
325 | acc_data_t dd; |
326 | if (fuse_norm_relu && !ws[c_off]) |
327 | dd = 0; |
328 | else |
329 | dd = _diff_dst[c]; |
330 | ws_reduce[C * ithr + c] += (_src[c] - mean[c]) * dd; |
331 | ws_reduce[C * nthr + C * ithr + c] += dd; |
332 | } |
333 | } |
334 | } |
335 | }); |
336 | |
337 | parallel_nd(C, [&](dim_t c) { |
338 | acc_data_t sqrt_variance |
339 | = static_cast<acc_data_t>(1.0f / sqrtf(variance[c] + eps)); |
340 | diff_gamma[c] = 0; |
341 | diff_beta[c] = 0; |
342 | for (dim_t n = 0; n < nthr; n++) { |
343 | diff_gamma[c] += ws_reduce[C * n + c]; |
344 | diff_beta[c] += ws_reduce[C * nthr + C * n + c]; |
345 | } |
346 | diff_gamma[c] *= sqrt_variance; |
347 | }); |
348 | |
349 | parallel(nthr, [&](const int ithr, const int nthr) { |
350 | dim_t N_s = 0, N_e = 0; |
351 | balance211(N, nthr, ithr, N_s, N_e); |
352 | |
353 | acc_data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; |
354 | acc_data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); |
355 | |
356 | for (dim_t c = 0; c < C; c++) { |
357 | diff_gamma_loc[c] = diff_gamma[c]; |
358 | diff_beta_loc[c] = diff_beta[c]; |
359 | } |
360 | |
361 | for (dim_t n = N_s; n < N_e; n++) { |
362 | for (dim_t sp = 0; sp < SP; sp++) { |
363 | acc_data_t *_diff_src; |
364 | const acc_data_t *_diff_dst; |
365 | const acc_data_t *_src; |
366 | const size_t s_off = (size_t)n * SP * C + sp * C; |
367 | if (utils::one_of(d_type, bf16, f16)) { |
368 | // store diff_src to f32 buffer |
369 | _diff_src = tmp_data_ + ithr * C_align; |
370 | // convert diff_dst to f32 |
371 | acc_data_t *tmp_diff_dst = tmp_data_ + ithr * C_align; |
372 | types::cvt_to_float(tmp_diff_dst, diff_dst + s_off, C); |
373 | _diff_dst = tmp_diff_dst; |
374 | if (calculate_diff_stats) { |
375 | // convert src to f32 |
376 | acc_data_t *tmp_src |
377 | = tmp_data_ + (2 * nthr + ithr) * C_align; |
378 | types::cvt_to_float(tmp_src, src + s_off, C); |
379 | _src = tmp_src; |
380 | } else |
381 | _src = nullptr; // to avoid compiler warning w/ gcc483 |
382 | } else { |
383 | _diff_src |
384 | = reinterpret_cast<acc_data_t *>(diff_src + s_off); |
385 | _diff_dst = reinterpret_cast<const acc_data_t *>( |
386 | diff_dst + s_off); |
387 | _src = reinterpret_cast<const acc_data_t *>(src + s_off); |
388 | } |
389 | |
390 | #if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD |
391 | PRAGMA_OMP_SIMD(simdlen(16)) |
392 | #endif |
393 | for (dim_t c = 0; c < nb_c_blk * c_blk; c++) { |
394 | const size_t c_off = s_off + c; |
395 | acc_data_t gamma = use_scale ? scale[c] : 1; |
396 | acc_data_t sqrt_variance = static_cast<acc_data_t>( |
397 | 1.0f / sqrtf(variance[c] + eps)); |
398 | acc_data_t v_diff_src; |
399 | if (fuse_norm_relu && !ws[c_off]) |
400 | v_diff_src = 0; |
401 | else |
402 | v_diff_src = _diff_dst[c]; |
403 | if (calculate_diff_stats) { |
404 | v_diff_src -= diff_beta_loc[c] / (SP * N) |
405 | + (_src[c] - mean[c]) * diff_gamma_loc[c] |
406 | * sqrt_variance / (SP * N); |
407 | } |
408 | v_diff_src *= gamma * sqrt_variance; |
409 | _diff_src[c] = v_diff_src; |
410 | } |
411 | for (dim_t c = 0; c < tail; c++) { |
412 | const size_t c_off = s_off + nb_c_blk * c_blk + c; |
413 | acc_data_t gamma |
414 | = use_scale ? scale[nb_c_blk * c_blk + c] : 1; |
415 | acc_data_t sqrt_variance = static_cast<acc_data_t>( |
416 | 1.0f / sqrtf(variance[nb_c_blk * c_blk + c] + eps)); |
417 | acc_data_t v_diff_src; |
418 | if (fuse_norm_relu && !ws[c_off]) |
419 | v_diff_src = 0; |
420 | else |
421 | v_diff_src = _diff_dst[nb_c_blk * c_blk + c]; |
422 | if (calculate_diff_stats) { |
423 | v_diff_src -= diff_beta_loc[nb_c_blk * c_blk + c] |
424 | / (SP * N) |
425 | + (_src[nb_c_blk * c_blk + c] |
426 | - mean[nb_c_blk * c_blk + c]) |
427 | * diff_gamma_loc[nb_c_blk * c_blk + c] |
428 | * sqrt_variance / (SP * N); |
429 | } |
430 | v_diff_src *= gamma * sqrt_variance; |
431 | _diff_src[nb_c_blk * c_blk + c] = v_diff_src; |
432 | } |
433 | if (utils::one_of(d_type, bf16, f16)) { |
434 | // convert diff_src from f32 |
435 | types::cvt_from_float(diff_src + s_off, _diff_src, C); |
436 | } |
437 | } |
438 | } |
439 | }); |
440 | return status::success; |
441 | } |
442 | |
443 | template struct nspc_batch_normalization_bwd_t<f32>; |
444 | template struct nspc_batch_normalization_bwd_t<bf16>; |
445 | template struct nspc_batch_normalization_bwd_t<f16>; |
446 | } // namespace cpu |
447 | } // namespace impl |
448 | } // namespace dnnl |
449 | |
450 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
451 | |