1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/cpu_primitive.hpp"
25#include "cpu/ref_io_helper.hpp"
26#include "cpu/ref_layer_normalization.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32status_t ref_layer_normalization_fwd_t::execute_forward(
33 const exec_ctx_t &ctx) const {
34 const memory_desc_wrapper src_d(pd()->src_md());
35 const memory_desc_wrapper dst_d(pd()->dst_md());
36 const memory_desc_wrapper stat_d(pd()->stat_md());
37 const memory_desc_wrapper sc_d(pd()->weights_md());
38
39 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
40 auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE);
41 auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT);
42 auto mean = pd()->stats_are_src()
43 ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN))
44 : CTX_OUT_MEM(float *, DNNL_ARG_MEAN);
45 auto variance = pd()->stats_are_src()
46 ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE))
47 : CTX_OUT_MEM(float *, DNNL_ARG_VARIANCE);
48 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
49
50 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
51 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
52
53 const dim_t N = pd()->across_axis();
54 const dim_t C = pd()->norm_axis();
55
56 const float eps = pd()->desc()->layer_norm_epsilon;
57 const bool save_stats = pd()->is_training();
58 const bool calculate_stats = !pd()->stats_are_src();
59
60 /* fast return */
61 if (this->pd()->has_zero_dim_memory()) {
62 if (calculate_stats && save_stats) {
63 for (dim_t n = 0; n < N; n++) {
64 mean[n] = 0;
65 variance[n] = 0;
66 }
67 }
68 return status::success;
69 }
70
71 parallel_nd(N, [&](dim_t n) {
72 const size_t s_off = stat_d.off_l(n);
73 auto v_mean = calculate_stats ? 0 : mean[s_off];
74 auto v_variance = calculate_stats ? 0 : variance[s_off];
75
76 if (calculate_stats) {
77 for (dim_t c = 0; c < C; ++c) {
78 const auto s_off = src_d.off_l(n * C + c);
79 float s = io::load_float_value(src_d.data_type(), src, s_off);
80 v_mean += s;
81 }
82 v_mean /= C;
83
84 for (dim_t c = 0; c < C; ++c) {
85 const auto s_off = src_d.off_l(n * C + c);
86 float s = io::load_float_value(src_d.data_type(), src, s_off);
87 float m = s - v_mean;
88 v_variance += m * m;
89 }
90 v_variance /= C;
91 }
92
93 float sqrt_variance = sqrtf(v_variance + eps);
94 for (dim_t c = 0; c < C; ++c) {
95 const float sm = (scale ? scale[sc_d.off(c)] : 1.f) / sqrt_variance;
96 const float sv = shift ? shift[sc_d.off(c)] : 0;
97 const auto s_off = src_d.off_l(n * C + c);
98 const auto d_off = dst_d.off_l(n * C + c);
99 float s = io::load_float_value(src_d.data_type(), src, s_off);
100 float d = sm * (s - v_mean) + sv;
101 d *= src_scales[0] * dst_scales[0];
102 io::store_float_value(dst_d.data_type(), d, dst, d_off);
103 }
104
105 if (calculate_stats) {
106 if (save_stats) {
107 mean[s_off] = v_mean;
108 variance[s_off] = v_variance;
109 }
110 }
111 });
112 return status::success;
113}
114
115status_t ref_layer_normalization_bwd_t::execute_backward(
116 const exec_ctx_t &ctx) const {
117 status_t status = status::success;
118
119 const memory_desc_wrapper src_d(pd()->src_md());
120 const memory_desc_wrapper stat_d(pd()->stat_md());
121 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
122 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
123 const memory_desc_wrapper sc_d(pd()->weights_md());
124 const memory_desc_wrapper diff_sc_d(pd()->diff_weights_md());
125
126 const auto use_scale = pd()->use_scale();
127 const auto use_shift = pd()->use_shift();
128
129 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
130 auto mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN);
131 auto variance = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE);
132 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
133 auto scale = CTX_IN_MEM(float *, DNNL_ARG_SCALE);
134 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
135 CHECK(status);
136
137 auto diff_scale = use_scale
138 ? CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SCALE, status)
139 : nullptr;
140 CHECK(status);
141 auto diff_shift = use_shift
142 ? CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SHIFT, status)
143 : nullptr;
144 CHECK(status);
145
146 const dim_t N = pd()->across_axis();
147 const dim_t C = pd()->norm_axis();
148
149 /* fast return */
150 if (this->pd()->has_zero_dim_memory()) {
151 if (diff_scale) {
152 for (dim_t c = 0; c < C; ++c) {
153 diff_scale[diff_sc_d.off(c)] = 0;
154 }
155 }
156 if (diff_shift) {
157 for (dim_t c = 0; c < C; ++c) {
158 diff_shift[diff_sc_d.off(c)] = 0;
159 }
160 }
161 return status::success;
162 }
163
164 const float eps = pd()->desc()->layer_norm_epsilon;
165 const bool calculate_diff_stats = !pd()->use_global_stats();
166
167 if (diff_scale || diff_shift) {
168 parallel_nd(C, [&](dim_t c) {
169 float diff_gamma = 0.f;
170 float diff_beta = 0.f;
171
172 for (dim_t n = 0; n < N; ++n) {
173 const auto src_off = src_d.off_l(n * C + c);
174 const auto diff_dst_off = diff_dst_d.off_l(n * C + c);
175 const auto stat_off = stat_d.off_l(n);
176 float inv_sqrt_variance = 1.f / sqrtf(variance[stat_off] + eps);
177 float s = io::load_float_value(src_d.data_type(), src, src_off);
178 float dd = io::load_float_value(
179 diff_dst_d.data_type(), diff_dst, diff_dst_off);
180 diff_gamma += (s - mean[stat_off]) * dd * inv_sqrt_variance;
181 diff_beta += dd;
182 }
183
184 if (diff_scale) diff_scale[diff_sc_d.off(c)] = diff_gamma;
185 if (diff_shift) diff_shift[diff_sc_d.off(c)] = diff_beta;
186 });
187 }
188
189 parallel_nd(N, [&](dim_t n) {
190 const size_t s_off = stat_d.off_l(n);
191 float inv_sqrt_variance = 1.f / sqrtf(variance[s_off] + eps);
192 float dd_gamma = 0.f;
193 float dd_gamma_x = 0.f;
194 if (calculate_diff_stats) {
195 for (dim_t c = 0; c < C; ++c) {
196 float gamma = scale ? scale[sc_d.off(c)] : 1.f;
197 const auto src_off = src_d.off_l(n * C + c);
198 const auto diff_dst_off = diff_dst_d.off_l(n * C + c);
199 float s = io::load_float_value(src_d.data_type(), src, src_off);
200 float dd = io::load_float_value(
201 diff_dst_d.data_type(), diff_dst, diff_dst_off);
202 dd_gamma += dd * gamma;
203 dd_gamma_x += dd * gamma * (s - mean[s_off]);
204 }
205 dd_gamma_x *= inv_sqrt_variance;
206 }
207
208 for (dim_t c = 0; c < C; ++c) {
209 float gamma = scale ? scale[sc_d.off(c)] : 1;
210 const auto src_off = src_d.off_l(n * C + c);
211 const auto diff_dst_off = diff_dst_d.off_l(n * C + c);
212 const auto diff_src_off = diff_src_d.off_l(n * C + c);
213 float dd = io::load_float_value(
214 diff_dst_d.data_type(), diff_dst, diff_dst_off);
215 float d_src = dd * gamma;
216 if (calculate_diff_stats) {
217 float s = io::load_float_value(src_d.data_type(), src, src_off);
218 d_src -= dd_gamma / C;
219 d_src -= (s - mean[s_off]) * dd_gamma_x * inv_sqrt_variance / C;
220 }
221 d_src *= inv_sqrt_variance;
222 io::store_float_value(
223 diff_src_d.data_type(), d_src, diff_src, diff_src_off);
224 }
225 });
226 return status::success;
227}
228
229} // namespace cpu
230} // namespace impl
231} // namespace dnnl
232
233// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
234