1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/clear_undefinedness.h>
6#include <torch/csrc/jit/runtime/custom_operator.h>
7
8namespace torch {
9namespace jit {
10
11Stack createStack(std::vector<at::Tensor>&& list) {
12 return Stack(
13 std::make_move_iterator(list.begin()),
14 std::make_move_iterator(list.end()));
15}
16
17void assertAllClose(const tensor_list& a, const tensor_list& b) {
18 ASSERT_EQ(a.size(), b.size());
19 for (size_t i = 0; i < a.size(); ++i) {
20 ASSERT_TRUE(a[i].is_same_size(b[i]));
21 ASSERT_TRUE(a[i].allclose(b[i]));
22 }
23}
24
25std::vector<at::Tensor> run(
26 InterpreterState& interp,
27 const std::vector<at::Tensor>& inputs) {
28 std::vector<IValue> stack(inputs.begin(), inputs.end());
29 interp.run(stack);
30 return fmap(stack, [](const IValue& i) { return i.toTensor(); });
31}
32
33static void unpackReturnTuple(Stack& stack) {
34 auto tuple = pop(stack).toTuple();
35 stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
36}
37
38std::pair<tensor_list, tensor_list> runGradient(
39 Gradient& grad_spec,
40 tensor_list& tensors_in,
41 tensor_list& tensor_grads_in) {
42 static const auto as_tensorlist = [](const Stack& stack) {
43 return fmap(stack, [](const IValue& i) { return i.toTensor(); });
44 };
45 ClearUndefinedness(grad_spec.df);
46 Code f_code{grad_spec.f, ""}, df_code{grad_spec.df, ""};
47 InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
48
49 auto f_stack = fmap<IValue>(tensors_in);
50 f_interpreter.run(f_stack);
51
52 Stack df_stack;
53 df_stack.insert(
54 df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
55 for (auto offset : grad_spec.df_input_captured_inputs)
56 df_stack.push_back(tensors_in[offset]);
57 for (auto offset : grad_spec.df_input_captured_outputs)
58 df_stack.push_back(f_stack[offset]);
59 df_interpreter.run(df_stack);
60 unpackReturnTuple(df_stack);
61 // Outputs of f needs to be sliced
62 f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
63 return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
64}
65
66std::shared_ptr<Graph> build_lstm() {
67 const auto graph_string = R"IR(
68 graph(%0 : Tensor,
69 %1 : Tensor,
70 %2 : Tensor,
71 %3 : Tensor,
72 %4 : Tensor):
73 %5 : Tensor = aten::mm(%0, %3)
74 %6 : Tensor = aten::mm(%1, %4)
75 %7 : int = prim::Constant[value=1]()
76 %8 : Tensor = aten::add(%5, %6, %7)
77 %9 : Tensor, %10 : Tensor, %11 : Tensor, %12 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%8)
78 %13 : Tensor = aten::sigmoid(%9)
79 %14 : Tensor = aten::sigmoid(%12)
80 %15 : Tensor = aten::tanh(%11)
81 %16 : Tensor = aten::sigmoid(%10)
82 %17 : Tensor = aten::mul(%16, %2)
83 %18 : Tensor = aten::mul(%13, %15)
84 %19 : int = prim::Constant[value=1]()
85 %20 : Tensor = aten::add(%17, %18, %19)
86 %21 : Tensor = aten::tanh(%20)
87 %22 : Tensor = aten::mul(%14, %21)
88 return (%22, %20))IR";
89 auto g = std::make_shared<Graph>();
90 torch::jit::parseIR(graph_string, g.get());
91 g->lint();
92
93 return g;
94}
95
96std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
97 // We use following two schemas for this graph:
98 // 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None,
99 // int? end=None, int step=1) -> Tensor(a)
100 // 2. slice.str(str string, int? start=None, int? end=None,
101 // int step=1) -> str
102 // %3 and %4 use slice.Tensor while %5 use slice.str.
103 // Since we can see %3 and %4 have the same last argument that is never used
104 // (same as default value of schema), we know we can ignore that last arg. For
105 // %5, we see that last three args are same as schema default, hence
106 // unnecessary.
107
108 const auto graph_string = R"IR(
109 graph(%0 : Tensor):
110 %1 : int = prim::Constant[value=1]()
111 %2 : int = prim::Constant[value=2]()
112 %20 : int = prim::Constant[value=0]()
113 %21 : int = prim::Constant[value=9223372036854775807]()
114 %22 : str = prim::Constant[value="value"]()
115 %3 : Tensor = aten::slice(%0, %1, %20, %2, %1)
116 %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
117 %5 : str = aten::slice(%22, %20, %21, %2)
118 return (%3, %4, %5))IR";
119
120 auto g = std::make_shared<Graph>();
121 torch::jit::parseIR(graph_string, g.get());
122 g->lint();
123 return g;
124}
125
126std::shared_ptr<Graph> build_mobile_export_with_out() {
127 const auto graph_string = R"IR(
128 graph(%x.1 : Tensor,
129 %y.1 : Tensor):
130 %8 : NoneType = prim::Constant()
131 %6 : int = prim::Constant[value=1]()
132 %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1)
133 return (%8))IR";
134
135 auto g = std::make_shared<Graph>();
136 torch::jit::parseIR(graph_string, g.get());
137 g->lint();
138 return g;
139}
140
141std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() {
142 // this is pretty much same test as build_mobile_export_analysis_graph(),
143 // but some aten::slice operators are hidden under block statement to check
144 // if we are correctly recursing all the nodes in graph.
145 const auto graph_string = R"IR(
146 graph(%0 : Tensor):
147 %1 : int = prim::Constant[value=1]()
148 %2 : int = prim::Constant[value=2]()
149 %20 : int = prim::Constant[value=0]()
150 %21 : int = prim::Constant[value=9223372036854775807]()
151 %22 : str = prim::Constant[value="value"]()
152 %3 : Tensor = aten::slice(%0, %1, %20, %2, %1)
153 %23 : bool = aten::Bool(%3)
154 %c : Tensor = prim::If(%23)
155 block0():
156 %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
157 %5 : str = aten::slice(%22, %20, %21, %2)
158 %c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1)
159 -> (%c.1)
160 block1():
161 -> (%3)
162 return (%3, %3))IR";
163 auto g = std::make_shared<Graph>();
164 torch::jit::parseIR(graph_string, g.get());
165 g->lint();
166 return g;
167}
168
169std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg() {
170 const auto graph_string = R"IR(
171 graph(%0 : Tensor):
172 %1 : int = prim::Constant[value=1]()
173 %2 : int = prim::Constant[value=2]()
174 %3 : int = prim::Constant[value=3]()
175 %4 : int[] = prim::tolist(%1, %2)
176 %5 : int[] = prim::tolist(%1, %2, %3)
177 return (%4, %5))IR";
178
179 auto g = std::make_shared<Graph>();
180 torch::jit::parseIR(graph_string, g.get());
181 g->lint();
182 return g;
183}
184
185std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const() {
186 const auto graph_string = R"IR(
187 graph(%input.1 : Tensor):
188 %7 : int = prim::Constant[value=1]() # <string>:3:58
189 %9 : int = prim::Constant[value=0]() # <string>:3:66
190 %8 : int[] = prim::ListConstruct(%7, %7)
191 %10 : int[] = prim::ListConstruct(%9, %9)
192 %11 : int[] = prim::ListConstruct(%7, %7)
193 %12 : Tensor = aten::conv2d(%input.1, %input.1, %input.1, %8, %10, %11, %7)
194 return (%12))IR";
195
196 auto g = std::make_shared<Graph>();
197 torch::jit::parseIR(graph_string, g.get());
198 g->lint();
199 return g;
200}
201
202at::Tensor t_use(at::Tensor x) {
203 return x;
204}
205at::Tensor t_def(at::Tensor x) {
206 return x.t();
207}
208
209bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
210 double maxValue = 0.0;
211 for (auto& tensor : inputs) {
212 maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
213 }
214 return diff.abs().max().item<float>() < 2e-6 * maxValue;
215}
216
217bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
218 return checkRtol(a - b, {a, b});
219}
220
221bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
222 return (a - b).abs().max().item<float>() == 0.f;
223}
224
225bool exactlyEqual(
226 const std::vector<at::Tensor>& a,
227 const std::vector<at::Tensor>& b) {
228 if (a.size() != b.size()) {
229 return false;
230 }
231 for (size_t i = 0; i < a.size(); ++i) {
232 if (!exactlyEqual(a[i], b[i])) {
233 return false;
234 }
235 }
236 return true;
237}
238
239std::pair<at::Tensor, at::Tensor> lstm(
240 at::Tensor input,
241 at::Tensor hx,
242 at::Tensor cx,
243 at::Tensor w_ih,
244 at::Tensor w_hh) {
245 auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));
246
247 auto chunked_gates = gates.chunk(4, 1);
248 auto ingate = chunked_gates[0];
249 auto forgetgate = chunked_gates[1];
250 auto cellgate = chunked_gates[2];
251 auto outgate = chunked_gates[3];
252
253 ingate = ingate.sigmoid();
254 outgate = outgate.sigmoid();
255 cellgate = cellgate.tanh();
256 forgetgate = forgetgate.sigmoid();
257
258 auto cy = (forgetgate * cx) + (ingate * cellgate);
259 auto hy = outgate * cy.tanh();
260
261 return {hy, cy};
262}
263
264inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
265 return c10::AliasAnalysisKind::FROM_SCHEMA;
266}
267
268namespace {
269RegisterOperators reg({
270 // This operator is intended to be used in JIT analysis and transformation
271 // pass unit tests in which Values with type Tensor are often required. It
272 // should not be used in situations in which the graph is actually executed
273 // because it always produces empty Tensors.
274 Operator(
275 "prim::MakeTestTensor() -> Tensor",
276 [](Stack& stack) { push(stack, at::Tensor()); },
277 aliasAnalysisFromSchema()),
278});
279} // namespace
280
281std::vector<at::Tensor> runGraph(
282 std::shared_ptr<Graph> graph,
283 const std::vector<at::Tensor>& inputs) {
284 std::vector<IValue> stack = fmap<IValue>(inputs);
285 Code code(graph, "test");
286 InterpreterState(code).run(stack);
287 TORCH_INTERNAL_ASSERT(!stack.empty());
288 // Graph outputs that are handled below:
289 // * A list of Tensors.
290 // * 1 Tensor.
291 if (stack.front().isTensorList()) {
292 return stack.front().toTensorVector();
293 }
294 TORCH_INTERNAL_ASSERT(stack.front().isTensor());
295 return {stack.front().toTensor()};
296}
297
298} // namespace jit
299} // namespace torch
300