1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "opdesc.hpp"
18#include "primitive_desc_iface.hpp"
19#include <initializer_list>
20
21#include "oneapi/dnnl/dnnl.h"
22
23#include "c_types_map.hpp"
24#include "type_helpers.hpp"
25#include "utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace rnn {
30
31int get_gates_count(dnnl_alg_kind_t cell_kind) {
32 switch (cell_kind) {
33 case dnnl::impl::alg_kind::vanilla_rnn: return 1;
34 case dnnl::impl::alg_kind::vanilla_gru:
35 case dnnl::impl::alg_kind::vanilla_augru: return 3;
36 case dnnl::impl::alg_kind::lbr_gru:
37 case dnnl::impl::alg_kind::lbr_augru: return 3;
38 case dnnl::impl::alg_kind::vanilla_lstm: return 4;
39 default: assert(!"unknown cell kind"); return 0;
40 }
41 return 0;
42}
43
44} // namespace rnn
45} // namespace impl
46} // namespace dnnl
47
48namespace {
49using namespace dnnl::impl;
50using namespace dnnl::impl::status;
51using namespace dnnl::impl::types;
52using namespace dnnl::impl::utils;
53
54void maybe_init_md(memory_desc_t &md, const memory_desc_t *with_md) {
55 if (with_md) md = *with_md;
56}
57
58bool xnor_md(const memory_desc_t *a_md, const memory_desc_t *b_md) {
59 return is_zero_md(a_md) == is_zero_md(b_md);
60}
61
62status_t check_runtime_dims_or_strides(
63 std::initializer_list<const memory_desc_t *> l) {
64 bool runtime_dims_or_strides = false;
65 for (auto md : l)
66 runtime_dims_or_strides = runtime_dims_or_strides
67 || memory_desc_wrapper(md).has_runtime_dims_or_strides();
68 return runtime_dims_or_strides ? unimplemented : success;
69}
70
71template <typename... DTs>
72bool expect_dt(const memory_desc_t &md, DTs... dts) {
73 return IMPLICATION(!is_zero_md(&md), utils::one_of(md.data_type, dts...));
74}
75
76status_t expect_dims(const memory_desc_t &md, std::initializer_list<dim_t> dims,
77 bool allow_zero = true) {
78 if (is_zero_md(&md))
79 return (allow_zero || dims.size() == 0) ? success : invalid_arguments;
80
81 if (md.ndims != (int)dims.size()) return invalid_arguments;
82
83 int d_in_md = 0;
84 for (auto d : dims)
85 if (d != md.dims[d_in_md++]) return invalid_arguments;
86
87 return success;
88}
89
90status_t check_data_type_consistency_fwd(const rnn_desc_t &r) {
91 using namespace data_type;
92 data_type_t src_layer_dt = r.src_layer_desc.data_type;
93 data_type_t dst_layer_dt = r.dst_layer_desc.data_type;
94 data_type_t weights_iter_dt = r.weights_iter_desc.data_type;
95 data_type_t weights_layer_dt = r.weights_layer_desc.data_type;
96 data_type_t weights_projection_dt = r.weights_projection_desc.data_type;
97
98 const bool is_forward = !(r.prop_kind == prop_kind::backward);
99 const bool is_inference = r.prop_kind == prop_kind::forward_inference;
100 const bool is_int8_ok
101 = one_of(r.cell_kind, dnnl_vanilla_lstm, dnnl_vanilla_gru);
102
103 const bool cell_state_check = expect_dt(r.src_iter_c_desc, f32, bf16, f16)
104 && expect_dt(r.dst_iter_c_desc, f32, bf16, f16);
105
106 const bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt,
107 weights_iter_dt, weights_layer_dt)
108 && expect_dt(r.src_iter_desc, f32)
109 && expect_dt(r.weights_peephole_desc, f32)
110 && expect_dt(r.weights_projection_desc, f32)
111 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
112
113 const bool is_bf16 = everyone_is(bf16, src_layer_dt, dst_layer_dt,
114 weights_iter_dt, weights_layer_dt)
115 && expect_dt(r.src_iter_desc, bf16)
116 && IMPLICATION(r.cell_kind == dnnl_vanilla_lstm,
117 expect_dt(r.weights_peephole_desc, f32))
118 /* weights_peephole_desc is reused as attention_desc */
119 && IMPLICATION(
120 one_of(r.cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru),
121 expect_dt(r.weights_peephole_desc, bf16))
122 && one_of(weights_projection_dt, bf16, data_type::undef)
123 && expect_dt(r.dst_iter_desc, bf16)
124 && one_of(r.bias_desc.data_type, bf16, f32);
125
126 const bool is_f16 = is_forward
127 && everyone_is(f16, src_layer_dt, dst_layer_dt, weights_iter_dt,
128 weights_layer_dt)
129 && expect_dt(r.src_iter_desc, f16)
130 && expect_dt(r.weights_peephole_desc, f16)
131 && r.weights_peephole_desc.data_type == data_type::undef
132 && expect_dt(r.dst_iter_desc, f16) && expect_dt(r.bias_desc, f16);
133
134 const bool is_u8u8u8 = is_inference && is_int8_ok && src_layer_dt == u8
135 && one_of(dst_layer_dt, u8, f32)
136 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
137 && expect_dt(r.src_iter_desc, u8)
138 && expect_dt(r.src_iter_c_desc, f32)
139 && r.weights_peephole_desc.data_type == data_type::undef
140 && one_of(weights_projection_dt, s8, data_type::undef)
141 && expect_dt(r.dst_iter_desc, u8)
142 && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
143
144 const bool is_f32u8f32 = is_inference && is_int8_ok && src_layer_dt == u8
145 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
146 && r.weights_peephole_desc.data_type == data_type::undef
147 && one_of(weights_projection_dt, s8, data_type::undef)
148 && one_of(dst_layer_dt, u8, f32) && expect_dt(r.src_iter_desc, f32)
149 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
150
151 const bool is_s8s8s8 = is_inference && is_int8_ok && src_layer_dt == s8
152 && one_of(dst_layer_dt, s8, f32)
153 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
154 && expect_dt(r.src_iter_desc, s8)
155 && expect_dt(r.src_iter_c_desc, f32)
156 && r.weights_peephole_desc.data_type == data_type::undef
157 && one_of(weights_projection_dt, s8, data_type::undef)
158 && expect_dt(r.dst_iter_desc, s8)
159 && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
160
161 const bool is_f32s8f32 = is_inference && is_int8_ok && src_layer_dt == s8
162 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
163 && r.weights_peephole_desc.data_type == data_type::undef
164 && one_of(weights_projection_dt, s8, data_type::undef)
165 && one_of(dst_layer_dt, s8, f32) && expect_dt(r.src_iter_desc, f32)
166 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
167
168 return cell_state_check
169 && (is_f32 || is_bf16 || is_f16 || is_u8u8u8 || is_f32u8f32
170 || is_s8s8s8 || is_f32s8f32)
171 ? success
172 : unimplemented;
173}
174
175status_t check_data_type_consistency_bwd(const rnn_desc_t &r) {
176 using namespace data_type;
177
178 /* We require diffs to be f32, even for bf16 */
179 bool are_diff_f32 = everyone_is(f32, r.diff_src_layer_desc.data_type,
180 r.diff_dst_layer_desc.data_type,
181 r.diff_weights_iter_desc.data_type,
182 r.diff_weights_layer_desc.data_type)
183 && expect_dt(r.diff_src_iter_desc, f32)
184 && expect_dt(r.diff_dst_iter_desc, f32)
185 && expect_dt(r.diff_weights_peephole_desc, f32)
186 && expect_dt(r.diff_weights_projection_desc, f32)
187 && expect_dt(r.diff_bias_desc, f32)
188 && expect_dt(r.diff_src_iter_c_desc, f32)
189 && expect_dt(r.diff_dst_iter_c_desc, f32);
190
191 return are_diff_f32 ? success : unimplemented;
192}
193
194status_t check_dim_consistency(const rnn_desc_t &r) {
195 const bool is_lstm_projection = r.cell_kind == dnnl_vanilla_lstm
196 && !is_zero_md(&r.weights_projection_desc);
197
198 const dim_t L = r.weights_layer_desc.dims[0];
199 const dim_t T = r.src_layer_desc.dims[0];
200 const dim_t N = r.src_layer_desc.dims[1];
201 const dim_t D = one_of(r.direction, dnnl_unidirectional_left2right,
202 dnnl_unidirectional_right2left)
203 ? 1
204 : 2;
205 const dim_t G = rnn::get_gates_count(r.cell_kind);
206 const dim_t SLC = r.src_layer_desc.dims[2];
207 const dim_t SIC = r.weights_iter_desc.dims[2];
208 const dim_t DLC = r.dst_layer_desc.dims[2];
209 const dim_t DHC = r.weights_layer_desc.dims[4];
210 const dim_t DIC
211 = is_lstm_projection ? r.weights_projection_desc.dims[3] : DHC;
212
213 const bool extra_bias = utils::one_of(
214 r.cell_kind, alg_kind::lbr_gru, alg_kind::lbr_augru);
215 const dim_t dlc_multiplier
216 = (r.direction == dnnl_bidirectional_concat) ? 2 : 1;
217
218 bool args_ok
219 = IMPLICATION(utils::one_of(r.cell_kind, alg_kind::vanilla_gru,
220 alg_kind::lbr_gru, alg_kind::vanilla_augru,
221 alg_kind::lbr_augru),
222 SIC == DHC)
223 && dlc_multiplier * DIC == DLC
224 && IMPLICATION(L > 1, dlc_multiplier * SLC == DLC)
225 && IMPLICATION(T > 1, SIC == DIC);
226 if (!args_ok) return invalid_arguments;
227
228 const bool is_augru = utils::one_of(
229 r.cell_kind, alg_kind::vanilla_augru, alg_kind::lbr_augru);
230 CHECK(expect_dims(r.src_layer_desc, {T, N, SLC}, false));
231 CHECK(expect_dims(r.src_iter_desc, {L, D, N, SIC}));
232 CHECK(expect_dims(r.src_iter_c_desc, {L, D, N, DHC}));
233 CHECK(expect_dims(r.weights_layer_desc, {L, D, SLC, G, DHC}, false));
234 CHECK(expect_dims(r.weights_iter_desc, {L, D, SIC, G, DHC}, false));
235 if (is_augru)
236 CHECK(expect_dims(r.weights_peephole_desc, {T, N, 1}));
237 else
238 CHECK(expect_dims(r.weights_peephole_desc, {L, D, 3, DHC}));
239 CHECK(expect_dims(r.weights_projection_desc, {L, D, DHC, DIC}));
240 CHECK(expect_dims(r.bias_desc, {L, D, G + extra_bias, DHC}));
241 CHECK(expect_dims(r.dst_layer_desc, {T, N, DLC}, false));
242 CHECK(expect_dims(r.dst_iter_desc, {L, D, N, DIC}));
243 CHECK(expect_dims(r.dst_iter_c_desc, {L, D, N, DHC}));
244
245 if (r.prop_kind == prop_kind::backward) {
246 CHECK(expect_dims(r.diff_src_layer_desc, {T, N, SLC}, false));
247 CHECK(expect_dims(r.diff_src_iter_desc, {L, D, N, SIC}));
248 CHECK(expect_dims(r.diff_src_iter_c_desc, {L, D, N, DHC}));
249 CHECK(expect_dims(
250 r.diff_weights_layer_desc, {L, D, SLC, G, DHC}, false));
251 CHECK(expect_dims(
252 r.diff_weights_iter_desc, {L, D, SIC, G, DHC}, false));
253 if (is_augru)
254 CHECK(expect_dims(r.diff_weights_peephole_desc, {T, N, 1}));
255 else
256 CHECK(expect_dims(r.diff_weights_peephole_desc, {L, D, 3, DHC}));
257 CHECK(expect_dims(r.diff_weights_projection_desc, {L, D, DHC, DIC}));
258 CHECK(expect_dims(r.diff_bias_desc, {L, D, G + extra_bias, DHC}));
259 CHECK(expect_dims(r.diff_dst_layer_desc, {T, N, DLC}, false));
260 CHECK(expect_dims(r.diff_dst_iter_desc, {L, D, N, DIC}));
261 CHECK(expect_dims(r.diff_dst_iter_c_desc, {L, D, N, DHC}));
262 }
263
264 return success;
265}
266
267status_t rnn_common_fwd_desc_init(rnn_desc_t *rnn_desc, prop_kind_t prop_kind,
268 alg_kind_t cell_kind, const rnn_direction_t direction,
269 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
270 const memory_desc_t *src_iter_c_desc,
271 const memory_desc_t *attention_desc,
272 const memory_desc_t *weights_layer_desc,
273 const memory_desc_t *weights_iter_desc,
274 const memory_desc_t *weights_peephole_desc,
275 const memory_desc_t *weights_projection_desc,
276 const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
277 const memory_desc_t *dst_iter_desc,
278 const memory_desc_t *dst_iter_c_desc, unsigned flags,
279 alg_kind_t activation = alg_kind::undef, float alpha = 0.0f,
280 float beta = 0.0f) {
281
282 // check that a supported cell kind has been passed
283 bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
284 dnnl_vanilla_gru, dnnl_lbr_gru, dnnl_vanilla_augru, dnnl_lbr_augru);
285 if (!args_ok) return invalid_arguments;
286
287 // check that all mandatory parameters are non-null
288 args_ok = args_ok
289 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
290 dst_layer_desc);
291 if (!args_ok) return invalid_arguments;
292
293 if (cell_kind == dnnl_vanilla_rnn) {
294 using namespace alg_kind;
295 args_ok = args_ok
296 && one_of(activation, eltwise_relu, eltwise_tanh,
297 eltwise_logistic);
298 if (!args_ok) return invalid_arguments;
299 }
300
301 if (cell_kind == dnnl_vanilla_lstm) {
302 // check if optional *_iter is provided then *_iter_c is provided too
303 args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
304 && xnor_md(dst_iter_desc, dst_iter_c_desc);
305 if (!args_ok) return invalid_arguments;
306 }
307
308 // check augru-specific restrictions
309 const bool is_augru = one_of(cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru);
310 if (is_augru) {
311 const dim_t L = weights_layer_desc->dims[0];
312 args_ok = args_ok && direction == dnnl_unidirectional_left2right
313 && L == 1;
314 if (!args_ok) return invalid_arguments;
315 }
316
317 CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
318 src_iter_c_desc, weights_layer_desc, weights_iter_desc,
319 weights_peephole_desc, weights_projection_desc, bias_desc,
320 dst_layer_desc, dst_iter_desc, dst_iter_c_desc}));
321
322 // Create the descriptor
323 auto rd = rnn_desc_t();
324
325 rd.primitive_kind = primitive_kind::rnn;
326 rd.prop_kind = prop_kind;
327 rd.cell_kind = cell_kind;
328 rd.direction = direction;
329 maybe_init_md(rd.src_layer_desc, src_layer_desc);
330 maybe_init_md(rd.src_iter_desc, src_iter_desc);
331 maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
332 maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
333 maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
334 maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
335 if (is_augru) maybe_init_md(rd.weights_peephole_desc, attention_desc);
336 maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
337 maybe_init_md(rd.bias_desc, bias_desc);
338 maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
339 maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
340 maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
341
342 rd.flags = flags;
343 rd.activation_kind = activation;
344 rd.alpha = alpha;
345 rd.beta = beta;
346
347 CHECK(check_data_type_consistency_fwd(rd));
348 CHECK(check_dim_consistency(rd));
349
350 *rnn_desc = rd;
351
352 return success;
353}
354
355status_t rnn_common_bwd_desc_init(rnn_desc_t *rnn_desc, prop_kind_t prop_kind,
356 alg_kind_t cell_kind, const dnnl_rnn_direction_t direction,
357 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
358 const memory_desc_t *src_iter_c_desc,
359 const memory_desc_t *attention_desc,
360 const memory_desc_t *weights_layer_desc,
361 const memory_desc_t *weights_iter_desc,
362 const memory_desc_t *weights_peephole_desc,
363 const memory_desc_t *weights_projection_desc,
364 const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
365 const memory_desc_t *dst_iter_desc,
366 const memory_desc_t *dst_iter_c_desc,
367 const memory_desc_t *diff_src_layer_desc,
368 const memory_desc_t *diff_src_iter_desc,
369 const memory_desc_t *diff_src_iter_c_desc,
370 const memory_desc_t *diff_attention_desc,
371 const memory_desc_t *diff_weights_layer_desc,
372 const memory_desc_t *diff_weights_iter_desc,
373 const memory_desc_t *diff_weights_peephole_desc,
374 const memory_desc_t *diff_weights_projection_desc,
375 const memory_desc_t *diff_bias_desc,
376 const memory_desc_t *diff_dst_layer_desc,
377 const memory_desc_t *diff_dst_iter_desc,
378 const memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
379 alg_kind_t activation = alg_kind::undef, float alpha = 0.0f,
380 float beta = 0.0f) {
381
382 // check that a supported cell kind has been passed
383 bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
384 dnnl_vanilla_gru, dnnl_lbr_gru, dnnl_vanilla_augru, dnnl_lbr_augru);
385 if (!args_ok) return invalid_arguments;
386
387 // check that all mandatory parameters are non-null
388 args_ok = args_ok
389 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
390 dst_layer_desc, diff_src_layer_desc,
391 diff_weights_layer_desc, diff_weights_iter_desc,
392 diff_dst_layer_desc);
393 if (!args_ok) return invalid_arguments;
394
395 if (cell_kind == dnnl_vanilla_rnn) {
396 using namespace alg_kind;
397 args_ok = args_ok
398 && one_of(activation, eltwise_relu, eltwise_tanh,
399 eltwise_logistic);
400 if (!args_ok) return invalid_arguments;
401 }
402
403 if (cell_kind == dnnl_vanilla_lstm) {
404 // check if optional *_iter is provided then *_iter_c is provided too
405 args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
406 && xnor_md(dst_iter_desc, dst_iter_c_desc);
407 if (!args_ok) return invalid_arguments;
408 }
409
410 const bool is_augru = one_of(cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru);
411 // check augru-specific restrictions
412 if (is_augru) {
413 const dim_t L = weights_layer_desc->dims[0];
414 const bool dims_ok = args_ok
415 && direction == dnnl_unidirectional_left2right && L == 1;
416 const bool descs_ok = !any_null(attention_desc, diff_attention_desc);
417 const bool args_ok = dims_ok && descs_ok;
418 if (!args_ok) return invalid_arguments;
419 }
420
421 // check if optional md is provided then diff_md is provided too
422 args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
423 && xnor_md(weights_peephole_desc, diff_weights_peephole_desc)
424 && xnor_md(weights_projection_desc, diff_weights_projection_desc)
425 && xnor_md(src_iter_desc, diff_src_iter_desc)
426 && xnor_md(src_iter_c_desc, diff_src_iter_c_desc)
427 && xnor_md(dst_iter_desc, diff_dst_iter_desc)
428 && xnor_md(dst_iter_c_desc, diff_dst_iter_c_desc);
429 if (!args_ok) return invalid_arguments;
430
431 CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
432 src_iter_c_desc, attention_desc, weights_layer_desc,
433 weights_iter_desc, weights_peephole_desc, weights_projection_desc,
434 bias_desc, dst_layer_desc, dst_iter_desc, dst_iter_c_desc,
435 diff_src_layer_desc, diff_src_iter_desc, diff_src_iter_c_desc,
436 diff_attention_desc, diff_weights_layer_desc,
437 diff_weights_iter_desc, diff_weights_peephole_desc,
438 diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
439 diff_dst_iter_desc, diff_dst_iter_c_desc}));
440
441 auto rd = rnn_desc_t();
442
443 rd.primitive_kind = primitive_kind::rnn;
444 rd.prop_kind = prop_kind;
445 rd.cell_kind = cell_kind;
446 rd.direction = direction;
447
448 maybe_init_md(rd.src_layer_desc, src_layer_desc);
449 maybe_init_md(rd.src_iter_desc, src_iter_desc);
450 maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
451 maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
452 maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
453 maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
454 if (is_augru) maybe_init_md(rd.weights_peephole_desc, attention_desc);
455 maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
456 maybe_init_md(rd.bias_desc, bias_desc);
457 maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
458 maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
459 maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
460 maybe_init_md(rd.diff_src_layer_desc, diff_src_layer_desc);
461 maybe_init_md(rd.diff_src_iter_desc, diff_src_iter_desc);
462 maybe_init_md(rd.diff_src_iter_c_desc, diff_src_iter_c_desc);
463 maybe_init_md(rd.diff_weights_layer_desc, diff_weights_layer_desc);
464 maybe_init_md(rd.diff_weights_iter_desc, diff_weights_iter_desc);
465 maybe_init_md(rd.diff_weights_peephole_desc, diff_weights_peephole_desc);
466 if (is_augru)
467 maybe_init_md(rd.diff_weights_peephole_desc, diff_attention_desc);
468 maybe_init_md(
469 rd.diff_weights_projection_desc, diff_weights_projection_desc);
470 maybe_init_md(rd.diff_bias_desc, diff_bias_desc);
471 maybe_init_md(rd.diff_dst_layer_desc, diff_dst_layer_desc);
472 maybe_init_md(rd.diff_dst_iter_desc, diff_dst_iter_desc);
473 maybe_init_md(rd.diff_dst_iter_c_desc, diff_dst_iter_c_desc);
474
475 rd.flags = flags;
476 rd.activation_kind = activation;
477 rd.alpha = alpha;
478 rd.beta = beta;
479
480 CHECK(check_data_type_consistency_fwd(rd));
481 CHECK(check_data_type_consistency_bwd(rd));
482
483 CHECK(check_dim_consistency(rd));
484
485 *rnn_desc = rd;
486
487 return success;
488}
489
490} // namespace
491
492/* Public C Api */
493
494status_t dnnl_vanilla_rnn_forward_primitive_desc_create(
495 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
496 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
497 const dnnl_rnn_direction_t direction,
498 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
499 const memory_desc_t *weights_layer_desc,
500 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
501 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
502 unsigned flags, float alpha, float beta, const primitive_attr_t *attr) {
503
504 auto rnn_desc = rnn_desc_t();
505 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_rnn,
506 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
507 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
508 dst_layer_desc, dst_iter_desc, nullptr, flags, activation, alpha,
509 beta));
510 return primitive_desc_create(primitive_desc_iface, engine,
511 (const op_desc_t *)&rnn_desc, nullptr, attr);
512}
513
514status_t dnnl_lstm_forward_primitive_desc_create(
515 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
516 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
517 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
518 const memory_desc_t *src_iter_c_desc,
519 const memory_desc_t *weights_layer_desc,
520 const memory_desc_t *weights_iter_desc,
521 const memory_desc_t *weights_peephole_desc,
522 const memory_desc_t *weights_projection_desc,
523 const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
524 const memory_desc_t *dst_iter_desc,
525 const memory_desc_t *dst_iter_c_desc, unsigned flags,
526 const primitive_attr_t *attr) {
527
528 auto rnn_desc = rnn_desc_t();
529 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_lstm,
530 direction, src_layer_desc, src_iter_desc, src_iter_c_desc, nullptr,
531 weights_layer_desc, weights_iter_desc, weights_peephole_desc,
532 weights_projection_desc, bias_desc, dst_layer_desc, dst_iter_desc,
533 dst_iter_c_desc, flags));
534 return primitive_desc_create(primitive_desc_iface, engine,
535 (const op_desc_t *)&rnn_desc, nullptr, attr);
536}
537
538status_t dnnl_gru_forward_primitive_desc_create(
539 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
540 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
541 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
542 const memory_desc_t *weights_layer_desc,
543 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
544 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
545 unsigned flags, const primitive_attr_t *attr) {
546
547 auto rnn_desc = rnn_desc_t();
548 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_gru,
549 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
550 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
551 dst_layer_desc, dst_iter_desc, nullptr, flags));
552 return primitive_desc_create(primitive_desc_iface, engine,
553 (const op_desc_t *)&rnn_desc, nullptr, attr);
554}
555
556status_t dnnl_lbr_gru_forward_primitive_desc_create(
557 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
558 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
559 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
560 const memory_desc_t *weights_layer_desc,
561 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
562 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
563 unsigned flags, const primitive_attr_t *attr) {
564
565 auto rnn_desc = rnn_desc_t();
566 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_gru,
567 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
568 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
569 dst_layer_desc, dst_iter_desc, nullptr, flags));
570 return primitive_desc_create(primitive_desc_iface, engine,
571 (const op_desc_t *)&rnn_desc, nullptr, attr);
572}
573
574status_t dnnl_augru_forward_primitive_desc_create(
575 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
576 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
577 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
578 const memory_desc_t *attention_desc,
579 const memory_desc_t *weights_layer_desc,
580 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
581 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
582 unsigned flags, const primitive_attr_t *attr) {
583
584 auto rnn_desc = rnn_desc_t();
585 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_augru,
586 direction, src_layer_desc, src_iter_desc, nullptr, attention_desc,
587 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
588 dst_layer_desc, dst_iter_desc, nullptr, flags));
589 return primitive_desc_create(primitive_desc_iface, engine,
590 (const op_desc_t *)&rnn_desc, nullptr, attr);
591}
592
593status_t dnnl_lbr_augru_forward_primitive_desc_create(
594 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
595 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
596 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
597 const memory_desc_t *attention_desc,
598 const memory_desc_t *weights_layer_desc,
599 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
600 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
601 unsigned flags, const primitive_attr_t *attr) {
602
603 auto rnn_desc = rnn_desc_t();
604 CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_augru,
605 direction, src_layer_desc, src_iter_desc, nullptr, attention_desc,
606 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
607 dst_layer_desc, dst_iter_desc, nullptr, flags));
608 return primitive_desc_create(primitive_desc_iface, engine,
609 (const op_desc_t *)&rnn_desc, nullptr, attr);
610}
611
612status_t dnnl_vanilla_rnn_backward_primitive_desc_create(
613 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
614 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
615 const dnnl_rnn_direction_t direction,
616 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
617 const memory_desc_t *weights_layer_desc,
618 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
619 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
620 const memory_desc_t *diff_src_layer_desc,
621 const memory_desc_t *diff_src_iter_desc,
622 const memory_desc_t *diff_weights_layer_desc,
623 const memory_desc_t *diff_weights_iter_desc,
624 const memory_desc_t *diff_bias_desc,
625 const memory_desc_t *diff_dst_layer_desc,
626 const memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha,
627 float beta, const primitive_desc_iface_t *hint_fwd_pd,
628 const primitive_attr_t *attr) {
629
630 auto rnn_desc = rnn_desc_t();
631 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_rnn,
632 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
633 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
634 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
635 diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc,
636 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
637 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags, activation,
638 alpha, beta));
639 return primitive_desc_create(primitive_desc_iface, engine,
640 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
641}
642
643status_t dnnl_lstm_backward_primitive_desc_create(
644 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
645 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
646 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
647 const memory_desc_t *src_iter_c_desc,
648 const memory_desc_t *weights_layer_desc,
649 const memory_desc_t *weights_iter_desc,
650 const memory_desc_t *weights_peephole_desc,
651 const memory_desc_t *weights_projection_desc,
652 const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
653 const memory_desc_t *dst_iter_desc,
654 const memory_desc_t *dst_iter_c_desc,
655 const memory_desc_t *diff_src_layer_desc,
656 const memory_desc_t *diff_src_iter_desc,
657 const memory_desc_t *diff_src_iter_c_desc,
658 const memory_desc_t *diff_weights_layer_desc,
659 const memory_desc_t *diff_weights_iter_desc,
660 const memory_desc_t *diff_weights_peephole_desc,
661 const memory_desc_t *diff_weights_projection_desc,
662 const memory_desc_t *diff_bias_desc,
663 const memory_desc_t *diff_dst_layer_desc,
664 const memory_desc_t *diff_dst_iter_desc,
665 const memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
666 const primitive_desc_iface_t *hint_fwd_pd,
667 const primitive_attr_t *attr) {
668
669 auto rnn_desc = rnn_desc_t();
670 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_lstm,
671 direction, src_layer_desc, src_iter_desc, src_iter_c_desc, nullptr,
672 weights_layer_desc, weights_iter_desc, weights_peephole_desc,
673 weights_projection_desc, bias_desc, dst_layer_desc, dst_iter_desc,
674 dst_iter_c_desc, diff_src_layer_desc, diff_src_iter_desc,
675 diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
676 diff_weights_iter_desc, diff_weights_peephole_desc,
677 diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
678 diff_dst_iter_desc, diff_dst_iter_c_desc, flags));
679 return primitive_desc_create(primitive_desc_iface, engine,
680 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
681}
682
683status_t dnnl_gru_backward_primitive_desc_create(
684 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
685 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
686 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
687 const memory_desc_t *weights_layer_desc,
688 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
689 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
690 const memory_desc_t *diff_src_layer_desc,
691 const memory_desc_t *diff_src_iter_desc,
692 const memory_desc_t *diff_weights_layer_desc,
693 const memory_desc_t *diff_weights_iter_desc,
694 const memory_desc_t *diff_bias_desc,
695 const memory_desc_t *diff_dst_layer_desc,
696 const memory_desc_t *diff_dst_iter_desc, unsigned flags,
697 const primitive_desc_iface_t *hint_fwd_pd,
698 const primitive_attr_t *attr) {
699
700 auto rnn_desc = rnn_desc_t();
701 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_gru,
702 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
703 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
704 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
705 diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc,
706 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
707 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags));
708 return primitive_desc_create(primitive_desc_iface, engine,
709 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
710}
711
712status_t dnnl_lbr_gru_backward_primitive_desc_create(
713 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
714 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
715 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
716 const memory_desc_t *weights_layer_desc,
717 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
718 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
719 const memory_desc_t *diff_src_layer_desc,
720 const memory_desc_t *diff_src_iter_desc,
721 const memory_desc_t *diff_weights_layer_desc,
722 const memory_desc_t *diff_weights_iter_desc,
723 const memory_desc_t *diff_bias_desc,
724 const memory_desc_t *diff_dst_layer_desc,
725 const memory_desc_t *diff_dst_iter_desc, unsigned flags,
726 const primitive_desc_iface_t *hint_fwd_pd,
727 const primitive_attr_t *attr) {
728
729 auto rnn_desc = rnn_desc_t();
730 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_gru,
731 direction, src_layer_desc, src_iter_desc, nullptr, nullptr,
732 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
733 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
734 diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc,
735 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
736 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags));
737 return primitive_desc_create(primitive_desc_iface, engine,
738 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
739}
740
741status_t dnnl_augru_backward_primitive_desc_create(
742 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
743 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
744 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
745 const memory_desc_t *attention_desc,
746 const memory_desc_t *weights_layer_desc,
747 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
748 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
749 const memory_desc_t *diff_src_layer_desc,
750 const memory_desc_t *diff_src_iter_desc,
751 const memory_desc_t *diff_attention_desc,
752 const memory_desc_t *diff_weights_layer_desc,
753 const memory_desc_t *diff_weights_iter_desc,
754 const memory_desc_t *diff_bias_desc,
755 const memory_desc_t *diff_dst_layer_desc,
756 const memory_desc_t *diff_dst_iter_desc, unsigned flags,
757 const primitive_desc_iface_t *hint_fwd_pd,
758 const primitive_attr_t *attr) {
759
760 auto rnn_desc = rnn_desc_t();
761 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_augru,
762 direction, src_layer_desc, src_iter_desc, nullptr, attention_desc,
763 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
764 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
765 diff_src_iter_desc, nullptr, diff_attention_desc,
766 diff_weights_layer_desc, diff_weights_iter_desc, nullptr, nullptr,
767 diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc, nullptr,
768 flags));
769 return primitive_desc_create(primitive_desc_iface, engine,
770 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
771}
772
773status_t dnnl_lbr_augru_backward_primitive_desc_create(
774 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
775 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
776 const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc,
777 const memory_desc_t *attention_desc,
778 const memory_desc_t *weights_layer_desc,
779 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
780 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
781 const memory_desc_t *diff_src_layer_desc,
782 const memory_desc_t *diff_src_iter_desc,
783 const memory_desc_t *diff_attention_desc,
784 const memory_desc_t *diff_weights_layer_desc,
785 const memory_desc_t *diff_weights_iter_desc,
786 const memory_desc_t *diff_bias_desc,
787 const memory_desc_t *diff_dst_layer_desc,
788 const memory_desc_t *diff_dst_iter_desc, unsigned flags,
789 const primitive_desc_iface_t *hint_fwd_pd,
790 const primitive_attr_t *attr) {
791
792 auto rnn_desc = rnn_desc_t();
793 CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_augru,
794 direction, src_layer_desc, src_iter_desc, nullptr, attention_desc,
795 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
796 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
797 diff_src_iter_desc, nullptr, diff_attention_desc,
798 diff_weights_layer_desc, diff_weights_iter_desc, nullptr, nullptr,
799 diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc, nullptr,
800 flags));
801 return primitive_desc_create(primitive_desc_iface, engine,
802 (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr);
803}
804