1 | #include <gmock/gmock.h> |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <ATen/Parallel.h> |
5 | #include <c10/core/DeviceType.h> |
6 | #include <test/cpp/jit/test_utils.h> |
7 | #include <torch/csrc/jit/runtime/instruction.h> |
8 | #include <torch/jit.h> |
9 | #include <torch/script.h> |
10 | #include <torch/torch.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | class TypeCheckTest : public ::testing::Test { |
16 | protected: |
17 | TypeCheckTest() : interp(makeInterp()) {} |
18 | |
19 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
20 | InterpreterState interp; |
21 | |
22 | private: |
23 | static InterpreterState makeInterp() { |
24 | auto graph = std::make_shared<Graph>(); |
25 | std::unordered_map<std::string, Value*> vmap; |
26 | parseIR( |
27 | R"IR( |
28 | graph(%a.1 : Tensor, |
29 | %b.1 : Tensor): |
30 | %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), Float(3, 3, strides=[3, 1])]](%a.1, %b.1) |
31 | return (%t0, %t1, %type_matched) |
32 | )IR" , |
33 | &*graph, |
34 | vmap); |
35 | |
36 | Code function(graph, "" ); |
37 | return InterpreterState(function); |
38 | } |
39 | }; |
40 | |
41 | TEST_F(TypeCheckTest, MatchingType) { |
42 | // TypeCheck yields to true! Shape, grad and device matches. |
43 | auto a = at::zeros({2, 2}, at::kFloat); |
44 | auto b = at::ones({3, 3}, at::kFloat); |
45 | a.set_requires_grad(true); |
46 | a = a.to(at::kCPU); |
47 | std::vector<IValue> stack({a, b}); |
48 | interp.run(stack); |
49 | ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a)); |
50 | ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b)); |
51 | ASSERT_TRUE(stack[2].toBool()); |
52 | } |
53 | |
54 | TEST_F(TypeCheckTest, SizeMismatch) { |
55 | auto a = at::zeros({2, 2}, at::kFloat); |
56 | auto b = at::ones({2, 2}, at::kFloat); // Size mismatch |
57 | a.set_requires_grad(true); |
58 | a = a.to(at::kCPU); |
59 | std::vector<IValue> stack({a, b}); |
60 | interp.run(stack); |
61 | ASSERT_FALSE(stack[2].toBool()); |
62 | } |
63 | |
64 | TEST_F(TypeCheckTest, GradientMismatch) { |
65 | auto a = at::zeros({2, 2}, at::kFloat); |
66 | auto b = at::ones({3, 3}, at::kFloat); |
67 | a = a.to(at::kCPU); |
68 | a.set_requires_grad(false); // Gradient mismatch |
69 | std::vector<IValue> stack({a, b}); |
70 | interp.run(stack); |
71 | ASSERT_FALSE(stack[2].toBool()); |
72 | } |
73 | |
74 | TEST_F(TypeCheckTest, ScalarTypeMismatch) { |
75 | auto a = at::zeros({2, 2}, at::kFloat); |
76 | auto b = at::ones({3, 3}, at::kFloat); |
77 | a = a.to(at::kCPU); |
78 | a.set_requires_grad(true); |
79 | a = a.to(at::kInt); // Scalar type mismatch |
80 | std::vector<IValue> stack({a, b}); |
81 | interp.run(stack); |
82 | ASSERT_FALSE(stack[2].toBool()); |
83 | } |
84 | |
85 | TEST_F(TypeCheckTest, DeviceMismatch_CUDA) { |
86 | auto a = at::zeros({2, 2}, at::kFloat); |
87 | auto b = at::ones({3, 3}, at::kFloat); |
88 | a.set_requires_grad(true); |
89 | a = a.to(at::kCUDA); // Device mismatch |
90 | std::vector<IValue> stack({a, b}); |
91 | interp.run(stack); |
92 | ASSERT_FALSE(stack[2].toBool()); |
93 | } |
94 | |
95 | // TODO: These tests weren't doing anything. |
96 | // TEST(TypeCheckErrorTest, EmptyCheckRaises) { |
97 | // // Test empty Typecheck raises an internal assertion |
98 | // auto graph = std::make_shared<Graph>(); |
99 | // std::unordered_map<std::string, Value*> vmap; |
100 | // EXPECT_ANY_THROW(parseIR( |
101 | // R"IR( |
102 | // graph(%a.1 : Tensor, |
103 | // %b.1 : Tensor): |
104 | // %type_matched : bool = prim::TypeCheck() |
105 | // return (%type_matched) |
106 | // )IR", |
107 | // &*graph, |
108 | // vmap)); |
109 | // } |
110 | |
111 | // TODO: These tests weren't doing anything. |
112 | // TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) { |
113 | // // Test for assertion if num_inputs + 1 != num_outputs |
114 | // auto graph = std::make_shared<Graph>(); |
115 | // std::unordered_map<std::string, Value*> vmap; |
116 | // EXPECT_ANY_THROW(parseIR( |
117 | // R"IR( |
118 | // graph(%a.1 : Tensor, |
119 | // %b.1 : Tensor): |
120 | // %type_matched : bool = prim::TypeCheck(%a.1) |
121 | // return (%type_matched) |
122 | // )IR", |
123 | // &*graph, |
124 | // vmap)); |
125 | // } |
126 | |
127 | TEST(InterpreterTest, Basic_CUDA) { |
128 | constexpr int batch_size = 4; |
129 | constexpr int input_size = 256; |
130 | constexpr int seq_len = 32; |
131 | |
132 | int hidden_size = 2 * input_size; |
133 | |
134 | auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA); |
135 | auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); |
136 | auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); |
137 | auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); |
138 | auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); |
139 | |
140 | auto lstm_g = build_lstm(); |
141 | Code lstm_function(lstm_g, "" ); |
142 | InterpreterState lstm_interp(lstm_function); |
143 | auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}); |
144 | std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); |
145 | |
146 | ASSERT_TRUE(exactlyEqual(outputs[0], hx)); |
147 | ASSERT_TRUE(exactlyEqual(outputs[1], cx)); |
148 | } |
149 | |
150 | TEST(InterpreterTest, IgnorableArgsInSchema) { |
151 | auto graph = build_mobile_export_analysis_graph(); |
152 | MobileCode function(graph, "" ); |
153 | auto op_to_specified_args = function.op_to_num_specified_args(); |
154 | ASSERT_TRUE(op_to_specified_args.size() == 2); |
155 | ASSERT_TRUE(op_to_specified_args["aten::slice.Tensor" ] == 4); |
156 | ASSERT_TRUE(op_to_specified_args["aten::slice.str" ] == 4); |
157 | auto graph_vararg = build_mobile_export_analysis_graph_with_vararg(); |
158 | MobileCode function_vararg(graph_vararg, "" ); |
159 | auto op_to_specified_args_vararg = function_vararg.op_to_num_specified_args(); |
160 | // should never register it |
161 | ASSERT_TRUE( |
162 | op_to_specified_args_vararg.find("prim::tolist" ) == |
163 | op_to_specified_args_vararg.end()); |
164 | |
165 | auto graph_nested = build_mobile_export_analysis_graph_nested(); |
166 | MobileCode function_nested(graph_nested, "" ); |
167 | auto op_to_specified_args_nested = function_nested.op_to_num_specified_args(); |
168 | ASSERT_TRUE(op_to_specified_args_nested["aten::slice.Tensor" ] == 4); |
169 | ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str" ] == 4); |
170 | |
171 | auto graph_non_const = build_mobile_export_analysis_graph_non_const(); |
172 | MobileCode function_non_const(graph_non_const, "" ); |
173 | auto op_to_specified_args_non_const = |
174 | function_non_const.op_to_num_specified_args(); |
175 | ASSERT_TRUE(op_to_specified_args_non_const["aten::conv2d" ] == 6); |
176 | } |
177 | |
178 | TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) { |
179 | auto graph = build_mobile_export_with_out(); |
180 | MobileCode function(graph, "" ); |
181 | auto op_to_specified_args = function.op_to_num_specified_args(); |
182 | ASSERT_TRUE(op_to_specified_args.size() == 1); |
183 | // this should be 3 when the add_out flag is set to True |
184 | ASSERT_TRUE(op_to_specified_args["aten::add.out" ] == 3); |
185 | } |
186 | |
187 | TEST(InterpreterTest, runAsyncBasicTest) { |
188 | /* |
189 | TODO: there are some problem with C++ parsing script program involving |
190 | fork. Use the test module below for now. |
191 | issue about this: github.com/pytorch/pytorch/issues/46368 |
192 | The test module file is generated by following: |
193 | class DemoModule(torch.nn.Module): |
194 | def forward(self): |
195 | r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) |
196 | r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) |
197 | return r1.wait() + r2.wait() |
198 | demo = DemoModule() |
199 | torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt') |
200 | */ |
201 | std::string filePath(__FILE__); |
202 | auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
203 | testModelFile.append("test_interpreter_async.pt" ); |
204 | auto model = load(testModelFile); |
205 | auto graph = model.get_method("forward" ).graph(); |
206 | Code function(graph, "" ); |
207 | auto asyncCounter = 0; |
208 | std::mutex mtx; |
209 | // a dummy executor which actually use at::launch, but add up a counter |
210 | auto launcher = [&](std::function<void()> f) { |
211 | mtx.lock(); |
212 | ++asyncCounter; |
213 | mtx.unlock(); |
214 | at::launch(f); |
215 | }; |
216 | std::vector<IValue> stack; |
217 | // NOLINTNEXTLINE(modernize-use-emplace) |
218 | stack.push_back(model._ivalue()); |
219 | InterpreterState interp(function, launcher); |
220 | interp.runAsync(stack)->wait(); |
221 | ASSERT_TRUE(asyncCounter > 0); |
222 | } |
223 | |
224 | TEST( |
225 | EnableRethrowCaughtExceptionTest, |
226 | EnableRethrowCaughtExceptionTestRethrowsCaughtException) { |
227 | auto graph = std::make_shared<Graph>(); |
228 | std::unordered_map<std::string, Value*> vmap; |
229 | parseIR( |
230 | R"IR( |
231 | graph(%0 : Tensor, |
232 | %1 : Tensor): |
233 | %2 : int = prim::Constant[value=2]() |
234 | %3 : Tensor = aten::add(%0, %1, %2) |
235 | return (%3) |
236 | )IR" , |
237 | &*graph, |
238 | vmap); |
239 | Code function(graph, "" ); |
240 | InterpreterState interp = InterpreterState(function); |
241 | auto a = at::zeros({2, 2}, at::kFloat); |
242 | auto b = at::ones({2, 3}, at::kFloat); |
243 | a.set_requires_grad(true); |
244 | a = a.to(at::kCPU); |
245 | std::vector<IValue> stack({a, b}); |
246 | |
247 | bool original_flag_value = FLAGS_torch_jit_enable_rethrow_caught_exception; |
248 | bool exception_handled = false; |
249 | try { |
250 | FLAGS_torch_jit_enable_rethrow_caught_exception = false; |
251 | interp.run(stack); |
252 | } catch (std::runtime_error& e) { |
253 | exception_handled = true; |
254 | std::string exception_msg = e.what(); |
255 | EXPECT_THAT( |
256 | exception_msg, |
257 | ::testing::HasSubstr("%3 : Tensor = aten::add(%0, %1, %2)" )); |
258 | EXPECT_THAT( |
259 | exception_msg, |
260 | ::testing::HasSubstr( |
261 | "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" )); |
262 | } |
263 | EXPECT_TRUE(exception_handled); |
264 | |
265 | exception_handled = false; |
266 | try { |
267 | FLAGS_torch_jit_enable_rethrow_caught_exception = true; |
268 | interp.run(stack); |
269 | } catch (c10::Error& e) { |
270 | exception_handled = true; |
271 | std::string exception_msg = e.what_without_backtrace(); |
272 | EXPECT_STREQ( |
273 | exception_msg.c_str(), |
274 | "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" ); |
275 | } |
276 | EXPECT_TRUE(exception_handled); |
277 | |
278 | FLAGS_torch_jit_enable_rethrow_caught_exception = true; |
279 | c10::intrusive_ptr<Future> future = interp.runAsync(stack); |
280 | future->wait(); |
281 | ASSERT_TRUE(future->completed()); |
282 | ASSERT_TRUE(future->hasError()); |
283 | try { |
284 | std::rethrow_exception(future->exception_ptr()); |
285 | } catch (c10::Error& e) { |
286 | std::string exception_msg = e.what_without_backtrace(); |
287 | EXPECT_STREQ( |
288 | exception_msg.c_str(), |
289 | "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" ); |
290 | } |
291 | |
292 | FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value; |
293 | } |
294 | |
295 | } // namespace jit |
296 | } // namespace torch |
297 | |