1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * Copyright 2020-2021 FUJITSU LIMITED |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include "dnnl_test_common.hpp" |
19 | #include "gtest/gtest.h" |
20 | |
21 | #include "oneapi/dnnl/dnnl.hpp" |
22 | #include "tests/test_isa_common.hpp" |
23 | |
24 | namespace dnnl { |
25 | |
26 | using data_type = memory::data_type; |
27 | using tag = memory::format_tag; |
28 | |
29 | class attr_test_t : public ::testing::Test { |
30 | protected: |
31 | void SetUp() override {} |
32 | }; |
33 | |
34 | TEST_F(attr_test_t, TestFPMathMode) { |
35 | dnnl::primitive_attr attr; |
36 | ASSERT_EQ(attr.get_fpmath_mode(), fpmath_mode::strict); |
37 | |
38 | for (auto m : {fpmath_mode::strict, fpmath_mode::bf16, fpmath_mode::f16, |
39 | fpmath_mode::tf32, fpmath_mode::any}) { |
40 | attr.set_fpmath_mode(m); |
41 | ASSERT_EQ(m, attr.get_fpmath_mode()); |
42 | } |
43 | } |
44 | |
45 | TEST_F(attr_test_t, TestFPMathModeDefault) { |
46 | ASSERT_EQ(fpmath_mode::strict, get_default_fpmath_mode()); |
47 | |
48 | for (auto m : {fpmath_mode::strict, fpmath_mode::bf16, fpmath_mode::f16, |
49 | fpmath_mode::tf32, fpmath_mode::any}) { |
50 | set_default_fpmath_mode(m); |
51 | ASSERT_EQ(m, get_default_fpmath_mode()); |
52 | dnnl::primitive_attr attr; |
53 | ASSERT_EQ(m, attr.get_fpmath_mode()); |
54 | } |
55 | } |
56 | |
57 | TEST_F(attr_test_t, TestScratchpadMode) { |
58 | dnnl::primitive_attr attr; |
59 | for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) { |
60 | attr.set_scratchpad_mode(m); |
61 | ASSERT_EQ(m, attr.get_scratchpad_mode()); |
62 | } |
63 | } |
64 | |
65 | TEST_F(attr_test_t, TestScratchpadModeEx) { |
66 | engine eng = get_test_engine(); |
67 | |
68 | const memory::dim N = 2, C = 2, W = 2; |
69 | |
70 | memory::desc data_md( |
71 | {N, C, W}, memory::data_type::f32, memory::format_tag::ncw); |
72 | |
73 | dnnl::primitive_attr attr; |
74 | for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) { |
75 | attr.set_scratchpad_mode(m); |
76 | auto softmax_pd = softmax_forward::primitive_desc(eng, |
77 | prop_kind::forward_inference, algorithm::softmax_accurate, |
78 | data_md, data_md, 1, attr); |
79 | auto scratchpad_size = (long)softmax_pd.scratchpad_desc().get_size(); |
80 | auto mem_consumption |
81 | = (long)softmax_pd.query_s64(query::memory_consumption_s64); |
82 | |
83 | if (m == scratchpad_mode::library) { |
84 | ASSERT_EQ(scratchpad_size, 0L); |
85 | } else { |
86 | ASSERT_EQ(mem_consumption, 0L); |
87 | } |
88 | } |
89 | } |
90 | |
91 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestScratchpadArg) { |
92 | engine eng = get_test_engine(); |
93 | |
94 | const memory::dim N = 2, C = 2, W = 2; |
95 | |
96 | memory::desc data_md( |
97 | {N, C, W}, memory::data_type::f32, memory::format_tag::ncw); |
98 | |
99 | dnnl::primitive_attr attr; |
100 | for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) { |
101 | attr.set_scratchpad_mode(m); |
102 | auto softmax_pd = softmax_forward::primitive_desc(eng, |
103 | prop_kind::forward_inference, algorithm::softmax_accurate, |
104 | data_md, data_md, 1, attr); |
105 | |
106 | auto src = test::make_memory(softmax_pd.src_desc(), eng); |
107 | auto dst = test::make_memory(softmax_pd.dst_desc(), eng); |
108 | auto scratchpad = test::make_memory(softmax_pd.scratchpad_desc(), eng); |
109 | |
110 | fill_data<float>(src.get_desc().get_size() / sizeof(float), src); |
111 | |
112 | stream s(eng); |
113 | |
114 | softmax_forward softmax_p(softmax_pd); |
115 | softmax_p.execute(s, |
116 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
117 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); |
118 | s.wait(); |
119 | } |
120 | } |
121 | |
122 | TEST_F(attr_test_t, TestZeroPoints) { |
123 | dnnl::primitive_attr attr; |
124 | |
125 | const std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_DST}; |
126 | const std::vector<int> unsupported_args = {DNNL_ARG_BIAS, DNNL_ARG_DST_2, |
127 | DNNL_ARG_MEAN, DNNL_ARG_WORKSPACE, DNNL_ARG_SCRATCHPAD}; |
128 | |
129 | for (auto arg : supported_args) { |
130 | // single non-default zero_point for supported arg |
131 | attr.set_zero_points_mask(arg, 0); |
132 | } |
133 | |
134 | for (auto arg : unsupported_args) { |
135 | // single **default** zero_point for **unsupported** arg |
136 | EXPECT_ANY_THROW(attr.set_zero_points_mask(arg, 0)); |
137 | } |
138 | |
139 | // multiple zero_points not implemented yet ... |
140 | } |
141 | |
142 | TEST_F(attr_test_t, TestZeroPointsExpectFailure) { |
143 | dnnl::primitive_attr attr; |
144 | |
145 | const int unsupported_arg = DNNL_ARG_MEAN; |
146 | |
147 | // single non-default zero_point for unsupported arg |
148 | EXPECT_ANY_THROW(attr.set_zero_points_mask(unsupported_arg, 0)); |
149 | |
150 | // multiple zero points for unsupported args |
151 | EXPECT_ANY_THROW(attr.set_zero_points_mask(unsupported_arg, 1 << 1)); |
152 | } |
153 | |
154 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestScales) { |
155 | dnnl::primitive_attr attr; |
156 | |
157 | const std::vector<int> supported_args = {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, |
158 | DNNL_ARG_MULTIPLE_SRC, DNNL_ARG_MULTIPLE_SRC + 42}; |
159 | const std::vector<int> unsupported_args = {DNNL_ARG_BIAS, DNNL_ARG_DST_2, |
160 | DNNL_ARG_MEAN, DNNL_ARG_WORKSPACE, DNNL_ARG_SCRATCHPAD}; |
161 | |
162 | for (auto arg : supported_args) { |
163 | // single non-default scales for supported arg |
164 | attr.set_scales_mask(arg, 0); |
165 | // multiple scales |
166 | attr.set_scales_mask(arg, 1 << 1); |
167 | } |
168 | |
169 | for (auto arg : unsupported_args) { |
170 | // single scales for unsupported args |
171 | EXPECT_ANY_THROW(attr.set_scales_mask(arg, 0)); |
172 | } |
173 | } |
174 | |
175 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNDataQuantization) { |
176 | dnnl::primitive_attr attr; |
177 | |
178 | float scale = NAN, shift = NAN; |
179 | |
180 | // default scale and shift |
181 | attr.get_rnn_data_qparams(scale, shift); |
182 | ASSERT_EQ(scale, 1.f); |
183 | ASSERT_EQ(shift, 0.f); |
184 | |
185 | // non-default |
186 | attr.set_rnn_data_qparams(0.5f, 2.f); |
187 | attr.get_rnn_data_qparams(scale, shift); |
188 | ASSERT_EQ(scale, 0.5f); |
189 | ASSERT_EQ(shift, 2.f); |
190 | } |
191 | |
192 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNWeightsQuantization) { |
193 | dnnl::primitive_attr attr; |
194 | |
195 | int scales_mask = INT_MAX; |
196 | std::vector<float> scales; |
197 | |
198 | // default scale and shift |
199 | attr.get_rnn_weights_qparams(scales_mask, scales); |
200 | ASSERT_EQ(scales_mask, 0); |
201 | ASSERT_EQ(scales.size(), 1U); |
202 | ASSERT_EQ(scales[0], 1.f); |
203 | |
204 | // single non-default scales |
205 | attr.set_rnn_weights_qparams(0, {2.f}); |
206 | attr.get_rnn_weights_qparams(scales_mask, scales); |
207 | ASSERT_EQ(scales_mask, 0); |
208 | ASSERT_EQ(scales.size(), 1U); |
209 | ASSERT_EQ(scales[0], 2.f); |
210 | |
211 | // multiple scales |
212 | attr.set_rnn_weights_qparams(1 << 1, {1.f, 2.f, 4.f}); |
213 | attr.get_rnn_weights_qparams(scales_mask, scales); |
214 | ASSERT_EQ(scales_mask, 1 << 1); |
215 | ASSERT_EQ(scales.size(), 3U); |
216 | ASSERT_EQ(scales[0], 1.f); |
217 | ASSERT_EQ(scales[1], 2.f); |
218 | ASSERT_EQ(scales[2], 4.f); |
219 | } |
220 | |
221 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNProjWeightsQuantization) { |
222 | dnnl::primitive_attr attr; |
223 | |
224 | int scales_mask = INT_MAX; |
225 | std::vector<float> scales; |
226 | |
227 | // default scale and shift |
228 | attr.get_rnn_weights_projection_qparams(scales_mask, scales); |
229 | ASSERT_EQ(scales_mask, 0); |
230 | ASSERT_EQ(scales.size(), 1U); |
231 | ASSERT_EQ(scales[0], 1.f); |
232 | |
233 | // single non-default scales |
234 | attr.set_rnn_weights_projection_qparams(0, {2.f}); |
235 | attr.get_rnn_weights_projection_qparams(scales_mask, scales); |
236 | ASSERT_EQ(scales_mask, 0); |
237 | ASSERT_EQ(scales.size(), 1U); |
238 | ASSERT_EQ(scales[0], 2.f); |
239 | |
240 | // multiple scales |
241 | attr.set_rnn_weights_projection_qparams(1 << 1, {1.f, 2.f, 4.f}); |
242 | attr.get_rnn_weights_projection_qparams(scales_mask, scales); |
243 | ASSERT_EQ(scales_mask, 1 << 1); |
244 | ASSERT_EQ(scales.size(), 3U); |
245 | ASSERT_EQ(scales[0], 1.f); |
246 | ASSERT_EQ(scales[1], 2.f); |
247 | ASSERT_EQ(scales[2], 4.f); |
248 | } |
249 | |
250 | TEST_F(attr_test_t, TestScalesExpectFailure) { |
251 | dnnl::primitive_attr attr; |
252 | const int unsupported_arg = DNNL_ARG_MEAN; |
253 | |
254 | // non-default scales for unsupported arg |
255 | EXPECT_ANY_THROW(attr.set_scales_mask(unsupported_arg, 0)); |
256 | EXPECT_ANY_THROW(attr.set_scales_mask(unsupported_arg, 1 << 1)); |
257 | } |
258 | |
259 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestPostOps) { |
260 | dnnl::primitive_attr attr; |
261 | dnnl::post_ops ops; |
262 | |
263 | algorithm alg = algorithm::undef; |
264 | float scale = NAN, alpha = NAN, beta = NAN; |
265 | int32_t sum_zp = INT_MAX; |
266 | data_type dt = data_type::undef; |
267 | |
268 | ASSERT_EQ(ops.len(), 0); |
269 | ASSERT_EQ(attr.get_post_ops().len(), 0); |
270 | |
271 | ops.append_sum(1.1f, 1, data_type::f32); |
272 | attr.set_post_ops(ops); |
273 | |
274 | ASSERT_EQ(attr.get_post_ops().len(), 1); |
275 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum); |
276 | attr.get_post_ops().get_params_sum(0, scale, sum_zp, dt); |
277 | ASSERT_FLOAT_EQ(scale, 1.1f); |
278 | ASSERT_EQ(1, sum_zp); |
279 | ASSERT_EQ(data_type::f32, dt); |
280 | |
281 | ops.append_eltwise(algorithm::eltwise_clip, 3.3f, 4.4f); |
282 | attr.set_post_ops(ops); |
283 | |
284 | ASSERT_EQ(attr.get_post_ops().len(), 2); |
285 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum); |
286 | ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise); |
287 | attr.get_post_ops().get_params_eltwise(1, alg, alpha, beta); |
288 | ASSERT_EQ(alg, algorithm::eltwise_clip); |
289 | ASSERT_FLOAT_EQ(alpha, 3.3f); |
290 | ASSERT_FLOAT_EQ(beta, 4.4f); |
291 | |
292 | memory::desc src1_md({1}, data_type::f32, {1}); |
293 | ops.append_binary(algorithm::binary_add, src1_md); |
294 | attr.set_post_ops(ops); |
295 | |
296 | ASSERT_EQ(attr.get_post_ops().len(), 3); |
297 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum); |
298 | ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise); |
299 | ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::binary); |
300 | memory::desc src1_md_out; |
301 | attr.get_post_ops().get_params_binary(2, alg, src1_md_out); |
302 | ASSERT_EQ(alg, algorithm::binary_add); |
303 | ASSERT_EQ(src1_md, src1_md_out); |
304 | |
305 | const int prelu_mask = 1; |
306 | ops.append_prelu(prelu_mask); |
307 | attr.set_post_ops(ops); |
308 | ASSERT_EQ(attr.get_post_ops().len(), 4); |
309 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum); |
310 | ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise); |
311 | ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::binary); |
312 | ASSERT_EQ(attr.get_post_ops().kind(3), primitive::kind::prelu); |
313 | int mask = INT_MAX; |
314 | attr.get_post_ops().get_params_prelu(3, mask); |
315 | ASSERT_EQ(mask, prelu_mask); |
316 | } |
317 | |
318 | TEST_F(attr_test_t, TestPostOpsCheckLimit) { |
319 | dnnl::post_ops ops_sum, ops_eltwise, ops_binary, ops_prelu; |
320 | |
321 | for (int i = 0; i < 32; i++) { |
322 | EXPECT_NO_THROW(ops_sum.append_sum(i + 1.f)); |
323 | EXPECT_NO_THROW(ops_eltwise.append_eltwise( |
324 | algorithm::eltwise_relu, 2 * i, 0.f)); |
325 | EXPECT_NO_THROW(ops_binary.append_binary(algorithm::binary_add, |
326 | memory::desc({i}, data_type::s8, memory::format_tag::a))); |
327 | EXPECT_NO_THROW(ops_prelu.append_prelu(1)); |
328 | } |
329 | |
330 | EXPECT_ANY_THROW(ops_prelu.append_prelu(1)); |
331 | EXPECT_ANY_THROW(ops_sum.append_sum(1.f)); |
332 | EXPECT_ANY_THROW( |
333 | ops_eltwise.append_eltwise(algorithm::eltwise_relu, 1.f, 0.f)); |
334 | EXPECT_ANY_THROW(ops_binary.append_binary(algorithm::binary_add, |
335 | memory::desc({1}, data_type::s8, memory::format_tag::a))); |
336 | } |
337 | |
338 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, DepthwiseFusionPostop) { |
339 | dnnl::primitive_attr attr; |
340 | dnnl::post_ops ops; |
341 | |
342 | data_type wei_dt = data_type::undef; |
343 | data_type bias_dt = data_type::undef; |
344 | data_type dst_dt = data_type::undef; |
345 | memory::dim kernel = -1; |
346 | memory::dim stride = -1; |
347 | memory::dim padding = -1; |
348 | |
349 | ASSERT_EQ(ops.len(), 0); |
350 | ASSERT_EQ(attr.get_post_ops().len(), 0); |
351 | |
352 | ops.append_dw(memory::data_type::s8, memory::data_type::f32, |
353 | memory::data_type::u8, 3, 1, 1); |
354 | attr.set_post_ops(ops); |
355 | |
356 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution); |
357 | attr.get_post_ops().get_params_dw( |
358 | 0, wei_dt, bias_dt, dst_dt, kernel, stride, padding); |
359 | ASSERT_EQ(wei_dt, memory::data_type::s8); |
360 | ASSERT_EQ(bias_dt, memory::data_type::f32); |
361 | ASSERT_EQ(dst_dt, memory::data_type::u8); |
362 | ASSERT_EQ(kernel, 3); |
363 | ASSERT_EQ(stride, 1); |
364 | ASSERT_EQ(padding, 1); |
365 | |
366 | kernel = stride = padding = -1; |
367 | ops.append_dw(memory::data_type::u8, memory::data_type::s32, |
368 | memory::data_type::f32, 3, 2, 1); |
369 | attr.set_post_ops(ops); |
370 | |
371 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution); |
372 | ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::convolution); |
373 | |
374 | attr.get_post_ops().get_params_dw( |
375 | 1, wei_dt, bias_dt, dst_dt, kernel, stride, padding); |
376 | |
377 | ASSERT_EQ(wei_dt, memory::data_type::u8); |
378 | ASSERT_EQ(bias_dt, memory::data_type::s32); |
379 | ASSERT_EQ(dst_dt, memory::data_type::f32); |
380 | ASSERT_EQ(kernel, 3); |
381 | ASSERT_EQ(stride, 2); |
382 | ASSERT_EQ(padding, 1); |
383 | |
384 | kernel = stride = padding = -1; |
385 | ops.append_dw(memory::data_type::f32, memory::data_type::f32, |
386 | memory::data_type::f32, 7, 3, 2); |
387 | attr.set_post_ops(ops); |
388 | |
389 | ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution); |
390 | ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::convolution); |
391 | ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::convolution); |
392 | |
393 | attr.get_post_ops().get_params_dw( |
394 | 2, wei_dt, bias_dt, dst_dt, kernel, stride, padding); |
395 | |
396 | ASSERT_EQ(wei_dt, memory::data_type::f32); |
397 | ASSERT_EQ(bias_dt, memory::data_type::f32); |
398 | ASSERT_EQ(dst_dt, memory::data_type::f32); |
399 | ASSERT_EQ(kernel, 7); |
400 | ASSERT_EQ(stride, 3); |
401 | ASSERT_EQ(padding, 2); |
402 | |
403 | kernel = stride = padding = -1; |
404 | ops.append_dw(memory::data_type::s8, memory::data_type::f32, |
405 | memory::data_type::u8, 5, 2, 1); |
406 | attr.set_post_ops(ops); |
407 | |
408 | ASSERT_EQ(attr.get_post_ops().kind(3), primitive::kind::convolution); |
409 | |
410 | attr.get_post_ops().get_params_dw( |
411 | 3, wei_dt, bias_dt, dst_dt, kernel, stride, padding); |
412 | |
413 | ASSERT_EQ(wei_dt, memory::data_type::s8); |
414 | ASSERT_EQ(bias_dt, memory::data_type::f32); |
415 | ASSERT_EQ(dst_dt, memory::data_type::u8); |
416 | ASSERT_EQ(kernel, 5); |
417 | ASSERT_EQ(stride, 2); |
418 | ASSERT_EQ(padding, 1); |
419 | } |
420 | |
421 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, DepthwiseFusion) { |
422 | |
423 | auto engine_kind = get_test_engine_kind(); |
424 | SKIP_IF(engine_kind != engine::kind::cpu, |
425 | "Depthwise fusion is only supported on CPU engine" ); |
426 | #if DNNL_AARCH64 |
427 | SKIP_IF(true, "Depthwise fusion is not supported on AArch64 at this time" ); |
428 | #endif |
429 | |
430 | engine e {engine_kind, 0}; |
431 | |
432 | std::vector<memory::data_type> test_dts { |
433 | memory::data_type::f32, memory::data_type::s8}; |
434 | |
435 | if (!unsupported_data_type(memory::data_type::bf16)) |
436 | test_dts.push_back(memory::data_type::bf16); |
437 | |
438 | for (auto dt : test_dts) { |
439 | |
440 | memory::desc dat_md {{1024, 512, 64, 64}, dt, memory::format_tag::any}; |
441 | memory::desc wht_md {{512, 512, 1, 1}, dt, memory::format_tag::any}; |
442 | |
443 | std::string impl_info_unfused; |
444 | |
445 | auto pd = convolution_forward::primitive_desc(e, |
446 | prop_kind::forward_inference, algorithm::convolution_auto, |
447 | dat_md, wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}); |
448 | |
449 | ASSERT_NO_THROW(impl_info_unfused = pd.impl_info_str();); |
450 | |
451 | // skip if above unfused impl is not jitted. |
452 | if (impl_info_unfused.compare(0, 3, "jit" ) != 0) continue; |
453 | |
454 | // skip if above unfused impl is jitted amx. |
455 | if (impl_info_unfused.find("amx" ) != std::string::npos) continue; |
456 | |
457 | dnnl::primitive_attr attr; |
458 | dnnl::post_ops ops; |
459 | ops.append_dw(dt, dt, dt, 3, 1, 1); |
460 | attr.set_post_ops(ops); |
461 | |
462 | std::string impl_info_fused; |
463 | |
464 | pd = convolution_forward::primitive_desc(e, |
465 | prop_kind::forward_inference, algorithm::convolution_auto, |
466 | dat_md, wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr); |
467 | ASSERT_NO_THROW(impl_info_fused = pd.impl_info_str();); |
468 | |
469 | // Make sure ref fused impl is not deployed. |
470 | // NOTE: When out_of_memory testing enabled, all implementations that |
471 | // construct primitive attributes will fail, hence the ref |
472 | // implementation is deployed. |
473 | if (!test_out_of_memory()) { |
474 | ASSERT_EQ(impl_info_fused, impl_info_unfused); |
475 | } |
476 | } |
477 | } |
478 | |
479 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, InnerProdBlockedWeights) { |
480 | auto engine_kind = get_test_engine_kind(); |
481 | bool skip_test = !DNNL_X64 || (DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE) |
482 | || (engine_kind != engine::kind::cpu); |
483 | #if DNNL_X64 && (DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE) |
484 | skip_test = skip_test || !dnnl::mayiuse(cpu_isa::avx512_core); |
485 | #endif |
486 | SKIP_IF(skip_test, |
487 | "Inner product blocked weights test is supported only on " |
488 | "avx512_core CPU" ); |
489 | |
490 | engine e {engine_kind, 0}; |
491 | |
492 | std::vector<memory::format_tag> blocked_weights_tags { |
493 | memory::format_tag::OIhw16i64o, memory::format_tag::OIhw16i32o, |
494 | memory::format_tag::OIhw16i16o}; |
495 | |
496 | for (const auto &weights_tag : blocked_weights_tags) { |
497 | memory::desc src_md {{1024, 512, 1, 1}, memory::data_type::f32, |
498 | memory::format_tag::any}; |
499 | memory::desc wei_md { |
500 | {256, 512, 1, 1}, memory::data_type::f32, weights_tag}; |
501 | memory::desc bia_md { |
502 | {256}, memory::data_type::f32, memory::format_tag::any}; |
503 | memory::desc dst_md { |
504 | {1024, 256}, memory::data_type::f32, memory::format_tag::any}; |
505 | |
506 | auto pd = inner_product_forward::primitive_desc( |
507 | e, prop_kind::forward_training, src_md, wei_md, bia_md, dst_md); |
508 | |
509 | std::string impl_info; |
510 | ASSERT_NO_THROW(impl_info = pd.impl_info_str();); |
511 | ASSERT_NE(impl_info, "ref:any" ); |
512 | } |
513 | } |
514 | |
515 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestGetAttr) { |
516 | auto engine_kind = get_test_engine_kind(); |
517 | SKIP_IF(engine_kind != engine::kind::cpu, |
518 | "Depthwise fusion is only supported on CPU engine" ); |
519 | |
520 | engine eng {engine_kind, 0}; |
521 | |
522 | auto dt = memory::data_type::s8; |
523 | dnnl::primitive_attr attr_s, attr_os, attr_dw; |
524 | dnnl::post_ops ops; |
525 | ops.append_dw(dt, dt, dt, 3, 1, 1); |
526 | attr_s.set_scales_mask(DNNL_ARG_SRC_0, 0); |
527 | attr_os.set_scales_mask(DNNL_ARG_DST, 0); |
528 | attr_dw.set_scales_mask( |
529 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, 1 << 1); |
530 | attr_dw.set_post_ops(ops); |
531 | |
532 | memory::desc dat_md {{512, 512, 3, 3}, dt, memory::format_tag::nchw}; |
533 | memory::desc wht_md {{512, 512, 1, 1}, dt, memory::format_tag::nchw}; |
534 | auto bin_pd = binary::primitive_desc( |
535 | eng, algorithm::binary_add, wht_md, wht_md, wht_md, attr_s); |
536 | |
537 | auto cd_pd_os = convolution_forward::primitive_desc(eng, |
538 | prop_kind::forward_inference, algorithm::convolution_auto, dat_md, |
539 | wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr_os); |
540 | auto cd_pd_dw = convolution_forward::primitive_desc(eng, |
541 | prop_kind::forward_inference, algorithm::convolution_auto, dat_md, |
542 | wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr_dw); |
543 | if (test_out_of_memory()) { |
544 | attr_s = bin_pd.get_primitive_attr(); |
545 | attr_os = cd_pd_os.get_primitive_attr(); |
546 | attr_dw = cd_pd_dw.get_primitive_attr(); |
547 | } else { |
548 | ASSERT_NO_THROW(attr_s = bin_pd.get_primitive_attr()); |
549 | ASSERT_NO_THROW(attr_os = cd_pd_os.get_primitive_attr()); |
550 | ASSERT_NO_THROW(attr_dw = cd_pd_dw.get_primitive_attr()); |
551 | } |
552 | } |
553 | |
554 | HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestGetCppObjects) { |
555 | SKIP_IF_CUDA(true, "Binary post-op is not supported for CUDA" ); |
556 | SKIP_IF_HIP(true, "Binary post-op is not supported for HIP" ); |
557 | |
558 | auto engine_kind = get_test_engine_kind(); |
559 | engine eng {engine_kind, 0}; |
560 | |
561 | // Post-ops is the only object that is returned from primitive attr, rest |
562 | // calls are of `void` type. Lack of "cloning" for post-ops led to a problem |
563 | // of using a dangling pointer from destroyed object via |
564 | // `pd.get_primitive_attr().get_post_ops()` construction as attributes will |
565 | // be destroyed once post-ops are saved on stack. |
566 | // See https://github.com/oneapi-src/oneDNN/issues/1337 for details. |
567 | dnnl::primitive_attr attr; |
568 | dnnl::post_ops ops; |
569 | memory::desc po_src1_md({1, 1, 1, 1}, data_type::f32, tag::abcd); |
570 | ops.append_binary(algorithm::binary_add, po_src1_md); |
571 | attr.set_post_ops(ops); |
572 | |
573 | memory::desc md {{512, 512, 3, 3}, data_type::f32, tag::abcd}; |
574 | auto bin_pd = binary::primitive_desc( |
575 | eng, algorithm::binary_add, md, md, md, attr); |
576 | |
577 | const auto q_po = bin_pd.get_primitive_attr().get_post_ops(); |
578 | ASSERT_EQ(q_po.len(), 1); |
579 | ASSERT_EQ(q_po.kind(0), primitive::kind::binary); |
580 | |
581 | algorithm q_alg; |
582 | memory::desc q_po_src1_md; |
583 | q_po.get_params_binary(0, q_alg, q_po_src1_md); |
584 | ASSERT_EQ(q_alg, algorithm::binary_add); |
585 | ASSERT_EQ(q_po_src1_md, po_src1_md); |
586 | } |
587 | |
588 | } // namespace dnnl |
589 | |