1#include <test/cpp/lazy/test_lazy_ops_util.h>
2
3#include <torch/csrc/lazy/backend/lowering_context.h>
4#include <torch/csrc/lazy/core/ir_builder.h>
5#include <torch/csrc/lazy/core/ir_dump_util.h>
6#include <torch/csrc/lazy/core/tensor_impl.h>
7
8#include <iostream>
9#include <string>
10
11namespace torch {
12namespace lazy {
13namespace {
14
15bool IsLtcTensor(const at::Tensor& tensor) {
16 return dynamic_cast<torch::lazy::LTCTensorImpl*>(
17 tensor.unsafeGetTensorImpl());
18}
19
20std::unordered_set<std::string>* CreateIgnoredCounters() {
21 std::unordered_set<std::string>* icounters =
22 new std::unordered_set<std::string>();
23 // Add below the counters whose name need to be ignored when doing
24 // is-any-counter-changed assertins.
25 icounters->insert("aten::rand");
26 return icounters;
27}
28
29} // namespace
30
31const std::unordered_set<std::string>* GetIgnoredCounters() {
32 static const std::unordered_set<std::string>* icounters =
33 CreateIgnoredCounters();
34 return icounters;
35}
36
37at::Tensor ToCpuTensor(const at::Tensor& tensor) {
38 // tensor.to() implicitly triggers a sync if t.device=torch::kLazy.
39 return tensor.to(torch::kCPU);
40}
41
42torch::Tensor CopyToDevice(
43 const torch::Tensor& tensor,
44 const torch::Device& device) {
45 return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true);
46}
47
48bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
49 tensor1 = ToCpuTensor(tensor1);
50 tensor2 = ToCpuTensor(tensor2);
51 if (torch::isnan(tensor1).any().item<bool>()) {
52 EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
53 tensor1.nan_to_num_();
54 tensor2.nan_to_num_();
55 }
56 if (tensor1.sizes() != tensor2.sizes() ||
57 tensor1.dtype() != tensor2.dtype()) {
58 std::cerr << "Different shape:\n"
59 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
60 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
61 return false;
62 }
63 at::ScalarType type1 = tensor1.scalar_type();
64 at::ScalarType type2 = tensor2.scalar_type();
65 if (type1 != type2) {
66 tensor1 = tensor1.toType(type2);
67 }
68 bool equal = tensor1.equal(tensor2);
69 return equal;
70}
71
72bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
73 tensor1 = ToCpuTensor(tensor1);
74 tensor2 = ToCpuTensor(tensor2);
75 if (tensor1.sizes() != tensor2.sizes()) {
76 std::cerr << "Different shape:\n"
77 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
78 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
79 return false;
80 }
81 at::ScalarType type1 = tensor1.scalar_type();
82 at::ScalarType type2 = tensor2.scalar_type();
83 if (type1 != type2) {
84 tensor1 = tensor1.toType(type2);
85 }
86 bool equal = tensor1.equal(tensor2);
87 return equal;
88}
89
90void ForEachDevice(const std::function<void(const torch::Device&)>& devfn) {
91 // Currently TorchScript backend only supports one type of hardware per
92 // process, which is set by env. And the ordinal is always 0 given distributed
93 // training/ multi-device is not supported yet.
94 auto device = torch::lazy::BackendDevice();
95 torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device);
96 devfn(torch_device);
97}
98
99bool CloseValues(
100 at::Tensor tensor1,
101 at::Tensor tensor2,
102 double rtol,
103 double atol) {
104 tensor1 = ToCpuTensor(tensor1);
105 tensor2 = ToCpuTensor(tensor2);
106 if (torch::isnan(tensor1).any().item<bool>()) {
107 EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
108 tensor1.nan_to_num_();
109 tensor2.nan_to_num_();
110 }
111 if (tensor1.sizes() != tensor2.sizes() ||
112 tensor1.dtype() != tensor2.dtype()) {
113 std::cerr << "Different shape:\n"
114 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
115 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
116 return false;
117 }
118 bool equal = tensor1.allclose(tensor2, rtol, atol);
119 return equal;
120}
121
122std::string GetTensorTextGraph(at::Tensor tensor) {
123 torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
124 return torch::lazy::DumpUtil::ToText({lazy_tensor->GetIrValue().node.get()});
125}
126
127std::string GetTensorDotGraph(at::Tensor tensor) {
128 torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
129 return torch::lazy::DumpUtil::ToDot({lazy_tensor->GetIrValue().node.get()});
130}
131
132void TestBackward(
133 const std::vector<torch::Tensor>& inputs,
134 const torch::Device& device,
135 const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
136 testfn,
137 double rtol,
138 double atol,
139 int derivative_level) {
140 std::vector<torch::Tensor> input_vars;
141 std::vector<torch::Tensor> xinput_vars;
142 std::vector<torch::Tensor> inputs_w_grad;
143 std::vector<torch::Tensor> xinputs_w_grad;
144 for (size_t i = 0; i < inputs.size(); ++i) {
145 const torch::Tensor& input = inputs[i];
146 if (input.defined()) {
147 torch::Tensor oinput =
148 input.clone().detach().set_requires_grad(input.requires_grad());
149 input_vars.push_back(oinput);
150
151 torch::Tensor xinput = CopyToDevice(input, device)
152 .detach()
153 .set_requires_grad(input.requires_grad());
154 xinput_vars.push_back(xinput);
155 if (input.requires_grad()) {
156 inputs_w_grad.push_back(oinput);
157 xinputs_w_grad.push_back(xinput);
158 }
159 } else {
160 input_vars.emplace_back();
161 xinput_vars.emplace_back();
162 }
163 }
164
165 torch::Tensor output = testfn(input_vars);
166 torch::Tensor xoutput = testfn(xinput_vars);
167 torch::lazy::AllClose(output, xoutput, rtol, atol);
168
169 std::vector<torch::Tensor> outs = {output};
170 std::vector<torch::Tensor> xouts = {xoutput};
171 for (int d = 1; d <= derivative_level; ++d) {
172 // Check grad of sum(outs) w.r.t inputs_w_grad.
173 torch::Tensor sum = torch::zeros_like(outs[0]).sum();
174 torch::Tensor xsum = torch::zeros_like(xouts[0]).sum();
175 for (size_t i = 0; i < outs.size(); ++i) {
176 if (outs[i].requires_grad()) {
177 sum += outs[i].sum();
178 xsum += xouts[i].sum();
179 }
180 }
181 // Calculating higher order derivative requires create_graph=true
182 bool create_graph = d != derivative_level;
183 outs = torch::autograd::grad(
184 {sum},
185 inputs_w_grad,
186 /*grad_outputs=*/{},
187 /*retain_graph=*/c10::nullopt,
188 /*create_graph=*/create_graph,
189 /*allow_unused=*/true);
190 xouts = torch::autograd::grad(
191 {xsum},
192 xinputs_w_grad,
193 /*grad_outputs=*/{},
194 /*retain_graph=*/c10::nullopt,
195 /*create_graph=*/create_graph,
196 /*allow_unused=*/true);
197 for (size_t i = 0; i < outs.size(); ++i) {
198 ASSERT_EQ(outs[i].defined(), xouts[i].defined());
199 if (outs[i].defined()) {
200 AllClose(outs[i], xouts[i], rtol, atol);
201 }
202 }
203 }
204}
205
206} // namespace lazy
207} // namespace torch
208