1#pragma once
2
3#include <torch/csrc/jit/ir/irparser.h>
4#include <torch/csrc/jit/runtime/autodiff.h>
5#include <torch/csrc/jit/runtime/interpreter.h>
6#include <torch/csrc/jit/testing/file_check.h>
7
8namespace {
9static inline void trim(std::string& s) {
10 s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
11 return !std::isspace(ch);
12 }));
13 s.erase(
14 std::find_if(
15 s.rbegin(),
16 s.rend(),
17 [](unsigned char ch) { return !std::isspace(ch); })
18 .base(),
19 s.end());
20 for (size_t i = 0; i < s.size(); ++i) {
21 while (i < s.size() && s[i] == '\n') {
22 s.erase(i, 1);
23 }
24 }
25 for (size_t i = 0; i < s.size(); ++i) {
26 if (s[i] == ' ') {
27 while (i + 1 < s.size() && s[i + 1] == ' ') {
28 s.erase(i + 1, 1);
29 }
30 }
31 }
32}
33} // namespace
34
35#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
36 try { \
37 (void)statement; \
38 FAIL(); \
39 } catch (const std::exception& e) { \
40 std::string substring_s(substring); \
41 trim(substring_s); \
42 auto exception_string = std::string(e.what()); \
43 trim(exception_string); \
44 ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
45 << " Error was: \n" \
46 << exception_string; \
47 }
48
49namespace torch {
50namespace jit {
51
52using tensor_list = std::vector<at::Tensor>;
53using namespace torch::autograd;
54
55// work around the fact that variable_tensor_list doesn't duplicate all
56// of std::vector's constructors.
57// most constructors are never used in the implementation, just in our tests.
58Stack createStack(std::vector<at::Tensor>&& list);
59
60void assertAllClose(const tensor_list& a, const tensor_list& b);
61
62std::vector<at::Tensor> run(
63 InterpreterState& interp,
64 const std::vector<at::Tensor>& inputs);
65
66std::pair<tensor_list, tensor_list> runGradient(
67 Gradient& grad_spec,
68 tensor_list& tensors_in,
69 tensor_list& tensor_grads_in);
70
71std::shared_ptr<Graph> build_lstm();
72std::shared_ptr<Graph> build_mobile_export_analysis_graph();
73std::shared_ptr<Graph> build_mobile_export_with_out();
74std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
75std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
76std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();
77
78at::Tensor t_use(at::Tensor x);
79at::Tensor t_def(at::Tensor x);
80
81// given the difference of output vs expected tensor, check whether the
82// difference is within a relative tolerance range. This is a standard way of
83// matching tensor values up to certain precision
84bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
85bool almostEqual(const at::Tensor& a, const at::Tensor& b);
86
87bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
88bool exactlyEqual(
89 const std::vector<at::Tensor>& a,
90 const std::vector<at::Tensor>& b);
91
92std::vector<at::Tensor> runGraph(
93 std::shared_ptr<Graph> graph,
94 const std::vector<at::Tensor>& inputs);
95
96std::pair<at::Tensor, at::Tensor> lstm(
97 at::Tensor input,
98 at::Tensor hx,
99 at::Tensor cx,
100 at::Tensor w_ih,
101 at::Tensor w_hh);
102
103} // namespace jit
104} // namespace torch
105