1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "opdesc.hpp" |
18 | #include "primitive_desc_iface.hpp" |
19 | #include <initializer_list> |
20 | |
21 | #include "oneapi/dnnl/dnnl.h" |
22 | |
23 | #include "c_types_map.hpp" |
24 | #include "type_helpers.hpp" |
25 | #include "utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace rnn { |
30 | |
31 | int get_gates_count(dnnl_alg_kind_t cell_kind) { |
32 | switch (cell_kind) { |
33 | case dnnl::impl::alg_kind::vanilla_rnn: return 1; |
34 | case dnnl::impl::alg_kind::vanilla_gru: |
35 | case dnnl::impl::alg_kind::vanilla_augru: return 3; |
36 | case dnnl::impl::alg_kind::lbr_gru: |
37 | case dnnl::impl::alg_kind::lbr_augru: return 3; |
38 | case dnnl::impl::alg_kind::vanilla_lstm: return 4; |
39 | default: assert(!"unknown cell kind" ); return 0; |
40 | } |
41 | return 0; |
42 | } |
43 | |
44 | } // namespace rnn |
45 | } // namespace impl |
46 | } // namespace dnnl |
47 | |
48 | namespace { |
49 | using namespace dnnl::impl; |
50 | using namespace dnnl::impl::status; |
51 | using namespace dnnl::impl::types; |
52 | using namespace dnnl::impl::utils; |
53 | |
54 | void maybe_init_md(memory_desc_t &md, const memory_desc_t *with_md) { |
55 | if (with_md) md = *with_md; |
56 | } |
57 | |
58 | bool xnor_md(const memory_desc_t *a_md, const memory_desc_t *b_md) { |
59 | return is_zero_md(a_md) == is_zero_md(b_md); |
60 | } |
61 | |
62 | status_t check_runtime_dims_or_strides( |
63 | std::initializer_list<const memory_desc_t *> l) { |
64 | bool runtime_dims_or_strides = false; |
65 | for (auto md : l) |
66 | runtime_dims_or_strides = runtime_dims_or_strides |
67 | || memory_desc_wrapper(md).has_runtime_dims_or_strides(); |
68 | return runtime_dims_or_strides ? unimplemented : success; |
69 | } |
70 | |
71 | template <typename... DTs> |
72 | bool expect_dt(const memory_desc_t &md, DTs... dts) { |
73 | return IMPLICATION(!is_zero_md(&md), utils::one_of(md.data_type, dts...)); |
74 | } |
75 | |
76 | status_t expect_dims(const memory_desc_t &md, std::initializer_list<dim_t> dims, |
77 | bool allow_zero = true) { |
78 | if (is_zero_md(&md)) |
79 | return (allow_zero || dims.size() == 0) ? success : invalid_arguments; |
80 | |
81 | if (md.ndims != (int)dims.size()) return invalid_arguments; |
82 | |
83 | int d_in_md = 0; |
84 | for (auto d : dims) |
85 | if (d != md.dims[d_in_md++]) return invalid_arguments; |
86 | |
87 | return success; |
88 | } |
89 | |
90 | status_t check_data_type_consistency_fwd(const rnn_desc_t &r) { |
91 | using namespace data_type; |
92 | data_type_t src_layer_dt = r.src_layer_desc.data_type; |
93 | data_type_t dst_layer_dt = r.dst_layer_desc.data_type; |
94 | data_type_t weights_iter_dt = r.weights_iter_desc.data_type; |
95 | data_type_t weights_layer_dt = r.weights_layer_desc.data_type; |
96 | data_type_t weights_projection_dt = r.weights_projection_desc.data_type; |
97 | |
98 | const bool is_forward = !(r.prop_kind == prop_kind::backward); |
99 | const bool is_inference = r.prop_kind == prop_kind::forward_inference; |
100 | const bool is_int8_ok |
101 | = one_of(r.cell_kind, dnnl_vanilla_lstm, dnnl_vanilla_gru); |
102 | |
103 | const bool cell_state_check = expect_dt(r.src_iter_c_desc, f32, bf16, f16) |
104 | && expect_dt(r.dst_iter_c_desc, f32, bf16, f16); |
105 | |
106 | const bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, |
107 | weights_iter_dt, weights_layer_dt) |
108 | && expect_dt(r.src_iter_desc, f32) |
109 | && expect_dt(r.weights_peephole_desc, f32) |
110 | && expect_dt(r.weights_projection_desc, f32) |
111 | && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32); |
112 | |
113 | const bool is_bf16 = everyone_is(bf16, src_layer_dt, dst_layer_dt, |
114 | weights_iter_dt, weights_layer_dt) |
115 | && expect_dt(r.src_iter_desc, bf16) |
116 | && IMPLICATION(r.cell_kind == dnnl_vanilla_lstm, |
117 | expect_dt(r.weights_peephole_desc, f32)) |
118 | /* weights_peephole_desc is reused as attention_desc */ |
119 | && IMPLICATION( |
120 | one_of(r.cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru), |
121 | expect_dt(r.weights_peephole_desc, bf16)) |
122 | && one_of(weights_projection_dt, bf16, data_type::undef) |
123 | && expect_dt(r.dst_iter_desc, bf16) |
124 | && one_of(r.bias_desc.data_type, bf16, f32); |
125 | |
126 | const bool is_f16 = is_forward |
127 | && everyone_is(f16, src_layer_dt, dst_layer_dt, weights_iter_dt, |
128 | weights_layer_dt) |
129 | && expect_dt(r.src_iter_desc, f16) |
130 | && expect_dt(r.weights_peephole_desc, f16) |
131 | && r.weights_peephole_desc.data_type == data_type::undef |
132 | && expect_dt(r.dst_iter_desc, f16) && expect_dt(r.bias_desc, f16); |
133 | |
134 | const bool is_u8u8u8 = is_inference && is_int8_ok && src_layer_dt == u8 |
135 | && one_of(dst_layer_dt, u8, f32) |
136 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
137 | && expect_dt(r.src_iter_desc, u8) |
138 | && expect_dt(r.src_iter_c_desc, f32) |
139 | && r.weights_peephole_desc.data_type == data_type::undef |
140 | && one_of(weights_projection_dt, s8, data_type::undef) |
141 | && expect_dt(r.dst_iter_desc, u8) |
142 | && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32); |
143 | |
144 | const bool is_f32u8f32 = is_inference && is_int8_ok && src_layer_dt == u8 |
145 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
146 | && r.weights_peephole_desc.data_type == data_type::undef |
147 | && one_of(weights_projection_dt, s8, data_type::undef) |
148 | && one_of(dst_layer_dt, u8, f32) && expect_dt(r.src_iter_desc, f32) |
149 | && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32); |
150 | |
151 | const bool is_s8s8s8 = is_inference && is_int8_ok && src_layer_dt == s8 |
152 | && one_of(dst_layer_dt, s8, f32) |
153 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
154 | && expect_dt(r.src_iter_desc, s8) |
155 | && expect_dt(r.src_iter_c_desc, f32) |
156 | && r.weights_peephole_desc.data_type == data_type::undef |
157 | && one_of(weights_projection_dt, s8, data_type::undef) |
158 | && expect_dt(r.dst_iter_desc, s8) |
159 | && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32); |
160 | |
161 | const bool is_f32s8f32 = is_inference && is_int8_ok && src_layer_dt == s8 |
162 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
163 | && r.weights_peephole_desc.data_type == data_type::undef |
164 | && one_of(weights_projection_dt, s8, data_type::undef) |
165 | && one_of(dst_layer_dt, s8, f32) && expect_dt(r.src_iter_desc, f32) |
166 | && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32); |
167 | |
168 | return cell_state_check |
169 | && (is_f32 || is_bf16 || is_f16 || is_u8u8u8 || is_f32u8f32 |
170 | || is_s8s8s8 || is_f32s8f32) |
171 | ? success |
172 | : unimplemented; |
173 | } |
174 | |
175 | status_t check_data_type_consistency_bwd(const rnn_desc_t &r) { |
176 | using namespace data_type; |
177 | |
178 | /* We require diffs to be f32, even for bf16 */ |
179 | bool are_diff_f32 = everyone_is(f32, r.diff_src_layer_desc.data_type, |
180 | r.diff_dst_layer_desc.data_type, |
181 | r.diff_weights_iter_desc.data_type, |
182 | r.diff_weights_layer_desc.data_type) |
183 | && expect_dt(r.diff_src_iter_desc, f32) |
184 | && expect_dt(r.diff_dst_iter_desc, f32) |
185 | && expect_dt(r.diff_weights_peephole_desc, f32) |
186 | && expect_dt(r.diff_weights_projection_desc, f32) |
187 | && expect_dt(r.diff_bias_desc, f32) |
188 | && expect_dt(r.diff_src_iter_c_desc, f32) |
189 | && expect_dt(r.diff_dst_iter_c_desc, f32); |
190 | |
191 | return are_diff_f32 ? success : unimplemented; |
192 | } |
193 | |
194 | status_t check_dim_consistency(const rnn_desc_t &r) { |
195 | const bool is_lstm_projection = r.cell_kind == dnnl_vanilla_lstm |
196 | && !is_zero_md(&r.weights_projection_desc); |
197 | |
198 | const dim_t L = r.weights_layer_desc.dims[0]; |
199 | const dim_t T = r.src_layer_desc.dims[0]; |
200 | const dim_t N = r.src_layer_desc.dims[1]; |
201 | const dim_t D = one_of(r.direction, dnnl_unidirectional_left2right, |
202 | dnnl_unidirectional_right2left) |
203 | ? 1 |
204 | : 2; |
205 | const dim_t G = rnn::get_gates_count(r.cell_kind); |
206 | const dim_t SLC = r.src_layer_desc.dims[2]; |
207 | const dim_t SIC = r.weights_iter_desc.dims[2]; |
208 | const dim_t DLC = r.dst_layer_desc.dims[2]; |
209 | const dim_t DHC = r.weights_layer_desc.dims[4]; |
210 | const dim_t DIC |
211 | = is_lstm_projection ? r.weights_projection_desc.dims[3] : DHC; |
212 | |
213 | const bool extra_bias = utils::one_of( |
214 | r.cell_kind, alg_kind::lbr_gru, alg_kind::lbr_augru); |
215 | const dim_t dlc_multiplier |
216 | = (r.direction == dnnl_bidirectional_concat) ? 2 : 1; |
217 | |
218 | bool args_ok |
219 | = IMPLICATION(utils::one_of(r.cell_kind, alg_kind::vanilla_gru, |
220 | alg_kind::lbr_gru, alg_kind::vanilla_augru, |
221 | alg_kind::lbr_augru), |
222 | SIC == DHC) |
223 | && dlc_multiplier * DIC == DLC |
224 | && IMPLICATION(L > 1, dlc_multiplier * SLC == DLC) |
225 | && IMPLICATION(T > 1, SIC == DIC); |
226 | if (!args_ok) return invalid_arguments; |
227 | |
228 | const bool is_augru = utils::one_of( |
229 | r.cell_kind, alg_kind::vanilla_augru, alg_kind::lbr_augru); |
230 | CHECK(expect_dims(r.src_layer_desc, {T, N, SLC}, false)); |
231 | CHECK(expect_dims(r.src_iter_desc, {L, D, N, SIC})); |
232 | CHECK(expect_dims(r.src_iter_c_desc, {L, D, N, DHC})); |
233 | CHECK(expect_dims(r.weights_layer_desc, {L, D, SLC, G, DHC}, false)); |
234 | CHECK(expect_dims(r.weights_iter_desc, {L, D, SIC, G, DHC}, false)); |
235 | if (is_augru) |
236 | CHECK(expect_dims(r.weights_peephole_desc, {T, N, 1})); |
237 | else |
238 | CHECK(expect_dims(r.weights_peephole_desc, {L, D, 3, DHC})); |
239 | CHECK(expect_dims(r.weights_projection_desc, {L, D, DHC, DIC})); |
240 | CHECK(expect_dims(r.bias_desc, {L, D, G + extra_bias, DHC})); |
241 | CHECK(expect_dims(r.dst_layer_desc, {T, N, DLC}, false)); |
242 | CHECK(expect_dims(r.dst_iter_desc, {L, D, N, DIC})); |
243 | CHECK(expect_dims(r.dst_iter_c_desc, {L, D, N, DHC})); |
244 | |
245 | if (r.prop_kind == prop_kind::backward) { |
246 | CHECK(expect_dims(r.diff_src_layer_desc, {T, N, SLC}, false)); |
247 | CHECK(expect_dims(r.diff_src_iter_desc, {L, D, N, SIC})); |
248 | CHECK(expect_dims(r.diff_src_iter_c_desc, {L, D, N, DHC})); |
249 | CHECK(expect_dims( |
250 | r.diff_weights_layer_desc, {L, D, SLC, G, DHC}, false)); |
251 | CHECK(expect_dims( |
252 | r.diff_weights_iter_desc, {L, D, SIC, G, DHC}, false)); |
253 | if (is_augru) |
254 | CHECK(expect_dims(r.diff_weights_peephole_desc, {T, N, 1})); |
255 | else |
256 | CHECK(expect_dims(r.diff_weights_peephole_desc, {L, D, 3, DHC})); |
257 | CHECK(expect_dims(r.diff_weights_projection_desc, {L, D, DHC, DIC})); |
258 | CHECK(expect_dims(r.diff_bias_desc, {L, D, G + extra_bias, DHC})); |
259 | CHECK(expect_dims(r.diff_dst_layer_desc, {T, N, DLC}, false)); |
260 | CHECK(expect_dims(r.diff_dst_iter_desc, {L, D, N, DIC})); |
261 | CHECK(expect_dims(r.diff_dst_iter_c_desc, {L, D, N, DHC})); |
262 | } |
263 | |
264 | return success; |
265 | } |
266 | |
267 | status_t rnn_common_fwd_desc_init(rnn_desc_t *rnn_desc, prop_kind_t prop_kind, |
268 | alg_kind_t cell_kind, const rnn_direction_t direction, |
269 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
270 | const memory_desc_t *src_iter_c_desc, |
271 | const memory_desc_t *attention_desc, |
272 | const memory_desc_t *weights_layer_desc, |
273 | const memory_desc_t *weights_iter_desc, |
274 | const memory_desc_t *weights_peephole_desc, |
275 | const memory_desc_t *weights_projection_desc, |
276 | const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc, |
277 | const memory_desc_t *dst_iter_desc, |
278 | const memory_desc_t *dst_iter_c_desc, unsigned flags, |
279 | alg_kind_t activation = alg_kind::undef, float alpha = 0.0f, |
280 | float beta = 0.0f) { |
281 | |
282 | // check that a supported cell kind has been passed |
283 | bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm, |
284 | dnnl_vanilla_gru, dnnl_lbr_gru, dnnl_vanilla_augru, dnnl_lbr_augru); |
285 | if (!args_ok) return invalid_arguments; |
286 | |
287 | // check that all mandatory parameters are non-null |
288 | args_ok = args_ok |
289 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
290 | dst_layer_desc); |
291 | if (!args_ok) return invalid_arguments; |
292 | |
293 | if (cell_kind == dnnl_vanilla_rnn) { |
294 | using namespace alg_kind; |
295 | args_ok = args_ok |
296 | && one_of(activation, eltwise_relu, eltwise_tanh, |
297 | eltwise_logistic); |
298 | if (!args_ok) return invalid_arguments; |
299 | } |
300 | |
301 | if (cell_kind == dnnl_vanilla_lstm) { |
302 | // check if optional *_iter is provided then *_iter_c is provided too |
303 | args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc) |
304 | && xnor_md(dst_iter_desc, dst_iter_c_desc); |
305 | if (!args_ok) return invalid_arguments; |
306 | } |
307 | |
308 | // check augru-specific restrictions |
309 | const bool is_augru = one_of(cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru); |
310 | if (is_augru) { |
311 | const dim_t L = weights_layer_desc->dims[0]; |
312 | args_ok = args_ok && direction == dnnl_unidirectional_left2right |
313 | && L == 1; |
314 | if (!args_ok) return invalid_arguments; |
315 | } |
316 | |
317 | CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc, |
318 | src_iter_c_desc, weights_layer_desc, weights_iter_desc, |
319 | weights_peephole_desc, weights_projection_desc, bias_desc, |
320 | dst_layer_desc, dst_iter_desc, dst_iter_c_desc})); |
321 | |
322 | // Create the descriptor |
323 | auto rd = rnn_desc_t(); |
324 | |
325 | rd.primitive_kind = primitive_kind::rnn; |
326 | rd.prop_kind = prop_kind; |
327 | rd.cell_kind = cell_kind; |
328 | rd.direction = direction; |
329 | maybe_init_md(rd.src_layer_desc, src_layer_desc); |
330 | maybe_init_md(rd.src_iter_desc, src_iter_desc); |
331 | maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc); |
332 | maybe_init_md(rd.weights_layer_desc, weights_layer_desc); |
333 | maybe_init_md(rd.weights_iter_desc, weights_iter_desc); |
334 | maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc); |
335 | if (is_augru) maybe_init_md(rd.weights_peephole_desc, attention_desc); |
336 | maybe_init_md(rd.weights_projection_desc, weights_projection_desc); |
337 | maybe_init_md(rd.bias_desc, bias_desc); |
338 | maybe_init_md(rd.dst_layer_desc, dst_layer_desc); |
339 | maybe_init_md(rd.dst_iter_desc, dst_iter_desc); |
340 | maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc); |
341 | |
342 | rd.flags = flags; |
343 | rd.activation_kind = activation; |
344 | rd.alpha = alpha; |
345 | rd.beta = beta; |
346 | |
347 | CHECK(check_data_type_consistency_fwd(rd)); |
348 | CHECK(check_dim_consistency(rd)); |
349 | |
350 | *rnn_desc = rd; |
351 | |
352 | return success; |
353 | } |
354 | |
355 | status_t rnn_common_bwd_desc_init(rnn_desc_t *rnn_desc, prop_kind_t prop_kind, |
356 | alg_kind_t cell_kind, const dnnl_rnn_direction_t direction, |
357 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
358 | const memory_desc_t *src_iter_c_desc, |
359 | const memory_desc_t *attention_desc, |
360 | const memory_desc_t *weights_layer_desc, |
361 | const memory_desc_t *weights_iter_desc, |
362 | const memory_desc_t *weights_peephole_desc, |
363 | const memory_desc_t *weights_projection_desc, |
364 | const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc, |
365 | const memory_desc_t *dst_iter_desc, |
366 | const memory_desc_t *dst_iter_c_desc, |
367 | const memory_desc_t *diff_src_layer_desc, |
368 | const memory_desc_t *diff_src_iter_desc, |
369 | const memory_desc_t *diff_src_iter_c_desc, |
370 | const memory_desc_t *diff_attention_desc, |
371 | const memory_desc_t *diff_weights_layer_desc, |
372 | const memory_desc_t *diff_weights_iter_desc, |
373 | const memory_desc_t *diff_weights_peephole_desc, |
374 | const memory_desc_t *diff_weights_projection_desc, |
375 | const memory_desc_t *diff_bias_desc, |
376 | const memory_desc_t *diff_dst_layer_desc, |
377 | const memory_desc_t *diff_dst_iter_desc, |
378 | const memory_desc_t *diff_dst_iter_c_desc, unsigned flags, |
379 | alg_kind_t activation = alg_kind::undef, float alpha = 0.0f, |
380 | float beta = 0.0f) { |
381 | |
382 | // check that a supported cell kind has been passed |
383 | bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm, |
384 | dnnl_vanilla_gru, dnnl_lbr_gru, dnnl_vanilla_augru, dnnl_lbr_augru); |
385 | if (!args_ok) return invalid_arguments; |
386 | |
387 | // check that all mandatory parameters are non-null |
388 | args_ok = args_ok |
389 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
390 | dst_layer_desc, diff_src_layer_desc, |
391 | diff_weights_layer_desc, diff_weights_iter_desc, |
392 | diff_dst_layer_desc); |
393 | if (!args_ok) return invalid_arguments; |
394 | |
395 | if (cell_kind == dnnl_vanilla_rnn) { |
396 | using namespace alg_kind; |
397 | args_ok = args_ok |
398 | && one_of(activation, eltwise_relu, eltwise_tanh, |
399 | eltwise_logistic); |
400 | if (!args_ok) return invalid_arguments; |
401 | } |
402 | |
403 | if (cell_kind == dnnl_vanilla_lstm) { |
404 | // check if optional *_iter is provided then *_iter_c is provided too |
405 | args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc) |
406 | && xnor_md(dst_iter_desc, dst_iter_c_desc); |
407 | if (!args_ok) return invalid_arguments; |
408 | } |
409 | |
410 | const bool is_augru = one_of(cell_kind, dnnl_vanilla_augru, dnnl_lbr_augru); |
411 | // check augru-specific restrictions |
412 | if (is_augru) { |
413 | const dim_t L = weights_layer_desc->dims[0]; |
414 | const bool dims_ok = args_ok |
415 | && direction == dnnl_unidirectional_left2right && L == 1; |
416 | const bool descs_ok = !any_null(attention_desc, diff_attention_desc); |
417 | const bool args_ok = dims_ok && descs_ok; |
418 | if (!args_ok) return invalid_arguments; |
419 | } |
420 | |
421 | // check if optional md is provided then diff_md is provided too |
422 | args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) |
423 | && xnor_md(weights_peephole_desc, diff_weights_peephole_desc) |
424 | && xnor_md(weights_projection_desc, diff_weights_projection_desc) |
425 | && xnor_md(src_iter_desc, diff_src_iter_desc) |
426 | && xnor_md(src_iter_c_desc, diff_src_iter_c_desc) |
427 | && xnor_md(dst_iter_desc, diff_dst_iter_desc) |
428 | && xnor_md(dst_iter_c_desc, diff_dst_iter_c_desc); |
429 | if (!args_ok) return invalid_arguments; |
430 | |
431 | CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc, |
432 | src_iter_c_desc, attention_desc, weights_layer_desc, |
433 | weights_iter_desc, weights_peephole_desc, weights_projection_desc, |
434 | bias_desc, dst_layer_desc, dst_iter_desc, dst_iter_c_desc, |
435 | diff_src_layer_desc, diff_src_iter_desc, diff_src_iter_c_desc, |
436 | diff_attention_desc, diff_weights_layer_desc, |
437 | diff_weights_iter_desc, diff_weights_peephole_desc, |
438 | diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc, |
439 | diff_dst_iter_desc, diff_dst_iter_c_desc})); |
440 | |
441 | auto rd = rnn_desc_t(); |
442 | |
443 | rd.primitive_kind = primitive_kind::rnn; |
444 | rd.prop_kind = prop_kind; |
445 | rd.cell_kind = cell_kind; |
446 | rd.direction = direction; |
447 | |
448 | maybe_init_md(rd.src_layer_desc, src_layer_desc); |
449 | maybe_init_md(rd.src_iter_desc, src_iter_desc); |
450 | maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc); |
451 | maybe_init_md(rd.weights_layer_desc, weights_layer_desc); |
452 | maybe_init_md(rd.weights_iter_desc, weights_iter_desc); |
453 | maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc); |
454 | if (is_augru) maybe_init_md(rd.weights_peephole_desc, attention_desc); |
455 | maybe_init_md(rd.weights_projection_desc, weights_projection_desc); |
456 | maybe_init_md(rd.bias_desc, bias_desc); |
457 | maybe_init_md(rd.dst_layer_desc, dst_layer_desc); |
458 | maybe_init_md(rd.dst_iter_desc, dst_iter_desc); |
459 | maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc); |
460 | maybe_init_md(rd.diff_src_layer_desc, diff_src_layer_desc); |
461 | maybe_init_md(rd.diff_src_iter_desc, diff_src_iter_desc); |
462 | maybe_init_md(rd.diff_src_iter_c_desc, diff_src_iter_c_desc); |
463 | maybe_init_md(rd.diff_weights_layer_desc, diff_weights_layer_desc); |
464 | maybe_init_md(rd.diff_weights_iter_desc, diff_weights_iter_desc); |
465 | maybe_init_md(rd.diff_weights_peephole_desc, diff_weights_peephole_desc); |
466 | if (is_augru) |
467 | maybe_init_md(rd.diff_weights_peephole_desc, diff_attention_desc); |
468 | maybe_init_md( |
469 | rd.diff_weights_projection_desc, diff_weights_projection_desc); |
470 | maybe_init_md(rd.diff_bias_desc, diff_bias_desc); |
471 | maybe_init_md(rd.diff_dst_layer_desc, diff_dst_layer_desc); |
472 | maybe_init_md(rd.diff_dst_iter_desc, diff_dst_iter_desc); |
473 | maybe_init_md(rd.diff_dst_iter_c_desc, diff_dst_iter_c_desc); |
474 | |
475 | rd.flags = flags; |
476 | rd.activation_kind = activation; |
477 | rd.alpha = alpha; |
478 | rd.beta = beta; |
479 | |
480 | CHECK(check_data_type_consistency_fwd(rd)); |
481 | CHECK(check_data_type_consistency_bwd(rd)); |
482 | |
483 | CHECK(check_dim_consistency(rd)); |
484 | |
485 | *rnn_desc = rd; |
486 | |
487 | return success; |
488 | } |
489 | |
490 | } // namespace |
491 | |
492 | /* Public C Api */ |
493 | |
494 | status_t dnnl_vanilla_rnn_forward_primitive_desc_create( |
495 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
496 | dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, |
497 | const dnnl_rnn_direction_t direction, |
498 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
499 | const memory_desc_t *weights_layer_desc, |
500 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
501 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
502 | unsigned flags, float alpha, float beta, const primitive_attr_t *attr) { |
503 | |
504 | auto rnn_desc = rnn_desc_t(); |
505 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_rnn, |
506 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
507 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
508 | dst_layer_desc, dst_iter_desc, nullptr, flags, activation, alpha, |
509 | beta)); |
510 | return primitive_desc_create(primitive_desc_iface, engine, |
511 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
512 | } |
513 | |
514 | status_t dnnl_lstm_forward_primitive_desc_create( |
515 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
516 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
517 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
518 | const memory_desc_t *src_iter_c_desc, |
519 | const memory_desc_t *weights_layer_desc, |
520 | const memory_desc_t *weights_iter_desc, |
521 | const memory_desc_t *weights_peephole_desc, |
522 | const memory_desc_t *weights_projection_desc, |
523 | const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc, |
524 | const memory_desc_t *dst_iter_desc, |
525 | const memory_desc_t *dst_iter_c_desc, unsigned flags, |
526 | const primitive_attr_t *attr) { |
527 | |
528 | auto rnn_desc = rnn_desc_t(); |
529 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_lstm, |
530 | direction, src_layer_desc, src_iter_desc, src_iter_c_desc, nullptr, |
531 | weights_layer_desc, weights_iter_desc, weights_peephole_desc, |
532 | weights_projection_desc, bias_desc, dst_layer_desc, dst_iter_desc, |
533 | dst_iter_c_desc, flags)); |
534 | return primitive_desc_create(primitive_desc_iface, engine, |
535 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
536 | } |
537 | |
538 | status_t dnnl_gru_forward_primitive_desc_create( |
539 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
540 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
541 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
542 | const memory_desc_t *weights_layer_desc, |
543 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
544 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
545 | unsigned flags, const primitive_attr_t *attr) { |
546 | |
547 | auto rnn_desc = rnn_desc_t(); |
548 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_gru, |
549 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
550 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
551 | dst_layer_desc, dst_iter_desc, nullptr, flags)); |
552 | return primitive_desc_create(primitive_desc_iface, engine, |
553 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
554 | } |
555 | |
556 | status_t dnnl_lbr_gru_forward_primitive_desc_create( |
557 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
558 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
559 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
560 | const memory_desc_t *weights_layer_desc, |
561 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
562 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
563 | unsigned flags, const primitive_attr_t *attr) { |
564 | |
565 | auto rnn_desc = rnn_desc_t(); |
566 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_gru, |
567 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
568 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
569 | dst_layer_desc, dst_iter_desc, nullptr, flags)); |
570 | return primitive_desc_create(primitive_desc_iface, engine, |
571 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
572 | } |
573 | |
574 | status_t dnnl_augru_forward_primitive_desc_create( |
575 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
576 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
577 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
578 | const memory_desc_t *attention_desc, |
579 | const memory_desc_t *weights_layer_desc, |
580 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
581 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
582 | unsigned flags, const primitive_attr_t *attr) { |
583 | |
584 | auto rnn_desc = rnn_desc_t(); |
585 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_augru, |
586 | direction, src_layer_desc, src_iter_desc, nullptr, attention_desc, |
587 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
588 | dst_layer_desc, dst_iter_desc, nullptr, flags)); |
589 | return primitive_desc_create(primitive_desc_iface, engine, |
590 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
591 | } |
592 | |
593 | status_t dnnl_lbr_augru_forward_primitive_desc_create( |
594 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
595 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
596 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
597 | const memory_desc_t *attention_desc, |
598 | const memory_desc_t *weights_layer_desc, |
599 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
600 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
601 | unsigned flags, const primitive_attr_t *attr) { |
602 | |
603 | auto rnn_desc = rnn_desc_t(); |
604 | CHECK(rnn_common_fwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_augru, |
605 | direction, src_layer_desc, src_iter_desc, nullptr, attention_desc, |
606 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
607 | dst_layer_desc, dst_iter_desc, nullptr, flags)); |
608 | return primitive_desc_create(primitive_desc_iface, engine, |
609 | (const op_desc_t *)&rnn_desc, nullptr, attr); |
610 | } |
611 | |
612 | status_t dnnl_vanilla_rnn_backward_primitive_desc_create( |
613 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
614 | dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, |
615 | const dnnl_rnn_direction_t direction, |
616 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
617 | const memory_desc_t *weights_layer_desc, |
618 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
619 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
620 | const memory_desc_t *diff_src_layer_desc, |
621 | const memory_desc_t *diff_src_iter_desc, |
622 | const memory_desc_t *diff_weights_layer_desc, |
623 | const memory_desc_t *diff_weights_iter_desc, |
624 | const memory_desc_t *diff_bias_desc, |
625 | const memory_desc_t *diff_dst_layer_desc, |
626 | const memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, |
627 | float beta, const primitive_desc_iface_t *hint_fwd_pd, |
628 | const primitive_attr_t *attr) { |
629 | |
630 | auto rnn_desc = rnn_desc_t(); |
631 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_rnn, |
632 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
633 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
634 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
635 | diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc, |
636 | diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc, |
637 | diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags, activation, |
638 | alpha, beta)); |
639 | return primitive_desc_create(primitive_desc_iface, engine, |
640 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
641 | } |
642 | |
643 | status_t dnnl_lstm_backward_primitive_desc_create( |
644 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
645 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
646 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
647 | const memory_desc_t *src_iter_c_desc, |
648 | const memory_desc_t *weights_layer_desc, |
649 | const memory_desc_t *weights_iter_desc, |
650 | const memory_desc_t *weights_peephole_desc, |
651 | const memory_desc_t *weights_projection_desc, |
652 | const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc, |
653 | const memory_desc_t *dst_iter_desc, |
654 | const memory_desc_t *dst_iter_c_desc, |
655 | const memory_desc_t *diff_src_layer_desc, |
656 | const memory_desc_t *diff_src_iter_desc, |
657 | const memory_desc_t *diff_src_iter_c_desc, |
658 | const memory_desc_t *diff_weights_layer_desc, |
659 | const memory_desc_t *diff_weights_iter_desc, |
660 | const memory_desc_t *diff_weights_peephole_desc, |
661 | const memory_desc_t *diff_weights_projection_desc, |
662 | const memory_desc_t *diff_bias_desc, |
663 | const memory_desc_t *diff_dst_layer_desc, |
664 | const memory_desc_t *diff_dst_iter_desc, |
665 | const memory_desc_t *diff_dst_iter_c_desc, unsigned flags, |
666 | const primitive_desc_iface_t *hint_fwd_pd, |
667 | const primitive_attr_t *attr) { |
668 | |
669 | auto rnn_desc = rnn_desc_t(); |
670 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_lstm, |
671 | direction, src_layer_desc, src_iter_desc, src_iter_c_desc, nullptr, |
672 | weights_layer_desc, weights_iter_desc, weights_peephole_desc, |
673 | weights_projection_desc, bias_desc, dst_layer_desc, dst_iter_desc, |
674 | dst_iter_c_desc, diff_src_layer_desc, diff_src_iter_desc, |
675 | diff_src_iter_c_desc, nullptr, diff_weights_layer_desc, |
676 | diff_weights_iter_desc, diff_weights_peephole_desc, |
677 | diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc, |
678 | diff_dst_iter_desc, diff_dst_iter_c_desc, flags)); |
679 | return primitive_desc_create(primitive_desc_iface, engine, |
680 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
681 | } |
682 | |
683 | status_t dnnl_gru_backward_primitive_desc_create( |
684 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
685 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
686 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
687 | const memory_desc_t *weights_layer_desc, |
688 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
689 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
690 | const memory_desc_t *diff_src_layer_desc, |
691 | const memory_desc_t *diff_src_iter_desc, |
692 | const memory_desc_t *diff_weights_layer_desc, |
693 | const memory_desc_t *diff_weights_iter_desc, |
694 | const memory_desc_t *diff_bias_desc, |
695 | const memory_desc_t *diff_dst_layer_desc, |
696 | const memory_desc_t *diff_dst_iter_desc, unsigned flags, |
697 | const primitive_desc_iface_t *hint_fwd_pd, |
698 | const primitive_attr_t *attr) { |
699 | |
700 | auto rnn_desc = rnn_desc_t(); |
701 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_gru, |
702 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
703 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
704 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
705 | diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc, |
706 | diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc, |
707 | diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags)); |
708 | return primitive_desc_create(primitive_desc_iface, engine, |
709 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
710 | } |
711 | |
712 | status_t dnnl_lbr_gru_backward_primitive_desc_create( |
713 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
714 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
715 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
716 | const memory_desc_t *weights_layer_desc, |
717 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
718 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
719 | const memory_desc_t *diff_src_layer_desc, |
720 | const memory_desc_t *diff_src_iter_desc, |
721 | const memory_desc_t *diff_weights_layer_desc, |
722 | const memory_desc_t *diff_weights_iter_desc, |
723 | const memory_desc_t *diff_bias_desc, |
724 | const memory_desc_t *diff_dst_layer_desc, |
725 | const memory_desc_t *diff_dst_iter_desc, unsigned flags, |
726 | const primitive_desc_iface_t *hint_fwd_pd, |
727 | const primitive_attr_t *attr) { |
728 | |
729 | auto rnn_desc = rnn_desc_t(); |
730 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_gru, |
731 | direction, src_layer_desc, src_iter_desc, nullptr, nullptr, |
732 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
733 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
734 | diff_src_iter_desc, nullptr, nullptr, diff_weights_layer_desc, |
735 | diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc, |
736 | diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags)); |
737 | return primitive_desc_create(primitive_desc_iface, engine, |
738 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
739 | } |
740 | |
741 | status_t dnnl_augru_backward_primitive_desc_create( |
742 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
743 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
744 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
745 | const memory_desc_t *attention_desc, |
746 | const memory_desc_t *weights_layer_desc, |
747 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
748 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
749 | const memory_desc_t *diff_src_layer_desc, |
750 | const memory_desc_t *diff_src_iter_desc, |
751 | const memory_desc_t *diff_attention_desc, |
752 | const memory_desc_t *diff_weights_layer_desc, |
753 | const memory_desc_t *diff_weights_iter_desc, |
754 | const memory_desc_t *diff_bias_desc, |
755 | const memory_desc_t *diff_dst_layer_desc, |
756 | const memory_desc_t *diff_dst_iter_desc, unsigned flags, |
757 | const primitive_desc_iface_t *hint_fwd_pd, |
758 | const primitive_attr_t *attr) { |
759 | |
760 | auto rnn_desc = rnn_desc_t(); |
761 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_vanilla_augru, |
762 | direction, src_layer_desc, src_iter_desc, nullptr, attention_desc, |
763 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
764 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
765 | diff_src_iter_desc, nullptr, diff_attention_desc, |
766 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, nullptr, |
767 | diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc, nullptr, |
768 | flags)); |
769 | return primitive_desc_create(primitive_desc_iface, engine, |
770 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
771 | } |
772 | |
773 | status_t dnnl_lbr_augru_backward_primitive_desc_create( |
774 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
775 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
776 | const memory_desc_t *src_layer_desc, const memory_desc_t *src_iter_desc, |
777 | const memory_desc_t *attention_desc, |
778 | const memory_desc_t *weights_layer_desc, |
779 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
780 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
781 | const memory_desc_t *diff_src_layer_desc, |
782 | const memory_desc_t *diff_src_iter_desc, |
783 | const memory_desc_t *diff_attention_desc, |
784 | const memory_desc_t *diff_weights_layer_desc, |
785 | const memory_desc_t *diff_weights_iter_desc, |
786 | const memory_desc_t *diff_bias_desc, |
787 | const memory_desc_t *diff_dst_layer_desc, |
788 | const memory_desc_t *diff_dst_iter_desc, unsigned flags, |
789 | const primitive_desc_iface_t *hint_fwd_pd, |
790 | const primitive_attr_t *attr) { |
791 | |
792 | auto rnn_desc = rnn_desc_t(); |
793 | CHECK(rnn_common_bwd_desc_init(&rnn_desc, prop_kind, dnnl_lbr_augru, |
794 | direction, src_layer_desc, src_iter_desc, nullptr, attention_desc, |
795 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc, |
796 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
797 | diff_src_iter_desc, nullptr, diff_attention_desc, |
798 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, nullptr, |
799 | diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc, nullptr, |
800 | flags)); |
801 | return primitive_desc_create(primitive_desc_iface, engine, |
802 | (const op_desc_t *)&rnn_desc, hint_fwd_pd, attr); |
803 | } |
804 | |