1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <numeric> |
18 | #include <utility> |
19 | #include <type_traits> |
20 | |
21 | #include "dnnl_test_common.hpp" |
22 | #include "gtest/gtest.h" |
23 | |
24 | #include "oneapi/dnnl/dnnl.hpp" |
25 | |
26 | namespace dnnl { |
27 | |
28 | struct test_rnn_sizes_t { |
29 | memory::dim l, d, t, mb; |
30 | memory::dim slc, sic, dhc, dic; |
31 | }; |
32 | |
33 | struct test_rnn_formats_t { |
34 | dnnl::memory::format_tag src_layer_fmt; |
35 | dnnl::memory::format_tag src_iter_fmt; |
36 | dnnl::memory::format_tag weights_layer_fmt; |
37 | dnnl::memory::format_tag weights_iter_fmt; |
38 | dnnl::memory::format_tag weights_peephole_fmt; |
39 | dnnl::memory::format_tag weights_projection_fmt; |
40 | dnnl::memory::format_tag bias_fmt; |
41 | dnnl::memory::format_tag dst_layer_fmt; |
42 | dnnl::memory::format_tag dst_iter_fmt; |
43 | }; |
44 | |
45 | struct { |
46 | dnnl::algorithm ; |
47 | float ; |
48 | }; |
49 | |
50 | struct test_rnn_params_t { |
51 | test_rnn_extra_t ; |
52 | prop_kind aprop; |
53 | dnnl::rnn_direction direction; |
54 | test_rnn_formats_t fmts; |
55 | test_rnn_sizes_t sizes; |
56 | bool expect_to_fail; |
57 | dnnl_status_t expected_status; |
58 | }; |
59 | |
60 | // We assume uniform data type accross tensors for now |
61 | template <typename T, typename data_t> |
62 | class rnn_forward_test_t : public ::testing::TestWithParam<test_rnn_params_t> { |
63 | |
64 | private: |
65 | memory::dim getNGates(); |
66 | |
67 | typename T::primitive_desc get_pd(prop_kind aprop, algorithm activation, |
68 | rnn_direction direction, const memory::desc &src_layer_md, |
69 | const memory::desc &src_iter_md, const memory::desc &src_iter_c_md, |
70 | const memory::desc &attention_md, |
71 | const memory::desc &weights_layer_md, |
72 | const memory::desc &weights_iter_md, |
73 | const memory::desc &weights_peephole_md, |
74 | const memory::desc &weights_projection_md, |
75 | const memory::desc &bias_md, const memory::desc &dst_layer_md, |
76 | const memory::desc &dst_iter_md, const memory::desc &dst_iter_c_md, |
77 | float alpha = 0.0f); |
78 | |
79 | bool skipTest(bool src_layer_match, bool augru_attention_match, |
80 | bool src_iter_match, bool src_iter_c_match, |
81 | bool weights_layer_match, bool weights_iter_match, bool bias_match, |
82 | bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) { |
83 | // By default, we ignore src_iter_c and dst_iter_c as they are |
84 | // only supported for lstm. For LSTM tests, this function |
85 | // should be specialized to handle them. |
86 | return src_layer_match && src_iter_match && weights_layer_match |
87 | && weights_iter_match && bias_match && dst_layer_match |
88 | && dst_iter_match; |
89 | } |
90 | |
91 | memory::desc querySrcIterC(const typename T::primitive_desc &rpd) { |
92 | return memory::desc(); |
93 | } |
94 | |
95 | memory::desc queryWeightsPeephole(const typename T::primitive_desc &rpd) { |
96 | return memory::desc(); |
97 | } |
98 | |
99 | memory::desc queryWeightsProjection(const typename T::primitive_desc &rpd) { |
100 | return memory::desc(); |
101 | } |
102 | |
103 | memory::desc queryDstIterC(const typename T::primitive_desc &rpd) { |
104 | return memory::desc(); |
105 | } |
106 | |
107 | void testExecArgQueries(typename T::primitive_desc pd) { |
108 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER) |
109 | == pd.weights_layer_desc()); |
110 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER) |
111 | == pd.weights_iter_desc()); |
112 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE) |
113 | == pd.weights_peephole_desc()); |
114 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION) |
115 | == pd.weights_projection_desc()); |
116 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS) |
117 | == pd.bias_desc()); |
118 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER) |
119 | == pd.src_layer_desc()); |
120 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION) |
121 | == pd.augru_attention_desc()); |
122 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER) |
123 | == pd.src_iter_desc()); |
124 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C) |
125 | == querySrcIterC(pd)); |
126 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER) |
127 | == pd.dst_layer_desc()); |
128 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_ITER) |
129 | == pd.dst_iter_desc()); |
130 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C) |
131 | == queryDstIterC(pd)); |
132 | }; |
133 | |
134 | void test_primitive_param_queries(typename T::primitive_desc pd) { |
135 | auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam(); |
136 | |
137 | dnnl::algorithm expected_cell_kind = algorithm::undef; |
138 | if (is_vanilla_rnn) expected_cell_kind = algorithm::vanilla_rnn; |
139 | if (is_lstm) expected_cell_kind = algorithm::vanilla_lstm; |
140 | if (is_gru) expected_cell_kind = algorithm::vanilla_gru; |
141 | if (is_lbr_gru) expected_cell_kind = algorithm::lbr_gru; |
142 | if (is_augru) expected_cell_kind = algorithm::vanilla_augru; |
143 | if (is_lbr_augru) expected_cell_kind = algorithm::lbr_augru; |
144 | |
145 | ASSERT_NE(expected_cell_kind, algorithm::undef); |
146 | ASSERT_EQ(pd.get_cell_kind(), expected_cell_kind); |
147 | ASSERT_EQ(pd.get_prop_kind(), p.aprop); |
148 | ASSERT_EQ(pd.get_direction(), p.direction); |
149 | |
150 | if (is_vanilla_rnn) { |
151 | ASSERT_EQ(pd.get_alpha(), p.extra.alpha); |
152 | ASSERT_EQ(pd.get_activation_kind(), p.extra.activation); |
153 | } else { |
154 | ASSERT_EQ(pd.get_alpha(), 0.0f); |
155 | ASSERT_EQ(pd.get_beta(), 0.0f); |
156 | ASSERT_EQ(pd.get_activation_kind(), algorithm::undef); |
157 | } |
158 | } |
159 | |
160 | protected: |
161 | static constexpr bool is_lstm = std::is_same<T, lstm_forward>::value; |
162 | static constexpr bool is_vanilla_rnn |
163 | = std::is_same<T, vanilla_rnn_forward>::value; |
164 | static constexpr bool is_gru = std::is_same<T, gru_forward>::value; |
165 | static constexpr bool is_lbr_gru = std::is_same<T, lbr_gru_forward>::value; |
166 | static constexpr bool is_augru = std::is_same<T, augru_forward>::value; |
167 | static constexpr bool is_lbr_augru |
168 | = std::is_same<T, lbr_augru_forward>::value; |
169 | |
170 | void SetUp() override { |
171 | auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam(); |
172 | catch_expected_failures( |
173 | [=]() { Test(); }, p.expect_to_fail, p.expected_status, false); |
174 | } |
175 | |
176 | void Test() { |
177 | auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam(); |
178 | const bool is_lstm_peephole |
179 | = p.fmts.weights_peephole_fmt != memory::format_tag::undef; |
180 | const bool is_lstm_projection |
181 | = p.fmts.weights_projection_fmt != memory::format_tag::undef; |
182 | auto eng = get_test_engine(); |
183 | auto strm = make_stream(eng); |
184 | //@todo check algorithm is one of the supported by RNN |
185 | //ASSERT_EQ(p.aalgorithm, algorithm::vanilla_lstm); |
186 | |
187 | // Initialize the data |
188 | memory::data_type prec = data_traits<data_t>::data_type; |
189 | auto dims = p.sizes; |
190 | auto t = dims.t, mb = dims.mb, l = dims.l, d = dims.d; |
191 | auto slc = dims.slc, sic = dims.sic, dhc = dims.dhc, dic = dims.dic; |
192 | auto dlc = (p.direction == rnn_direction::bidirectional_concat ? 2 : 1) |
193 | * dic; |
194 | memory::dim g = getNGates(); |
195 | memory::dim = std::is_same<T, lbr_gru_forward>::value |
196 | || std::is_same<T, lbr_augru_forward>::value |
197 | ? 1 |
198 | : 0; |
199 | |
200 | auto attention_dims = {t, mb, (memory::dim)1}; |
201 | auto weights_layer_dims = {l, d, slc, g, dhc}; |
202 | auto weights_iter_dims = {l, d, sic, g, dhc}; |
203 | auto weights_peephole_dims = {l, d, (memory::dim)3, dhc}; |
204 | auto weights_projection_dims = {l, d, dhc, dic}; |
205 | auto bias_dims = {l, d, g + bias_extra_gate, dhc}; |
206 | auto src_layer_dims = {t, mb, slc}; |
207 | auto src_iter_dims = {l, d, mb, sic}; |
208 | auto src_iter_c_dims = {l, d, mb, dhc}; |
209 | auto dst_layer_dims = {t, mb, dlc}; |
210 | auto dst_iter_dims = {l, d, mb, dic}; |
211 | auto dst_iter_c_dims = {l, d, mb, dhc}; |
212 | |
213 | auto attention_md_any |
214 | = memory::desc({attention_dims}, prec, memory::format_tag::any); |
215 | auto weights_layer_md_any = memory::desc( |
216 | {weights_layer_dims}, prec, memory::format_tag::any); |
217 | auto weights_iter_md_any = memory::desc( |
218 | {weights_iter_dims}, prec, memory::format_tag::any); |
219 | auto weights_peephole_md_any = memory::desc( |
220 | {weights_peephole_dims}, prec, memory::format_tag::any); |
221 | auto weights_projection_md_any = memory::desc( |
222 | {weights_projection_dims}, prec, memory::format_tag::any); |
223 | auto bias_md_any |
224 | = memory::desc({bias_dims}, prec, memory::format_tag::any); |
225 | auto src_layer_md_any |
226 | = memory::desc({src_layer_dims}, prec, memory::format_tag::any); |
227 | auto src_iter_md_any |
228 | = memory::desc({src_iter_dims}, prec, memory::format_tag::any); |
229 | auto src_iter_c_md_any = memory::desc( |
230 | {src_iter_c_dims}, prec, memory::format_tag::any); |
231 | auto dst_layer_md_any |
232 | = memory::desc({dst_layer_dims}, prec, memory::format_tag::any); |
233 | auto dst_iter_md_any |
234 | = memory::desc({dst_iter_dims}, prec, memory::format_tag::any); |
235 | auto dst_iter_c_md_any = memory::desc( |
236 | {dst_iter_c_dims}, prec, memory::format_tag::any); |
237 | |
238 | auto attention_md_tgt = (is_augru || is_lbr_augru) |
239 | ? memory::desc({attention_dims}, prec, memory::format_tag::tnc) |
240 | : memory::desc(); |
241 | auto weights_layer_md_tgt = memory::desc( |
242 | {weights_layer_dims}, prec, p.fmts.weights_layer_fmt); |
243 | auto weights_iter_md_tgt = memory::desc( |
244 | {weights_iter_dims}, prec, p.fmts.weights_iter_fmt); |
245 | auto weights_peephole_md_tgt = is_lstm_peephole |
246 | ? memory::desc({weights_peephole_dims}, prec, |
247 | p.fmts.weights_peephole_fmt) |
248 | : memory::desc(); |
249 | auto weights_projection_md_tgt = is_lstm_projection |
250 | ? memory::desc({weights_projection_dims}, prec, |
251 | p.fmts.weights_projection_fmt) |
252 | : memory::desc(); |
253 | auto bias_md_tgt = memory::desc({bias_dims}, prec, p.fmts.bias_fmt); |
254 | auto src_layer_md_tgt |
255 | = memory::desc({src_layer_dims}, prec, p.fmts.src_layer_fmt); |
256 | auto augru_attention_md_tgt |
257 | = memory::desc({attention_dims}, prec, memory::format_tag::tnc); |
258 | auto src_iter_md_tgt |
259 | = (p.fmts.src_iter_fmt != memory::format_tag::undef) |
260 | ? memory::desc({src_iter_dims}, prec, p.fmts.src_iter_fmt) |
261 | : memory::desc(); |
262 | auto src_iter_c_md_tgt |
263 | = (p.fmts.src_iter_fmt != memory::format_tag::undef) |
264 | ? memory::desc({src_iter_c_dims}, prec, p.fmts.src_iter_fmt) |
265 | : memory::desc(); |
266 | auto dst_layer_md_tgt |
267 | = memory::desc({dst_layer_dims}, prec, p.fmts.dst_layer_fmt); |
268 | auto dst_iter_md_tgt |
269 | = (p.fmts.dst_iter_fmt != memory::format_tag::undef) |
270 | ? memory::desc({dst_iter_dims}, prec, p.fmts.dst_iter_fmt) |
271 | : memory::desc(); |
272 | auto dst_iter_c_md_tgt |
273 | = (p.fmts.dst_iter_fmt != memory::format_tag::undef) |
274 | ? memory::desc({dst_iter_c_dims}, prec, p.fmts.dst_iter_fmt) |
275 | : memory::desc(); |
276 | |
277 | auto weights_projection_md_ldio = memory::desc( |
278 | {weights_projection_dims}, prec, memory::format_tag::ldio); |
279 | |
280 | // Create the reference primitive descriptor |
281 | auto ref_pd = get_pd(p.aprop, p.extra.activation, p.direction, |
282 | src_layer_md_any, src_iter_md_any, src_iter_c_md_any, |
283 | attention_md_any, weights_layer_md_any, weights_iter_md_any, |
284 | weights_peephole_md_any, weights_projection_md_any, bias_md_any, |
285 | dst_layer_md_any, dst_iter_md_any, dst_iter_c_md_any, |
286 | p.extra.alpha); |
287 | // test construction from a C pd |
288 | ref_pd = typename T::primitive_desc(ref_pd.get()); |
289 | testExecArgQueries(ref_pd); |
290 | test_primitive_param_queries(ref_pd); |
291 | |
292 | // Query the descriptor for memory descriptors |
293 | auto weights_layer_md_ref = ref_pd.weights_layer_desc(); |
294 | auto weights_iter_md_ref = ref_pd.weights_iter_desc(); |
295 | auto weights_peephole_md_ref = queryWeightsPeephole(ref_pd); |
296 | auto weights_projection_md_ref = queryWeightsProjection(ref_pd); |
297 | auto bias_md_ref = ref_pd.bias_desc(); |
298 | auto src_layer_md_ref = ref_pd.src_layer_desc(); |
299 | auto augru_attention_md_ref = ref_pd.augru_attention_desc(); |
300 | auto src_iter_md_ref = ref_pd.src_iter_desc(); |
301 | auto src_iter_c_md_ref = querySrcIterC(ref_pd); |
302 | auto dst_layer_md_ref = ref_pd.dst_layer_desc(); |
303 | auto dst_iter_md_ref = ref_pd.dst_iter_desc(); |
304 | auto dst_iter_c_md_ref = queryDstIterC(ref_pd); |
305 | |
306 | if (skipTest(weights_layer_md_ref == weights_layer_md_tgt, |
307 | weights_iter_md_ref == weights_iter_md_tgt, |
308 | bias_md_ref == bias_md_tgt, |
309 | src_layer_md_ref == src_layer_md_tgt, |
310 | augru_attention_md_ref == augru_attention_md_tgt, |
311 | src_iter_md_ref == src_iter_md_tgt, |
312 | src_iter_c_md_ref == src_iter_c_md_tgt, |
313 | dst_layer_md_ref == dst_layer_md_tgt, |
314 | dst_iter_md_ref == dst_iter_md_tgt, |
315 | dst_iter_c_md_ref == dst_iter_c_md_tgt)) |
316 | return; |
317 | |
318 | /* initialize data */ |
319 | auto weights_layer_ref = test::make_memory(weights_layer_md_ref, eng); |
320 | auto weights_iter_ref = test::make_memory(weights_iter_md_ref, eng); |
321 | auto weights_peephole_ref |
322 | = test::make_memory(weights_peephole_md_ref, eng); |
323 | auto weights_projection_ref |
324 | = test::make_memory(weights_projection_md_ref, eng); |
325 | auto bias_ref = test::make_memory(bias_md_ref, eng); |
326 | auto src_layer_ref = test::make_memory(src_layer_md_ref, eng); |
327 | auto augru_attention_ref |
328 | = test::make_memory(augru_attention_md_ref, eng); |
329 | auto src_iter_ref = test::make_memory(src_iter_md_ref, eng); |
330 | auto src_iter_c_ref = test::make_memory(src_iter_c_md_ref, eng); |
331 | auto dst_layer_ref = test::make_memory(dst_layer_md_ref, eng); |
332 | auto dst_iter_ref = test::make_memory(dst_iter_md_ref, eng); |
333 | auto dst_iter_c_ref = test::make_memory(dst_iter_c_md_ref, eng); |
334 | |
335 | auto weights_layer_tgt = test::make_memory(weights_layer_md_tgt, eng); |
336 | auto weights_iter_tgt = test::make_memory(weights_iter_md_tgt, eng); |
337 | auto weights_peephole_tgt |
338 | = test::make_memory(weights_peephole_md_tgt, eng); |
339 | auto weights_projection_tgt |
340 | = test::make_memory(weights_projection_md_tgt, eng); |
341 | auto bias_tgt = test::make_memory(bias_md_tgt, eng); |
342 | auto src_layer_tgt = test::make_memory(src_layer_md_tgt, eng); |
343 | auto augru_attention_tgt |
344 | = test::make_memory(augru_attention_md_tgt, eng); |
345 | auto src_iter_tgt = test::make_memory(src_iter_md_tgt, eng); |
346 | auto src_iter_c_tgt = test::make_memory(src_iter_c_md_tgt, eng); |
347 | auto dst_layer_tgt = test::make_memory(dst_layer_md_tgt, eng); |
348 | auto dst_iter_tgt = test::make_memory(dst_iter_md_tgt, eng); |
349 | auto dst_iter_c_tgt = test::make_memory(dst_iter_c_md_tgt, eng); |
350 | |
351 | auto weights_projection_ldio = memory(weights_projection_md_ldio, eng); |
352 | |
353 | // Assumption: b is a plain layout |
354 | auto init_tensor = [&](memory a, memory b, int scale = 1) { |
355 | auto desc = a.get_desc(); |
356 | auto b_dims = desc.get_dims(); |
357 | auto n_elems = std::accumulate(b_dims.begin(), b_dims.end(), |
358 | size_t(1), std::multiplies<dnnl_dim_t>()); |
359 | const dnnl::impl::memory_desc_wrapper mdw(desc.get()); |
360 | { |
361 | auto b_ptr = map_memory<float>(b); |
362 | for (size_t i = 0; i < n_elems; i++) |
363 | b_ptr[i] = scale * i; |
364 | } |
365 | reorder(b, a).execute(strm, b, a); |
366 | strm.wait(); |
367 | }; |
368 | auto init_zero_tensor = [&](const memory &a, memory::format_tag fmt) { |
369 | auto desc = a.get_desc(); |
370 | memory::desc tmp_md(desc.get_dims(), desc.get_data_type(), fmt); |
371 | auto tmp = test::make_memory(tmp_md, eng); |
372 | // Zero fill the tmp tensor |
373 | init_tensor(a, tmp, 0); |
374 | }; |
375 | auto init_id_wights_projection = [&](memory &w_plain, memory &w_rnn) { |
376 | auto w_plain_ptr = map_memory<float>(w_plain); |
377 | for_(memory::dim l = 0; l < dims.l; ++l) |
378 | for_(memory::dim d = 0; d < dims.d; ++d) |
379 | for_(memory::dim i = 0; i < dims.dhc; ++i) |
380 | for (memory::dim o = 0; o < dims.dic; ++o) { |
381 | auto off = (((l * dims.d) + d) * dims.dhc + i) * dims.dic + o; |
382 | w_plain_ptr[off] = (i == o) ? 1.f : 0.f; |
383 | } |
384 | |
385 | reorder(w_plain, w_rnn).execute(strm, w_plain, w_rnn); |
386 | strm.wait(); |
387 | }; |
388 | |
389 | init_tensor(weights_layer_ref, weights_layer_tgt); |
390 | init_tensor(weights_iter_ref, weights_iter_tgt); |
391 | if (is_lstm_peephole) |
392 | init_tensor(weights_peephole_ref, weights_peephole_tgt); |
393 | else if (std::is_same<T, lstm_forward>::value) |
394 | init_zero_tensor(weights_peephole_ref, memory::format_tag::ldgo); |
395 | if (is_lstm_projection) |
396 | init_tensor(weights_projection_ref, weights_projection_tgt); |
397 | else if (std::is_same<T, lstm_forward>::value) |
398 | init_id_wights_projection( |
399 | weights_projection_ldio, weights_projection_ref); |
400 | init_tensor(bias_ref, bias_tgt); |
401 | init_tensor(src_layer_ref, src_layer_tgt); |
402 | if (is_augru || is_lbr_augru) |
403 | init_tensor(augru_attention_ref, augru_attention_tgt); |
404 | if (p.fmts.src_iter_fmt != memory::format_tag::undef) { |
405 | init_tensor(src_iter_ref, src_iter_tgt); |
406 | if (std::is_same<T, lstm_forward>::value) |
407 | init_tensor(src_iter_c_ref, src_iter_c_tgt); |
408 | } else { |
409 | init_zero_tensor(src_iter_ref, memory::format_tag::ldnc); |
410 | if (std::is_same<T, lstm_forward>::value) |
411 | init_zero_tensor(src_iter_c_ref, memory::format_tag::ldnc); |
412 | } |
413 | |
414 | EXPECT_ANY_THROW(T(ref_pd, {})); |
415 | // run the non packed version |
416 | T(ref_pd).execute(strm, |
417 | {{DNNL_ARG_SRC_LAYER, src_layer_ref}, |
418 | {DNNL_ARG_AUGRU_ATTENTION, augru_attention_ref}, |
419 | {DNNL_ARG_SRC_ITER, src_iter_ref}, |
420 | {DNNL_ARG_SRC_ITER_C, src_iter_c_ref}, |
421 | {DNNL_ARG_WEIGHTS_LAYER, weights_layer_ref}, |
422 | {DNNL_ARG_WEIGHTS_ITER, weights_iter_ref}, |
423 | {DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_ref}, |
424 | {DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_ref}, |
425 | {DNNL_ARG_BIAS, bias_ref}, |
426 | {DNNL_ARG_DST_LAYER, dst_layer_ref}, |
427 | {DNNL_ARG_DST_ITER, dst_iter_ref}, |
428 | {DNNL_ARG_DST_ITER_C, dst_iter_c_ref}}); |
429 | strm.wait(); |
430 | |
431 | // run the packed version |
432 | auto tgt_pd = get_pd(p.aprop, p.extra.activation, p.direction, |
433 | src_layer_md_tgt, src_iter_md_tgt, src_iter_c_md_tgt, |
434 | attention_md_tgt, weights_layer_md_tgt, weights_iter_md_tgt, |
435 | weights_peephole_md_tgt, weights_projection_md_tgt, bias_md_tgt, |
436 | dst_layer_md_tgt, dst_iter_md_tgt, dst_iter_c_md_tgt, |
437 | p.extra.alpha); |
438 | testExecArgQueries(tgt_pd); |
439 | test_primitive_param_queries(tgt_pd); |
440 | |
441 | EXPECT_ANY_THROW(T(tgt_pd, {})); |
442 | T(tgt_pd).execute(strm, |
443 | {{DNNL_ARG_SRC_LAYER, src_layer_tgt}, |
444 | {DNNL_ARG_AUGRU_ATTENTION, augru_attention_tgt}, |
445 | {DNNL_ARG_SRC_ITER, src_iter_tgt}, |
446 | {DNNL_ARG_SRC_ITER_C, src_iter_c_tgt}, |
447 | {DNNL_ARG_WEIGHTS_LAYER, weights_layer_tgt}, |
448 | {DNNL_ARG_WEIGHTS_ITER, weights_iter_tgt}, |
449 | {DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_tgt}, |
450 | {DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_tgt}, |
451 | {DNNL_ARG_BIAS, bias_tgt}, |
452 | {DNNL_ARG_DST_LAYER, dst_layer_tgt}, |
453 | {DNNL_ARG_DST_ITER, dst_iter_tgt}, |
454 | {DNNL_ARG_DST_ITER_C, dst_iter_c_tgt}}); |
455 | strm.wait(); |
456 | |
457 | // compare dst_layer and dst_iter |
458 | compare_data<data_t>(dst_layer_ref, dst_layer_tgt, 1e-5); |
459 | if (p.fmts.dst_iter_fmt != memory::format_tag::undef) { |
460 | compare_data<data_t>(dst_iter_ref, dst_iter_tgt, 1e-5); |
461 | if (std::is_same<T, lstm_forward>::value) |
462 | compare_data<data_t>(dst_iter_c_ref, dst_iter_c_tgt, 1e-5); |
463 | } |
464 | } |
465 | }; |
466 | |
467 | /* RNN specializations */ |
468 | template <> |
469 | memory::dim rnn_forward_test_t<vanilla_rnn_forward, float>::getNGates() { |
470 | return 1; |
471 | } |
472 | |
473 | template <> |
474 | vanilla_rnn_forward::primitive_desc |
475 | rnn_forward_test_t<vanilla_rnn_forward, float>::get_pd(prop_kind aprop, |
476 | algorithm activation, rnn_direction direction, |
477 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
478 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
479 | const memory::desc &weights_layer_md, |
480 | const memory::desc &weights_iter_md, const memory::desc &, |
481 | const memory::desc &, const memory::desc &bias_md, |
482 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
483 | const memory::desc &dst_iter_c_md, float alpha) { |
484 | return vanilla_rnn_forward::primitive_desc(get_test_engine(), aprop, |
485 | activation, direction, src_layer_md, src_iter_md, weights_layer_md, |
486 | weights_iter_md, bias_md, dst_layer_md, dst_iter_md, alpha); |
487 | } |
488 | |
489 | /* LSTM specializations */ |
490 | template <> |
491 | memory::dim rnn_forward_test_t<lstm_forward, float>::getNGates() { |
492 | return 4; |
493 | } |
494 | |
495 | template <> |
496 | lstm_forward::primitive_desc rnn_forward_test_t<lstm_forward, float>::get_pd( |
497 | prop_kind aprop, algorithm activation, rnn_direction direction, |
498 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
499 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
500 | const memory::desc &weights_layer_md, |
501 | const memory::desc &weights_iter_md, |
502 | const memory::desc &weights_peephole_md, |
503 | const memory::desc &weights_projection_md, const memory::desc &bias_md, |
504 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
505 | const memory::desc &dst_iter_c_md, float alpha) { |
506 | return lstm_forward::primitive_desc(get_test_engine(), aprop, direction, |
507 | src_layer_md, src_iter_md, src_iter_c_md, weights_layer_md, |
508 | weights_iter_md, weights_peephole_md, weights_projection_md, |
509 | bias_md, dst_layer_md, dst_iter_md, dst_iter_c_md); |
510 | } |
511 | |
512 | template <> |
513 | bool rnn_forward_test_t<lstm_forward, float>::skipTest(bool src_layer_match, |
514 | bool augru_attention_match, bool src_iter_match, bool src_iter_c_match, |
515 | bool weights_layer_match, bool weights_iter_match, bool bias_match, |
516 | bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) { |
517 | return src_layer_match && src_iter_match && src_iter_c_match |
518 | && weights_layer_match && weights_iter_match && bias_match |
519 | && dst_layer_match && dst_iter_match && dst_iter_c_match; |
520 | } |
521 | |
522 | template <> |
523 | bool rnn_forward_test_t<augru_forward, float>::skipTest(bool src_layer_match, |
524 | bool augru_attention_match, bool src_iter_match, bool src_iter_c_match, |
525 | bool weights_layer_match, bool weights_iter_match, bool bias_match, |
526 | bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) { |
527 | return src_layer_match && augru_attention_match && src_iter_match |
528 | && src_iter_c_match && weights_layer_match && weights_iter_match |
529 | && bias_match && dst_layer_match && dst_iter_match |
530 | && dst_iter_c_match; |
531 | } |
532 | |
533 | template <> |
534 | bool rnn_forward_test_t<lbr_augru_forward, float>::skipTest( |
535 | bool src_layer_match, bool augru_attention_match, bool src_iter_match, |
536 | bool src_iter_c_match, bool weights_layer_match, |
537 | bool weights_iter_match, bool bias_match, bool dst_layer_match, |
538 | bool dst_iter_match, bool dst_iter_c_match) { |
539 | return src_layer_match && augru_attention_match && src_iter_match |
540 | && src_iter_c_match && weights_layer_match && weights_iter_match |
541 | && bias_match && dst_layer_match && dst_iter_match |
542 | && dst_iter_c_match; |
543 | } |
544 | |
545 | template <> |
546 | memory::desc rnn_forward_test_t<lstm_forward, float>::querySrcIterC( |
547 | const lstm_forward::primitive_desc &rpd) { |
548 | return rpd.src_iter_c_desc(); |
549 | } |
550 | |
551 | template <> |
552 | memory::desc rnn_forward_test_t<lstm_forward, float>::queryWeightsPeephole( |
553 | const lstm_forward::primitive_desc &rpd) { |
554 | return rpd.weights_peephole_desc(); |
555 | } |
556 | |
557 | template <> |
558 | memory::desc rnn_forward_test_t<lstm_forward, float>::queryWeightsProjection( |
559 | const lstm_forward::primitive_desc &rpd) { |
560 | return rpd.weights_projection_desc(); |
561 | } |
562 | |
563 | template <> |
564 | memory::desc rnn_forward_test_t<lstm_forward, float>::queryDstIterC( |
565 | const lstm_forward::primitive_desc &rpd) { |
566 | return rpd.dst_iter_c_desc(); |
567 | } |
568 | |
569 | /* GRU specializations */ |
570 | template <> |
571 | memory::dim rnn_forward_test_t<gru_forward, float>::getNGates() { |
572 | return 3; |
573 | } |
574 | |
575 | template <> |
576 | gru_forward::primitive_desc rnn_forward_test_t<gru_forward, float>::get_pd( |
577 | prop_kind aprop, algorithm activation, rnn_direction direction, |
578 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
579 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
580 | const memory::desc &weights_layer_md, |
581 | const memory::desc &weights_iter_md, const memory::desc &, |
582 | const memory::desc &, const memory::desc &bias_md, |
583 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
584 | const memory::desc &dst_iter_c_md, float alpha) { |
585 | return gru_forward::primitive_desc(get_test_engine(), aprop, direction, |
586 | src_layer_md, src_iter_md, weights_layer_md, weights_iter_md, |
587 | bias_md, dst_layer_md, dst_iter_md); |
588 | } |
589 | |
590 | /* LBR GRU specializations */ |
591 | template <> |
592 | memory::dim rnn_forward_test_t<lbr_gru_forward, float>::getNGates() { |
593 | return 3; |
594 | } |
595 | |
596 | template <> |
597 | lbr_gru_forward::primitive_desc |
598 | rnn_forward_test_t<lbr_gru_forward, float>::get_pd(prop_kind aprop, |
599 | algorithm activation, rnn_direction direction, |
600 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
601 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
602 | const memory::desc &weights_layer_md, |
603 | const memory::desc &weights_iter_md, const memory::desc &, |
604 | const memory::desc &, const memory::desc &bias_md, |
605 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
606 | const memory::desc &dst_iter_c_md, float alpha) { |
607 | return lbr_gru_forward::primitive_desc(get_test_engine(), aprop, direction, |
608 | src_layer_md, src_iter_md, weights_layer_md, weights_iter_md, |
609 | bias_md, dst_layer_md, dst_iter_md); |
610 | } |
611 | |
612 | /* AUGRU specializations */ |
613 | template <> |
614 | memory::dim rnn_forward_test_t<augru_forward, float>::getNGates() { |
615 | return 3; |
616 | } |
617 | |
618 | template <> |
619 | augru_forward::primitive_desc rnn_forward_test_t<augru_forward, float>::get_pd( |
620 | prop_kind aprop, algorithm activation, rnn_direction direction, |
621 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
622 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
623 | const memory::desc &weights_layer_md, |
624 | const memory::desc &weights_iter_md, const memory::desc &, |
625 | const memory::desc &, const memory::desc &bias_md, |
626 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
627 | const memory::desc &dst_iter_c_md, float alpha) { |
628 | return augru_forward::primitive_desc(get_test_engine(), aprop, direction, |
629 | src_layer_md, src_iter_md, attention_md, weights_layer_md, |
630 | weights_iter_md, bias_md, dst_layer_md, dst_iter_md); |
631 | } |
632 | |
633 | /* LBR AUGRU specializations */ |
634 | template <> |
635 | memory::dim rnn_forward_test_t<lbr_augru_forward, float>::getNGates() { |
636 | return 3; |
637 | } |
638 | |
639 | template <> |
640 | lbr_augru_forward::primitive_desc |
641 | rnn_forward_test_t<lbr_augru_forward, float>::get_pd(prop_kind aprop, |
642 | algorithm activation, rnn_direction direction, |
643 | const memory::desc &src_layer_md, const memory::desc &src_iter_md, |
644 | const memory::desc &src_iter_c_md, const memory::desc &attention_md, |
645 | const memory::desc &weights_layer_md, |
646 | const memory::desc &weights_iter_md, const memory::desc &, |
647 | const memory::desc &, const memory::desc &bias_md, |
648 | const memory::desc &dst_layer_md, const memory::desc &dst_iter_md, |
649 | const memory::desc &dst_iter_c_md, float alpha) { |
650 | return lbr_augru_forward::primitive_desc(get_test_engine(), aprop, |
651 | direction, src_layer_md, src_iter_md, attention_md, |
652 | weights_layer_md, weights_iter_md, bias_md, dst_layer_md, |
653 | dst_iter_md); |
654 | } |
655 | |
656 | using eng = engine::kind; |
657 | using fmt = memory::format_tag; |
658 | using alg = algorithm; |
659 | using dir = rnn_direction; |
660 | using rnn_forward_test_f32 = rnn_forward_test_t<vanilla_rnn_forward, float>; |
661 | using lstm_forward_test_f32 = rnn_forward_test_t<lstm_forward, float>; |
662 | using gru_forward_test_f32 = rnn_forward_test_t<gru_forward, float>; |
663 | using lbr_gru_forward_test_f32 = rnn_forward_test_t<lbr_gru_forward, float>; |
664 | using augru_forward_test_f32 = rnn_forward_test_t<augru_forward, float>; |
665 | using lbr_augru_forward_test_f32 = rnn_forward_test_t<lbr_augru_forward, float>; |
666 | |
667 | using cfg_f32 = test_rnn_params_t; |
668 | |
669 | #define PLAIN_RNN(a) \ |
670 | { a, 0.0f } |
671 | #define NOT_RNN \ |
672 | { alg::undef, 0.0f } |
673 | |
674 | TEST_P(rnn_forward_test_f32, TestsRnn) {} |
675 | CPU_INSTANTIATE_TEST_SUITE_P(TestRnn, rnn_forward_test_f32, |
676 | ::testing::Values( |
677 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
678 | prop_kind::forward_inference, |
679 | dir::unidirectional_left2right, |
680 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
681 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
682 | fmt::ldnc}, |
683 | test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}}, |
684 | /* Check for invalid parameters: unsupported unrolling */ |
685 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
686 | prop_kind::forward_inference, |
687 | dir::unidirectional_left2right, |
688 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
689 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
690 | fmt::ldnc}, |
691 | test_rnn_sizes_t {2, 1, 10, 16, 200, 100, 100, 100}, |
692 | true, dnnl_invalid_arguments}, |
693 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
694 | prop_kind::forward_inference, |
695 | dir::unidirectional_left2right, |
696 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
697 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
698 | fmt::ldnc}, |
699 | test_rnn_sizes_t {2, 1, 10, 16, 100, 200, 100, 100}, |
700 | true, dnnl_invalid_arguments}, |
701 | /* Check for invalid parameters: inconsistent dimensions */ |
702 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
703 | prop_kind::forward_inference, |
704 | dir::unidirectional_left2right, |
705 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
706 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
707 | fmt::ldnc}, |
708 | test_rnn_sizes_t {2, 1, 10, 16, 100, 100, 50, 100}, |
709 | true, dnnl_invalid_arguments}, |
710 | /* Check if passing {src,dst}_iter impacts results */ |
711 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
712 | |
713 | prop_kind::forward_inference, |
714 | dir::unidirectional_left2right, |
715 | {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo, |
716 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
717 | fmt::ldnc}, |
718 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}, |
719 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
720 | prop_kind::forward_inference, |
721 | dir::unidirectional_left2right, |
722 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
723 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
724 | fmt::undef}, |
725 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}, |
726 | cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), |
727 | prop_kind::forward_inference, |
728 | dir::unidirectional_left2right, |
729 | {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo, |
730 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
731 | fmt::undef}, |
732 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); |
733 | |
734 | TEST_P(lstm_forward_test_f32, TestsLSTM) {} |
735 | CPU_INSTANTIATE_TEST_SUITE_P(TestLSTM, lstm_forward_test_f32, |
736 | ::testing::Values( |
737 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
738 | dir::unidirectional_left2right, |
739 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
740 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
741 | fmt::ldnc}, |
742 | test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}}, |
743 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
744 | dir::unidirectional_left2right, |
745 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, |
746 | fmt::undef, fmt::ldgo, fmt::tnc, fmt::ldnc}, |
747 | test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}}, |
748 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
749 | dir::unidirectional_left2right, |
750 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, |
751 | fmt::ldio, fmt::ldgo, fmt::tnc, fmt::ldnc}, |
752 | test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}}, |
753 | /* Non uniform sizes tests */ |
754 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
755 | dir::unidirectional_left2right, |
756 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
757 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
758 | fmt::ldnc}, |
759 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
760 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
761 | dir::unidirectional_left2right, |
762 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, |
763 | fmt::undef, fmt::ldgo, fmt::tnc, fmt::ldnc}, |
764 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
765 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
766 | dir::unidirectional_left2right, |
767 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, |
768 | fmt::ldio, fmt::ldgo, fmt::tnc, fmt::ldnc}, |
769 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 15}}, |
770 | /* Check if not passing dst_iter impacts results */ |
771 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
772 | dir::unidirectional_left2right, |
773 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
774 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
775 | fmt::undef}, |
776 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); |
777 | |
778 | TEST_P(gru_forward_test_f32, TestsGRU) {} |
779 | CPU_INSTANTIATE_TEST_SUITE_P(TestGRU, gru_forward_test_f32, |
780 | ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
781 | dir::unidirectional_left2right, |
782 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
783 | fmt::undef, fmt::undef, fmt::ldgo, |
784 | fmt::tnc, fmt::ldnc}, |
785 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
786 | /* Check if not passing dst_iter impacts results */ |
787 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
788 | dir::unidirectional_left2right, |
789 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
790 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
791 | fmt::undef}, |
792 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); |
793 | |
794 | TEST_P(lbr_gru_forward_test_f32, TestsGRUlbr) {} |
795 | CPU_INSTANTIATE_TEST_SUITE_P(TestGRUlbr, lbr_gru_forward_test_f32, |
796 | ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
797 | dir::unidirectional_left2right, |
798 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
799 | fmt::undef, fmt::undef, fmt::ldgo, |
800 | fmt::tnc, fmt::ldnc}, |
801 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
802 | /* Check if not passing dst_iter impacts results */ |
803 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
804 | dir::unidirectional_left2right, |
805 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
806 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
807 | fmt::undef}, |
808 | test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); |
809 | |
810 | TEST_P(augru_forward_test_f32, TestsAUGRU) {} |
811 | CPU_INSTANTIATE_TEST_SUITE_P(TestAUGRU, augru_forward_test_f32, |
812 | ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
813 | dir::unidirectional_left2right, |
814 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
815 | fmt::undef, fmt::undef, fmt::ldgo, |
816 | fmt::tnc, fmt::ldnc}, |
817 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
818 | /* Check if not passing dst_iter impacts results */ |
819 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
820 | dir::unidirectional_left2right, |
821 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
822 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
823 | fmt::undef}, |
824 | test_rnn_sizes_t {1, 1, 5, 1, 4, 4, 4, 4}})); |
825 | |
826 | TEST_P(lbr_augru_forward_test_f32, TestsAUGRUlbr) {} |
827 | CPU_INSTANTIATE_TEST_SUITE_P(TestAUGRUlbr, lbr_augru_forward_test_f32, |
828 | ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
829 | dir::unidirectional_left2right, |
830 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
831 | fmt::undef, fmt::undef, fmt::ldgo, |
832 | fmt::tnc, fmt::ldnc}, |
833 | test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}}, |
834 | /* Check if not passing dst_iter impacts results */ |
835 | cfg_f32 {NOT_RNN, prop_kind::forward_inference, |
836 | dir::unidirectional_left2right, |
837 | {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, |
838 | fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, |
839 | fmt::undef}, |
840 | test_rnn_sizes_t {1, 1, 5, 1, 4, 4, 4, 4}})); |
841 | |
842 | } // namespace dnnl |
843 | |