1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/clear_undefinedness.h> |
6 | #include <torch/csrc/jit/runtime/custom_operator.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | Stack createStack(std::vector<at::Tensor>&& list) { |
12 | return Stack( |
13 | std::make_move_iterator(list.begin()), |
14 | std::make_move_iterator(list.end())); |
15 | } |
16 | |
17 | void assertAllClose(const tensor_list& a, const tensor_list& b) { |
18 | ASSERT_EQ(a.size(), b.size()); |
19 | for (size_t i = 0; i < a.size(); ++i) { |
20 | ASSERT_TRUE(a[i].is_same_size(b[i])); |
21 | ASSERT_TRUE(a[i].allclose(b[i])); |
22 | } |
23 | } |
24 | |
25 | std::vector<at::Tensor> run( |
26 | InterpreterState& interp, |
27 | const std::vector<at::Tensor>& inputs) { |
28 | std::vector<IValue> stack(inputs.begin(), inputs.end()); |
29 | interp.run(stack); |
30 | return fmap(stack, [](const IValue& i) { return i.toTensor(); }); |
31 | } |
32 | |
33 | static void unpackReturnTuple(Stack& stack) { |
34 | auto tuple = pop(stack).toTuple(); |
35 | stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end()); |
36 | } |
37 | |
38 | std::pair<tensor_list, tensor_list> runGradient( |
39 | Gradient& grad_spec, |
40 | tensor_list& tensors_in, |
41 | tensor_list& tensor_grads_in) { |
42 | static const auto as_tensorlist = [](const Stack& stack) { |
43 | return fmap(stack, [](const IValue& i) { return i.toTensor(); }); |
44 | }; |
45 | ClearUndefinedness(grad_spec.df); |
46 | Code f_code{grad_spec.f, "" }, df_code{grad_spec.df, "" }; |
47 | InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; |
48 | |
49 | auto f_stack = fmap<IValue>(tensors_in); |
50 | f_interpreter.run(f_stack); |
51 | |
52 | Stack df_stack; |
53 | df_stack.insert( |
54 | df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end()); |
55 | for (auto offset : grad_spec.df_input_captured_inputs) |
56 | df_stack.push_back(tensors_in[offset]); |
57 | for (auto offset : grad_spec.df_input_captured_outputs) |
58 | df_stack.push_back(f_stack[offset]); |
59 | df_interpreter.run(df_stack); |
60 | unpackReturnTuple(df_stack); |
61 | // Outputs of f needs to be sliced |
62 | f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end()); |
63 | return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack)); |
64 | } |
65 | |
66 | std::shared_ptr<Graph> build_lstm() { |
67 | const auto graph_string = R"IR( |
68 | graph(%0 : Tensor, |
69 | %1 : Tensor, |
70 | %2 : Tensor, |
71 | %3 : Tensor, |
72 | %4 : Tensor): |
73 | %5 : Tensor = aten::mm(%0, %3) |
74 | %6 : Tensor = aten::mm(%1, %4) |
75 | %7 : int = prim::Constant[value=1]() |
76 | %8 : Tensor = aten::add(%5, %6, %7) |
77 | %9 : Tensor, %10 : Tensor, %11 : Tensor, %12 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%8) |
78 | %13 : Tensor = aten::sigmoid(%9) |
79 | %14 : Tensor = aten::sigmoid(%12) |
80 | %15 : Tensor = aten::tanh(%11) |
81 | %16 : Tensor = aten::sigmoid(%10) |
82 | %17 : Tensor = aten::mul(%16, %2) |
83 | %18 : Tensor = aten::mul(%13, %15) |
84 | %19 : int = prim::Constant[value=1]() |
85 | %20 : Tensor = aten::add(%17, %18, %19) |
86 | %21 : Tensor = aten::tanh(%20) |
87 | %22 : Tensor = aten::mul(%14, %21) |
88 | return (%22, %20))IR" ; |
89 | auto g = std::make_shared<Graph>(); |
90 | torch::jit::parseIR(graph_string, g.get()); |
91 | g->lint(); |
92 | |
93 | return g; |
94 | } |
95 | |
96 | std::shared_ptr<Graph> build_mobile_export_analysis_graph() { |
97 | // We use following two schemas for this graph: |
98 | // 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None, |
99 | // int? end=None, int step=1) -> Tensor(a) |
100 | // 2. slice.str(str string, int? start=None, int? end=None, |
101 | // int step=1) -> str |
102 | // %3 and %4 use slice.Tensor while %5 use slice.str. |
103 | // Since we can see %3 and %4 have the same last argument that is never used |
104 | // (same as default value of schema), we know we can ignore that last arg. For |
105 | // %5, we see that last three args are same as schema default, hence |
106 | // unnecessary. |
107 | |
108 | const auto graph_string = R"IR( |
109 | graph(%0 : Tensor): |
110 | %1 : int = prim::Constant[value=1]() |
111 | %2 : int = prim::Constant[value=2]() |
112 | %20 : int = prim::Constant[value=0]() |
113 | %21 : int = prim::Constant[value=9223372036854775807]() |
114 | %22 : str = prim::Constant[value="value"]() |
115 | %3 : Tensor = aten::slice(%0, %1, %20, %2, %1) |
116 | %4 : Tensor = aten::slice(%0, %2, %20, %21, %1) |
117 | %5 : str = aten::slice(%22, %20, %21, %2) |
118 | return (%3, %4, %5))IR" ; |
119 | |
120 | auto g = std::make_shared<Graph>(); |
121 | torch::jit::parseIR(graph_string, g.get()); |
122 | g->lint(); |
123 | return g; |
124 | } |
125 | |
126 | std::shared_ptr<Graph> build_mobile_export_with_out() { |
127 | const auto graph_string = R"IR( |
128 | graph(%x.1 : Tensor, |
129 | %y.1 : Tensor): |
130 | %8 : NoneType = prim::Constant() |
131 | %6 : int = prim::Constant[value=1]() |
132 | %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1) |
133 | return (%8))IR" ; |
134 | |
135 | auto g = std::make_shared<Graph>(); |
136 | torch::jit::parseIR(graph_string, g.get()); |
137 | g->lint(); |
138 | return g; |
139 | } |
140 | |
141 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() { |
142 | // this is pretty much same test as build_mobile_export_analysis_graph(), |
143 | // but some aten::slice operators are hidden under block statement to check |
144 | // if we are correctly recursing all the nodes in graph. |
145 | const auto graph_string = R"IR( |
146 | graph(%0 : Tensor): |
147 | %1 : int = prim::Constant[value=1]() |
148 | %2 : int = prim::Constant[value=2]() |
149 | %20 : int = prim::Constant[value=0]() |
150 | %21 : int = prim::Constant[value=9223372036854775807]() |
151 | %22 : str = prim::Constant[value="value"]() |
152 | %3 : Tensor = aten::slice(%0, %1, %20, %2, %1) |
153 | %23 : bool = aten::Bool(%3) |
154 | %c : Tensor = prim::If(%23) |
155 | block0(): |
156 | %4 : Tensor = aten::slice(%0, %2, %20, %21, %1) |
157 | %5 : str = aten::slice(%22, %20, %21, %2) |
158 | %c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1) |
159 | -> (%c.1) |
160 | block1(): |
161 | -> (%3) |
162 | return (%3, %3))IR" ; |
163 | auto g = std::make_shared<Graph>(); |
164 | torch::jit::parseIR(graph_string, g.get()); |
165 | g->lint(); |
166 | return g; |
167 | } |
168 | |
169 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg() { |
170 | const auto graph_string = R"IR( |
171 | graph(%0 : Tensor): |
172 | %1 : int = prim::Constant[value=1]() |
173 | %2 : int = prim::Constant[value=2]() |
174 | %3 : int = prim::Constant[value=3]() |
175 | %4 : int[] = prim::tolist(%1, %2) |
176 | %5 : int[] = prim::tolist(%1, %2, %3) |
177 | return (%4, %5))IR" ; |
178 | |
179 | auto g = std::make_shared<Graph>(); |
180 | torch::jit::parseIR(graph_string, g.get()); |
181 | g->lint(); |
182 | return g; |
183 | } |
184 | |
185 | std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const() { |
186 | const auto graph_string = R"IR( |
187 | graph(%input.1 : Tensor): |
188 | %7 : int = prim::Constant[value=1]() # <string>:3:58 |
189 | %9 : int = prim::Constant[value=0]() # <string>:3:66 |
190 | %8 : int[] = prim::ListConstruct(%7, %7) |
191 | %10 : int[] = prim::ListConstruct(%9, %9) |
192 | %11 : int[] = prim::ListConstruct(%7, %7) |
193 | %12 : Tensor = aten::conv2d(%input.1, %input.1, %input.1, %8, %10, %11, %7) |
194 | return (%12))IR" ; |
195 | |
196 | auto g = std::make_shared<Graph>(); |
197 | torch::jit::parseIR(graph_string, g.get()); |
198 | g->lint(); |
199 | return g; |
200 | } |
201 | |
202 | at::Tensor t_use(at::Tensor x) { |
203 | return x; |
204 | } |
205 | at::Tensor t_def(at::Tensor x) { |
206 | return x.t(); |
207 | } |
208 | |
209 | bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) { |
210 | double maxValue = 0.0; |
211 | for (auto& tensor : inputs) { |
212 | maxValue = fmax(tensor.abs().max().item<float>(), maxValue); |
213 | } |
214 | return diff.abs().max().item<float>() < 2e-6 * maxValue; |
215 | } |
216 | |
217 | bool almostEqual(const at::Tensor& a, const at::Tensor& b) { |
218 | return checkRtol(a - b, {a, b}); |
219 | } |
220 | |
221 | bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { |
222 | return (a - b).abs().max().item<float>() == 0.f; |
223 | } |
224 | |
225 | bool exactlyEqual( |
226 | const std::vector<at::Tensor>& a, |
227 | const std::vector<at::Tensor>& b) { |
228 | if (a.size() != b.size()) { |
229 | return false; |
230 | } |
231 | for (size_t i = 0; i < a.size(); ++i) { |
232 | if (!exactlyEqual(a[i], b[i])) { |
233 | return false; |
234 | } |
235 | } |
236 | return true; |
237 | } |
238 | |
239 | std::pair<at::Tensor, at::Tensor> lstm( |
240 | at::Tensor input, |
241 | at::Tensor hx, |
242 | at::Tensor cx, |
243 | at::Tensor w_ih, |
244 | at::Tensor w_hh) { |
245 | auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh)); |
246 | |
247 | auto chunked_gates = gates.chunk(4, 1); |
248 | auto ingate = chunked_gates[0]; |
249 | auto forgetgate = chunked_gates[1]; |
250 | auto cellgate = chunked_gates[2]; |
251 | auto outgate = chunked_gates[3]; |
252 | |
253 | ingate = ingate.sigmoid(); |
254 | outgate = outgate.sigmoid(); |
255 | cellgate = cellgate.tanh(); |
256 | forgetgate = forgetgate.sigmoid(); |
257 | |
258 | auto cy = (forgetgate * cx) + (ingate * cellgate); |
259 | auto hy = outgate * cy.tanh(); |
260 | |
261 | return {hy, cy}; |
262 | } |
263 | |
264 | inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { |
265 | return c10::AliasAnalysisKind::FROM_SCHEMA; |
266 | } |
267 | |
268 | namespace { |
269 | RegisterOperators reg({ |
270 | // This operator is intended to be used in JIT analysis and transformation |
271 | // pass unit tests in which Values with type Tensor are often required. It |
272 | // should not be used in situations in which the graph is actually executed |
273 | // because it always produces empty Tensors. |
274 | Operator( |
275 | "prim::MakeTestTensor() -> Tensor" , |
276 | [](Stack& stack) { push(stack, at::Tensor()); }, |
277 | aliasAnalysisFromSchema()), |
278 | }); |
279 | } // namespace |
280 | |
281 | std::vector<at::Tensor> runGraph( |
282 | std::shared_ptr<Graph> graph, |
283 | const std::vector<at::Tensor>& inputs) { |
284 | std::vector<IValue> stack = fmap<IValue>(inputs); |
285 | Code code(graph, "test" ); |
286 | InterpreterState(code).run(stack); |
287 | TORCH_INTERNAL_ASSERT(!stack.empty()); |
288 | // Graph outputs that are handled below: |
289 | // * A list of Tensors. |
290 | // * 1 Tensor. |
291 | if (stack.front().isTensorList()) { |
292 | return stack.front().toTensorVector(); |
293 | } |
294 | TORCH_INTERNAL_ASSERT(stack.front().isTensor()); |
295 | return {stack.front().toTensor()}; |
296 | } |
297 | |
298 | } // namespace jit |
299 | } // namespace torch |
300 | |