1#include <gtest/gtest.h>
2
3#include <torch/cuda.h>
4
5#include <iostream>
6#include <string>
7
8std::string add_negative_flag(const std::string& flag) {
9 std::string filter = ::testing::GTEST_FLAG(filter);
10 if (filter.find('-') == std::string::npos) {
11 filter.push_back('-');
12 } else {
13 filter.push_back(':');
14 }
15 filter += flag;
16 return filter;
17}
18
19int main(int argc, char* argv[]) {
20 ::testing::InitGoogleTest(&argc, argv);
21 if (!torch::cuda::is_available()) {
22 std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests"
23 << std::endl;
24 ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
25 } else if (torch::cuda::device_count() < 2) {
26 std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests"
27 << std::endl;
28 ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA");
29 }
30
31 return RUN_ALL_TESTS();
32}
33