1 | #pragma once |
2 | |
3 | #include <gtest/gtest.h> |
4 | #include <torch/csrc/lazy/backend/backend_device.h> |
5 | #include <torch/csrc/lazy/core/debug_util.h> |
6 | #include <torch/csrc/lazy/core/ir.h> |
7 | #include <torch/csrc/lazy/core/tensor.h> |
8 | #include <torch/torch.h> |
9 | |
10 | #include <cmath> |
11 | #include <functional> |
12 | #include <string> |
13 | #include <unordered_set> |
14 | |
15 | namespace torch { |
16 | namespace lazy { |
17 | |
18 | const std::unordered_set<std::string>* GetIgnoredCounters(); |
19 | |
20 | // Converts an at::Tensor(device=torch::kLazy) to at::Tensor(device=torch::kCPU) |
21 | // This at::Tensor can be torch::Tensor which is a Variable, or at::Tensor which |
22 | // know nothing about autograd. If the input tensor is already a CPU tensor, it |
23 | // will be returned. Needed because EqualValues and AllClose require CPU tensors |
24 | // on both sides. |
25 | at::Tensor ToCpuTensor(const at::Tensor& tensor); |
26 | |
27 | // Helper function to copy a tensor to device. |
28 | torch::Tensor CopyToDevice( |
29 | const torch::Tensor& tensor, |
30 | const torch::Device& device); |
31 | |
32 | bool EqualValues(at::Tensor tensor1, at::Tensor tensor2); |
33 | |
34 | bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2); |
35 | |
36 | bool CloseValues( |
37 | at::Tensor tensor1, |
38 | at::Tensor tensor2, |
39 | double rtol = 1e-5, |
40 | double atol = 1e-8); |
41 | |
42 | static inline void AllClose( |
43 | at::Tensor tensor, |
44 | at::Tensor xla_tensor, |
45 | double rtol = 1e-5, |
46 | double atol = 1e-8) { |
47 | EXPECT_TRUE(CloseValues(tensor, xla_tensor, rtol, atol)); |
48 | } |
49 | |
50 | static inline void AllClose( |
51 | at::Tensor tensor, |
52 | torch::lazy::LazyTensor& xla_tensor, |
53 | double rtol = 1e-5, |
54 | double atol = 1e-8) { |
55 | EXPECT_TRUE( |
56 | CloseValues(tensor, xla_tensor.ToTensor(/*detached=*/false), rtol, atol)); |
57 | } |
58 | |
59 | static inline void AllEqual(at::Tensor tensor, at::Tensor xla_tensor) { |
60 | EXPECT_TRUE(EqualValues(tensor, xla_tensor)); |
61 | } |
62 | |
63 | void ForEachDevice(const std::function<void(const torch::Device&)>& devfn); |
64 | |
65 | std::string GetTensorTextGraph(at::Tensor tensor); |
66 | |
67 | std::string GetTensorDotGraph(at::Tensor tensor); |
68 | |
69 | std::string GetTensorHloGraph(at::Tensor tensor); |
70 | |
71 | void TestBackward( |
72 | const std::vector<torch::Tensor>& inputs, |
73 | const torch::Device& device, |
74 | const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>& |
75 | testfn, |
76 | double rtol = 1e-5, |
77 | double atol = 1e-8, |
78 | int derivative_level = 1); |
79 | |
80 | } // namespace lazy |
81 | } // namespace torch |
82 | |