1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.h"
21#include "oneapi/dnnl/dnnl_types.h"
22
23namespace dnnl {
24
25const dnnl_status_t ok = dnnl_success;
26
27class pd_iter_test_t : public ::testing::Test {
28protected:
29 dnnl_engine_t engine;
30 void SetUp() override {
31 auto engine_kind
32 = static_cast<dnnl_engine_kind_t>(get_test_engine_kind());
33 ASSERT_EQ(dnnl_engine_create(&engine, engine_kind, 0), ok);
34 }
35 void TearDown() override { dnnl_engine_destroy(engine); }
36};
37
38TEST_F(pd_iter_test_t, TestReLUImpls) {
39 dnnl_memory_desc_t dense_md;
40 dnnl_dims_t dims = {4, 16, 16, 16};
41 ASSERT_EQ(dnnl_memory_desc_create_with_tag(
42 &dense_md, 4, dims, dnnl_f32, dnnl_nchw),
43 ok);
44
45 dnnl_primitive_desc_t pd;
46 dnnl_status_t rc = dnnl_eltwise_forward_primitive_desc_create(&pd, engine,
47 dnnl_forward_inference, dnnl_eltwise_relu, dense_md, dense_md, 0.f,
48 0.f, nullptr);
49 ASSERT_EQ(rc, ok); /* there should be at least one impl */
50
51 while ((rc = dnnl_primitive_desc_next_impl(pd)) == ok)
52 ;
53 ASSERT_EQ(rc, dnnl_last_impl_reached);
54
55 // Primitive descriptor has to be valid when the iterator
56 // reaches the end.
57 dnnl_primitive_t p;
58 rc = dnnl_primitive_create(&p, pd);
59 ASSERT_EQ(rc, ok);
60
61 rc = dnnl_primitive_desc_destroy(pd);
62 ASSERT_EQ(rc, ok);
63 rc = dnnl_primitive_destroy(p);
64 ASSERT_EQ(rc, ok);
65 rc = dnnl_memory_desc_destroy(dense_md);
66 ASSERT_EQ(rc, ok);
67}
68
69TEST_F(pd_iter_test_t, UnsupportedPrimitives) {
70 const float scales[2] = {1.0f, 1.0f};
71 dnnl_memory_desc_t mds[2];
72
73 dnnl_dims_t dims = {1, 16, 16, 16};
74 ASSERT_EQ(dnnl_memory_desc_create_with_tag(
75 &mds[0], 4, dims, dnnl_f32, dnnl_nchw),
76 ok);
77 ASSERT_EQ(dnnl_memory_desc_create_with_tag(
78 &mds[1], 4, dims, dnnl_f32, dnnl_nchw),
79 ok);
80
81 dnnl_primitive_desc_t reorder_pd;
82 dnnl_primitive_desc_t concat_pd;
83 dnnl_primitive_desc_t sum_pd;
84
85 ASSERT_EQ(dnnl_reorder_primitive_desc_create(
86 &reorder_pd, mds[0], engine, mds[1], engine, nullptr),
87 ok);
88 ASSERT_EQ(dnnl_concat_primitive_desc_create(
89 &concat_pd, engine, nullptr, 2, 0, mds, nullptr),
90 ok);
91 ASSERT_EQ(dnnl_sum_primitive_desc_create(
92 &sum_pd, engine, mds[0], 2, scales, mds, nullptr),
93 ok);
94
95 ASSERT_EQ(
96 dnnl_primitive_desc_next_impl(reorder_pd), dnnl_last_impl_reached);
97 ASSERT_EQ(dnnl_primitive_desc_next_impl(concat_pd), dnnl_last_impl_reached);
98 ASSERT_EQ(dnnl_primitive_desc_next_impl(sum_pd), dnnl_last_impl_reached);
99
100 ASSERT_EQ(dnnl_primitive_desc_destroy(reorder_pd), ok);
101 ASSERT_EQ(dnnl_primitive_desc_destroy(concat_pd), ok);
102 ASSERT_EQ(dnnl_primitive_desc_destroy(sum_pd), ok);
103
104 ASSERT_EQ(dnnl_memory_desc_destroy(mds[0]), ok);
105 ASSERT_EQ(dnnl_memory_desc_destroy(mds[1]), ok);
106}
107
108TEST(pd_next_impl, TestEltwiseImpl) {
109 SKIP_IF_CUDA(true, "Unsupported memory format for CUDA");
110 SKIP_IF_HIP(true, "Unsupported memory format for HIP");
111 auto eng = get_test_engine();
112 memory::desc md(
113 {8, 32, 4, 4}, memory::data_type::f32, memory::format_tag::nChw8c);
114
115 eltwise_forward::primitive_desc epd(eng, prop_kind::forward_training,
116 algorithm::eltwise_relu, md, md, 0.f);
117
118 std::string impl0(epd.impl_info_str());
119 eltwise_forward e0(epd);
120
121 while (epd.next_impl()) {
122 std::string impl1(epd.impl_info_str());
123 eltwise_forward e1(epd);
124 ASSERT_NE(impl0, impl1);
125 impl0 = impl1;
126 }
127
128 // When the last implementation is reached all subsequent `next_impl()`
129 // calls should return `false`.
130 ASSERT_EQ(epd.next_impl(), false);
131 ASSERT_EQ(epd.next_impl(), false);
132}
133
134} // namespace dnnl
135