1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/bfloat16.hpp"
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/memory_tracking.hpp"
24#include "common/type_helpers.hpp"
25#include "cpu/ref_batch_normalization.hpp"
26#include "cpu/simple_q10n.hpp"
27
28#define DATA_OFF(f, n, c, d, h, w) \
29 (ndims == 2) ? (f).off(n, c) \
30 : ((ndims == 3) ? (f).off(n, c, w) \
31 : ((ndims == 4) ? (f).off(n, c, h, w) \
32 : (f).off(n, c, d, h, w)))
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38using namespace memory_tracking::names;
39
40namespace {
41
42using acc_data_t = float;
43
44template <typename T>
45inline float maybe_up_convert(T x) {
46 return x;
47}
48
49template <>
50inline float maybe_up_convert<bfloat16_t>(bfloat16_t x) {
51 return (float)x;
52}
53
54} // namespace
55
56using namespace data_type;
57
58template <impl::data_type_t d_type>
59status_t ref_batch_normalization_fwd_t<d_type>::execute_forward(
60 const exec_ctx_t &ctx) const {
61 /* fast return */
62 if (this->pd()->has_zero_dim_memory()) return status::success;
63
64 status_t status = status::success;
65
66 const memory_desc_wrapper data_d(pd()->src_md());
67 const memory_desc_wrapper ss_d(pd()->weights_md());
68
69 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
70 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
71 auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT);
72
73 auto mean = pd()->stats_is_src()
74 ? const_cast<acc_data_t *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN))
75 : CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_MEAN, status);
76 CHECK(status);
77 auto variance = pd()->stats_is_src()
78 ? const_cast<acc_data_t *>(
79 CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE))
80 : CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_VARIANCE, status);
81 CHECK(status);
82
83 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
84 CHECK(status);
85 auto ws = CTX_OUT_CLEAN_MEM(uint8_t *, DNNL_ARG_WORKSPACE, status);
86 CHECK(status);
87
88 const auto ndims = data_d.ndims();
89 const auto N = pd()->MB();
90 const auto C = pd()->C();
91 const auto D = pd()->D();
92 const auto H = pd()->H();
93 const auto W = pd()->W();
94
95 const auto eps = pd()->desc()->batch_norm_epsilon;
96 const auto calculate_stats = !pd()->stats_is_src();
97 const auto fuse_norm_relu = pd()->fuse_norm_relu();
98 const auto save_stats = pd()->is_training();
99 const auto is_training = pd()->is_training();
100
101 /* fast return */
102 if (this->pd()->has_zero_dim_memory()) {
103 if (calculate_stats && save_stats)
104 for (dim_t c = 0; c < pd()->C(); c++) {
105 mean[c] = 0;
106 variance[c] = 0;
107 }
108 return status::success;
109 }
110
111 const bool with_relu = pd()->with_relu_post_op(is_training);
112 auto maybe_post_op = [&](acc_data_t res) {
113 if (with_relu) return math::relu_fwd(res, pd()->alpha());
114 return res;
115 };
116
117 parallel_nd(C, [&](dim_t c) {
118 acc_data_t v_mean = calculate_stats ? 0 : mean[c];
119 acc_data_t v_variance = calculate_stats ? 0 : variance[c];
120
121 if (calculate_stats) {
122 for_(int n = 0; n < N; ++n)
123 for_(int d = 0; d < D; ++d)
124 for_(int h = 0; h < H; ++h)
125 for (int w = 0; w < W; ++w) {
126 v_mean += maybe_up_convert(
127 src[DATA_OFF(data_d, n, c, d, h, w)]);
128 }
129 v_mean /= W * N * H * D;
130
131 for_(int n = 0; n < N; ++n)
132 for_(int d = 0; d < D; ++d)
133 for_(int h = 0; h < H; ++h)
134 for (int w = 0; w < W; ++w) {
135 acc_data_t m = src[DATA_OFF(data_d, n, c, d, h, w)] - v_mean;
136 v_variance += m * m;
137 }
138 v_variance /= W * H * N * D;
139 }
140
141 acc_data_t sqrt_variance = sqrtf(v_variance + eps);
142 acc_data_t sm = (scale ? scale[ss_d.off(c)] : 1.0f) / sqrt_variance;
143 acc_data_t sv = shift ? shift[ss_d.off(c)] : 0;
144
145 for_(dim_t n = 0; n < N; ++n)
146 for_(dim_t d = 0; d < D; ++d)
147 for_(dim_t h = 0; h < H; ++h)
148 for (dim_t w = 0; w < W; ++w) {
149 auto d_off = DATA_OFF(data_d, n, c, d, h, w);
150 acc_data_t bn_res
151 = sm * (maybe_up_convert(src[d_off]) - v_mean) + sv;
152 if (fuse_norm_relu) {
153 if (bn_res <= 0) {
154 bn_res = 0;
155 if (is_training) ws[d_off] = 0;
156 } else {
157 if (is_training) ws[d_off] = 1;
158 }
159 }
160 if (d_type == s8)
161 dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res));
162 else
163 dst[d_off] = maybe_post_op(bn_res);
164 }
165
166 if (calculate_stats) {
167 if (save_stats) {
168 mean[c] = v_mean;
169 variance[c] = v_variance;
170 }
171 }
172 });
173 return status::success;
174}
175
176template struct ref_batch_normalization_fwd_t<s8>;
177template struct ref_batch_normalization_fwd_t<f32>;
178template struct ref_batch_normalization_fwd_t<bf16>;
179template struct ref_batch_normalization_fwd_t<f16>;
180
181template <impl::data_type_t d_type>
182status_t ref_batch_normalization_bwd_t<d_type>::execute_backward(
183 const exec_ctx_t &ctx) const {
184 status_t status = status::success;
185
186 const memory_desc_wrapper data_d(pd()->src_md());
187 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
188 const memory_desc_wrapper ss_d(pd()->weights_md());
189 const memory_desc_wrapper diff_ss_d(pd()->diff_weights_md());
190
191 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
192 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
193 auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
194 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
195 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
196
197 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
198 CHECK(status);
199
200 auto scale = CTX_IN_MEM(acc_data_t *, DNNL_ARG_SCALE);
201 auto diff_scale
202 = CTX_OUT_CLEAN_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE, status);
203 CHECK(status);
204 auto diff_shift
205 = CTX_OUT_CLEAN_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT, status);
206 CHECK(status);
207
208 const auto ndims = data_d.ndims();
209 const auto N = pd()->MB();
210 const auto C = pd()->C();
211 const auto D = pd()->D();
212 const auto H = pd()->H();
213 const auto W = pd()->W();
214
215 const auto eps = pd()->desc()->batch_norm_epsilon;
216 const auto calculate_diff_stats = !pd()->use_global_stats();
217 const auto fuse_norm_relu = pd()->fuse_norm_relu();
218
219 /* fast return */
220 if (this->pd()->has_zero_dim_memory()) {
221 if (diff_scale) {
222 for (dim_t c = 0; c < C; ++c) {
223 diff_scale[diff_ss_d.off(c)] = 0.0f;
224 }
225 }
226 if (diff_shift) {
227 for (dim_t c = 0; c < C; ++c) {
228 diff_shift[diff_ss_d.off(c)] = 0.0f;
229 }
230 }
231 return status::success;
232 }
233
234 parallel_nd(C, [&](dim_t c) {
235 acc_data_t v_mean = mean[c];
236 acc_data_t v_variance = variance[c];
237 acc_data_t sqrt_variance
238 = static_cast<acc_data_t>(1.0f / sqrtf(v_variance + eps));
239 acc_data_t gamma = scale ? scale[ss_d.off(c)] : 1.0f;
240 acc_data_t diff_gamma = 0;
241 acc_data_t diff_beta = 0;
242
243 for_(dim_t n = 0; n < N; ++n)
244 for_(dim_t d = 0; d < D; ++d)
245 for_(dim_t h = 0; h < H; ++h)
246 for (dim_t w = 0; w < W; ++w) {
247 const size_t s_off = DATA_OFF(data_d, n, c, d, h, w);
248 acc_data_t dd;
249 if (fuse_norm_relu && !ws[s_off])
250 dd = 0;
251 else
252 dd = maybe_up_convert(
253 diff_dst[DATA_OFF(diff_data_d, n, c, d, h, w)]);
254 diff_gamma += (maybe_up_convert(src[s_off]) - v_mean) * dd;
255 diff_beta += dd;
256 }
257 diff_gamma *= sqrt_variance;
258
259 if (diff_scale) diff_scale[diff_ss_d.off(c)] = diff_gamma;
260 if (diff_shift) diff_shift[diff_ss_d.off(c)] = diff_beta;
261
262 for_(dim_t n = 0; n < N; ++n)
263 for_(dim_t d = 0; d < D; ++d)
264 for_(dim_t h = 0; h < H; ++h)
265 for (dim_t w = 0; w < W; ++w) {
266 const size_t s_off = DATA_OFF(data_d, n, c, d, h, w);
267 const size_t dd_off = DATA_OFF(diff_data_d, n, c, d, h, w);
268 acc_data_t dd;
269 if (fuse_norm_relu && !ws[s_off])
270 dd = 0;
271 else
272 dd = maybe_up_convert(diff_dst[dd_off]);
273 acc_data_t v_diff_src = dd;
274 if (calculate_diff_stats) {
275 v_diff_src -= diff_beta / (D * W * H * N)
276 + (maybe_up_convert(src[s_off]) - v_mean) * diff_gamma
277 * sqrt_variance / (D * W * H * N);
278 }
279 v_diff_src *= gamma * sqrt_variance;
280 diff_src[dd_off] = v_diff_src;
281 }
282 });
283 return status::success;
284}
285
286template struct ref_batch_normalization_bwd_t<f32>;
287template struct ref_batch_normalization_bwd_t<bf16>;
288template struct ref_batch_normalization_bwd_t<f16>;
289
290} // namespace cpu
291} // namespace impl
292} // namespace dnnl
293
294// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
295