1#include <gmock/gmock.h>
2#include <gtest/gtest.h>
3
4#include <ATen/Parallel.h>
5#include <c10/core/DeviceType.h>
6#include <test/cpp/jit/test_utils.h>
7#include <torch/csrc/jit/runtime/instruction.h>
8#include <torch/jit.h>
9#include <torch/script.h>
10#include <torch/torch.h>
11
12namespace torch {
13namespace jit {
14
15class TypeCheckTest : public ::testing::Test {
16 protected:
17 TypeCheckTest() : interp(makeInterp()) {}
18
19 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
20 InterpreterState interp;
21
22 private:
23 static InterpreterState makeInterp() {
24 auto graph = std::make_shared<Graph>();
25 std::unordered_map<std::string, Value*> vmap;
26 parseIR(
27 R"IR(
28graph(%a.1 : Tensor,
29 %b.1 : Tensor):
30 %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
31 return (%t0, %t1, %type_matched)
32 )IR",
33 &*graph,
34 vmap);
35
36 Code function(graph, "");
37 return InterpreterState(function);
38 }
39};
40
41TEST_F(TypeCheckTest, MatchingType) {
42 // TypeCheck yields to true! Shape, grad and device matches.
43 auto a = at::zeros({2, 2}, at::kFloat);
44 auto b = at::ones({3, 3}, at::kFloat);
45 a.set_requires_grad(true);
46 a = a.to(at::kCPU);
47 std::vector<IValue> stack({a, b});
48 interp.run(stack);
49 ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
50 ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
51 ASSERT_TRUE(stack[2].toBool());
52}
53
54TEST_F(TypeCheckTest, SizeMismatch) {
55 auto a = at::zeros({2, 2}, at::kFloat);
56 auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
57 a.set_requires_grad(true);
58 a = a.to(at::kCPU);
59 std::vector<IValue> stack({a, b});
60 interp.run(stack);
61 ASSERT_FALSE(stack[2].toBool());
62}
63
64TEST_F(TypeCheckTest, GradientMismatch) {
65 auto a = at::zeros({2, 2}, at::kFloat);
66 auto b = at::ones({3, 3}, at::kFloat);
67 a = a.to(at::kCPU);
68 a.set_requires_grad(false); // Gradient mismatch
69 std::vector<IValue> stack({a, b});
70 interp.run(stack);
71 ASSERT_FALSE(stack[2].toBool());
72}
73
74TEST_F(TypeCheckTest, ScalarTypeMismatch) {
75 auto a = at::zeros({2, 2}, at::kFloat);
76 auto b = at::ones({3, 3}, at::kFloat);
77 a = a.to(at::kCPU);
78 a.set_requires_grad(true);
79 a = a.to(at::kInt); // Scalar type mismatch
80 std::vector<IValue> stack({a, b});
81 interp.run(stack);
82 ASSERT_FALSE(stack[2].toBool());
83}
84
85TEST_F(TypeCheckTest, DeviceMismatch_CUDA) {
86 auto a = at::zeros({2, 2}, at::kFloat);
87 auto b = at::ones({3, 3}, at::kFloat);
88 a.set_requires_grad(true);
89 a = a.to(at::kCUDA); // Device mismatch
90 std::vector<IValue> stack({a, b});
91 interp.run(stack);
92 ASSERT_FALSE(stack[2].toBool());
93}
94
95// TODO: These tests weren't doing anything.
96// TEST(TypeCheckErrorTest, EmptyCheckRaises) {
97// // Test empty Typecheck raises an internal assertion
98// auto graph = std::make_shared<Graph>();
99// std::unordered_map<std::string, Value*> vmap;
100// EXPECT_ANY_THROW(parseIR(
101// R"IR(
102// graph(%a.1 : Tensor,
103// %b.1 : Tensor):
104// %type_matched : bool = prim::TypeCheck()
105// return (%type_matched)
106// )IR",
107// &*graph,
108// vmap));
109// }
110
111// TODO: These tests weren't doing anything.
112// TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) {
113// // Test for assertion if num_inputs + 1 != num_outputs
114// auto graph = std::make_shared<Graph>();
115// std::unordered_map<std::string, Value*> vmap;
116// EXPECT_ANY_THROW(parseIR(
117// R"IR(
118// graph(%a.1 : Tensor,
119// %b.1 : Tensor):
120// %type_matched : bool = prim::TypeCheck(%a.1)
121// return (%type_matched)
122// )IR",
123// &*graph,
124// vmap));
125// }
126
127TEST(InterpreterTest, Basic_CUDA) {
128 constexpr int batch_size = 4;
129 constexpr int input_size = 256;
130 constexpr int seq_len = 32;
131
132 int hidden_size = 2 * input_size;
133
134 auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
135 auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
136 auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
137 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
138 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
139
140 auto lstm_g = build_lstm();
141 Code lstm_function(lstm_g, "");
142 InterpreterState lstm_interp(lstm_function);
143 auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
144 std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
145
146 ASSERT_TRUE(exactlyEqual(outputs[0], hx));
147 ASSERT_TRUE(exactlyEqual(outputs[1], cx));
148}
149
150TEST(InterpreterTest, IgnorableArgsInSchema) {
151 auto graph = build_mobile_export_analysis_graph();
152 MobileCode function(graph, "");
153 auto op_to_specified_args = function.op_to_num_specified_args();
154 ASSERT_TRUE(op_to_specified_args.size() == 2);
155 ASSERT_TRUE(op_to_specified_args["aten::slice.Tensor"] == 4);
156 ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 4);
157 auto graph_vararg = build_mobile_export_analysis_graph_with_vararg();
158 MobileCode function_vararg(graph_vararg, "");
159 auto op_to_specified_args_vararg = function_vararg.op_to_num_specified_args();
160 // should never register it
161 ASSERT_TRUE(
162 op_to_specified_args_vararg.find("prim::tolist") ==
163 op_to_specified_args_vararg.end());
164
165 auto graph_nested = build_mobile_export_analysis_graph_nested();
166 MobileCode function_nested(graph_nested, "");
167 auto op_to_specified_args_nested = function_nested.op_to_num_specified_args();
168 ASSERT_TRUE(op_to_specified_args_nested["aten::slice.Tensor"] == 4);
169 ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 4);
170
171 auto graph_non_const = build_mobile_export_analysis_graph_non_const();
172 MobileCode function_non_const(graph_non_const, "");
173 auto op_to_specified_args_non_const =
174 function_non_const.op_to_num_specified_args();
175 ASSERT_TRUE(op_to_specified_args_non_const["aten::conv2d"] == 6);
176}
177
178TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) {
179 auto graph = build_mobile_export_with_out();
180 MobileCode function(graph, "");
181 auto op_to_specified_args = function.op_to_num_specified_args();
182 ASSERT_TRUE(op_to_specified_args.size() == 1);
183 // this should be 3 when the add_out flag is set to True
184 ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 3);
185}
186
187TEST(InterpreterTest, runAsyncBasicTest) {
188 /*
189 TODO: there are some problem with C++ parsing script program involving
190 fork. Use the test module below for now.
191 issue about this: github.com/pytorch/pytorch/issues/46368
192 The test module file is generated by following:
193 class DemoModule(torch.nn.Module):
194 def forward(self):
195 r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
196 r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
197 return r1.wait() + r2.wait()
198 demo = DemoModule()
199 torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
200 */
201 std::string filePath(__FILE__);
202 auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
203 testModelFile.append("test_interpreter_async.pt");
204 auto model = load(testModelFile);
205 auto graph = model.get_method("forward").graph();
206 Code function(graph, "");
207 auto asyncCounter = 0;
208 std::mutex mtx;
209 // a dummy executor which actually use at::launch, but add up a counter
210 auto launcher = [&](std::function<void()> f) {
211 mtx.lock();
212 ++asyncCounter;
213 mtx.unlock();
214 at::launch(f);
215 };
216 std::vector<IValue> stack;
217 // NOLINTNEXTLINE(modernize-use-emplace)
218 stack.push_back(model._ivalue());
219 InterpreterState interp(function, launcher);
220 interp.runAsync(stack)->wait();
221 ASSERT_TRUE(asyncCounter > 0);
222}
223
224TEST(
225 EnableRethrowCaughtExceptionTest,
226 EnableRethrowCaughtExceptionTestRethrowsCaughtException) {
227 auto graph = std::make_shared<Graph>();
228 std::unordered_map<std::string, Value*> vmap;
229 parseIR(
230 R"IR(
231graph(%0 : Tensor,
232 %1 : Tensor):
233 %2 : int = prim::Constant[value=2]()
234 %3 : Tensor = aten::add(%0, %1, %2)
235 return (%3)
236 )IR",
237 &*graph,
238 vmap);
239 Code function(graph, "");
240 InterpreterState interp = InterpreterState(function);
241 auto a = at::zeros({2, 2}, at::kFloat);
242 auto b = at::ones({2, 3}, at::kFloat);
243 a.set_requires_grad(true);
244 a = a.to(at::kCPU);
245 std::vector<IValue> stack({a, b});
246
247 bool original_flag_value = FLAGS_torch_jit_enable_rethrow_caught_exception;
248 bool exception_handled = false;
249 try {
250 FLAGS_torch_jit_enable_rethrow_caught_exception = false;
251 interp.run(stack);
252 } catch (std::runtime_error& e) {
253 exception_handled = true;
254 std::string exception_msg = e.what();
255 EXPECT_THAT(
256 exception_msg,
257 ::testing::HasSubstr("%3 : Tensor = aten::add(%0, %1, %2)"));
258 EXPECT_THAT(
259 exception_msg,
260 ::testing::HasSubstr(
261 "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"));
262 }
263 EXPECT_TRUE(exception_handled);
264
265 exception_handled = false;
266 try {
267 FLAGS_torch_jit_enable_rethrow_caught_exception = true;
268 interp.run(stack);
269 } catch (c10::Error& e) {
270 exception_handled = true;
271 std::string exception_msg = e.what_without_backtrace();
272 EXPECT_STREQ(
273 exception_msg.c_str(),
274 "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
275 }
276 EXPECT_TRUE(exception_handled);
277
278 FLAGS_torch_jit_enable_rethrow_caught_exception = true;
279 c10::intrusive_ptr<Future> future = interp.runAsync(stack);
280 future->wait();
281 ASSERT_TRUE(future->completed());
282 ASSERT_TRUE(future->hasError());
283 try {
284 std::rethrow_exception(future->exception_ptr());
285 } catch (c10::Error& e) {
286 std::string exception_msg = e.what_without_backtrace();
287 EXPECT_STREQ(
288 exception_msg.c_str(),
289 "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
290 }
291
292 FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value;
293}
294
295} // namespace jit
296} // namespace torch
297