1#include <gtest/gtest.h>
2
3#include <sstream>
4
5#include <c10/core/Device.h>
6#include <torch/csrc/lazy/backend/backend_device.h>
7#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
8#include <torch/torch.h>
9
10namespace torch {
11namespace lazy {
12
13TEST(BackendDeviceTest, BackendDeviceType) {
14 auto type = BackendDeviceType();
15
16 EXPECT_EQ(type.type, 0);
17 EXPECT_STREQ(type.toString().c_str(), "Unknown");
18}
19
20TEST(BackendDeviceTest, Basic1) {
21 auto device = BackendDevice();
22
23 EXPECT_EQ(device.ordinal(), 0);
24 if (std::getenv("LTC_TS_CUDA") != nullptr) {
25 EXPECT_EQ(device.type(), 1);
26 EXPECT_STREQ(device.toString().c_str(), "CUDA0");
27 } else {
28 EXPECT_EQ(device.type(), 0);
29 EXPECT_STREQ(device.toString().c_str(), "CPU0");
30 }
31}
32
33TEST(BackendDeviceTest, Basic2) {
34 auto type = std::make_shared<BackendDeviceType>();
35 type->type = 1;
36 auto device = BackendDevice(std::move(type), 1);
37
38 EXPECT_EQ(device.type(), 1);
39 EXPECT_EQ(device.ordinal(), 1);
40 EXPECT_STREQ(device.toString().c_str(), "Unknown1");
41}
42
43TEST(BackendDeviceTest, Basic3) {
44 struct TestType : public BackendDeviceType {
45 std::string toString() const override {
46 return "Test";
47 }
48 };
49
50 auto device = BackendDevice(std::make_shared<TestType>(), 1);
51
52 EXPECT_EQ(device.type(), 0);
53 EXPECT_EQ(device.ordinal(), 1);
54 EXPECT_STREQ(device.toString().c_str(), "Test1");
55}
56
57TEST(BackendDeviceTest, Basic4) {
58 // Seems weird to have setters in BackendImplInterface given getBackend()
59 // returns a const pointer.
60 auto default_type = getBackend()->GetDefaultDeviceType();
61 auto default_ordinal = getBackend()->GetDefaultDeviceOrdinal();
62 const_cast<BackendImplInterface*>(getBackend())
63 ->SetDefaultDeviceType(static_cast<int8_t>(c10::kCUDA));
64 const_cast<BackendImplInterface*>(getBackend())->SetDefaultDeviceOrdinal(1);
65
66 auto device = BackendDevice();
67
68 EXPECT_EQ(device.type(), 1);
69 EXPECT_EQ(device.ordinal(), 1);
70 EXPECT_STREQ(device.toString().c_str(), "CUDA1");
71
72 const_cast<BackendImplInterface*>(getBackend())
73 ->SetDefaultDeviceType(default_type->type);
74 const_cast<BackendImplInterface*>(getBackend())
75 ->SetDefaultDeviceOrdinal(default_ordinal);
76}
77
78TEST(BackendDeviceTest, Compare) {
79 auto type = std::make_shared<BackendDeviceType>();
80 type->type = 1;
81
82 auto device1 = BackendDevice(std::make_shared<BackendDeviceType>(), 1);
83 auto device2 = BackendDevice(std::move(type), 0);
84 auto device3 = BackendDevice(std::make_shared<BackendDeviceType>(), 2);
85 auto device4 = BackendDevice(std::make_shared<BackendDeviceType>(), 1);
86
87 EXPECT_NE(device1, device2);
88 EXPECT_NE(device1, device3);
89 EXPECT_EQ(device1, device4);
90 EXPECT_LT(device1, device2);
91 EXPECT_LT(device1, device3);
92}
93
94TEST(BackendDeviceTest, Ostream) {
95 auto device = BackendDevice();
96 std::stringstream ss;
97 ss << device;
98
99 EXPECT_EQ(device.toString(), ss.str());
100}
101
102TEST(BackendDeviceTest, FromAten) {
103 auto device = c10::Device(c10::kCPU);
104 EXPECT_THROW(atenDeviceToBackendDevice(device), c10::Error);
105
106 device = c10::Device(c10::kLazy);
107#ifndef FBCODE_CAFFE2
108 auto backend_device = atenDeviceToBackendDevice(device);
109#else
110 // Lazy Tensor is disabled in FBCODE until addressing non-virtual methods
111 // (e.g. sizes) in TensorImpl
112 EXPECT_THROW(atenDeviceToBackendDevice(device), c10::Error);
113#endif // FBCODE_CAFFE2
114}
115
116TEST(BackendDeviceTest, ToAten) {
117 auto device = backendDeviceToAtenDevice(BackendDevice());
118 EXPECT_EQ(device.type(), c10::kLazy);
119 EXPECT_TRUE(device.has_index());
120 EXPECT_EQ(device.index(), 0);
121}
122
123// TODO(alanwaketan): Update the following test once we have TorchScript backend
124// upstreamed.
125TEST(BackendDeviceTest, GetBackendDevice1) {
126 auto tensor = torch::rand({0, 1, 3, 0});
127 EXPECT_FALSE(GetBackendDevice(tensor));
128}
129
130TEST(BackendDeviceTest, GetBackendDevice2) {
131 auto tensor1 = torch::rand({0, 1, 3, 0});
132 auto tensor2 = torch::rand({0, 1, 3, 0});
133 // TODO(alanwaketan): Cover the test case for GetBackendDevice().
134 EXPECT_FALSE(GetBackendDevice(tensor1, tensor2));
135}
136
137} // namespace lazy
138} // namespace torch
139