1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_LAYER_NORMALIZATION_HPP |
18 | #define CPU_X64_JIT_UNI_LAYER_NORMALIZATION_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/reorder.hpp" |
27 | #include "common/stream.hpp" |
28 | #include "common/utils.hpp" |
29 | |
30 | #include "cpu/cpu_layer_normalization_pd.hpp" |
31 | |
32 | #include "cpu/x64/cpu_isa_traits.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | struct stat_and_data_kernel_t { |
40 | static stat_and_data_kernel_t *create(const layer_normalization_pd_t *pd); |
41 | virtual ~stat_and_data_kernel_t() = default; |
42 | |
43 | virtual void operator()(const void *src, void *dst, const float *scale, |
44 | const float *shift, float *mean, float *var, |
45 | const float *src_scales, const float *dst_scales, |
46 | const size_t block_size) const {}; |
47 | |
48 | virtual status_t create_kernel() { return status::success; } |
49 | |
50 | protected: |
51 | stat_and_data_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {} |
52 | |
53 | const layer_normalization_pd_t *pd_; |
54 | }; |
55 | |
56 | struct diff_ss_kernel_t { |
57 | static diff_ss_kernel_t *create(const layer_normalization_pd_t *pd); |
58 | virtual ~diff_ss_kernel_t() = default; |
59 | |
60 | virtual void operator()(const void *src, const void *diff_dst, |
61 | float *diff_gamma, float *diff_beta, const float *mean, |
62 | const float *var, float *const inv_sqrtvar, |
63 | const size_t block_size) const {}; |
64 | |
65 | virtual status_t create_kernel() { return status::success; } |
66 | |
67 | protected: |
68 | diff_ss_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {} |
69 | |
70 | const layer_normalization_pd_t *pd_; |
71 | }; |
72 | |
73 | struct diff_data_kernel_t { |
74 | static diff_data_kernel_t *create(const layer_normalization_pd_t *pd); |
75 | virtual ~diff_data_kernel_t() = default; |
76 | |
77 | virtual void operator()(const void *src, const void *diff_dst, |
78 | void *diff_src, const float *ss, const float *mean, |
79 | float *const inv_sqrtvar, const size_t block_size) const {}; |
80 | |
81 | virtual status_t create_kernel() { return status::success; } |
82 | |
83 | protected: |
84 | diff_data_kernel_t(const layer_normalization_pd_t *pd) : pd_(pd) {} |
85 | |
86 | const layer_normalization_pd_t *pd_; |
87 | }; |
88 | |
89 | struct jit_uni_layer_normalization_fwd_t : public primitive_t { |
90 | struct pd_t : public cpu_layer_normalization_fwd_pd_t { |
91 | using cpu_layer_normalization_fwd_pd_t:: |
92 | cpu_layer_normalization_fwd_pd_t; |
93 | |
94 | DECLARE_COMMON_PD_T("jit:uni" , jit_uni_layer_normalization_fwd_t); |
95 | |
96 | status_t init(engine_t *engine) { |
97 | using namespace data_type; |
98 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
99 | const memory_desc_wrapper src_d(src_md()); |
100 | |
101 | const bool ok = is_fwd() && !has_zero_dim_memory() |
102 | && utils::one_of( |
103 | src_md()->data_type, f32, bf16, f16, s8, u8) |
104 | && utils::one_of( |
105 | dst_md()->data_type, f32, bf16, f16, s8, u8) |
106 | && IMPLICATION(utils::one_of(bf16, src_md()->data_type, |
107 | dst_md()->data_type), |
108 | mayiuse(avx512_core)) |
109 | && IMPLICATION(utils::one_of(f16, src_md()->data_type, |
110 | dst_md()->data_type), |
111 | mayiuse(avx512_core_fp16)) |
112 | && stat_md()->data_type == f32 |
113 | && check_scale_shift_data_type() |
114 | && attr()->has_default_values(skip_mask_t::scales_runtime) |
115 | && attr_scales_ok() && set_default_formats_common() |
116 | && src_d.is_blocking_desc() |
117 | // plain format, last logical dim is last physical |
118 | && src_d.blocking_desc().strides[ndims() - 1] == 1; |
119 | if (!ok) return status::unimplemented; |
120 | |
121 | CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_)); |
122 | |
123 | if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) { |
124 | CHECK(reorder_primitive_desc_create(reorder_pd_, engine, |
125 | stats_are_src() ? stat_md() : &reordered_stat_md_, |
126 | stats_are_src() ? &reordered_stat_md_ : stat_md())); |
127 | } |
128 | |
129 | init_scratchpad(); |
130 | return status::success; |
131 | } |
132 | |
133 | bool use_tmp_stats() const { return reorder_pd_ || stats_are_tmp(); } |
134 | |
135 | std::shared_ptr<primitive_desc_t> reorder_pd_; |
136 | memory_desc_t reordered_stat_md_; |
137 | |
138 | private: |
139 | void init_scratchpad() { |
140 | using namespace memory_tracking::names; |
141 | auto scratchpad = scratchpad_registry().registrar(); |
142 | if (use_tmp_stats()) { |
143 | scratchpad.template book<float>( |
144 | key_lnorm_tmp_mean, across_axis()); |
145 | scratchpad.template book<float>( |
146 | key_lnorm_tmp_var, across_axis()); |
147 | } |
148 | if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) { |
149 | scratchpad.book(key_nested, reorder_pd_->scratchpad_registry()); |
150 | } |
151 | } |
152 | }; |
153 | |
154 | status_t init(engine_t *engine) override { |
155 | if (pd()->reorder_pd_) |
156 | pd()->reorder_pd_->create_primitive(reorder_, engine); |
157 | CHECK(safe_ptr_assign( |
158 | stat_and_data_kernel_, stat_and_data_kernel_t::create(pd()))); |
159 | if (stat_and_data_kernel_) |
160 | CHECK(stat_and_data_kernel_->create_kernel()); |
161 | return status::success; |
162 | } |
163 | |
164 | jit_uni_layer_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
165 | virtual ~jit_uni_layer_normalization_fwd_t() = default; |
166 | |
167 | void reorder_stat(const exec_ctx_t &ctx, engine_t *engine, |
168 | const memory_arg_t &in, const memory_arg_t &out) const { |
169 | using namespace memory_tracking::names; |
170 | exec_args_t r_args; |
171 | r_args[DNNL_ARG_SRC] = in; |
172 | r_args[DNNL_ARG_DST] = out; |
173 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
174 | |
175 | nested_scratchpad_t ns(ctx, key_nested, reorder_); |
176 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
177 | reorder_->execute(r_ctx); |
178 | } |
179 | |
180 | status_t execute(const exec_ctx_t &ctx) const override { |
181 | /* LN supports arbitrary layout for input/output statistics. |
182 | * For best performance we compute LN with statistics in the same format |
183 | * as data tensor (i.e. data in abcd, stats in abc) and user's |
184 | * input/output statistics are reordered if necessary */ |
185 | using namespace memory_tracking::names; |
186 | engine_t *engine = ctx.stream()->engine(); |
187 | auto scratchpad = ctx.get_scratchpad_grantor(); |
188 | auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); |
189 | auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var); |
190 | memory_t mean(engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); |
191 | memory_t variance( |
192 | engine, &(pd()->reordered_stat_md_), std::move(variance_mem)); |
193 | |
194 | // reorder input stats |
195 | if (pd()->stats_are_src() && reorder_) { |
196 | reorder_stat( |
197 | ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); |
198 | reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), |
199 | {&variance, false}); |
200 | } |
201 | status_t status = execute_forward(ctx); |
202 | if (status != status::success) return status; |
203 | // reorder output stats |
204 | if (!pd()->stats_are_src() && reorder_) { |
205 | reorder_stat( |
206 | ctx, engine, {&mean, true}, ctx.args().at(DNNL_ARG_MEAN)); |
207 | reorder_stat(ctx, engine, {&variance, true}, |
208 | ctx.args().at(DNNL_ARG_VARIANCE)); |
209 | } |
210 | |
211 | return status::success; |
212 | } |
213 | |
214 | private: |
215 | status_t execute_forward(const exec_ctx_t &ctx) const; |
216 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
217 | |
218 | std::unique_ptr<stat_and_data_kernel_t> stat_and_data_kernel_; |
219 | std::shared_ptr<primitive_t> reorder_; |
220 | }; |
221 | |
222 | struct jit_uni_layer_normalization_bwd_t : public primitive_t { |
223 | struct pd_t : public cpu_layer_normalization_bwd_pd_t { |
224 | using cpu_layer_normalization_bwd_pd_t:: |
225 | cpu_layer_normalization_bwd_pd_t; |
226 | |
227 | DECLARE_COMMON_PD_T("jit:uni" , jit_uni_layer_normalization_bwd_t); |
228 | |
229 | status_t init(engine_t *engine) { |
230 | using namespace data_type; |
231 | const memory_desc_wrapper src_d(src_md()); |
232 | |
233 | const bool ok = is_bwd() && !has_zero_dim_memory() |
234 | && mayiuse(avx2) // sse41 is not supported yet |
235 | && utils::one_of(src_md()->data_type, f32, bf16, f16) |
236 | && utils::one_of(diff_dst_md()->data_type, f32, bf16, f16) |
237 | && utils::one_of(diff_src_md()->data_type, f32, bf16, f16) |
238 | && IMPLICATION(utils::one_of(bf16, src_md()->data_type, |
239 | diff_dst_md()->data_type, |
240 | diff_src_md()->data_type), |
241 | mayiuse(avx512_core)) |
242 | && IMPLICATION(utils::one_of(f16, src_md()->data_type, |
243 | diff_dst_md()->data_type, |
244 | diff_src_md()->data_type), |
245 | mayiuse(avx512_core_fp16)) |
246 | && stat_md()->data_type == f32 |
247 | && check_scale_shift_data_type() |
248 | && attr()->has_default_values() |
249 | && set_default_formats_common() |
250 | && src_d.is_blocking_desc() |
251 | // plain format, last logical dim is last physical |
252 | && src_d.blocking_desc().strides[ndims() - 1] == 1; |
253 | if (!ok) return status::unimplemented; |
254 | |
255 | CHECK(fill_compatible_stats_md(*src_md(), reordered_stat_md_)); |
256 | |
257 | if (reordered_stat_md_ != *stat_md()) { |
258 | CHECK(reorder_primitive_desc_create( |
259 | reorder_pd_, engine, stat_md(), &reordered_stat_md_)); |
260 | } |
261 | |
262 | nthr_ = dnnl_get_max_threads(); |
263 | init_scratchpad(); |
264 | return status::success; |
265 | } |
266 | |
267 | bool use_tmp_stats() const { return reorder_pd_.get(); } |
268 | |
269 | std::shared_ptr<primitive_desc_t> reorder_pd_; |
270 | memory_desc_t reordered_stat_md_; |
271 | int nthr_; // To not exceed the limit in execute used for set up. |
272 | |
273 | private: |
274 | void init_scratchpad() { |
275 | using namespace memory_tracking::names; |
276 | auto scratchpad = scratchpad_registry().registrar(); |
277 | if (use_tmp_stats()) { |
278 | scratchpad.template book<float>( |
279 | key_lnorm_tmp_mean, across_axis()); |
280 | scratchpad.template book<float>( |
281 | key_lnorm_tmp_var, across_axis()); |
282 | } |
283 | scratchpad.template book<float>( |
284 | key_lnorm_reduction, 2 * norm_axis() * nthr_); |
285 | scratchpad.template book<float>( |
286 | key_lnorm_tmp_diff_ss, 2 * norm_axis()); |
287 | if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) { |
288 | scratchpad.book(key_nested, reorder_pd_->scratchpad_registry()); |
289 | } |
290 | scratchpad.template book<float>( |
291 | key_lnorm_inv_sqrtvar, across_axis()); |
292 | } |
293 | }; |
294 | |
295 | status_t init(engine_t *engine) override { |
296 | if (pd()->reorder_pd_) |
297 | pd()->reorder_pd_->create_primitive(reorder_, engine); |
298 | CHECK(safe_ptr_assign(diff_ss_kernel_, diff_ss_kernel_t::create(pd()))); |
299 | CHECK(safe_ptr_assign( |
300 | diff_data_kernel_, diff_data_kernel_t::create(pd()))); |
301 | if (diff_ss_kernel_) CHECK(diff_ss_kernel_->create_kernel()); |
302 | if (diff_data_kernel_) CHECK(diff_data_kernel_->create_kernel()); |
303 | return status::success; |
304 | } |
305 | |
306 | jit_uni_layer_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
307 | virtual ~jit_uni_layer_normalization_bwd_t() = default; |
308 | |
309 | void reorder_stat(const exec_ctx_t &ctx, engine_t *engine, |
310 | const memory_arg_t &in, const memory_arg_t &out) const { |
311 | using namespace memory_tracking::names; |
312 | exec_args_t r_args; |
313 | r_args[DNNL_ARG_SRC] = in; |
314 | r_args[DNNL_ARG_DST] = out; |
315 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
316 | |
317 | nested_scratchpad_t ns(ctx, key_nested, reorder_); |
318 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
319 | reorder_->execute(r_ctx); |
320 | } |
321 | |
322 | status_t execute(const exec_ctx_t &ctx) const override { |
323 | using namespace memory_tracking::names; |
324 | /* LN supports arbitrary layout for input/output statistics. |
325 | * For best performance we compute LN with statistics in the same format |
326 | * as data tensor (i.e. data in abcd, stats in abc) and user's |
327 | * input/output statistics are reordered if necessary */ |
328 | |
329 | if (reorder_) { |
330 | engine_t *engine = ctx.stream()->engine(); |
331 | auto scratchpad = ctx.get_scratchpad_grantor(); |
332 | auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); |
333 | auto variance_mem |
334 | = scratchpad.get_memory_storage(key_lnorm_tmp_var); |
335 | memory_t mean( |
336 | engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); |
337 | memory_t variance(engine, &(pd()->reordered_stat_md_), |
338 | std::move(variance_mem)); |
339 | reorder_stat( |
340 | ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); |
341 | reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), |
342 | {&variance, false}); |
343 | } |
344 | |
345 | return execute_backward(ctx); |
346 | } |
347 | |
348 | private: |
349 | status_t execute_backward(const exec_ctx_t &ctx) const; |
350 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
351 | |
352 | std::unique_ptr<diff_ss_kernel_t> diff_ss_kernel_; |
353 | std::unique_ptr<diff_data_kernel_t> diff_data_kernel_; |
354 | std::shared_ptr<primitive_t> reorder_; |
355 | }; |
356 | |
357 | } // namespace x64 |
358 | } // namespace cpu |
359 | } // namespace impl |
360 | } // namespace dnnl |
361 | |
362 | #endif |
363 | |
364 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
365 | |