1/*******************************************************************************
2* Copyright 2016-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gtest/gtest.h"
18
19#include <assert.h>
20#include <atomic>
21#include <string>
22
23#include "dnnl_test_common.hpp"
24
25#include "gtest/gtest.h"
26
27using namespace testing;
28
29static std::atomic<bool> g_is_current_test_failed(false);
30bool is_current_test_failed() {
31 return g_is_current_test_failed;
32}
33
34class assert_fail_handler_t : public EmptyTestEventListener {
35protected:
36 void OnTestStart(const TestInfo &test_info) override {
37 g_is_current_test_failed = false;
38 }
39 void OnTestPartResult(const testing::TestPartResult &part_result) override {
40 if (part_result.type() == testing::TestPartResult::kFatalFailure) {
41 g_is_current_test_failed = true;
42 }
43 }
44};
45
46class dnnl_environment_t : public ::testing::Environment {
47public:
48 void SetUp() override;
49 void TearDown() override;
50};
51
52static void test_init(int argc, char *argv[]);
53
54int main(int argc, char *argv[]) {
55 int result;
56 {
57 ::testing::InitGoogleTest(&argc, argv);
58
59 // Parse oneDNN command line arguments
60 test_init(argc, argv);
61
62 TestEventListeners &listeners = UnitTest::GetInstance()->listeners();
63
64 auto *fail_handler = new assert_fail_handler_t();
65 listeners.Append(fail_handler);
66
67 ::testing::AddGlobalTestEnvironment(new dnnl_environment_t());
68
69#if _WIN32
70 // Safety cleanup.
71 system("where /q umdh && del pre_cpu.txt");
72 system("where /q umdh && del post_cpu.txt");
73 system("where /q umdh && del memdiff_cpu.txt");
74
75 // Get first snapshot.
76 system("where /q umdh && umdh -pn:tests.exe -f:pre_cpu.txt");
77#endif
78
79 result = RUN_ALL_TESTS();
80 }
81
82#if _WIN32
83 // Get second snapshot.
84 system("where /q umdh && umdh -pn:tests.exe -f:post_cpu.txt");
85
86 // Prepare memory diff.
87 system("where /q umdh && umdh pre_cpu.txt post_cpu.txt -f:memdiff_cpu.txt");
88
89 // Cleanup.
90 system("where /q umdh && del pre_cpu.txt");
91 system("where /q umdh && del post_cpu.txt");
92#endif
93
94 return result;
95}
96
97static std::string find_cmd_option(
98 char **argv_beg, char **argv_end, const std::string &option) {
99 for (auto arg = argv_beg; arg != argv_end; arg++) {
100 std::string s(*arg);
101 auto pos = s.find(option);
102 if (pos != std::string::npos) return s.substr(pos + option.length());
103 }
104 return {};
105}
106
107inline dnnl::engine::kind to_engine_kind(const std::string &str) {
108 if (str.empty() || str == "cpu") return dnnl::engine::kind::cpu;
109
110 if (str == "gpu") return dnnl::engine::kind::gpu;
111
112 assert(!"not expected");
113 return dnnl::engine::kind::cpu;
114}
115
116// test_engine can be accessed only from tests compiled with
117// DNNL_TEST_WITH_ENGINE_PARAM macro
118#ifdef DNNL_TEST_WITH_ENGINE_PARAM
119static dnnl::engine::kind test_engine_kind;
120static std::unique_ptr<dnnl::engine> test_engine;
121
122dnnl::engine::kind get_test_engine_kind() {
123 return test_engine_kind;
124}
125
126dnnl::engine get_test_engine() {
127 return *test_engine;
128}
129
130void dnnl_environment_t::SetUp() {
131 test_engine.reset(new dnnl::engine(get_test_engine_kind(), 0));
132}
133
134void dnnl_environment_t::TearDown() {
135 test_engine.reset();
136}
137#else
138void dnnl_environment_t::SetUp() {}
139void dnnl_environment_t::TearDown() {}
140#endif
141
142void test_init(int argc, char *argv[]) {
143 auto engine_str = find_cmd_option(argv, argv + argc, "--engine=");
144#ifdef DNNL_TEST_WITH_ENGINE_PARAM
145 test_engine_kind = to_engine_kind(engine_str);
146
147 std::string filter_str = ::testing::GTEST_FLAG(filter);
148 if (test_engine_kind == dnnl::engine::kind::cpu) {
149 // Exclude non-CPU tests
150 ::testing::GTEST_FLAG(filter) = filter_str + ":-*_GPU*";
151 } else if (test_engine_kind == dnnl::engine::kind::gpu) {
152 // Exclude non-GPU tests
153 ::testing::GTEST_FLAG(filter) = filter_str + ":-*_CPU*";
154 }
155#else
156 assert(engine_str.empty()
157 && "--engine parameter is not supported by this test");
158#endif
159}
160