1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_LAYER_NORMALIZATION_HPP
18#define CPU_X64_JIT_UNI_LAYER_NORMALIZATION_HPP
19
20#include <memory>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/reorder.hpp"
27#include "common/stream.hpp"
28#include "common/utils.hpp"
29
30#include "cpu/cpu_layer_normalization_pd.hpp"
31
32#include "cpu/x64/cpu_isa_traits.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39struct stat_and_data_kernel_t {
40 static stat_and_data_kernel_t *create(const layer_normalization_pd_t *pd);
41 virtual ~stat_and_data_kernel_t() = default;
42
43 virtual void operator()(const void *src, void *dst, const float *scale,
44 const float *shift, float *mean, float *var,
45 const float *src_scales, const float *dst_scales,
46 const size_t block_size) const {};
47
48 virtual status_t create_kernel() { return status::success; }
49
50protected:
51 stat_and_data_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {}
52
53 const layer_normalization_pd_t *pd_;
54};
55
56struct diff_ss_kernel_t {
57 static diff_ss_kernel_t *create(const layer_normalization_pd_t *pd);
58 virtual ~diff_ss_kernel_t() = default;
59
60 virtual void operator()(const void *src, const void *diff_dst,
61 float *diff_gamma, float *diff_beta, const float *mean,
62 const float *var, float *const inv_sqrtvar,
63 const size_t block_size) const {};
64
65 virtual status_t create_kernel() { return status::success; }
66
67protected:
68 diff_ss_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {}
69
70 const layer_normalization_pd_t *pd_;
71};
72
73struct diff_data_kernel_t {
74 static diff_data_kernel_t *create(const layer_normalization_pd_t *pd);
75 virtual ~diff_data_kernel_t() = default;
76
77 virtual void operator()(const void *src, const void *diff_dst,
78 void *diff_src, const float *ss, const float *mean,
79 float *const inv_sqrtvar, const size_t block_size) const {};
80
81 virtual status_t create_kernel() { return status::success; }
82
83protected:
84 diff_data_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {}
85
86 const layer_normalization_pd_t *pd_;
87};
88
89struct jit_uni_layer_normalization_fwd_t : public primitive_t {
90 struct pd_t : public cpu_layer_normalization_fwd_pd_t {
91 using cpu_layer_normalization_fwd_pd_t::
92 cpu_layer_normalization_fwd_pd_t;
93
94 DECLARE_COMMON_PD_T("jit:uni", jit_uni_layer_normalization_fwd_t);
95
96 status_t init(engine_t *engine) {
97 using namespace data_type;
98 using skip_mask_t = primitive_attr_t::skip_mask_t;
99 const memory_desc_wrapper src_d(src_md());
100
101 const bool ok = is_fwd() && !has_zero_dim_memory()
102 && utils::one_of(
103 src_md()->data_type, f32, bf16, f16, s8, u8)
104 && utils::one_of(
105 dst_md()->data_type, f32, bf16, f16, s8, u8)
106 && IMPLICATION(utils::one_of(bf16, src_md()->data_type,
107 dst_md()->data_type),
108 mayiuse(avx512_core))
109 && IMPLICATION(utils::one_of(f16, src_md()->data_type,
110 dst_md()->data_type),
111 mayiuse(avx512_core_fp16))
112 && stat_md()->data_type == f32
113 && check_scale_shift_data_type()
114 && attr()->has_default_values(skip_mask_t::scales_runtime)
115 && attr_scales_ok() && set_default_formats_common()
116 && src_d.is_blocking_desc()
117 // plain format, last logical dim is last physical
118 && src_d.blocking_desc().strides[ndims() - 1] == 1;
119 if (!ok) return status::unimplemented;
120
121 CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_));
122
123 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
124 CHECK(reorder_primitive_desc_create(reorder_pd_, engine,
125 stats_are_src() ? stat_md() : &reordered_stat_md_,
126 stats_are_src() ? &reordered_stat_md_ : stat_md()));
127 }
128
129 init_scratchpad();
130 return status::success;
131 }
132
133 bool use_tmp_stats() const { return reorder_pd_ || stats_are_tmp(); }
134
135 std::shared_ptr<primitive_desc_t> reorder_pd_;
136 memory_desc_t reordered_stat_md_;
137
138 private:
139 void init_scratchpad() {
140 using namespace memory_tracking::names;
141 auto scratchpad = scratchpad_registry().registrar();
142 if (use_tmp_stats()) {
143 scratchpad.template book<float>(
144 key_lnorm_tmp_mean, across_axis());
145 scratchpad.template book<float>(
146 key_lnorm_tmp_var, across_axis());
147 }
148 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
149 scratchpad.book(key_nested, reorder_pd_->scratchpad_registry());
150 }
151 }
152 };
153
154 status_t init(engine_t *engine) override {
155 if (pd()->reorder_pd_)
156 pd()->reorder_pd_->create_primitive(reorder_, engine);
157 CHECK(safe_ptr_assign(
158 stat_and_data_kernel_, stat_and_data_kernel_t::create(pd())));
159 if (stat_and_data_kernel_)
160 CHECK(stat_and_data_kernel_->create_kernel());
161 return status::success;
162 }
163
164 jit_uni_layer_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
165 virtual ~jit_uni_layer_normalization_fwd_t() = default;
166
167 void reorder_stat(const exec_ctx_t &ctx, engine_t *engine,
168 const memory_arg_t &in, const memory_arg_t &out) const {
169 using namespace memory_tracking::names;
170 exec_args_t r_args;
171 r_args[DNNL_ARG_SRC] = in;
172 r_args[DNNL_ARG_DST] = out;
173 exec_ctx_t r_ctx(ctx, std::move(r_args));
174
175 nested_scratchpad_t ns(ctx, key_nested, reorder_);
176 r_ctx.set_scratchpad_grantor(ns.grantor());
177 reorder_->execute(r_ctx);
178 }
179
180 status_t execute(const exec_ctx_t &ctx) const override {
181 /* LN supports arbitrary layout for input/output statistics.
182 * For best performance we compute LN with statistics in the same format
183 * as data tensor (i.e. data in abcd, stats in abc) and user's
184 * input/output statistics are reordered if necessary */
185 using namespace memory_tracking::names;
186 engine_t *engine = ctx.stream()->engine();
187 auto scratchpad = ctx.get_scratchpad_grantor();
188 auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean);
189 auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var);
190 memory_t mean(engine, &(pd()->reordered_stat_md_), std::move(mean_mem));
191 memory_t variance(
192 engine, &(pd()->reordered_stat_md_), std::move(variance_mem));
193
194 // reorder input stats
195 if (pd()->stats_are_src() && reorder_) {
196 reorder_stat(
197 ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false});
198 reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE),
199 {&variance, false});
200 }
201 status_t status = execute_forward(ctx);
202 if (status != status::success) return status;
203 // reorder output stats
204 if (!pd()->stats_are_src() && reorder_) {
205 reorder_stat(
206 ctx, engine, {&mean, true}, ctx.args().at(DNNL_ARG_MEAN));
207 reorder_stat(ctx, engine, {&variance, true},
208 ctx.args().at(DNNL_ARG_VARIANCE));
209 }
210
211 return status::success;
212 }
213
214private:
215 status_t execute_forward(const exec_ctx_t &ctx) const;
216 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
217
218 std::unique_ptr<stat_and_data_kernel_t> stat_and_data_kernel_;
219 std::shared_ptr<primitive_t> reorder_;
220};
221
222struct jit_uni_layer_normalization_bwd_t : public primitive_t {
223 struct pd_t : public cpu_layer_normalization_bwd_pd_t {
224 using cpu_layer_normalization_bwd_pd_t::
225 cpu_layer_normalization_bwd_pd_t;
226
227 DECLARE_COMMON_PD_T("jit:uni", jit_uni_layer_normalization_bwd_t);
228
229 status_t init(engine_t *engine) {
230 using namespace data_type;
231 const memory_desc_wrapper src_d(src_md());
232
233 const bool ok = is_bwd() && !has_zero_dim_memory()
234 && mayiuse(avx2) // sse41 is not supported yet
235 && utils::one_of(src_md()->data_type, f32, bf16, f16)
236 && utils::one_of(diff_dst_md()->data_type, f32, bf16, f16)
237 && utils::one_of(diff_src_md()->data_type, f32, bf16, f16)
238 && IMPLICATION(utils::one_of(bf16, src_md()->data_type,
239 diff_dst_md()->data_type,
240 diff_src_md()->data_type),
241 mayiuse(avx512_core))
242 && IMPLICATION(utils::one_of(f16, src_md()->data_type,
243 diff_dst_md()->data_type,
244 diff_src_md()->data_type),
245 mayiuse(avx512_core_fp16))
246 && stat_md()->data_type == f32
247 && check_scale_shift_data_type()
248 && attr()->has_default_values()
249 && set_default_formats_common()
250 && src_d.is_blocking_desc()
251 // plain format, last logical dim is last physical
252 && src_d.blocking_desc().strides[ndims() - 1] == 1;
253 if (!ok) return status::unimplemented;
254
255 CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_));
256
257 if (reordered_stat_md_ != *stat_md()) {
258 CHECK(reorder_primitive_desc_create(
259 reorder_pd_, engine, stat_md(), &reordered_stat_md_));
260 }
261
262 nthr_ = dnnl_get_max_threads();
263 init_scratchpad();
264 return status::success;
265 }
266
267 bool use_tmp_stats() const { return reorder_pd_.get(); }
268
269 std::shared_ptr<primitive_desc_t> reorder_pd_;
270 memory_desc_t reordered_stat_md_;
271 int nthr_; // To not exceed the limit in execute used for set up.
272
273 private:
274 void init_scratchpad() {
275 using namespace memory_tracking::names;
276 auto scratchpad = scratchpad_registry().registrar();
277 if (use_tmp_stats()) {
278 scratchpad.template book<float>(
279 key_lnorm_tmp_mean, across_axis());
280 scratchpad.template book<float>(
281 key_lnorm_tmp_var, across_axis());
282 }
283 scratchpad.template book<float>(
284 key_lnorm_reduction, 2 * norm_axis() * nthr_);
285 scratchpad.template book<float>(
286 key_lnorm_tmp_diff_ss, 2 * norm_axis());
287 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
288 scratchpad.book(key_nested, reorder_pd_->scratchpad_registry());
289 }
290 scratchpad.template book<float>(
291 key_lnorm_inv_sqrtvar, across_axis());
292 }
293 };
294
295 status_t init(engine_t *engine) override {
296 if (pd()->reorder_pd_)
297 pd()->reorder_pd_->create_primitive(reorder_, engine);
298 CHECK(safe_ptr_assign(diff_ss_kernel_, diff_ss_kernel_t::create(pd())));
299 CHECK(safe_ptr_assign(
300 diff_data_kernel_, diff_data_kernel_t::create(pd())));
301 if (diff_ss_kernel_) CHECK(diff_ss_kernel_->create_kernel());
302 if (diff_data_kernel_) CHECK(diff_data_kernel_->create_kernel());
303 return status::success;
304 }
305
306 jit_uni_layer_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {}
307 virtual ~jit_uni_layer_normalization_bwd_t() = default;
308
309 void reorder_stat(const exec_ctx_t &ctx, engine_t *engine,
310 const memory_arg_t &in, const memory_arg_t &out) const {
311 using namespace memory_tracking::names;
312 exec_args_t r_args;
313 r_args[DNNL_ARG_SRC] = in;
314 r_args[DNNL_ARG_DST] = out;
315 exec_ctx_t r_ctx(ctx, std::move(r_args));
316
317 nested_scratchpad_t ns(ctx, key_nested, reorder_);
318 r_ctx.set_scratchpad_grantor(ns.grantor());
319 reorder_->execute(r_ctx);
320 }
321
322 status_t execute(const exec_ctx_t &ctx) const override {
323 using namespace memory_tracking::names;
324 /* LN supports arbitrary layout for input/output statistics.
325 * For best performance we compute LN with statistics in the same format
326 * as data tensor (i.e. data in abcd, stats in abc) and user's
327 * input/output statistics are reordered if necessary */
328
329 if (reorder_) {
330 engine_t *engine = ctx.stream()->engine();
331 auto scratchpad = ctx.get_scratchpad_grantor();
332 auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean);
333 auto variance_mem
334 = scratchpad.get_memory_storage(key_lnorm_tmp_var);
335 memory_t mean(
336 engine, &(pd()->reordered_stat_md_), std::move(mean_mem));
337 memory_t variance(engine, &(pd()->reordered_stat_md_),
338 std::move(variance_mem));
339 reorder_stat(
340 ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false});
341 reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE),
342 {&variance, false});
343 }
344
345 return execute_backward(ctx);
346 }
347
348private:
349 status_t execute_backward(const exec_ctx_t &ctx) const;
350 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
351
352 std::unique_ptr<diff_ss_kernel_t> diff_ss_kernel_;
353 std::unique_ptr<diff_data_kernel_t> diff_data_kernel_;
354 std::shared_ptr<primitive_t> reorder_;
355};
356
357} // namespace x64
358} // namespace cpu
359} // namespace impl
360} // namespace dnnl
361
362#endif
363
364// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
365