1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_INNER_PRODUCT_PD_HPP |
18 | #define COMMON_INNER_PRODUCT_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | |
29 | status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, |
30 | const memory_desc_t *src_desc, const memory_desc_t *weights_desc, |
31 | const memory_desc_t *bias_desc, const memory_desc_t *dst_desc); |
32 | |
33 | struct inner_product_fwd_pd_t; |
34 | |
35 | struct inner_product_pd_t : public primitive_desc_t { |
36 | static constexpr auto base_pkind = primitive_kind::inner_product; |
37 | |
38 | inner_product_pd_t(const inner_product_desc_t *adesc, |
39 | const primitive_attr_t *attr, |
40 | const inner_product_fwd_pd_t *hint_fwd_pd) |
41 | : primitive_desc_t(attr, base_pkind) |
42 | , desc_(*adesc) |
43 | , hint_fwd_pd_(hint_fwd_pd) {} |
44 | |
45 | const inner_product_desc_t *desc() const { return &desc_; } |
46 | const op_desc_t *op_desc() const override { |
47 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
48 | } |
49 | |
50 | status_t query(query_t what, int idx, void *result) const override { |
51 | switch (what) { |
52 | case query::prop_kind: |
53 | *(prop_kind_t *)result = desc()->prop_kind; |
54 | break; |
55 | default: return primitive_desc_t::query(what, idx, result); |
56 | } |
57 | return status::success; |
58 | } |
59 | |
60 | /* common inner_product aux functions */ |
61 | |
62 | dim_t MB() const { return invariant_src_md()->dims[0]; } |
63 | dim_t IC() const { return invariant_src_md()->dims[1]; } |
64 | dim_t OC() const { return invariant_dst_md()->dims[1]; } |
65 | |
66 | dim_t ID() const { |
67 | return ndims() >= 5 ? invariant_src_md()->dims[ndims() - 3] : 1; |
68 | } |
69 | dim_t IH() const { |
70 | return ndims() >= 4 ? invariant_src_md()->dims[ndims() - 2] : 1; |
71 | } |
72 | dim_t IW() const { |
73 | return ndims() >= 3 ? invariant_src_md()->dims[ndims() - 1] : 1; |
74 | } |
75 | |
76 | dim_t OD() const { |
77 | return ndims() >= 5 ? invariant_dst_md()->dims[ndims() - 3] : 1; |
78 | } |
79 | dim_t OH() const { |
80 | return ndims() >= 4 ? invariant_dst_md()->dims[ndims() - 2] : 1; |
81 | } |
82 | dim_t OW() const { |
83 | return ndims() >= 3 ? invariant_dst_md()->dims[ndims() - 1] : 1; |
84 | } |
85 | |
86 | dim_t KD() const { |
87 | return ndims() >= 5 ? invariant_wei_md()->dims[ndims() - 3] : 1; |
88 | } |
89 | dim_t KH() const { |
90 | return ndims() >= 4 ? invariant_wei_md()->dims[ndims() - 2] : 1; |
91 | } |
92 | dim_t KW() const { |
93 | return ndims() >= 3 ? invariant_wei_md()->dims[ndims() - 1] : 1; |
94 | } |
95 | |
96 | dim_t IC_total() const { |
97 | return utils::array_product(&invariant_src_md()->dims[1], ndims() - 1); |
98 | } |
99 | |
100 | dim_t IC_total_padded() const { |
101 | auto src_d = desc()->prop_kind == prop_kind::backward_data |
102 | ? memory_desc_wrapper(diff_src_md()) |
103 | : memory_desc_wrapper(src_md()); |
104 | assert(src_d.is_blocking_desc()); |
105 | if (!src_d.is_blocking_desc()) return -1; |
106 | return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); |
107 | } |
108 | |
109 | int ndims() const { return invariant_src_md()->ndims; } |
110 | |
111 | bool with_bias() const { |
112 | auto *bia_d = desc()->prop_kind == prop_kind::backward_weights |
113 | ? &desc()->diff_bias_desc |
114 | : &desc()->bias_desc; |
115 | return !memory_desc_wrapper(bia_d).is_zero(); |
116 | } |
117 | |
118 | bool has_zero_dim_memory() const { |
119 | const auto s_d = memory_desc_wrapper(*invariant_src_md()); |
120 | const auto d_d = memory_desc_wrapper(*invariant_dst_md()); |
121 | return s_d.has_zero_dim() || d_d.has_zero_dim(); |
122 | } |
123 | |
124 | bool is_fwd() const { |
125 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
126 | prop_kind::forward_inference); |
127 | } |
128 | |
129 | virtual const memory_desc_t *invariant_src_md() const { |
130 | return desc()->prop_kind == prop_kind::backward_data ? diff_src_md() |
131 | : src_md(); |
132 | } |
133 | |
134 | virtual const memory_desc_t *invariant_wei_md(int index = 0) const { |
135 | return desc()->prop_kind == prop_kind::backward_weights |
136 | ? diff_weights_md(index) |
137 | : weights_md(index); |
138 | } |
139 | |
140 | virtual const memory_desc_t *invariant_bia_md() const { |
141 | return invariant_wei_md(1); |
142 | } |
143 | |
144 | virtual const memory_desc_t *invariant_dst_md() const { |
145 | return is_fwd() ? dst_md() : diff_dst_md(); |
146 | } |
147 | |
148 | protected: |
149 | inner_product_desc_t desc_; |
150 | const inner_product_fwd_pd_t *hint_fwd_pd_; |
151 | |
152 | bool set_default_formats_common_template(memory_desc_t &src_md, |
153 | format_tag_t src_tag, memory_desc_t &wei_md, format_tag_t wei_tag, |
154 | memory_desc_t &dst_md, format_tag_t dst_tag, |
155 | memory_desc_t &bia_md) { |
156 | using namespace format_tag; |
157 | |
158 | #define IS_OK(f) \ |
159 | do { \ |
160 | if ((f) != status::success) return false; \ |
161 | } while (0) |
162 | if (src_md.format_kind == format_kind::any |
163 | && !utils::one_of(src_tag, any, undef)) |
164 | IS_OK(memory_desc_init_by_tag(src_md, src_tag)); |
165 | if (dst_md.format_kind == format_kind::any |
166 | && !utils::one_of(dst_tag, any, undef)) |
167 | IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); |
168 | if (wei_md.format_kind == format_kind::any |
169 | && !utils::one_of(wei_tag, any, undef)) |
170 | IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); |
171 | if (with_bias() && bia_md.format_kind == format_kind::any) |
172 | IS_OK(memory_desc_init_by_tag(bia_md, x)); |
173 | #undef IS_OK |
174 | |
175 | return true; |
176 | } |
177 | |
178 | bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, |
179 | data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { |
180 | bool ok = true |
181 | && (src_dt == data_type::undef |
182 | || invariant_src_md()->data_type == src_dt) |
183 | && (wei_dt == data_type::undef |
184 | || invariant_wei_md()->data_type == wei_dt) |
185 | && (dst_dt == data_type::undef |
186 | || invariant_dst_md()->data_type == dst_dt) |
187 | && (acc_dt == data_type::undef |
188 | || desc_.accum_data_type == acc_dt); |
189 | if (with_bias() && bia_dt != data_type::undef) |
190 | ok = ok && invariant_bia_md()->data_type == bia_dt; |
191 | return ok; |
192 | } |
193 | }; |
194 | |
195 | struct inner_product_fwd_pd_t : public inner_product_pd_t { |
196 | typedef inner_product_fwd_pd_t base_class; |
197 | typedef inner_product_fwd_pd_t hint_class; |
198 | |
199 | inner_product_fwd_pd_t(const inner_product_desc_t *adesc, |
200 | const primitive_attr_t *attr, |
201 | const inner_product_fwd_pd_t *hint_fwd_pd) |
202 | : inner_product_pd_t(adesc, attr, hint_fwd_pd) |
203 | , src_md_(desc_.src_desc) |
204 | , weights_md_(desc_.weights_desc) |
205 | , bias_md_(desc_.bias_desc) |
206 | , dst_md_(desc_.dst_desc) {} |
207 | |
208 | arg_usage_t arg_usage(int arg) const override { |
209 | if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) |
210 | return arg_usage_t::input; |
211 | |
212 | if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; |
213 | |
214 | if (arg == DNNL_ARG_DST) return arg_usage_t::output; |
215 | |
216 | return primitive_desc_t::arg_usage(arg); |
217 | } |
218 | |
219 | const memory_desc_t *arg_md(int arg) const override { |
220 | switch (arg) { |
221 | case DNNL_ARG_SRC: return src_md(0); |
222 | case DNNL_ARG_WEIGHTS: return weights_md(0); |
223 | case DNNL_ARG_BIAS: return weights_md(1); |
224 | case DNNL_ARG_DST: return dst_md(0); |
225 | default: return inner_product_pd_t::arg_md(arg); |
226 | } |
227 | } |
228 | |
229 | const memory_desc_t *src_md(int index = 0) const override { |
230 | return index == 0 ? &src_md_ : &glob_zero_md; |
231 | } |
232 | const memory_desc_t *dst_md(int index = 0) const override { |
233 | return index == 0 ? &dst_md_ : &glob_zero_md; |
234 | } |
235 | const memory_desc_t *weights_md(int index = 0) const override { |
236 | if (index == 0) return &weights_md_; |
237 | if (index == 1 && with_bias()) return &bias_md_; |
238 | return &glob_zero_md; |
239 | } |
240 | |
241 | int n_inputs() const override { |
242 | return 2 + with_bias() + n_binary_po_inputs(); |
243 | } |
244 | int n_outputs() const override { return 1; } |
245 | |
246 | protected: |
247 | memory_desc_t src_md_; |
248 | memory_desc_t weights_md_; |
249 | memory_desc_t bias_md_; |
250 | memory_desc_t dst_md_; |
251 | |
252 | bool set_default_formats_common( |
253 | format_tag_t src_tag, format_tag_t wei_tag, format_tag_t dst_tag) { |
254 | return set_default_formats_common_template(src_md_, src_tag, |
255 | weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); |
256 | } |
257 | }; |
258 | |
259 | struct inner_product_bwd_data_pd_t : public inner_product_pd_t { |
260 | typedef inner_product_bwd_data_pd_t base_class; |
261 | typedef inner_product_fwd_pd_t hint_class; |
262 | |
263 | inner_product_bwd_data_pd_t(const inner_product_desc_t *adesc, |
264 | const primitive_attr_t *attr, |
265 | const inner_product_fwd_pd_t *hint_fwd_pd) |
266 | : inner_product_pd_t(adesc, attr, hint_fwd_pd) |
267 | , diff_src_md_(desc_.diff_src_desc) |
268 | , weights_md_(desc_.weights_desc) |
269 | , diff_dst_md_(desc_.diff_dst_desc) {} |
270 | |
271 | arg_usage_t arg_usage(int arg) const override { |
272 | if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) |
273 | return arg_usage_t::input; |
274 | |
275 | if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; |
276 | |
277 | return primitive_desc_t::arg_usage(arg); |
278 | } |
279 | |
280 | const memory_desc_t *arg_md(int arg) const override { |
281 | switch (arg) { |
282 | case DNNL_ARG_DIFF_SRC: return diff_src_md(0); |
283 | case DNNL_ARG_WEIGHTS: return weights_md(0); |
284 | case DNNL_ARG_DIFF_DST: return diff_dst_md(0); |
285 | default: return inner_product_pd_t::arg_md(arg); |
286 | } |
287 | } |
288 | |
289 | const memory_desc_t *diff_src_md(int index = 0) const override { |
290 | return index == 0 ? &diff_src_md_ : &glob_zero_md; |
291 | } |
292 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
293 | return index == 0 ? &diff_dst_md_ : &glob_zero_md; |
294 | } |
295 | const memory_desc_t *weights_md(int index = 0) const override { |
296 | return index == 0 ? &weights_md_ : &glob_zero_md; |
297 | } |
298 | |
299 | int n_inputs() const override { return 2; } |
300 | int n_outputs() const override { return 1; } |
301 | |
302 | protected: |
303 | memory_desc_t diff_src_md_; |
304 | memory_desc_t weights_md_; |
305 | memory_desc_t diff_dst_md_; |
306 | |
307 | bool set_default_formats_common(format_tag_t diff_src_tag, |
308 | format_tag_t wei_tag, format_tag_t diff_dst_tag) { |
309 | memory_desc_t dummy_md; |
310 | return set_default_formats_common_template(diff_src_md_, diff_src_tag, |
311 | weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, dummy_md); |
312 | } |
313 | }; |
314 | |
315 | struct inner_product_bwd_weights_pd_t : public inner_product_pd_t { |
316 | typedef inner_product_bwd_weights_pd_t base_class; |
317 | typedef inner_product_fwd_pd_t hint_class; |
318 | |
319 | inner_product_bwd_weights_pd_t(const inner_product_desc_t *adesc, |
320 | const primitive_attr_t *attr, |
321 | const inner_product_fwd_pd_t *hint_fwd_pd) |
322 | : inner_product_pd_t(adesc, attr, hint_fwd_pd) |
323 | , src_md_(desc_.src_desc) |
324 | , diff_weights_md_(desc_.diff_weights_desc) |
325 | , diff_bias_md_(desc_.diff_bias_desc) |
326 | , diff_dst_md_(desc_.diff_dst_desc) {} |
327 | |
328 | arg_usage_t arg_usage(int arg) const override { |
329 | if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) |
330 | return arg_usage_t::input; |
331 | |
332 | if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; |
333 | |
334 | if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) |
335 | return arg_usage_t::output; |
336 | |
337 | return primitive_desc_t::arg_usage(arg); |
338 | } |
339 | |
340 | const memory_desc_t *arg_md(int arg) const override { |
341 | switch (arg) { |
342 | case DNNL_ARG_SRC: return src_md(0); |
343 | case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); |
344 | case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); |
345 | case DNNL_ARG_DIFF_DST: return diff_dst_md(0); |
346 | default: return inner_product_pd_t::arg_md(arg); |
347 | } |
348 | } |
349 | |
350 | const memory_desc_t *src_md(int index = 0) const override { |
351 | return index == 0 ? &src_md_ : &glob_zero_md; |
352 | } |
353 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
354 | return index == 0 ? &diff_dst_md_ : &glob_zero_md; |
355 | } |
356 | const memory_desc_t *diff_weights_md(int index = 0) const override { |
357 | if (index == 0) return &diff_weights_md_; |
358 | if (index == 1 && with_bias()) return &diff_bias_md_; |
359 | return &glob_zero_md; |
360 | } |
361 | |
362 | int n_inputs() const override { return 2; } |
363 | int n_outputs() const override { return 1 + with_bias(); } |
364 | |
365 | protected: |
366 | memory_desc_t src_md_; |
367 | memory_desc_t diff_weights_md_; |
368 | memory_desc_t diff_bias_md_; |
369 | memory_desc_t diff_dst_md_; |
370 | |
371 | bool set_default_formats_common(format_tag_t src_tag, |
372 | format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { |
373 | return set_default_formats_common_template(src_md_, src_tag, |
374 | diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, |
375 | diff_bias_md_); |
376 | } |
377 | }; |
378 | |
379 | } // namespace impl |
380 | } // namespace dnnl |
381 | |
382 | #endif |
383 | |
384 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
385 | |