1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/cpu_primitive.hpp" |
25 | #include "cpu/ref_io_helper.hpp" |
26 | #include "cpu/ref_layer_normalization.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | status_t ref_layer_normalization_fwd_t::execute_forward( |
33 | const exec_ctx_t &ctx) const { |
34 | const memory_desc_wrapper src_d(pd()->src_md()); |
35 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
36 | const memory_desc_wrapper stat_d(pd()->stat_md()); |
37 | const memory_desc_wrapper sc_d(pd()->weights_md()); |
38 | |
39 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
40 | auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE); |
41 | auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT); |
42 | auto mean = pd()->stats_are_src() |
43 | ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN)) |
44 | : CTX_OUT_MEM(float *, DNNL_ARG_MEAN); |
45 | auto variance = pd()->stats_are_src() |
46 | ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE)) |
47 | : CTX_OUT_MEM(float *, DNNL_ARG_VARIANCE); |
48 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
49 | |
50 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
51 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
52 | |
53 | const dim_t N = pd()->across_axis(); |
54 | const dim_t C = pd()->norm_axis(); |
55 | |
56 | const float eps = pd()->desc()->layer_norm_epsilon; |
57 | const bool save_stats = pd()->is_training(); |
58 | const bool calculate_stats = !pd()->stats_are_src(); |
59 | |
60 | /* fast return */ |
61 | if (this->pd()->has_zero_dim_memory()) { |
62 | if (calculate_stats && save_stats) { |
63 | for (dim_t n = 0; n < N; n++) { |
64 | mean[n] = 0; |
65 | variance[n] = 0; |
66 | } |
67 | } |
68 | return status::success; |
69 | } |
70 | |
71 | parallel_nd(N, [&](dim_t n) { |
72 | const size_t s_off = stat_d.off_l(n); |
73 | auto v_mean = calculate_stats ? 0 : mean[s_off]; |
74 | auto v_variance = calculate_stats ? 0 : variance[s_off]; |
75 | |
76 | if (calculate_stats) { |
77 | for (dim_t c = 0; c < C; ++c) { |
78 | const auto s_off = src_d.off_l(n * C + c); |
79 | float s = io::load_float_value(src_d.data_type(), src, s_off); |
80 | v_mean += s; |
81 | } |
82 | v_mean /= C; |
83 | |
84 | for (dim_t c = 0; c < C; ++c) { |
85 | const auto s_off = src_d.off_l(n * C + c); |
86 | float s = io::load_float_value(src_d.data_type(), src, s_off); |
87 | float m = s - v_mean; |
88 | v_variance += m * m; |
89 | } |
90 | v_variance /= C; |
91 | } |
92 | |
93 | float sqrt_variance = sqrtf(v_variance + eps); |
94 | for (dim_t c = 0; c < C; ++c) { |
95 | const float sm = (scale ? scale[sc_d.off(c)] : 1.f) / sqrt_variance; |
96 | const float sv = shift ? shift[sc_d.off(c)] : 0; |
97 | const auto s_off = src_d.off_l(n * C + c); |
98 | const auto d_off = dst_d.off_l(n * C + c); |
99 | float s = io::load_float_value(src_d.data_type(), src, s_off); |
100 | float d = sm * (s - v_mean) + sv; |
101 | d *= src_scales[0] * dst_scales[0]; |
102 | io::store_float_value(dst_d.data_type(), d, dst, d_off); |
103 | } |
104 | |
105 | if (calculate_stats) { |
106 | if (save_stats) { |
107 | mean[s_off] = v_mean; |
108 | variance[s_off] = v_variance; |
109 | } |
110 | } |
111 | }); |
112 | return status::success; |
113 | } |
114 | |
115 | status_t ref_layer_normalization_bwd_t::execute_backward( |
116 | const exec_ctx_t &ctx) const { |
117 | status_t status = status::success; |
118 | |
119 | const memory_desc_wrapper src_d(pd()->src_md()); |
120 | const memory_desc_wrapper stat_d(pd()->stat_md()); |
121 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
122 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
123 | const memory_desc_wrapper sc_d(pd()->weights_md()); |
124 | const memory_desc_wrapper diff_sc_d(pd()->diff_weights_md()); |
125 | |
126 | const auto use_scale = pd()->use_scale(); |
127 | const auto use_shift = pd()->use_shift(); |
128 | |
129 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
130 | auto mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN); |
131 | auto variance = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE); |
132 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
133 | auto scale = CTX_IN_MEM(float *, DNNL_ARG_SCALE); |
134 | auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); |
135 | CHECK(status); |
136 | |
137 | auto diff_scale = use_scale |
138 | ? CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SCALE, status) |
139 | : nullptr; |
140 | CHECK(status); |
141 | auto diff_shift = use_shift |
142 | ? CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SHIFT, status) |
143 | : nullptr; |
144 | CHECK(status); |
145 | |
146 | const dim_t N = pd()->across_axis(); |
147 | const dim_t C = pd()->norm_axis(); |
148 | |
149 | /* fast return */ |
150 | if (this->pd()->has_zero_dim_memory()) { |
151 | if (diff_scale) { |
152 | for (dim_t c = 0; c < C; ++c) { |
153 | diff_scale[diff_sc_d.off(c)] = 0; |
154 | } |
155 | } |
156 | if (diff_shift) { |
157 | for (dim_t c = 0; c < C; ++c) { |
158 | diff_shift[diff_sc_d.off(c)] = 0; |
159 | } |
160 | } |
161 | return status::success; |
162 | } |
163 | |
164 | const float eps = pd()->desc()->layer_norm_epsilon; |
165 | const bool calculate_diff_stats = !pd()->use_global_stats(); |
166 | |
167 | if (diff_scale || diff_shift) { |
168 | parallel_nd(C, [&](dim_t c) { |
169 | float diff_gamma = 0.f; |
170 | float diff_beta = 0.f; |
171 | |
172 | for (dim_t n = 0; n < N; ++n) { |
173 | const auto src_off = src_d.off_l(n * C + c); |
174 | const auto diff_dst_off = diff_dst_d.off_l(n * C + c); |
175 | const auto stat_off = stat_d.off_l(n); |
176 | float inv_sqrt_variance = 1.f / sqrtf(variance[stat_off] + eps); |
177 | float s = io::load_float_value(src_d.data_type(), src, src_off); |
178 | float dd = io::load_float_value( |
179 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
180 | diff_gamma += (s - mean[stat_off]) * dd * inv_sqrt_variance; |
181 | diff_beta += dd; |
182 | } |
183 | |
184 | if (diff_scale) diff_scale[diff_sc_d.off(c)] = diff_gamma; |
185 | if (diff_shift) diff_shift[diff_sc_d.off(c)] = diff_beta; |
186 | }); |
187 | } |
188 | |
189 | parallel_nd(N, [&](dim_t n) { |
190 | const size_t s_off = stat_d.off_l(n); |
191 | float inv_sqrt_variance = 1.f / sqrtf(variance[s_off] + eps); |
192 | float dd_gamma = 0.f; |
193 | float dd_gamma_x = 0.f; |
194 | if (calculate_diff_stats) { |
195 | for (dim_t c = 0; c < C; ++c) { |
196 | float gamma = scale ? scale[sc_d.off(c)] : 1.f; |
197 | const auto src_off = src_d.off_l(n * C + c); |
198 | const auto diff_dst_off = diff_dst_d.off_l(n * C + c); |
199 | float s = io::load_float_value(src_d.data_type(), src, src_off); |
200 | float dd = io::load_float_value( |
201 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
202 | dd_gamma += dd * gamma; |
203 | dd_gamma_x += dd * gamma * (s - mean[s_off]); |
204 | } |
205 | dd_gamma_x *= inv_sqrt_variance; |
206 | } |
207 | |
208 | for (dim_t c = 0; c < C; ++c) { |
209 | float gamma = scale ? scale[sc_d.off(c)] : 1; |
210 | const auto src_off = src_d.off_l(n * C + c); |
211 | const auto diff_dst_off = diff_dst_d.off_l(n * C + c); |
212 | const auto diff_src_off = diff_src_d.off_l(n * C + c); |
213 | float dd = io::load_float_value( |
214 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
215 | float d_src = dd * gamma; |
216 | if (calculate_diff_stats) { |
217 | float s = io::load_float_value(src_d.data_type(), src, src_off); |
218 | d_src -= dd_gamma / C; |
219 | d_src -= (s - mean[s_off]) * dd_gamma_x * inv_sqrt_variance / C; |
220 | } |
221 | d_src *= inv_sqrt_variance; |
222 | io::store_float_value( |
223 | diff_src_d.data_type(), d_src, diff_src, diff_src_off); |
224 | } |
225 | }); |
226 | return status::success; |
227 | } |
228 | |
229 | } // namespace cpu |
230 | } // namespace impl |
231 | } // namespace dnnl |
232 | |
233 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
234 | |