1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/util/Exception.h> |
5 | |
6 | // -- Device ------------------------------------------------------- |
7 | |
8 | struct ExpectedDeviceTestResult { |
9 | std::string device_string; |
10 | c10::DeviceType device_type; |
11 | c10::DeviceIndex device_index; |
12 | }; |
13 | |
14 | TEST(DeviceTest, BasicConstruction) { |
15 | std::vector<ExpectedDeviceTestResult> valid_devices = { |
16 | {"cpu" , c10::DeviceType::CPU, -1}, |
17 | {"cuda" , c10::DeviceType::CUDA, -1}, |
18 | {"cpu:0" , c10::DeviceType::CPU, 0}, |
19 | {"cuda:0" , c10::DeviceType::CUDA, 0}, |
20 | {"cuda:1" , c10::DeviceType::CUDA, 1}, |
21 | }; |
22 | std::vector<std::string> invalid_device_strings = { |
23 | "cpu:x" , |
24 | "cpu:foo" , |
25 | "cuda:cuda" , |
26 | "cuda:" , |
27 | "cpu:0:0" , |
28 | "cpu:0:" , |
29 | "cpu:-1" , |
30 | "::" , |
31 | ":" , |
32 | "cpu:00" , |
33 | "cpu:01" }; |
34 | |
35 | for (auto& ds : valid_devices) { |
36 | c10::Device d(ds.device_string); |
37 | ASSERT_EQ(d.type(), ds.device_type) |
38 | << "Device String: " << ds.device_string; |
39 | ASSERT_EQ(d.index(), ds.device_index) |
40 | << "Device String: " << ds.device_string; |
41 | } |
42 | |
43 | auto make_device = [](const std::string& ds) { return c10::Device(ds); }; |
44 | |
45 | for (auto& ds : invalid_device_strings) { |
46 | EXPECT_THROW(make_device(ds), c10::Error) << "Device String: " << ds; |
47 | } |
48 | } |
49 | |