1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3* Copyright 2020-2021 FUJITSU LIMITED
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "dnnl_test_common.hpp"
19#include "gtest/gtest.h"
20
21#include "oneapi/dnnl/dnnl.hpp"
22#include "tests/test_isa_common.hpp"
23
24namespace dnnl {
25
26using data_type = memory::data_type;
27using tag = memory::format_tag;
28
29class attr_test_t : public ::testing::Test {
30protected:
31 void SetUp() override {}
32};
33
34TEST_F(attr_test_t, TestFPMathMode) {
35 dnnl::primitive_attr attr;
36 ASSERT_EQ(attr.get_fpmath_mode(), fpmath_mode::strict);
37
38 for (auto m : {fpmath_mode::strict, fpmath_mode::bf16, fpmath_mode::f16,
39 fpmath_mode::tf32, fpmath_mode::any}) {
40 attr.set_fpmath_mode(m);
41 ASSERT_EQ(m, attr.get_fpmath_mode());
42 }
43}
44
45TEST_F(attr_test_t, TestFPMathModeDefault) {
46 ASSERT_EQ(fpmath_mode::strict, get_default_fpmath_mode());
47
48 for (auto m : {fpmath_mode::strict, fpmath_mode::bf16, fpmath_mode::f16,
49 fpmath_mode::tf32, fpmath_mode::any}) {
50 set_default_fpmath_mode(m);
51 ASSERT_EQ(m, get_default_fpmath_mode());
52 dnnl::primitive_attr attr;
53 ASSERT_EQ(m, attr.get_fpmath_mode());
54 }
55}
56
57TEST_F(attr_test_t, TestScratchpadMode) {
58 dnnl::primitive_attr attr;
59 for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) {
60 attr.set_scratchpad_mode(m);
61 ASSERT_EQ(m, attr.get_scratchpad_mode());
62 }
63}
64
65TEST_F(attr_test_t, TestScratchpadModeEx) {
66 engine eng = get_test_engine();
67
68 const memory::dim N = 2, C = 2, W = 2;
69
70 memory::desc data_md(
71 {N, C, W}, memory::data_type::f32, memory::format_tag::ncw);
72
73 dnnl::primitive_attr attr;
74 for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) {
75 attr.set_scratchpad_mode(m);
76 auto softmax_pd = softmax_forward::primitive_desc(eng,
77 prop_kind::forward_inference, algorithm::softmax_accurate,
78 data_md, data_md, 1, attr);
79 auto scratchpad_size = (long)softmax_pd.scratchpad_desc().get_size();
80 auto mem_consumption
81 = (long)softmax_pd.query_s64(query::memory_consumption_s64);
82
83 if (m == scratchpad_mode::library) {
84 ASSERT_EQ(scratchpad_size, 0L);
85 } else {
86 ASSERT_EQ(mem_consumption, 0L);
87 }
88 }
89}
90
91HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestScratchpadArg) {
92 engine eng = get_test_engine();
93
94 const memory::dim N = 2, C = 2, W = 2;
95
96 memory::desc data_md(
97 {N, C, W}, memory::data_type::f32, memory::format_tag::ncw);
98
99 dnnl::primitive_attr attr;
100 for (auto m : {scratchpad_mode::library, scratchpad_mode::user}) {
101 attr.set_scratchpad_mode(m);
102 auto softmax_pd = softmax_forward::primitive_desc(eng,
103 prop_kind::forward_inference, algorithm::softmax_accurate,
104 data_md, data_md, 1, attr);
105
106 auto src = test::make_memory(softmax_pd.src_desc(), eng);
107 auto dst = test::make_memory(softmax_pd.dst_desc(), eng);
108 auto scratchpad = test::make_memory(softmax_pd.scratchpad_desc(), eng);
109
110 fill_data<float>(src.get_desc().get_size() / sizeof(float), src);
111
112 stream s(eng);
113
114 softmax_forward softmax_p(softmax_pd);
115 softmax_p.execute(s,
116 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
117 {DNNL_ARG_SCRATCHPAD, scratchpad}});
118 s.wait();
119 }
120}
121
122TEST_F(attr_test_t, TestZeroPoints) {
123 dnnl::primitive_attr attr;
124
125 const std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_DST};
126 const std::vector<int> unsupported_args = {DNNL_ARG_BIAS, DNNL_ARG_DST_2,
127 DNNL_ARG_MEAN, DNNL_ARG_WORKSPACE, DNNL_ARG_SCRATCHPAD};
128
129 for (auto arg : supported_args) {
130 // single non-default zero_point for supported arg
131 attr.set_zero_points_mask(arg, 0);
132 }
133
134 for (auto arg : unsupported_args) {
135 // single **default** zero_point for **unsupported** arg
136 EXPECT_ANY_THROW(attr.set_zero_points_mask(arg, 0));
137 }
138
139 // multiple zero_points not implemented yet ...
140}
141
142TEST_F(attr_test_t, TestZeroPointsExpectFailure) {
143 dnnl::primitive_attr attr;
144
145 const int unsupported_arg = DNNL_ARG_MEAN;
146
147 // single non-default zero_point for unsupported arg
148 EXPECT_ANY_THROW(attr.set_zero_points_mask(unsupported_arg, 0));
149
150 // multiple zero points for unsupported args
151 EXPECT_ANY_THROW(attr.set_zero_points_mask(unsupported_arg, 1 << 1));
152}
153
154HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestScales) {
155 dnnl::primitive_attr attr;
156
157 const std::vector<int> supported_args = {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1,
158 DNNL_ARG_MULTIPLE_SRC, DNNL_ARG_MULTIPLE_SRC + 42};
159 const std::vector<int> unsupported_args = {DNNL_ARG_BIAS, DNNL_ARG_DST_2,
160 DNNL_ARG_MEAN, DNNL_ARG_WORKSPACE, DNNL_ARG_SCRATCHPAD};
161
162 for (auto arg : supported_args) {
163 // single non-default scales for supported arg
164 attr.set_scales_mask(arg, 0);
165 // multiple scales
166 attr.set_scales_mask(arg, 1 << 1);
167 }
168
169 for (auto arg : unsupported_args) {
170 // single scales for unsupported args
171 EXPECT_ANY_THROW(attr.set_scales_mask(arg, 0));
172 }
173}
174
175HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNDataQuantization) {
176 dnnl::primitive_attr attr;
177
178 float scale = NAN, shift = NAN;
179
180 // default scale and shift
181 attr.get_rnn_data_qparams(scale, shift);
182 ASSERT_EQ(scale, 1.f);
183 ASSERT_EQ(shift, 0.f);
184
185 // non-default
186 attr.set_rnn_data_qparams(0.5f, 2.f);
187 attr.get_rnn_data_qparams(scale, shift);
188 ASSERT_EQ(scale, 0.5f);
189 ASSERT_EQ(shift, 2.f);
190}
191
192HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNWeightsQuantization) {
193 dnnl::primitive_attr attr;
194
195 int scales_mask = INT_MAX;
196 std::vector<float> scales;
197
198 // default scale and shift
199 attr.get_rnn_weights_qparams(scales_mask, scales);
200 ASSERT_EQ(scales_mask, 0);
201 ASSERT_EQ(scales.size(), 1U);
202 ASSERT_EQ(scales[0], 1.f);
203
204 // single non-default scales
205 attr.set_rnn_weights_qparams(0, {2.f});
206 attr.get_rnn_weights_qparams(scales_mask, scales);
207 ASSERT_EQ(scales_mask, 0);
208 ASSERT_EQ(scales.size(), 1U);
209 ASSERT_EQ(scales[0], 2.f);
210
211 // multiple scales
212 attr.set_rnn_weights_qparams(1 << 1, {1.f, 2.f, 4.f});
213 attr.get_rnn_weights_qparams(scales_mask, scales);
214 ASSERT_EQ(scales_mask, 1 << 1);
215 ASSERT_EQ(scales.size(), 3U);
216 ASSERT_EQ(scales[0], 1.f);
217 ASSERT_EQ(scales[1], 2.f);
218 ASSERT_EQ(scales[2], 4.f);
219}
220
221HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestRNNProjWeightsQuantization) {
222 dnnl::primitive_attr attr;
223
224 int scales_mask = INT_MAX;
225 std::vector<float> scales;
226
227 // default scale and shift
228 attr.get_rnn_weights_projection_qparams(scales_mask, scales);
229 ASSERT_EQ(scales_mask, 0);
230 ASSERT_EQ(scales.size(), 1U);
231 ASSERT_EQ(scales[0], 1.f);
232
233 // single non-default scales
234 attr.set_rnn_weights_projection_qparams(0, {2.f});
235 attr.get_rnn_weights_projection_qparams(scales_mask, scales);
236 ASSERT_EQ(scales_mask, 0);
237 ASSERT_EQ(scales.size(), 1U);
238 ASSERT_EQ(scales[0], 2.f);
239
240 // multiple scales
241 attr.set_rnn_weights_projection_qparams(1 << 1, {1.f, 2.f, 4.f});
242 attr.get_rnn_weights_projection_qparams(scales_mask, scales);
243 ASSERT_EQ(scales_mask, 1 << 1);
244 ASSERT_EQ(scales.size(), 3U);
245 ASSERT_EQ(scales[0], 1.f);
246 ASSERT_EQ(scales[1], 2.f);
247 ASSERT_EQ(scales[2], 4.f);
248}
249
250TEST_F(attr_test_t, TestScalesExpectFailure) {
251 dnnl::primitive_attr attr;
252 const int unsupported_arg = DNNL_ARG_MEAN;
253
254 // non-default scales for unsupported arg
255 EXPECT_ANY_THROW(attr.set_scales_mask(unsupported_arg, 0));
256 EXPECT_ANY_THROW(attr.set_scales_mask(unsupported_arg, 1 << 1));
257}
258
259HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestPostOps) {
260 dnnl::primitive_attr attr;
261 dnnl::post_ops ops;
262
263 algorithm alg = algorithm::undef;
264 float scale = NAN, alpha = NAN, beta = NAN;
265 int32_t sum_zp = INT_MAX;
266 data_type dt = data_type::undef;
267
268 ASSERT_EQ(ops.len(), 0);
269 ASSERT_EQ(attr.get_post_ops().len(), 0);
270
271 ops.append_sum(1.1f, 1, data_type::f32);
272 attr.set_post_ops(ops);
273
274 ASSERT_EQ(attr.get_post_ops().len(), 1);
275 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum);
276 attr.get_post_ops().get_params_sum(0, scale, sum_zp, dt);
277 ASSERT_FLOAT_EQ(scale, 1.1f);
278 ASSERT_EQ(1, sum_zp);
279 ASSERT_EQ(data_type::f32, dt);
280
281 ops.append_eltwise(algorithm::eltwise_clip, 3.3f, 4.4f);
282 attr.set_post_ops(ops);
283
284 ASSERT_EQ(attr.get_post_ops().len(), 2);
285 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum);
286 ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise);
287 attr.get_post_ops().get_params_eltwise(1, alg, alpha, beta);
288 ASSERT_EQ(alg, algorithm::eltwise_clip);
289 ASSERT_FLOAT_EQ(alpha, 3.3f);
290 ASSERT_FLOAT_EQ(beta, 4.4f);
291
292 memory::desc src1_md({1}, data_type::f32, {1});
293 ops.append_binary(algorithm::binary_add, src1_md);
294 attr.set_post_ops(ops);
295
296 ASSERT_EQ(attr.get_post_ops().len(), 3);
297 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum);
298 ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise);
299 ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::binary);
300 memory::desc src1_md_out;
301 attr.get_post_ops().get_params_binary(2, alg, src1_md_out);
302 ASSERT_EQ(alg, algorithm::binary_add);
303 ASSERT_EQ(src1_md, src1_md_out);
304
305 const int prelu_mask = 1;
306 ops.append_prelu(prelu_mask);
307 attr.set_post_ops(ops);
308 ASSERT_EQ(attr.get_post_ops().len(), 4);
309 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::sum);
310 ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::eltwise);
311 ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::binary);
312 ASSERT_EQ(attr.get_post_ops().kind(3), primitive::kind::prelu);
313 int mask = INT_MAX;
314 attr.get_post_ops().get_params_prelu(3, mask);
315 ASSERT_EQ(mask, prelu_mask);
316}
317
318TEST_F(attr_test_t, TestPostOpsCheckLimit) {
319 dnnl::post_ops ops_sum, ops_eltwise, ops_binary, ops_prelu;
320
321 for (int i = 0; i < 32; i++) {
322 EXPECT_NO_THROW(ops_sum.append_sum(i + 1.f));
323 EXPECT_NO_THROW(ops_eltwise.append_eltwise(
324 algorithm::eltwise_relu, 2 * i, 0.f));
325 EXPECT_NO_THROW(ops_binary.append_binary(algorithm::binary_add,
326 memory::desc({i}, data_type::s8, memory::format_tag::a)));
327 EXPECT_NO_THROW(ops_prelu.append_prelu(1));
328 }
329
330 EXPECT_ANY_THROW(ops_prelu.append_prelu(1));
331 EXPECT_ANY_THROW(ops_sum.append_sum(1.f));
332 EXPECT_ANY_THROW(
333 ops_eltwise.append_eltwise(algorithm::eltwise_relu, 1.f, 0.f));
334 EXPECT_ANY_THROW(ops_binary.append_binary(algorithm::binary_add,
335 memory::desc({1}, data_type::s8, memory::format_tag::a)));
336}
337
338HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, DepthwiseFusionPostop) {
339 dnnl::primitive_attr attr;
340 dnnl::post_ops ops;
341
342 data_type wei_dt = data_type::undef;
343 data_type bias_dt = data_type::undef;
344 data_type dst_dt = data_type::undef;
345 memory::dim kernel = -1;
346 memory::dim stride = -1;
347 memory::dim padding = -1;
348
349 ASSERT_EQ(ops.len(), 0);
350 ASSERT_EQ(attr.get_post_ops().len(), 0);
351
352 ops.append_dw(memory::data_type::s8, memory::data_type::f32,
353 memory::data_type::u8, 3, 1, 1);
354 attr.set_post_ops(ops);
355
356 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution);
357 attr.get_post_ops().get_params_dw(
358 0, wei_dt, bias_dt, dst_dt, kernel, stride, padding);
359 ASSERT_EQ(wei_dt, memory::data_type::s8);
360 ASSERT_EQ(bias_dt, memory::data_type::f32);
361 ASSERT_EQ(dst_dt, memory::data_type::u8);
362 ASSERT_EQ(kernel, 3);
363 ASSERT_EQ(stride, 1);
364 ASSERT_EQ(padding, 1);
365
366 kernel = stride = padding = -1;
367 ops.append_dw(memory::data_type::u8, memory::data_type::s32,
368 memory::data_type::f32, 3, 2, 1);
369 attr.set_post_ops(ops);
370
371 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution);
372 ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::convolution);
373
374 attr.get_post_ops().get_params_dw(
375 1, wei_dt, bias_dt, dst_dt, kernel, stride, padding);
376
377 ASSERT_EQ(wei_dt, memory::data_type::u8);
378 ASSERT_EQ(bias_dt, memory::data_type::s32);
379 ASSERT_EQ(dst_dt, memory::data_type::f32);
380 ASSERT_EQ(kernel, 3);
381 ASSERT_EQ(stride, 2);
382 ASSERT_EQ(padding, 1);
383
384 kernel = stride = padding = -1;
385 ops.append_dw(memory::data_type::f32, memory::data_type::f32,
386 memory::data_type::f32, 7, 3, 2);
387 attr.set_post_ops(ops);
388
389 ASSERT_EQ(attr.get_post_ops().kind(0), primitive::kind::convolution);
390 ASSERT_EQ(attr.get_post_ops().kind(1), primitive::kind::convolution);
391 ASSERT_EQ(attr.get_post_ops().kind(2), primitive::kind::convolution);
392
393 attr.get_post_ops().get_params_dw(
394 2, wei_dt, bias_dt, dst_dt, kernel, stride, padding);
395
396 ASSERT_EQ(wei_dt, memory::data_type::f32);
397 ASSERT_EQ(bias_dt, memory::data_type::f32);
398 ASSERT_EQ(dst_dt, memory::data_type::f32);
399 ASSERT_EQ(kernel, 7);
400 ASSERT_EQ(stride, 3);
401 ASSERT_EQ(padding, 2);
402
403 kernel = stride = padding = -1;
404 ops.append_dw(memory::data_type::s8, memory::data_type::f32,
405 memory::data_type::u8, 5, 2, 1);
406 attr.set_post_ops(ops);
407
408 ASSERT_EQ(attr.get_post_ops().kind(3), primitive::kind::convolution);
409
410 attr.get_post_ops().get_params_dw(
411 3, wei_dt, bias_dt, dst_dt, kernel, stride, padding);
412
413 ASSERT_EQ(wei_dt, memory::data_type::s8);
414 ASSERT_EQ(bias_dt, memory::data_type::f32);
415 ASSERT_EQ(dst_dt, memory::data_type::u8);
416 ASSERT_EQ(kernel, 5);
417 ASSERT_EQ(stride, 2);
418 ASSERT_EQ(padding, 1);
419}
420
421HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, DepthwiseFusion) {
422
423 auto engine_kind = get_test_engine_kind();
424 SKIP_IF(engine_kind != engine::kind::cpu,
425 "Depthwise fusion is only supported on CPU engine");
426#if DNNL_AARCH64
427 SKIP_IF(true, "Depthwise fusion is not supported on AArch64 at this time");
428#endif
429
430 engine e {engine_kind, 0};
431
432 std::vector<memory::data_type> test_dts {
433 memory::data_type::f32, memory::data_type::s8};
434
435 if (!unsupported_data_type(memory::data_type::bf16))
436 test_dts.push_back(memory::data_type::bf16);
437
438 for (auto dt : test_dts) {
439
440 memory::desc dat_md {{1024, 512, 64, 64}, dt, memory::format_tag::any};
441 memory::desc wht_md {{512, 512, 1, 1}, dt, memory::format_tag::any};
442
443 std::string impl_info_unfused;
444
445 auto pd = convolution_forward::primitive_desc(e,
446 prop_kind::forward_inference, algorithm::convolution_auto,
447 dat_md, wht_md, dat_md, {1, 1}, {0, 0}, {0, 0});
448
449 ASSERT_NO_THROW(impl_info_unfused = pd.impl_info_str(););
450
451 // skip if above unfused impl is not jitted.
452 if (impl_info_unfused.compare(0, 3, "jit") != 0) continue;
453
454 // skip if above unfused impl is jitted amx.
455 if (impl_info_unfused.find("amx") != std::string::npos) continue;
456
457 dnnl::primitive_attr attr;
458 dnnl::post_ops ops;
459 ops.append_dw(dt, dt, dt, 3, 1, 1);
460 attr.set_post_ops(ops);
461
462 std::string impl_info_fused;
463
464 pd = convolution_forward::primitive_desc(e,
465 prop_kind::forward_inference, algorithm::convolution_auto,
466 dat_md, wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr);
467 ASSERT_NO_THROW(impl_info_fused = pd.impl_info_str(););
468
469 // Make sure ref fused impl is not deployed.
470 // NOTE: When out_of_memory testing enabled, all implementations that
471 // construct primitive attributes will fail, hence the ref
472 // implementation is deployed.
473 if (!test_out_of_memory()) {
474 ASSERT_EQ(impl_info_fused, impl_info_unfused);
475 }
476 }
477}
478
479HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, InnerProdBlockedWeights) {
480 auto engine_kind = get_test_engine_kind();
481 bool skip_test = !DNNL_X64 || (DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE)
482 || (engine_kind != engine::kind::cpu);
483#if DNNL_X64 && (DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE)
484 skip_test = skip_test || !dnnl::mayiuse(cpu_isa::avx512_core);
485#endif
486 SKIP_IF(skip_test,
487 "Inner product blocked weights test is supported only on "
488 "avx512_core CPU");
489
490 engine e {engine_kind, 0};
491
492 std::vector<memory::format_tag> blocked_weights_tags {
493 memory::format_tag::OIhw16i64o, memory::format_tag::OIhw16i32o,
494 memory::format_tag::OIhw16i16o};
495
496 for (const auto &weights_tag : blocked_weights_tags) {
497 memory::desc src_md {{1024, 512, 1, 1}, memory::data_type::f32,
498 memory::format_tag::any};
499 memory::desc wei_md {
500 {256, 512, 1, 1}, memory::data_type::f32, weights_tag};
501 memory::desc bia_md {
502 {256}, memory::data_type::f32, memory::format_tag::any};
503 memory::desc dst_md {
504 {1024, 256}, memory::data_type::f32, memory::format_tag::any};
505
506 auto pd = inner_product_forward::primitive_desc(
507 e, prop_kind::forward_training, src_md, wei_md, bia_md, dst_md);
508
509 std::string impl_info;
510 ASSERT_NO_THROW(impl_info = pd.impl_info_str(););
511 ASSERT_NE(impl_info, "ref:any");
512 }
513}
514
515HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestGetAttr) {
516 auto engine_kind = get_test_engine_kind();
517 SKIP_IF(engine_kind != engine::kind::cpu,
518 "Depthwise fusion is only supported on CPU engine");
519
520 engine eng {engine_kind, 0};
521
522 auto dt = memory::data_type::s8;
523 dnnl::primitive_attr attr_s, attr_os, attr_dw;
524 dnnl::post_ops ops;
525 ops.append_dw(dt, dt, dt, 3, 1, 1);
526 attr_s.set_scales_mask(DNNL_ARG_SRC_0, 0);
527 attr_os.set_scales_mask(DNNL_ARG_DST, 0);
528 attr_dw.set_scales_mask(
529 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, 1 << 1);
530 attr_dw.set_post_ops(ops);
531
532 memory::desc dat_md {{512, 512, 3, 3}, dt, memory::format_tag::nchw};
533 memory::desc wht_md {{512, 512, 1, 1}, dt, memory::format_tag::nchw};
534 auto bin_pd = binary::primitive_desc(
535 eng, algorithm::binary_add, wht_md, wht_md, wht_md, attr_s);
536
537 auto cd_pd_os = convolution_forward::primitive_desc(eng,
538 prop_kind::forward_inference, algorithm::convolution_auto, dat_md,
539 wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr_os);
540 auto cd_pd_dw = convolution_forward::primitive_desc(eng,
541 prop_kind::forward_inference, algorithm::convolution_auto, dat_md,
542 wht_md, dat_md, {1, 1}, {0, 0}, {0, 0}, attr_dw);
543 if (test_out_of_memory()) {
544 attr_s = bin_pd.get_primitive_attr();
545 attr_os = cd_pd_os.get_primitive_attr();
546 attr_dw = cd_pd_dw.get_primitive_attr();
547 } else {
548 ASSERT_NO_THROW(attr_s = bin_pd.get_primitive_attr());
549 ASSERT_NO_THROW(attr_os = cd_pd_os.get_primitive_attr());
550 ASSERT_NO_THROW(attr_dw = cd_pd_dw.get_primitive_attr());
551 }
552}
553
554HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestGetCppObjects) {
555 SKIP_IF_CUDA(true, "Binary post-op is not supported for CUDA");
556 SKIP_IF_HIP(true, "Binary post-op is not supported for HIP");
557
558 auto engine_kind = get_test_engine_kind();
559 engine eng {engine_kind, 0};
560
561 // Post-ops is the only object that is returned from primitive attr, rest
562 // calls are of `void` type. Lack of "cloning" for post-ops led to a problem
563 // of using a dangling pointer from destroyed object via
564 // `pd.get_primitive_attr().get_post_ops()` construction as attributes will
565 // be destroyed once post-ops are saved on stack.
566 // See https://github.com/oneapi-src/oneDNN/issues/1337 for details.
567 dnnl::primitive_attr attr;
568 dnnl::post_ops ops;
569 memory::desc po_src1_md({1, 1, 1, 1}, data_type::f32, tag::abcd);
570 ops.append_binary(algorithm::binary_add, po_src1_md);
571 attr.set_post_ops(ops);
572
573 memory::desc md {{512, 512, 3, 3}, data_type::f32, tag::abcd};
574 auto bin_pd = binary::primitive_desc(
575 eng, algorithm::binary_add, md, md, md, attr);
576
577 const auto q_po = bin_pd.get_primitive_attr().get_post_ops();
578 ASSERT_EQ(q_po.len(), 1);
579 ASSERT_EQ(q_po.kind(0), primitive::kind::binary);
580
581 algorithm q_alg;
582 memory::desc q_po_src1_md;
583 q_po.get_params_binary(0, q_alg, q_po_src1_md);
584 ASSERT_EQ(q_alg, algorithm::binary_add);
585 ASSERT_EQ(q_po_src1_md, po_src1_md);
586}
587
588} // namespace dnnl
589