1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_SIMPLE_LAYER_NORMALIZATION_HPP |
18 | #define CPU_SIMPLE_LAYER_NORMALIZATION_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/reorder_pd.hpp" |
27 | #include "common/stream.hpp" |
28 | #include "common/utils.hpp" |
29 | |
30 | #include "cpu/cpu_layer_normalization_pd.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | |
36 | struct simple_layer_normalization_fwd_t : public primitive_t { |
37 | struct pd_t : public cpu_layer_normalization_fwd_pd_t { |
38 | using cpu_layer_normalization_fwd_pd_t:: |
39 | cpu_layer_normalization_fwd_pd_t; |
40 | |
41 | DECLARE_COMMON_PD_T("simple:any" , simple_layer_normalization_fwd_t); |
42 | |
43 | status_t init(engine_t *engine); |
44 | |
45 | bool use_tmp_stats() const { return reorder_pd_ || stats_are_tmp(); } |
46 | |
47 | std::shared_ptr<primitive_desc_t> reorder_pd_; |
48 | memory_desc_t reordered_stat_md_; |
49 | |
50 | private: |
51 | void init_scratchpad() { |
52 | using namespace memory_tracking::names; |
53 | auto scratchpad = scratchpad_registry().registrar(); |
54 | if (use_tmp_stats()) { |
55 | scratchpad.template book<float>( |
56 | key_lnorm_tmp_mean, across_axis()); |
57 | scratchpad.template book<float>( |
58 | key_lnorm_tmp_var, across_axis()); |
59 | } |
60 | if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) { |
61 | scratchpad.book(key_nested, reorder_pd_->scratchpad_registry()); |
62 | } |
63 | } |
64 | }; |
65 | |
66 | status_t init(engine_t *engine) override { |
67 | if (pd()->reorder_pd_) |
68 | pd()->reorder_pd_->create_primitive(reorder_, engine); |
69 | return status::success; |
70 | } |
71 | |
72 | simple_layer_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
73 | |
74 | void reorder_stat(const exec_ctx_t &ctx, engine_t *engine, |
75 | const memory_arg_t &in, const memory_arg_t &out) const { |
76 | using namespace memory_tracking::names; |
77 | exec_args_t r_args; |
78 | r_args[DNNL_ARG_SRC] = in; |
79 | r_args[DNNL_ARG_DST] = out; |
80 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
81 | |
82 | nested_scratchpad_t ns(ctx, key_nested, reorder_); |
83 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
84 | reorder_->execute(r_ctx); |
85 | } |
86 | |
87 | status_t execute(const exec_ctx_t &ctx) const override { |
88 | /* LN supports arbitrary layout for input/output statistics. |
89 | * For best performance we compute LN with statistics in the same format |
90 | * as data tensor (i.e. data in abcd, stats in abc) and user's |
91 | * input/output statistics are reordered if necessary */ |
92 | using namespace memory_tracking::names; |
93 | engine_t *engine = ctx.stream()->engine(); |
94 | auto scratchpad = ctx.get_scratchpad_grantor(); |
95 | auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); |
96 | auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var); |
97 | memory_t mean(engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); |
98 | memory_t variance( |
99 | engine, &(pd()->reordered_stat_md_), std::move(variance_mem)); |
100 | |
101 | // reorder input stats |
102 | if (pd()->stats_are_src() && reorder_) { |
103 | reorder_stat( |
104 | ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); |
105 | reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), |
106 | {&variance, false}); |
107 | } |
108 | status_t status = execute_forward(ctx); |
109 | if (status != status::success) return status; |
110 | // reorder output stats |
111 | if (!pd()->stats_are_src() && reorder_) { |
112 | reorder_stat( |
113 | ctx, engine, {&mean, true}, ctx.args().at(DNNL_ARG_MEAN)); |
114 | reorder_stat(ctx, engine, {&variance, true}, |
115 | ctx.args().at(DNNL_ARG_VARIANCE)); |
116 | } |
117 | |
118 | return status::success; |
119 | } |
120 | |
121 | private: |
122 | status_t execute_forward(const exec_ctx_t &ctx) const; |
123 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
124 | |
125 | std::shared_ptr<primitive_t> reorder_; |
126 | }; |
127 | |
128 | struct simple_layer_normalization_bwd_t : public primitive_t { |
129 | struct pd_t : public cpu_layer_normalization_bwd_pd_t { |
130 | using cpu_layer_normalization_bwd_pd_t:: |
131 | cpu_layer_normalization_bwd_pd_t; |
132 | |
133 | DECLARE_COMMON_PD_T("simple:any" , simple_layer_normalization_bwd_t); |
134 | |
135 | status_t init(engine_t *engine); |
136 | |
137 | bool use_tmp_stats() const { return reorder_pd_.get(); } |
138 | |
139 | std::shared_ptr<primitive_desc_t> reorder_pd_; |
140 | memory_desc_t reordered_stat_md_; |
141 | int nthr_; // To not exceed the limit in execute used for set up. |
142 | |
143 | private: |
144 | void init_scratchpad() { |
145 | using namespace memory_tracking::names; |
146 | auto scratchpad = scratchpad_registry().registrar(); |
147 | if (use_tmp_stats()) { |
148 | scratchpad.template book<float>( |
149 | key_lnorm_tmp_mean, across_axis()); |
150 | scratchpad.template book<float>( |
151 | key_lnorm_tmp_var, across_axis()); |
152 | } |
153 | scratchpad.template book<float>( |
154 | key_lnorm_reduction, 2 * norm_axis() * nthr_); |
155 | scratchpad.template book<float>( |
156 | key_lnorm_tmp_diff_ss, 2 * norm_axis()); |
157 | if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) { |
158 | scratchpad.book(key_nested, reorder_pd_->scratchpad_registry()); |
159 | } |
160 | scratchpad.template book<float>( |
161 | key_lnorm_inv_sqrtvar, across_axis()); |
162 | } |
163 | }; |
164 | |
165 | status_t init(engine_t *engine) override { |
166 | if (pd()->reorder_pd_) |
167 | pd()->reorder_pd_->create_primitive(reorder_, engine); |
168 | return status::success; |
169 | } |
170 | |
171 | simple_layer_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
172 | |
173 | void reorder_stat(const exec_ctx_t &ctx, engine_t *engine, |
174 | const memory_arg_t &in, const memory_arg_t &out) const { |
175 | using namespace memory_tracking::names; |
176 | exec_args_t r_args; |
177 | r_args[DNNL_ARG_SRC] = in; |
178 | r_args[DNNL_ARG_DST] = out; |
179 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
180 | |
181 | nested_scratchpad_t ns(ctx, key_nested, reorder_); |
182 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
183 | reorder_->execute(r_ctx); |
184 | } |
185 | |
186 | status_t execute(const exec_ctx_t &ctx) const override { |
187 | using namespace memory_tracking::names; |
188 | /* LN supports arbitrary layout for input/output statistics. |
189 | * For best performance we compute LN with statistics in the same format |
190 | * as data tensor (i.e. data in abcd, stats in abc) and user's |
191 | * input/output statistics are reordered if necessary */ |
192 | |
193 | if (reorder_) { |
194 | engine_t *engine = ctx.stream()->engine(); |
195 | auto scratchpad = ctx.get_scratchpad_grantor(); |
196 | auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); |
197 | auto variance_mem |
198 | = scratchpad.get_memory_storage(key_lnorm_tmp_var); |
199 | memory_t mean( |
200 | engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); |
201 | memory_t variance(engine, &(pd()->reordered_stat_md_), |
202 | std::move(variance_mem)); |
203 | reorder_stat( |
204 | ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); |
205 | reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), |
206 | {&variance, false}); |
207 | } |
208 | |
209 | return execute_backward(ctx); |
210 | } |
211 | |
212 | private: |
213 | status_t execute_backward(const exec_ctx_t &ctx) const; |
214 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
215 | |
216 | std::shared_ptr<primitive_t> reorder_; |
217 | }; |
218 | |
219 | } // namespace cpu |
220 | } // namespace impl |
221 | } // namespace dnnl |
222 | |
223 | #endif |
224 | |
225 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
226 | |