1#include <gtest/gtest.h>
2
3#include <c10/core/Device.h>
4#include <c10/util/Exception.h>
5
6// -- Device -------------------------------------------------------
7
8struct ExpectedDeviceTestResult {
9 std::string device_string;
10 c10::DeviceType device_type;
11 c10::DeviceIndex device_index;
12};
13
14TEST(DeviceTest, BasicConstruction) {
15 std::vector<ExpectedDeviceTestResult> valid_devices = {
16 {"cpu", c10::DeviceType::CPU, -1},
17 {"cuda", c10::DeviceType::CUDA, -1},
18 {"cpu:0", c10::DeviceType::CPU, 0},
19 {"cuda:0", c10::DeviceType::CUDA, 0},
20 {"cuda:1", c10::DeviceType::CUDA, 1},
21 };
22 std::vector<std::string> invalid_device_strings = {
23 "cpu:x",
24 "cpu:foo",
25 "cuda:cuda",
26 "cuda:",
27 "cpu:0:0",
28 "cpu:0:",
29 "cpu:-1",
30 "::",
31 ":",
32 "cpu:00",
33 "cpu:01"};
34
35 for (auto& ds : valid_devices) {
36 c10::Device d(ds.device_string);
37 ASSERT_EQ(d.type(), ds.device_type)
38 << "Device String: " << ds.device_string;
39 ASSERT_EQ(d.index(), ds.device_index)
40 << "Device String: " << ds.device_string;
41 }
42
43 auto make_device = [](const std::string& ds) { return c10::Device(ds); };
44
45 for (auto& ds : invalid_device_strings) {
46 EXPECT_THROW(make_device(ds), c10::Error) << "Device String: " << ds;
47 }
48}
49