1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_RNN_PD_HPP |
18 | #define COMMON_RNN_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | #include "rnn.hpp" |
25 | #include "type_helpers.hpp" |
26 | #include "utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | |
31 | struct rnn_fwd_pd_t; |
32 | |
33 | struct rnn_pd_t : public primitive_desc_t { |
34 | static constexpr auto base_pkind = primitive_kind::rnn; |
35 | |
36 | const rnn_desc_t *desc() const { return &desc_; } |
37 | const op_desc_t *op_desc() const override { |
38 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
39 | } |
40 | |
41 | status_t query(query_t what, int idx, void *result) const override { |
42 | switch (what) { |
43 | case query::prop_kind: |
44 | *(prop_kind_t *)result = desc()->prop_kind; |
45 | break; |
46 | case query::cell_kind: |
47 | *(alg_kind_t *)result = desc()->cell_kind; |
48 | break; |
49 | case query::activation_kind: |
50 | *(alg_kind_t *)result = desc()->activation_kind; |
51 | break; |
52 | case query::direction: |
53 | *(rnn_direction_t *)result = desc()->direction; |
54 | break; |
55 | case query::alpha_f32: *(float *)result = desc()->alpha; break; |
56 | default: return primitive_desc_t::query(what, idx, result); |
57 | } |
58 | return status::success; |
59 | } |
60 | |
61 | const memory_desc_t *src_md(int index = 0) const override { |
62 | if (index == 0) return &src_layer_md_; |
63 | if (index == 1 && with_src_iter()) return &src_iter_md_; |
64 | if (index == 2 && with_src_iter_c()) return &src_iter_c_md_; |
65 | return &glob_zero_md; |
66 | } |
67 | |
68 | memory_desc_t &augru_attention_md() { |
69 | if (with_augru_attention()) return weights_peephole_md_; |
70 | return glob_zero_md; |
71 | } |
72 | |
73 | const memory_desc_t &const_augru_attention_md() const { |
74 | if (with_augru_attention()) return weights_peephole_md_; |
75 | return glob_zero_md; |
76 | } |
77 | |
78 | const memory_desc_t *weights_md(int index = 0) const override { |
79 | if (index == 0) return &weights_layer_md_; |
80 | if (index == 1) return &weights_iter_md_; |
81 | |
82 | const int peephole_index = 2; |
83 | if (is_lstm_peephole() && index == peephole_index) |
84 | return &weights_peephole_md_; |
85 | |
86 | const int projection_index = 2 + is_lstm_peephole(); |
87 | if (is_lstm_projection() && index == projection_index) |
88 | return &weights_projection_md_; |
89 | |
90 | const int bias_index = 2 + is_lstm_peephole() + is_lstm_projection(); |
91 | if (with_bias() && index == bias_index) return &bias_md_; |
92 | |
93 | return &glob_zero_md; |
94 | } |
95 | const memory_desc_t *dst_md(int index = 0) const override { |
96 | if (index == 0) return &dst_layer_md_; |
97 | if (index == 1 && with_dst_iter()) return &dst_iter_md_; |
98 | if (index == 2 && with_dst_iter_c()) return &dst_iter_c_md_; |
99 | return &glob_zero_md; |
100 | } |
101 | const memory_desc_t *workspace_md(int index = 0) const override { |
102 | return (index == 0) ? &ws_md_ : &glob_zero_md; |
103 | } |
104 | |
105 | /* common aux functions */ |
106 | |
107 | bool is_training() const { |
108 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
109 | prop_kind::backward); |
110 | } |
111 | |
112 | bool is_fwd() const { |
113 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
114 | prop_kind::forward_inference); |
115 | } |
116 | |
117 | dim_t T() const { return desc_.src_layer_desc.dims[0]; } |
118 | dim_t MB() const { return desc_.src_layer_desc.dims[1]; } |
119 | |
120 | dim_t L() const { return desc_.weights_layer_desc.dims[0]; } |
121 | dim_t D() const { return desc_.weights_layer_desc.dims[1]; } |
122 | |
123 | dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } |
124 | |
125 | dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } |
126 | dim_t G() const { return desc_.weights_layer_desc.dims[3]; } |
127 | dim_t DHC() const { return desc_.weights_layer_desc.dims[4]; } |
128 | |
129 | // Returns the number of channels for the iter tensor. |
130 | // Must be equal to the dst_iter.dims[3] if dst_iter is not zero. |
131 | dim_t DIC() const { |
132 | return is_lstm_projection() ? desc_.weights_projection_desc.dims[3] |
133 | : DHC(); |
134 | } |
135 | |
136 | dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } |
137 | |
138 | bool with_bias() const { |
139 | return !memory_desc_wrapper(desc_.bias_desc).is_zero(); |
140 | } |
141 | |
142 | bool with_augru_attention() const { return is_augru(); } |
143 | |
144 | bool with_src_iter() const { |
145 | return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); |
146 | } |
147 | |
148 | bool with_src_iter_c() const { |
149 | return is_lstm() |
150 | && !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); |
151 | } |
152 | |
153 | bool with_dst_iter() const { |
154 | return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); |
155 | } |
156 | |
157 | bool with_dst_iter_c() const { |
158 | return is_lstm() && !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); |
159 | } |
160 | |
161 | dnnl::impl::alg_kind_t cell_kind() const { return desc_.cell_kind; } |
162 | dnnl::impl::alg_kind_t activation_kind() const { |
163 | return desc_.activation_kind; |
164 | } |
165 | |
166 | bool is_lbr() const { |
167 | return utils::one_of(cell_kind(), dnnl_lbr_gru, dnnl_lbr_augru); |
168 | } |
169 | |
170 | bool is_augru() const { |
171 | return utils::one_of(cell_kind(), dnnl_vanilla_augru, dnnl_lbr_augru); |
172 | } |
173 | |
174 | bool is_lstm() const { return cell_kind() == dnnl_vanilla_lstm; } |
175 | |
176 | bool is_lstm_peephole() const { |
177 | return is_lstm() |
178 | && !memory_desc_wrapper(weights_peephole_md_).is_zero(); |
179 | } |
180 | |
181 | bool is_lstm_projection() const { |
182 | return !memory_desc_wrapper(weights_projection_md_).is_zero(); |
183 | } |
184 | |
185 | dnnl_rnn_direction_t direction() const { return desc_.direction; } |
186 | |
187 | protected: |
188 | rnn_desc_t desc_; |
189 | const rnn_fwd_pd_t *hint_fwd_pd_; |
190 | |
191 | memory_desc_t src_layer_md_; |
192 | memory_desc_t src_iter_md_; |
193 | memory_desc_t src_iter_c_md_; |
194 | memory_desc_t weights_layer_md_; |
195 | memory_desc_t weights_iter_md_; |
196 | memory_desc_t weights_peephole_md_; |
197 | memory_desc_t weights_projection_md_; |
198 | memory_desc_t bias_md_; |
199 | memory_desc_t dst_layer_md_; |
200 | memory_desc_t dst_iter_md_; |
201 | memory_desc_t dst_iter_c_md_; |
202 | |
203 | memory_desc_t ws_md_; |
204 | |
205 | rnn_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, |
206 | const rnn_fwd_pd_t *hint_fwd_pd) |
207 | : primitive_desc_t(attr, base_pkind) |
208 | , desc_(*adesc) |
209 | , hint_fwd_pd_(hint_fwd_pd) |
210 | , src_layer_md_(desc_.src_layer_desc) |
211 | , src_iter_md_(desc_.src_iter_desc) |
212 | , src_iter_c_md_(desc_.src_iter_c_desc) |
213 | , weights_layer_md_(desc_.weights_layer_desc) |
214 | , weights_iter_md_(desc_.weights_iter_desc) |
215 | , weights_peephole_md_(desc_.weights_peephole_desc) |
216 | , weights_projection_md_(desc_.weights_projection_desc) |
217 | , bias_md_(desc_.bias_desc) |
218 | , dst_layer_md_(desc_.dst_layer_desc) |
219 | , dst_iter_md_(desc_.dst_iter_desc) |
220 | , dst_iter_c_md_(desc_.dst_iter_c_desc) |
221 | , ws_md_() {} |
222 | }; |
223 | |
224 | struct rnn_fwd_pd_t : public rnn_pd_t { |
225 | typedef rnn_fwd_pd_t base_class; |
226 | typedef rnn_fwd_pd_t hint_class; |
227 | |
228 | arg_usage_t arg_usage(int arg) const override { |
229 | if (arg == DNNL_ARG_SRC_LAYER) return arg_usage_t::input; |
230 | |
231 | if (arg == DNNL_ARG_AUGRU_ATTENTION && with_augru_attention()) |
232 | return arg_usage_t::input; |
233 | |
234 | if (arg == DNNL_ARG_SRC_ITER && with_src_iter()) |
235 | return arg_usage_t::input; |
236 | |
237 | if (arg == DNNL_ARG_SRC_ITER_C && with_src_iter_c()) |
238 | return arg_usage_t::input; |
239 | |
240 | if (utils::one_of(arg, DNNL_ARG_WEIGHTS_LAYER, DNNL_ARG_WEIGHTS_ITER)) |
241 | return arg_usage_t::input; |
242 | |
243 | if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE && is_lstm_peephole()) |
244 | return arg_usage_t::input; |
245 | |
246 | if (arg == DNNL_ARG_WEIGHTS_PROJECTION && is_lstm_projection()) |
247 | return arg_usage_t::input; |
248 | |
249 | if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; |
250 | |
251 | if (arg == DNNL_ARG_DST_LAYER) return arg_usage_t::output; |
252 | |
253 | if (arg == DNNL_ARG_DST_ITER && with_dst_iter()) |
254 | return arg_usage_t::output; |
255 | |
256 | if (arg == DNNL_ARG_DST_ITER_C && with_dst_iter() && is_lstm()) |
257 | return arg_usage_t::output; |
258 | |
259 | if (arg == DNNL_ARG_WORKSPACE && is_training()) |
260 | return arg_usage_t::output; |
261 | |
262 | return primitive_desc_t::arg_usage(arg); |
263 | } |
264 | |
265 | const memory_desc_t *arg_md(int arg) const override { |
266 | switch (arg) { |
267 | case DNNL_ARG_SRC_LAYER: return src_md(0); |
268 | case DNNL_ARG_AUGRU_ATTENTION: return &const_augru_attention_md(); |
269 | case DNNL_ARG_SRC_ITER: return src_md(1); |
270 | case DNNL_ARG_SRC_ITER_C: return src_md(2); |
271 | case DNNL_ARG_WEIGHTS_LAYER: return weights_md(0); |
272 | case DNNL_ARG_WEIGHTS_ITER: return weights_md(1); |
273 | case DNNL_ARG_WEIGHTS_PEEPHOLE: |
274 | return is_lstm_peephole() ? weights_md(2) : &glob_zero_md; |
275 | case DNNL_ARG_WEIGHTS_PROJECTION: |
276 | return is_lstm_projection() ? weights_md(2 + is_lstm_peephole()) |
277 | : &glob_zero_md; |
278 | case DNNL_ARG_BIAS: |
279 | return weights_md( |
280 | 2 + is_lstm_peephole() + is_lstm_projection()); |
281 | case DNNL_ARG_DST_LAYER: return dst_md(0); |
282 | case DNNL_ARG_DST_ITER: return dst_md(1); |
283 | case DNNL_ARG_DST_ITER_C: return dst_md(2); |
284 | default: return rnn_pd_t::arg_md(arg); |
285 | } |
286 | } |
287 | |
288 | int n_inputs() const override { |
289 | return 3 + is_lstm_peephole() + is_lstm_projection() + with_bias() |
290 | + with_src_iter() + with_src_iter_c() + is_augru(); |
291 | } |
292 | int n_outputs() const override { |
293 | return 1 + with_dst_iter() + with_dst_iter_c() + is_training(); |
294 | } |
295 | |
296 | protected: |
297 | rnn_fwd_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, |
298 | const rnn_fwd_pd_t *hint_fwd_pd) |
299 | : rnn_pd_t(adesc, attr, hint_fwd_pd) {} |
300 | }; |
301 | |
302 | struct rnn_bwd_pd_t : public rnn_pd_t { |
303 | typedef rnn_bwd_pd_t base_class; |
304 | typedef rnn_fwd_pd_t hint_class; |
305 | |
306 | arg_usage_t arg_usage(int arg) const override { |
307 | if (utils::one_of(arg, DNNL_ARG_SRC_LAYER, DNNL_ARG_DST_LAYER, |
308 | DNNL_ARG_DIFF_DST_LAYER, DNNL_ARG_WEIGHTS_LAYER, |
309 | DNNL_ARG_WEIGHTS_ITER)) |
310 | return arg_usage_t::input; |
311 | |
312 | if (utils::one_of(arg, DNNL_ARG_DIFF_SRC_LAYER, |
313 | DNNL_ARG_DIFF_WEIGHTS_LAYER, DNNL_ARG_DIFF_WEIGHTS_ITER)) |
314 | return arg_usage_t::output; |
315 | |
316 | if (with_augru_attention()) { |
317 | if (arg == DNNL_ARG_AUGRU_ATTENTION) return arg_usage_t::input; |
318 | if (arg == DNNL_ARG_DIFF_AUGRU_ATTENTION) |
319 | return arg_usage_t::output; |
320 | } |
321 | |
322 | if (is_lstm_peephole()) { |
323 | if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE) return arg_usage_t::input; |
324 | |
325 | if (arg == DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE) |
326 | return arg_usage_t::output; |
327 | } |
328 | |
329 | if (is_lstm_projection()) { |
330 | if (arg == DNNL_ARG_WEIGHTS_PROJECTION) return arg_usage_t::input; |
331 | |
332 | if (arg == DNNL_ARG_DIFF_WEIGHTS_PROJECTION) |
333 | return arg_usage_t::output; |
334 | } |
335 | |
336 | if (with_bias()) { |
337 | if (arg == DNNL_ARG_BIAS) return arg_usage_t::input; |
338 | |
339 | if (arg == DNNL_ARG_DIFF_BIAS) return arg_usage_t::output; |
340 | } |
341 | |
342 | if (with_src_iter()) { |
343 | if (arg == DNNL_ARG_SRC_ITER) return arg_usage_t::input; |
344 | |
345 | if (arg == DNNL_ARG_DIFF_SRC_ITER) return arg_usage_t::output; |
346 | } |
347 | |
348 | if (with_src_iter_c()) { |
349 | if (arg == DNNL_ARG_SRC_ITER_C) return arg_usage_t::input; |
350 | |
351 | if (arg == DNNL_ARG_DIFF_SRC_ITER_C) return arg_usage_t::output; |
352 | } |
353 | |
354 | if (with_dst_iter() |
355 | && utils::one_of( |
356 | arg, DNNL_ARG_DST_ITER, DNNL_ARG_DIFF_DST_ITER)) |
357 | return arg_usage_t::input; |
358 | |
359 | if (with_dst_iter_c() |
360 | && utils::one_of( |
361 | arg, DNNL_ARG_DST_ITER_C, DNNL_ARG_DIFF_DST_ITER_C)) |
362 | return arg_usage_t::input; |
363 | |
364 | if (arg == DNNL_ARG_WORKSPACE) return arg_usage_t::input; |
365 | |
366 | return primitive_desc_t::arg_usage(arg); |
367 | } |
368 | |
369 | const memory_desc_t *arg_md(int arg) const override { |
370 | switch (arg) { |
371 | case DNNL_ARG_SRC_LAYER: return src_md(0); |
372 | case DNNL_ARG_AUGRU_ATTENTION: return &const_augru_attention_md(); |
373 | case DNNL_ARG_SRC_ITER: return src_md(1); |
374 | case DNNL_ARG_SRC_ITER_C: return src_md(2); |
375 | case DNNL_ARG_DIFF_SRC_LAYER: return diff_src_md(0); |
376 | case DNNL_ARG_DIFF_AUGRU_ATTENTION: |
377 | return &const_diff_augru_attention_md(); |
378 | case DNNL_ARG_DIFF_SRC_ITER: return diff_src_md(1); |
379 | case DNNL_ARG_DIFF_SRC_ITER_C: return diff_src_md(2); |
380 | case DNNL_ARG_WEIGHTS_LAYER: return weights_md(0); |
381 | case DNNL_ARG_WEIGHTS_ITER: return weights_md(1); |
382 | case DNNL_ARG_WEIGHTS_PEEPHOLE: |
383 | return is_lstm_peephole() ? weights_md(2) : &glob_zero_md; |
384 | case DNNL_ARG_WEIGHTS_PROJECTION: |
385 | return is_lstm_projection() ? weights_md(2 + is_lstm_peephole()) |
386 | : &glob_zero_md; |
387 | case DNNL_ARG_BIAS: |
388 | return weights_md( |
389 | 2 + is_lstm_peephole() + is_lstm_projection()); |
390 | case DNNL_ARG_DIFF_WEIGHTS_LAYER: return diff_weights_md(0); |
391 | case DNNL_ARG_DIFF_WEIGHTS_ITER: return diff_weights_md(1); |
392 | case DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE: |
393 | return is_lstm_peephole() ? diff_weights_md(2) : &glob_zero_md; |
394 | case DNNL_ARG_DIFF_WEIGHTS_PROJECTION: |
395 | return is_lstm_projection() |
396 | ? diff_weights_md(2 + is_lstm_peephole()) |
397 | : &glob_zero_md; |
398 | case DNNL_ARG_DIFF_BIAS: |
399 | return diff_weights_md( |
400 | 2 + is_lstm_peephole() + is_lstm_projection()); |
401 | case DNNL_ARG_DST_LAYER: return dst_md(0); |
402 | case DNNL_ARG_DST_ITER: return dst_md(1); |
403 | case DNNL_ARG_DST_ITER_C: return dst_md(2); |
404 | case DNNL_ARG_DIFF_DST_LAYER: return diff_dst_md(0); |
405 | case DNNL_ARG_DIFF_DST_ITER: return diff_dst_md(1); |
406 | case DNNL_ARG_DIFF_DST_ITER_C: return diff_dst_md(2); |
407 | default: return rnn_pd_t::arg_md(arg); |
408 | } |
409 | } |
410 | |
411 | const memory_desc_t *diff_src_md(int index = 0) const override { |
412 | if (index == 0) return &diff_src_layer_md_; |
413 | if (index == 1 && with_src_iter()) return &diff_src_iter_md_; |
414 | if (index == 2 && with_src_iter_c()) return &diff_src_iter_c_md_; |
415 | return &glob_zero_md; |
416 | } |
417 | memory_desc_t &diff_augru_attention_md() { |
418 | if (with_augru_attention()) return diff_weights_peephole_md_; |
419 | return glob_zero_md; |
420 | } |
421 | const memory_desc_t &const_diff_augru_attention_md() const { |
422 | if (with_augru_attention()) return diff_weights_peephole_md_; |
423 | return glob_zero_md; |
424 | } |
425 | const memory_desc_t *diff_weights_md(int index = 0) const override { |
426 | if (index == 0) return &diff_weights_layer_md_; |
427 | if (index == 1) return &diff_weights_iter_md_; |
428 | |
429 | const int peephole_index = 2; |
430 | if (is_lstm_peephole() && index == peephole_index) |
431 | return &diff_weights_peephole_md_; |
432 | |
433 | const int projection_index = 2 + is_lstm_peephole(); |
434 | if (is_lstm_projection() && index == projection_index) |
435 | return &diff_weights_projection_md_; |
436 | |
437 | const int bias_index = 2 + is_lstm_peephole() + is_lstm_projection(); |
438 | if (with_bias() && index == bias_index) return &diff_bias_md_; |
439 | |
440 | return &glob_zero_md; |
441 | } |
442 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
443 | if (index == 0) return &diff_dst_layer_md_; |
444 | if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; |
445 | if (index == 2 && with_dst_iter_c()) return &diff_dst_iter_c_md_; |
446 | return &glob_zero_md; |
447 | } |
448 | |
449 | int n_inputs() const override { |
450 | return 6 + with_src_iter() + with_src_iter_c() |
451 | + 2 * (with_dst_iter() + with_dst_iter_c()) + is_lstm_peephole() |
452 | + is_lstm_projection() + with_bias() + is_augru(); |
453 | } |
454 | int n_outputs() const override { |
455 | return 3 + with_src_iter() + with_src_iter_c() + is_lstm_peephole() |
456 | + is_lstm_projection() + with_bias() + is_augru(); |
457 | } |
458 | |
459 | protected: |
460 | memory_desc_t diff_src_layer_md_; |
461 | memory_desc_t diff_src_iter_md_; |
462 | memory_desc_t diff_src_iter_c_md_; |
463 | memory_desc_t diff_weights_layer_md_; |
464 | memory_desc_t diff_weights_iter_md_; |
465 | memory_desc_t diff_weights_peephole_md_; |
466 | memory_desc_t diff_weights_projection_md_; |
467 | memory_desc_t diff_bias_md_; |
468 | memory_desc_t diff_dst_layer_md_; |
469 | memory_desc_t diff_dst_iter_md_; |
470 | memory_desc_t diff_dst_iter_c_md_; |
471 | |
472 | rnn_bwd_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, |
473 | const rnn_fwd_pd_t *hint_fwd_pd) |
474 | : rnn_pd_t(adesc, attr, hint_fwd_pd) |
475 | , diff_src_layer_md_(desc_.diff_src_layer_desc) |
476 | , diff_src_iter_md_(desc_.diff_src_iter_desc) |
477 | , diff_src_iter_c_md_(desc_.diff_src_iter_c_desc) |
478 | , diff_weights_layer_md_(desc_.diff_weights_layer_desc) |
479 | , diff_weights_iter_md_(desc_.diff_weights_iter_desc) |
480 | , diff_weights_peephole_md_(desc_.diff_weights_peephole_desc) |
481 | , diff_weights_projection_md_(desc_.diff_weights_projection_desc) |
482 | , diff_bias_md_(desc_.diff_bias_desc) |
483 | , diff_dst_layer_md_(desc_.diff_dst_layer_desc) |
484 | , diff_dst_iter_md_(desc_.diff_dst_iter_desc) |
485 | , diff_dst_iter_c_md_(desc_.diff_dst_iter_c_desc) {} |
486 | }; |
487 | |
488 | } // namespace impl |
489 | } // namespace dnnl |
490 | |
491 | #endif |
492 | |