1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include <tuple> |
23 | |
24 | namespace dnnl { |
25 | |
26 | static bool are_valid_flags( |
27 | dnnl_engine_kind_t engine_kind, dnnl_stream_flags_t stream_flags) { |
28 | bool ok = true; |
29 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
30 | if (engine_kind == dnnl_gpu && (stream_flags & dnnl_stream_out_of_order)) |
31 | ok = false; |
32 | #endif |
33 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL |
34 | if (engine_kind == dnnl_cpu && (stream_flags & dnnl_stream_out_of_order)) |
35 | ok = false; |
36 | #endif |
37 | return ok; |
38 | } |
39 | |
40 | class stream_test_c_t |
41 | : public ::testing::TestWithParam< |
42 | std::tuple<dnnl_engine_kind_t, dnnl_stream_flags_t>> { |
43 | protected: |
44 | void SetUp() override { |
45 | std::tie(eng_kind, stream_flags) = GetParam(); |
46 | |
47 | if (dnnl_engine_get_count(eng_kind) == 0) return; |
48 | |
49 | DNNL_CHECK(dnnl_engine_create(&engine, eng_kind, 0)); |
50 | |
51 | // Check that the flags are compatible with the engine |
52 | if (!are_valid_flags(eng_kind, stream_flags)) { |
53 | DNNL_CHECK(dnnl_engine_destroy(engine)); |
54 | engine = nullptr; |
55 | return; |
56 | } |
57 | |
58 | DNNL_CHECK(dnnl_stream_create(&stream, engine, stream_flags)); |
59 | } |
60 | |
61 | void TearDown() override { |
62 | if (stream) { DNNL_CHECK(dnnl_stream_destroy(stream)); } |
63 | if (engine) { DNNL_CHECK(dnnl_engine_destroy(engine)); } |
64 | } |
65 | |
66 | dnnl_engine_kind_t eng_kind; |
67 | dnnl_stream_flags_t stream_flags; |
68 | |
69 | dnnl_engine_t engine = nullptr; |
70 | dnnl_stream_t stream = nullptr; |
71 | }; |
72 | |
73 | class stream_test_cpp_t |
74 | : public ::testing::TestWithParam< |
75 | std::tuple<dnnl_engine_kind_t, dnnl_stream_flags_t>> {}; |
76 | |
77 | TEST_P(stream_test_c_t, Create) { |
78 | SKIP_IF(!engine, "Engines not found or stream flags are incompatible." ); |
79 | |
80 | DNNL_CHECK(dnnl_stream_wait(stream)); |
81 | } |
82 | |
83 | TEST(stream_test_c_t, WaitNullStream) { |
84 | dnnl_stream_t stream = nullptr; |
85 | dnnl_status_t status = dnnl_stream_wait(stream); |
86 | ASSERT_EQ(status, dnnl_invalid_arguments); |
87 | } |
88 | |
89 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
90 | TEST(stream_test_c_t, Wait) { |
91 | dnnl_engine_t engine; |
92 | DNNL_CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); |
93 | |
94 | dnnl_stream_t stream; |
95 | DNNL_CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); |
96 | |
97 | DNNL_CHECK(dnnl_stream_wait(stream)); |
98 | |
99 | DNNL_CHECK(dnnl_stream_destroy(stream)); |
100 | DNNL_CHECK(dnnl_engine_destroy(engine)); |
101 | } |
102 | #endif |
103 | |
104 | TEST_P(stream_test_cpp_t, Wait) { |
105 | dnnl_engine_kind_t eng_kind_c = dnnl_cpu; |
106 | dnnl_stream_flags_t stream_flags_c = dnnl_stream_in_order; |
107 | std::tie(eng_kind_c, stream_flags_c) = GetParam(); |
108 | |
109 | engine::kind eng_kind = static_cast<engine::kind>(eng_kind_c); |
110 | stream::flags stream_flags = static_cast<stream::flags>(stream_flags_c); |
111 | SKIP_IF(engine::get_count(eng_kind) == 0, "Engines not found." ); |
112 | |
113 | engine eng(eng_kind, 0); |
114 | SKIP_IF(!are_valid_flags(static_cast<dnnl_engine_kind_t>(eng.get_kind()), |
115 | stream_flags_c), |
116 | "Incompatible stream flags." ); |
117 | |
118 | stream s(eng, stream_flags); |
119 | engine s_eng = s.get_engine(); |
120 | s.wait(); |
121 | } |
122 | |
123 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
124 | TEST(stream_test_c_t, GetStream) { |
125 | dnnl_engine_t engine; |
126 | DNNL_CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); |
127 | |
128 | dnnl_stream_t stream; |
129 | DNNL_CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); |
130 | |
131 | dnnl_engine_t stream_engine; |
132 | DNNL_CHECK(dnnl_stream_get_engine(stream, &stream_engine)); |
133 | ASSERT_EQ(engine, stream_engine); |
134 | |
135 | DNNL_CHECK(dnnl_stream_destroy(stream)); |
136 | DNNL_CHECK(dnnl_engine_destroy(engine)); |
137 | } |
138 | #endif |
139 | |
140 | namespace { |
141 | struct print_to_string_param_name_t { |
142 | template <class ParamType> |
143 | std::string operator()( |
144 | const ::testing::TestParamInfo<ParamType> &info) const { |
145 | return to_string(std::get<0>(info.param)) + "_" |
146 | + to_string(std::get<1>(info.param)); |
147 | } |
148 | }; |
149 | |
150 | auto all_params = ::testing::Combine(::testing::Values(dnnl_cpu, dnnl_gpu), |
151 | ::testing::Values(dnnl_stream_in_order, dnnl_stream_out_of_order)); |
152 | |
153 | } // namespace |
154 | |
155 | INSTANTIATE_TEST_SUITE_P(AllEngineKinds, stream_test_c_t, all_params, |
156 | print_to_string_param_name_t()); |
157 | INSTANTIATE_TEST_SUITE_P(AllEngineKinds, stream_test_cpp_t, all_params, |
158 | print_to_string_param_name_t()); |
159 | |
160 | } // namespace dnnl |
161 | |