1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_ELTWISE_PD_HPP |
18 | #define COMMON_ELTWISE_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | |
28 | struct eltwise_fwd_pd_t; |
29 | |
30 | struct eltwise_pd_t : public primitive_desc_t { |
31 | static constexpr auto base_pkind = primitive_kind::eltwise; |
32 | |
33 | const eltwise_desc_t *desc() const { return &desc_; } |
34 | const op_desc_t *op_desc() const override { |
35 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
36 | } |
37 | |
38 | status_t query(query_t what, int idx, void *result) const override { |
39 | switch (what) { |
40 | case query::prop_kind: |
41 | *(prop_kind_t *)result = desc()->prop_kind; |
42 | break; |
43 | case query::alg_kind: |
44 | *(alg_kind_t *)result = desc()->alg_kind; |
45 | break; |
46 | case query::alpha_f32: *(float *)result = desc()->alpha; break; |
47 | case query::beta_f32: *(float *)result = desc()->beta; break; |
48 | default: return primitive_desc_t::query(what, idx, result); |
49 | } |
50 | return status::success; |
51 | } |
52 | |
53 | /* common eltwise aux functions */ |
54 | |
55 | dim_t MB() const { return data_md()->dims[0]; } |
56 | dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } |
57 | dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } |
58 | dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } |
59 | dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } |
60 | |
61 | int ndims() const { return data_md()->ndims; } |
62 | |
63 | bool is_fwd() const { |
64 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
65 | prop_kind::forward_inference); |
66 | } |
67 | |
68 | bool has_zero_dim_memory() const { |
69 | return memory_desc_wrapper(data_md()).has_zero_dim(); |
70 | } |
71 | |
72 | bool use_dst() const { |
73 | using namespace alg_kind; |
74 | return !is_fwd() |
75 | && utils::one_of(desc_.alg_kind, eltwise_relu_use_dst_for_bwd, |
76 | eltwise_tanh_use_dst_for_bwd, |
77 | eltwise_elu_use_dst_for_bwd, |
78 | eltwise_sqrt_use_dst_for_bwd, |
79 | eltwise_logistic_use_dst_for_bwd, |
80 | eltwise_exp_use_dst_for_bwd, |
81 | eltwise_clip_v2_use_dst_for_bwd); |
82 | } |
83 | |
84 | protected: |
85 | eltwise_desc_t desc_; |
86 | const eltwise_fwd_pd_t *hint_fwd_pd_; |
87 | |
88 | memory_desc_t src_md_; |
89 | memory_desc_t dst_md_; |
90 | |
91 | eltwise_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, |
92 | const eltwise_fwd_pd_t *hint_fwd_pd) |
93 | : primitive_desc_t(attr, base_pkind) |
94 | , desc_(*adesc) |
95 | , hint_fwd_pd_(hint_fwd_pd) |
96 | , src_md_(desc_.src_desc) |
97 | , dst_md_(desc_.dst_desc) {} |
98 | |
99 | private: |
100 | const memory_desc_t *data_md(int index = 0) const { |
101 | return use_dst() ? dst_md(index) : src_md(index); |
102 | } |
103 | }; |
104 | |
105 | struct eltwise_fwd_pd_t : public eltwise_pd_t { |
106 | typedef eltwise_fwd_pd_t base_class; |
107 | typedef eltwise_fwd_pd_t hint_class; |
108 | |
109 | arg_usage_t arg_usage(int arg) const override { |
110 | if (arg == DNNL_ARG_SRC) return arg_usage_t::input; |
111 | |
112 | if (arg == DNNL_ARG_DST) return arg_usage_t::output; |
113 | |
114 | return primitive_desc_t::arg_usage(arg); |
115 | } |
116 | |
117 | const memory_desc_t *arg_md(int arg) const override { |
118 | switch (arg) { |
119 | case DNNL_ARG_SRC: return src_md(0); |
120 | case DNNL_ARG_DST: return dst_md(0); |
121 | default: return eltwise_pd_t::arg_md(arg); |
122 | } |
123 | } |
124 | |
125 | const memory_desc_t *src_md(int index = 0) const override { |
126 | return index == 0 ? &src_md_ : &glob_zero_md; |
127 | } |
128 | const memory_desc_t *dst_md(int index = 0) const override { |
129 | return index == 0 ? &dst_md_ : &glob_zero_md; |
130 | } |
131 | |
132 | int n_inputs() const override { return 1 + n_binary_po_inputs(); } |
133 | int n_outputs() const override { return 1; } |
134 | |
135 | static bool eltwise_preserves_zero( |
136 | alg_kind_t alg, float alpha, float beta) { |
137 | using namespace alg_kind; |
138 | using namespace utils; |
139 | return one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, |
140 | eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_swish, |
141 | eltwise_gelu_tanh, eltwise_gelu_erf, eltwise_round, |
142 | eltwise_hardswish) |
143 | || one_of(alg, eltwise_relu_use_dst_for_bwd, |
144 | eltwise_tanh_use_dst_for_bwd, |
145 | eltwise_elu_use_dst_for_bwd, |
146 | eltwise_sqrt_use_dst_for_bwd) |
147 | || (one_of(alg, eltwise_clip, eltwise_clip_v2) && alpha <= 0 |
148 | && beta >= 0) |
149 | || (alg == eltwise_linear && beta == 0) |
150 | || (alg == eltwise_pow && beta > 0); |
151 | } |
152 | |
153 | static bool eltwise_preserves_zero( |
154 | const post_ops_t::entry_t::eltwise_t &eltwise) { |
155 | return eltwise_preserves_zero(eltwise.alg, eltwise.alpha, eltwise.beta); |
156 | } |
157 | |
158 | bool is_zero_preserved() const { |
159 | return eltwise_preserves_zero(desc_.alg_kind, desc_.alpha, desc_.beta); |
160 | } |
161 | |
162 | protected: |
163 | eltwise_fwd_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, |
164 | const eltwise_fwd_pd_t *hint_fwd_pd) |
165 | : eltwise_pd_t(adesc, attr, hint_fwd_pd) {} |
166 | |
167 | bool set_default_formats_common() { |
168 | return IMPLICATION(dst_md_.format_kind == format_kind::any, |
169 | memory_desc_init_by_md_and_dt( |
170 | dst_md_, src_md_, dst_md_.data_type) |
171 | == status::success); |
172 | } |
173 | }; |
174 | |
175 | struct eltwise_bwd_pd_t : public eltwise_pd_t { |
176 | typedef eltwise_bwd_pd_t base_class; |
177 | typedef eltwise_fwd_pd_t hint_class; |
178 | |
179 | arg_usage_t arg_usage(int arg) const override { |
180 | if (use_dst() ? arg == DNNL_ARG_DST : arg == DNNL_ARG_SRC) |
181 | return arg_usage_t::input; |
182 | |
183 | if (arg == DNNL_ARG_DIFF_DST) return arg_usage_t::input; |
184 | if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; |
185 | |
186 | return primitive_desc_t::arg_usage(arg); |
187 | } |
188 | |
189 | const memory_desc_t *arg_md(int arg) const override { |
190 | switch (arg) { |
191 | case DNNL_ARG_SRC: return src_md(0); |
192 | case DNNL_ARG_DST: return dst_md(0); |
193 | case DNNL_ARG_DIFF_SRC: return diff_src_md(0); |
194 | case DNNL_ARG_DIFF_DST: return diff_dst_md(0); |
195 | default: return eltwise_pd_t::arg_md(arg); |
196 | } |
197 | } |
198 | |
199 | // To avoid additional logic in implementations |
200 | const memory_desc_t *data_md(int index = 0) const { |
201 | return use_dst() ? dst_md(index) : src_md(index); |
202 | } |
203 | const memory_desc_t *src_md(int index = 0) const override { |
204 | return (index == 0 && !use_dst()) ? &src_md_ : &glob_zero_md; |
205 | } |
206 | const memory_desc_t *dst_md(int index = 0) const override { |
207 | return (index == 0 && use_dst()) ? &dst_md_ : &glob_zero_md; |
208 | } |
209 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
210 | return index == 0 ? &diff_dst_md_ : &glob_zero_md; |
211 | } |
212 | const memory_desc_t *diff_src_md(int index = 0) const override { |
213 | return index == 0 ? &diff_src_md_ : &glob_zero_md; |
214 | } |
215 | |
216 | int n_inputs() const override { return 2; } |
217 | int n_outputs() const override { return 1; } |
218 | |
219 | static bool eltwise_preserves_zero( |
220 | alg_kind_t alg, float alpha, float beta) { |
221 | // Unlike forward counterpart, bwd works on two tensors (with same formats) |
222 | // and if alg moves zero to non-zero, it's fine, because diff_dst will |
223 | // still have zeros in padding and multiplication of zero and non-zero |
224 | // gives desired result. However, it doesn't work in case of special fp |
225 | // values which are NaN or infinity which give NaN when multiplying on |
226 | // zero, so excluding all those algs from here. |
227 | using namespace alg_kind; |
228 | using namespace utils; |
229 | return one_of(alg, eltwise_abs, eltwise_clip, eltwise_clip_v2, |
230 | eltwise_elu, eltwise_exp, eltwise_gelu_erf, |
231 | eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_linear, |
232 | eltwise_logistic, eltwise_mish, eltwise_relu, |
233 | eltwise_soft_relu, eltwise_square, eltwise_swish, |
234 | eltwise_tanh) |
235 | || one_of(alg, eltwise_elu_use_dst_for_bwd, |
236 | eltwise_exp_use_dst_for_bwd, |
237 | eltwise_logistic_use_dst_for_bwd, |
238 | eltwise_relu_use_dst_for_bwd, |
239 | eltwise_tanh_use_dst_for_bwd, |
240 | eltwise_clip_v2_use_dst_for_bwd) |
241 | || (alg == eltwise_pow && beta >= 1); |
242 | } |
243 | |
244 | bool is_zero_preserved() const { |
245 | return eltwise_preserves_zero(desc_.alg_kind, desc_.alpha, desc_.beta); |
246 | } |
247 | |
248 | protected: |
249 | memory_desc_t diff_src_md_; |
250 | memory_desc_t diff_dst_md_; |
251 | |
252 | eltwise_bwd_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, |
253 | const eltwise_fwd_pd_t *hint_fwd_pd) |
254 | : eltwise_pd_t(adesc, attr, hint_fwd_pd) |
255 | , diff_src_md_(desc_.diff_src_desc) |
256 | , diff_dst_md_(desc_.diff_dst_desc) {} |
257 | |
258 | bool set_default_formats_common() { |
259 | return IMPLICATION(diff_dst_md_.format_kind == format_kind::any, |
260 | memory_desc_init_by_md_and_dt( |
261 | diff_dst_md_, *data_md(), diff_dst_md_.data_type) |
262 | == status::success) |
263 | && IMPLICATION(diff_src_md_.format_kind == format_kind::any, |
264 | memory_desc_init_by_md_and_dt(diff_src_md_, *data_md(), |
265 | diff_src_md_.data_type) |
266 | == status::success); |
267 | } |
268 | }; |
269 | |
270 | } // namespace impl |
271 | } // namespace dnnl |
272 | |
273 | #endif |
274 | |
275 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
276 | |