1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/bfloat16.hpp" |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "cpu/ref_batch_normalization.hpp" |
26 | #include "cpu/simple_q10n.hpp" |
27 | |
28 | #define DATA_OFF(f, n, c, d, h, w) \ |
29 | (ndims == 2) ? (f).off(n, c) \ |
30 | : ((ndims == 3) ? (f).off(n, c, w) \ |
31 | : ((ndims == 4) ? (f).off(n, c, h, w) \ |
32 | : (f).off(n, c, d, h, w))) |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | |
38 | using namespace memory_tracking::names; |
39 | |
40 | namespace { |
41 | |
42 | using acc_data_t = float; |
43 | |
44 | template <typename T> |
45 | inline float maybe_up_convert(T x) { |
46 | return x; |
47 | } |
48 | |
49 | template <> |
50 | inline float maybe_up_convert<bfloat16_t>(bfloat16_t x) { |
51 | return (float)x; |
52 | } |
53 | |
54 | } // namespace |
55 | |
56 | using namespace data_type; |
57 | |
58 | template <impl::data_type_t d_type> |
59 | status_t ref_batch_normalization_fwd_t<d_type>::execute_forward( |
60 | const exec_ctx_t &ctx) const { |
61 | /* fast return */ |
62 | if (this->pd()->has_zero_dim_memory()) return status::success; |
63 | |
64 | status_t status = status::success; |
65 | |
66 | const memory_desc_wrapper data_d(pd()->src_md()); |
67 | const memory_desc_wrapper ss_d(pd()->weights_md()); |
68 | |
69 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
70 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
71 | auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT); |
72 | |
73 | auto mean = pd()->stats_is_src() |
74 | ? const_cast<acc_data_t *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN)) |
75 | : CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_MEAN, status); |
76 | CHECK(status); |
77 | auto variance = pd()->stats_is_src() |
78 | ? const_cast<acc_data_t *>( |
79 | CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE)) |
80 | : CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_VARIANCE, status); |
81 | CHECK(status); |
82 | |
83 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
84 | CHECK(status); |
85 | auto ws = CTX_OUT_CLEAN_MEM(uint8_t *, DNNL_ARG_WORKSPACE, status); |
86 | CHECK(status); |
87 | |
88 | const auto ndims = data_d.ndims(); |
89 | const auto N = pd()->MB(); |
90 | const auto C = pd()->C(); |
91 | const auto D = pd()->D(); |
92 | const auto H = pd()->H(); |
93 | const auto W = pd()->W(); |
94 | |
95 | const auto eps = pd()->desc()->batch_norm_epsilon; |
96 | const auto calculate_stats = !pd()->stats_is_src(); |
97 | const auto fuse_norm_relu = pd()->fuse_norm_relu(); |
98 | const auto save_stats = pd()->is_training(); |
99 | const auto is_training = pd()->is_training(); |
100 | |
101 | /* fast return */ |
102 | if (this->pd()->has_zero_dim_memory()) { |
103 | if (calculate_stats && save_stats) |
104 | for (dim_t c = 0; c < pd()->C(); c++) { |
105 | mean[c] = 0; |
106 | variance[c] = 0; |
107 | } |
108 | return status::success; |
109 | } |
110 | |
111 | const bool with_relu = pd()->with_relu_post_op(is_training); |
112 | auto maybe_post_op = [&](acc_data_t res) { |
113 | if (with_relu) return math::relu_fwd(res, pd()->alpha()); |
114 | return res; |
115 | }; |
116 | |
117 | parallel_nd(C, [&](dim_t c) { |
118 | acc_data_t v_mean = calculate_stats ? 0 : mean[c]; |
119 | acc_data_t v_variance = calculate_stats ? 0 : variance[c]; |
120 | |
121 | if (calculate_stats) { |
122 | for_(int n = 0; n < N; ++n) |
123 | for_(int d = 0; d < D; ++d) |
124 | for_(int h = 0; h < H; ++h) |
125 | for (int w = 0; w < W; ++w) { |
126 | v_mean += maybe_up_convert( |
127 | src[DATA_OFF(data_d, n, c, d, h, w)]); |
128 | } |
129 | v_mean /= W * N * H * D; |
130 | |
131 | for_(int n = 0; n < N; ++n) |
132 | for_(int d = 0; d < D; ++d) |
133 | for_(int h = 0; h < H; ++h) |
134 | for (int w = 0; w < W; ++w) { |
135 | acc_data_t m = src[DATA_OFF(data_d, n, c, d, h, w)] - v_mean; |
136 | v_variance += m * m; |
137 | } |
138 | v_variance /= W * H * N * D; |
139 | } |
140 | |
141 | acc_data_t sqrt_variance = sqrtf(v_variance + eps); |
142 | acc_data_t sm = (scale ? scale[ss_d.off(c)] : 1.0f) / sqrt_variance; |
143 | acc_data_t sv = shift ? shift[ss_d.off(c)] : 0; |
144 | |
145 | for_(dim_t n = 0; n < N; ++n) |
146 | for_(dim_t d = 0; d < D; ++d) |
147 | for_(dim_t h = 0; h < H; ++h) |
148 | for (dim_t w = 0; w < W; ++w) { |
149 | auto d_off = DATA_OFF(data_d, n, c, d, h, w); |
150 | acc_data_t bn_res |
151 | = sm * (maybe_up_convert(src[d_off]) - v_mean) + sv; |
152 | if (fuse_norm_relu) { |
153 | if (bn_res <= 0) { |
154 | bn_res = 0; |
155 | if (is_training) ws[d_off] = 0; |
156 | } else { |
157 | if (is_training) ws[d_off] = 1; |
158 | } |
159 | } |
160 | if (d_type == s8) |
161 | dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res)); |
162 | else |
163 | dst[d_off] = maybe_post_op(bn_res); |
164 | } |
165 | |
166 | if (calculate_stats) { |
167 | if (save_stats) { |
168 | mean[c] = v_mean; |
169 | variance[c] = v_variance; |
170 | } |
171 | } |
172 | }); |
173 | return status::success; |
174 | } |
175 | |
176 | template struct ref_batch_normalization_fwd_t<s8>; |
177 | template struct ref_batch_normalization_fwd_t<f32>; |
178 | template struct ref_batch_normalization_fwd_t<bf16>; |
179 | template struct ref_batch_normalization_fwd_t<f16>; |
180 | |
181 | template <impl::data_type_t d_type> |
182 | status_t ref_batch_normalization_bwd_t<d_type>::execute_backward( |
183 | const exec_ctx_t &ctx) const { |
184 | status_t status = status::success; |
185 | |
186 | const memory_desc_wrapper data_d(pd()->src_md()); |
187 | const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); |
188 | const memory_desc_wrapper ss_d(pd()->weights_md()); |
189 | const memory_desc_wrapper diff_ss_d(pd()->diff_weights_md()); |
190 | |
191 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
192 | auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN); |
193 | auto variance = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE); |
194 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
195 | auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE); |
196 | |
197 | auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
198 | CHECK(status); |
199 | |
200 | auto scale = CTX_IN_MEM(acc_data_t *, DNNL_ARG_SCALE); |
201 | auto diff_scale |
202 | = CTX_OUT_CLEAN_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE, status); |
203 | CHECK(status); |
204 | auto diff_shift |
205 | = CTX_OUT_CLEAN_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT, status); |
206 | CHECK(status); |
207 | |
208 | const auto ndims = data_d.ndims(); |
209 | const auto N = pd()->MB(); |
210 | const auto C = pd()->C(); |
211 | const auto D = pd()->D(); |
212 | const auto H = pd()->H(); |
213 | const auto W = pd()->W(); |
214 | |
215 | const auto eps = pd()->desc()->batch_norm_epsilon; |
216 | const auto calculate_diff_stats = !pd()->use_global_stats(); |
217 | const auto fuse_norm_relu = pd()->fuse_norm_relu(); |
218 | |
219 | /* fast return */ |
220 | if (this->pd()->has_zero_dim_memory()) { |
221 | if (diff_scale) { |
222 | for (dim_t c = 0; c < C; ++c) { |
223 | diff_scale[diff_ss_d.off(c)] = 0.0f; |
224 | } |
225 | } |
226 | if (diff_shift) { |
227 | for (dim_t c = 0; c < C; ++c) { |
228 | diff_shift[diff_ss_d.off(c)] = 0.0f; |
229 | } |
230 | } |
231 | return status::success; |
232 | } |
233 | |
234 | parallel_nd(C, [&](dim_t c) { |
235 | acc_data_t v_mean = mean[c]; |
236 | acc_data_t v_variance = variance[c]; |
237 | acc_data_t sqrt_variance |
238 | = static_cast<acc_data_t>(1.0f / sqrtf(v_variance + eps)); |
239 | acc_data_t gamma = scale ? scale[ss_d.off(c)] : 1.0f; |
240 | acc_data_t diff_gamma = 0; |
241 | acc_data_t diff_beta = 0; |
242 | |
243 | for_(dim_t n = 0; n < N; ++n) |
244 | for_(dim_t d = 0; d < D; ++d) |
245 | for_(dim_t h = 0; h < H; ++h) |
246 | for (dim_t w = 0; w < W; ++w) { |
247 | const size_t s_off = DATA_OFF(data_d, n, c, d, h, w); |
248 | acc_data_t dd; |
249 | if (fuse_norm_relu && !ws[s_off]) |
250 | dd = 0; |
251 | else |
252 | dd = maybe_up_convert( |
253 | diff_dst[DATA_OFF(diff_data_d, n, c, d, h, w)]); |
254 | diff_gamma += (maybe_up_convert(src[s_off]) - v_mean) * dd; |
255 | diff_beta += dd; |
256 | } |
257 | diff_gamma *= sqrt_variance; |
258 | |
259 | if (diff_scale) diff_scale[diff_ss_d.off(c)] = diff_gamma; |
260 | if (diff_shift) diff_shift[diff_ss_d.off(c)] = diff_beta; |
261 | |
262 | for_(dim_t n = 0; n < N; ++n) |
263 | for_(dim_t d = 0; d < D; ++d) |
264 | for_(dim_t h = 0; h < H; ++h) |
265 | for (dim_t w = 0; w < W; ++w) { |
266 | const size_t s_off = DATA_OFF(data_d, n, c, d, h, w); |
267 | const size_t dd_off = DATA_OFF(diff_data_d, n, c, d, h, w); |
268 | acc_data_t dd; |
269 | if (fuse_norm_relu && !ws[s_off]) |
270 | dd = 0; |
271 | else |
272 | dd = maybe_up_convert(diff_dst[dd_off]); |
273 | acc_data_t v_diff_src = dd; |
274 | if (calculate_diff_stats) { |
275 | v_diff_src -= diff_beta / (D * W * H * N) |
276 | + (maybe_up_convert(src[s_off]) - v_mean) * diff_gamma |
277 | * sqrt_variance / (D * W * H * N); |
278 | } |
279 | v_diff_src *= gamma * sqrt_variance; |
280 | diff_src[dd_off] = v_diff_src; |
281 | } |
282 | }); |
283 | return status::success; |
284 | } |
285 | |
286 | template struct ref_batch_normalization_bwd_t<f32>; |
287 | template struct ref_batch_normalization_bwd_t<bf16>; |
288 | template struct ref_batch_normalization_bwd_t<f16>; |
289 | |
290 | } // namespace cpu |
291 | } // namespace impl |
292 | } // namespace dnnl |
293 | |
294 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
295 | |