1#include <test/cpp/jit/test_utils.h>
2
3#include <gtest/gtest.h>
4
5#include <c10/core/TensorOptions.h>
6#include <torch/csrc/autograd/generated/variable_factories.h>
7#include <torch/csrc/jit/api/module.h>
8#include <torch/csrc/jit/frontend/resolver.h>
9#include <torch/csrc/jit/mobile/compatibility/backport.h>
10#include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13#include <torch/csrc/jit/mobile/import.h>
14#include <torch/csrc/jit/mobile/interpreter.h>
15#include <torch/csrc/jit/mobile/module.h>
16#include <torch/csrc/jit/mobile/parse_bytecode.h>
17#include <torch/csrc/jit/mobile/parse_operators.h>
18#include <torch/csrc/jit/mobile/upgrader_mobile.h>
19#include <torch/csrc/jit/serialization/export.h>
20#include <torch/csrc/jit/serialization/import.h>
21#include <torch/custom_class.h>
22#include <torch/torch.h>
23
24#include <torch/csrc/jit/serialization/import_export_functions.h>
25#include <unordered_set>
26
27// Tests go in torch::jit
28namespace torch {
29namespace jit {
30
31TEST(LiteInterpreterTest, UpsampleNearest2d) {
32 Module m("m");
33 m.define(R"(
34 def forward(self, input: Tensor, scale:float):
35 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
36 )");
37
38 std::vector<IValue> inputs;
39 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
40 inputs.emplace_back(at::Scalar(2.0));
41 auto ref = m.forward(inputs);
42
43 std::stringstream ss;
44 m._save_for_mobile(ss);
45 mobile::Module bc = _load_for_mobile(ss);
46 IValue res;
47 res = bc.forward(inputs);
48
49 auto resd = res.toTensor();
50 auto refd = ref.toTensor();
51 ASSERT_TRUE(resd.equal(refd));
52}
53
54TEST(LiteInterpreterTest, CheckAttrAccess) {
55 Module m("m");
56 m.register_attribute("mobile_optimized", BoolType::get(), true);
57
58 std::stringstream ss;
59 m._save_for_mobile(ss);
60 mobile::Module bc = _load_for_mobile(ss);
61 bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
62
63 AT_ASSERT(mobile_optimized);
64 m.setattr("mobile_optimized", false);
65 ss = std::stringstream();
66 m._save_for_mobile(ss);
67 bc = _load_for_mobile(ss);
68 mobile_optimized = bc.attr("mobile_optimized", false).toBool();
69
70 AT_ASSERT(!mobile_optimized);
71}
72
73TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
74 const std::vector<std::string> test_programs{
75 // test invoking a method with default parameter
76 R"(
77 def test_func(self, x, b : int = 4):
78 return self.foo + x + b
79 )",
80 // inner method call with default parameter (gets inlined)
81 R"(
82 def add_with_default_arg(self, x, b : int = 4):
83 return self.foo + x + b
84 def test_func(self, x):
85 return self.add_with_default_arg(x) # invoke method w/ default arg
86 )",
87 // simple method call
88 R"(
89 def test_func(self, x):
90 b = 4
91 return self.foo + x + b
92 )",
93 };
94 for (const auto& test_program : test_programs) {
95 Module m("m");
96 m.register_parameter("foo", torch::ones({}), false);
97 m.define(test_program);
98
99 const int fortyTwo = 42; // (keep linter happy)
100 auto minput = fortyTwo * torch::ones({});
101 auto ref = m.run_method("test_func", minput);
102
103 std::stringstream ss;
104 m._save_for_mobile(ss);
105 mobile::Module bc = _load_for_mobile(ss);
106 const auto& test_func = bc.get_method("test_func");
107 IValue res;
108 for (int i = 0; i < 3; ++i) {
109 res = test_func({minput});
110 }
111
112 auto resd = res.toTensor().item<float>();
113 auto refd = ref.toTensor().item<float>();
114 AT_ASSERT(resd == refd);
115 }
116}
117
118TEST(LiteInterpreterTest, Conv) {
119 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
120 if (s && strcmp(s, "1") == 0)
121 return;
122
123 std::vector<torch::jit::IValue> inputs;
124
125 Module m("m");
126 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
127 m.register_parameter("bias", torch::ones({20}), false);
128 m.define(R"(
129 def forward(self, input):
130 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
131 )");
132
133 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
134 inputs.push_back(torch::ones({1, 1, 28, 28}));
135
136 auto outputref = m.forward(inputs).toTensor();
137
138 std::stringstream ss;
139 m._save_for_mobile(ss);
140 mobile::Module bc = _load_for_mobile(ss);
141 IValue res;
142 for (int i = 0; i < 3; ++i) {
143 res = bc.get_method("forward")(inputs);
144 }
145 auto output = res.toTensor();
146 AT_ASSERT(outputref.dim() == output.dim());
147 AT_ASSERT(
148 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
149}
150
151TEST(LiteInterpreterTest, Inline) {
152 Module m("m");
153 m.define(R"JIT(
154 def foo1(self, x):
155 return x + 1
156
157 def foo2(self, x):
158 return self.foo1(x) + 2
159
160 def foo3(self, x):
161 return self.foo2(x) + 3
162 )JIT");
163 std::stringstream ss;
164 m._save_for_mobile(ss);
165 mobile::Module bc = _load_for_mobile(ss);
166 std::vector<torch::jit::IValue> inputs({torch::ones({})});
167 auto output = bc.get_method("foo3")(inputs);
168 AT_ASSERT(output.toTensor().item<float>() == 7.0);
169}
170
171TEST(LiteInterpreterTest, Tuple) {
172 Module m("m");
173 m.define(R"JIT(
174 def foo(self, x):
175 return (1, 2, x + 3)
176
177 def forward(self, x):
178 tuple = self.foo(x)
179 return tuple
180 )JIT");
181 std::stringstream ss;
182 m._save_for_mobile(ss);
183 mobile::Module bc = _load_for_mobile(ss);
184 std::vector<torch::jit::IValue> inputs({torch::ones({})});
185 auto output = bc.get_method("forward")(inputs);
186 AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
187}
188
189TEST(LiteInterpreterTest, AtenFormat) {
190 Module m("m");
191 m.define(R"""(
192 def forward(self, fmt:str="first {} {}", num:str="abc"):
193 x = 2
194 x = x * x
195 return fmt.format(num, x)
196 )""");
197 std::stringstream ss;
198 m._save_for_mobile(ss);
199 mobile::Module bc = _load_for_mobile(ss);
200 std::vector<torch::jit::IValue> inputs;
201 auto output_bc = bc.get_method("forward")(inputs);
202 auto output_m = m.get_method("forward")(inputs);
203 // std::cout << output_m.toStringRef() << "\n"
204 // << output_bc.toStringRef() << std::endl;
205 AT_ASSERT(output_m.toStringRef() == output_bc.toStringRef());
206}
207
208TEST(LiteInterpreterTest, PrimDevice) {
209 Module m("m");
210 m.define(R"""(
211 def forward(self, x:torch.Tensor):
212 return x.device
213 )""");
214 std::stringstream ss;
215 m._save_for_mobile(ss);
216 mobile::Module bc = _load_for_mobile(ss);
217 std::vector<torch::jit::IValue> inputs;
218 auto minput = 3.5 * torch::ones({});
219 inputs.emplace_back(minput);
220 auto output_bc = bc.get_method("forward")(inputs);
221 auto output_m = m.get_method("forward")(inputs);
222 AT_ASSERT(output_bc.toDevice().str() == output_m.toDevice().str());
223}
224
225TEST(LiteInterpreterTest, Dict) {
226 Module m("m");
227 m.define(R"JIT(
228 def foo(self, x):
229 return {"result": x + 1}
230
231 def forward(self, x):
232 d = self.foo(x)
233 return d
234 )JIT");
235 std::stringstream ss;
236 m._save_for_mobile(ss);
237 mobile::Module bc = _load_for_mobile(ss);
238 std::vector<torch::jit::IValue> inputs({torch::ones({})});
239 auto output = bc.get_method("forward")(inputs);
240 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
241}
242
243TEST(LiteInterpreterTest, List) {
244 Module m("m");
245 m.define(R"JIT(
246 def foo(self, x):
247 return [x + 2]
248
249 def forward(self, x):
250 d = self.foo(x)
251 return d
252 )JIT");
253 std::stringstream ss;
254 m._save_for_mobile(ss);
255 mobile::Module bc = _load_for_mobile(ss);
256 std::vector<torch::jit::IValue> inputs({torch::ones({})});
257 auto output = bc.get_method("forward")(inputs);
258 auto server_output = m.forward(inputs);
259 EXPECT_EQ(output.toList().get(0).toTensor().item().toInt(), 3);
260 EXPECT_EQ(output, server_output);
261}
262
263TEST(LiteInterpreterTest, PrimOverload) {
264 /*
265 // temporarily disabled
266 script::Module m("m");
267 m.define(R"JIT(
268 def forward(self, x):
269 result = [1, 2]
270 result.append(3)
271 return result
272 )JIT");
273 std::stringstream ss;
274 m._save_for_mobile(ss);
275 mobile::Module bc = _load_for_mobile(ss);
276 std::vector<torch::jit::IValue> inputs({torch::ones({})});
277 auto output = bc.get_method("forward")(inputs);
278 AT_ASSERT(output.toIntList()[2] == 3);
279 */
280}
281
282TEST(LiteInterpreterTest, Prim) {
283 Module m("m");
284 m.define(R"JIT(
285 def forward(self, x):
286 return int(x)
287 )JIT");
288
289 std::vector<IValue> inputs;
290 auto minput = 3.5 * torch::ones({});
291 inputs.emplace_back(minput);
292 auto ref = m.run_method("forward", minput);
293
294 std::stringstream ss;
295 m._save_for_mobile(ss);
296 mobile::Module bc = _load_for_mobile(ss);
297 IValue res;
298 for (int i = 0; i < 3; ++i) {
299 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
300 auto bcinputs = inputs;
301 res = bc.get_method("forward")(bcinputs);
302 }
303
304 auto resi = res.toInt();
305 auto refi = ref.toInt();
306 AT_ASSERT(resi == refi);
307}
308
309TEST(LiteInterpreterTest, PrimScalar) {
310 Module m("m");
311 m.define(R"JIT(
312 def forward(self, x):
313 return int(x.item())
314 )JIT");
315
316 std::vector<IValue> inputs;
317 auto minput = 3.5 * torch::ones({});
318 inputs.emplace_back(minput);
319 auto ref = m.run_method("forward", minput);
320
321 std::stringstream ss;
322 m._save_for_mobile(ss);
323 mobile::Module bc = _load_for_mobile(ss);
324 IValue res;
325 for (int i = 0; i < 3; ++i) {
326 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
327 auto bcinputs = inputs;
328 res = bc.get_method("forward")(bcinputs);
329 }
330
331 auto resi = res.toInt();
332 auto refi = ref.toInt();
333 AT_ASSERT(resi == refi);
334}
335
336TEST(LiteInterpreterTest, LoadOrigJit) {
337 Module m("m");
338 m.register_parameter("foo", torch::ones({}), false);
339 m.define(R"(
340 def forward(self, x):
341 b = 4
342 return self.foo + x + b
343 )");
344 std::stringstream ss;
345 m.save(ss);
346 ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
347}
348
349TEST(LiteInterpreterTest, WrongMethodName) {
350 Module m("m");
351 m.register_parameter("foo", torch::ones({}), false);
352 m.define(R"(
353 def add(self, x):
354 b = 4
355 return self.foo + x + b
356 )");
357 std::stringstream ss;
358 m._save_for_mobile(ss);
359 mobile::Module bc = _load_for_mobile(ss);
360 std::vector<IValue> inputs;
361 auto minput = 5 * torch::ones({});
362 inputs.emplace_back(minput);
363 ASSERT_THROWS_WITH_MESSAGE(
364 bc.get_method("forward")(inputs), "is not defined");
365}
366
367TEST(LiteInterpreterTest, SetState) {
368 Module m("m");
369 m.register_parameter("foo", torch::ones({}), false);
370 m.define(R"(
371 def __getstate__(self):
372 return self.foo + self.foo
373 def __setstate__(self, a):
374 self.foo = a
375 def forward(self, x):
376 b = 4
377 return self.foo + x + b
378 )");
379
380 std::vector<IValue> inputs;
381 auto minput = 5 * torch::ones({});
382 inputs.emplace_back(minput);
383
384 std::stringstream ms;
385 m.save(ms);
386 auto loaded_m = load(ms);
387 auto ref = loaded_m.run_method("forward", minput);
388
389 std::stringstream ss;
390 m._save_for_mobile(ss);
391 mobile::Module bc = _load_for_mobile(ss);
392 IValue res;
393 for (int i = 0; i < 3; ++i) {
394 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
395 auto bcinputs = inputs;
396 res = bc.get_method("forward")(bcinputs);
397 }
398
399 auto resd = res.toTensor().item<float>();
400 auto refd = ref.toTensor().item<float>();
401 AT_ASSERT(resd == refd);
402}
403
404class TorchBindLiteInterpreterTestStruct
405 : public torch::jit::CustomClassHolder {
406 public:
407 std::string get(at::Tensor t) {
408 std::stringstream ss;
409 ss << "Hello! Your tensor has ";
410 ss << t.numel();
411 ss << " elements!";
412 return ss.str();
413 }
414};
415
416namespace {
417struct ClassNamespaceValue : public SugaredValue {
418 explicit ClassNamespaceValue(c10::QualifiedName name)
419 : basename_(std::move(name)) {}
420
421 std::shared_ptr<SugaredValue> attr(
422 const SourceRange& loc,
423 GraphFunction& m,
424 const std::string& name) override {
425 const auto fullName = c10::QualifiedName(basename_, name);
426
427 // Check to see if it is a custom class.
428 if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
429 return std::make_shared<ClassValue>(custom_class);
430 }
431
432 // If it's not a custom class, assume it's another namespace
433 // NOLINTNEXTLINE(performance-move-const-arg)
434 return std::make_shared<ClassNamespaceValue>(std::move(fullName));
435 }
436
437 std::string kind() const override {
438 return "Class Namespace";
439 }
440
441 private:
442 c10::QualifiedName basename_;
443};
444
445struct TestModuleResolver : public Resolver {
446 std::shared_ptr<SugaredValue> resolveValue(
447 const std::string& name,
448 GraphFunction& m,
449 const SourceRange& loc) override {
450 if (name == "torch") {
451 return std::make_shared<BuiltinModule>("aten");
452 } else if (name == "__torch__") {
453 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
454 }
455
456 return nullptr;
457 }
458
459 TypePtr resolveType(const std::string& name, const SourceRange& loc)
460 override {
461 return nullptr;
462 }
463};
464} // namespace
465
466TEST(LiteInterpreterTest, BuiltinClass) {
467 script::Module m("m");
468
469 auto cls = getCustomClass(
470 "__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest");
471 TORCH_INTERNAL_ASSERT(cls);
472 c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
473 m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
474
475 m.register_parameter("foo", torch::ones({}), false);
476 m.define(
477 R"(
478 def __getstate__(self):
479 return 1
480 def __setstate__(self, a):
481 self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest()
482
483 def forward(self, x) -> str:
484 return self.my_obj.get(x)
485 )",
486 std::make_shared<TestModuleResolver>());
487
488 std::stringstream ss;
489 m._save_for_mobile(ss);
490 mobile::Module bc = _load_for_mobile(ss);
491 auto res =
492 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
493 const auto& str = res.toStringRef();
494 std::string expected = "Hello! Your tensor has 12 elements!";
495 AT_ASSERT(str == expected);
496}
497
498TEST(LiteInterpreterTest, BuiltinFunction) {
499 script::Module m("m");
500 auto custom_class_obj =
501 make_custom_class<TorchBindLiteInterpreterTestStruct>();
502 m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
503 m.define(R"(
504 def forward(self, x) -> str:
505 return self.my_obj.get(x)
506 )");
507
508 std::stringstream ss;
509 m._save_for_mobile(ss);
510 mobile::Module bc = _load_for_mobile(ss);
511 auto res =
512 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
513 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
514 auto str = res.toStringRef();
515 std::string expected = "Hello! Your tensor has 12 elements!";
516 AT_ASSERT(str == expected);
517}
518
519#if !defined FB_XPLAT_BUILD
520TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
521 auto runtime_bytecode_version = _get_runtime_bytecode_version();
522 AT_ASSERT(
523 runtime_bytecode_version ==
524 caffe2::serialize::kMaxSupportedBytecodeVersion);
525}
526
527TEST(LiteInterpreterTest, GetRuntimeOperatorsVersion) {
528 auto runtime_operators_version = _get_runtime_operators_min_max_versions();
529 AT_ASSERT(
530 runtime_operators_version.first ==
531 caffe2::serialize::kMinSupportedFileFormatVersion &&
532 runtime_operators_version.second ==
533 caffe2::serialize::kMaxSupportedFileFormatVersion);
534}
535
536/**
537 * The test below is disarmed for FB internal xplat builds since
538 * BUCK requires us to pass in the script_module_v4.ptl file in
539 * as a resource dependency of the build rule for this file, and
540 * we would need to access it via the C++ Resources API instead
541 * of directly reading from disk (which is what the open source
542 * build/run does).
543 */
544TEST(LiteInterpreterTest, GetByteCodeVersion) {
545 std::string filePath(__FILE__);
546 auto test_model_file_v4 =
547 filePath.substr(0, filePath.find_last_of("/\\") + 1);
548 test_model_file_v4.append("script_module_v4.ptl");
549
550 auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
551 AT_ASSERT(version_v4 == 4);
552}
553
554#endif // !defined(FB_XPLAT_BUILD)
555
556TEST(LiteInterpreterTest, GetContainTypes) {
557 Module m("m");
558 m.define(R"(
559 def forward(self):
560 return 3
561 )");
562
563 std::stringstream ss;
564 m._save_for_mobile(ss, {}, true);
565
566 _get_mobile_model_contained_types(ss);
567}
568
569namespace {
570
571void compareModelOutput(
572 c10::ArrayRef<IValue> actual_result_list,
573 const std::vector<IValue>& expect_result_list) {
574 AT_ASSERT(actual_result_list.size() == expect_result_list.size());
575 AT_ASSERT(
576 actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
577 AT_ASSERT(
578 actual_result_list[1].toTensor().dim() ==
579 expect_result_list[1].toTensor().dim());
580 AT_ASSERT(
581 actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
582 AT_ASSERT(
583 actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
584 ASSERT_EQ(
585 actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
586 ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
587 ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
588 ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
589 AT_ASSERT(
590 actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
591 ASSERT_EQ(
592 actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
593 ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
594 ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
595}
596
597void runAndCheckTorchScriptModel(
598 std::stringstream& input_model_stream,
599 const std::vector<IValue>& input_data,
600 const std::vector<IValue>& expect_result_list,
601 const uint64_t expect_version) {
602 auto actual_version = _get_model_bytecode_version(input_model_stream);
603 AT_ASSERT(actual_version == expect_version);
604
605 // Load and run the backport model, then compare the result with expect
606 // result
607 Module m_mobile = load(input_model_stream);
608
609 auto actual_result = m_mobile.forward(input_data);
610 const auto& actual_result_list = actual_result.toTupleRef().elements();
611 compareModelOutput(actual_result_list, expect_result_list);
612}
613
614void runAndCheckBytecodeModel(
615 std::stringstream& input_model_stream,
616 const std::vector<IValue>& input_data,
617 const std::vector<IValue>& expect_result_list,
618 const uint64_t expect_version) {
619 auto actual_version = _get_model_bytecode_version(input_model_stream);
620 AT_ASSERT(actual_version == expect_version);
621
622 // Load and run the backport model, then compare the result with expect
623 // result
624 Module m_mobile = load(input_model_stream);
625
626 auto actual_result = m_mobile.forward(input_data);
627 const auto& actual_result_list = actual_result.toTupleRef().elements();
628
629 compareModelOutput(actual_result_list, expect_result_list);
630}
631
632void backportAllVersionCheck(
633 std::stringstream& test_model_file_stream,
634 std::vector<IValue>& input_data,
635 std::vector<IValue>& expect_result_list,
636 const uint64_t expect_from_version) {
637 auto from_version = _get_model_bytecode_version(test_model_file_stream);
638 EXPECT_EQ(from_version, expect_from_version);
639 AT_ASSERT(from_version > 0);
640
641 // Backport script_module_v5.ptl to an older version
642 constexpr int64_t minimum_to_version = 4;
643 auto current_to_version = from_version - 1;
644
645 // Verify all candidate to_version work as expected. All backport to version
646 // larger than minimum_to_version should success.
647 while (current_to_version >= minimum_to_version) {
648 // Do not declare std::stringstream oss outside of the while loop as
649 // oss.clear() doesn't reset the stream content, only clears out error state
650 // flag in stringstream causing a problematic stream. Instead, it's cleaner
651 // and safer to just declare a new std::stringstream one and swap them.
652 std::stringstream oss;
653 bool backPortSuccess =
654 _backport_for_mobile(test_model_file_stream, oss, current_to_version);
655 AT_ASSERT(backPortSuccess);
656
657 // Check backport model version
658 auto backport_version = _get_model_bytecode_version(oss);
659 backport_version = _get_model_bytecode_version(oss);
660 AT_ASSERT(backport_version == current_to_version);
661
662 // Load and run the backport model, then compare the result with expect
663 // result
664 runAndCheckBytecodeModel(
665 oss, input_data, expect_result_list, current_to_version);
666 oss.seekg(0, oss.beg);
667 runAndCheckTorchScriptModel(
668 oss, input_data, expect_result_list, current_to_version);
669
670 current_to_version--;
671 }
672 // backport to minimum version - 1 should fail
673 std::stringstream oss;
674 bool backPortSuccess =
675 _backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1);
676 AT_ASSERT(!backPortSuccess);
677}
678} // namespace
679
680#if !defined FB_XPLAT_BUILD
681TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
682 torch::jit::Module module("m");
683 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
684 module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
685 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
686 module.register_parameter("bias", torch::ones({20}), false);
687 module.define(R"(
688 def fn(self, x:float=1.0):
689 return x
690
691 def forward(self, input):
692 x1 = torch.zeros(2, 2)
693 x2 = torch.empty_like(torch.empty(2, 2))
694 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
695 # Add torch.add operator to cover bytecode version bump from 6 to 7
696 # for bytecode version 7, the main change is to support defaults arguments with out arguments
697 x = 2 * torch.ones(1)
698 h = torch.ones(1)
699 torch.add(x, h, out=x)
700 device = torch.ones(1, 1).cpu().device.type
701 is_cuda = x1.is_cuda
702 bool_val = True
703 check_is = [] is None
704 check_is_not = [1] is not None
705 check_not = not bool_val
706 num_to_tensor = torch.tensor([self.fn()])
707 d = {"a": "abc"}
708 check_dict_index = d["a"]
709 check_dim = x1.dim()
710 return (
711 x1, x2, x3, x, device, is_cuda, check_is,
712 check_is_not, num_to_tensor, check_dict_index,
713 check_dim, check_not
714 )
715 )");
716
717 torch::jit::Module module_freeze = freeze(module);
718
719 std::stringstream input_model_stream;
720 module_freeze._save_for_mobile(
721 input_model_stream,
722 /*extra_files=*/{},
723 /*save_mobile_debug_info=*/false,
724 /*use_flatbuffer=*/true);
725 std::vector<IValue> input_data =
726 std::vector<IValue>({torch::ones({1, 1, 28, 28})});
727 std::vector<IValue> expect_result_list;
728 expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
729 expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
730 expect_result_list.emplace_back(
731 at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
732 expect_result_list.emplace_back(3 * at::ones({1}));
733 // "cpu" False, False, True, tensor(1), "abc", 2, False)
734 expect_result_list.emplace_back(c10::IValue("cpu"));
735 expect_result_list.emplace_back(c10::IValue(false));
736 expect_result_list.emplace_back(c10::IValue(false));
737 expect_result_list.emplace_back(c10::IValue(true));
738 expect_result_list.emplace_back(c10::IValue(at::ones({1})));
739 expect_result_list.emplace_back(c10::IValue("abc"));
740 expect_result_list.emplace_back(c10::IValue(2));
741 expect_result_list.emplace_back(c10::IValue(false));
742
743 backportAllVersionCheck(
744 input_model_stream,
745 input_data,
746 expect_result_list,
747 9); // flatbuffer starts at 9
748}
749#endif // !defined(FB_XPLAT_BUILD)
750
751TEST(LiteInterpreterTest, GetRuntimeOpsAndInfo) {
752 auto runtime_ops = _get_runtime_ops_and_info();
753 // Ballpark estimate of the minimal number of ops; just used to
754 // verify API returns a reasonably large number.
755 AT_ASSERT(runtime_ops.size() > 2900);
756}
757
758TEST(LiteInterpreterTest, isCompatibleSuccess) {
759 // test trivial success case
760 auto runtime_info = RuntimeCompatibilityInfo::get();
761 std::unordered_map<std::string, OperatorInfo> model_ops;
762 model_ops["aten::add.Scalar"] = OperatorInfo{2};
763
764 std::unordered_set<std::string> types = {"List", "int", "NamedTuple"};
765 auto model_info = ModelCompatibilityInfo{
766 caffe2::serialize::kMaxSupportedBytecodeVersion,
767 model_ops,
768 types,
769 _get_runtime_bytecode_min_max_versions().first};
770
771 AT_ASSERT(
772 is_compatible(runtime_info, model_info).status ==
773 ModelCompatibilityStatus::OK);
774}
775
776TEST(LiteInterpreterTest, isCompatibleFail) {
777 // test trivial failure due to ops
778 std::unordered_map<std::string, OperatorInfo> model_ops;
779 model_ops["aten::add.Scalar"] = OperatorInfo{2};
780 auto model_info = ModelCompatibilityInfo{
781 caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops};
782 std::unordered_map<std::string, OperatorInfo> runtime_ops;
783 runtime_ops["aten::add.Int"] = OperatorInfo{2};
784 auto runtime_info = RuntimeCompatibilityInfo{
785 std::pair<uint64_t, uint64_t>(
786 caffe2::serialize::kMinSupportedBytecodeVersion,
787 caffe2::serialize::kMaxSupportedBytecodeVersion),
788 runtime_ops,
789 _get_mobile_supported_types()};
790
791 auto result = is_compatible(runtime_info, model_info);
792 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
793 AT_ASSERT(
794 result.errors[0] ==
795 "Operator 'aten::add.Scalar' missing from runtime (not found)");
796
797 // test trivial failure due to bytecode greater than max supported bytecode
798 // version
799 runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
800 runtime_info = RuntimeCompatibilityInfo{
801 std::pair<uint64_t, uint64_t>(
802 caffe2::serialize::kMinSupportedBytecodeVersion,
803 caffe2::serialize::kMaxSupportedBytecodeVersion),
804 runtime_ops,
805 _get_mobile_supported_types()};
806 model_info.bytecode_version =
807 caffe2::serialize::kMaxSupportedBytecodeVersion + 1;
808
809 result = is_compatible(runtime_info, model_info);
810 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
811
812 // test trivial failure due to bytecode less than min supported bytecode
813 // version
814 runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
815 runtime_info = RuntimeCompatibilityInfo{
816 std::pair<uint64_t, uint64_t>(
817 caffe2::serialize::kMinSupportedBytecodeVersion,
818 caffe2::serialize::kMaxSupportedBytecodeVersion),
819 runtime_ops,
820 _get_mobile_supported_types()};
821 model_info.bytecode_version =
822 caffe2::serialize::kMinSupportedBytecodeVersion - 1;
823
824 result = is_compatible(runtime_info, model_info);
825 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
826
827 // test trivial failure due to type
828 runtime_info = RuntimeCompatibilityInfo::get();
829 std::unordered_set<std::string> types = {"List", "int", "Sequence"};
830
831 model_info = ModelCompatibilityInfo{
832 caffe2::serialize::kMaxSupportedBytecodeVersion,
833 model_ops,
834 types,
835 _get_runtime_bytecode_min_max_versions().first};
836
837 AT_ASSERT(
838 is_compatible(runtime_info, model_info).status ==
839 ModelCompatibilityStatus::ERROR);
840
841 // test trivial failure due to operator version
842 runtime_info = RuntimeCompatibilityInfo::get();
843
844 model_info = ModelCompatibilityInfo{
845 caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, {}, 0};
846
847 AT_ASSERT(
848 is_compatible(runtime_info, model_info).status ==
849 ModelCompatibilityStatus::ERROR);
850}
851
852TEST(LiteInterpreterTest, Eval) {
853 std::vector<torch::jit::IValue> inputs;
854
855 Module m("m");
856 m.define(R"(
857 def __init__(self, x):
858 self.training = True
859
860 def forward(self, input):
861 return torch.dropout(input, 1.0, self.training)
862 )");
863
864 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
865 inputs.push_back(torch::ones({1, 1, 28, 28}));
866 m.eval();
867 auto outputref = m.forward(inputs).toTensor();
868
869 // save m in training mode to make sure that mobile eval() will correctly
870 // change back to eval mode
871 m.train();
872 std::stringstream ss;
873 m._save_for_mobile(ss);
874 mobile::Module bc = _load_for_mobile(ss);
875 bc.eval();
876 IValue res;
877 for (int i = 0; i < 3; ++i) {
878 res = bc.get_method("forward")(inputs);
879 }
880 auto output = res.toTensor();
881 AT_ASSERT(outputref.dim() == output.dim());
882 AT_ASSERT(
883 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
884}
885
886TEST(LiteInterpreterTest, FindWrongMethodName) {
887 Module m("m");
888 m.register_parameter("foo", torch::ones({}), false);
889 m.define(R"(
890 def add(self, x):
891 b = 4
892 return self.foo + x + b
893 )");
894 std::stringstream ss;
895 m._save_for_mobile(ss);
896 mobile::Module bc = _load_for_mobile(ss);
897 ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
898}
899
900TEST(LiteInterpreterTest, FindAndRunMethod) {
901 Module m("m");
902 m.register_parameter("foo", torch::ones({}), false);
903 m.define(R"(
904 def add_it(self, x):
905 b = 4
906 return self.foo + x + b
907 )");
908
909 std::vector<IValue> inputs;
910 auto minput = 5 * torch::ones({});
911 inputs.emplace_back(minput);
912 auto ref = m.get_method("add_it")(inputs);
913
914 std::stringstream ss;
915 m._save_for_mobile(ss);
916 mobile::Module bc = _load_for_mobile(ss);
917 IValue res;
918 for (int i = 0; i < 3; ++i) {
919 auto bcinputs = inputs;
920 auto method = bc.find_method("add_it");
921 AT_ASSERT(method != c10::nullopt);
922 res = (*method)(std::move(bcinputs));
923 }
924
925 auto resd = res.toTensor().item<float>();
926 auto refd = ref.toTensor().item<float>();
927 AT_ASSERT(resd == refd);
928}
929
930TEST(LiteInterpreterTest, RunMethodVariadic) {
931 Module m("m");
932 m.register_parameter("foo", torch::ones({}), false);
933 m.define(R"(
934 def add_three(self, x, y):
935 return self.foo + x + y
936 )");
937
938 std::vector<IValue> inputs;
939 auto inputx = 5 * torch::ones({});
940 auto inputy = 4 * torch::ones({});
941 auto ref = m.run_method("add_three", inputx, inputy);
942
943 std::stringstream ss;
944 m._save_for_mobile(ss);
945 mobile::Module bc = _load_for_mobile(ss);
946 IValue res = bc.run_method("add_three", inputx, inputy);
947
948 auto resd = res.toTensor().item<float>();
949 auto refd = ref.toTensor().item<float>();
950 AT_ASSERT(resd == refd);
951}
952
953TEST(LiteInterpreterTest, DuplicateSetState) {
954 Module m("M");
955 m.register_parameter("foo", torch::ones({}), false);
956 m.define(R"(
957 def __getstate__(self):
958 return self.foo + self.foo
959 def __setstate__(self, a):
960 self.foo = a
961 def forward(self, x):
962 b = 4
963 return self.foo + x + b
964 )");
965
966 Module b("B");
967 b.register_module("M0", m);
968 b.register_module("M1", m);
969 b.define(R"(
970 def forward(self, x):
971 return self.M0.forward(x) + self.M1.forward(x)
972 )");
973
974 std::stringstream ss;
975 m._save_for_mobile(ss);
976 mobile::Module bc = _load_for_mobile(ss);
977 const auto methods = bc.get_methods();
978 const size_t expected_n = 3;
979 ASSERT_EQ(methods.size(), expected_n);
980}
981
982TEST(LiteInterpreterTest, ExtraFiles) {
983 const auto script = R"JIT(
984 def forward(self):
985 x = torch.rand(5, 5)
986 x = x.mm(x)
987 return x
988 )JIT";
989
990 auto module =
991 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
992 module->define(script);
993 std::ostringstream oss;
994 std::unordered_map<std::string, std::string> extra_files;
995 extra_files["metadata.json"] = "abc";
996 extra_files["mobile_info.json"] = "{\"key\": 23}";
997 module->_save_for_mobile(oss, extra_files);
998
999 std::istringstream iss(oss.str());
1000 std::unordered_map<std::string, std::string> loaded_extra_files;
1001 loaded_extra_files["metadata.json"] = "";
1002 torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1003 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1004
1005 loaded_extra_files.clear();
1006 std::vector<std::string> all_files =
1007 caffe2::serialize::PyTorchStreamReader(&iss).getAllRecords();
1008
1009 for (auto& file_name : all_files) {
1010 if (file_name.find("extra/") == 0) {
1011 loaded_extra_files[file_name.substr(6)] = "";
1012 }
1013 }
1014 iss.seekg(0, iss.beg);
1015 torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1016 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1017 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
1018}
1019
1020TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
1021 torch::jit::Module m("m");
1022 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1023 m.register_parameter("bias", torch::ones({20}), false);
1024 m.define(R"(
1025 def forward(self, input):
1026 x1 = torch.zeros(2, 2)
1027 x2 = torch.empty_like(torch.empty(2, 2))
1028 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
1029 return (x1, x2, x3)
1030 )");
1031 m.eval();
1032
1033 std::stringstream ss;
1034 m._save_for_mobile(ss);
1035
1036 torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss);
1037 std::set<std::string> operator_names =
1038 torch::jit::mobile::_export_operator_list(ptl_model);
1039 std::set<std::string> expected_operator_names = {
1040 "aten::_convolution",
1041 "aten::empty.memory_format",
1042 "aten::empty_like",
1043 "aten::zeros",
1044 };
1045 EXPECT_EQ(operator_names, expected_operator_names)
1046 << "Expected the root operator lists to be the same";
1047}
1048
1049TEST(LiteInterpreterTest, DefaultArgsConv) {
1050 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
1051 if (s && strcmp(s, "1") == 0)
1052 return;
1053
1054 std::vector<torch::jit::IValue> inputs;
1055
1056 Module m("m");
1057 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1058 m.register_parameter("bias", torch::ones({20}), false);
1059 m.define(R"(
1060 def forward(self, input):
1061 return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
1062 )");
1063
1064 inputs.push_back(torch::ones({1, 1, 28, 28}));
1065
1066 auto outputref = m.forward(inputs).toTensor();
1067
1068 std::stringstream ss;
1069 m._save_for_mobile(ss);
1070 mobile::Module bc = _load_for_mobile(ss);
1071 IValue res;
1072 for (int i = 0; i < 1; ++i) {
1073 res = bc.get_method("forward")(inputs);
1074 }
1075 auto output = res.toTensor();
1076 AT_ASSERT(outputref.dim() == output.dim());
1077 AT_ASSERT(output.equal(outputref));
1078}
1079
1080TEST(RunTimeTest, ParseBytecode) {
1081 // A simple example to show a simple bytecode that can be used independent of
1082 // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1083 // It has basic control flow (if, else) and basic data orchestration (list
1084 // construction). The original PyTorch program:
1085
1086 // class Module(torch.nn.Module):
1087 //
1088 // def __init__(self):
1089 // super().__init__()
1090 //
1091 // def forward(self, x: int, h: int, xfirst: bool):
1092 // if xfirst:
1093 // return [x, h]
1094 // else:
1095 // return [h, x]
1096
1097 // 1. Prepare for the bytecode. In reality it can be from a customized
1098 // deserializer.
1099 std::vector<IValue> instructions{
1100 to_tuple({"STOREN", 1, 4}),
1101 to_tuple({"DROPR", 1, 0}),
1102 to_tuple({"MOVE", 4, 0}),
1103 to_tuple({"JF", 5, 0}),
1104 to_tuple({"LOAD", 2, 0}),
1105 to_tuple({"LOAD", 3, 0}),
1106 to_tuple({"LIST_CONSTRUCT", 0, 2}),
1107 to_tuple({"JMP", 4, 0}),
1108 to_tuple({"LOAD", 3, 0}),
1109 to_tuple({"LOAD", 2, 0}),
1110 to_tuple({"LIST_CONSTRUCT", 1, 2}),
1111 to_tuple({"STORE", 5, 0}),
1112 to_tuple({"DROPR", 3, 0}),
1113 to_tuple({"DROPR", 2, 0}),
1114 to_tuple({"MOVE", 5, 0}),
1115 to_tuple({"RET", 0, 0}),
1116 };
1117 std::vector<IValue> operators; // empty for this example
1118 std::vector<IValue> constants; // empty for this example
1119
1120 std::vector<IValue> types{"List[int]", "List[int]"};
1121 // 2. Parse the function
1122 std::string function_name("test_function");
1123 auto function = std::unique_ptr<mobile::Function>(
1124 new mobile::Function(c10::QualifiedName(function_name)));
1125 c10::ivalue::TupleElements debug_handles_m_tuple;
1126 parseInstructions(
1127 function_name,
1128 std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1129 debug_handles_m_tuple,
1130 function.get());
1131 parseTypes(c10::ivalue::Tuple::create(types)->elements(), function.get());
1132 const size_t rsize = 5;
1133 parseRegisterSize(rsize, function.get());
1134
1135 // 3. Prepare for inputs and run the function
1136 // Note that the first input is reserved for Module object.
1137 // Since this is a function test and Module object is not required,
1138 // a dummy IValue (0) is added here.
1139 std::vector<IValue> inputs{0, 1, 2, true};
1140 function->run(inputs);
1141 auto output = inputs[0].toList();
1142 ASSERT_EQ(output[0], 1);
1143 ASSERT_EQ(output[1], 2);
1144
1145 std::vector<IValue> inputs1{0, 1, 2, false};
1146 function->run(inputs1);
1147 auto output1 = inputs1[0].toList();
1148 ASSERT_EQ(output1[0], 2);
1149 ASSERT_EQ(output1[1], 1);
1150}
1151
1152TEST(RunTimeTest, ParseOperator) {
1153 // A simple example to show a simple bytecode that can be used independent of
1154 // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1155 // It has one operator and we should be able to register it. The original
1156 // PyTorch program:
1157
1158 // class Add(torch.nn.Module):
1159 // def __init__(self):
1160 // super(Add, self).__init__()
1161
1162 // def forward(self, a, b):
1163 // return a + b
1164
1165 // 1. Prepare for the bytecode. In reality it can be from a customized
1166 // deserializer.
1167 std::vector<IValue> instructions{
1168 to_tuple({"STOREN", 1, 3}),
1169 to_tuple({"DROPR", 1, 0}),
1170 to_tuple({"MOVE", 2, 0}),
1171 to_tuple({"MOVE", 3, 0}),
1172 to_tuple({"OP", 0, 0}),
1173 to_tuple({"RET", 0, 0}),
1174 };
1175 std::vector<IValue> operators{
1176 to_tuple({"aten::add", "Tensor", 2}),
1177 };
1178 std::vector<IValue> constants{
1179 to_tuple({1}),
1180 };
1181 // 2. Parse the function
1182 std::string function_name("test_function");
1183 auto function = std::unique_ptr<mobile::Function>(
1184 new mobile::Function(c10::QualifiedName(function_name)));
1185 c10::ivalue::TupleElements debug_handles_m_tuple;
1186 parseInstructions(
1187 function_name,
1188 std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1189 debug_handles_m_tuple,
1190 function.get());
1191 parseOperators(
1192 std::move(*c10::ivalue::Tuple::create(operators)).elements(),
1193 1,
1194 function.get());
1195 const size_t rsize = 5;
1196 parseRegisterSize(rsize, function.get());
1197
1198 // 3. Prepare for inputs and run the function
1199 // Note that the first input is reserved for Module object.
1200 // Since this is a function test and Module object is not required,
1201 // a dummy IValue (0) is added here.
1202 std::vector<IValue> inputs{0, at::tensor(1), at::tensor(2)};
1203 function->run(inputs);
1204 auto output = inputs[0];
1205 ASSERT_EQ(output, at::tensor(3));
1206}
1207
1208namespace {
1209void testLiteModuleCompareResultTensors(
1210 Module& m,
1211 const std::vector<torch::jit::IValue>& inputs,
1212 const std::string& method_name = "forward") {
1213 auto outputref = m.get_method(method_name)(inputs).toTensor();
1214
1215 std::stringstream ss;
1216 m._save_for_mobile(ss);
1217 mobile::Module bc = _load_for_mobile(ss);
1218 IValue res;
1219 for (int i = 0; i < 3; ++i) {
1220 res = bc.get_method(method_name)(inputs);
1221 }
1222 auto output = res.toTensor();
1223 AT_ASSERT(outputref.dim() == output.dim());
1224 AT_ASSERT(output.equal(outputref));
1225}
1226
1227void testDefaultArgsPinv(int num_args) {
1228 Module m("m");
1229 if (num_args == 1) {
1230 m.define(R"(
1231 def forward(self, input):
1232 return torch.linalg_pinv(input)
1233 )");
1234 } else if (num_args == 2) {
1235 m.define(R"(
1236 def forward(self, input):
1237 return torch.linalg_pinv(input, 1e-5)
1238 )");
1239 } else if (num_args == 3) {
1240 m.define(R"(
1241 def forward(self, input):
1242 return torch.linalg_pinv(input, 1e-5, True)
1243 )");
1244 }
1245
1246 std::vector<torch::jit::IValue> inputs;
1247 const int N = 28;
1248 auto input = torch::range(1, N * N, 1);
1249 input[0] = 1; // a more stable matrix
1250 input = input.view({N, N});
1251 inputs.push_back(input);
1252 testLiteModuleCompareResultTensors(m, inputs);
1253}
1254} // namespace
1255
1256#if !defined FB_XPLAT_BUILD
1257TEST(LiteInterpreterTest, DefaultArgsPinv) {
1258 // Test with different number of specified arguments.
1259 // Arguments not specified take default value.
1260 for (int num_args = 1; num_args <= 3; ++num_args) {
1261 testDefaultArgsPinv(num_args);
1262 }
1263
1264 // bytecode with one specified argument:
1265 // (6,
1266 // ('__torch__.m.forward',
1267 // (('instructions',
1268 // (('STOREN', 1, 2),
1269 // ('DROPR', 1, 0),
1270 // ('MOVE', 2, 0),
1271 // ('OP', 0, 0),
1272 // ('RET', 0, 0))),
1273 // ('operators', (('aten::linalg_pinv', '', 1),)),
1274 // ('constants', (False, 1e-15)), # default constants are not
1275 // used
1276 // ('types', ()),
1277 // ('register_size', 2)),
1278 // (('arguments',
1279 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1280 // None)),
1281 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1282 // None)))),
1283 // ('returns',
1284 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1285 // None)),)))))
1286
1287 // bytecode with 2 specified argument:
1288 // (6,
1289 // ('__torch__.m.forward',
1290 // (('instructions',
1291 // (('STOREN', 1, 2),
1292 // ('DROPR', 1, 0),
1293 // ('MOVE', 2, 0),
1294 // ('LOADC', 1, 0), # added LOADC for specified argument
1295 // ('OP', 0, 0),
1296 // ('RET', 0, 0))),
1297 // ('operators', (('aten::linalg_pinv', '', 2),)),
1298 // ('constants', (False, 1e-05)), # updated constant table
1299 // ('types', ()),
1300 // ('register_size', 2)),
1301 // (('arguments',
1302 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1303 // None)),
1304 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1305 // None)))),
1306 // ('returns',
1307 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1308 // None)),)))))
1309
1310 // bytecode with 3 specified arguments:
1311 // (6,
1312 // ('__torch__.m.forward',
1313 // (('instructions',
1314 // (('STOREN', 1, 2),
1315 // ('DROPR', 1, 0),
1316 // ('MOVE', 2, 0),
1317 // ('LOADC', 1, 0),
1318 // ('LOADC', 0, 0),
1319 // ('OP', 0, 0),
1320 // ('RET', 0, 0))),
1321 // ('operators', (('aten::linalg_pinv', '', 3),)),
1322 // ('constants', (True, 1e-05)),
1323 // ('types', ()),
1324 // ('register_size', 2)),
1325 // (('arguments',
1326 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1327 // None)),
1328 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1329 // None)))),
1330 // ('returns',
1331 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1332 // None)),)))))
1333}
1334
1335TEST(LiteInterpreterTest, DefaultArgsTensorinvSpecifyDefault) {
1336 // The second argument is specified, but the value is the same as the default
1337 // value. It's treated as "not specified" since the value can be fetched from
1338 // schema.
1339 Module m("m");
1340 m.define(R"(
1341 def forward(self, input):
1342 return torch.linalg_tensorinv(input, 2)
1343 )");
1344 torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1345 auto arg_nums = code.op_to_num_specified_args();
1346 ASSERT_EQ(arg_nums.size(), 1);
1347 ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1348 std::vector<torch::jit::IValue> inputs;
1349 const int N = 4;
1350 auto input = torch::rand({N, N, N, N});
1351 inputs.push_back(input);
1352 testLiteModuleCompareResultTensors(m, inputs);
1353}
1354
1355void testDefaultArgsPinvWithOutArg(int num_args) {
1356 Module m("m");
1357 if (num_args == 1) {
1358 m.define(R"(
1359 def forward(self, input):
1360 return torch.linalg_pinv(input, out=input)
1361 )");
1362 } else if (num_args == 2) {
1363 m.define(R"(
1364 def forward(self, input):
1365 return torch.linalg_pinv(input, 1e-5, out=input)
1366 )");
1367 } else if (num_args == 3) {
1368 m.define(R"(
1369 def forward(self, input):
1370 return torch.linalg_pinv(input, 1e-5, True, out=input)
1371 )");
1372 }
1373
1374 const int N = 28;
1375 auto input = torch::range(1, N * N, 1);
1376 input[0] = 10000; // a more stable matrix
1377 input = input.view({N, N});
1378 auto ref = m.run_method("forward", input);
1379 TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1380 TORCH_CHECK(input.equal(ref.toTensor()));
1381}
1382
1383TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) {
1384 // Test with different number of specified arguments + out arg.
1385 // Arguments not specified take default value.
1386 for (int num_args = 1; num_args <= 3; ++num_args) {
1387 testDefaultArgsPinvWithOutArg(num_args);
1388 }
1389}
1390
1391TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
1392 Module m("m");
1393 m.define(R"(
1394 def forward(self, x, h):
1395 torch.add(x, h, out=x)
1396 )");
1397
1398 std::vector<IValue> inputs;
1399 auto input_x = 2 * torch::ones({});
1400 auto input_h = torch::ones({});
1401 auto ref = m.run_method("forward", input_x, input_h);
1402
1403 std::stringstream ss;
1404
1405 m._save_for_mobile(ss, {}, true);
1406 mobile::Module bc = _load_for_mobile(ss);
1407 bc.run_method("forward", input_x, input_h);
1408 AT_ASSERT(input_x.equal(4 * torch::ones({})));
1409
1410 auto ops = _get_model_ops_and_info(ss);
1411 auto op = ops.find("aten::add.out");
1412 TORCH_CHECK(
1413 op != ops.end() && op->second.num_schema_args.has_value() &&
1414 op->second.num_schema_args.value() == 3);
1415}
1416
1417TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
1418 Module a("A");
1419 a.define(R"(
1420 def bar(self, x, y):
1421 return x + y
1422 )");
1423 Module b("B");
1424 b.register_module("A0", a);
1425 b.define(R"(
1426 def foo(self, x, y):
1427 return self.A0.bar(x, y) + 2
1428 )");
1429 Module c("C");
1430 c.register_module("B0", b);
1431 c.define(R"(
1432 def forward(self, x, y):
1433 return self.B0.foo(x, y) + 3
1434 )");
1435
1436 std::vector<IValue> inputs;
1437 inputs.emplace_back(torch::rand({2, 4}));
1438 inputs.emplace_back(torch::rand({13, 9}));
1439
1440 std::stringstream ss;
1441 c._save_for_mobile(ss, ExtraFilesMap(), true);
1442 auto lite_m = _load_for_mobile(ss);
1443 std::string error_pattern = R"(
1444 Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
1445Traceback of TorchScript (most recent call last):
1446 File "<string>", line 3, in <unknown>
1447
1448 def forward(self, x, y):
1449 return self.B0.foo(x, y) + 3
1450 ~~~~~~~~~~~ <--- HERE
1451
1452 File "<string>", line 3, in foo
1453
1454 def foo(self, x, y):
1455 return self.A0.bar(x, y) + 2
1456 ~~~~~~~~~~~ <--- HERE
1457
1458 File "<string>", line 3, in bar
1459
1460 def bar(self, x, y):
1461 return x + y
1462 ~~~~~ <--- HERE
1463 )";
1464 ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
1465}
1466#endif // !defined(FB_XPLAT_BUILD)
1467
1468namespace {
1469static auto reg =
1470 torch::class_<TorchBindLiteInterpreterTestStruct>(
1471 "_TorchScriptTesting",
1472 "_LiteInterpreterTest")
1473 .def(torch::init<>())
1474 .def("get", &TorchBindLiteInterpreterTestStruct::get)
1475 .def_pickle(
1476 // __getattr__
1477 [](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>&
1478 self) -> int64_t { return 0; },
1479 // __setattr__
1480 [](int64_t state) {
1481 return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>();
1482 });
1483
1484} // namespace
1485
1486TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) {
1487 // Create 3 methods:
1488 //
1489 // 1. forward() returns a tensor with dtype=torch.int64 (4)
1490 // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1491 // 3. forward3() returns a tensor with dtype=torch.float32 but
1492 // the dtype is inferred by the input tensor's dtype
1493 //
1494 // If caching works correctly, then the result from the full-jit
1495 // module and the lite module will be the same. Otherwise, it
1496 // will be different if we don't correctly ignore the cache
1497 // entry for an operator that has a different number of
1498 // arguments.
1499 Module m("m");
1500 m.define(R"(
1501 def forward(self):
1502 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1503 return ret1.fill_(25)
1504 )");
1505 m.define(R"(
1506 def forward2(self):
1507 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1508 return ret1.fill_(32.0)
1509 )");
1510 m.define(R"(
1511 def forward3(self):
1512 ret1 = torch.new_empty(torch.zeros(10), [10])
1513 return ret1.fill_(12.0)
1514 )");
1515
1516 std::vector<torch::jit::IValue> inputs;
1517 testLiteModuleCompareResultTensors(m, inputs, "forward");
1518 testLiteModuleCompareResultTensors(m, inputs, "forward2");
1519 testLiteModuleCompareResultTensors(m, inputs, "forward3");
1520}
1521
1522TEST(RunTimeTest, RuntimeCall) {
1523 // def call(x):
1524 // return x + x
1525 //
1526 // def forward(a):
1527 // x = a + call(a)
1528 // y = a + call(x)
1529 // return y
1530
1531 std::vector<IValue> instructionsCall{
1532 to_tuple({"STORE", 1, 0}),
1533 to_tuple({"LOAD", 1, 0}),
1534 to_tuple({"MOVE", 1, 0}),
1535 to_tuple({"LOADC", 0, 0}),
1536 to_tuple({"OP", 0, 0}),
1537 to_tuple({"RET", 0, 0}),
1538 };
1539 std::vector<IValue> instructionsFoo{
1540 to_tuple({"STORE", 1, 0}),
1541 to_tuple({"LOAD", 1, 0}),
1542 to_tuple({"LOAD", 1, 0}),
1543 to_tuple({"MOVE", 1, 0}),
1544 to_tuple({"CALL", 0, 0}),
1545 to_tuple({"LOADC", 0, 0}),
1546 to_tuple({"OP", 0, 0}),
1547 to_tuple({"CALL", 0, 0}),
1548 to_tuple({"LOADC", 0, 0}),
1549 to_tuple({"OP", 0, 0}),
1550 to_tuple({"RET", 0, 0}),
1551 };
1552 std::vector<IValue> operatorsFoo{
1553 to_tuple({"aten::add", "Tensor", 3}),
1554 };
1555 std::vector<IValue> constantsFoo{
1556 1,
1557 };
1558 std::vector<IValue> operatorsCall{
1559 to_tuple({"aten::add", "Tensor", 3}),
1560 };
1561 std::vector<IValue> constantsCall{
1562 1,
1563 };
1564
1565 auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo"));
1566 c10::ivalue::TupleElements debug_handles_m_tuple;
1567 parseInstructions(
1568 "foo",
1569 std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(),
1570 debug_handles_m_tuple,
1571 foo.get());
1572 parseOperators(
1573 std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(),
1574 1,
1575 foo.get());
1576 parseConstants(
1577 std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(),
1578 foo.get());
1579 const size_t rsize = 5;
1580 parseRegisterSize(rsize, foo.get());
1581
1582 auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call"));
1583 parseInstructions(
1584 "call",
1585 std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(),
1586 debug_handles_m_tuple,
1587 call.get());
1588 parseOperators(
1589 std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(),
1590 1,
1591 call.get());
1592 parseConstants(
1593 std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(),
1594 call.get());
1595 parseRegisterSize(rsize, call.get());
1596
1597 foo->append_function(*call);
1598
1599 std::vector<IValue> inputs{at::tensor(1)};
1600 foo->run(inputs);
1601 auto output = inputs[0];
1602 ASSERT_EQ(output, at::tensor(7));
1603}
1604
1605TEST(LiteInterpreterTest, OperatorSize1) {
1606 Module m("m");
1607 m.define(R"(
1608 def forward(self, input: Tensor, scale:float):
1609 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1610 )");
1611
1612 std::stringstream ss;
1613 m._save_for_mobile(ss);
1614 mobile::Module bc = _load_for_mobile(ss);
1615 const auto& func = bc.get_method("forward").function();
1616 ASSERT_EQ(
1617 func.get_code().operator_input_sizes_.size(),
1618 func.get_code().operators_.size());
1619}
1620
1621TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1622 const std::vector<std::string> test_programs{
1623 // test invoking a method with default parameter
1624 R"(
1625 def test_func(self, x, b : int = 4):
1626 return self.foo + x + b
1627 )",
1628 // inner method call with default parameter (gets inlined)
1629 R"(
1630 def add_with_default_arg(self, x, b : int = 4):
1631 return self.foo + x + b
1632 def test_func(self, x):
1633 return self.add_with_default_arg(x) # invoke method w/ default arg
1634 )",
1635 // simple method call
1636 R"(
1637 def test_func(self, x):
1638 b = 4
1639 return self.foo + x + b
1640 )",
1641 };
1642 for (const auto& test_program : test_programs) {
1643 Module m("m");
1644 m.register_parameter("foo", torch::ones({}), false);
1645 m.define(test_program);
1646
1647 std::stringstream ss;
1648 m._save_for_mobile(ss);
1649 mobile::Module bc = _load_for_mobile(ss);
1650 const auto& func = bc.get_method("test_func").function();
1651 ASSERT_EQ(
1652 func.get_code().operator_input_sizes_.size(),
1653 func.get_code().operators_.size());
1654 }
1655}
1656
1657#if !defined FB_XPLAT_BUILD
1658// The following test run in fbcode only
1659TEST(LiteInterpreterUpgraderTest, DivTensorV2) {
1660 std::string filePath(__FILE__);
1661 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1662 test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl");
1663 /*
1664 (('__torch__.MyModule.forward',
1665 (('instructions',
1666 (('STOREN', 1, 3),
1667 ('DROPR', 1, 0),
1668 ('LOAD', 2, 0),
1669 ('LOAD', 3, 0),
1670 ('OP', 0, 0),
1671 ('LOAD', 2, 0),
1672 ('LOAD', 3, 0),
1673 ('OP', 1, 0),
1674 ('MOVE', 2, 0),
1675 ('MOVE', 3, 0),
1676 ('OP', 2, 0),
1677 ('TUPLE_CONSTRUCT', 3, 0),
1678 ('RET', 0, 0))),
1679 ('operators',
1680 (('aten::div', 'Tensor'),
1681 ('aten::div', 'Tensor'),
1682 ('aten::div', 'Tensor'))),
1683 ('constants', ()),
1684 ('types', ()),
1685 ('register_size', 3))),)
1686
1687 */
1688 mobile::Module m_module = _load_for_mobile(test_model_file);
1689 auto intrsuction_list =
1690 m_module.get_method("forward").function().get_code().instructions_;
1691 uint64_t number_of_call_instruction = 0;
1692 for (auto& instruction : intrsuction_list) {
1693 number_of_call_instruction += (instruction.op == OpCode::CALL);
1694 }
1695 // 3 operators will use upgrader
1696 ASSERT_EQ(number_of_call_instruction, 3);
1697
1698 std::vector<IValue> inputs = {
1699 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1700 auto actual_output = m_module.forward(inputs);
1701 auto expect_output = 2.0 * torch::ones({1});
1702 auto actual_output_list = actual_output.toTuple()->elements();
1703 ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1704}
1705
1706TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) {
1707 std::string filePath(__FILE__);
1708 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1709 test_model_file.append(
1710 "upgrader_models/test_versioned_div_tensor_out_v2.ptl");
1711 /*
1712 (('__torch__.MyModule.forward',
1713 (('instructions',
1714 (('STOREN', 1, 4),
1715 ('DROPR', 1, 0),
1716 ('MOVE', 2, 0),
1717 ('MOVE', 3, 0),
1718 ('MOVE', 4, 0),
1719 ('OP', 0, 0),
1720 ('RET', 0, 0))),
1721 ('operators', (('aten::div', 'out'),)),
1722 ('constants', ()),
1723 ('types', ()),
1724 ('register_size', 4))),)
1725 */
1726 mobile::Module m_module = _load_for_mobile(test_model_file);
1727
1728 auto intrsuction_list =
1729 m_module.get_method("forward").function().get_code().instructions_;
1730 uint64_t number_of_call_instruction = 0;
1731 for (auto& instruction : intrsuction_list) {
1732 number_of_call_instruction += (instruction.op == OpCode::CALL);
1733 }
1734 // One operator will use upgrader
1735 ASSERT_EQ(number_of_call_instruction, 1);
1736
1737 std::vector<IValue> inputs{
1738 IValue(6 * torch::ones({1})),
1739 IValue(3 * torch::ones({1})),
1740 IValue(torch::empty({1}))};
1741 m_module.forward(inputs);
1742 auto expect_output = 2.0 * torch::ones({1});
1743 auto actual_output = inputs[2].toTensor();
1744 // The out argument will be overwritten with the output
1745 ASSERT_TRUE(actual_output.equal(expect_output));
1746}
1747
1748TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) {
1749 std::string filePath(__FILE__);
1750 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1751 test_model_file.append(
1752 "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl");
1753 /*
1754 (('__torch__.MyModule.forward',
1755 (('instructions',
1756 (('STOREN', 1, 3),
1757 ('DROPR', 1, 0),
1758 ('MOVE', 2, 0),
1759 ('MOVE', 3, 0),
1760 ('OP', 0, 0),
1761 ('RET', 0, 0))),
1762 ('operators', (('aten::div_', 'Tensor'),)),
1763 ('constants', ()),
1764 ('types', ()),
1765 ('register_size', 3))),)
1766 */
1767 mobile::Module m_module = _load_for_mobile(test_model_file);
1768
1769 auto intrsuction_list =
1770 m_module.get_method("forward").function().get_code().instructions_;
1771 uint64_t number_of_call_instruction = 0;
1772 for (auto& instruction : intrsuction_list) {
1773 number_of_call_instruction += (instruction.op == OpCode::CALL);
1774 }
1775 // One operator will use upgrader
1776 ASSERT_EQ(number_of_call_instruction, 1);
1777
1778 std::vector<IValue> inputs{
1779 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1780 m_module.forward(inputs);
1781 auto expect_output = 2.0 * torch::ones({1});
1782 auto actual_output = inputs[0].toTensor();
1783 // The out argument will be overwritten with the output
1784 ASSERT_TRUE(actual_output.equal(expect_output));
1785}
1786
1787TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) {
1788 std::string filePath(__FILE__);
1789 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1790 test_model_file.append(
1791 "upgrader_models/test_versioned_div_scalar_float_v2.ptl");
1792 /*
1793 (('__torch__.MyModuleFloat.forward',
1794 (('instructions',
1795 (('STOREN', 1, 3),
1796 ('DROPR', 1, 0),
1797 ('MOVE', 2, 0),
1798 ('MOVE', 3, 0),
1799 ('OP', 0, 0),
1800 ('RET', 0, 0))),
1801 ('operators', (('aten::div', 'Scalar'),)),
1802 ('constants', ()),
1803 ('types', ()),
1804 ('register_size', 3))),)
1805 */
1806
1807 mobile::Module m_module = _load_for_mobile(test_model_file);
1808
1809 auto intrsuction_list =
1810 m_module.get_method("forward").function().get_code().instructions_;
1811 uint64_t number_of_call_instruction = 0;
1812 for (auto& instruction : intrsuction_list) {
1813 number_of_call_instruction += (instruction.op == OpCode::CALL);
1814 }
1815 // One operator will use upgrader
1816 ASSERT_EQ(number_of_call_instruction, 1);
1817
1818 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1819 auto output = m_module.forward(inputs);
1820 auto expect_output = 2.0 * torch::ones({1});
1821 auto actual_output = output.toTensor();
1822
1823 // The out argument will be overwritten with the output
1824 ASSERT_TRUE(actual_output.equal(expect_output));
1825}
1826
1827TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) {
1828 std::string filePath(__FILE__);
1829 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1830 test_model_file.append(
1831 "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl");
1832 /*
1833 (('__torch__.MyModuleFloat.forward',
1834 (('instructions',
1835 (('STOREN', 1, 3),
1836 ('DROPR', 1, 0),
1837 ('MOVE', 2, 0),
1838 ('OP', 0, 0),
1839 ('MOVE', 3, 0),
1840 ('OP', 1, 0),
1841 ('RET', 0, 0))),
1842 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1843 ('constants', ()),
1844 ('types', ()),
1845 ('register_size', 3))),)
1846 */
1847 mobile::Module m_module = _load_for_mobile(test_model_file);
1848
1849 auto intrsuction_list =
1850 m_module.get_method("forward").function().get_code().instructions_;
1851 uint64_t number_of_call_instruction = 0;
1852 for (auto& instruction : intrsuction_list) {
1853 number_of_call_instruction += (instruction.op == OpCode::CALL);
1854 }
1855 // No operator will use upgrader
1856 ASSERT_EQ(number_of_call_instruction, 0);
1857
1858 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1859 auto output = m_module.forward(inputs);
1860 auto expect_output = 0.5 * torch::ones({1});
1861 auto actual_output = output.toTensor();
1862 std::cout << "expect output: " << expect_output;
1863 std::cout << "actual output: " << actual_output;
1864 // The out argument will be overwritten with the output
1865 ASSERT_TRUE(actual_output.equal(expect_output));
1866}
1867
1868TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) {
1869 std::string filePath(__FILE__);
1870 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1871 test_model_file.append(
1872 "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl");
1873 /*
1874 (('__torch__.MyModuleInt.forward',
1875 (('instructions',
1876 (('STOREN', 1, 3),
1877 ('DROPR', 1, 0),
1878 ('MOVE', 2, 0),
1879 ('OP', 0, 0),
1880 ('MOVE', 3, 0),
1881 ('OP', 1, 0),
1882 ('RET', 0, 0))),
1883 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1884 ('constants', ()),
1885 ('types', ()),
1886 ('register_size', 3))),)
1887 */
1888 mobile::Module m_module = _load_for_mobile(test_model_file);
1889
1890 auto intrsuction_list =
1891 m_module.get_method("forward").function().get_code().instructions_;
1892 uint64_t number_of_call_instruction = 0;
1893 for (auto& instruction : intrsuction_list) {
1894 number_of_call_instruction += (instruction.op == OpCode::CALL);
1895 }
1896 // No operator will use upgrader
1897 ASSERT_EQ(number_of_call_instruction, 0);
1898
1899 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1900 auto output = m_module.forward(inputs);
1901 auto expect_output = 0.5 * torch::ones({1});
1902 auto actual_output = output.toTensor();
1903
1904 // The out argument will be overwritten with the output
1905 ASSERT_TRUE(actual_output.equal(expect_output));
1906}
1907
1908TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) {
1909 std::string filePath(__FILE__);
1910 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1911 test_model_file.append(
1912 "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl");
1913 /*
1914 (('__torch__.MyModule.forward',
1915 (('instructions',
1916 (('STOREN', 1, 5),
1917 ('DROPR', 1, 0),
1918 ('LOAD', 2, 0),
1919 ('LOAD', 3, 0),
1920 ('OP', 0, 0),
1921 ('MOVE', 2, 0),
1922 ('LOAD', 4, 0),
1923 ('OP', 1, 0),
1924 ('LOAD', 3, 0),
1925 ('MOVE', 4, 0),
1926 ('OP', 2, 0),
1927 ('MOVE', 3, 0),
1928 ('MOVE', 5, 0),
1929 ('OP', 3, 0),
1930 ('TUPLE_CONSTRUCT', 4, 0),
1931 ('RET', 0, 0))),
1932 ('operators',
1933 (('aten::div', ''),
1934 ('aten::div', 'float'),
1935 ('aten::div', ''),
1936 ('aten::div', 'int'))),
1937 ('constants', ()),
1938 ('types', ()),
1939 ('register_size', 5))),)
1940 */
1941 mobile::Module m_module = _load_for_mobile(test_model_file);
1942 auto intrsuction_list =
1943 m_module.get_method("forward").function().get_code().instructions_;
1944 uint64_t number_of_call_instruction = 0;
1945 for (auto& instruction : intrsuction_list) {
1946 number_of_call_instruction += (instruction.op == OpCode::CALL);
1947 }
1948 // No operator will use upgrader
1949 ASSERT_EQ(number_of_call_instruction, 0);
1950
1951 std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1952 auto output = m_module.forward(inputs);
1953 auto output_list = output.toTupleRef().elements();
1954 auto expect_output = std::vector<IValue>(
1955 {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1956 // auto actual_output = output.toTensor();
1957 for (size_t i = 0; i < expect_output.size(); i++) {
1958 ASSERT_EQ(output_list[i], expect_output[i]);
1959 }
1960}
1961
1962TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) {
1963 std::string filePath(__FILE__);
1964 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1965 test_model_file.append(
1966 "upgrader_models/test_versioned_div_scalar_int_v2.ptl");
1967 /*
1968 (('__torch__.MyModuleInt.forward',
1969 (('instructions',
1970 (('STOREN', 1, 3),
1971 ('DROPR', 1, 0),
1972 ('MOVE', 2, 0),
1973 ('MOVE', 3, 0),
1974 ('OP', 0, 0),
1975 ('RET', 0, 0))),
1976 ('operators', (('aten::div', 'Scalar'),)),
1977 ('constants', ()),
1978 ('types', ()),
1979 ('register_size', 3))),)
1980 */
1981 mobile::Module m_module = _load_for_mobile(test_model_file);
1982
1983 auto intrsuction_list =
1984 m_module.get_method("forward").function().get_code().instructions_;
1985 uint64_t number_of_call_instruction = 0;
1986 for (auto& instruction : intrsuction_list) {
1987 number_of_call_instruction += (instruction.op == OpCode::CALL);
1988 }
1989 // One operator will use upgrader
1990 ASSERT_EQ(number_of_call_instruction, 1);
1991
1992 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1993 auto output = m_module.forward(inputs);
1994 auto expect_output = 2.0 * torch::ones({1});
1995 auto actual_output = output.toTensor();
1996
1997 // The out argument will be overwritten with the output
1998 ASSERT_TRUE(actual_output.equal(expect_output));
1999}
2000
2001TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) {
2002 std::string filePath(__FILE__);
2003 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2004 test_model_file.append(
2005 "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl");
2006 /*
2007 (('__torch__.MyModuleFloat.forward',
2008 (('instructions',
2009 (('STOREN', 1, 3),
2010 ('DROPR', 1, 0),
2011 ('MOVE', 2, 0),
2012 ('MOVE', 3, 0),
2013 ('OP', 0, 0),
2014 ('RET', 0, 0))),
2015 ('operators', (('aten::div_', 'Scalar'),)),
2016 ('constants', ()),
2017 ('types', ()),
2018 ('register_size', 3))),)
2019 */
2020
2021 mobile::Module m_module = _load_for_mobile(test_model_file);
2022
2023 auto intrsuction_list =
2024 m_module.get_method("forward").function().get_code().instructions_;
2025 uint64_t number_of_call_instruction = 0;
2026 for (auto& instruction : intrsuction_list) {
2027 number_of_call_instruction += (instruction.op == OpCode::CALL);
2028 }
2029 // One operator will use upgrader
2030 ASSERT_EQ(number_of_call_instruction, 1);
2031
2032 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
2033 auto output = m_module.forward(inputs);
2034 auto expect_output = 2.0 * torch::ones({1});
2035 auto actual_output = output.toTensor();
2036
2037 // The out argument will be overwritten with the output
2038 ASSERT_TRUE(actual_output.equal(expect_output));
2039}
2040
2041TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) {
2042 std::string filePath(__FILE__);
2043 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2044 test_model_file.append(
2045 "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl");
2046 /*
2047 (('__torch__.MyModuleInt.forward',
2048 (('instructions',
2049 (('STOREN', 1, 3),
2050 ('DROPR', 1, 0),
2051 ('MOVE', 2, 0),
2052 ('MOVE', 3, 0),
2053 ('OP', 0, 0),
2054 ('RET', 0, 0))),
2055 ('operators', (('aten::div_', 'Scalar'),)),
2056 ('constants', ()),
2057 ('types', ()),
2058 ('register_size', 3))),)
2059 */
2060
2061 mobile::Module m_module = _load_for_mobile(test_model_file);
2062
2063 auto intrsuction_list =
2064 m_module.get_method("forward").function().get_code().instructions_;
2065 uint64_t number_of_call_instruction = 0;
2066 for (auto& instruction : intrsuction_list) {
2067 number_of_call_instruction += (instruction.op == OpCode::CALL);
2068 }
2069 // One operator will use upgrader
2070 ASSERT_EQ(number_of_call_instruction, 1);
2071
2072 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
2073 auto output = m_module.forward(inputs);
2074 auto expect_output = 2.0 * torch::ones({1});
2075 auto actual_output = output.toTensor();
2076
2077 // The out argument will be overwritten with the output
2078 ASSERT_TRUE(actual_output.equal(expect_output));
2079}
2080
2081#endif // !defined(FB_XPLAT_BUILD)
2082
2083TEST(LiteInterpreterUpgraderTest, Upgrader) {
2084 std::vector<mobile::Function> upgrader_functions;
2085
2086 for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
2087 byteCodeFunctionWithOperator.function.initialize_operators(true);
2088 ASSERT_EQ(
2089 byteCodeFunctionWithOperator.function.get_code().operators_.size(),
2090 byteCodeFunctionWithOperator.function.get_code().op_names_.size());
2091 if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
2092 for (const auto& op : byteCodeFunctionWithOperator.operators) {
2093 byteCodeFunctionWithOperator.function.append_operator(
2094 op.name, op.overload_name, op.num_specified_args);
2095 }
2096 }
2097 upgrader_functions.push_back(byteCodeFunctionWithOperator.function);
2098 }
2099
2100 ASSERT_EQ(getUpgraderBytecodeList().size(), upgrader_functions.size());
2101}
2102
2103void enumerateTupleType(
2104 size_t depth,
2105 std::vector<TypePtr>& current,
2106 const std::vector<TypePtr>& candidates,
2107 std::vector<TypePtr>& out) {
2108 static std::vector<std::string> fieldNames;
2109 if (depth > fieldNames.size()) {
2110 fieldNames.reserve(depth);
2111 for (size_t i = fieldNames.size(); i < depth; i++) {
2112 fieldNames.push_back("field" + std::to_string(i));
2113 }
2114 }
2115 if (depth == 0) {
2116 out.push_back(TupleType::create(current));
2117 while (fieldNames.size() > current.size()) {
2118 fieldNames.pop_back();
2119 }
2120 out.push_back(TupleType::createNamed("NamedTuple", fieldNames, current));
2121 return;
2122 }
2123 for (const auto& type : candidates) {
2124 if (containsAnyType(type)) {
2125 continue;
2126 }
2127 current.push_back(type);
2128 enumerateTupleType(depth - 1, current, candidates, out);
2129 current.pop_back();
2130 }
2131}
2132
2133class LiteInterpreterDynamicTypeTestFixture
2134 : public ::testing::TestWithParam<size_t> {
2135 protected:
2136 void SetUp() override {
2137 cu = std::make_shared<CompilationUnit>();
2138 std::vector<TypePtr> keyTypes = {
2139 AnyType::get(),
2140 IntType::get(),
2141 BoolType::get(),
2142 FloatType::get(),
2143 ComplexType::get(),
2144 StringType::get(),
2145 TensorType::get(),
2146 DeviceObjType::get(),
2147 };
2148 types = {
2149 NoneType::get(),
2150 NumberType::get(),
2151 ClassType::create("__torch__.TestClass1", cu),
2152 ClassType::create("__torch__.TestClass2", cu),
2153 AnyListType::get(),
2154 AnyTupleType::get(),
2155 StreamObjType::get(),
2156 CapsuleType::get(),
2157 GeneratorType::get(),
2158 StorageType::get(),
2159 VarType::create("t"),
2160 VarType::create("v"),
2161 AnyClassType::get()};
2162 std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
2163 auto expandTypes = [&](size_t tupleSize) {
2164 std::vector<TypePtr> nested;
2165 for (const auto& type : types) {
2166 if (!(type == AnyType::get())) {
2167 nested.emplace_back(ListType::create(type));
2168 if (!(type == NoneType::get() ||
2169 type->kind() == OptionalType::Kind)) {
2170 nested.emplace_back(OptionalType::create(type));
2171 }
2172 }
2173 for (const auto& keyType : keyTypes) {
2174 nested.emplace_back(DictType::create(keyType, type));
2175 }
2176 }
2177 std::vector<TypePtr> tmp;
2178 enumerateTupleType(tupleSize, tmp, types, nested);
2179 std::move(
2180 std::begin(nested), std::end(nested), std::back_inserter(types));
2181 };
2182 expandTypes(1);
2183 expandTypes(1);
2184 }
2185 std::shared_ptr<CompilationUnit> cu;
2186 std::vector<TypePtr> types;
2187
2188 public:
2189 static constexpr size_t kNumSplits = 10;
2190};
2191
2192constexpr size_t LiteInterpreterDynamicTypeTestFixture::kNumSplits;
2193
2194/**
2195 * Enumerate all possible JIT types appearing in mobile runtime, and test
2196 * whether subtyping relation is preserved after one of the JIT types is
2197 * converted to DynamicType.
2198 *
2199 * We firstly enumerate all "base" types in a vector, and implement
2200 * expandTypes() to enumerate container types one "level" up for a given set
2201 * of types. We call expandTypes() twice to test types nested less or equal
2202 * to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc.
2203 */
2204TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) {
2205 size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits;
2206 size_t begin = num * GetParam();
2207 size_t end = std::min(types.size(), begin + num);
2208 for (const auto& a : types) {
2209 auto da = DynamicType::create(*a);
2210 for (size_t i = begin; i < end; i++) {
2211 const auto& b = types[i];
2212 bool result = a->isSubtypeOf(*b);
2213 EXPECT_EQ(result, da->isSubtypeOf(*b));
2214 result = b->isSubtypeOf(*a);
2215 EXPECT_EQ(result, b->isSubtypeOf(*da));
2216 }
2217 }
2218}
2219
2220INSTANTIATE_TEST_CASE_P(
2221 PyTorch,
2222 LiteInterpreterDynamicTypeTestFixture,
2223 ::testing::Range(
2224 static_cast<size_t>(0),
2225 LiteInterpreterDynamicTypeTestFixture::kNumSplits));
2226
2227} // namespace jit
2228} // namespace torch
2229