1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24// short names for brevity
25using data_type = memory::data_type;
26using tag = memory::format_tag;
27
28class runtime_dim_test_t : public ::testing::Test {
29protected:
30 engine eng = get_test_engine();
31 void SetUp() override {}
32
33 template <typename F>
34 void check_status(const F &f, dnnl_status_t status) {
35 catch_expected_failures(f, status != dnnl_success, status, false);
36 }
37};
38#define CHECK_STATUs(status, ...) check_status([&]() { __VA_ARGS__; }, status)
39#define CHECK_STATUS(status, ...) CHECK_STATUs(status, __VA_ARGS__)
40
41#define CHECK_OK(...) CHECK_STATUS(dnnl_success, __VA_ARGS__)
42#define CHECK_INVALID(...) CHECK_STATUS(dnnl_invalid_arguments, __VA_ARGS__)
43#define CHECK_UNIMPL(...) CHECK_STATUS(dnnl_unimplemented, __VA_ARGS__)
44
45TEST_F(runtime_dim_test_t, TestMemory) {
46 memory::desc md_tag {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL},
47 data_type::f32, tag::ab};
48 ASSERT_EQ(md_tag.get_size(), DNNL_RUNTIME_SIZE_VAL);
49 CHECK_INVALID(test::make_memory(md_tag, eng));
50
51 memory::desc md_strides {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL},
52 data_type::f32, {100, 1}};
53 ASSERT_EQ(md_strides.get_size(), DNNL_RUNTIME_SIZE_VAL);
54 CHECK_INVALID(test::make_memory(md_strides, eng));
55}
56
57TEST_F(runtime_dim_test_t, TestBNorm) {
58 memory::desc md {
59 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
60 normalization_flags flags {};
61 CHECK_UNIMPL(batch_normalization_forward::primitive_desc(
62 eng, prop_kind::forward, md, md, 0.1f, flags));
63
64 batch_normalization_forward::primitive_desc fwd_hint;
65 {
66 auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd);
67 CHECK_OK(fwd_hint = batch_normalization_forward::primitive_desc(eng,
68 prop_kind::forward, valid_md, valid_md, 0.1f, flags));
69 }
70 CHECK_UNIMPL(batch_normalization_backward::primitive_desc(
71 eng, prop_kind::backward_data, md, md, md, 0.1f, flags, fwd_hint));
72}
73
74TEST_F(runtime_dim_test_t, TestBinary) {
75 memory::desc md {
76 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
77 CHECK_UNIMPL(
78 binary::primitive_desc(eng, algorithm::binary_add, md, md, md));
79}
80
81TEST_F(runtime_dim_test_t, TestConcat) {
82 memory::desc md {
83 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
84 CHECK_UNIMPL(concat::primitive_desc(eng, 1, {md, md}));
85}
86
87TEST_F(runtime_dim_test_t, TestConv) {
88 memory::desc src_md {
89 {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd};
90 memory::desc wei_md {{32, 16, 3, 3}, data_type::f32, tag::abcd};
91 memory::desc dst_md {
92 {DNNL_RUNTIME_DIM_VAL, 32, 7, 7}, data_type::f32, tag::abcd};
93 CHECK_UNIMPL(convolution_forward::primitive_desc(eng, prop_kind::forward,
94 algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1},
95 {1, 1}, {1, 1}));
96
97 convolution_forward::primitive_desc fwd_hint;
98 {
99 auto valid_src_md
100 = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd);
101 auto valid_dst_md
102 = memory::desc({2, 32, 7, 7}, data_type::f32, tag::abcd);
103 CHECK_OK(fwd_hint
104 = convolution_forward::primitive_desc(eng, prop_kind::forward,
105 algorithm::convolution_direct, valid_src_md, wei_md,
106 valid_dst_md, {1, 1}, {1, 1}, {1, 1}));
107 }
108
109 CHECK_UNIMPL(convolution_backward_data::primitive_desc(eng,
110 algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1},
111 {1, 1}, {1, 1}, fwd_hint));
112 CHECK_UNIMPL(convolution_backward_weights::primitive_desc(eng,
113 algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1},
114 {1, 1}, {1, 1}, fwd_hint));
115}
116
117TEST_F(runtime_dim_test_t, TestDeconv) {
118 memory::desc src_md {
119 {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd};
120 memory::desc wei_md {{32, 16, 3, 3}, data_type::f32, tag::abcd};
121 memory::desc dst_md {
122 {DNNL_RUNTIME_DIM_VAL, 32, 7, 7}, data_type::f32, tag::abcd};
123 CHECK_UNIMPL(deconvolution_forward::primitive_desc(eng, prop_kind::forward,
124 algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1},
125 {1, 1}, {1, 1}));
126
127 deconvolution_forward::primitive_desc fwd_hint;
128 {
129 auto valid_src_md
130 = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd);
131 auto valid_dst_md
132 = memory::desc({2, 32, 7, 7}, data_type::f32, tag::abcd);
133 CHECK_OK(fwd_hint
134 = deconvolution_forward::primitive_desc(eng, prop_kind::forward,
135 algorithm::deconvolution_direct, valid_src_md, wei_md,
136 valid_dst_md, {1, 1}, {1, 1}, {1, 1}));
137 }
138 CHECK_UNIMPL(deconvolution_backward_data::primitive_desc(eng,
139 algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1},
140 {1, 1}, {1, 1}, fwd_hint));
141 CHECK_UNIMPL(deconvolution_backward_weights::primitive_desc(eng,
142 algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1},
143 {1, 1}, {1, 1}, fwd_hint));
144}
145
146TEST_F(runtime_dim_test_t, TestEltwise) {
147 memory::desc md {
148 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
149 CHECK_UNIMPL(eltwise_forward::primitive_desc(
150 eng, prop_kind::forward, algorithm::eltwise_relu, md, md, 0.1f));
151
152 eltwise_forward::primitive_desc fwd_hint;
153 {
154 auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd);
155 CHECK_OK(fwd_hint
156 = eltwise_forward::primitive_desc(eng, prop_kind::forward,
157 algorithm::eltwise_relu, valid_md, valid_md, 0.1f));
158 }
159
160 CHECK_UNIMPL(eltwise_backward::primitive_desc(
161 eng, algorithm::eltwise_relu, md, md, md, 0.1f, fwd_hint));
162}
163
164TEST_F(runtime_dim_test_t, TestInnerProduct) {
165 memory::desc src_md {
166 {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd};
167 memory::desc wei_md {{32, 16, 7, 7}, data_type::f32, tag::abcd};
168 memory::desc dst_md {{DNNL_RUNTIME_DIM_VAL, 32}, data_type::f32, tag::ab};
169 CHECK_UNIMPL(inner_product_forward::primitive_desc(
170 eng, prop_kind::forward, src_md, wei_md, dst_md));
171
172 inner_product_forward::primitive_desc fwd_hint;
173 {
174 auto valid_src_md
175 = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd);
176 auto valid_dst_md = memory::desc({2, 32}, data_type::f32, tag::ab);
177 CHECK_OK(fwd_hint
178 = inner_product_forward::primitive_desc(eng, prop_kind::forward,
179 valid_src_md, wei_md, valid_dst_md));
180 }
181
182 CHECK_UNIMPL(inner_product_backward_data::primitive_desc(
183 eng, src_md, wei_md, dst_md, fwd_hint));
184 CHECK_UNIMPL(inner_product_backward_weights::primitive_desc(
185 eng, src_md, wei_md, dst_md, fwd_hint));
186}
187
188TEST_F(runtime_dim_test_t, TestLNorm) {
189 memory::desc md {{DNNL_RUNTIME_DIM_VAL, 16, 16}, data_type::f32, tag::abc};
190 memory::desc stat_md {{DNNL_RUNTIME_DIM_VAL, 16}, data_type::f32, tag::ab};
191 normalization_flags flags {};
192 CHECK_UNIMPL(layer_normalization_forward::primitive_desc(
193 eng, prop_kind::forward, md, md, stat_md, 0.1f, flags));
194
195 layer_normalization_forward::primitive_desc fwd_hint;
196 {
197 auto valid_md = memory::desc({2, 16, 16}, data_type::f32, tag::abc);
198 auto valid_stat_md = memory::desc({2, 16}, data_type::f32, tag::ab);
199 CHECK_OK(fwd_hint = layer_normalization_forward::primitive_desc(eng,
200 prop_kind::forward, valid_md, valid_md, valid_stat_md,
201 0.1f, flags));
202 }
203 CHECK_UNIMPL(layer_normalization_backward::primitive_desc(eng,
204 prop_kind::backward_data, md, md, md, stat_md, 0.1f, flags,
205 fwd_hint));
206}
207
208TEST_F(runtime_dim_test_t, TestLRN) {
209 memory::desc md {
210 {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd};
211
212 CHECK_UNIMPL(lrn_forward::primitive_desc(eng, prop_kind::forward,
213 algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f, 1.0f));
214
215 lrn_forward::primitive_desc fwd_hint;
216 {
217 auto valid_md = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd);
218 CHECK_OK(fwd_hint = lrn_forward::primitive_desc(eng, prop_kind::forward,
219 algorithm::lrn_across_channels, valid_md, valid_md, 5,
220 1.f, 0.75f, 1.0f));
221 }
222 CHECK_UNIMPL(
223 lrn_backward::primitive_desc(eng, algorithm::lrn_across_channels,
224 md, md, md, 5, 1.f, 0.75f, 1.0f, fwd_hint));
225}
226
227CPU_TEST_F(runtime_dim_test_t, TestMatmul) {
228 memory::desc a_md {{DNNL_RUNTIME_DIM_VAL, 3}, data_type::f32, tag::ab};
229 memory::desc b_md {{3, DNNL_RUNTIME_DIM_VAL}, data_type::f32, tag::ba};
230 memory::desc c_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL},
231 data_type::f32, tag::ab};
232 CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md));
233}
234
235TEST_F(runtime_dim_test_t, TestPool) {
236 memory::desc src_md {
237 {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::abcd};
238 memory::desc dst_md {
239 {DNNL_RUNTIME_DIM_VAL, 16, 4, 4}, data_type::f32, tag::abcd};
240 CHECK_UNIMPL(pooling_forward::primitive_desc(eng, prop_kind::forward,
241 algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2}, {0, 0},
242 {0, 0}, {0, 0}));
243
244 pooling_forward::primitive_desc fwd_hint;
245 {
246 auto valid_src_md
247 = memory::desc({2, 16, 8, 8}, data_type::f32, tag::abcd);
248 auto valid_dst_md
249 = memory::desc({2, 16, 4, 4}, data_type::f32, tag::abcd);
250 CHECK_OK(fwd_hint
251 = pooling_forward::primitive_desc(eng, prop_kind::forward,
252 algorithm::pooling_max, valid_src_md, valid_dst_md,
253 {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0}));
254 }
255
256 CHECK_UNIMPL(pooling_backward::primitive_desc(eng, algorithm::pooling_max,
257 src_md, dst_md, {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0}, fwd_hint));
258}
259
260TEST_F(runtime_dim_test_t, TestPReLU) {
261 memory::desc data_md {
262 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
263 memory::desc weights_md {
264 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
265
266 CHECK_UNIMPL(prelu_forward::primitive_desc(
267 eng, prop_kind::forward, data_md, weights_md, data_md));
268
269 prelu_forward::primitive_desc fwd_hint;
270 {
271 auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd);
272 CHECK_OK(fwd_hint = prelu_forward::primitive_desc(eng,
273 prop_kind::forward, valid_md, valid_md, valid_md));
274 }
275
276 memory::desc diff_data_desc {
277 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
278 memory::desc diff_weights_desc {
279 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
280
281 CHECK_UNIMPL(prelu_backward::primitive_desc(eng, data_md, weights_md,
282 diff_data_desc, diff_weights_desc, diff_data_desc, fwd_hint));
283}
284
285CPU_TEST_F(runtime_dim_test_t, TestReorder) {
286 memory::desc src_md {
287 {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::abcd};
288 memory::desc dst_md {
289 {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::acdb};
290 CHECK_OK(reorder::primitive_desc(eng, src_md, eng, dst_md));
291}
292
293TEST_F(runtime_dim_test_t, TestRNN) {
294 memory::dim l = 10, c = 8, g = 1, d = 1;
295 memory::desc src_layer_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL, c},
296 data_type::f32, tag::tnc};
297 memory::desc src_iter_md {
298 {l, d, DNNL_RUNTIME_DIM_VAL, c}, data_type::f32, tag::ldnc};
299 memory::desc wei_layer_md {{l, d, c, g, c}, data_type::f32, tag::ldigo};
300 memory::desc wei_iter_md {{l, d, c, g, c}, data_type::f32, tag::ldigo};
301 memory::desc bia_md {{l, d, g, c}, data_type::f32, tag::ldgo};
302 memory::desc dst_layer_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL, c},
303 data_type::f32, tag::tnc};
304 memory::desc dst_iter_md {
305 {l, d, DNNL_RUNTIME_DIM_VAL, c}, data_type::f32, tag::ldnc};
306 CHECK_UNIMPL(vanilla_rnn_forward::primitive_desc(eng, prop_kind::forward,
307 algorithm::eltwise_relu, rnn_direction::unidirectional_left2right,
308 src_layer_md, src_iter_md, wei_layer_md, wei_iter_md, bia_md,
309 dst_layer_md, dst_iter_md));
310}
311
312TEST_F(runtime_dim_test_t, TestShuffle) {
313 memory::desc md {
314 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
315 CHECK_UNIMPL(shuffle_forward::primitive_desc(
316 eng, prop_kind::forward, md, md, 1, 4));
317
318 shuffle_forward::primitive_desc fwd_hint;
319 {
320 auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd);
321 CHECK_OK(fwd_hint = shuffle_forward::primitive_desc(
322 eng, prop_kind::forward, valid_md, valid_md, 1, 4));
323 }
324
325 CHECK_UNIMPL(shuffle_backward::primitive_desc(eng, md, md, 1, 4, fwd_hint));
326}
327
328TEST_F(runtime_dim_test_t, TestSoftmax) {
329 memory::desc md {{DNNL_RUNTIME_DIM_VAL, 16}, data_type::f32, tag::ab};
330 CHECK_UNIMPL(softmax_forward::primitive_desc(
331 eng, prop_kind::forward, algorithm::softmax_accurate, md, md, 1));
332
333 softmax_forward::primitive_desc fwd_hint;
334 {
335 auto valid_md = memory::desc({2, 16}, data_type::f32, tag::ab);
336 CHECK_OK(fwd_hint
337 = softmax_forward::primitive_desc(eng, prop_kind::forward,
338 algorithm::softmax_accurate, valid_md, valid_md, 1));
339 }
340
341 CHECK_UNIMPL(softmax_backward::primitive_desc(
342 eng, algorithm::softmax_accurate, md, md, md, 1, fwd_hint));
343}
344
345TEST_F(runtime_dim_test_t, TestSum) {
346 memory::desc md {
347 {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd};
348 CHECK_UNIMPL(sum::primitive_desc(eng, {1.f, 1.f}, {md, md}));
349}
350
351} // namespace dnnl
352