1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <numeric>
18#include <utility>
19#include <type_traits>
20
21#include "dnnl_test_common.hpp"
22#include "gtest/gtest.h"
23
24#include "oneapi/dnnl/dnnl.hpp"
25
26namespace dnnl {
27
28struct test_rnn_sizes_t {
29 memory::dim l, d, t, mb;
30 memory::dim slc, sic, dhc, dic;
31};
32
33struct test_rnn_formats_t {
34 dnnl::memory::format_tag src_layer_fmt;
35 dnnl::memory::format_tag src_iter_fmt;
36 dnnl::memory::format_tag weights_layer_fmt;
37 dnnl::memory::format_tag weights_iter_fmt;
38 dnnl::memory::format_tag weights_peephole_fmt;
39 dnnl::memory::format_tag weights_projection_fmt;
40 dnnl::memory::format_tag bias_fmt;
41 dnnl::memory::format_tag dst_layer_fmt;
42 dnnl::memory::format_tag dst_iter_fmt;
43};
44
45struct test_rnn_extra_t {
46 dnnl::algorithm activation;
47 float alpha;
48};
49
50struct test_rnn_params_t {
51 test_rnn_extra_t extra;
52 prop_kind aprop;
53 dnnl::rnn_direction direction;
54 test_rnn_formats_t fmts;
55 test_rnn_sizes_t sizes;
56 bool expect_to_fail;
57 dnnl_status_t expected_status;
58};
59
60// We assume uniform data type accross tensors for now
61template <typename T, typename data_t>
62class rnn_forward_test_t : public ::testing::TestWithParam<test_rnn_params_t> {
63
64private:
65 memory::dim getNGates();
66
67 typename T::primitive_desc get_pd(prop_kind aprop, algorithm activation,
68 rnn_direction direction, const memory::desc &src_layer_md,
69 const memory::desc &src_iter_md, const memory::desc &src_iter_c_md,
70 const memory::desc &attention_md,
71 const memory::desc &weights_layer_md,
72 const memory::desc &weights_iter_md,
73 const memory::desc &weights_peephole_md,
74 const memory::desc &weights_projection_md,
75 const memory::desc &bias_md, const memory::desc &dst_layer_md,
76 const memory::desc &dst_iter_md, const memory::desc &dst_iter_c_md,
77 float alpha = 0.0f);
78
79 bool skipTest(bool src_layer_match, bool augru_attention_match,
80 bool src_iter_match, bool src_iter_c_match,
81 bool weights_layer_match, bool weights_iter_match, bool bias_match,
82 bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) {
83 // By default, we ignore src_iter_c and dst_iter_c as they are
84 // only supported for lstm. For LSTM tests, this function
85 // should be specialized to handle them.
86 return src_layer_match && src_iter_match && weights_layer_match
87 && weights_iter_match && bias_match && dst_layer_match
88 && dst_iter_match;
89 }
90
91 memory::desc querySrcIterC(const typename T::primitive_desc &rpd) {
92 return memory::desc();
93 }
94
95 memory::desc queryWeightsPeephole(const typename T::primitive_desc &rpd) {
96 return memory::desc();
97 }
98
99 memory::desc queryWeightsProjection(const typename T::primitive_desc &rpd) {
100 return memory::desc();
101 }
102
103 memory::desc queryDstIterC(const typename T::primitive_desc &rpd) {
104 return memory::desc();
105 }
106
107 void testExecArgQueries(typename T::primitive_desc pd) {
108 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER)
109 == pd.weights_layer_desc());
110 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER)
111 == pd.weights_iter_desc());
112 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE)
113 == pd.weights_peephole_desc());
114 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION)
115 == pd.weights_projection_desc());
116 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
117 == pd.bias_desc());
118 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER)
119 == pd.src_layer_desc());
120 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION)
121 == pd.augru_attention_desc());
122 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER)
123 == pd.src_iter_desc());
124 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C)
125 == querySrcIterC(pd));
126 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER)
127 == pd.dst_layer_desc());
128 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_ITER)
129 == pd.dst_iter_desc());
130 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C)
131 == queryDstIterC(pd));
132 };
133
134 void test_primitive_param_queries(typename T::primitive_desc pd) {
135 auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
136
137 dnnl::algorithm expected_cell_kind = algorithm::undef;
138 if (is_vanilla_rnn) expected_cell_kind = algorithm::vanilla_rnn;
139 if (is_lstm) expected_cell_kind = algorithm::vanilla_lstm;
140 if (is_gru) expected_cell_kind = algorithm::vanilla_gru;
141 if (is_lbr_gru) expected_cell_kind = algorithm::lbr_gru;
142 if (is_augru) expected_cell_kind = algorithm::vanilla_augru;
143 if (is_lbr_augru) expected_cell_kind = algorithm::lbr_augru;
144
145 ASSERT_NE(expected_cell_kind, algorithm::undef);
146 ASSERT_EQ(pd.get_cell_kind(), expected_cell_kind);
147 ASSERT_EQ(pd.get_prop_kind(), p.aprop);
148 ASSERT_EQ(pd.get_direction(), p.direction);
149
150 if (is_vanilla_rnn) {
151 ASSERT_EQ(pd.get_alpha(), p.extra.alpha);
152 ASSERT_EQ(pd.get_activation_kind(), p.extra.activation);
153 } else {
154 ASSERT_EQ(pd.get_alpha(), 0.0f);
155 ASSERT_EQ(pd.get_beta(), 0.0f);
156 ASSERT_EQ(pd.get_activation_kind(), algorithm::undef);
157 }
158 }
159
160protected:
161 static constexpr bool is_lstm = std::is_same<T, lstm_forward>::value;
162 static constexpr bool is_vanilla_rnn
163 = std::is_same<T, vanilla_rnn_forward>::value;
164 static constexpr bool is_gru = std::is_same<T, gru_forward>::value;
165 static constexpr bool is_lbr_gru = std::is_same<T, lbr_gru_forward>::value;
166 static constexpr bool is_augru = std::is_same<T, augru_forward>::value;
167 static constexpr bool is_lbr_augru
168 = std::is_same<T, lbr_augru_forward>::value;
169
170 void SetUp() override {
171 auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
172 catch_expected_failures(
173 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
174 }
175
176 void Test() {
177 auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
178 const bool is_lstm_peephole
179 = p.fmts.weights_peephole_fmt != memory::format_tag::undef;
180 const bool is_lstm_projection
181 = p.fmts.weights_projection_fmt != memory::format_tag::undef;
182 auto eng = get_test_engine();
183 auto strm = make_stream(eng);
184 //@todo check algorithm is one of the supported by RNN
185 //ASSERT_EQ(p.aalgorithm, algorithm::vanilla_lstm);
186
187 // Initialize the data
188 memory::data_type prec = data_traits<data_t>::data_type;
189 auto dims = p.sizes;
190 auto t = dims.t, mb = dims.mb, l = dims.l, d = dims.d;
191 auto slc = dims.slc, sic = dims.sic, dhc = dims.dhc, dic = dims.dic;
192 auto dlc = (p.direction == rnn_direction::bidirectional_concat ? 2 : 1)
193 * dic;
194 memory::dim g = getNGates();
195 memory::dim bias_extra_gate = std::is_same<T, lbr_gru_forward>::value
196 || std::is_same<T, lbr_augru_forward>::value
197 ? 1
198 : 0;
199
200 auto attention_dims = {t, mb, (memory::dim)1};
201 auto weights_layer_dims = {l, d, slc, g, dhc};
202 auto weights_iter_dims = {l, d, sic, g, dhc};
203 auto weights_peephole_dims = {l, d, (memory::dim)3, dhc};
204 auto weights_projection_dims = {l, d, dhc, dic};
205 auto bias_dims = {l, d, g + bias_extra_gate, dhc};
206 auto src_layer_dims = {t, mb, slc};
207 auto src_iter_dims = {l, d, mb, sic};
208 auto src_iter_c_dims = {l, d, mb, dhc};
209 auto dst_layer_dims = {t, mb, dlc};
210 auto dst_iter_dims = {l, d, mb, dic};
211 auto dst_iter_c_dims = {l, d, mb, dhc};
212
213 auto attention_md_any
214 = memory::desc({attention_dims}, prec, memory::format_tag::any);
215 auto weights_layer_md_any = memory::desc(
216 {weights_layer_dims}, prec, memory::format_tag::any);
217 auto weights_iter_md_any = memory::desc(
218 {weights_iter_dims}, prec, memory::format_tag::any);
219 auto weights_peephole_md_any = memory::desc(
220 {weights_peephole_dims}, prec, memory::format_tag::any);
221 auto weights_projection_md_any = memory::desc(
222 {weights_projection_dims}, prec, memory::format_tag::any);
223 auto bias_md_any
224 = memory::desc({bias_dims}, prec, memory::format_tag::any);
225 auto src_layer_md_any
226 = memory::desc({src_layer_dims}, prec, memory::format_tag::any);
227 auto src_iter_md_any
228 = memory::desc({src_iter_dims}, prec, memory::format_tag::any);
229 auto src_iter_c_md_any = memory::desc(
230 {src_iter_c_dims}, prec, memory::format_tag::any);
231 auto dst_layer_md_any
232 = memory::desc({dst_layer_dims}, prec, memory::format_tag::any);
233 auto dst_iter_md_any
234 = memory::desc({dst_iter_dims}, prec, memory::format_tag::any);
235 auto dst_iter_c_md_any = memory::desc(
236 {dst_iter_c_dims}, prec, memory::format_tag::any);
237
238 auto attention_md_tgt = (is_augru || is_lbr_augru)
239 ? memory::desc({attention_dims}, prec, memory::format_tag::tnc)
240 : memory::desc();
241 auto weights_layer_md_tgt = memory::desc(
242 {weights_layer_dims}, prec, p.fmts.weights_layer_fmt);
243 auto weights_iter_md_tgt = memory::desc(
244 {weights_iter_dims}, prec, p.fmts.weights_iter_fmt);
245 auto weights_peephole_md_tgt = is_lstm_peephole
246 ? memory::desc({weights_peephole_dims}, prec,
247 p.fmts.weights_peephole_fmt)
248 : memory::desc();
249 auto weights_projection_md_tgt = is_lstm_projection
250 ? memory::desc({weights_projection_dims}, prec,
251 p.fmts.weights_projection_fmt)
252 : memory::desc();
253 auto bias_md_tgt = memory::desc({bias_dims}, prec, p.fmts.bias_fmt);
254 auto src_layer_md_tgt
255 = memory::desc({src_layer_dims}, prec, p.fmts.src_layer_fmt);
256 auto augru_attention_md_tgt
257 = memory::desc({attention_dims}, prec, memory::format_tag::tnc);
258 auto src_iter_md_tgt
259 = (p.fmts.src_iter_fmt != memory::format_tag::undef)
260 ? memory::desc({src_iter_dims}, prec, p.fmts.src_iter_fmt)
261 : memory::desc();
262 auto src_iter_c_md_tgt
263 = (p.fmts.src_iter_fmt != memory::format_tag::undef)
264 ? memory::desc({src_iter_c_dims}, prec, p.fmts.src_iter_fmt)
265 : memory::desc();
266 auto dst_layer_md_tgt
267 = memory::desc({dst_layer_dims}, prec, p.fmts.dst_layer_fmt);
268 auto dst_iter_md_tgt
269 = (p.fmts.dst_iter_fmt != memory::format_tag::undef)
270 ? memory::desc({dst_iter_dims}, prec, p.fmts.dst_iter_fmt)
271 : memory::desc();
272 auto dst_iter_c_md_tgt
273 = (p.fmts.dst_iter_fmt != memory::format_tag::undef)
274 ? memory::desc({dst_iter_c_dims}, prec, p.fmts.dst_iter_fmt)
275 : memory::desc();
276
277 auto weights_projection_md_ldio = memory::desc(
278 {weights_projection_dims}, prec, memory::format_tag::ldio);
279
280 // Create the reference primitive descriptor
281 auto ref_pd = get_pd(p.aprop, p.extra.activation, p.direction,
282 src_layer_md_any, src_iter_md_any, src_iter_c_md_any,
283 attention_md_any, weights_layer_md_any, weights_iter_md_any,
284 weights_peephole_md_any, weights_projection_md_any, bias_md_any,
285 dst_layer_md_any, dst_iter_md_any, dst_iter_c_md_any,
286 p.extra.alpha);
287 // test construction from a C pd
288 ref_pd = typename T::primitive_desc(ref_pd.get());
289 testExecArgQueries(ref_pd);
290 test_primitive_param_queries(ref_pd);
291
292 // Query the descriptor for memory descriptors
293 auto weights_layer_md_ref = ref_pd.weights_layer_desc();
294 auto weights_iter_md_ref = ref_pd.weights_iter_desc();
295 auto weights_peephole_md_ref = queryWeightsPeephole(ref_pd);
296 auto weights_projection_md_ref = queryWeightsProjection(ref_pd);
297 auto bias_md_ref = ref_pd.bias_desc();
298 auto src_layer_md_ref = ref_pd.src_layer_desc();
299 auto augru_attention_md_ref = ref_pd.augru_attention_desc();
300 auto src_iter_md_ref = ref_pd.src_iter_desc();
301 auto src_iter_c_md_ref = querySrcIterC(ref_pd);
302 auto dst_layer_md_ref = ref_pd.dst_layer_desc();
303 auto dst_iter_md_ref = ref_pd.dst_iter_desc();
304 auto dst_iter_c_md_ref = queryDstIterC(ref_pd);
305
306 if (skipTest(weights_layer_md_ref == weights_layer_md_tgt,
307 weights_iter_md_ref == weights_iter_md_tgt,
308 bias_md_ref == bias_md_tgt,
309 src_layer_md_ref == src_layer_md_tgt,
310 augru_attention_md_ref == augru_attention_md_tgt,
311 src_iter_md_ref == src_iter_md_tgt,
312 src_iter_c_md_ref == src_iter_c_md_tgt,
313 dst_layer_md_ref == dst_layer_md_tgt,
314 dst_iter_md_ref == dst_iter_md_tgt,
315 dst_iter_c_md_ref == dst_iter_c_md_tgt))
316 return;
317
318 /* initialize data */
319 auto weights_layer_ref = test::make_memory(weights_layer_md_ref, eng);
320 auto weights_iter_ref = test::make_memory(weights_iter_md_ref, eng);
321 auto weights_peephole_ref
322 = test::make_memory(weights_peephole_md_ref, eng);
323 auto weights_projection_ref
324 = test::make_memory(weights_projection_md_ref, eng);
325 auto bias_ref = test::make_memory(bias_md_ref, eng);
326 auto src_layer_ref = test::make_memory(src_layer_md_ref, eng);
327 auto augru_attention_ref
328 = test::make_memory(augru_attention_md_ref, eng);
329 auto src_iter_ref = test::make_memory(src_iter_md_ref, eng);
330 auto src_iter_c_ref = test::make_memory(src_iter_c_md_ref, eng);
331 auto dst_layer_ref = test::make_memory(dst_layer_md_ref, eng);
332 auto dst_iter_ref = test::make_memory(dst_iter_md_ref, eng);
333 auto dst_iter_c_ref = test::make_memory(dst_iter_c_md_ref, eng);
334
335 auto weights_layer_tgt = test::make_memory(weights_layer_md_tgt, eng);
336 auto weights_iter_tgt = test::make_memory(weights_iter_md_tgt, eng);
337 auto weights_peephole_tgt
338 = test::make_memory(weights_peephole_md_tgt, eng);
339 auto weights_projection_tgt
340 = test::make_memory(weights_projection_md_tgt, eng);
341 auto bias_tgt = test::make_memory(bias_md_tgt, eng);
342 auto src_layer_tgt = test::make_memory(src_layer_md_tgt, eng);
343 auto augru_attention_tgt
344 = test::make_memory(augru_attention_md_tgt, eng);
345 auto src_iter_tgt = test::make_memory(src_iter_md_tgt, eng);
346 auto src_iter_c_tgt = test::make_memory(src_iter_c_md_tgt, eng);
347 auto dst_layer_tgt = test::make_memory(dst_layer_md_tgt, eng);
348 auto dst_iter_tgt = test::make_memory(dst_iter_md_tgt, eng);
349 auto dst_iter_c_tgt = test::make_memory(dst_iter_c_md_tgt, eng);
350
351 auto weights_projection_ldio = memory(weights_projection_md_ldio, eng);
352
353 // Assumption: b is a plain layout
354 auto init_tensor = [&](memory a, memory b, int scale = 1) {
355 auto desc = a.get_desc();
356 auto b_dims = desc.get_dims();
357 auto n_elems = std::accumulate(b_dims.begin(), b_dims.end(),
358 size_t(1), std::multiplies<dnnl_dim_t>());
359 const dnnl::impl::memory_desc_wrapper mdw(desc.get());
360 {
361 auto b_ptr = map_memory<float>(b);
362 for (size_t i = 0; i < n_elems; i++)
363 b_ptr[i] = scale * i;
364 }
365 reorder(b, a).execute(strm, b, a);
366 strm.wait();
367 };
368 auto init_zero_tensor = [&](const memory &a, memory::format_tag fmt) {
369 auto desc = a.get_desc();
370 memory::desc tmp_md(desc.get_dims(), desc.get_data_type(), fmt);
371 auto tmp = test::make_memory(tmp_md, eng);
372 // Zero fill the tmp tensor
373 init_tensor(a, tmp, 0);
374 };
375 auto init_id_wights_projection = [&](memory &w_plain, memory &w_rnn) {
376 auto w_plain_ptr = map_memory<float>(w_plain);
377 for_(memory::dim l = 0; l < dims.l; ++l)
378 for_(memory::dim d = 0; d < dims.d; ++d)
379 for_(memory::dim i = 0; i < dims.dhc; ++i)
380 for (memory::dim o = 0; o < dims.dic; ++o) {
381 auto off = (((l * dims.d) + d) * dims.dhc + i) * dims.dic + o;
382 w_plain_ptr[off] = (i == o) ? 1.f : 0.f;
383 }
384
385 reorder(w_plain, w_rnn).execute(strm, w_plain, w_rnn);
386 strm.wait();
387 };
388
389 init_tensor(weights_layer_ref, weights_layer_tgt);
390 init_tensor(weights_iter_ref, weights_iter_tgt);
391 if (is_lstm_peephole)
392 init_tensor(weights_peephole_ref, weights_peephole_tgt);
393 else if (std::is_same<T, lstm_forward>::value)
394 init_zero_tensor(weights_peephole_ref, memory::format_tag::ldgo);
395 if (is_lstm_projection)
396 init_tensor(weights_projection_ref, weights_projection_tgt);
397 else if (std::is_same<T, lstm_forward>::value)
398 init_id_wights_projection(
399 weights_projection_ldio, weights_projection_ref);
400 init_tensor(bias_ref, bias_tgt);
401 init_tensor(src_layer_ref, src_layer_tgt);
402 if (is_augru || is_lbr_augru)
403 init_tensor(augru_attention_ref, augru_attention_tgt);
404 if (p.fmts.src_iter_fmt != memory::format_tag::undef) {
405 init_tensor(src_iter_ref, src_iter_tgt);
406 if (std::is_same<T, lstm_forward>::value)
407 init_tensor(src_iter_c_ref, src_iter_c_tgt);
408 } else {
409 init_zero_tensor(src_iter_ref, memory::format_tag::ldnc);
410 if (std::is_same<T, lstm_forward>::value)
411 init_zero_tensor(src_iter_c_ref, memory::format_tag::ldnc);
412 }
413
414 EXPECT_ANY_THROW(T(ref_pd, {}));
415 // run the non packed version
416 T(ref_pd).execute(strm,
417 {{DNNL_ARG_SRC_LAYER, src_layer_ref},
418 {DNNL_ARG_AUGRU_ATTENTION, augru_attention_ref},
419 {DNNL_ARG_SRC_ITER, src_iter_ref},
420 {DNNL_ARG_SRC_ITER_C, src_iter_c_ref},
421 {DNNL_ARG_WEIGHTS_LAYER, weights_layer_ref},
422 {DNNL_ARG_WEIGHTS_ITER, weights_iter_ref},
423 {DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_ref},
424 {DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_ref},
425 {DNNL_ARG_BIAS, bias_ref},
426 {DNNL_ARG_DST_LAYER, dst_layer_ref},
427 {DNNL_ARG_DST_ITER, dst_iter_ref},
428 {DNNL_ARG_DST_ITER_C, dst_iter_c_ref}});
429 strm.wait();
430
431 // run the packed version
432 auto tgt_pd = get_pd(p.aprop, p.extra.activation, p.direction,
433 src_layer_md_tgt, src_iter_md_tgt, src_iter_c_md_tgt,
434 attention_md_tgt, weights_layer_md_tgt, weights_iter_md_tgt,
435 weights_peephole_md_tgt, weights_projection_md_tgt, bias_md_tgt,
436 dst_layer_md_tgt, dst_iter_md_tgt, dst_iter_c_md_tgt,
437 p.extra.alpha);
438 testExecArgQueries(tgt_pd);
439 test_primitive_param_queries(tgt_pd);
440
441 EXPECT_ANY_THROW(T(tgt_pd, {}));
442 T(tgt_pd).execute(strm,
443 {{DNNL_ARG_SRC_LAYER, src_layer_tgt},
444 {DNNL_ARG_AUGRU_ATTENTION, augru_attention_tgt},
445 {DNNL_ARG_SRC_ITER, src_iter_tgt},
446 {DNNL_ARG_SRC_ITER_C, src_iter_c_tgt},
447 {DNNL_ARG_WEIGHTS_LAYER, weights_layer_tgt},
448 {DNNL_ARG_WEIGHTS_ITER, weights_iter_tgt},
449 {DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_tgt},
450 {DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_tgt},
451 {DNNL_ARG_BIAS, bias_tgt},
452 {DNNL_ARG_DST_LAYER, dst_layer_tgt},
453 {DNNL_ARG_DST_ITER, dst_iter_tgt},
454 {DNNL_ARG_DST_ITER_C, dst_iter_c_tgt}});
455 strm.wait();
456
457 // compare dst_layer and dst_iter
458 compare_data<data_t>(dst_layer_ref, dst_layer_tgt, 1e-5);
459 if (p.fmts.dst_iter_fmt != memory::format_tag::undef) {
460 compare_data<data_t>(dst_iter_ref, dst_iter_tgt, 1e-5);
461 if (std::is_same<T, lstm_forward>::value)
462 compare_data<data_t>(dst_iter_c_ref, dst_iter_c_tgt, 1e-5);
463 }
464 }
465};
466
467/* RNN specializations */
468template <>
469memory::dim rnn_forward_test_t<vanilla_rnn_forward, float>::getNGates() {
470 return 1;
471}
472
473template <>
474vanilla_rnn_forward::primitive_desc
475rnn_forward_test_t<vanilla_rnn_forward, float>::get_pd(prop_kind aprop,
476 algorithm activation, rnn_direction direction,
477 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
478 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
479 const memory::desc &weights_layer_md,
480 const memory::desc &weights_iter_md, const memory::desc &,
481 const memory::desc &, const memory::desc &bias_md,
482 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
483 const memory::desc &dst_iter_c_md, float alpha) {
484 return vanilla_rnn_forward::primitive_desc(get_test_engine(), aprop,
485 activation, direction, src_layer_md, src_iter_md, weights_layer_md,
486 weights_iter_md, bias_md, dst_layer_md, dst_iter_md, alpha);
487}
488
489/* LSTM specializations */
490template <>
491memory::dim rnn_forward_test_t<lstm_forward, float>::getNGates() {
492 return 4;
493}
494
495template <>
496lstm_forward::primitive_desc rnn_forward_test_t<lstm_forward, float>::get_pd(
497 prop_kind aprop, algorithm activation, rnn_direction direction,
498 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
499 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
500 const memory::desc &weights_layer_md,
501 const memory::desc &weights_iter_md,
502 const memory::desc &weights_peephole_md,
503 const memory::desc &weights_projection_md, const memory::desc &bias_md,
504 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
505 const memory::desc &dst_iter_c_md, float alpha) {
506 return lstm_forward::primitive_desc(get_test_engine(), aprop, direction,
507 src_layer_md, src_iter_md, src_iter_c_md, weights_layer_md,
508 weights_iter_md, weights_peephole_md, weights_projection_md,
509 bias_md, dst_layer_md, dst_iter_md, dst_iter_c_md);
510}
511
512template <>
513bool rnn_forward_test_t<lstm_forward, float>::skipTest(bool src_layer_match,
514 bool augru_attention_match, bool src_iter_match, bool src_iter_c_match,
515 bool weights_layer_match, bool weights_iter_match, bool bias_match,
516 bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) {
517 return src_layer_match && src_iter_match && src_iter_c_match
518 && weights_layer_match && weights_iter_match && bias_match
519 && dst_layer_match && dst_iter_match && dst_iter_c_match;
520}
521
522template <>
523bool rnn_forward_test_t<augru_forward, float>::skipTest(bool src_layer_match,
524 bool augru_attention_match, bool src_iter_match, bool src_iter_c_match,
525 bool weights_layer_match, bool weights_iter_match, bool bias_match,
526 bool dst_layer_match, bool dst_iter_match, bool dst_iter_c_match) {
527 return src_layer_match && augru_attention_match && src_iter_match
528 && src_iter_c_match && weights_layer_match && weights_iter_match
529 && bias_match && dst_layer_match && dst_iter_match
530 && dst_iter_c_match;
531}
532
533template <>
534bool rnn_forward_test_t<lbr_augru_forward, float>::skipTest(
535 bool src_layer_match, bool augru_attention_match, bool src_iter_match,
536 bool src_iter_c_match, bool weights_layer_match,
537 bool weights_iter_match, bool bias_match, bool dst_layer_match,
538 bool dst_iter_match, bool dst_iter_c_match) {
539 return src_layer_match && augru_attention_match && src_iter_match
540 && src_iter_c_match && weights_layer_match && weights_iter_match
541 && bias_match && dst_layer_match && dst_iter_match
542 && dst_iter_c_match;
543}
544
545template <>
546memory::desc rnn_forward_test_t<lstm_forward, float>::querySrcIterC(
547 const lstm_forward::primitive_desc &rpd) {
548 return rpd.src_iter_c_desc();
549}
550
551template <>
552memory::desc rnn_forward_test_t<lstm_forward, float>::queryWeightsPeephole(
553 const lstm_forward::primitive_desc &rpd) {
554 return rpd.weights_peephole_desc();
555}
556
557template <>
558memory::desc rnn_forward_test_t<lstm_forward, float>::queryWeightsProjection(
559 const lstm_forward::primitive_desc &rpd) {
560 return rpd.weights_projection_desc();
561}
562
563template <>
564memory::desc rnn_forward_test_t<lstm_forward, float>::queryDstIterC(
565 const lstm_forward::primitive_desc &rpd) {
566 return rpd.dst_iter_c_desc();
567}
568
569/* GRU specializations */
570template <>
571memory::dim rnn_forward_test_t<gru_forward, float>::getNGates() {
572 return 3;
573}
574
575template <>
576gru_forward::primitive_desc rnn_forward_test_t<gru_forward, float>::get_pd(
577 prop_kind aprop, algorithm activation, rnn_direction direction,
578 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
579 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
580 const memory::desc &weights_layer_md,
581 const memory::desc &weights_iter_md, const memory::desc &,
582 const memory::desc &, const memory::desc &bias_md,
583 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
584 const memory::desc &dst_iter_c_md, float alpha) {
585 return gru_forward::primitive_desc(get_test_engine(), aprop, direction,
586 src_layer_md, src_iter_md, weights_layer_md, weights_iter_md,
587 bias_md, dst_layer_md, dst_iter_md);
588}
589
590/* LBR GRU specializations */
591template <>
592memory::dim rnn_forward_test_t<lbr_gru_forward, float>::getNGates() {
593 return 3;
594}
595
596template <>
597lbr_gru_forward::primitive_desc
598rnn_forward_test_t<lbr_gru_forward, float>::get_pd(prop_kind aprop,
599 algorithm activation, rnn_direction direction,
600 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
601 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
602 const memory::desc &weights_layer_md,
603 const memory::desc &weights_iter_md, const memory::desc &,
604 const memory::desc &, const memory::desc &bias_md,
605 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
606 const memory::desc &dst_iter_c_md, float alpha) {
607 return lbr_gru_forward::primitive_desc(get_test_engine(), aprop, direction,
608 src_layer_md, src_iter_md, weights_layer_md, weights_iter_md,
609 bias_md, dst_layer_md, dst_iter_md);
610}
611
612/* AUGRU specializations */
613template <>
614memory::dim rnn_forward_test_t<augru_forward, float>::getNGates() {
615 return 3;
616}
617
618template <>
619augru_forward::primitive_desc rnn_forward_test_t<augru_forward, float>::get_pd(
620 prop_kind aprop, algorithm activation, rnn_direction direction,
621 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
622 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
623 const memory::desc &weights_layer_md,
624 const memory::desc &weights_iter_md, const memory::desc &,
625 const memory::desc &, const memory::desc &bias_md,
626 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
627 const memory::desc &dst_iter_c_md, float alpha) {
628 return augru_forward::primitive_desc(get_test_engine(), aprop, direction,
629 src_layer_md, src_iter_md, attention_md, weights_layer_md,
630 weights_iter_md, bias_md, dst_layer_md, dst_iter_md);
631}
632
633/* LBR AUGRU specializations */
634template <>
635memory::dim rnn_forward_test_t<lbr_augru_forward, float>::getNGates() {
636 return 3;
637}
638
639template <>
640lbr_augru_forward::primitive_desc
641rnn_forward_test_t<lbr_augru_forward, float>::get_pd(prop_kind aprop,
642 algorithm activation, rnn_direction direction,
643 const memory::desc &src_layer_md, const memory::desc &src_iter_md,
644 const memory::desc &src_iter_c_md, const memory::desc &attention_md,
645 const memory::desc &weights_layer_md,
646 const memory::desc &weights_iter_md, const memory::desc &,
647 const memory::desc &, const memory::desc &bias_md,
648 const memory::desc &dst_layer_md, const memory::desc &dst_iter_md,
649 const memory::desc &dst_iter_c_md, float alpha) {
650 return lbr_augru_forward::primitive_desc(get_test_engine(), aprop,
651 direction, src_layer_md, src_iter_md, attention_md,
652 weights_layer_md, weights_iter_md, bias_md, dst_layer_md,
653 dst_iter_md);
654}
655
656using eng = engine::kind;
657using fmt = memory::format_tag;
658using alg = algorithm;
659using dir = rnn_direction;
660using rnn_forward_test_f32 = rnn_forward_test_t<vanilla_rnn_forward, float>;
661using lstm_forward_test_f32 = rnn_forward_test_t<lstm_forward, float>;
662using gru_forward_test_f32 = rnn_forward_test_t<gru_forward, float>;
663using lbr_gru_forward_test_f32 = rnn_forward_test_t<lbr_gru_forward, float>;
664using augru_forward_test_f32 = rnn_forward_test_t<augru_forward, float>;
665using lbr_augru_forward_test_f32 = rnn_forward_test_t<lbr_augru_forward, float>;
666
667using cfg_f32 = test_rnn_params_t;
668
669#define PLAIN_RNN(a) \
670 { a, 0.0f }
671#define NOT_RNN \
672 { alg::undef, 0.0f }
673
674TEST_P(rnn_forward_test_f32, TestsRnn) {}
675CPU_INSTANTIATE_TEST_SUITE_P(TestRnn, rnn_forward_test_f32,
676 ::testing::Values(
677 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
678 prop_kind::forward_inference,
679 dir::unidirectional_left2right,
680 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
681 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
682 fmt::ldnc},
683 test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}},
684 /* Check for invalid parameters: unsupported unrolling */
685 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
686 prop_kind::forward_inference,
687 dir::unidirectional_left2right,
688 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
689 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
690 fmt::ldnc},
691 test_rnn_sizes_t {2, 1, 10, 16, 200, 100, 100, 100},
692 true, dnnl_invalid_arguments},
693 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
694 prop_kind::forward_inference,
695 dir::unidirectional_left2right,
696 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
697 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
698 fmt::ldnc},
699 test_rnn_sizes_t {2, 1, 10, 16, 100, 200, 100, 100},
700 true, dnnl_invalid_arguments},
701 /* Check for invalid parameters: inconsistent dimensions */
702 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
703 prop_kind::forward_inference,
704 dir::unidirectional_left2right,
705 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
706 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
707 fmt::ldnc},
708 test_rnn_sizes_t {2, 1, 10, 16, 100, 100, 50, 100},
709 true, dnnl_invalid_arguments},
710 /* Check if passing {src,dst}_iter impacts results */
711 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
712
713 prop_kind::forward_inference,
714 dir::unidirectional_left2right,
715 {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo,
716 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
717 fmt::ldnc},
718 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}},
719 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
720 prop_kind::forward_inference,
721 dir::unidirectional_left2right,
722 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
723 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
724 fmt::undef},
725 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}},
726 cfg_f32 {PLAIN_RNN(alg::eltwise_tanh),
727 prop_kind::forward_inference,
728 dir::unidirectional_left2right,
729 {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo,
730 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
731 fmt::undef},
732 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}));
733
734TEST_P(lstm_forward_test_f32, TestsLSTM) {}
735CPU_INSTANTIATE_TEST_SUITE_P(TestLSTM, lstm_forward_test_f32,
736 ::testing::Values(
737 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
738 dir::unidirectional_left2right,
739 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
740 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
741 fmt::ldnc},
742 test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}},
743 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
744 dir::unidirectional_left2right,
745 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo,
746 fmt::undef, fmt::ldgo, fmt::tnc, fmt::ldnc},
747 test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}},
748 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
749 dir::unidirectional_left2right,
750 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo,
751 fmt::ldio, fmt::ldgo, fmt::tnc, fmt::ldnc},
752 test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}},
753 /* Non uniform sizes tests */
754 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
755 dir::unidirectional_left2right,
756 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
757 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
758 fmt::ldnc},
759 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
760 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
761 dir::unidirectional_left2right,
762 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo,
763 fmt::undef, fmt::ldgo, fmt::tnc, fmt::ldnc},
764 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
765 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
766 dir::unidirectional_left2right,
767 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, fmt::ldgo,
768 fmt::ldio, fmt::ldgo, fmt::tnc, fmt::ldnc},
769 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 15}},
770 /* Check if not passing dst_iter impacts results */
771 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
772 dir::unidirectional_left2right,
773 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
774 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
775 fmt::undef},
776 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}));
777
778TEST_P(gru_forward_test_f32, TestsGRU) {}
779CPU_INSTANTIATE_TEST_SUITE_P(TestGRU, gru_forward_test_f32,
780 ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference,
781 dir::unidirectional_left2right,
782 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
783 fmt::undef, fmt::undef, fmt::ldgo,
784 fmt::tnc, fmt::ldnc},
785 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
786 /* Check if not passing dst_iter impacts results */
787 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
788 dir::unidirectional_left2right,
789 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
790 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
791 fmt::undef},
792 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}));
793
794TEST_P(lbr_gru_forward_test_f32, TestsGRUlbr) {}
795CPU_INSTANTIATE_TEST_SUITE_P(TestGRUlbr, lbr_gru_forward_test_f32,
796 ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference,
797 dir::unidirectional_left2right,
798 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
799 fmt::undef, fmt::undef, fmt::ldgo,
800 fmt::tnc, fmt::ldnc},
801 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
802 /* Check if not passing dst_iter impacts results */
803 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
804 dir::unidirectional_left2right,
805 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
806 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
807 fmt::undef},
808 test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}));
809
810TEST_P(augru_forward_test_f32, TestsAUGRU) {}
811CPU_INSTANTIATE_TEST_SUITE_P(TestAUGRU, augru_forward_test_f32,
812 ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference,
813 dir::unidirectional_left2right,
814 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
815 fmt::undef, fmt::undef, fmt::ldgo,
816 fmt::tnc, fmt::ldnc},
817 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
818 /* Check if not passing dst_iter impacts results */
819 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
820 dir::unidirectional_left2right,
821 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
822 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
823 fmt::undef},
824 test_rnn_sizes_t {1, 1, 5, 1, 4, 4, 4, 4}}));
825
826TEST_P(lbr_augru_forward_test_f32, TestsAUGRUlbr) {}
827CPU_INSTANTIATE_TEST_SUITE_P(TestAUGRUlbr, lbr_augru_forward_test_f32,
828 ::testing::Values(cfg_f32 {NOT_RNN, prop_kind::forward_inference,
829 dir::unidirectional_left2right,
830 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
831 fmt::undef, fmt::undef, fmt::ldgo,
832 fmt::tnc, fmt::ldnc},
833 test_rnn_sizes_t {1, 1, 1, 1, 10, 5, 5, 5}},
834 /* Check if not passing dst_iter impacts results */
835 cfg_f32 {NOT_RNN, prop_kind::forward_inference,
836 dir::unidirectional_left2right,
837 {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo,
838 fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc,
839 fmt::undef},
840 test_rnn_sizes_t {1, 1, 5, 1, 4, 4, 4, 4}}));
841
842} // namespace dnnl
843