1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | #include "oneapi/dnnl/dnnl_types.h" |
22 | |
23 | namespace dnnl { |
24 | |
25 | const dnnl_status_t ok = dnnl_success; |
26 | |
27 | class pd_iter_test_t : public ::testing::Test { |
28 | protected: |
29 | dnnl_engine_t engine; |
30 | void SetUp() override { |
31 | auto engine_kind |
32 | = static_cast<dnnl_engine_kind_t>(get_test_engine_kind()); |
33 | ASSERT_EQ(dnnl_engine_create(&engine, engine_kind, 0), ok); |
34 | } |
35 | void TearDown() override { dnnl_engine_destroy(engine); } |
36 | }; |
37 | |
38 | TEST_F(pd_iter_test_t, TestReLUImpls) { |
39 | dnnl_memory_desc_t dense_md; |
40 | dnnl_dims_t dims = {4, 16, 16, 16}; |
41 | ASSERT_EQ(dnnl_memory_desc_create_with_tag( |
42 | &dense_md, 4, dims, dnnl_f32, dnnl_nchw), |
43 | ok); |
44 | |
45 | dnnl_primitive_desc_t pd; |
46 | dnnl_status_t rc = dnnl_eltwise_forward_primitive_desc_create(&pd, engine, |
47 | dnnl_forward_inference, dnnl_eltwise_relu, dense_md, dense_md, 0.f, |
48 | 0.f, nullptr); |
49 | ASSERT_EQ(rc, ok); /* there should be at least one impl */ |
50 | |
51 | while ((rc = dnnl_primitive_desc_next_impl(pd)) == ok) |
52 | ; |
53 | ASSERT_EQ(rc, dnnl_last_impl_reached); |
54 | |
55 | // Primitive descriptor has to be valid when the iterator |
56 | // reaches the end. |
57 | dnnl_primitive_t p; |
58 | rc = dnnl_primitive_create(&p, pd); |
59 | ASSERT_EQ(rc, ok); |
60 | |
61 | rc = dnnl_primitive_desc_destroy(pd); |
62 | ASSERT_EQ(rc, ok); |
63 | rc = dnnl_primitive_destroy(p); |
64 | ASSERT_EQ(rc, ok); |
65 | rc = dnnl_memory_desc_destroy(dense_md); |
66 | ASSERT_EQ(rc, ok); |
67 | } |
68 | |
69 | TEST_F(pd_iter_test_t, UnsupportedPrimitives) { |
70 | const float scales[2] = {1.0f, 1.0f}; |
71 | dnnl_memory_desc_t mds[2]; |
72 | |
73 | dnnl_dims_t dims = {1, 16, 16, 16}; |
74 | ASSERT_EQ(dnnl_memory_desc_create_with_tag( |
75 | &mds[0], 4, dims, dnnl_f32, dnnl_nchw), |
76 | ok); |
77 | ASSERT_EQ(dnnl_memory_desc_create_with_tag( |
78 | &mds[1], 4, dims, dnnl_f32, dnnl_nchw), |
79 | ok); |
80 | |
81 | dnnl_primitive_desc_t reorder_pd; |
82 | dnnl_primitive_desc_t concat_pd; |
83 | dnnl_primitive_desc_t sum_pd; |
84 | |
85 | ASSERT_EQ(dnnl_reorder_primitive_desc_create( |
86 | &reorder_pd, mds[0], engine, mds[1], engine, nullptr), |
87 | ok); |
88 | ASSERT_EQ(dnnl_concat_primitive_desc_create( |
89 | &concat_pd, engine, nullptr, 2, 0, mds, nullptr), |
90 | ok); |
91 | ASSERT_EQ(dnnl_sum_primitive_desc_create( |
92 | &sum_pd, engine, mds[0], 2, scales, mds, nullptr), |
93 | ok); |
94 | |
95 | ASSERT_EQ( |
96 | dnnl_primitive_desc_next_impl(reorder_pd), dnnl_last_impl_reached); |
97 | ASSERT_EQ(dnnl_primitive_desc_next_impl(concat_pd), dnnl_last_impl_reached); |
98 | ASSERT_EQ(dnnl_primitive_desc_next_impl(sum_pd), dnnl_last_impl_reached); |
99 | |
100 | ASSERT_EQ(dnnl_primitive_desc_destroy(reorder_pd), ok); |
101 | ASSERT_EQ(dnnl_primitive_desc_destroy(concat_pd), ok); |
102 | ASSERT_EQ(dnnl_primitive_desc_destroy(sum_pd), ok); |
103 | |
104 | ASSERT_EQ(dnnl_memory_desc_destroy(mds[0]), ok); |
105 | ASSERT_EQ(dnnl_memory_desc_destroy(mds[1]), ok); |
106 | } |
107 | |
108 | TEST(pd_next_impl, TestEltwiseImpl) { |
109 | SKIP_IF_CUDA(true, "Unsupported memory format for CUDA" ); |
110 | SKIP_IF_HIP(true, "Unsupported memory format for HIP" ); |
111 | auto eng = get_test_engine(); |
112 | memory::desc md( |
113 | {8, 32, 4, 4}, memory::data_type::f32, memory::format_tag::nChw8c); |
114 | |
115 | eltwise_forward::primitive_desc epd(eng, prop_kind::forward_training, |
116 | algorithm::eltwise_relu, md, md, 0.f); |
117 | |
118 | std::string impl0(epd.impl_info_str()); |
119 | eltwise_forward e0(epd); |
120 | |
121 | while (epd.next_impl()) { |
122 | std::string impl1(epd.impl_info_str()); |
123 | eltwise_forward e1(epd); |
124 | ASSERT_NE(impl0, impl1); |
125 | impl0 = impl1; |
126 | } |
127 | |
128 | // When the last implementation is reached all subsequent `next_impl()` |
129 | // calls should return `false`. |
130 | ASSERT_EQ(epd.next_impl(), false); |
131 | ASSERT_EQ(epd.next_impl(), false); |
132 | } |
133 | |
134 | } // namespace dnnl |
135 | |