1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include <algorithm>
21
22#include "common/c_types_map.hpp"
23#include "common/compiler_workarounds.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/platform.hpp"
28
29#include "cpu/cpu_batch_normalization_utils.hpp"
30
31#include "cpu/nspc_batch_normalization.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37using namespace memory_tracking::names;
38using namespace data_type;
39
40template <data_type_t d_type>
41status_t nspc_batch_normalization_fwd_t<d_type>::execute_forward(
42 const exec_ctx_t &ctx) const {
43 const bool save_stats = pd()->is_training();
44 const bool is_training = pd()->is_training();
45 const bool fuse_norm_relu = pd()->fuse_norm_relu();
46 const bool calculate_stats = !pd()->stats_is_src();
47 const bool with_relu = pd()->with_relu_post_op(is_training);
48
49 const auto use_scale = pd()->use_scale();
50 const auto use_shift = pd()->use_shift();
51
52 auto scratchpad = ctx.get_scratchpad_grantor();
53 auto tmp_mean = scratchpad.template get<acc_data_t>(key_bnorm_tmp_mean);
54 auto tmp_var = scratchpad.template get<acc_data_t>(key_bnorm_tmp_var);
55 auto *ws_reduce = scratchpad.template get<acc_data_t>(key_bnorm_reduction);
56
57 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
58 auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE);
59 auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT);
60
61 acc_data_t *mean, *variance;
62 if (!calculate_stats) {
63 mean = const_cast<acc_data_t *>(
64 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN));
65 variance = const_cast<acc_data_t *>(
66 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE));
67 } else {
68 if (save_stats) {
69 mean = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
70 variance = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
71 } else {
72 mean = tmp_mean;
73 variance = tmp_var;
74 }
75 }
76
77 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
78 auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
79 acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt);
80
81 const dim_t N = pd()->MB();
82 const dim_t C = pd()->C();
83 const int simd_w = 16;
84 const dim_t C_align = utils::rnd_up(C, simd_w);
85 const dim_t SP = pd()->H() * pd()->W() * pd()->D();
86
87 const float eps = pd()->desc()->batch_norm_epsilon;
88 auto maybe_post_op = [&](acc_data_t res) {
89 if (with_relu) return math::relu_fwd(res, pd()->alpha());
90 return res;
91 };
92 const int nthr = pd()->nthr_;
93
94 if (calculate_stats) {
95 parallel(nthr, [&](const int ithr, const int nthr) {
96 dim_t N_s = 0, N_e = 0;
97 balance211(N, nthr, ithr, N_s, N_e);
98
99 for (dim_t c = 0; c < C; c++)
100 ws_reduce[C * ithr + c] = 0.;
101
102 for (dim_t n = N_s; n < N_e; n++) {
103 for (dim_t sp = 0; sp < SP; sp++) {
104 const acc_data_t *_src;
105 const size_t s_off = (size_t)n * SP * C + sp * C;
106 if (utils::one_of(d_type, bf16, f16)) {
107 // convert src from xf16 to f32
108 acc_data_t *tmp_src = tmp_data_ + ithr * C_align;
109 types::cvt_to_float(tmp_src, src + s_off, C);
110 _src = tmp_src;
111 } else {
112 _src = reinterpret_cast<const acc_data_t *>(
113 src + s_off);
114 }
115 PRAGMA_OMP_SIMD()
116 for (int c = 0; c < C; c++) {
117 ws_reduce[C * ithr + c] += _src[c];
118 }
119 }
120 }
121 });
122 parallel_nd(C, [&](dim_t c) {
123 mean[c] = 0;
124 for (dim_t n = 0; n < nthr; n++)
125 mean[c] += ws_reduce[C * n + c];
126 mean[c] /= SP * N;
127 });
128 parallel(nthr, [&](const int ithr, const int nthr) {
129 dim_t N_s = 0, N_e = 0;
130 balance211(N, nthr, ithr, N_s, N_e);
131
132 acc_data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
133
134 if (ithr > 0 || save_stats) {
135 for (dim_t c = 0; c < C; c++)
136 mean_loc[c] = mean[c];
137 }
138
139 for (dim_t c = 0; c < C; c++)
140 ws_reduce[C * ithr + c] = 0.;
141
142 for (dim_t n = N_s; n < N_e; n++) {
143 for (dim_t sp = 0; sp < SP; sp++) {
144 const acc_data_t *_src;
145 const size_t s_off = (size_t)n * SP * C + sp * C;
146 if (utils::one_of(d_type, bf16, f16)) {
147 // convert src from xf16 to f32
148 acc_data_t *tmp_src = tmp_data_ + ithr * C_align;
149 types::cvt_to_float(tmp_src, src + s_off, C);
150 _src = tmp_src;
151 } else {
152 _src = reinterpret_cast<const acc_data_t *>(
153 src + s_off);
154 }
155 PRAGMA_OMP_SIMD()
156 for (int c = 0; c < C; c++) {
157 acc_data_t m = _src[c] - mean_loc[c];
158 ws_reduce[C * ithr + c] += m * m;
159 }
160 }
161 }
162 });
163 parallel_nd(C, [&](dim_t c) {
164 variance[c] = 0;
165 for (dim_t n = 0; n < nthr; n++)
166 variance[c] += ws_reduce[C * n + c];
167 variance[c] /= SP * N;
168 });
169 parallel(nthr, [&](const int ithr, const int nthr) {
170 acc_data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;
171 if (ithr > 0 || save_stats) {
172 for (dim_t c = 0; c < C; c++)
173 variance_loc[c] = variance[c];
174 }
175 });
176 }
177
178 parallel(nthr, [&](const int ithr, const int nthr) {
179 dim_t N_s = 0, N_e = 0;
180 balance211(N, nthr, ithr, N_s, N_e);
181
182 acc_data_t *mean_loc, *variance_loc;
183 if (calculate_stats) {
184 mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
185 variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;
186 } else {
187 mean_loc = mean;
188 variance_loc = variance;
189 }
190
191 for (dim_t n = N_s; n < N_e; n++) {
192 for (dim_t sp = 0; sp < SP; sp++) {
193 acc_data_t *_dst;
194 const acc_data_t *_src;
195 const size_t s_off = (size_t)n * SP * C + sp * C;
196 if (utils::one_of(d_type, bf16, f16)) {
197 // store dst to f32 buffer
198 _dst = tmp_data_ + ithr * C_align;
199 // convert src from xf16 to f32
200 acc_data_t *tmp_src = tmp_data_ + (nthr + ithr) * C_align;
201 types::cvt_to_float(tmp_src, src + s_off, C);
202 _src = tmp_src;
203 } else {
204 _dst = reinterpret_cast<acc_data_t *>(dst + s_off);
205 _src = reinterpret_cast<const acc_data_t *>(src + s_off);
206 }
207#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
208 PRAGMA_OMP_SIMD()
209#endif
210 for (int c = 0; c < C; c++) {
211 const size_t c_off = s_off + c;
212 acc_data_t sqrt_variance = static_cast<acc_data_t>(
213 sqrtf(variance_loc[c] + eps));
214 acc_data_t sm = (use_scale ? (acc_data_t)scale[c]
215 : (acc_data_t)1.0f)
216 / sqrt_variance;
217 acc_data_t sv
218 = use_shift ? (acc_data_t)shift[c] : (acc_data_t)0;
219 acc_data_t bn_res = sm * (_src[c] - mean_loc[c]) + sv;
220 if (fuse_norm_relu) {
221 if (bn_res <= 0) {
222 bn_res = 0;
223 if (is_training) ws[c_off] = 0;
224 } else {
225 if (is_training) ws[c_off] = 1;
226 }
227 }
228 _dst[c] = maybe_post_op(bn_res);
229 }
230 if (utils::one_of(d_type, bf16, f16)) {
231 // convert dst from f32 to xf16
232 types::cvt_from_float(dst + s_off, _dst, C);
233 }
234 }
235 }
236 });
237 return status::success;
238}
239
240template struct nspc_batch_normalization_fwd_t<f32>;
241template struct nspc_batch_normalization_fwd_t<bf16>;
242template struct nspc_batch_normalization_fwd_t<f16>;
243
244template <data_type_t d_type>
245status_t nspc_batch_normalization_bwd_t<d_type>::execute_backward(
246 const exec_ctx_t &ctx) const {
247
248 const auto use_scale = pd()->use_scale();
249
250 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
251 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
252 auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
253 auto scale = CTX_IN_MEM(acc_data_t *, DNNL_ARG_SCALE);
254 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
255 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
256
257 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
258 auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE);
259 auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT);
260
261 auto scratchpad = ctx.get_scratchpad_grantor();
262 auto tmp_diff_ss
263 = scratchpad.template get<acc_data_t>(key_bnorm_tmp_diff_ss);
264
265 const size_t scratch_diff_shift_off = diff_scale ? 0 : pd()->C();
266 if (diff_scale == nullptr) diff_scale = tmp_diff_ss;
267 if (diff_shift == nullptr)
268 diff_shift = &tmp_diff_ss[scratch_diff_shift_off];
269
270 const dim_t N = pd()->MB();
271 const dim_t C = pd()->C();
272 const int simd_w = 16;
273 const dim_t C_align = utils::rnd_up(C, simd_w);
274 const dim_t SP = pd()->D() * pd()->H() * pd()->W();
275 acc_data_t *diff_gamma = diff_scale, *diff_beta = diff_shift;
276 acc_data_t *ws_reduce
277 = scratchpad.template get<acc_data_t>(key_bnorm_reduction);
278 acc_data_t *tmp_data_ = scratchpad.template get<acc_data_t>(key_bnorm_cvt);
279
280 const float eps = pd()->desc()->batch_norm_epsilon;
281 const bool calculate_diff_stats = !pd()->use_global_stats();
282 const bool fuse_norm_relu = pd()->fuse_norm_relu();
283
284 /* Note: potential seg-fault from incorrectly compiled vectorized-loop.
285 * Explicit tail-processing fixes this issue. */
286 const dim_t c_blk = std::max(
287 platform::get_vector_register_size() / (int)sizeof(float), 8);
288 const dim_t tail = C % c_blk;
289 const dim_t nb_c_blk = (size_t)C / c_blk;
290 const int nthr = pd()->nthr_;
291
292 parallel(nthr, [&](const int ithr, const int nthr) {
293 dim_t N_s = 0, N_e = 0;
294 balance211(N, nthr, ithr, N_s, N_e);
295
296 for (dim_t c = 0; c < C; c++) {
297 ws_reduce[C * ithr + c] = 0.;
298 ws_reduce[C * nthr + C * ithr + c] = 0.;
299 }
300
301 for (dim_t n = N_s; n < N_e; n++) {
302 for (dim_t sp = 0; sp < SP; sp++) {
303 const acc_data_t *_diff_dst;
304 const acc_data_t *_src;
305 const size_t s_off = (size_t)n * SP * C + sp * C;
306 if (utils::one_of(d_type, bf16, f16)) {
307 // convert diff_dst to f32
308 acc_data_t *tmp_diff_dst = tmp_data_ + ithr * C_align;
309 types::cvt_to_float(tmp_diff_dst, diff_dst + s_off, C);
310 _diff_dst = tmp_diff_dst;
311 // convert src to f32
312 acc_data_t *tmp_src = tmp_data_ + (nthr + ithr) * C_align;
313 types::cvt_to_float(tmp_src, src + s_off, C);
314 _src = tmp_src;
315 } else {
316 _diff_dst = reinterpret_cast<const acc_data_t *>(
317 diff_dst + s_off);
318 _src = reinterpret_cast<const acc_data_t *>(src + s_off);
319 }
320#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
321 PRAGMA_OMP_SIMD()
322#endif
323 for (dim_t c = 0; c < C; c++) {
324 const size_t c_off = s_off + c;
325 acc_data_t dd;
326 if (fuse_norm_relu && !ws[c_off])
327 dd = 0;
328 else
329 dd = _diff_dst[c];
330 ws_reduce[C * ithr + c] += (_src[c] - mean[c]) * dd;
331 ws_reduce[C * nthr + C * ithr + c] += dd;
332 }
333 }
334 }
335 });
336
337 parallel_nd(C, [&](dim_t c) {
338 acc_data_t sqrt_variance
339 = static_cast<acc_data_t>(1.0f / sqrtf(variance[c] + eps));
340 diff_gamma[c] = 0;
341 diff_beta[c] = 0;
342 for (dim_t n = 0; n < nthr; n++) {
343 diff_gamma[c] += ws_reduce[C * n + c];
344 diff_beta[c] += ws_reduce[C * nthr + C * n + c];
345 }
346 diff_gamma[c] *= sqrt_variance;
347 });
348
349 parallel(nthr, [&](const int ithr, const int nthr) {
350 dim_t N_s = 0, N_e = 0;
351 balance211(N, nthr, ithr, N_s, N_e);
352
353 acc_data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr;
354 acc_data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr);
355
356 for (dim_t c = 0; c < C; c++) {
357 diff_gamma_loc[c] = diff_gamma[c];
358 diff_beta_loc[c] = diff_beta[c];
359 }
360
361 for (dim_t n = N_s; n < N_e; n++) {
362 for (dim_t sp = 0; sp < SP; sp++) {
363 acc_data_t *_diff_src;
364 const acc_data_t *_diff_dst;
365 const acc_data_t *_src;
366 const size_t s_off = (size_t)n * SP * C + sp * C;
367 if (utils::one_of(d_type, bf16, f16)) {
368 // store diff_src to f32 buffer
369 _diff_src = tmp_data_ + ithr * C_align;
370 // convert diff_dst to f32
371 acc_data_t *tmp_diff_dst = tmp_data_ + ithr * C_align;
372 types::cvt_to_float(tmp_diff_dst, diff_dst + s_off, C);
373 _diff_dst = tmp_diff_dst;
374 if (calculate_diff_stats) {
375 // convert src to f32
376 acc_data_t *tmp_src
377 = tmp_data_ + (2 * nthr + ithr) * C_align;
378 types::cvt_to_float(tmp_src, src + s_off, C);
379 _src = tmp_src;
380 } else
381 _src = nullptr; // to avoid compiler warning w/ gcc483
382 } else {
383 _diff_src
384 = reinterpret_cast<acc_data_t *>(diff_src + s_off);
385 _diff_dst = reinterpret_cast<const acc_data_t *>(
386 diff_dst + s_off);
387 _src = reinterpret_cast<const acc_data_t *>(src + s_off);
388 }
389
390#if CLANG_WA_02_SAFE_TO_USE_OMP_SIMD
391 PRAGMA_OMP_SIMD(simdlen(16))
392#endif
393 for (dim_t c = 0; c < nb_c_blk * c_blk; c++) {
394 const size_t c_off = s_off + c;
395 acc_data_t gamma = use_scale ? scale[c] : 1;
396 acc_data_t sqrt_variance = static_cast<acc_data_t>(
397 1.0f / sqrtf(variance[c] + eps));
398 acc_data_t v_diff_src;
399 if (fuse_norm_relu && !ws[c_off])
400 v_diff_src = 0;
401 else
402 v_diff_src = _diff_dst[c];
403 if (calculate_diff_stats) {
404 v_diff_src -= diff_beta_loc[c] / (SP * N)
405 + (_src[c] - mean[c]) * diff_gamma_loc[c]
406 * sqrt_variance / (SP * N);
407 }
408 v_diff_src *= gamma * sqrt_variance;
409 _diff_src[c] = v_diff_src;
410 }
411 for (dim_t c = 0; c < tail; c++) {
412 const size_t c_off = s_off + nb_c_blk * c_blk + c;
413 acc_data_t gamma
414 = use_scale ? scale[nb_c_blk * c_blk + c] : 1;
415 acc_data_t sqrt_variance = static_cast<acc_data_t>(
416 1.0f / sqrtf(variance[nb_c_blk * c_blk + c] + eps));
417 acc_data_t v_diff_src;
418 if (fuse_norm_relu && !ws[c_off])
419 v_diff_src = 0;
420 else
421 v_diff_src = _diff_dst[nb_c_blk * c_blk + c];
422 if (calculate_diff_stats) {
423 v_diff_src -= diff_beta_loc[nb_c_blk * c_blk + c]
424 / (SP * N)
425 + (_src[nb_c_blk * c_blk + c]
426 - mean[nb_c_blk * c_blk + c])
427 * diff_gamma_loc[nb_c_blk * c_blk + c]
428 * sqrt_variance / (SP * N);
429 }
430 v_diff_src *= gamma * sqrt_variance;
431 _diff_src[nb_c_blk * c_blk + c] = v_diff_src;
432 }
433 if (utils::one_of(d_type, bf16, f16)) {
434 // convert diff_src from f32
435 types::cvt_from_float(diff_src + s_off, _diff_src, C);
436 }
437 }
438 }
439 });
440 return status::success;
441}
442
443template struct nspc_batch_normalization_bwd_t<f32>;
444template struct nspc_batch_normalization_bwd_t<bf16>;
445template struct nspc_batch_normalization_bwd_t<f16>;
446} // namespace cpu
447} // namespace impl
448} // namespace dnnl
449
450// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
451