1#include <test/cpp/jit/test_utils.h>
2
3#include <gtest/gtest.h>
4
5#include <c10/core/TensorOptions.h>
6#include <torch/csrc/autograd/generated/variable_factories.h>
7#include <torch/csrc/jit/api/module.h>
8#include <torch/csrc/jit/frontend/resolver.h>
9#include <torch/csrc/jit/mobile/compatibility/backport.h>
10#include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13#include <torch/csrc/jit/mobile/import.h>
14#include <torch/csrc/jit/mobile/interpreter.h>
15#include <torch/csrc/jit/mobile/module.h>
16#include <torch/csrc/jit/mobile/parse_bytecode.h>
17#include <torch/csrc/jit/mobile/parse_operators.h>
18#include <torch/csrc/jit/serialization/export.h>
19#include <torch/csrc/jit/serialization/export_bytecode.h>
20#include <torch/csrc/jit/serialization/import.h>
21#include <torch/custom_class.h>
22#include <torch/torch.h>
23
24#include <unordered_set>
25
26// Tests go in torch::jit
27namespace torch {
28namespace jit {
29
30TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
31 Module m("m");
32 m.define(R"(
33 def forward(self, input: Tensor, scale:float):
34 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
35 )");
36
37 std::vector<IValue> inputs;
38 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
39 inputs.emplace_back(at::Scalar(2.0));
40 auto ref = m.forward(inputs);
41
42 CompilationOptions options;
43 mobile::Module bc = jitModuleToMobile(m, options);
44 IValue res;
45 res = bc.forward(inputs);
46
47 auto resd = res.toTensor();
48 auto refd = ref.toTensor();
49 ASSERT_TRUE(resd.equal(refd));
50}
51
52TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
53 Module m("m");
54 m.register_attribute("mobile_optimized", BoolType::get(), true);
55
56 CompilationOptions options;
57 mobile::Module bc = jitModuleToMobile(m, options);
58 bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
59
60 AT_ASSERT(mobile_optimized);
61 m.setattr("mobile_optimized", false);
62 bc = jitModuleToMobile(m, options);
63 mobile_optimized = bc.attr("mobile_optimized", false).toBool();
64 AT_ASSERT(!mobile_optimized);
65}
66
67TEST(
68 LiteInterpreterDirectTest,
69 MethodInvocation) { // NOLINT (use =delete in gtest)
70 const std::vector<std::string> test_programs{
71 // test invoking a method with default parameter
72 R"(
73 def test_func(self, x, b : int = 4):
74 return self.foo + x + b
75 )",
76 // inner method call with default parameter (gets inlined)
77 R"(
78 def add_with_default_arg(self, x, b : int = 4):
79 return self.foo + x + b
80 def test_func(self, x):
81 return self.add_with_default_arg(x) # invoke method w/ default arg
82 )",
83 // simple method call
84 R"(
85 def test_func(self, x):
86 b = 4
87 return self.foo + x + b
88 )",
89 };
90 for (const auto& test_program : test_programs) {
91 Module m("m");
92 m.register_parameter("foo", torch::ones({}), false);
93 m.define(test_program);
94
95 const int fortyTwo = 42; // (keep linter happy)
96 auto minput = fortyTwo * torch::ones({});
97 auto ref = m.run_method("test_func", minput);
98
99 CompilationOptions options;
100 mobile::Module bc = jitModuleToMobile(m, options);
101 const auto& test_func = bc.get_method("test_func");
102 std::cerr << "hello " << std::endl;
103 IValue res;
104 for (int i = 0; i < 3; ++i) {
105 res = test_func({minput});
106 }
107 std::cerr << "hello 3" << std::endl;
108
109 auto resd = res.toTensor().item<float>();
110 auto refd = ref.toTensor().item<float>();
111 AT_ASSERT(resd == refd);
112 }
113}
114
115TEST(LiteInterpreterDirectTest, Conv) {
116 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
117 if (s && strcmp(s, "1") == 0)
118 return;
119
120 std::vector<torch::jit::IValue> inputs;
121
122 Module m("m");
123 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
124 m.register_parameter("bias", torch::ones({20}), false);
125 m.define(R"(
126 def forward(self, input):
127 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
128 )");
129
130 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
131 inputs.push_back(torch::ones({1, 1, 28, 28}));
132
133 auto outputref = m.forward(inputs).toTensor();
134
135 CompilationOptions options;
136 mobile::Module bc = jitModuleToMobile(m, options);
137 IValue res;
138 for (int i = 0; i < 3; ++i) {
139 res = bc.get_method("forward")(inputs);
140 }
141 auto output = res.toTensor();
142 AT_ASSERT(outputref.dim() == output.dim());
143 AT_ASSERT(
144 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
145}
146
147TEST(LiteInterpreterDirectTest, Inline) {
148 Module m("m");
149 m.define(R"JIT(
150 def foo1(self, x):
151 return x + 1
152
153 def foo2(self, x):
154 return self.foo1(x) + 2
155
156 def foo3(self, x):
157 return self.foo2(x) + 3
158 )JIT");
159 CompilationOptions options;
160 mobile::Module bc = jitModuleToMobile(m, options);
161 std::vector<torch::jit::IValue> inputs({torch::ones({})});
162 auto output = bc.get_method("foo3")(inputs);
163 AT_ASSERT(output.toTensor().item<float>() == 7.0);
164}
165
166TEST(LiteInterpreterDirectTest, Tuple) {
167 Module m("m");
168 m.define(R"JIT(
169 def foo(self, x):
170 return (1, 2, x + 3)
171
172 def forward(self, x):
173 tuple = self.foo(x)
174 return tuple
175 )JIT");
176 CompilationOptions options;
177 mobile::Module bc = jitModuleToMobile(m, options);
178 std::vector<torch::jit::IValue> inputs({torch::ones({})});
179 auto output = bc.get_method("forward")(inputs);
180 AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
181}
182
183TEST(LiteInterpreterDirectTest, Dict) {
184 Module m("m");
185 m.define(R"JIT(
186 def foo(self, x):
187 return {"result": x + 1}
188
189 def forward(self, x):
190 d = self.foo(x)
191 return d
192 )JIT");
193 CompilationOptions options;
194 mobile::Module bc = jitModuleToMobile(m, options);
195 std::vector<torch::jit::IValue> inputs({torch::ones({})});
196 auto output = bc.get_method("forward")(inputs);
197 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
198}
199
200TEST(LiteInterpreterDirectTest, Prim) {
201 Module m("m");
202 m.define(R"JIT(
203 def forward(self, x):
204 return int(x)
205 )JIT");
206
207 std::vector<IValue> inputs;
208 auto minput = 3.5 * torch::ones({});
209 inputs.emplace_back(minput);
210 auto ref = m.run_method("forward", minput);
211
212 CompilationOptions options;
213 mobile::Module bc = jitModuleToMobile(m, options);
214
215 IValue res;
216 for (int i = 0; i < 3; ++i) {
217 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
218 auto bcinputs = inputs;
219 res = bc.get_method("forward")(bcinputs);
220 }
221
222 auto resi = res.toInt();
223 auto refi = ref.toInt();
224 AT_ASSERT(resi == refi);
225}
226
227TEST(LiteInterpreterDirectTest, PrimScalar) {
228 Module m("m");
229 m.define(R"JIT(
230 def forward(self, x):
231 return int(x.item())
232 )JIT");
233
234 std::vector<IValue> inputs;
235 auto minput = 3.5 * torch::ones({});
236 inputs.emplace_back(minput);
237 auto ref = m.run_method("forward", minput);
238
239 CompilationOptions options;
240 mobile::Module bc = jitModuleToMobile(m, options);
241 IValue res;
242 for (int i = 0; i < 3; ++i) {
243 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
244 auto bcinputs = inputs;
245 res = bc.get_method("forward")(bcinputs);
246 }
247
248 auto resi = res.toInt();
249 auto refi = ref.toInt();
250 AT_ASSERT(resi == refi);
251}
252
253TEST(LiteInterpreterDirectTest, WrongMethodName) {
254 Module m("m");
255 m.register_parameter("foo", torch::ones({}), false);
256 m.define(R"(
257 def add(self, x):
258 b = 4
259 return self.foo + x + b
260 )");
261 CompilationOptions options;
262 mobile::Module bc = jitModuleToMobile(m, options);
263 std::vector<IValue> inputs;
264 auto minput = 5 * torch::ones({});
265 inputs.emplace_back(minput);
266 ASSERT_THROWS_WITH_MESSAGE(
267 bc.get_method("forward")(inputs), "is not defined");
268}
269
270TEST(LiteInterpreterDirectTest, SetState) {
271 Module m("m");
272 m.register_parameter("foo", torch::ones({}), false);
273 m.define(R"(
274 def __getstate__(self):
275 return self.foo
276 def __setstate__(self, a):
277 self.foo = a
278 def forward(self, x):
279 b = 4
280 return self.foo + x + b
281 )");
282
283 std::vector<IValue> inputs;
284 auto minput = 5 * torch::ones({});
285 inputs.emplace_back(minput);
286
287 std::stringstream ms;
288 m.save(ms);
289 auto loaded_m = load(ms);
290 auto ref = loaded_m.run_method("forward", minput);
291
292 CompilationOptions options;
293 mobile::Module bc = jitModuleToMobile(m, options);
294 IValue res;
295 for (int i = 0; i < 3; ++i) {
296 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
297 auto bcinputs = inputs;
298 res = bc.get_method("forward")(bcinputs);
299 }
300
301 auto resd = res.toTensor().item<float>();
302 auto refd = ref.toTensor().item<float>();
303 AT_ASSERT(resd == refd);
304}
305
306class TorchBindLiteInterpreterDirectTestStruct
307 : public torch::jit::CustomClassHolder {
308 public:
309 std::string get(at::Tensor t) {
310 std::stringstream ss;
311 ss << "Hello! Your tensor has ";
312 ss << t.numel();
313 ss << " elements!";
314 return ss.str();
315 }
316};
317
318namespace {
319struct ClassNamespaceValue : public SugaredValue {
320 explicit ClassNamespaceValue(c10::QualifiedName name)
321 : basename_(std::move(name)) {}
322
323 std::shared_ptr<SugaredValue> attr(
324 const SourceRange&,
325 GraphFunction&,
326 const std::string& name) override {
327 const auto fullName = c10::QualifiedName(basename_, name);
328
329 // Check to see if it is a custom class.
330 if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
331 return std::make_shared<ClassValue>(custom_class);
332 }
333
334 // If it's not a custom class, assume it's another namespace
335 // NOLINTNEXTLINE(performance-move-const-arg)
336 return std::make_shared<ClassNamespaceValue>(fullName);
337 }
338
339 std::string kind() const override {
340 return "Class Namespace";
341 }
342
343 private:
344 c10::QualifiedName basename_;
345};
346
347struct TestModuleResolver : public Resolver {
348 std::shared_ptr<SugaredValue> resolveValue(
349 const std::string& name,
350 GraphFunction&,
351 const SourceRange&) override {
352 if (name == "torch") {
353 return std::make_shared<BuiltinModule>("aten");
354 } else if (name == "__torch__") {
355 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
356 }
357
358 return nullptr;
359 }
360
361 TypePtr resolveType(const std::string&, const SourceRange&) override {
362 return nullptr;
363 }
364};
365} // namespace
366
367TEST(LiteInterpreterDirectTest, BuiltinFunction) {
368 script::Module m("m");
369 auto custom_class_obj =
370 make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
371 m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
372 m.define(R"(
373 def forward(self, x) -> str:
374 return self.my_obj.get(x)
375 )");
376
377 CompilationOptions options;
378 mobile::Module bc = jitModuleToMobile(m, options);
379 auto res =
380 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
381 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
382 auto str = res.toStringRef();
383 std::string expected = "Hello! Your tensor has 12 elements!";
384 AT_ASSERT(str == expected);
385}
386
387#if !defined FB_XPLAT_BUILD
388TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
389 auto runtime_bytecode_version = _get_runtime_bytecode_version();
390 AT_ASSERT(
391 runtime_bytecode_version ==
392 caffe2::serialize::kMaxSupportedBytecodeVersion);
393}
394
395TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
396 auto runtime_operators_version = _get_runtime_operators_min_max_versions();
397 AT_ASSERT(
398 runtime_operators_version.first ==
399 caffe2::serialize::kMinSupportedFileFormatVersion &&
400 runtime_operators_version.second ==
401 caffe2::serialize::kMaxSupportedFileFormatVersion);
402}
403
404/**
405 * The test below is disarmed for FB internal xplat builds since
406 * BUCK requires us to pass in the script_module_v4.ptl file in
407 * as a resource dependency of the build rule for this file, and
408 * we would need to access it via the C++ Resources API instead
409 * of directly reading from disk (which is what the open source
410 * build/run does).
411 */
412TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
413 std::string filePath(__FILE__);
414 auto test_model_file_v4 =
415 filePath.substr(0, filePath.find_last_of("/\\") + 1);
416 test_model_file_v4.append("script_module_v4.ptl");
417
418 auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
419 AT_ASSERT(version_v4 == 4);
420}
421
422#endif // !defined(FB_XPLAT_BUILD)
423
424TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
425 auto runtime_ops = _get_runtime_ops_and_info();
426 // Ballpark estimate of the minimal number of ops; just used to
427 // verify API returns a reasonably large number.
428 AT_ASSERT(runtime_ops.size() > 2900);
429}
430
431TEST(LiteInterpreterDirectTest, Eval) {
432 std::vector<torch::jit::IValue> inputs;
433
434 Module m("m");
435 m.define(R"(
436 def __init__(self, x):
437 self.training = True
438
439 def forward(self, input):
440 return torch.dropout(input, 1.0, self.training)
441 )");
442
443 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
444 inputs.push_back(torch::ones({1, 1, 28, 28}));
445 m.eval();
446 auto outputref = m.forward(inputs).toTensor();
447
448 // save m in training mode to make sure that mobile eval() will correctly
449 // change back to eval mode
450 m.train();
451 CompilationOptions options;
452 mobile::Module bc = jitModuleToMobile(m, options);
453 bc.eval();
454 IValue res;
455 for (int i = 0; i < 3; ++i) {
456 res = bc.get_method("forward")(inputs);
457 }
458 auto output = res.toTensor();
459 AT_ASSERT(outputref.dim() == output.dim());
460 AT_ASSERT(
461 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
462}
463
464TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
465 Module m("m");
466 m.register_parameter("foo", torch::ones({}), false);
467 m.define(R"(
468 def add(self, x):
469 b = 4
470 return self.foo + x + b
471 )");
472 CompilationOptions options;
473 mobile::Module bc = jitModuleToMobile(m, options);
474 ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
475}
476
477TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
478 Module m("m");
479 m.register_parameter("foo", torch::ones({}), false);
480 m.define(R"(
481 def add_it(self, x):
482 b = 4
483 return self.foo + x + b
484 )");
485
486 std::vector<IValue> inputs;
487 auto minput = 5 * torch::ones({});
488 inputs.emplace_back(minput);
489 auto ref = m.get_method("add_it")(inputs);
490
491 CompilationOptions options;
492 mobile::Module bc = jitModuleToMobile(m, options);
493 IValue res;
494 for (int i = 0; i < 3; ++i) {
495 auto bcinputs = inputs;
496 auto method = bc.find_method("add_it");
497 AT_ASSERT(method != c10::nullopt);
498 res = (*method)(std::move(bcinputs));
499 }
500
501 auto resd = res.toTensor().item<float>();
502 auto refd = ref.toTensor().item<float>();
503 AT_ASSERT(resd == refd);
504}
505
506TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
507 Module m("m");
508 m.register_parameter("foo", torch::ones({}), false);
509 m.define(R"(
510 def add_three(self, x, y):
511 return self.foo + x + y
512 )");
513
514 std::vector<IValue> inputs;
515 auto inputx = 5 * torch::ones({});
516 auto inputy = 4 * torch::ones({});
517 auto ref = m.run_method("add_three", inputx, inputy);
518
519 CompilationOptions options;
520 mobile::Module bc = jitModuleToMobile(m, options);
521 IValue res = bc.run_method("add_three", inputx, inputy);
522
523 auto resd = res.toTensor().item<float>();
524 auto refd = ref.toTensor().item<float>();
525 AT_ASSERT(resd == refd);
526}
527
528TEST(LiteInterpreterDirectTest, DuplicateSetState) {
529 Module m("M");
530 m.register_parameter("foo", torch::ones({}), false);
531 m.define(R"(
532 def __getstate__(self):
533 return self.foo + self.foo
534 def __setstate__(self, a):
535 self.foo = a
536 def forward(self, x):
537 b = 4
538 return self.foo + x + b
539 )");
540
541 Module b("B");
542 b.register_module("M0", m);
543 b.register_module("M1", m);
544 b.define(R"(
545 def forward(self, x):
546 return self.M0.forward(x) + self.M1.forward(x)
547 )");
548
549 CompilationOptions options;
550 mobile::Module bc = jitModuleToMobile(m, options);
551 const auto methods = bc.get_methods();
552 const size_t expected_n = 3;
553 ASSERT_EQ(methods.size(), expected_n);
554}
555
556TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
557 torch::jit::Module m("m");
558 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
559 m.register_parameter("bias", torch::ones({20}), false);
560 m.define(R"(
561 def forward(self, input):
562 x1 = torch.zeros(2, 2)
563 x2 = torch.empty_like(torch.empty(2, 2))
564 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
565 return (x1, x2, x3)
566 )");
567 m.eval();
568
569 CompilationOptions options;
570 mobile::Module ptl_model = jitModuleToMobile(m, options);
571 std::set<std::string> operator_names =
572 torch::jit::mobile::_export_operator_list(ptl_model);
573 std::set<std::string> expected_operator_names = {
574 "aten::_convolution",
575 "aten::empty.memory_format",
576 "aten::empty_like",
577 "aten::zeros",
578 };
579 EXPECT_EQ(operator_names, expected_operator_names)
580 << "Expected the root operator lists to be the same";
581}
582
583TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
584 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
585 if (s && strcmp(s, "1") == 0)
586 return;
587
588 std::vector<torch::jit::IValue> inputs;
589
590 Module m("m");
591 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
592 m.register_parameter("bias", torch::ones({20}), false);
593 m.define(R"(
594 def forward(self, input):
595 return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
596 )");
597
598 inputs.emplace_back(torch::ones({1, 1, 28, 28}));
599
600 auto outputref = m.forward(inputs).toTensor();
601
602 CompilationOptions options;
603 mobile::Module bc = jitModuleToMobile(m, options);
604 IValue res;
605 for (int i = 0; i < 1; ++i) {
606 res = bc.get_method("forward")(inputs);
607 }
608 auto output = res.toTensor();
609 AT_ASSERT(outputref.dim() == output.dim());
610 AT_ASSERT(output.equal(outputref));
611}
612
613namespace {
614void testLiteModuleCompareResultTensors(
615 Module& m,
616 const std::vector<torch::jit::IValue>& inputs,
617 const std::string& method_name = "forward") {
618 auto outputref = m.get_method(method_name)(inputs).toTensor();
619
620 CompilationOptions options;
621 mobile::Module bc = jitModuleToMobile(m, options);
622 IValue res;
623 for (int i = 0; i < 3; ++i) {
624 res = bc.get_method(method_name)(inputs);
625 }
626 auto output = res.toTensor();
627 AT_ASSERT(outputref.dim() == output.dim());
628 AT_ASSERT(output.equal(outputref));
629}
630
631void testDefaultArgsPinv2(int num_args) {
632 Module m("m");
633 if (num_args == 1) {
634 m.define(R"(
635 def forward(self, input):
636 return torch.linalg_pinv(input)
637 )");
638 } else if (num_args == 2) {
639 m.define(R"(
640 def forward(self, input):
641 return torch.linalg_pinv(input, 1e-5)
642 )");
643 } else if (num_args == 3) {
644 m.define(R"(
645 def forward(self, input):
646 return torch.linalg_pinv(input, 1e-5, True)
647 )");
648 }
649
650 std::vector<torch::jit::IValue> inputs;
651 const int N = 28;
652 auto input = torch::range(1, N * N, 1);
653 input[0] = 1; // a more stable matrix
654 input = input.view({N, N});
655 inputs.emplace_back(input);
656 testLiteModuleCompareResultTensors(m, inputs);
657}
658} // namespace
659
660#if !defined FB_XPLAT_BUILD
661TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
662 // Test with different number of specified arguments.
663 // Arguments not specified take default value.
664 for (int num_args = 1; num_args <= 3; ++num_args) {
665 testDefaultArgsPinv2(num_args);
666 }
667
668 // bytecode with one specified argument:
669 // (6,
670 // ('__torch__.m.forward',
671 // (('instructions',
672 // (('STOREN', 1, 2),
673 // ('DROPR', 1, 0),
674 // ('MOVE', 2, 0),
675 // ('OP', 0, 0),
676 // ('RET', 0, 0))),
677 // ('operators', (('aten::linalg_pinv', '', 1),)),
678 // ('constants', (False, 1e-15)), # default constants are not
679 // used
680 // ('types', ()),
681 // ('register_size', 2)),
682 // (('arguments',
683 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
684 // None)),
685 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
686 // None)))),
687 // ('returns',
688 // ((('name', ''), ('type', 'Tensor'), ('default_value',
689 // None)),)))))
690
691 // bytecode with 2 specified argument:
692 // (6,
693 // ('__torch__.m.forward',
694 // (('instructions',
695 // (('STOREN', 1, 2),
696 // ('DROPR', 1, 0),
697 // ('MOVE', 2, 0),
698 // ('LOADC', 1, 0), # added LOADC for specified argument
699 // ('OP', 0, 0),
700 // ('RET', 0, 0))),
701 // ('operators', (('aten::linalg_pinv', '', 2),)),
702 // ('constants', (False, 1e-05)), # updated constant table
703 // ('types', ()),
704 // ('register_size', 2)),
705 // (('arguments',
706 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
707 // None)),
708 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
709 // None)))),
710 // ('returns',
711 // ((('name', ''), ('type', 'Tensor'), ('default_value',
712 // None)),)))))
713
714 // bytecode with 3 specified arguments:
715 // (6,
716 // ('__torch__.m.forward',
717 // (('instructions',
718 // (('STOREN', 1, 2),
719 // ('DROPR', 1, 0),
720 // ('MOVE', 2, 0),
721 // ('LOADC', 1, 0),
722 // ('LOADC', 0, 0),
723 // ('OP', 0, 0),
724 // ('RET', 0, 0))),
725 // ('operators', (('aten::linalg_pinv', '', 3),)),
726 // ('constants', (True, 1e-05)),
727 // ('types', ()),
728 // ('register_size', 2)),
729 // (('arguments',
730 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
731 // None)),
732 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
733 // None)))),
734 // ('returns',
735 // ((('name', ''), ('type', 'Tensor'), ('default_value',
736 // None)),)))))
737}
738
739TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
740 // The second argument is specified, but the value is the same as the default
741 // value. It's treated as "not specified" since the value can be fetched from
742 // schema.
743 Module m("m");
744 m.define(R"(
745 def forward(self, input):
746 return torch.linalg_tensorinv(input, 2)
747 )");
748 torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
749 auto arg_nums = code.op_to_num_specified_args();
750 ASSERT_EQ(arg_nums.size(), 1);
751 ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
752 std::vector<torch::jit::IValue> inputs;
753 const int N = 4;
754 auto input = torch::rand({N, N, N, N});
755 inputs.emplace_back(input);
756 testLiteModuleCompareResultTensors(m, inputs);
757}
758
759void testDefaultArgsPinvWithOutArg2(int num_args) {
760 Module m("m");
761 if (num_args == 1) {
762 m.define(R"(
763 def forward(self, input):
764 return torch.linalg_pinv(input, out=input)
765 )");
766 } else if (num_args == 2) {
767 m.define(R"(
768 def forward(self, input):
769 return torch.linalg_pinv(input, 1e-5, out=input)
770 )");
771 } else if (num_args == 3) {
772 m.define(R"(
773 def forward(self, input):
774 return torch.linalg_pinv(input, 1e-5, True, out=input)
775 )");
776 }
777
778 const int N = 28;
779 auto input = torch::range(1, N * N, 1);
780 input[0] = 10000; // a more stable matrix
781 input = input.view({N, N});
782 auto ref = m.run_method("forward", input);
783 TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
784 TORCH_CHECK(input.equal(ref.toTensor()));
785}
786
787TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
788 // Test with different number of specified arguments + out arg.
789 // Arguments not specified take default value.
790 for (int num_args = 1; num_args <= 3; ++num_args) {
791 testDefaultArgsPinvWithOutArg2(num_args);
792 }
793}
794
795TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
796 Module m("m");
797 m.define(R"(
798 def forward(self, x, h):
799 torch.add(x, h, out=x)
800 )");
801
802 std::vector<IValue> inputs;
803 auto input_x = 2 * torch::ones({});
804 auto input_h = torch::ones({});
805 auto ref = m.run_method("forward", input_x, input_h);
806
807 CompilationOptions options;
808 mobile::Module bc = jitModuleToMobile(m, options);
809 bc.run_method("forward", input_x, input_h);
810 AT_ASSERT(input_x.equal(4 * torch::ones({})));
811}
812
813TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
814 Module a("A");
815 a.define(R"(
816 def bar(self, x, y):
817 return x + y
818 )");
819 Module b("B");
820 b.register_module("A0", a);
821 b.define(R"(
822 def foo(self, x, y):
823 return self.A0.bar(x, y) + 2
824 )");
825 Module c("C");
826 c.register_module("B0", b);
827 c.define(R"(
828 def forward(self, x, y):
829 return self.B0.foo(x, y) + 3
830 )");
831
832 std::vector<IValue> inputs;
833 inputs.emplace_back(torch::rand({2, 4}));
834 inputs.emplace_back(torch::rand({13, 9}));
835
836 CompilationOptions options;
837 auto lite_m = jitModuleToMobile(c, options);
838 std::string error_pattern = R"(
839 Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
840Traceback of TorchScript (most recent call last):
841 File "<string>", line 3, in <unknown>
842
843 def forward(self, x, y):
844 return self.B0.foo(x, y) + 3
845 ~~~~~~~~~~~ <--- HERE
846
847 File "<string>", line 3, in foo
848
849 def foo(self, x, y):
850 return self.A0.bar(x, y) + 2
851 ~~~~~~~~~~~ <--- HERE
852
853 File "<string>", line 3, in bar
854
855 def bar(self, x, y):
856 return x + y
857 ~~~~~ <--- HERE
858 )";
859 ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
860}
861#endif // !defined(FB_XPLAT_BUILD)
862
863namespace {
864static auto reg =
865 torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
866 "_TorchScriptTesting",
867 "_LiteInterpreterDirectTest")
868 .def(torch::init<>())
869 .def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
870 .def_pickle(
871 // __getattr__
872 [](const c10::intrusive_ptr<
873 TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
874 return 0;
875 },
876 // __setattr__
877 [](int64_t) {
878 return c10::make_intrusive<
879 TorchBindLiteInterpreterDirectTestStruct>();
880 });
881
882} // namespace
883
884TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
885 // Create 3 methods:
886 //
887 // 1. forward() returns a tensor with dtype=torch.int64 (4)
888 // 2. forward2() returns a tensor with dtype=torch.float32 (6)
889 // 3. forward3() returns a tensor with dtype=torch.float32 but
890 // the dtype is inferred by the input tensor's dtype
891 //
892 // If caching works correctly, then the result from the full-jit
893 // module and the lite module will be the same. Otherwise, it
894 // will be different if we don't correctly ignore the cache
895 // entry for an operator that has a different number of
896 // arguments.
897 Module m("m");
898 m.define(R"(
899 def forward(self):
900 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
901 return ret1.fill_(25)
902 )");
903 m.define(R"(
904 def forward2(self):
905 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
906 return ret1.fill_(32.0)
907 )");
908 m.define(R"(
909 def forward3(self):
910 ret1 = torch.new_empty(torch.zeros(10), [10])
911 return ret1.fill_(12.0)
912 )");
913
914 std::vector<torch::jit::IValue> inputs;
915 testLiteModuleCompareResultTensors(m, inputs, "forward");
916 testLiteModuleCompareResultTensors(m, inputs, "forward2");
917 testLiteModuleCompareResultTensors(m, inputs, "forward3");
918}
919
920} // namespace jit
921} // namespace torch
922