1 | #include <test/cpp/lazy/test_lazy_ops_util.h> |
2 | |
3 | #include <torch/csrc/lazy/backend/lowering_context.h> |
4 | #include <torch/csrc/lazy/core/ir_builder.h> |
5 | #include <torch/csrc/lazy/core/ir_dump_util.h> |
6 | #include <torch/csrc/lazy/core/tensor_impl.h> |
7 | |
8 | #include <iostream> |
9 | #include <string> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | namespace { |
14 | |
15 | bool IsLtcTensor(const at::Tensor& tensor) { |
16 | return dynamic_cast<torch::lazy::LTCTensorImpl*>( |
17 | tensor.unsafeGetTensorImpl()); |
18 | } |
19 | |
20 | std::unordered_set<std::string>* CreateIgnoredCounters() { |
21 | std::unordered_set<std::string>* icounters = |
22 | new std::unordered_set<std::string>(); |
23 | // Add below the counters whose name need to be ignored when doing |
24 | // is-any-counter-changed assertins. |
25 | icounters->insert("aten::rand" ); |
26 | return icounters; |
27 | } |
28 | |
29 | } // namespace |
30 | |
31 | const std::unordered_set<std::string>* GetIgnoredCounters() { |
32 | static const std::unordered_set<std::string>* icounters = |
33 | CreateIgnoredCounters(); |
34 | return icounters; |
35 | } |
36 | |
37 | at::Tensor ToCpuTensor(const at::Tensor& tensor) { |
38 | // tensor.to() implicitly triggers a sync if t.device=torch::kLazy. |
39 | return tensor.to(torch::kCPU); |
40 | } |
41 | |
42 | torch::Tensor CopyToDevice( |
43 | const torch::Tensor& tensor, |
44 | const torch::Device& device) { |
45 | return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true); |
46 | } |
47 | |
48 | bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) { |
49 | tensor1 = ToCpuTensor(tensor1); |
50 | tensor2 = ToCpuTensor(tensor2); |
51 | if (torch::isnan(tensor1).any().item<bool>()) { |
52 | EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2))); |
53 | tensor1.nan_to_num_(); |
54 | tensor2.nan_to_num_(); |
55 | } |
56 | if (tensor1.sizes() != tensor2.sizes() || |
57 | tensor1.dtype() != tensor2.dtype()) { |
58 | std::cerr << "Different shape:\n" |
59 | << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n" |
60 | << tensor2.dtype() << " " << tensor2.sizes() << "\n" ; |
61 | return false; |
62 | } |
63 | at::ScalarType type1 = tensor1.scalar_type(); |
64 | at::ScalarType type2 = tensor2.scalar_type(); |
65 | if (type1 != type2) { |
66 | tensor1 = tensor1.toType(type2); |
67 | } |
68 | bool equal = tensor1.equal(tensor2); |
69 | return equal; |
70 | } |
71 | |
72 | bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) { |
73 | tensor1 = ToCpuTensor(tensor1); |
74 | tensor2 = ToCpuTensor(tensor2); |
75 | if (tensor1.sizes() != tensor2.sizes()) { |
76 | std::cerr << "Different shape:\n" |
77 | << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n" |
78 | << tensor2.dtype() << " " << tensor2.sizes() << "\n" ; |
79 | return false; |
80 | } |
81 | at::ScalarType type1 = tensor1.scalar_type(); |
82 | at::ScalarType type2 = tensor2.scalar_type(); |
83 | if (type1 != type2) { |
84 | tensor1 = tensor1.toType(type2); |
85 | } |
86 | bool equal = tensor1.equal(tensor2); |
87 | return equal; |
88 | } |
89 | |
90 | void ForEachDevice(const std::function<void(const torch::Device&)>& devfn) { |
91 | // Currently TorchScript backend only supports one type of hardware per |
92 | // process, which is set by env. And the ordinal is always 0 given distributed |
93 | // training/ multi-device is not supported yet. |
94 | auto device = torch::lazy::BackendDevice(); |
95 | torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device); |
96 | devfn(torch_device); |
97 | } |
98 | |
99 | bool CloseValues( |
100 | at::Tensor tensor1, |
101 | at::Tensor tensor2, |
102 | double rtol, |
103 | double atol) { |
104 | tensor1 = ToCpuTensor(tensor1); |
105 | tensor2 = ToCpuTensor(tensor2); |
106 | if (torch::isnan(tensor1).any().item<bool>()) { |
107 | EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2))); |
108 | tensor1.nan_to_num_(); |
109 | tensor2.nan_to_num_(); |
110 | } |
111 | if (tensor1.sizes() != tensor2.sizes() || |
112 | tensor1.dtype() != tensor2.dtype()) { |
113 | std::cerr << "Different shape:\n" |
114 | << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n" |
115 | << tensor2.dtype() << " " << tensor2.sizes() << "\n" ; |
116 | return false; |
117 | } |
118 | bool equal = tensor1.allclose(tensor2, rtol, atol); |
119 | return equal; |
120 | } |
121 | |
122 | std::string GetTensorTextGraph(at::Tensor tensor) { |
123 | torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor); |
124 | return torch::lazy::DumpUtil::ToText({lazy_tensor->GetIrValue().node.get()}); |
125 | } |
126 | |
127 | std::string GetTensorDotGraph(at::Tensor tensor) { |
128 | torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor); |
129 | return torch::lazy::DumpUtil::ToDot({lazy_tensor->GetIrValue().node.get()}); |
130 | } |
131 | |
132 | void TestBackward( |
133 | const std::vector<torch::Tensor>& inputs, |
134 | const torch::Device& device, |
135 | const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>& |
136 | testfn, |
137 | double rtol, |
138 | double atol, |
139 | int derivative_level) { |
140 | std::vector<torch::Tensor> input_vars; |
141 | std::vector<torch::Tensor> xinput_vars; |
142 | std::vector<torch::Tensor> inputs_w_grad; |
143 | std::vector<torch::Tensor> xinputs_w_grad; |
144 | for (size_t i = 0; i < inputs.size(); ++i) { |
145 | const torch::Tensor& input = inputs[i]; |
146 | if (input.defined()) { |
147 | torch::Tensor oinput = |
148 | input.clone().detach().set_requires_grad(input.requires_grad()); |
149 | input_vars.push_back(oinput); |
150 | |
151 | torch::Tensor xinput = CopyToDevice(input, device) |
152 | .detach() |
153 | .set_requires_grad(input.requires_grad()); |
154 | xinput_vars.push_back(xinput); |
155 | if (input.requires_grad()) { |
156 | inputs_w_grad.push_back(oinput); |
157 | xinputs_w_grad.push_back(xinput); |
158 | } |
159 | } else { |
160 | input_vars.emplace_back(); |
161 | xinput_vars.emplace_back(); |
162 | } |
163 | } |
164 | |
165 | torch::Tensor output = testfn(input_vars); |
166 | torch::Tensor xoutput = testfn(xinput_vars); |
167 | torch::lazy::AllClose(output, xoutput, rtol, atol); |
168 | |
169 | std::vector<torch::Tensor> outs = {output}; |
170 | std::vector<torch::Tensor> xouts = {xoutput}; |
171 | for (int d = 1; d <= derivative_level; ++d) { |
172 | // Check grad of sum(outs) w.r.t inputs_w_grad. |
173 | torch::Tensor sum = torch::zeros_like(outs[0]).sum(); |
174 | torch::Tensor xsum = torch::zeros_like(xouts[0]).sum(); |
175 | for (size_t i = 0; i < outs.size(); ++i) { |
176 | if (outs[i].requires_grad()) { |
177 | sum += outs[i].sum(); |
178 | xsum += xouts[i].sum(); |
179 | } |
180 | } |
181 | // Calculating higher order derivative requires create_graph=true |
182 | bool create_graph = d != derivative_level; |
183 | outs = torch::autograd::grad( |
184 | {sum}, |
185 | inputs_w_grad, |
186 | /*grad_outputs=*/{}, |
187 | /*retain_graph=*/c10::nullopt, |
188 | /*create_graph=*/create_graph, |
189 | /*allow_unused=*/true); |
190 | xouts = torch::autograd::grad( |
191 | {xsum}, |
192 | xinputs_w_grad, |
193 | /*grad_outputs=*/{}, |
194 | /*retain_graph=*/c10::nullopt, |
195 | /*create_graph=*/create_graph, |
196 | /*allow_unused=*/true); |
197 | for (size_t i = 0; i < outs.size(); ++i) { |
198 | ASSERT_EQ(outs[i].defined(), xouts[i].defined()); |
199 | if (outs[i].defined()) { |
200 | AllClose(outs[i], xouts[i], rtol, atol); |
201 | } |
202 | } |
203 | } |
204 | } |
205 | |
206 | } // namespace lazy |
207 | } // namespace torch |
208 | |