1 | /******************************************************************************* |
2 | * Copyright 2016-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gtest/gtest.h" |
18 | |
19 | #include <assert.h> |
20 | #include <atomic> |
21 | #include <string> |
22 | |
23 | #include "dnnl_test_common.hpp" |
24 | |
25 | #include "gtest/gtest.h" |
26 | |
27 | using namespace testing; |
28 | |
29 | static std::atomic<bool> g_is_current_test_failed(false); |
30 | bool is_current_test_failed() { |
31 | return g_is_current_test_failed; |
32 | } |
33 | |
34 | class assert_fail_handler_t : public EmptyTestEventListener { |
35 | protected: |
36 | void OnTestStart(const TestInfo &test_info) override { |
37 | g_is_current_test_failed = false; |
38 | } |
39 | void OnTestPartResult(const testing::TestPartResult &part_result) override { |
40 | if (part_result.type() == testing::TestPartResult::kFatalFailure) { |
41 | g_is_current_test_failed = true; |
42 | } |
43 | } |
44 | }; |
45 | |
46 | class dnnl_environment_t : public ::testing::Environment { |
47 | public: |
48 | void SetUp() override; |
49 | void TearDown() override; |
50 | }; |
51 | |
52 | static void test_init(int argc, char *argv[]); |
53 | |
54 | int main(int argc, char *argv[]) { |
55 | int result; |
56 | { |
57 | ::testing::InitGoogleTest(&argc, argv); |
58 | |
59 | // Parse oneDNN command line arguments |
60 | test_init(argc, argv); |
61 | |
62 | TestEventListeners &listeners = UnitTest::GetInstance()->listeners(); |
63 | |
64 | auto *fail_handler = new assert_fail_handler_t(); |
65 | listeners.Append(fail_handler); |
66 | |
67 | ::testing::AddGlobalTestEnvironment(new dnnl_environment_t()); |
68 | |
69 | #if _WIN32 |
70 | // Safety cleanup. |
71 | system("where /q umdh && del pre_cpu.txt" ); |
72 | system("where /q umdh && del post_cpu.txt" ); |
73 | system("where /q umdh && del memdiff_cpu.txt" ); |
74 | |
75 | // Get first snapshot. |
76 | system("where /q umdh && umdh -pn:tests.exe -f:pre_cpu.txt" ); |
77 | #endif |
78 | |
79 | result = RUN_ALL_TESTS(); |
80 | } |
81 | |
82 | #if _WIN32 |
83 | // Get second snapshot. |
84 | system("where /q umdh && umdh -pn:tests.exe -f:post_cpu.txt" ); |
85 | |
86 | // Prepare memory diff. |
87 | system("where /q umdh && umdh pre_cpu.txt post_cpu.txt -f:memdiff_cpu.txt" ); |
88 | |
89 | // Cleanup. |
90 | system("where /q umdh && del pre_cpu.txt" ); |
91 | system("where /q umdh && del post_cpu.txt" ); |
92 | #endif |
93 | |
94 | return result; |
95 | } |
96 | |
97 | static std::string find_cmd_option( |
98 | char **argv_beg, char **argv_end, const std::string &option) { |
99 | for (auto arg = argv_beg; arg != argv_end; arg++) { |
100 | std::string s(*arg); |
101 | auto pos = s.find(option); |
102 | if (pos != std::string::npos) return s.substr(pos + option.length()); |
103 | } |
104 | return {}; |
105 | } |
106 | |
107 | inline dnnl::engine::kind to_engine_kind(const std::string &str) { |
108 | if (str.empty() || str == "cpu" ) return dnnl::engine::kind::cpu; |
109 | |
110 | if (str == "gpu" ) return dnnl::engine::kind::gpu; |
111 | |
112 | assert(!"not expected" ); |
113 | return dnnl::engine::kind::cpu; |
114 | } |
115 | |
116 | // test_engine can be accessed only from tests compiled with |
117 | // DNNL_TEST_WITH_ENGINE_PARAM macro |
118 | #ifdef DNNL_TEST_WITH_ENGINE_PARAM |
119 | static dnnl::engine::kind test_engine_kind; |
120 | static std::unique_ptr<dnnl::engine> test_engine; |
121 | |
122 | dnnl::engine::kind get_test_engine_kind() { |
123 | return test_engine_kind; |
124 | } |
125 | |
126 | dnnl::engine get_test_engine() { |
127 | return *test_engine; |
128 | } |
129 | |
130 | void dnnl_environment_t::SetUp() { |
131 | test_engine.reset(new dnnl::engine(get_test_engine_kind(), 0)); |
132 | } |
133 | |
134 | void dnnl_environment_t::TearDown() { |
135 | test_engine.reset(); |
136 | } |
137 | #else |
138 | void dnnl_environment_t::SetUp() {} |
139 | void dnnl_environment_t::TearDown() {} |
140 | #endif |
141 | |
142 | void test_init(int argc, char *argv[]) { |
143 | auto engine_str = find_cmd_option(argv, argv + argc, "--engine=" ); |
144 | #ifdef DNNL_TEST_WITH_ENGINE_PARAM |
145 | test_engine_kind = to_engine_kind(engine_str); |
146 | |
147 | std::string filter_str = ::testing::GTEST_FLAG(filter); |
148 | if (test_engine_kind == dnnl::engine::kind::cpu) { |
149 | // Exclude non-CPU tests |
150 | ::testing::GTEST_FLAG(filter) = filter_str + ":-*_GPU*" ; |
151 | } else if (test_engine_kind == dnnl::engine::kind::gpu) { |
152 | // Exclude non-GPU tests |
153 | ::testing::GTEST_FLAG(filter) = filter_str + ":-*_CPU*" ; |
154 | } |
155 | #else |
156 | assert(engine_str.empty() |
157 | && "--engine parameter is not supported by this test" ); |
158 | #endif |
159 | } |
160 | |