1 | #include <gtest/gtest.h> |
---|---|
2 | |
3 | #include <torch/cuda.h> |
4 | |
5 | #include <iostream> |
6 | #include <string> |
7 | |
8 | std::string add_negative_flag(const std::string& flag) { |
9 | std::string filter = ::testing::GTEST_FLAG(filter); |
10 | if (filter.find('-') == std::string::npos) { |
11 | filter.push_back('-'); |
12 | } else { |
13 | filter.push_back(':'); |
14 | } |
15 | filter += flag; |
16 | return filter; |
17 | } |
18 | |
19 | int main(int argc, char* argv[]) { |
20 | ::testing::InitGoogleTest(&argc, argv); |
21 | if (!torch::cuda::is_available()) { |
22 | std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests" |
23 | << std::endl; |
24 | ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); |
25 | } else if (torch::cuda::device_count() < 2) { |
26 | std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests" |
27 | << std::endl; |
28 | ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA"); |
29 | } |
30 | |
31 | return RUN_ALL_TESTS(); |
32 | } |
33 |