1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/reorder.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/cpu_primitive.hpp"
26#include "cpu/ref_io_helper.hpp"
27
28#include "cpu/simple_layer_normalization.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34using namespace memory_tracking::names;
35using namespace data_type;
36
37status_t simple_layer_normalization_fwd_t::pd_t::init(engine_t *engine) {
38 using namespace data_type;
39 using skip_mask_t = primitive_attr_t::skip_mask_t;
40 const memory_desc_wrapper src_d(src_md());
41
42 const bool ok = is_fwd() && !has_zero_dim_memory()
43 && utils::one_of(src_md()->data_type, f32, bf16, f16, s8, u8)
44 && utils::one_of(dst_md()->data_type, f32, bf16, f16, s8, u8)
45 && platform::has_data_type_support(src_md()->data_type)
46 && platform::has_data_type_support(dst_md()->data_type)
47 && stat_md()->data_type == f32 && check_scale_shift_data_type()
48 && attr()->has_default_values(skip_mask_t::scales_runtime)
49 && attr_scales_ok() && set_default_formats_common()
50 && src_d.is_blocking_desc()
51 // plain format, last logical dim is last physical
52 && src_d.blocking_desc().strides[ndims() - 1] == 1;
53 if (!ok) return status::unimplemented;
54
55 CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_));
56
57 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
58 CHECK(reorder_primitive_desc_create(reorder_pd_, engine,
59 stats_are_src() ? stat_md() : &reordered_stat_md_,
60 stats_are_src() ? &reordered_stat_md_ : stat_md()));
61 }
62
63 init_scratchpad();
64 return status::success;
65}
66
67status_t simple_layer_normalization_fwd_t::execute_forward(
68 const exec_ctx_t &ctx) const {
69 const bool use_scale = pd()->use_scale();
70 const bool use_shift = pd()->use_shift();
71
72 auto scratchpad = ctx.get_scratchpad_grantor();
73 const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
74 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
75
76 auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE);
77 auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT);
78
79 float *mean, *variance;
80 if (pd()->use_tmp_stats()) {
81 mean = scratchpad.template get<float>(key_lnorm_tmp_mean);
82 variance = scratchpad.template get<float>(key_lnorm_tmp_var);
83 } else {
84 mean = pd()->stats_are_src()
85 ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN))
86 : CTX_OUT_MEM(float *, DNNL_ARG_MEAN);
87 variance = pd()->stats_are_src()
88 ? const_cast<float *>(
89 CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE))
90 : CTX_OUT_MEM(float *, DNNL_ARG_VARIANCE);
91 }
92
93 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
94 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
95
96 const memory_desc_wrapper src_d(pd()->src_md());
97 const memory_desc_wrapper dst_d(pd()->dst_md());
98
99 const dim_t N = pd()->across_axis();
100 const dim_t C = pd()->norm_axis();
101 const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1];
102
103 const auto calculate_stats = !pd()->stats_are_src();
104 const auto src_dt = pd()->src_md()->data_type;
105 const auto dst_dt = pd()->dst_md()->data_type;
106 const auto eps = pd()->desc()->layer_norm_epsilon;
107 const auto save_stats = pd()->is_training();
108
109 parallel(0, [&](const int ithr, const int nthr) {
110 dim_t N_start = 0, N_end = 0;
111 balance211(N, nthr, ithr, N_start, N_end);
112 const char *const __restrict src_ptr
113 = reinterpret_cast<const char *>(src)
114 + N_start * C_padded * src_d.data_type_size();
115 char *const __restrict dst_ptr = reinterpret_cast<char *>(dst)
116 + N_start * C_padded * dst_d.data_type_size();
117 float *const __restrict mean_ptr = &mean[N_start];
118 float *const __restrict var_ptr = &variance[N_start];
119 const size_t block_size = N_end - N_start;
120 // Note: manual unrolling for scale and shift due to clang issue.
121 // see: CLANG_WA_01_SAFE_TO_USE_OMP_SIMD
122 for (size_t offset = 0; offset < block_size; offset++) {
123 float v_mean = 0, v_variance = 0;
124 if (calculate_stats) {
125 PRAGMA_OMP_SIMD(reduction(+ : v_mean))
126 for (dim_t c = 0; c < C; ++c) {
127 float s = io::load_float_value(
128 src_dt, src_ptr, c + C * offset);
129 v_mean += s;
130 }
131 v_mean /= C;
132
133 PRAGMA_OMP_SIMD(reduction(+ : v_variance))
134 for (dim_t c = 0; c < C; ++c) {
135 float s = io::load_float_value(
136 src_dt, src_ptr, c + C * offset);
137 float src_sub_mean = s - v_mean;
138 v_variance += src_sub_mean * src_sub_mean;
139 }
140 v_variance /= C;
141 } else {
142 v_mean = mean_ptr[offset];
143 v_variance = var_ptr[offset];
144 }
145
146 const float inv_sqrtvar = 1.f / sqrtf(v_variance + eps);
147 if (use_scale && use_shift) {
148 PRAGMA_OMP_SIMD()
149 for (dim_t c = 0; c < C; ++c) {
150 const float sm = scale[c] * inv_sqrtvar;
151 const float sv = shift[c];
152 const size_t off = c + C * offset;
153 float s = io::load_float_value(src_dt, src_ptr, off);
154 float d = sm * (s - v_mean) + sv;
155 d *= src_scales[0] * dst_scales[0];
156 io::store_float_value(dst_dt, d, dst_ptr, off);
157 }
158 } else if (use_scale) {
159 PRAGMA_OMP_SIMD()
160 for (dim_t c = 0; c < C; ++c) {
161 const float sm = scale[c] * inv_sqrtvar;
162 const size_t off = c + C * offset;
163 float s = io::load_float_value(src_dt, src_ptr, off);
164 float d = sm * (s - v_mean);
165 d *= src_scales[0] * dst_scales[0];
166 io::store_float_value(dst_dt, d, dst_ptr, off);
167 }
168 } else if (use_shift) {
169 PRAGMA_OMP_SIMD()
170 for (dim_t c = 0; c < C; ++c) {
171 const float sm = 1.f * inv_sqrtvar;
172 const float sv = shift[c];
173 const size_t off = c + C * offset;
174 float s = io::load_float_value(src_dt, src_ptr, off);
175 float d = sm * (s - v_mean) + sv;
176 d *= src_scales[0] * dst_scales[0];
177 io::store_float_value(dst_dt, d, dst_ptr, off);
178 }
179 } else {
180 PRAGMA_OMP_SIMD()
181 for (dim_t c = 0; c < C; ++c) {
182 const float sm = 1.f * inv_sqrtvar;
183 const size_t off = c + C * offset;
184 float s = io::load_float_value(src_dt, src_ptr, off);
185 float d = sm * (s - v_mean);
186 d *= src_scales[0] * dst_scales[0];
187 io::store_float_value(dst_dt, d, dst_ptr, off);
188 }
189 }
190 if (calculate_stats && save_stats) {
191 mean_ptr[offset] = v_mean;
192 var_ptr[offset] = v_variance;
193 }
194 }
195 });
196 return status::success;
197}
198
199status_t simple_layer_normalization_bwd_t::pd_t::init(engine_t *engine) {
200 using namespace data_type;
201 const memory_desc_wrapper src_d(src_md());
202
203 const bool ok = is_bwd() && !has_zero_dim_memory()
204 && utils::one_of(src_md()->data_type, f32, bf16, f16)
205 && utils::one_of(diff_dst_md()->data_type, f32, bf16, f16)
206 && utils::one_of(diff_src_md()->data_type, f32, bf16, f16)
207 && platform::has_data_type_support(src_md()->data_type)
208 && platform::has_data_type_support(diff_dst_md()->data_type)
209 && platform::has_data_type_support(diff_src_md()->data_type)
210 && stat_md()->data_type == f32 && check_scale_shift_data_type()
211 && attr()->has_default_values() && set_default_formats_common()
212 && src_d.is_blocking_desc()
213 // plain format, last logical dim is last physical
214 && src_d.blocking_desc().strides[ndims() - 1] == 1;
215 if (!ok) return status::unimplemented;
216
217 CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_));
218
219 if (reordered_stat_md_ != *stat_md()) {
220 CHECK(reorder_primitive_desc_create(
221 reorder_pd_, engine, stat_md(), &reordered_stat_md_));
222 }
223
224 nthr_ = dnnl_get_max_threads();
225 init_scratchpad();
226 return status::success;
227}
228
229status_t simple_layer_normalization_bwd_t::execute_backward(
230 const exec_ctx_t &ctx) const {
231 status_t status = status::success;
232
233 const bool use_scale = pd()->use_scale();
234
235 auto scratchpad = ctx.get_scratchpad_grantor();
236 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
237 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
238 auto scale = CTX_IN_MEM(float *, DNNL_ARG_SCALE);
239 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
240
241 auto diff_scale = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SCALE, status);
242 CHECK(status);
243 auto diff_shift = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SHIFT, status);
244 CHECK(status);
245
246 const float *mean, *variance;
247 if (pd()->use_tmp_stats()) {
248 mean = scratchpad.template get<float>(key_lnorm_tmp_mean);
249 variance = scratchpad.template get<float>(key_lnorm_tmp_var);
250 } else {
251 mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN);
252 variance = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE);
253 }
254
255 float *const inv_sqrtvar
256 = scratchpad.template get<float>(key_lnorm_inv_sqrtvar);
257
258 const memory_desc_wrapper src_d(pd()->src_md());
259 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
260 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
261
262 const dim_t N = pd()->across_axis();
263 const dim_t C = pd()->norm_axis();
264 const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1];
265
266 float *reduce = scratchpad.template get<float>(key_lnorm_reduction);
267 if (diff_scale == nullptr)
268 diff_scale = scratchpad.template get<float>(key_lnorm_tmp_diff_ss);
269 if (diff_shift == nullptr) {
270 diff_shift = scratchpad.template get<float>(key_lnorm_tmp_diff_ss);
271 }
272
273 const int max_nthr = pd()->nthr_;
274
275 const auto src_dt = pd()->src_md()->data_type;
276 const auto diff_dst_dt = pd()->diff_dst_md()->data_type;
277 const auto diff_src_dt = pd()->diff_src_md()->data_type;
278 const auto eps = pd()->desc()->layer_norm_epsilon;
279 const auto calculate_diff_stats = !pd()->stats_are_src();
280
281 parallel(max_nthr, [&](int ithr, int nthr) {
282 dim_t N_start = 0, N_end = 0;
283 balance211(N, nthr, ithr, N_start, N_end);
284 const size_t block_size = N_end - N_start;
285 const char *const __restrict src_ptr
286 = reinterpret_cast<const char *>(src)
287 + N_start * C_padded * src_d.data_type_size();
288 const char *const __restrict diff_dst_ptr
289 = reinterpret_cast<const char *>(diff_dst)
290 + N_start * C_padded * diff_dst_d.data_type_size();
291 const float *mean_ptr = &mean[N_start];
292 const float *var_ptr = &variance[N_start];
293 float *const inv_sqrtvar_ptr = &inv_sqrtvar[N_start];
294
295 float *my_diff_gamma = reduce + C * ithr;
296 float *my_diff_beta = reduce + C * nthr + C * ithr;
297
298 PRAGMA_OMP_SIMD()
299 for (dim_t c = 0; c < C; c++) {
300 my_diff_gamma[c] = 0.;
301 my_diff_beta[c] = 0.;
302 }
303
304 for (size_t offset = 0; offset < block_size; offset++) {
305 inv_sqrtvar_ptr[offset] = 1. / sqrtf(var_ptr[offset] + eps);
306
307 PRAGMA_OMP_SIMD()
308 for (dim_t c = 0; c < C; c++) {
309 const size_t off = c + C * offset;
310 float s = io::load_float_value(src_dt, src_ptr, off);
311 float dd = io::load_float_value(diff_dst_dt, diff_dst_ptr, off);
312 my_diff_gamma[c] += (s - mean_ptr[offset]) * dd
313 * inv_sqrtvar_ptr[offset];
314 my_diff_beta[c] += dd;
315 }
316 }
317 });
318
319 parallel_nd(C, [&](dim_t c) {
320 float diff_gamma = 0, diff_beta = 0;
321 for (dim_t n = 0; n < max_nthr; n++) {
322 diff_gamma += reduce[C * n + c];
323 diff_beta += reduce[C * max_nthr + C * n + c];
324 }
325 diff_scale[c] = diff_gamma;
326 diff_shift[c] = diff_beta;
327 });
328
329 parallel(max_nthr, [&](int ithr, int nthr) {
330 dim_t N_start = 0, N_end = 0;
331 balance211(N, nthr, ithr, N_start, N_end);
332 const size_t block_size = N_end - N_start;
333 const char *const __restrict src_ptr
334 = reinterpret_cast<const char *>(src)
335 + N_start * C_padded * src_d.data_type_size();
336 const char *const __restrict diff_dst_ptr
337 = reinterpret_cast<const char *>(diff_dst)
338 + N_start * C_padded * diff_dst_d.data_type_size();
339 char *const __restrict diff_src_ptr = reinterpret_cast<char *>(diff_src)
340 + N_start * C_padded * diff_src_d.data_type_size();
341 const float *mean_ptr = &mean[N_start];
342 float *const inv_sqrtvar_ptr = &inv_sqrtvar[N_start];
343
344 // Note: manual unrolling for scale and shift due to clang issue.
345 // see: CLANG_WA_01_SAFE_TO_USE_OMP_SIMD
346 float dd_gamma, dd_gamma_x;
347 for (size_t offset = 0; offset < block_size; offset++) {
348 // reduce gamma
349 dd_gamma = dd_gamma_x = 0;
350 if (calculate_diff_stats) {
351 if (use_scale) {
352 PRAGMA_OMP_SIMD(reduction(+ : dd_gamma, dd_gamma_x))
353 for (dim_t c = 0; c < C; c++) {
354 const size_t off = c + C * offset;
355 float s = io::load_float_value(src_dt, src_ptr, off);
356 float dd = io::load_float_value(
357 diff_dst_dt, diff_dst_ptr, off);
358 dd_gamma += dd * scale[c];
359 dd_gamma_x += dd * scale[c] * (s - mean_ptr[offset]);
360 }
361 } else {
362 PRAGMA_OMP_SIMD(reduction(+ : dd_gamma, dd_gamma_x))
363 for (dim_t c = 0; c < C; c++) {
364 const size_t off = c + C * offset;
365 float s = io::load_float_value(src_dt, src_ptr, off);
366 float dd = io::load_float_value(
367 diff_dst_dt, diff_dst_ptr, off);
368 dd_gamma += dd;
369 dd_gamma_x += dd * (s - mean_ptr[offset]);
370 }
371 }
372 dd_gamma_x *= inv_sqrtvar_ptr[offset];
373 }
374
375 // calculate diff_dst
376 if (use_scale) {
377 PRAGMA_OMP_SIMD()
378 for (dim_t c = 0; c < C; c++) {
379 const size_t off = c + C * offset;
380 float dd = io::load_float_value(
381 diff_dst_dt, diff_dst_ptr, off);
382 float ds = dd * scale[c];
383 if (calculate_diff_stats) {
384 float s = io::load_float_value(src_dt, src_ptr, off);
385 ds -= dd_gamma / C;
386 ds -= (s - mean_ptr[offset]) * dd_gamma_x
387 * inv_sqrtvar_ptr[offset] / C;
388 }
389 ds *= inv_sqrtvar_ptr[offset];
390 io::store_float_value(diff_src_dt, ds, diff_src_ptr, off);
391 }
392 } else {
393 PRAGMA_OMP_SIMD()
394 for (dim_t c = 0; c < C; c++) {
395 const size_t off = c + C * offset;
396 float dd = io::load_float_value(
397 diff_dst_dt, diff_dst_ptr, off);
398 float ds = dd;
399 if (calculate_diff_stats) {
400 float s = io::load_float_value(src_dt, src_ptr, off);
401 ds -= dd_gamma / C;
402 ds -= (s - mean_ptr[offset]) * dd_gamma_x
403 * inv_sqrtvar_ptr[offset] / C;
404 }
405 ds *= inv_sqrtvar_ptr[offset];
406 io::store_float_value(diff_src_dt, ds, diff_src_ptr, off);
407 }
408 }
409 }
410 });
411 return status::success;
412}
413
414} // namespace cpu
415} // namespace impl
416} // namespace dnnl
417