1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_LAYER_NORMALIZATION_PD_HPP |
18 | #define COMMON_LAYER_NORMALIZATION_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | |
29 | struct layer_normalization_fwd_pd_t; |
30 | |
31 | struct layer_normalization_pd_t : public primitive_desc_t { |
32 | static constexpr auto base_pkind = primitive_kind::layer_normalization; |
33 | |
34 | const layer_normalization_desc_t *desc() const { return &desc_; } |
35 | const op_desc_t *op_desc() const override { |
36 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
37 | } |
38 | |
39 | status_t query(query_t what, int idx, void *result) const override { |
40 | switch (what) { |
41 | case query::prop_kind: |
42 | *(prop_kind_t *)result = desc()->prop_kind; |
43 | break; |
44 | case query::primitive_kind: |
45 | *(primitive_kind_t *)result = desc_.primitive_kind; |
46 | break; |
47 | case query::epsilon_f32: |
48 | *(float *)result = desc()->layer_norm_epsilon; |
49 | break; |
50 | case query::flags: *(uint32_t *)result = desc()->flags; break; |
51 | |
52 | default: return primitive_desc_t::query(what, idx, result); |
53 | } |
54 | return status::success; |
55 | } |
56 | |
57 | /* common layer_normalization aux functions */ |
58 | int ndims() const { return desc_.src_desc.ndims; } |
59 | dim_t across_axis() const { |
60 | return utils::array_product(desc_.src_desc.dims, ndims() - 1); |
61 | } |
62 | dim_t norm_axis() const { return desc_.src_desc.dims[ndims() - 1]; } |
63 | |
64 | bool stats_are_src() const { |
65 | return desc_.flags & normalization_flags::use_global_stats; |
66 | } |
67 | bool stats_are_tmp() const { return !(stats_are_src() || is_training()); } |
68 | |
69 | bool use_scale() const { |
70 | return desc_.flags & normalization_flags::use_scale; |
71 | } |
72 | bool use_shift() const { |
73 | return desc_.flags & normalization_flags::use_shift; |
74 | } |
75 | bool use_global_stats() const { |
76 | return desc_.flags & normalization_flags::use_global_stats; |
77 | } |
78 | |
79 | bool is_fwd() const { |
80 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
81 | prop_kind::forward_inference); |
82 | } |
83 | bool is_bwd() const { return !this->is_fwd(); } |
84 | bool is_training() const { |
85 | return desc_.prop_kind == prop_kind::forward_training; |
86 | } |
87 | |
88 | bool has_zero_dim_memory() const { |
89 | return memory_desc_wrapper(desc_.src_desc).has_zero_dim(); |
90 | } |
91 | |
92 | const memory_desc_t *stat_md() const { return &stat_md_; } |
93 | |
94 | protected: |
95 | layer_normalization_desc_t desc_; |
96 | const layer_normalization_fwd_pd_t *hint_fwd_pd_; |
97 | |
98 | memory_desc_t src_md_; |
99 | memory_desc_t stat_md_; |
100 | memory_desc_t scaleshift_md_; |
101 | |
102 | layer_normalization_pd_t(const layer_normalization_desc_t *adesc, |
103 | const primitive_attr_t *attr, |
104 | const layer_normalization_fwd_pd_t *hint_fwd_pd) |
105 | : primitive_desc_t(attr, base_pkind) |
106 | , desc_(*adesc) |
107 | , hint_fwd_pd_(hint_fwd_pd) |
108 | , src_md_(desc_.src_desc) |
109 | , stat_md_(desc_.stat_desc) |
110 | , scaleshift_md_(desc_.data_scaleshift_desc) {} |
111 | |
112 | bool set_default_stat_md_format(const memory_desc_t &src_md) { |
113 | if (stat_md_.format_kind != format_kind::any) return true; |
114 | |
115 | // src memory desc in non-blocked memory format is unsupported |
116 | if (src_md.format_kind != format_kind::blocked) return false; |
117 | |
118 | // if the normalization axis is blocked, fallback to plain format |
119 | bool is_norm_dim_blocked = false; |
120 | for (int d = 0; d < src_md.format_desc.blocking.inner_nblks; ++d) |
121 | is_norm_dim_blocked |
122 | |= src_md.format_desc.blocking.inner_idxs[d] == ndims() - 1; |
123 | if (is_norm_dim_blocked) |
124 | return memory_desc_init_by_strides(stat_md_, nullptr) |
125 | == status::success; |
126 | |
127 | // the default memory format for stat is derived from src_md by |
128 | // dropping the normalization dimension and keeping the physical order |
129 | // of other dimensions (preserving the blocked structure if any) |
130 | return memory_desc_init_by_blocking_desc( |
131 | stat_md_, src_md.format_desc.blocking) |
132 | == status::success; |
133 | } |
134 | |
135 | // Stats and src here are compatible if: |
136 | // `stat_strides[:] == data_strides[:] / last_data_dimension` |
137 | // i.e. abcd & abc, bacd & bac - compatible |
138 | status_t fill_compatible_stats_md( |
139 | const memory_desc_t &src_md, memory_desc_t &stat_md) { |
140 | stat_md = src_md; |
141 | stat_md.data_type = dnnl_f32; |
142 | stat_md.ndims -= 1; |
143 | return memory_desc_init_by_blocking_desc( |
144 | stat_md, src_md.format_desc.blocking); |
145 | } |
146 | |
147 | private: |
148 | const memory_desc_t &src_desc() const { return desc_.src_desc; } |
149 | }; |
150 | |
151 | struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { |
152 | typedef layer_normalization_fwd_pd_t base_class; |
153 | typedef layer_normalization_fwd_pd_t hint_class; |
154 | |
155 | arg_usage_t arg_usage(int arg) const override { |
156 | if (arg == DNNL_ARG_SRC) return arg_usage_t::input; |
157 | if (arg == DNNL_ARG_DST) return arg_usage_t::output; |
158 | |
159 | if (utils::one_of(arg, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE)) { |
160 | if (stats_are_src()) return arg_usage_t::input; |
161 | if (!stats_are_src() && is_training()) return arg_usage_t::output; |
162 | return arg_usage_t::unused; |
163 | } |
164 | |
165 | if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; |
166 | if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; |
167 | |
168 | return primitive_desc_t::arg_usage(arg); |
169 | } |
170 | |
171 | const memory_desc_t *arg_md(int arg) const override { |
172 | switch (arg) { |
173 | case DNNL_ARG_SRC: return src_md(0); |
174 | case DNNL_ARG_DST: return dst_md(0); |
175 | case DNNL_ARG_MEAN: return stats_are_src() ? src_md(1) : dst_md(1); |
176 | case DNNL_ARG_VARIANCE: |
177 | return stats_are_src() ? src_md(2) : dst_md(2); |
178 | case DNNL_ARG_SCALE: |
179 | case DNNL_ARG_SHIFT: return weights_md(0); |
180 | default: return layer_normalization_pd_t::arg_md(arg); |
181 | } |
182 | } |
183 | |
184 | const memory_desc_t *src_md(int index = 0) const override { |
185 | if (index == 0) return &src_md_; |
186 | if (stats_are_src() && (index == 1 || index == 2)) return &stat_md_; |
187 | return &glob_zero_md; |
188 | } |
189 | |
190 | const memory_desc_t *dst_md(int index = 0) const override { |
191 | if (index == 0) return &dst_md_; |
192 | if (!stats_are_src() && is_training() && (index == 1 || index == 2)) |
193 | return &stat_md_; |
194 | return &glob_zero_md; |
195 | } |
196 | |
197 | const memory_desc_t *weights_md(int index = 0) const override { |
198 | return index == 0 ? &scaleshift_md_ : &glob_zero_md; |
199 | } |
200 | |
201 | int n_inputs() const override { |
202 | return 1 + 2 * stats_are_src() + use_scale() + use_shift(); |
203 | } |
204 | int n_outputs() const override { |
205 | return 1 + 2 * (!stats_are_src()) * is_training(); |
206 | } |
207 | |
208 | protected: |
209 | memory_desc_t dst_md_; |
210 | |
211 | layer_normalization_fwd_pd_t(const layer_normalization_desc_t *adesc, |
212 | const primitive_attr_t *attr, |
213 | const layer_normalization_fwd_pd_t *hint_fwd_pd) |
214 | : layer_normalization_pd_t(adesc, attr, hint_fwd_pd) |
215 | , dst_md_(desc_.dst_desc) {} |
216 | |
217 | bool set_default_formats_common() { |
218 | return IMPLICATION(dst_md_.format_kind == format_kind::any, |
219 | memory_desc_init_by_md_and_dt( |
220 | dst_md_, src_md_, dst_md_.data_type) |
221 | == status::success) |
222 | && set_default_stat_md_format(src_md_); |
223 | } |
224 | |
225 | bool check_scale_shift_data_type() const { |
226 | return IMPLICATION(use_scale() || use_shift(), |
227 | weights_md()->data_type == data_type::f32); |
228 | } |
229 | |
230 | bool attr_scales_ok() const { |
231 | const auto &scales = attr()->scales_; |
232 | bool ok = true; |
233 | for (const auto &e : scales.scales_) { |
234 | ok = ok && e.second.mask_ == 0; |
235 | } |
236 | return ok; |
237 | } |
238 | }; |
239 | |
240 | struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t { |
241 | typedef layer_normalization_bwd_pd_t base_class; |
242 | typedef layer_normalization_fwd_pd_t hint_class; |
243 | |
244 | arg_usage_t arg_usage(int arg) const override { |
245 | if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, |
246 | DNNL_ARG_DIFF_DST)) |
247 | return arg_usage_t::input; |
248 | |
249 | if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; |
250 | if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; |
251 | |
252 | if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; |
253 | |
254 | if (arg == DNNL_ARG_DIFF_SCALE && use_scale()) |
255 | return arg_usage_t::output; |
256 | if (arg == DNNL_ARG_DIFF_SHIFT && use_shift()) |
257 | return arg_usage_t::output; |
258 | |
259 | return primitive_desc_t::arg_usage(arg); |
260 | } |
261 | |
262 | const memory_desc_t *arg_md(int arg) const override { |
263 | switch (arg) { |
264 | case DNNL_ARG_SRC: return src_md(0); |
265 | case DNNL_ARG_MEAN: return src_md(1); |
266 | case DNNL_ARG_VARIANCE: return src_md(2); |
267 | case DNNL_ARG_SCALE: |
268 | case DNNL_ARG_SHIFT: return weights_md(0); |
269 | case DNNL_ARG_DIFF_SRC: return diff_src_md(0); |
270 | case DNNL_ARG_DIFF_DST: return diff_dst_md(0); |
271 | case DNNL_ARG_DIFF_SCALE: |
272 | case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0); |
273 | default: return layer_normalization_pd_t::arg_md(arg); |
274 | } |
275 | } |
276 | |
277 | const memory_desc_t *src_md(int index = 0) const override { |
278 | return index == 0 ? &src_md_ : index <= 2 ? &stat_md_ : &glob_zero_md; |
279 | } |
280 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
281 | return index == 0 ? &diff_dst_md_ : &glob_zero_md; |
282 | } |
283 | const memory_desc_t *diff_src_md(int index = 0) const override { |
284 | return index == 0 ? &diff_src_md_ : &glob_zero_md; |
285 | } |
286 | |
287 | const memory_desc_t *weights_md(int index = 0) const override { |
288 | return index == 0 ? &scaleshift_md_ : &glob_zero_md; |
289 | } |
290 | const memory_desc_t *diff_weights_md(int index = 0) const override { |
291 | return index == 0 ? &diff_scaleshift_md_ : &glob_zero_md; |
292 | } |
293 | |
294 | int n_inputs() const override { return 4 + use_scale() + use_shift(); } |
295 | int n_outputs() const override { |
296 | return 1 |
297 | + (desc_.prop_kind == prop_kind::backward) |
298 | * (use_scale() + use_shift()); |
299 | } |
300 | |
301 | protected: |
302 | memory_desc_t diff_src_md_; |
303 | memory_desc_t diff_dst_md_; |
304 | memory_desc_t diff_scaleshift_md_; |
305 | |
306 | layer_normalization_bwd_pd_t(const layer_normalization_desc_t *adesc, |
307 | const primitive_attr_t *attr, |
308 | const layer_normalization_fwd_pd_t *hint_fwd_pd) |
309 | : layer_normalization_pd_t(adesc, attr, hint_fwd_pd) |
310 | , diff_src_md_(desc_.diff_src_desc) |
311 | , diff_dst_md_(desc_.diff_dst_desc) |
312 | , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) {} |
313 | |
314 | bool set_default_formats_common() { |
315 | return IMPLICATION(diff_dst_md_.format_kind == format_kind::any, |
316 | memory_desc_init_by_md_and_dt( |
317 | diff_dst_md_, src_md_, diff_dst_md_.data_type) |
318 | == status::success) |
319 | && IMPLICATION(diff_src_md_.format_kind == format_kind::any, |
320 | memory_desc_init_by_md_and_dt( |
321 | diff_src_md_, src_md_, diff_src_md_.data_type) |
322 | == status::success) |
323 | && set_default_stat_md_format(diff_src_md_); |
324 | } |
325 | |
326 | bool check_scale_shift_data_type() const { |
327 | return IMPLICATION(use_scale() || use_shift(), |
328 | utils::everyone_is(data_type::f32, weights_md()->data_type, |
329 | diff_weights_md()->data_type)); |
330 | } |
331 | }; |
332 | |
333 | } // namespace impl |
334 | } // namespace dnnl |
335 | |
336 | #endif |
337 | |
338 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
339 | |