1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/irparser.h> |
4 | #include <torch/csrc/jit/runtime/autodiff.h> |
5 | #include <torch/csrc/jit/runtime/interpreter.h> |
6 | #include <torch/csrc/jit/testing/file_check.h> |
7 | |
8 | namespace { |
9 | static inline void trim(std::string& s) { |
10 | s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { |
11 | return !std::isspace(ch); |
12 | })); |
13 | s.erase( |
14 | std::find_if( |
15 | s.rbegin(), |
16 | s.rend(), |
17 | [](unsigned char ch) { return !std::isspace(ch); }) |
18 | .base(), |
19 | s.end()); |
20 | for (size_t i = 0; i < s.size(); ++i) { |
21 | while (i < s.size() && s[i] == '\n') { |
22 | s.erase(i, 1); |
23 | } |
24 | } |
25 | for (size_t i = 0; i < s.size(); ++i) { |
26 | if (s[i] == ' ') { |
27 | while (i + 1 < s.size() && s[i + 1] == ' ') { |
28 | s.erase(i + 1, 1); |
29 | } |
30 | } |
31 | } |
32 | } |
33 | } // namespace |
34 | |
35 | #define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \ |
36 | try { \ |
37 | (void)statement; \ |
38 | FAIL(); \ |
39 | } catch (const std::exception& e) { \ |
40 | std::string substring_s(substring); \ |
41 | trim(substring_s); \ |
42 | auto exception_string = std::string(e.what()); \ |
43 | trim(exception_string); \ |
44 | ASSERT_NE(exception_string.find(substring_s), std::string::npos) \ |
45 | << " Error was: \n" \ |
46 | << exception_string; \ |
47 | } |
48 | |
49 | namespace torch { |
50 | namespace jit { |
51 | |
52 | using tensor_list = std::vector<at::Tensor>; |
53 | using namespace torch::autograd; |
54 | |
55 | // work around the fact that variable_tensor_list doesn't duplicate all |
56 | // of std::vector's constructors. |
57 | // most constructors are never used in the implementation, just in our tests. |
58 | Stack createStack(std::vector<at::Tensor>&& list); |
59 | |
60 | void assertAllClose(const tensor_list& a, const tensor_list& b); |
61 | |
62 | std::vector<at::Tensor> run( |
63 | InterpreterState& interp, |
64 | const std::vector<at::Tensor>& inputs); |
65 | |
66 | std::pair<tensor_list, tensor_list> runGradient( |
67 | Gradient& grad_spec, |
68 | tensor_list& tensors_in, |
69 | tensor_list& tensor_grads_in); |
70 | |
71 | std::shared_ptr<Graph> build_lstm(); |
72 | std::shared_ptr<Graph> build_mobile_export_analysis_graph(); |
73 | std::shared_ptr<Graph> build_mobile_export_with_out(); |
74 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg(); |
75 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested(); |
76 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const(); |
77 | |
78 | at::Tensor t_use(at::Tensor x); |
79 | at::Tensor t_def(at::Tensor x); |
80 | |
81 | // given the difference of output vs expected tensor, check whether the |
82 | // difference is within a relative tolerance range. This is a standard way of |
83 | // matching tensor values up to certain precision |
84 | bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs); |
85 | bool almostEqual(const at::Tensor& a, const at::Tensor& b); |
86 | |
87 | bool exactlyEqual(const at::Tensor& a, const at::Tensor& b); |
88 | bool exactlyEqual( |
89 | const std::vector<at::Tensor>& a, |
90 | const std::vector<at::Tensor>& b); |
91 | |
92 | std::vector<at::Tensor> runGraph( |
93 | std::shared_ptr<Graph> graph, |
94 | const std::vector<at::Tensor>& inputs); |
95 | |
96 | std::pair<at::Tensor, at::Tensor> lstm( |
97 | at::Tensor input, |
98 | at::Tensor hx, |
99 | at::Tensor cx, |
100 | at::Tensor w_ih, |
101 | at::Tensor w_hh); |
102 | |
103 | } // namespace jit |
104 | } // namespace torch |
105 | |