1 | #include <test/cpp/jit/test_utils.h> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <torch/csrc/autograd/generated/variable_factories.h> |
7 | #include <torch/csrc/jit/api/module.h> |
8 | #include <torch/csrc/jit/frontend/resolver.h> |
9 | #include <torch/csrc/jit/mobile/compatibility/backport.h> |
10 | #include <torch/csrc/jit/mobile/compatibility/backport_manager.h> |
11 | #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h> |
12 | #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h> |
13 | #include <torch/csrc/jit/mobile/import.h> |
14 | #include <torch/csrc/jit/mobile/interpreter.h> |
15 | #include <torch/csrc/jit/mobile/module.h> |
16 | #include <torch/csrc/jit/mobile/parse_bytecode.h> |
17 | #include <torch/csrc/jit/mobile/parse_operators.h> |
18 | #include <torch/csrc/jit/serialization/export.h> |
19 | #include <torch/csrc/jit/serialization/export_bytecode.h> |
20 | #include <torch/csrc/jit/serialization/import.h> |
21 | #include <torch/custom_class.h> |
22 | #include <torch/torch.h> |
23 | |
24 | #include <unordered_set> |
25 | |
26 | // Tests go in torch::jit |
27 | namespace torch { |
28 | namespace jit { |
29 | |
30 | TEST(LiteInterpreterDirectTest, UpsampleNearest2d) { |
31 | Module m("m" ); |
32 | m.define(R"( |
33 | def forward(self, input: Tensor, scale:float): |
34 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
35 | )" ); |
36 | |
37 | std::vector<IValue> inputs; |
38 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
39 | inputs.emplace_back(at::Scalar(2.0)); |
40 | auto ref = m.forward(inputs); |
41 | |
42 | CompilationOptions options; |
43 | mobile::Module bc = jitModuleToMobile(m, options); |
44 | IValue res; |
45 | res = bc.forward(inputs); |
46 | |
47 | auto resd = res.toTensor(); |
48 | auto refd = ref.toTensor(); |
49 | ASSERT_TRUE(resd.equal(refd)); |
50 | } |
51 | |
52 | TEST(LiteInterpreterDirectTest, CheckAttrAccess) { |
53 | Module m("m" ); |
54 | m.register_attribute("mobile_optimized" , BoolType::get(), true); |
55 | |
56 | CompilationOptions options; |
57 | mobile::Module bc = jitModuleToMobile(m, options); |
58 | bool mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
59 | |
60 | AT_ASSERT(mobile_optimized); |
61 | m.setattr("mobile_optimized" , false); |
62 | bc = jitModuleToMobile(m, options); |
63 | mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
64 | AT_ASSERT(!mobile_optimized); |
65 | } |
66 | |
67 | TEST( |
68 | LiteInterpreterDirectTest, |
69 | MethodInvocation) { // NOLINT (use =delete in gtest) |
70 | const std::vector<std::string> test_programs{ |
71 | // test invoking a method with default parameter |
72 | R"( |
73 | def test_func(self, x, b : int = 4): |
74 | return self.foo + x + b |
75 | )" , |
76 | // inner method call with default parameter (gets inlined) |
77 | R"( |
78 | def add_with_default_arg(self, x, b : int = 4): |
79 | return self.foo + x + b |
80 | def test_func(self, x): |
81 | return self.add_with_default_arg(x) # invoke method w/ default arg |
82 | )" , |
83 | // simple method call |
84 | R"( |
85 | def test_func(self, x): |
86 | b = 4 |
87 | return self.foo + x + b |
88 | )" , |
89 | }; |
90 | for (const auto& test_program : test_programs) { |
91 | Module m("m" ); |
92 | m.register_parameter("foo" , torch::ones({}), false); |
93 | m.define(test_program); |
94 | |
95 | const int fortyTwo = 42; // (keep linter happy) |
96 | auto minput = fortyTwo * torch::ones({}); |
97 | auto ref = m.run_method("test_func" , minput); |
98 | |
99 | CompilationOptions options; |
100 | mobile::Module bc = jitModuleToMobile(m, options); |
101 | const auto& test_func = bc.get_method("test_func" ); |
102 | std::cerr << "hello " << std::endl; |
103 | IValue res; |
104 | for (int i = 0; i < 3; ++i) { |
105 | res = test_func({minput}); |
106 | } |
107 | std::cerr << "hello 3" << std::endl; |
108 | |
109 | auto resd = res.toTensor().item<float>(); |
110 | auto refd = ref.toTensor().item<float>(); |
111 | AT_ASSERT(resd == refd); |
112 | } |
113 | } |
114 | |
115 | TEST(LiteInterpreterDirectTest, Conv) { |
116 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
117 | if (s && strcmp(s, "1" ) == 0) |
118 | return; |
119 | |
120 | std::vector<torch::jit::IValue> inputs; |
121 | |
122 | Module m("m" ); |
123 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
124 | m.register_parameter("bias" , torch::ones({20}), false); |
125 | m.define(R"( |
126 | def forward(self, input): |
127 | return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
128 | )" ); |
129 | |
130 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
131 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
132 | |
133 | auto outputref = m.forward(inputs).toTensor(); |
134 | |
135 | CompilationOptions options; |
136 | mobile::Module bc = jitModuleToMobile(m, options); |
137 | IValue res; |
138 | for (int i = 0; i < 3; ++i) { |
139 | res = bc.get_method("forward" )(inputs); |
140 | } |
141 | auto output = res.toTensor(); |
142 | AT_ASSERT(outputref.dim() == output.dim()); |
143 | AT_ASSERT( |
144 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
145 | } |
146 | |
147 | TEST(LiteInterpreterDirectTest, Inline) { |
148 | Module m("m" ); |
149 | m.define(R"JIT( |
150 | def foo1(self, x): |
151 | return x + 1 |
152 | |
153 | def foo2(self, x): |
154 | return self.foo1(x) + 2 |
155 | |
156 | def foo3(self, x): |
157 | return self.foo2(x) + 3 |
158 | )JIT" ); |
159 | CompilationOptions options; |
160 | mobile::Module bc = jitModuleToMobile(m, options); |
161 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
162 | auto output = bc.get_method("foo3" )(inputs); |
163 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
164 | } |
165 | |
166 | TEST(LiteInterpreterDirectTest, Tuple) { |
167 | Module m("m" ); |
168 | m.define(R"JIT( |
169 | def foo(self, x): |
170 | return (1, 2, x + 3) |
171 | |
172 | def forward(self, x): |
173 | tuple = self.foo(x) |
174 | return tuple |
175 | )JIT" ); |
176 | CompilationOptions options; |
177 | mobile::Module bc = jitModuleToMobile(m, options); |
178 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
179 | auto output = bc.get_method("forward" )(inputs); |
180 | AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2); |
181 | } |
182 | |
183 | TEST(LiteInterpreterDirectTest, Dict) { |
184 | Module m("m" ); |
185 | m.define(R"JIT( |
186 | def foo(self, x): |
187 | return {"result": x + 1} |
188 | |
189 | def forward(self, x): |
190 | d = self.foo(x) |
191 | return d |
192 | )JIT" ); |
193 | CompilationOptions options; |
194 | mobile::Module bc = jitModuleToMobile(m, options); |
195 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
196 | auto output = bc.get_method("forward" )(inputs); |
197 | AT_ASSERT(output.toGenericDict().at("result" ).toTensor().item().toInt() == 2); |
198 | } |
199 | |
200 | TEST(LiteInterpreterDirectTest, Prim) { |
201 | Module m("m" ); |
202 | m.define(R"JIT( |
203 | def forward(self, x): |
204 | return int(x) |
205 | )JIT" ); |
206 | |
207 | std::vector<IValue> inputs; |
208 | auto minput = 3.5 * torch::ones({}); |
209 | inputs.emplace_back(minput); |
210 | auto ref = m.run_method("forward" , minput); |
211 | |
212 | CompilationOptions options; |
213 | mobile::Module bc = jitModuleToMobile(m, options); |
214 | |
215 | IValue res; |
216 | for (int i = 0; i < 3; ++i) { |
217 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
218 | auto bcinputs = inputs; |
219 | res = bc.get_method("forward" )(bcinputs); |
220 | } |
221 | |
222 | auto resi = res.toInt(); |
223 | auto refi = ref.toInt(); |
224 | AT_ASSERT(resi == refi); |
225 | } |
226 | |
227 | TEST(LiteInterpreterDirectTest, PrimScalar) { |
228 | Module m("m" ); |
229 | m.define(R"JIT( |
230 | def forward(self, x): |
231 | return int(x.item()) |
232 | )JIT" ); |
233 | |
234 | std::vector<IValue> inputs; |
235 | auto minput = 3.5 * torch::ones({}); |
236 | inputs.emplace_back(minput); |
237 | auto ref = m.run_method("forward" , minput); |
238 | |
239 | CompilationOptions options; |
240 | mobile::Module bc = jitModuleToMobile(m, options); |
241 | IValue res; |
242 | for (int i = 0; i < 3; ++i) { |
243 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
244 | auto bcinputs = inputs; |
245 | res = bc.get_method("forward" )(bcinputs); |
246 | } |
247 | |
248 | auto resi = res.toInt(); |
249 | auto refi = ref.toInt(); |
250 | AT_ASSERT(resi == refi); |
251 | } |
252 | |
253 | TEST(LiteInterpreterDirectTest, WrongMethodName) { |
254 | Module m("m" ); |
255 | m.register_parameter("foo" , torch::ones({}), false); |
256 | m.define(R"( |
257 | def add(self, x): |
258 | b = 4 |
259 | return self.foo + x + b |
260 | )" ); |
261 | CompilationOptions options; |
262 | mobile::Module bc = jitModuleToMobile(m, options); |
263 | std::vector<IValue> inputs; |
264 | auto minput = 5 * torch::ones({}); |
265 | inputs.emplace_back(minput); |
266 | ASSERT_THROWS_WITH_MESSAGE( |
267 | bc.get_method("forward" )(inputs), "is not defined" ); |
268 | } |
269 | |
270 | TEST(LiteInterpreterDirectTest, SetState) { |
271 | Module m("m" ); |
272 | m.register_parameter("foo" , torch::ones({}), false); |
273 | m.define(R"( |
274 | def __getstate__(self): |
275 | return self.foo |
276 | def __setstate__(self, a): |
277 | self.foo = a |
278 | def forward(self, x): |
279 | b = 4 |
280 | return self.foo + x + b |
281 | )" ); |
282 | |
283 | std::vector<IValue> inputs; |
284 | auto minput = 5 * torch::ones({}); |
285 | inputs.emplace_back(minput); |
286 | |
287 | std::stringstream ms; |
288 | m.save(ms); |
289 | auto loaded_m = load(ms); |
290 | auto ref = loaded_m.run_method("forward" , minput); |
291 | |
292 | CompilationOptions options; |
293 | mobile::Module bc = jitModuleToMobile(m, options); |
294 | IValue res; |
295 | for (int i = 0; i < 3; ++i) { |
296 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
297 | auto bcinputs = inputs; |
298 | res = bc.get_method("forward" )(bcinputs); |
299 | } |
300 | |
301 | auto resd = res.toTensor().item<float>(); |
302 | auto refd = ref.toTensor().item<float>(); |
303 | AT_ASSERT(resd == refd); |
304 | } |
305 | |
306 | class TorchBindLiteInterpreterDirectTestStruct |
307 | : public torch::jit::CustomClassHolder { |
308 | public: |
309 | std::string get(at::Tensor t) { |
310 | std::stringstream ss; |
311 | ss << "Hello! Your tensor has " ; |
312 | ss << t.numel(); |
313 | ss << " elements!" ; |
314 | return ss.str(); |
315 | } |
316 | }; |
317 | |
318 | namespace { |
319 | struct ClassNamespaceValue : public SugaredValue { |
320 | explicit ClassNamespaceValue(c10::QualifiedName name) |
321 | : basename_(std::move(name)) {} |
322 | |
323 | std::shared_ptr<SugaredValue> attr( |
324 | const SourceRange&, |
325 | GraphFunction&, |
326 | const std::string& name) override { |
327 | const auto fullName = c10::QualifiedName(basename_, name); |
328 | |
329 | // Check to see if it is a custom class. |
330 | if (auto custom_class = getCustomClass(fullName.qualifiedName())) { |
331 | return std::make_shared<ClassValue>(custom_class); |
332 | } |
333 | |
334 | // If it's not a custom class, assume it's another namespace |
335 | // NOLINTNEXTLINE(performance-move-const-arg) |
336 | return std::make_shared<ClassNamespaceValue>(fullName); |
337 | } |
338 | |
339 | std::string kind() const override { |
340 | return "Class Namespace" ; |
341 | } |
342 | |
343 | private: |
344 | c10::QualifiedName basename_; |
345 | }; |
346 | |
347 | struct TestModuleResolver : public Resolver { |
348 | std::shared_ptr<SugaredValue> resolveValue( |
349 | const std::string& name, |
350 | GraphFunction&, |
351 | const SourceRange&) override { |
352 | if (name == "torch" ) { |
353 | return std::make_shared<BuiltinModule>("aten" ); |
354 | } else if (name == "__torch__" ) { |
355 | return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name)); |
356 | } |
357 | |
358 | return nullptr; |
359 | } |
360 | |
361 | TypePtr resolveType(const std::string&, const SourceRange&) override { |
362 | return nullptr; |
363 | } |
364 | }; |
365 | } // namespace |
366 | |
367 | TEST(LiteInterpreterDirectTest, BuiltinFunction) { |
368 | script::Module m("m" ); |
369 | auto custom_class_obj = |
370 | make_custom_class<TorchBindLiteInterpreterDirectTestStruct>(); |
371 | m.register_attribute("my_obj" , custom_class_obj.type(), custom_class_obj); |
372 | m.define(R"( |
373 | def forward(self, x) -> str: |
374 | return self.my_obj.get(x) |
375 | )" ); |
376 | |
377 | CompilationOptions options; |
378 | mobile::Module bc = jitModuleToMobile(m, options); |
379 | auto res = |
380 | bc.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
381 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
382 | auto str = res.toStringRef(); |
383 | std::string expected = "Hello! Your tensor has 12 elements!" ; |
384 | AT_ASSERT(str == expected); |
385 | } |
386 | |
387 | #if !defined FB_XPLAT_BUILD |
388 | TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) { |
389 | auto runtime_bytecode_version = _get_runtime_bytecode_version(); |
390 | AT_ASSERT( |
391 | runtime_bytecode_version == |
392 | caffe2::serialize::kMaxSupportedBytecodeVersion); |
393 | } |
394 | |
395 | TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) { |
396 | auto runtime_operators_version = _get_runtime_operators_min_max_versions(); |
397 | AT_ASSERT( |
398 | runtime_operators_version.first == |
399 | caffe2::serialize::kMinSupportedFileFormatVersion && |
400 | runtime_operators_version.second == |
401 | caffe2::serialize::kMaxSupportedFileFormatVersion); |
402 | } |
403 | |
404 | /** |
405 | * The test below is disarmed for FB internal xplat builds since |
406 | * BUCK requires us to pass in the script_module_v4.ptl file in |
407 | * as a resource dependency of the build rule for this file, and |
408 | * we would need to access it via the C++ Resources API instead |
409 | * of directly reading from disk (which is what the open source |
410 | * build/run does). |
411 | */ |
412 | TEST(LiteInterpreterDirectTest, GetByteCodeVersion) { |
413 | std::string filePath(__FILE__); |
414 | auto test_model_file_v4 = |
415 | filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
416 | test_model_file_v4.append("script_module_v4.ptl" ); |
417 | |
418 | auto version_v4 = _get_model_bytecode_version(test_model_file_v4); |
419 | AT_ASSERT(version_v4 == 4); |
420 | } |
421 | |
422 | #endif // !defined(FB_XPLAT_BUILD) |
423 | |
424 | TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) { |
425 | auto runtime_ops = _get_runtime_ops_and_info(); |
426 | // Ballpark estimate of the minimal number of ops; just used to |
427 | // verify API returns a reasonably large number. |
428 | AT_ASSERT(runtime_ops.size() > 2900); |
429 | } |
430 | |
431 | TEST(LiteInterpreterDirectTest, Eval) { |
432 | std::vector<torch::jit::IValue> inputs; |
433 | |
434 | Module m("m" ); |
435 | m.define(R"( |
436 | def __init__(self, x): |
437 | self.training = True |
438 | |
439 | def forward(self, input): |
440 | return torch.dropout(input, 1.0, self.training) |
441 | )" ); |
442 | |
443 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
444 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
445 | m.eval(); |
446 | auto outputref = m.forward(inputs).toTensor(); |
447 | |
448 | // save m in training mode to make sure that mobile eval() will correctly |
449 | // change back to eval mode |
450 | m.train(); |
451 | CompilationOptions options; |
452 | mobile::Module bc = jitModuleToMobile(m, options); |
453 | bc.eval(); |
454 | IValue res; |
455 | for (int i = 0; i < 3; ++i) { |
456 | res = bc.get_method("forward" )(inputs); |
457 | } |
458 | auto output = res.toTensor(); |
459 | AT_ASSERT(outputref.dim() == output.dim()); |
460 | AT_ASSERT( |
461 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
462 | } |
463 | |
464 | TEST(LiteInterpreterDirectTest, FindWrongMethodName) { |
465 | Module m("m" ); |
466 | m.register_parameter("foo" , torch::ones({}), false); |
467 | m.define(R"( |
468 | def add(self, x): |
469 | b = 4 |
470 | return self.foo + x + b |
471 | )" ); |
472 | CompilationOptions options; |
473 | mobile::Module bc = jitModuleToMobile(m, options); |
474 | ASSERT_TRUE(bc.find_method("forward" ) == c10::nullopt); |
475 | } |
476 | |
477 | TEST(LiteInterpreterDirectTest, FindAndRunMethod) { |
478 | Module m("m" ); |
479 | m.register_parameter("foo" , torch::ones({}), false); |
480 | m.define(R"( |
481 | def add_it(self, x): |
482 | b = 4 |
483 | return self.foo + x + b |
484 | )" ); |
485 | |
486 | std::vector<IValue> inputs; |
487 | auto minput = 5 * torch::ones({}); |
488 | inputs.emplace_back(minput); |
489 | auto ref = m.get_method("add_it" )(inputs); |
490 | |
491 | CompilationOptions options; |
492 | mobile::Module bc = jitModuleToMobile(m, options); |
493 | IValue res; |
494 | for (int i = 0; i < 3; ++i) { |
495 | auto bcinputs = inputs; |
496 | auto method = bc.find_method("add_it" ); |
497 | AT_ASSERT(method != c10::nullopt); |
498 | res = (*method)(std::move(bcinputs)); |
499 | } |
500 | |
501 | auto resd = res.toTensor().item<float>(); |
502 | auto refd = ref.toTensor().item<float>(); |
503 | AT_ASSERT(resd == refd); |
504 | } |
505 | |
506 | TEST(LiteInterpreterDirectTest, RunMethodVariadic) { |
507 | Module m("m" ); |
508 | m.register_parameter("foo" , torch::ones({}), false); |
509 | m.define(R"( |
510 | def add_three(self, x, y): |
511 | return self.foo + x + y |
512 | )" ); |
513 | |
514 | std::vector<IValue> inputs; |
515 | auto inputx = 5 * torch::ones({}); |
516 | auto inputy = 4 * torch::ones({}); |
517 | auto ref = m.run_method("add_three" , inputx, inputy); |
518 | |
519 | CompilationOptions options; |
520 | mobile::Module bc = jitModuleToMobile(m, options); |
521 | IValue res = bc.run_method("add_three" , inputx, inputy); |
522 | |
523 | auto resd = res.toTensor().item<float>(); |
524 | auto refd = ref.toTensor().item<float>(); |
525 | AT_ASSERT(resd == refd); |
526 | } |
527 | |
528 | TEST(LiteInterpreterDirectTest, DuplicateSetState) { |
529 | Module m("M" ); |
530 | m.register_parameter("foo" , torch::ones({}), false); |
531 | m.define(R"( |
532 | def __getstate__(self): |
533 | return self.foo + self.foo |
534 | def __setstate__(self, a): |
535 | self.foo = a |
536 | def forward(self, x): |
537 | b = 4 |
538 | return self.foo + x + b |
539 | )" ); |
540 | |
541 | Module b("B" ); |
542 | b.register_module("M0" , m); |
543 | b.register_module("M1" , m); |
544 | b.define(R"( |
545 | def forward(self, x): |
546 | return self.M0.forward(x) + self.M1.forward(x) |
547 | )" ); |
548 | |
549 | CompilationOptions options; |
550 | mobile::Module bc = jitModuleToMobile(m, options); |
551 | const auto methods = bc.get_methods(); |
552 | const size_t expected_n = 3; |
553 | ASSERT_EQ(methods.size(), expected_n); |
554 | } |
555 | |
556 | TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) { |
557 | torch::jit::Module m("m" ); |
558 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
559 | m.register_parameter("bias" , torch::ones({20}), false); |
560 | m.define(R"( |
561 | def forward(self, input): |
562 | x1 = torch.zeros(2, 2) |
563 | x2 = torch.empty_like(torch.empty(2, 2)) |
564 | x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
565 | return (x1, x2, x3) |
566 | )" ); |
567 | m.eval(); |
568 | |
569 | CompilationOptions options; |
570 | mobile::Module ptl_model = jitModuleToMobile(m, options); |
571 | std::set<std::string> operator_names = |
572 | torch::jit::mobile::_export_operator_list(ptl_model); |
573 | std::set<std::string> expected_operator_names = { |
574 | "aten::_convolution" , |
575 | "aten::empty.memory_format" , |
576 | "aten::empty_like" , |
577 | "aten::zeros" , |
578 | }; |
579 | EXPECT_EQ(operator_names, expected_operator_names) |
580 | << "Expected the root operator lists to be the same" ; |
581 | } |
582 | |
583 | TEST(LiteInterpreterDirectTest, DefaultArgsConv) { |
584 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
585 | if (s && strcmp(s, "1" ) == 0) |
586 | return; |
587 | |
588 | std::vector<torch::jit::IValue> inputs; |
589 | |
590 | Module m("m" ); |
591 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
592 | m.register_parameter("bias" , torch::ones({20}), false); |
593 | m.define(R"( |
594 | def forward(self, input): |
595 | return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1) |
596 | )" ); |
597 | |
598 | inputs.emplace_back(torch::ones({1, 1, 28, 28})); |
599 | |
600 | auto outputref = m.forward(inputs).toTensor(); |
601 | |
602 | CompilationOptions options; |
603 | mobile::Module bc = jitModuleToMobile(m, options); |
604 | IValue res; |
605 | for (int i = 0; i < 1; ++i) { |
606 | res = bc.get_method("forward" )(inputs); |
607 | } |
608 | auto output = res.toTensor(); |
609 | AT_ASSERT(outputref.dim() == output.dim()); |
610 | AT_ASSERT(output.equal(outputref)); |
611 | } |
612 | |
613 | namespace { |
614 | void testLiteModuleCompareResultTensors( |
615 | Module& m, |
616 | const std::vector<torch::jit::IValue>& inputs, |
617 | const std::string& method_name = "forward" ) { |
618 | auto outputref = m.get_method(method_name)(inputs).toTensor(); |
619 | |
620 | CompilationOptions options; |
621 | mobile::Module bc = jitModuleToMobile(m, options); |
622 | IValue res; |
623 | for (int i = 0; i < 3; ++i) { |
624 | res = bc.get_method(method_name)(inputs); |
625 | } |
626 | auto output = res.toTensor(); |
627 | AT_ASSERT(outputref.dim() == output.dim()); |
628 | AT_ASSERT(output.equal(outputref)); |
629 | } |
630 | |
631 | void testDefaultArgsPinv2(int num_args) { |
632 | Module m("m" ); |
633 | if (num_args == 1) { |
634 | m.define(R"( |
635 | def forward(self, input): |
636 | return torch.linalg_pinv(input) |
637 | )" ); |
638 | } else if (num_args == 2) { |
639 | m.define(R"( |
640 | def forward(self, input): |
641 | return torch.linalg_pinv(input, 1e-5) |
642 | )" ); |
643 | } else if (num_args == 3) { |
644 | m.define(R"( |
645 | def forward(self, input): |
646 | return torch.linalg_pinv(input, 1e-5, True) |
647 | )" ); |
648 | } |
649 | |
650 | std::vector<torch::jit::IValue> inputs; |
651 | const int N = 28; |
652 | auto input = torch::range(1, N * N, 1); |
653 | input[0] = 1; // a more stable matrix |
654 | input = input.view({N, N}); |
655 | inputs.emplace_back(input); |
656 | testLiteModuleCompareResultTensors(m, inputs); |
657 | } |
658 | } // namespace |
659 | |
660 | #if !defined FB_XPLAT_BUILD |
661 | TEST(LiteInterpreterDirectTest, DefaultArgsPinv) { |
662 | // Test with different number of specified arguments. |
663 | // Arguments not specified take default value. |
664 | for (int num_args = 1; num_args <= 3; ++num_args) { |
665 | testDefaultArgsPinv2(num_args); |
666 | } |
667 | |
668 | // bytecode with one specified argument: |
669 | // (6, |
670 | // ('__torch__.m.forward', |
671 | // (('instructions', |
672 | // (('STOREN', 1, 2), |
673 | // ('DROPR', 1, 0), |
674 | // ('MOVE', 2, 0), |
675 | // ('OP', 0, 0), |
676 | // ('RET', 0, 0))), |
677 | // ('operators', (('aten::linalg_pinv', '', 1),)), |
678 | // ('constants', (False, 1e-15)), # default constants are not |
679 | // used |
680 | // ('types', ()), |
681 | // ('register_size', 2)), |
682 | // (('arguments', |
683 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
684 | // None)), |
685 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
686 | // None)))), |
687 | // ('returns', |
688 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
689 | // None)),))))) |
690 | |
691 | // bytecode with 2 specified argument: |
692 | // (6, |
693 | // ('__torch__.m.forward', |
694 | // (('instructions', |
695 | // (('STOREN', 1, 2), |
696 | // ('DROPR', 1, 0), |
697 | // ('MOVE', 2, 0), |
698 | // ('LOADC', 1, 0), # added LOADC for specified argument |
699 | // ('OP', 0, 0), |
700 | // ('RET', 0, 0))), |
701 | // ('operators', (('aten::linalg_pinv', '', 2),)), |
702 | // ('constants', (False, 1e-05)), # updated constant table |
703 | // ('types', ()), |
704 | // ('register_size', 2)), |
705 | // (('arguments', |
706 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
707 | // None)), |
708 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
709 | // None)))), |
710 | // ('returns', |
711 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
712 | // None)),))))) |
713 | |
714 | // bytecode with 3 specified arguments: |
715 | // (6, |
716 | // ('__torch__.m.forward', |
717 | // (('instructions', |
718 | // (('STOREN', 1, 2), |
719 | // ('DROPR', 1, 0), |
720 | // ('MOVE', 2, 0), |
721 | // ('LOADC', 1, 0), |
722 | // ('LOADC', 0, 0), |
723 | // ('OP', 0, 0), |
724 | // ('RET', 0, 0))), |
725 | // ('operators', (('aten::linalg_pinv', '', 3),)), |
726 | // ('constants', (True, 1e-05)), |
727 | // ('types', ()), |
728 | // ('register_size', 2)), |
729 | // (('arguments', |
730 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
731 | // None)), |
732 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
733 | // None)))), |
734 | // ('returns', |
735 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
736 | // None)),))))) |
737 | } |
738 | |
739 | TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) { |
740 | // The second argument is specified, but the value is the same as the default |
741 | // value. It's treated as "not specified" since the value can be fetched from |
742 | // schema. |
743 | Module m("m" ); |
744 | m.define(R"( |
745 | def forward(self, input): |
746 | return torch.linalg_tensorinv(input, 2) |
747 | )" ); |
748 | torch::jit::MobileCode code(m.get_method("forward" ).graph(), "forward" ); |
749 | auto arg_nums = code.op_to_num_specified_args(); |
750 | ASSERT_EQ(arg_nums.size(), 1); |
751 | ASSERT_EQ(arg_nums["aten::linalg_tensorinv" ], 1); |
752 | std::vector<torch::jit::IValue> inputs; |
753 | const int N = 4; |
754 | auto input = torch::rand({N, N, N, N}); |
755 | inputs.emplace_back(input); |
756 | testLiteModuleCompareResultTensors(m, inputs); |
757 | } |
758 | |
759 | void testDefaultArgsPinvWithOutArg2(int num_args) { |
760 | Module m("m" ); |
761 | if (num_args == 1) { |
762 | m.define(R"( |
763 | def forward(self, input): |
764 | return torch.linalg_pinv(input, out=input) |
765 | )" ); |
766 | } else if (num_args == 2) { |
767 | m.define(R"( |
768 | def forward(self, input): |
769 | return torch.linalg_pinv(input, 1e-5, out=input) |
770 | )" ); |
771 | } else if (num_args == 3) { |
772 | m.define(R"( |
773 | def forward(self, input): |
774 | return torch.linalg_pinv(input, 1e-5, True, out=input) |
775 | )" ); |
776 | } |
777 | |
778 | const int N = 28; |
779 | auto input = torch::range(1, N * N, 1); |
780 | input[0] = 10000; // a more stable matrix |
781 | input = input.view({N, N}); |
782 | auto ref = m.run_method("forward" , input); |
783 | TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); |
784 | TORCH_CHECK(input.equal(ref.toTensor())); |
785 | } |
786 | |
787 | TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) { |
788 | // Test with different number of specified arguments + out arg. |
789 | // Arguments not specified take default value. |
790 | for (int num_args = 1; num_args <= 3; ++num_args) { |
791 | testDefaultArgsPinvWithOutArg2(num_args); |
792 | } |
793 | } |
794 | |
795 | TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) { |
796 | Module m("m" ); |
797 | m.define(R"( |
798 | def forward(self, x, h): |
799 | torch.add(x, h, out=x) |
800 | )" ); |
801 | |
802 | std::vector<IValue> inputs; |
803 | auto input_x = 2 * torch::ones({}); |
804 | auto input_h = torch::ones({}); |
805 | auto ref = m.run_method("forward" , input_x, input_h); |
806 | |
807 | CompilationOptions options; |
808 | mobile::Module bc = jitModuleToMobile(m, options); |
809 | bc.run_method("forward" , input_x, input_h); |
810 | AT_ASSERT(input_x.equal(4 * torch::ones({}))); |
811 | } |
812 | |
813 | TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) { |
814 | Module a("A" ); |
815 | a.define(R"( |
816 | def bar(self, x, y): |
817 | return x + y |
818 | )" ); |
819 | Module b("B" ); |
820 | b.register_module("A0" , a); |
821 | b.define(R"( |
822 | def foo(self, x, y): |
823 | return self.A0.bar(x, y) + 2 |
824 | )" ); |
825 | Module c("C" ); |
826 | c.register_module("B0" , b); |
827 | c.define(R"( |
828 | def forward(self, x, y): |
829 | return self.B0.foo(x, y) + 3 |
830 | )" ); |
831 | |
832 | std::vector<IValue> inputs; |
833 | inputs.emplace_back(torch::rand({2, 4})); |
834 | inputs.emplace_back(torch::rand({13, 9})); |
835 | |
836 | CompilationOptions options; |
837 | auto lite_m = jitModuleToMobile(c, options); |
838 | std::string error_pattern = R"( |
839 | Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add |
840 | Traceback of TorchScript (most recent call last): |
841 | File "<string>", line 3, in <unknown> |
842 | |
843 | def forward(self, x, y): |
844 | return self.B0.foo(x, y) + 3 |
845 | ~~~~~~~~~~~ <--- HERE |
846 | |
847 | File "<string>", line 3, in foo |
848 | |
849 | def foo(self, x, y): |
850 | return self.A0.bar(x, y) + 2 |
851 | ~~~~~~~~~~~ <--- HERE |
852 | |
853 | File "<string>", line 3, in bar |
854 | |
855 | def bar(self, x, y): |
856 | return x + y |
857 | ~~~~~ <--- HERE |
858 | )" ; |
859 | ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern); |
860 | } |
861 | #endif // !defined(FB_XPLAT_BUILD) |
862 | |
863 | namespace { |
864 | static auto reg = |
865 | torch::class_<TorchBindLiteInterpreterDirectTestStruct>( |
866 | "_TorchScriptTesting" , |
867 | "_LiteInterpreterDirectTest" ) |
868 | .def(torch::init<>()) |
869 | .def("get" , &TorchBindLiteInterpreterDirectTestStruct::get) |
870 | .def_pickle( |
871 | // __getattr__ |
872 | [](const c10::intrusive_ptr< |
873 | TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t { |
874 | return 0; |
875 | }, |
876 | // __setattr__ |
877 | [](int64_t) { |
878 | return c10::make_intrusive< |
879 | TorchBindLiteInterpreterDirectTestStruct>(); |
880 | }); |
881 | |
882 | } // namespace |
883 | |
884 | TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) { |
885 | // Create 3 methods: |
886 | // |
887 | // 1. forward() returns a tensor with dtype=torch.int64 (4) |
888 | // 2. forward2() returns a tensor with dtype=torch.float32 (6) |
889 | // 3. forward3() returns a tensor with dtype=torch.float32 but |
890 | // the dtype is inferred by the input tensor's dtype |
891 | // |
892 | // If caching works correctly, then the result from the full-jit |
893 | // module and the lite module will be the same. Otherwise, it |
894 | // will be different if we don't correctly ignore the cache |
895 | // entry for an operator that has a different number of |
896 | // arguments. |
897 | Module m("m" ); |
898 | m.define(R"( |
899 | def forward(self): |
900 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4) |
901 | return ret1.fill_(25) |
902 | )" ); |
903 | m.define(R"( |
904 | def forward2(self): |
905 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6) |
906 | return ret1.fill_(32.0) |
907 | )" ); |
908 | m.define(R"( |
909 | def forward3(self): |
910 | ret1 = torch.new_empty(torch.zeros(10), [10]) |
911 | return ret1.fill_(12.0) |
912 | )" ); |
913 | |
914 | std::vector<torch::jit::IValue> inputs; |
915 | testLiteModuleCompareResultTensors(m, inputs, "forward" ); |
916 | testLiteModuleCompareResultTensors(m, inputs, "forward2" ); |
917 | testLiteModuleCompareResultTensors(m, inputs, "forward3" ); |
918 | } |
919 | |
920 | } // namespace jit |
921 | } // namespace torch |
922 | |