1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_LAYER_NORMALIZATION_PD_HPP
18#define COMMON_LAYER_NORMALIZATION_PD_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "utils.hpp"
25
26namespace dnnl {
27namespace impl {
28
29struct layer_normalization_fwd_pd_t;
30
31struct layer_normalization_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::layer_normalization;
33
34 const layer_normalization_desc_t *desc() const { return &desc_; }
35 const op_desc_t *op_desc() const override {
36 return reinterpret_cast<const op_desc_t *>(this->desc());
37 }
38
39 status_t query(query_t what, int idx, void *result) const override {
40 switch (what) {
41 case query::prop_kind:
42 *(prop_kind_t *)result = desc()->prop_kind;
43 break;
44 case query::primitive_kind:
45 *(primitive_kind_t *)result = desc_.primitive_kind;
46 break;
47 case query::epsilon_f32:
48 *(float *)result = desc()->layer_norm_epsilon;
49 break;
50 case query::flags: *(uint32_t *)result = desc()->flags; break;
51
52 default: return primitive_desc_t::query(what, idx, result);
53 }
54 return status::success;
55 }
56
57 /* common layer_normalization aux functions */
58 int ndims() const { return desc_.src_desc.ndims; }
59 dim_t across_axis() const {
60 return utils::array_product(desc_.src_desc.dims, ndims() - 1);
61 }
62 dim_t norm_axis() const { return desc_.src_desc.dims[ndims() - 1]; }
63
64 bool stats_are_src() const {
65 return desc_.flags & normalization_flags::use_global_stats;
66 }
67 bool stats_are_tmp() const { return !(stats_are_src() || is_training()); }
68
69 bool use_scale() const {
70 return desc_.flags & normalization_flags::use_scale;
71 }
72 bool use_shift() const {
73 return desc_.flags & normalization_flags::use_shift;
74 }
75 bool use_global_stats() const {
76 return desc_.flags & normalization_flags::use_global_stats;
77 }
78
79 bool is_fwd() const {
80 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
81 prop_kind::forward_inference);
82 }
83 bool is_bwd() const { return !this->is_fwd(); }
84 bool is_training() const {
85 return desc_.prop_kind == prop_kind::forward_training;
86 }
87
88 bool has_zero_dim_memory() const {
89 return memory_desc_wrapper(desc_.src_desc).has_zero_dim();
90 }
91
92 const memory_desc_t *stat_md() const { return &stat_md_; }
93
94protected:
95 layer_normalization_desc_t desc_;
96 const layer_normalization_fwd_pd_t *hint_fwd_pd_;
97
98 memory_desc_t src_md_;
99 memory_desc_t stat_md_;
100 memory_desc_t scaleshift_md_;
101
102 layer_normalization_pd_t(const layer_normalization_desc_t *adesc,
103 const primitive_attr_t *attr,
104 const layer_normalization_fwd_pd_t *hint_fwd_pd)
105 : primitive_desc_t(attr, base_pkind)
106 , desc_(*adesc)
107 , hint_fwd_pd_(hint_fwd_pd)
108 , src_md_(desc_.src_desc)
109 , stat_md_(desc_.stat_desc)
110 , scaleshift_md_(desc_.data_scaleshift_desc) {}
111
112 bool set_default_stat_md_format(const memory_desc_t &src_md) {
113 if (stat_md_.format_kind != format_kind::any) return true;
114
115 // src memory desc in non-blocked memory format is unsupported
116 if (src_md.format_kind != format_kind::blocked) return false;
117
118 // if the normalization axis is blocked, fallback to plain format
119 bool is_norm_dim_blocked = false;
120 for (int d = 0; d < src_md.format_desc.blocking.inner_nblks; ++d)
121 is_norm_dim_blocked
122 |= src_md.format_desc.blocking.inner_idxs[d] == ndims() - 1;
123 if (is_norm_dim_blocked)
124 return memory_desc_init_by_strides(stat_md_, nullptr)
125 == status::success;
126
127 // the default memory format for stat is derived from src_md by
128 // dropping the normalization dimension and keeping the physical order
129 // of other dimensions (preserving the blocked structure if any)
130 return memory_desc_init_by_blocking_desc(
131 stat_md_, src_md.format_desc.blocking)
132 == status::success;
133 }
134
135 // Stats and src here are compatible if:
136 // `stat_strides[:] == data_strides[:] / last_data_dimension`
137 // i.e. abcd & abc, bacd & bac - compatible
138 status_t fill_compatible_stats_md(
139 const memory_desc_t &src_md, memory_desc_t &stat_md) {
140 stat_md = src_md;
141 stat_md.data_type = dnnl_f32;
142 stat_md.ndims -= 1;
143 return memory_desc_init_by_blocking_desc(
144 stat_md, src_md.format_desc.blocking);
145 }
146
147private:
148 const memory_desc_t &src_desc() const { return desc_.src_desc; }
149};
150
151struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t {
152 typedef layer_normalization_fwd_pd_t base_class;
153 typedef layer_normalization_fwd_pd_t hint_class;
154
155 arg_usage_t arg_usage(int arg) const override {
156 if (arg == DNNL_ARG_SRC) return arg_usage_t::input;
157 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
158
159 if (utils::one_of(arg, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE)) {
160 if (stats_are_src()) return arg_usage_t::input;
161 if (!stats_are_src() && is_training()) return arg_usage_t::output;
162 return arg_usage_t::unused;
163 }
164
165 if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input;
166 if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input;
167
168 return primitive_desc_t::arg_usage(arg);
169 }
170
171 const memory_desc_t *arg_md(int arg) const override {
172 switch (arg) {
173 case DNNL_ARG_SRC: return src_md(0);
174 case DNNL_ARG_DST: return dst_md(0);
175 case DNNL_ARG_MEAN: return stats_are_src() ? src_md(1) : dst_md(1);
176 case DNNL_ARG_VARIANCE:
177 return stats_are_src() ? src_md(2) : dst_md(2);
178 case DNNL_ARG_SCALE:
179 case DNNL_ARG_SHIFT: return weights_md(0);
180 default: return layer_normalization_pd_t::arg_md(arg);
181 }
182 }
183
184 const memory_desc_t *src_md(int index = 0) const override {
185 if (index == 0) return &src_md_;
186 if (stats_are_src() && (index == 1 || index == 2)) return &stat_md_;
187 return &glob_zero_md;
188 }
189
190 const memory_desc_t *dst_md(int index = 0) const override {
191 if (index == 0) return &dst_md_;
192 if (!stats_are_src() && is_training() && (index == 1 || index == 2))
193 return &stat_md_;
194 return &glob_zero_md;
195 }
196
197 const memory_desc_t *weights_md(int index = 0) const override {
198 return index == 0 ? &scaleshift_md_ : &glob_zero_md;
199 }
200
201 int n_inputs() const override {
202 return 1 + 2 * stats_are_src() + use_scale() + use_shift();
203 }
204 int n_outputs() const override {
205 return 1 + 2 * (!stats_are_src()) * is_training();
206 }
207
208protected:
209 memory_desc_t dst_md_;
210
211 layer_normalization_fwd_pd_t(const layer_normalization_desc_t *adesc,
212 const primitive_attr_t *attr,
213 const layer_normalization_fwd_pd_t *hint_fwd_pd)
214 : layer_normalization_pd_t(adesc, attr, hint_fwd_pd)
215 , dst_md_(desc_.dst_desc) {}
216
217 bool set_default_formats_common() {
218 return IMPLICATION(dst_md_.format_kind == format_kind::any,
219 memory_desc_init_by_md_and_dt(
220 dst_md_, src_md_, dst_md_.data_type)
221 == status::success)
222 && set_default_stat_md_format(src_md_);
223 }
224
225 bool check_scale_shift_data_type() const {
226 return IMPLICATION(use_scale() || use_shift(),
227 weights_md()->data_type == data_type::f32);
228 }
229
230 bool attr_scales_ok() const {
231 const auto &scales = attr()->scales_;
232 bool ok = true;
233 for (const auto &e : scales.scales_) {
234 ok = ok && e.second.mask_ == 0;
235 }
236 return ok;
237 }
238};
239
240struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t {
241 typedef layer_normalization_bwd_pd_t base_class;
242 typedef layer_normalization_fwd_pd_t hint_class;
243
244 arg_usage_t arg_usage(int arg) const override {
245 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE,
246 DNNL_ARG_DIFF_DST))
247 return arg_usage_t::input;
248
249 if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input;
250 if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input;
251
252 if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output;
253
254 if (arg == DNNL_ARG_DIFF_SCALE && use_scale())
255 return arg_usage_t::output;
256 if (arg == DNNL_ARG_DIFF_SHIFT && use_shift())
257 return arg_usage_t::output;
258
259 return primitive_desc_t::arg_usage(arg);
260 }
261
262 const memory_desc_t *arg_md(int arg) const override {
263 switch (arg) {
264 case DNNL_ARG_SRC: return src_md(0);
265 case DNNL_ARG_MEAN: return src_md(1);
266 case DNNL_ARG_VARIANCE: return src_md(2);
267 case DNNL_ARG_SCALE:
268 case DNNL_ARG_SHIFT: return weights_md(0);
269 case DNNL_ARG_DIFF_SRC: return diff_src_md(0);
270 case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
271 case DNNL_ARG_DIFF_SCALE:
272 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0);
273 default: return layer_normalization_pd_t::arg_md(arg);
274 }
275 }
276
277 const memory_desc_t *src_md(int index = 0) const override {
278 return index == 0 ? &src_md_ : index <= 2 ? &stat_md_ : &glob_zero_md;
279 }
280 const memory_desc_t *diff_dst_md(int index = 0) const override {
281 return index == 0 ? &diff_dst_md_ : &glob_zero_md;
282 }
283 const memory_desc_t *diff_src_md(int index = 0) const override {
284 return index == 0 ? &diff_src_md_ : &glob_zero_md;
285 }
286
287 const memory_desc_t *weights_md(int index = 0) const override {
288 return index == 0 ? &scaleshift_md_ : &glob_zero_md;
289 }
290 const memory_desc_t *diff_weights_md(int index = 0) const override {
291 return index == 0 ? &diff_scaleshift_md_ : &glob_zero_md;
292 }
293
294 int n_inputs() const override { return 4 + use_scale() + use_shift(); }
295 int n_outputs() const override {
296 return 1
297 + (desc_.prop_kind == prop_kind::backward)
298 * (use_scale() + use_shift());
299 }
300
301protected:
302 memory_desc_t diff_src_md_;
303 memory_desc_t diff_dst_md_;
304 memory_desc_t diff_scaleshift_md_;
305
306 layer_normalization_bwd_pd_t(const layer_normalization_desc_t *adesc,
307 const primitive_attr_t *attr,
308 const layer_normalization_fwd_pd_t *hint_fwd_pd)
309 : layer_normalization_pd_t(adesc, attr, hint_fwd_pd)
310 , diff_src_md_(desc_.diff_src_desc)
311 , diff_dst_md_(desc_.diff_dst_desc)
312 , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) {}
313
314 bool set_default_formats_common() {
315 return IMPLICATION(diff_dst_md_.format_kind == format_kind::any,
316 memory_desc_init_by_md_and_dt(
317 diff_dst_md_, src_md_, diff_dst_md_.data_type)
318 == status::success)
319 && IMPLICATION(diff_src_md_.format_kind == format_kind::any,
320 memory_desc_init_by_md_and_dt(
321 diff_src_md_, src_md_, diff_src_md_.data_type)
322 == status::success)
323 && set_default_stat_md_format(diff_src_md_);
324 }
325
326 bool check_scale_shift_data_type() const {
327 return IMPLICATION(use_scale() || use_shift(),
328 utils::everyone_is(data_type::f32, weights_md()->data_type,
329 diff_weights_md()->data_type));
330 }
331};
332
333} // namespace impl
334} // namespace dnnl
335
336#endif
337
338// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
339