1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_BATCH_NORMALIZATION_PD_HPP
18#define COMMON_BATCH_NORMALIZATION_PD_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "utils.hpp"
25
26namespace dnnl {
27namespace impl {
28
29struct batch_normalization_fwd_pd_t;
30
31struct batch_normalization_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::batch_normalization;
33
34 const batch_normalization_desc_t *desc() const { return &desc_; }
35 const op_desc_t *op_desc() const override {
36 return reinterpret_cast<const op_desc_t *>(this->desc());
37 }
38
39 status_t query(query_t what, int idx, void *result) const override {
40 switch (what) {
41 case query::prop_kind:
42 *(prop_kind_t *)result = desc()->prop_kind;
43 break;
44 case query::epsilon_f32:
45 *(float *)result = desc()->batch_norm_epsilon;
46 break;
47 case query::flags: *(uint32_t *)result = desc()->flags; break;
48 default: return primitive_desc_t::query(what, idx, result);
49 }
50 return status::success;
51 }
52
53 /* common batch_normalization aux functions */
54
55 dim_t MB() const { return src_md()->dims[0]; }
56 dim_t C() const { return src_md()->dims[1]; }
57 dim_t D() const { return ndims() >= 5 ? src_md()->dims[ndims() - 3] : 1; }
58 dim_t H() const { return ndims() >= 4 ? src_md()->dims[ndims() - 2] : 1; }
59 dim_t W() const { return ndims() >= 3 ? src_md()->dims[ndims() - 1] : 1; }
60
61 int ndims() const { return src_md()->ndims; }
62
63 bool stats_is_src() const {
64 return desc_.flags & normalization_flags::use_global_stats;
65 }
66 bool use_scale() const {
67 return desc_.flags & normalization_flags::use_scale;
68 }
69 bool use_shift() const {
70 return desc_.flags & normalization_flags::use_shift;
71 }
72 bool use_global_stats() const {
73 return desc_.flags & normalization_flags::use_global_stats;
74 }
75 bool fuse_norm_relu() const {
76 return desc_.flags & normalization_flags::fuse_norm_relu;
77 }
78 bool fuse_norm_add_relu() const {
79 return desc_.flags & normalization_flags::fuse_norm_add_relu;
80 }
81 bool with_relu_post_op(bool require_nslope_zero = true) const {
82 const auto &p = this->attr()->post_ops_;
83 const bool nslope_zero_ok
84 = IMPLICATION(is_training(), require_nslope_zero);
85 return p.len() == 1 && p.entry_[0].is_relu(true, require_nslope_zero)
86 && nslope_zero_ok;
87 }
88
89 float alpha() const {
90 const auto &p = attr()->post_ops_;
91 const bool entry_size_ok = p.entry_.size() > 0;
92 assert(entry_size_ok || fuse_norm_relu() || fuse_norm_add_relu());
93 if (entry_size_ok) return p.entry_[0].eltwise.alpha;
94 return 0.f;
95 }
96
97 bool is_fwd() const {
98 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
99 prop_kind::forward_inference);
100 }
101
102 bool is_training() const {
103 return desc_.prop_kind == prop_kind::forward_training;
104 }
105
106 bool has_zero_dim_memory() const {
107 return memory_desc_wrapper(src_md()).has_zero_dim();
108 }
109
110protected:
111 batch_normalization_desc_t desc_;
112 const batch_normalization_fwd_pd_t *hint_fwd_pd_;
113
114 memory_desc_t src_md_;
115 memory_desc_t stat_md_;
116 memory_desc_t scaleshift_md_;
117
118 memory_desc_t ws_md_;
119
120 batch_normalization_pd_t(const batch_normalization_desc_t *adesc,
121 const primitive_attr_t *attr,
122 const batch_normalization_fwd_pd_t *hint_fwd_pd)
123 : primitive_desc_t(attr, base_pkind)
124 , desc_(*adesc)
125 , hint_fwd_pd_(hint_fwd_pd)
126 , src_md_(desc_.src_desc)
127 , stat_md_(desc_.stat_desc)
128 , scaleshift_md_(desc_.scaleshift_desc)
129 , ws_md_() {}
130
131 virtual void init_default_ws(size_t bits_per_element) {
132 const auto src_mdw = memory_desc_wrapper(src_md_);
133
134 const dim_t nelems = src_mdw.nelems(true);
135 const dim_t bits_per_byte = 8;
136 const dims_t ws_sz = {
137 (dim_t)utils::div_up(nelems * bits_per_element, bits_per_byte)};
138 memory_desc_init_by_tag(ws_md_, 1, ws_sz, data_type::u8, format_tag::x);
139 }
140};
141
142struct batch_normalization_fwd_pd_t : public batch_normalization_pd_t {
143 typedef batch_normalization_fwd_pd_t base_class;
144 typedef batch_normalization_fwd_pd_t hint_class;
145
146 arg_usage_t arg_usage(int arg) const override {
147 if (arg == DNNL_ARG_SRC) return arg_usage_t::input;
148 if (arg == DNNL_ARG_SRC_1 && fuse_norm_add_relu())
149 return arg_usage_t::input;
150 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
151
152 if (utils::one_of(arg, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE)) {
153 if (stats_is_src()) return arg_usage_t::input;
154 if (!stats_is_src() && is_training()) return arg_usage_t::output;
155 return arg_usage_t::unused;
156 }
157
158 if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input;
159 if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input;
160
161 if (arg == DNNL_ARG_WORKSPACE && !types::is_zero_md(workspace_md()))
162 return arg_usage_t::output;
163
164 return primitive_desc_t::arg_usage(arg);
165 }
166
167 const memory_desc_t *arg_md(int arg) const override {
168 switch (arg) {
169 case DNNL_ARG_SRC_1: return dst_md(3);
170 case DNNL_ARG_SRC: return src_md(0);
171 case DNNL_ARG_DST: return dst_md(0);
172 case DNNL_ARG_MEAN: return stats_is_src() ? src_md(1) : dst_md(1);
173 case DNNL_ARG_VARIANCE:
174 return stats_is_src() ? src_md(2) : dst_md(2);
175 case DNNL_ARG_SCALE:
176 case DNNL_ARG_SHIFT: return weights_md(0);
177 default: return batch_normalization_pd_t::arg_md(arg);
178 }
179 }
180
181 const memory_desc_t *src_md(int index = 0) const override {
182 if (index == 0) return &src_md_;
183 if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
184 return &glob_zero_md;
185 }
186
187 const memory_desc_t *dst_md(int index = 0) const override {
188 if (index == 0) return &dst_md_;
189 if (!stats_is_src() && is_training() && (index == 1 || index == 2))
190 return &stat_md_;
191 if (fuse_norm_add_relu() && index == 3) return &dst_md_;
192 return &glob_zero_md;
193 }
194
195 const memory_desc_t *weights_md(int index = 0) const override {
196 return index == 0 ? &scaleshift_md_ : &glob_zero_md;
197 }
198
199 const memory_desc_t *workspace_md(int index = 0) const override {
200 return index == 0 ? &ws_md_ : &glob_zero_md;
201 }
202
203 const memory_desc_t *stat_md() const {
204 return stats_is_src() ? src_md(1) : dst_md(1);
205 }
206
207 int n_inputs() const override {
208 return 1 + 2 * stats_is_src() + use_scale() + use_shift()
209 + fuse_norm_add_relu();
210 }
211 int n_outputs() const override {
212 return 1 + !types::is_zero_md(workspace_md())
213 + (2 * (!stats_is_src())) * is_training();
214 }
215
216protected:
217 memory_desc_t dst_md_;
218
219 batch_normalization_fwd_pd_t(const batch_normalization_desc_t *adesc,
220 const primitive_attr_t *attr,
221 const batch_normalization_fwd_pd_t *hint_fwd_pd)
222 : batch_normalization_pd_t(adesc, attr, hint_fwd_pd)
223 , dst_md_(desc_.dst_desc) {}
224
225 bool set_default_formats_common() {
226 return IMPLICATION(dst_md_.format_kind == format_kind::any,
227 memory_desc_init_by_md_and_dt(
228 dst_md_, src_md_, dst_md_.data_type)
229 == status::success);
230 }
231 bool check_scale_shift_data_type() const {
232 return IMPLICATION(use_scale() || use_shift(),
233 weights_md()->data_type == data_type::f32);
234 }
235};
236
237struct batch_normalization_bwd_pd_t : public batch_normalization_pd_t {
238 typedef batch_normalization_bwd_pd_t base_class;
239 typedef batch_normalization_fwd_pd_t hint_class;
240
241 arg_usage_t arg_usage(int arg) const override {
242 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE,
243 DNNL_ARG_DIFF_DST))
244 return arg_usage_t::input;
245
246 if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input;
247 if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input;
248
249 if (arg == DNNL_ARG_WORKSPACE && !types::is_zero_md(workspace_md()))
250 return arg_usage_t::input;
251
252 if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output;
253 if (arg == DNNL_ARG_DIFF_SRC_1 && fuse_norm_add_relu())
254 return arg_usage_t::output;
255
256 if (arg == DNNL_ARG_DIFF_SCALE && use_scale())
257 return arg_usage_t::output;
258 if (arg == DNNL_ARG_DIFF_SHIFT && use_shift())
259 return arg_usage_t::output;
260 return primitive_desc_t::arg_usage(arg);
261 }
262
263 const memory_desc_t *arg_md(int arg) const override {
264 switch (arg) {
265 case DNNL_ARG_SRC: return src_md(0);
266 case DNNL_ARG_MEAN: return src_md(1);
267 case DNNL_ARG_VARIANCE: return src_md(2);
268 case DNNL_ARG_SCALE:
269 case DNNL_ARG_SHIFT: return weights_md(0);
270 case DNNL_ARG_DIFF_SRC_1: return diff_dst_md(1);
271 case DNNL_ARG_DIFF_SRC: return diff_src_md(0);
272 case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
273 case DNNL_ARG_DIFF_SCALE:
274 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0);
275 default: return batch_normalization_pd_t::arg_md(arg);
276 }
277 }
278
279 const memory_desc_t *src_md(int index = 0) const override {
280 return index == 0 ? &src_md_ : index <= 2 ? &stat_md_ : &glob_zero_md;
281 }
282 const memory_desc_t *diff_dst_md(int index = 0) const override {
283 if (index == 0) return &diff_dst_md_;
284 if (fuse_norm_add_relu() && index == 1) return &diff_dst_md_;
285 return &glob_zero_md;
286 }
287 const memory_desc_t *diff_src_md(int index = 0) const override {
288 return index == 0 ? &diff_src_md_ : &glob_zero_md;
289 }
290
291 const memory_desc_t *weights_md(int index = 0) const override {
292 return index == 0 ? &scaleshift_md_ : &glob_zero_md;
293 }
294 const memory_desc_t *diff_weights_md(int index = 0) const override {
295 return index == 0 ? &diff_scaleshift_md_ : &glob_zero_md;
296 }
297
298 const memory_desc_t *workspace_md(int index = 0) const override {
299 return index == 0 ? &ws_md_ : &glob_zero_md;
300 }
301
302 const memory_desc_t *stat_md() const { return src_md(1); }
303
304 int n_inputs() const override {
305 return 4 + (!types::is_zero_md(workspace_md())) + use_scale();
306 }
307 int n_outputs() const override {
308 return 1 + fuse_norm_add_relu()
309 + (!types::is_zero_md(diff_weights_md()))
310 * (use_scale() + use_shift());
311 }
312
313protected:
314 memory_desc_t diff_src_md_;
315 memory_desc_t diff_dst_md_;
316 memory_desc_t diff_scaleshift_md_;
317
318 batch_normalization_bwd_pd_t(const batch_normalization_desc_t *adesc,
319 const primitive_attr_t *attr,
320 const batch_normalization_fwd_pd_t *hint_fwd_pd)
321 : batch_normalization_pd_t(adesc, attr, hint_fwd_pd)
322 , diff_src_md_(desc_.diff_src_desc)
323 , diff_dst_md_(desc_.diff_dst_desc)
324 , diff_scaleshift_md_(desc_.diff_scaleshift_desc) {}
325
326 bool set_default_formats_common() {
327 return IMPLICATION(diff_dst_md_.format_kind == format_kind::any,
328 memory_desc_init_by_md_and_dt(
329 diff_dst_md_, src_md_, diff_dst_md_.data_type)
330 == status::success)
331 && IMPLICATION(diff_src_md_.format_kind == format_kind::any,
332 memory_desc_init_by_md_and_dt(
333 diff_src_md_, src_md_, diff_src_md_.data_type)
334 == status::success);
335 }
336
337 bool check_scale_shift_data_type() const {
338 return IMPLICATION(use_scale() || use_shift(),
339 utils::everyone_is(data_type::f32, weights_md()->data_type,
340 diff_weights_md()->data_type));
341 }
342};
343
344} // namespace impl
345} // namespace dnnl
346
347#endif
348
349// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
350