1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include <tuple>
23
24namespace dnnl {
25
26static bool are_valid_flags(
27 dnnl_engine_kind_t engine_kind, dnnl_stream_flags_t stream_flags) {
28 bool ok = true;
29#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
30 if (engine_kind == dnnl_gpu && (stream_flags & dnnl_stream_out_of_order))
31 ok = false;
32#endif
33#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
34 if (engine_kind == dnnl_cpu && (stream_flags & dnnl_stream_out_of_order))
35 ok = false;
36#endif
37 return ok;
38}
39
40class stream_test_c_t
41 : public ::testing::TestWithParam<
42 std::tuple<dnnl_engine_kind_t, dnnl_stream_flags_t>> {
43protected:
44 void SetUp() override {
45 std::tie(eng_kind, stream_flags) = GetParam();
46
47 if (dnnl_engine_get_count(eng_kind) == 0) return;
48
49 DNNL_CHECK(dnnl_engine_create(&engine, eng_kind, 0));
50
51 // Check that the flags are compatible with the engine
52 if (!are_valid_flags(eng_kind, stream_flags)) {
53 DNNL_CHECK(dnnl_engine_destroy(engine));
54 engine = nullptr;
55 return;
56 }
57
58 DNNL_CHECK(dnnl_stream_create(&stream, engine, stream_flags));
59 }
60
61 void TearDown() override {
62 if (stream) { DNNL_CHECK(dnnl_stream_destroy(stream)); }
63 if (engine) { DNNL_CHECK(dnnl_engine_destroy(engine)); }
64 }
65
66 dnnl_engine_kind_t eng_kind;
67 dnnl_stream_flags_t stream_flags;
68
69 dnnl_engine_t engine = nullptr;
70 dnnl_stream_t stream = nullptr;
71};
72
73class stream_test_cpp_t
74 : public ::testing::TestWithParam<
75 std::tuple<dnnl_engine_kind_t, dnnl_stream_flags_t>> {};
76
77TEST_P(stream_test_c_t, Create) {
78 SKIP_IF(!engine, "Engines not found or stream flags are incompatible.");
79
80 DNNL_CHECK(dnnl_stream_wait(stream));
81}
82
83TEST(stream_test_c_t, WaitNullStream) {
84 dnnl_stream_t stream = nullptr;
85 dnnl_status_t status = dnnl_stream_wait(stream);
86 ASSERT_EQ(status, dnnl_invalid_arguments);
87}
88
89#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
90TEST(stream_test_c_t, Wait) {
91 dnnl_engine_t engine;
92 DNNL_CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0));
93
94 dnnl_stream_t stream;
95 DNNL_CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
96
97 DNNL_CHECK(dnnl_stream_wait(stream));
98
99 DNNL_CHECK(dnnl_stream_destroy(stream));
100 DNNL_CHECK(dnnl_engine_destroy(engine));
101}
102#endif
103
104TEST_P(stream_test_cpp_t, Wait) {
105 dnnl_engine_kind_t eng_kind_c = dnnl_cpu;
106 dnnl_stream_flags_t stream_flags_c = dnnl_stream_in_order;
107 std::tie(eng_kind_c, stream_flags_c) = GetParam();
108
109 engine::kind eng_kind = static_cast<engine::kind>(eng_kind_c);
110 stream::flags stream_flags = static_cast<stream::flags>(stream_flags_c);
111 SKIP_IF(engine::get_count(eng_kind) == 0, "Engines not found.");
112
113 engine eng(eng_kind, 0);
114 SKIP_IF(!are_valid_flags(static_cast<dnnl_engine_kind_t>(eng.get_kind()),
115 stream_flags_c),
116 "Incompatible stream flags.");
117
118 stream s(eng, stream_flags);
119 engine s_eng = s.get_engine();
120 s.wait();
121}
122
123#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
124TEST(stream_test_c_t, GetStream) {
125 dnnl_engine_t engine;
126 DNNL_CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0));
127
128 dnnl_stream_t stream;
129 DNNL_CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
130
131 dnnl_engine_t stream_engine;
132 DNNL_CHECK(dnnl_stream_get_engine(stream, &stream_engine));
133 ASSERT_EQ(engine, stream_engine);
134
135 DNNL_CHECK(dnnl_stream_destroy(stream));
136 DNNL_CHECK(dnnl_engine_destroy(engine));
137}
138#endif
139
140namespace {
141struct print_to_string_param_name_t {
142 template <class ParamType>
143 std::string operator()(
144 const ::testing::TestParamInfo<ParamType> &info) const {
145 return to_string(std::get<0>(info.param)) + "_"
146 + to_string(std::get<1>(info.param));
147 }
148};
149
150auto all_params = ::testing::Combine(::testing::Values(dnnl_cpu, dnnl_gpu),
151 ::testing::Values(dnnl_stream_in_order, dnnl_stream_out_of_order));
152
153} // namespace
154
155INSTANTIATE_TEST_SUITE_P(AllEngineKinds, stream_test_c_t, all_params,
156 print_to_string_param_name_t());
157INSTANTIATE_TEST_SUITE_P(AllEngineKinds, stream_test_cpp_t, all_params,
158 print_to_string_param_name_t());
159
160} // namespace dnnl
161