1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | // short names for brevity |
25 | using data_type = memory::data_type; |
26 | using tag = memory::format_tag; |
27 | |
28 | class runtime_dim_test_t : public ::testing::Test { |
29 | protected: |
30 | engine eng = get_test_engine(); |
31 | void SetUp() override {} |
32 | |
33 | template <typename F> |
34 | void check_status(const F &f, dnnl_status_t status) { |
35 | catch_expected_failures(f, status != dnnl_success, status, false); |
36 | } |
37 | }; |
38 | #define CHECK_STATUs(status, ...) check_status([&]() { __VA_ARGS__; }, status) |
39 | #define CHECK_STATUS(status, ...) CHECK_STATUs(status, __VA_ARGS__) |
40 | |
41 | #define CHECK_OK(...) CHECK_STATUS(dnnl_success, __VA_ARGS__) |
42 | #define CHECK_INVALID(...) CHECK_STATUS(dnnl_invalid_arguments, __VA_ARGS__) |
43 | #define CHECK_UNIMPL(...) CHECK_STATUS(dnnl_unimplemented, __VA_ARGS__) |
44 | |
45 | TEST_F(runtime_dim_test_t, TestMemory) { |
46 | memory::desc md_tag {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL}, |
47 | data_type::f32, tag::ab}; |
48 | ASSERT_EQ(md_tag.get_size(), DNNL_RUNTIME_SIZE_VAL); |
49 | CHECK_INVALID(test::make_memory(md_tag, eng)); |
50 | |
51 | memory::desc md_strides {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL}, |
52 | data_type::f32, {100, 1}}; |
53 | ASSERT_EQ(md_strides.get_size(), DNNL_RUNTIME_SIZE_VAL); |
54 | CHECK_INVALID(test::make_memory(md_strides, eng)); |
55 | } |
56 | |
57 | TEST_F(runtime_dim_test_t, TestBNorm) { |
58 | memory::desc md { |
59 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
60 | normalization_flags flags {}; |
61 | CHECK_UNIMPL(batch_normalization_forward::primitive_desc( |
62 | eng, prop_kind::forward, md, md, 0.1f, flags)); |
63 | |
64 | batch_normalization_forward::primitive_desc fwd_hint; |
65 | { |
66 | auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd); |
67 | CHECK_OK(fwd_hint = batch_normalization_forward::primitive_desc(eng, |
68 | prop_kind::forward, valid_md, valid_md, 0.1f, flags)); |
69 | } |
70 | CHECK_UNIMPL(batch_normalization_backward::primitive_desc( |
71 | eng, prop_kind::backward_data, md, md, md, 0.1f, flags, fwd_hint)); |
72 | } |
73 | |
74 | TEST_F(runtime_dim_test_t, TestBinary) { |
75 | memory::desc md { |
76 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
77 | CHECK_UNIMPL( |
78 | binary::primitive_desc(eng, algorithm::binary_add, md, md, md)); |
79 | } |
80 | |
81 | TEST_F(runtime_dim_test_t, TestConcat) { |
82 | memory::desc md { |
83 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
84 | CHECK_UNIMPL(concat::primitive_desc(eng, 1, {md, md})); |
85 | } |
86 | |
87 | TEST_F(runtime_dim_test_t, TestConv) { |
88 | memory::desc src_md { |
89 | {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd}; |
90 | memory::desc wei_md {{32, 16, 3, 3}, data_type::f32, tag::abcd}; |
91 | memory::desc dst_md { |
92 | {DNNL_RUNTIME_DIM_VAL, 32, 7, 7}, data_type::f32, tag::abcd}; |
93 | CHECK_UNIMPL(convolution_forward::primitive_desc(eng, prop_kind::forward, |
94 | algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1}, |
95 | {1, 1}, {1, 1})); |
96 | |
97 | convolution_forward::primitive_desc fwd_hint; |
98 | { |
99 | auto valid_src_md |
100 | = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd); |
101 | auto valid_dst_md |
102 | = memory::desc({2, 32, 7, 7}, data_type::f32, tag::abcd); |
103 | CHECK_OK(fwd_hint |
104 | = convolution_forward::primitive_desc(eng, prop_kind::forward, |
105 | algorithm::convolution_direct, valid_src_md, wei_md, |
106 | valid_dst_md, {1, 1}, {1, 1}, {1, 1})); |
107 | } |
108 | |
109 | CHECK_UNIMPL(convolution_backward_data::primitive_desc(eng, |
110 | algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1}, |
111 | {1, 1}, {1, 1}, fwd_hint)); |
112 | CHECK_UNIMPL(convolution_backward_weights::primitive_desc(eng, |
113 | algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1}, |
114 | {1, 1}, {1, 1}, fwd_hint)); |
115 | } |
116 | |
117 | TEST_F(runtime_dim_test_t, TestDeconv) { |
118 | memory::desc src_md { |
119 | {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd}; |
120 | memory::desc wei_md {{32, 16, 3, 3}, data_type::f32, tag::abcd}; |
121 | memory::desc dst_md { |
122 | {DNNL_RUNTIME_DIM_VAL, 32, 7, 7}, data_type::f32, tag::abcd}; |
123 | CHECK_UNIMPL(deconvolution_forward::primitive_desc(eng, prop_kind::forward, |
124 | algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1}, |
125 | {1, 1}, {1, 1})); |
126 | |
127 | deconvolution_forward::primitive_desc fwd_hint; |
128 | { |
129 | auto valid_src_md |
130 | = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd); |
131 | auto valid_dst_md |
132 | = memory::desc({2, 32, 7, 7}, data_type::f32, tag::abcd); |
133 | CHECK_OK(fwd_hint |
134 | = deconvolution_forward::primitive_desc(eng, prop_kind::forward, |
135 | algorithm::deconvolution_direct, valid_src_md, wei_md, |
136 | valid_dst_md, {1, 1}, {1, 1}, {1, 1})); |
137 | } |
138 | CHECK_UNIMPL(deconvolution_backward_data::primitive_desc(eng, |
139 | algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1}, |
140 | {1, 1}, {1, 1}, fwd_hint)); |
141 | CHECK_UNIMPL(deconvolution_backward_weights::primitive_desc(eng, |
142 | algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1}, |
143 | {1, 1}, {1, 1}, fwd_hint)); |
144 | } |
145 | |
146 | TEST_F(runtime_dim_test_t, TestEltwise) { |
147 | memory::desc md { |
148 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
149 | CHECK_UNIMPL(eltwise_forward::primitive_desc( |
150 | eng, prop_kind::forward, algorithm::eltwise_relu, md, md, 0.1f)); |
151 | |
152 | eltwise_forward::primitive_desc fwd_hint; |
153 | { |
154 | auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd); |
155 | CHECK_OK(fwd_hint |
156 | = eltwise_forward::primitive_desc(eng, prop_kind::forward, |
157 | algorithm::eltwise_relu, valid_md, valid_md, 0.1f)); |
158 | } |
159 | |
160 | CHECK_UNIMPL(eltwise_backward::primitive_desc( |
161 | eng, algorithm::eltwise_relu, md, md, md, 0.1f, fwd_hint)); |
162 | } |
163 | |
164 | TEST_F(runtime_dim_test_t, TestInnerProduct) { |
165 | memory::desc src_md { |
166 | {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd}; |
167 | memory::desc wei_md {{32, 16, 7, 7}, data_type::f32, tag::abcd}; |
168 | memory::desc dst_md {{DNNL_RUNTIME_DIM_VAL, 32}, data_type::f32, tag::ab}; |
169 | CHECK_UNIMPL(inner_product_forward::primitive_desc( |
170 | eng, prop_kind::forward, src_md, wei_md, dst_md)); |
171 | |
172 | inner_product_forward::primitive_desc fwd_hint; |
173 | { |
174 | auto valid_src_md |
175 | = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd); |
176 | auto valid_dst_md = memory::desc({2, 32}, data_type::f32, tag::ab); |
177 | CHECK_OK(fwd_hint |
178 | = inner_product_forward::primitive_desc(eng, prop_kind::forward, |
179 | valid_src_md, wei_md, valid_dst_md)); |
180 | } |
181 | |
182 | CHECK_UNIMPL(inner_product_backward_data::primitive_desc( |
183 | eng, src_md, wei_md, dst_md, fwd_hint)); |
184 | CHECK_UNIMPL(inner_product_backward_weights::primitive_desc( |
185 | eng, src_md, wei_md, dst_md, fwd_hint)); |
186 | } |
187 | |
188 | TEST_F(runtime_dim_test_t, TestLNorm) { |
189 | memory::desc md {{DNNL_RUNTIME_DIM_VAL, 16, 16}, data_type::f32, tag::abc}; |
190 | memory::desc stat_md {{DNNL_RUNTIME_DIM_VAL, 16}, data_type::f32, tag::ab}; |
191 | normalization_flags flags {}; |
192 | CHECK_UNIMPL(layer_normalization_forward::primitive_desc( |
193 | eng, prop_kind::forward, md, md, stat_md, 0.1f, flags)); |
194 | |
195 | layer_normalization_forward::primitive_desc fwd_hint; |
196 | { |
197 | auto valid_md = memory::desc({2, 16, 16}, data_type::f32, tag::abc); |
198 | auto valid_stat_md = memory::desc({2, 16}, data_type::f32, tag::ab); |
199 | CHECK_OK(fwd_hint = layer_normalization_forward::primitive_desc(eng, |
200 | prop_kind::forward, valid_md, valid_md, valid_stat_md, |
201 | 0.1f, flags)); |
202 | } |
203 | CHECK_UNIMPL(layer_normalization_backward::primitive_desc(eng, |
204 | prop_kind::backward_data, md, md, md, stat_md, 0.1f, flags, |
205 | fwd_hint)); |
206 | } |
207 | |
208 | TEST_F(runtime_dim_test_t, TestLRN) { |
209 | memory::desc md { |
210 | {DNNL_RUNTIME_DIM_VAL, 16, 7, 7}, data_type::f32, tag::abcd}; |
211 | |
212 | CHECK_UNIMPL(lrn_forward::primitive_desc(eng, prop_kind::forward, |
213 | algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f, 1.0f)); |
214 | |
215 | lrn_forward::primitive_desc fwd_hint; |
216 | { |
217 | auto valid_md = memory::desc({2, 16, 7, 7}, data_type::f32, tag::abcd); |
218 | CHECK_OK(fwd_hint = lrn_forward::primitive_desc(eng, prop_kind::forward, |
219 | algorithm::lrn_across_channels, valid_md, valid_md, 5, |
220 | 1.f, 0.75f, 1.0f)); |
221 | } |
222 | CHECK_UNIMPL( |
223 | lrn_backward::primitive_desc(eng, algorithm::lrn_across_channels, |
224 | md, md, md, 5, 1.f, 0.75f, 1.0f, fwd_hint)); |
225 | } |
226 | |
227 | CPU_TEST_F(runtime_dim_test_t, TestMatmul) { |
228 | memory::desc a_md {{DNNL_RUNTIME_DIM_VAL, 3}, data_type::f32, tag::ab}; |
229 | memory::desc b_md {{3, DNNL_RUNTIME_DIM_VAL}, data_type::f32, tag::ba}; |
230 | memory::desc c_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL}, |
231 | data_type::f32, tag::ab}; |
232 | CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md)); |
233 | } |
234 | |
235 | TEST_F(runtime_dim_test_t, TestPool) { |
236 | memory::desc src_md { |
237 | {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::abcd}; |
238 | memory::desc dst_md { |
239 | {DNNL_RUNTIME_DIM_VAL, 16, 4, 4}, data_type::f32, tag::abcd}; |
240 | CHECK_UNIMPL(pooling_forward::primitive_desc(eng, prop_kind::forward, |
241 | algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2}, {0, 0}, |
242 | {0, 0}, {0, 0})); |
243 | |
244 | pooling_forward::primitive_desc fwd_hint; |
245 | { |
246 | auto valid_src_md |
247 | = memory::desc({2, 16, 8, 8}, data_type::f32, tag::abcd); |
248 | auto valid_dst_md |
249 | = memory::desc({2, 16, 4, 4}, data_type::f32, tag::abcd); |
250 | CHECK_OK(fwd_hint |
251 | = pooling_forward::primitive_desc(eng, prop_kind::forward, |
252 | algorithm::pooling_max, valid_src_md, valid_dst_md, |
253 | {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0})); |
254 | } |
255 | |
256 | CHECK_UNIMPL(pooling_backward::primitive_desc(eng, algorithm::pooling_max, |
257 | src_md, dst_md, {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0}, fwd_hint)); |
258 | } |
259 | |
260 | TEST_F(runtime_dim_test_t, TestPReLU) { |
261 | memory::desc data_md { |
262 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
263 | memory::desc weights_md { |
264 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
265 | |
266 | CHECK_UNIMPL(prelu_forward::primitive_desc( |
267 | eng, prop_kind::forward, data_md, weights_md, data_md)); |
268 | |
269 | prelu_forward::primitive_desc fwd_hint; |
270 | { |
271 | auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd); |
272 | CHECK_OK(fwd_hint = prelu_forward::primitive_desc(eng, |
273 | prop_kind::forward, valid_md, valid_md, valid_md)); |
274 | } |
275 | |
276 | memory::desc diff_data_desc { |
277 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
278 | memory::desc diff_weights_desc { |
279 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
280 | |
281 | CHECK_UNIMPL(prelu_backward::primitive_desc(eng, data_md, weights_md, |
282 | diff_data_desc, diff_weights_desc, diff_data_desc, fwd_hint)); |
283 | } |
284 | |
285 | CPU_TEST_F(runtime_dim_test_t, TestReorder) { |
286 | memory::desc src_md { |
287 | {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::abcd}; |
288 | memory::desc dst_md { |
289 | {DNNL_RUNTIME_DIM_VAL, 16, 8, 8}, data_type::f32, tag::acdb}; |
290 | CHECK_OK(reorder::primitive_desc(eng, src_md, eng, dst_md)); |
291 | } |
292 | |
293 | TEST_F(runtime_dim_test_t, TestRNN) { |
294 | memory::dim l = 10, c = 8, g = 1, d = 1; |
295 | memory::desc src_layer_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL, c}, |
296 | data_type::f32, tag::tnc}; |
297 | memory::desc src_iter_md { |
298 | {l, d, DNNL_RUNTIME_DIM_VAL, c}, data_type::f32, tag::ldnc}; |
299 | memory::desc wei_layer_md {{l, d, c, g, c}, data_type::f32, tag::ldigo}; |
300 | memory::desc wei_iter_md {{l, d, c, g, c}, data_type::f32, tag::ldigo}; |
301 | memory::desc bia_md {{l, d, g, c}, data_type::f32, tag::ldgo}; |
302 | memory::desc dst_layer_md {{DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL, c}, |
303 | data_type::f32, tag::tnc}; |
304 | memory::desc dst_iter_md { |
305 | {l, d, DNNL_RUNTIME_DIM_VAL, c}, data_type::f32, tag::ldnc}; |
306 | CHECK_UNIMPL(vanilla_rnn_forward::primitive_desc(eng, prop_kind::forward, |
307 | algorithm::eltwise_relu, rnn_direction::unidirectional_left2right, |
308 | src_layer_md, src_iter_md, wei_layer_md, wei_iter_md, bia_md, |
309 | dst_layer_md, dst_iter_md)); |
310 | } |
311 | |
312 | TEST_F(runtime_dim_test_t, TestShuffle) { |
313 | memory::desc md { |
314 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
315 | CHECK_UNIMPL(shuffle_forward::primitive_desc( |
316 | eng, prop_kind::forward, md, md, 1, 4)); |
317 | |
318 | shuffle_forward::primitive_desc fwd_hint; |
319 | { |
320 | auto valid_md = memory::desc({2, 16, 3, 3}, data_type::f32, tag::abcd); |
321 | CHECK_OK(fwd_hint = shuffle_forward::primitive_desc( |
322 | eng, prop_kind::forward, valid_md, valid_md, 1, 4)); |
323 | } |
324 | |
325 | CHECK_UNIMPL(shuffle_backward::primitive_desc(eng, md, md, 1, 4, fwd_hint)); |
326 | } |
327 | |
328 | TEST_F(runtime_dim_test_t, TestSoftmax) { |
329 | memory::desc md {{DNNL_RUNTIME_DIM_VAL, 16}, data_type::f32, tag::ab}; |
330 | CHECK_UNIMPL(softmax_forward::primitive_desc( |
331 | eng, prop_kind::forward, algorithm::softmax_accurate, md, md, 1)); |
332 | |
333 | softmax_forward::primitive_desc fwd_hint; |
334 | { |
335 | auto valid_md = memory::desc({2, 16}, data_type::f32, tag::ab); |
336 | CHECK_OK(fwd_hint |
337 | = softmax_forward::primitive_desc(eng, prop_kind::forward, |
338 | algorithm::softmax_accurate, valid_md, valid_md, 1)); |
339 | } |
340 | |
341 | CHECK_UNIMPL(softmax_backward::primitive_desc( |
342 | eng, algorithm::softmax_accurate, md, md, md, 1, fwd_hint)); |
343 | } |
344 | |
345 | TEST_F(runtime_dim_test_t, TestSum) { |
346 | memory::desc md { |
347 | {DNNL_RUNTIME_DIM_VAL, 16, 3, 3}, data_type::f32, tag::abcd}; |
348 | CHECK_UNIMPL(sum::primitive_desc(eng, {1.f, 1.f}, {md, md})); |
349 | } |
350 | |
351 | } // namespace dnnl |
352 | |