1 | #include <test/cpp/jit/test_utils.h> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <torch/csrc/autograd/generated/variable_factories.h> |
7 | #include <torch/csrc/jit/api/module.h> |
8 | #include <torch/csrc/jit/frontend/resolver.h> |
9 | #include <torch/csrc/jit/mobile/compatibility/backport.h> |
10 | #include <torch/csrc/jit/mobile/compatibility/backport_manager.h> |
11 | #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h> |
12 | #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h> |
13 | #include <torch/csrc/jit/mobile/import.h> |
14 | #include <torch/csrc/jit/mobile/interpreter.h> |
15 | #include <torch/csrc/jit/mobile/module.h> |
16 | #include <torch/csrc/jit/mobile/parse_bytecode.h> |
17 | #include <torch/csrc/jit/mobile/parse_operators.h> |
18 | #include <torch/csrc/jit/mobile/upgrader_mobile.h> |
19 | #include <torch/csrc/jit/serialization/export.h> |
20 | #include <torch/csrc/jit/serialization/import.h> |
21 | #include <torch/custom_class.h> |
22 | #include <torch/torch.h> |
23 | |
24 | #include <torch/csrc/jit/serialization/import_export_functions.h> |
25 | #include <unordered_set> |
26 | |
27 | // Tests go in torch::jit |
28 | namespace torch { |
29 | namespace jit { |
30 | |
31 | TEST(LiteInterpreterTest, UpsampleNearest2d) { |
32 | Module m("m" ); |
33 | m.define(R"( |
34 | def forward(self, input: Tensor, scale:float): |
35 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
36 | )" ); |
37 | |
38 | std::vector<IValue> inputs; |
39 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
40 | inputs.emplace_back(at::Scalar(2.0)); |
41 | auto ref = m.forward(inputs); |
42 | |
43 | std::stringstream ss; |
44 | m._save_for_mobile(ss); |
45 | mobile::Module bc = _load_for_mobile(ss); |
46 | IValue res; |
47 | res = bc.forward(inputs); |
48 | |
49 | auto resd = res.toTensor(); |
50 | auto refd = ref.toTensor(); |
51 | ASSERT_TRUE(resd.equal(refd)); |
52 | } |
53 | |
54 | TEST(LiteInterpreterTest, CheckAttrAccess) { |
55 | Module m("m" ); |
56 | m.register_attribute("mobile_optimized" , BoolType::get(), true); |
57 | |
58 | std::stringstream ss; |
59 | m._save_for_mobile(ss); |
60 | mobile::Module bc = _load_for_mobile(ss); |
61 | bool mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
62 | |
63 | AT_ASSERT(mobile_optimized); |
64 | m.setattr("mobile_optimized" , false); |
65 | ss = std::stringstream(); |
66 | m._save_for_mobile(ss); |
67 | bc = _load_for_mobile(ss); |
68 | mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
69 | |
70 | AT_ASSERT(!mobile_optimized); |
71 | } |
72 | |
73 | TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest) |
74 | const std::vector<std::string> test_programs{ |
75 | // test invoking a method with default parameter |
76 | R"( |
77 | def test_func(self, x, b : int = 4): |
78 | return self.foo + x + b |
79 | )" , |
80 | // inner method call with default parameter (gets inlined) |
81 | R"( |
82 | def add_with_default_arg(self, x, b : int = 4): |
83 | return self.foo + x + b |
84 | def test_func(self, x): |
85 | return self.add_with_default_arg(x) # invoke method w/ default arg |
86 | )" , |
87 | // simple method call |
88 | R"( |
89 | def test_func(self, x): |
90 | b = 4 |
91 | return self.foo + x + b |
92 | )" , |
93 | }; |
94 | for (const auto& test_program : test_programs) { |
95 | Module m("m" ); |
96 | m.register_parameter("foo" , torch::ones({}), false); |
97 | m.define(test_program); |
98 | |
99 | const int fortyTwo = 42; // (keep linter happy) |
100 | auto minput = fortyTwo * torch::ones({}); |
101 | auto ref = m.run_method("test_func" , minput); |
102 | |
103 | std::stringstream ss; |
104 | m._save_for_mobile(ss); |
105 | mobile::Module bc = _load_for_mobile(ss); |
106 | const auto& test_func = bc.get_method("test_func" ); |
107 | IValue res; |
108 | for (int i = 0; i < 3; ++i) { |
109 | res = test_func({minput}); |
110 | } |
111 | |
112 | auto resd = res.toTensor().item<float>(); |
113 | auto refd = ref.toTensor().item<float>(); |
114 | AT_ASSERT(resd == refd); |
115 | } |
116 | } |
117 | |
118 | TEST(LiteInterpreterTest, Conv) { |
119 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
120 | if (s && strcmp(s, "1" ) == 0) |
121 | return; |
122 | |
123 | std::vector<torch::jit::IValue> inputs; |
124 | |
125 | Module m("m" ); |
126 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
127 | m.register_parameter("bias" , torch::ones({20}), false); |
128 | m.define(R"( |
129 | def forward(self, input): |
130 | return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
131 | )" ); |
132 | |
133 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
134 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
135 | |
136 | auto outputref = m.forward(inputs).toTensor(); |
137 | |
138 | std::stringstream ss; |
139 | m._save_for_mobile(ss); |
140 | mobile::Module bc = _load_for_mobile(ss); |
141 | IValue res; |
142 | for (int i = 0; i < 3; ++i) { |
143 | res = bc.get_method("forward" )(inputs); |
144 | } |
145 | auto output = res.toTensor(); |
146 | AT_ASSERT(outputref.dim() == output.dim()); |
147 | AT_ASSERT( |
148 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
149 | } |
150 | |
151 | TEST(LiteInterpreterTest, Inline) { |
152 | Module m("m" ); |
153 | m.define(R"JIT( |
154 | def foo1(self, x): |
155 | return x + 1 |
156 | |
157 | def foo2(self, x): |
158 | return self.foo1(x) + 2 |
159 | |
160 | def foo3(self, x): |
161 | return self.foo2(x) + 3 |
162 | )JIT" ); |
163 | std::stringstream ss; |
164 | m._save_for_mobile(ss); |
165 | mobile::Module bc = _load_for_mobile(ss); |
166 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
167 | auto output = bc.get_method("foo3" )(inputs); |
168 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
169 | } |
170 | |
171 | TEST(LiteInterpreterTest, Tuple) { |
172 | Module m("m" ); |
173 | m.define(R"JIT( |
174 | def foo(self, x): |
175 | return (1, 2, x + 3) |
176 | |
177 | def forward(self, x): |
178 | tuple = self.foo(x) |
179 | return tuple |
180 | )JIT" ); |
181 | std::stringstream ss; |
182 | m._save_for_mobile(ss); |
183 | mobile::Module bc = _load_for_mobile(ss); |
184 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
185 | auto output = bc.get_method("forward" )(inputs); |
186 | AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2); |
187 | } |
188 | |
189 | TEST(LiteInterpreterTest, AtenFormat) { |
190 | Module m("m" ); |
191 | m.define(R"""( |
192 | def forward(self, fmt:str="first {} {}", num:str="abc"): |
193 | x = 2 |
194 | x = x * x |
195 | return fmt.format(num, x) |
196 | )""" ); |
197 | std::stringstream ss; |
198 | m._save_for_mobile(ss); |
199 | mobile::Module bc = _load_for_mobile(ss); |
200 | std::vector<torch::jit::IValue> inputs; |
201 | auto output_bc = bc.get_method("forward" )(inputs); |
202 | auto output_m = m.get_method("forward" )(inputs); |
203 | // std::cout << output_m.toStringRef() << "\n" |
204 | // << output_bc.toStringRef() << std::endl; |
205 | AT_ASSERT(output_m.toStringRef() == output_bc.toStringRef()); |
206 | } |
207 | |
208 | TEST(LiteInterpreterTest, PrimDevice) { |
209 | Module m("m" ); |
210 | m.define(R"""( |
211 | def forward(self, x:torch.Tensor): |
212 | return x.device |
213 | )""" ); |
214 | std::stringstream ss; |
215 | m._save_for_mobile(ss); |
216 | mobile::Module bc = _load_for_mobile(ss); |
217 | std::vector<torch::jit::IValue> inputs; |
218 | auto minput = 3.5 * torch::ones({}); |
219 | inputs.emplace_back(minput); |
220 | auto output_bc = bc.get_method("forward" )(inputs); |
221 | auto output_m = m.get_method("forward" )(inputs); |
222 | AT_ASSERT(output_bc.toDevice().str() == output_m.toDevice().str()); |
223 | } |
224 | |
225 | TEST(LiteInterpreterTest, Dict) { |
226 | Module m("m" ); |
227 | m.define(R"JIT( |
228 | def foo(self, x): |
229 | return {"result": x + 1} |
230 | |
231 | def forward(self, x): |
232 | d = self.foo(x) |
233 | return d |
234 | )JIT" ); |
235 | std::stringstream ss; |
236 | m._save_for_mobile(ss); |
237 | mobile::Module bc = _load_for_mobile(ss); |
238 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
239 | auto output = bc.get_method("forward" )(inputs); |
240 | AT_ASSERT(output.toGenericDict().at("result" ).toTensor().item().toInt() == 2); |
241 | } |
242 | |
243 | TEST(LiteInterpreterTest, List) { |
244 | Module m("m" ); |
245 | m.define(R"JIT( |
246 | def foo(self, x): |
247 | return [x + 2] |
248 | |
249 | def forward(self, x): |
250 | d = self.foo(x) |
251 | return d |
252 | )JIT" ); |
253 | std::stringstream ss; |
254 | m._save_for_mobile(ss); |
255 | mobile::Module bc = _load_for_mobile(ss); |
256 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
257 | auto output = bc.get_method("forward" )(inputs); |
258 | auto server_output = m.forward(inputs); |
259 | EXPECT_EQ(output.toList().get(0).toTensor().item().toInt(), 3); |
260 | EXPECT_EQ(output, server_output); |
261 | } |
262 | |
263 | TEST(LiteInterpreterTest, PrimOverload) { |
264 | /* |
265 | // temporarily disabled |
266 | script::Module m("m"); |
267 | m.define(R"JIT( |
268 | def forward(self, x): |
269 | result = [1, 2] |
270 | result.append(3) |
271 | return result |
272 | )JIT"); |
273 | std::stringstream ss; |
274 | m._save_for_mobile(ss); |
275 | mobile::Module bc = _load_for_mobile(ss); |
276 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
277 | auto output = bc.get_method("forward")(inputs); |
278 | AT_ASSERT(output.toIntList()[2] == 3); |
279 | */ |
280 | } |
281 | |
282 | TEST(LiteInterpreterTest, Prim) { |
283 | Module m("m" ); |
284 | m.define(R"JIT( |
285 | def forward(self, x): |
286 | return int(x) |
287 | )JIT" ); |
288 | |
289 | std::vector<IValue> inputs; |
290 | auto minput = 3.5 * torch::ones({}); |
291 | inputs.emplace_back(minput); |
292 | auto ref = m.run_method("forward" , minput); |
293 | |
294 | std::stringstream ss; |
295 | m._save_for_mobile(ss); |
296 | mobile::Module bc = _load_for_mobile(ss); |
297 | IValue res; |
298 | for (int i = 0; i < 3; ++i) { |
299 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
300 | auto bcinputs = inputs; |
301 | res = bc.get_method("forward" )(bcinputs); |
302 | } |
303 | |
304 | auto resi = res.toInt(); |
305 | auto refi = ref.toInt(); |
306 | AT_ASSERT(resi == refi); |
307 | } |
308 | |
309 | TEST(LiteInterpreterTest, PrimScalar) { |
310 | Module m("m" ); |
311 | m.define(R"JIT( |
312 | def forward(self, x): |
313 | return int(x.item()) |
314 | )JIT" ); |
315 | |
316 | std::vector<IValue> inputs; |
317 | auto minput = 3.5 * torch::ones({}); |
318 | inputs.emplace_back(minput); |
319 | auto ref = m.run_method("forward" , minput); |
320 | |
321 | std::stringstream ss; |
322 | m._save_for_mobile(ss); |
323 | mobile::Module bc = _load_for_mobile(ss); |
324 | IValue res; |
325 | for (int i = 0; i < 3; ++i) { |
326 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
327 | auto bcinputs = inputs; |
328 | res = bc.get_method("forward" )(bcinputs); |
329 | } |
330 | |
331 | auto resi = res.toInt(); |
332 | auto refi = ref.toInt(); |
333 | AT_ASSERT(resi == refi); |
334 | } |
335 | |
336 | TEST(LiteInterpreterTest, LoadOrigJit) { |
337 | Module m("m" ); |
338 | m.register_parameter("foo" , torch::ones({}), false); |
339 | m.define(R"( |
340 | def forward(self, x): |
341 | b = 4 |
342 | return self.foo + x + b |
343 | )" ); |
344 | std::stringstream ss; |
345 | m.save(ss); |
346 | ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found" ); |
347 | } |
348 | |
349 | TEST(LiteInterpreterTest, WrongMethodName) { |
350 | Module m("m" ); |
351 | m.register_parameter("foo" , torch::ones({}), false); |
352 | m.define(R"( |
353 | def add(self, x): |
354 | b = 4 |
355 | return self.foo + x + b |
356 | )" ); |
357 | std::stringstream ss; |
358 | m._save_for_mobile(ss); |
359 | mobile::Module bc = _load_for_mobile(ss); |
360 | std::vector<IValue> inputs; |
361 | auto minput = 5 * torch::ones({}); |
362 | inputs.emplace_back(minput); |
363 | ASSERT_THROWS_WITH_MESSAGE( |
364 | bc.get_method("forward" )(inputs), "is not defined" ); |
365 | } |
366 | |
367 | TEST(LiteInterpreterTest, SetState) { |
368 | Module m("m" ); |
369 | m.register_parameter("foo" , torch::ones({}), false); |
370 | m.define(R"( |
371 | def __getstate__(self): |
372 | return self.foo + self.foo |
373 | def __setstate__(self, a): |
374 | self.foo = a |
375 | def forward(self, x): |
376 | b = 4 |
377 | return self.foo + x + b |
378 | )" ); |
379 | |
380 | std::vector<IValue> inputs; |
381 | auto minput = 5 * torch::ones({}); |
382 | inputs.emplace_back(minput); |
383 | |
384 | std::stringstream ms; |
385 | m.save(ms); |
386 | auto loaded_m = load(ms); |
387 | auto ref = loaded_m.run_method("forward" , minput); |
388 | |
389 | std::stringstream ss; |
390 | m._save_for_mobile(ss); |
391 | mobile::Module bc = _load_for_mobile(ss); |
392 | IValue res; |
393 | for (int i = 0; i < 3; ++i) { |
394 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
395 | auto bcinputs = inputs; |
396 | res = bc.get_method("forward" )(bcinputs); |
397 | } |
398 | |
399 | auto resd = res.toTensor().item<float>(); |
400 | auto refd = ref.toTensor().item<float>(); |
401 | AT_ASSERT(resd == refd); |
402 | } |
403 | |
404 | class TorchBindLiteInterpreterTestStruct |
405 | : public torch::jit::CustomClassHolder { |
406 | public: |
407 | std::string get(at::Tensor t) { |
408 | std::stringstream ss; |
409 | ss << "Hello! Your tensor has " ; |
410 | ss << t.numel(); |
411 | ss << " elements!" ; |
412 | return ss.str(); |
413 | } |
414 | }; |
415 | |
416 | namespace { |
417 | struct ClassNamespaceValue : public SugaredValue { |
418 | explicit ClassNamespaceValue(c10::QualifiedName name) |
419 | : basename_(std::move(name)) {} |
420 | |
421 | std::shared_ptr<SugaredValue> attr( |
422 | const SourceRange& loc, |
423 | GraphFunction& m, |
424 | const std::string& name) override { |
425 | const auto fullName = c10::QualifiedName(basename_, name); |
426 | |
427 | // Check to see if it is a custom class. |
428 | if (auto custom_class = getCustomClass(fullName.qualifiedName())) { |
429 | return std::make_shared<ClassValue>(custom_class); |
430 | } |
431 | |
432 | // If it's not a custom class, assume it's another namespace |
433 | // NOLINTNEXTLINE(performance-move-const-arg) |
434 | return std::make_shared<ClassNamespaceValue>(std::move(fullName)); |
435 | } |
436 | |
437 | std::string kind() const override { |
438 | return "Class Namespace" ; |
439 | } |
440 | |
441 | private: |
442 | c10::QualifiedName basename_; |
443 | }; |
444 | |
445 | struct TestModuleResolver : public Resolver { |
446 | std::shared_ptr<SugaredValue> resolveValue( |
447 | const std::string& name, |
448 | GraphFunction& m, |
449 | const SourceRange& loc) override { |
450 | if (name == "torch" ) { |
451 | return std::make_shared<BuiltinModule>("aten" ); |
452 | } else if (name == "__torch__" ) { |
453 | return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name)); |
454 | } |
455 | |
456 | return nullptr; |
457 | } |
458 | |
459 | TypePtr resolveType(const std::string& name, const SourceRange& loc) |
460 | override { |
461 | return nullptr; |
462 | } |
463 | }; |
464 | } // namespace |
465 | |
466 | TEST(LiteInterpreterTest, BuiltinClass) { |
467 | script::Module m("m" ); |
468 | |
469 | auto cls = getCustomClass( |
470 | "__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest" ); |
471 | TORCH_INTERNAL_ASSERT(cls); |
472 | c10::intrusive_ptr<torch::CustomClassHolder> obj_holder; |
473 | m.register_attribute("my_obj" , cls, IValue::make_capsule(obj_holder)); |
474 | |
475 | m.register_parameter("foo" , torch::ones({}), false); |
476 | m.define( |
477 | R"( |
478 | def __getstate__(self): |
479 | return 1 |
480 | def __setstate__(self, a): |
481 | self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest() |
482 | |
483 | def forward(self, x) -> str: |
484 | return self.my_obj.get(x) |
485 | )" , |
486 | std::make_shared<TestModuleResolver>()); |
487 | |
488 | std::stringstream ss; |
489 | m._save_for_mobile(ss); |
490 | mobile::Module bc = _load_for_mobile(ss); |
491 | auto res = |
492 | bc.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
493 | const auto& str = res.toStringRef(); |
494 | std::string expected = "Hello! Your tensor has 12 elements!" ; |
495 | AT_ASSERT(str == expected); |
496 | } |
497 | |
498 | TEST(LiteInterpreterTest, BuiltinFunction) { |
499 | script::Module m("m" ); |
500 | auto custom_class_obj = |
501 | make_custom_class<TorchBindLiteInterpreterTestStruct>(); |
502 | m.register_attribute("my_obj" , custom_class_obj.type(), custom_class_obj); |
503 | m.define(R"( |
504 | def forward(self, x) -> str: |
505 | return self.my_obj.get(x) |
506 | )" ); |
507 | |
508 | std::stringstream ss; |
509 | m._save_for_mobile(ss); |
510 | mobile::Module bc = _load_for_mobile(ss); |
511 | auto res = |
512 | bc.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
513 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
514 | auto str = res.toStringRef(); |
515 | std::string expected = "Hello! Your tensor has 12 elements!" ; |
516 | AT_ASSERT(str == expected); |
517 | } |
518 | |
519 | #if !defined FB_XPLAT_BUILD |
520 | TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) { |
521 | auto runtime_bytecode_version = _get_runtime_bytecode_version(); |
522 | AT_ASSERT( |
523 | runtime_bytecode_version == |
524 | caffe2::serialize::kMaxSupportedBytecodeVersion); |
525 | } |
526 | |
527 | TEST(LiteInterpreterTest, GetRuntimeOperatorsVersion) { |
528 | auto runtime_operators_version = _get_runtime_operators_min_max_versions(); |
529 | AT_ASSERT( |
530 | runtime_operators_version.first == |
531 | caffe2::serialize::kMinSupportedFileFormatVersion && |
532 | runtime_operators_version.second == |
533 | caffe2::serialize::kMaxSupportedFileFormatVersion); |
534 | } |
535 | |
536 | /** |
537 | * The test below is disarmed for FB internal xplat builds since |
538 | * BUCK requires us to pass in the script_module_v4.ptl file in |
539 | * as a resource dependency of the build rule for this file, and |
540 | * we would need to access it via the C++ Resources API instead |
541 | * of directly reading from disk (which is what the open source |
542 | * build/run does). |
543 | */ |
544 | TEST(LiteInterpreterTest, GetByteCodeVersion) { |
545 | std::string filePath(__FILE__); |
546 | auto test_model_file_v4 = |
547 | filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
548 | test_model_file_v4.append("script_module_v4.ptl" ); |
549 | |
550 | auto version_v4 = _get_model_bytecode_version(test_model_file_v4); |
551 | AT_ASSERT(version_v4 == 4); |
552 | } |
553 | |
554 | #endif // !defined(FB_XPLAT_BUILD) |
555 | |
556 | TEST(LiteInterpreterTest, GetContainTypes) { |
557 | Module m("m" ); |
558 | m.define(R"( |
559 | def forward(self): |
560 | return 3 |
561 | )" ); |
562 | |
563 | std::stringstream ss; |
564 | m._save_for_mobile(ss, {}, true); |
565 | |
566 | _get_mobile_model_contained_types(ss); |
567 | } |
568 | |
569 | namespace { |
570 | |
571 | void compareModelOutput( |
572 | c10::ArrayRef<IValue> actual_result_list, |
573 | const std::vector<IValue>& expect_result_list) { |
574 | AT_ASSERT(actual_result_list.size() == expect_result_list.size()); |
575 | AT_ASSERT( |
576 | actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor())); |
577 | AT_ASSERT( |
578 | actual_result_list[1].toTensor().dim() == |
579 | expect_result_list[1].toTensor().dim()); |
580 | AT_ASSERT( |
581 | actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor())); |
582 | AT_ASSERT( |
583 | actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor())); |
584 | ASSERT_EQ( |
585 | actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef()); |
586 | ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool()); |
587 | ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool()); |
588 | ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool()); |
589 | AT_ASSERT( |
590 | actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor())); |
591 | ASSERT_EQ( |
592 | actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef()); |
593 | ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt()); |
594 | ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool()); |
595 | } |
596 | |
597 | void runAndCheckTorchScriptModel( |
598 | std::stringstream& input_model_stream, |
599 | const std::vector<IValue>& input_data, |
600 | const std::vector<IValue>& expect_result_list, |
601 | const uint64_t expect_version) { |
602 | auto actual_version = _get_model_bytecode_version(input_model_stream); |
603 | AT_ASSERT(actual_version == expect_version); |
604 | |
605 | // Load and run the backport model, then compare the result with expect |
606 | // result |
607 | Module m_mobile = load(input_model_stream); |
608 | |
609 | auto actual_result = m_mobile.forward(input_data); |
610 | const auto& actual_result_list = actual_result.toTupleRef().elements(); |
611 | compareModelOutput(actual_result_list, expect_result_list); |
612 | } |
613 | |
614 | void runAndCheckBytecodeModel( |
615 | std::stringstream& input_model_stream, |
616 | const std::vector<IValue>& input_data, |
617 | const std::vector<IValue>& expect_result_list, |
618 | const uint64_t expect_version) { |
619 | auto actual_version = _get_model_bytecode_version(input_model_stream); |
620 | AT_ASSERT(actual_version == expect_version); |
621 | |
622 | // Load and run the backport model, then compare the result with expect |
623 | // result |
624 | Module m_mobile = load(input_model_stream); |
625 | |
626 | auto actual_result = m_mobile.forward(input_data); |
627 | const auto& actual_result_list = actual_result.toTupleRef().elements(); |
628 | |
629 | compareModelOutput(actual_result_list, expect_result_list); |
630 | } |
631 | |
632 | void backportAllVersionCheck( |
633 | std::stringstream& test_model_file_stream, |
634 | std::vector<IValue>& input_data, |
635 | std::vector<IValue>& expect_result_list, |
636 | const uint64_t expect_from_version) { |
637 | auto from_version = _get_model_bytecode_version(test_model_file_stream); |
638 | EXPECT_EQ(from_version, expect_from_version); |
639 | AT_ASSERT(from_version > 0); |
640 | |
641 | // Backport script_module_v5.ptl to an older version |
642 | constexpr int64_t minimum_to_version = 4; |
643 | auto current_to_version = from_version - 1; |
644 | |
645 | // Verify all candidate to_version work as expected. All backport to version |
646 | // larger than minimum_to_version should success. |
647 | while (current_to_version >= minimum_to_version) { |
648 | // Do not declare std::stringstream oss outside of the while loop as |
649 | // oss.clear() doesn't reset the stream content, only clears out error state |
650 | // flag in stringstream causing a problematic stream. Instead, it's cleaner |
651 | // and safer to just declare a new std::stringstream one and swap them. |
652 | std::stringstream oss; |
653 | bool backPortSuccess = |
654 | _backport_for_mobile(test_model_file_stream, oss, current_to_version); |
655 | AT_ASSERT(backPortSuccess); |
656 | |
657 | // Check backport model version |
658 | auto backport_version = _get_model_bytecode_version(oss); |
659 | backport_version = _get_model_bytecode_version(oss); |
660 | AT_ASSERT(backport_version == current_to_version); |
661 | |
662 | // Load and run the backport model, then compare the result with expect |
663 | // result |
664 | runAndCheckBytecodeModel( |
665 | oss, input_data, expect_result_list, current_to_version); |
666 | oss.seekg(0, oss.beg); |
667 | runAndCheckTorchScriptModel( |
668 | oss, input_data, expect_result_list, current_to_version); |
669 | |
670 | current_to_version--; |
671 | } |
672 | // backport to minimum version - 1 should fail |
673 | std::stringstream oss; |
674 | bool backPortSuccess = |
675 | _backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1); |
676 | AT_ASSERT(!backPortSuccess); |
677 | } |
678 | } // namespace |
679 | |
680 | #if !defined FB_XPLAT_BUILD |
681 | TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { |
682 | torch::jit::Module module("m" ); |
683 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
684 | module.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
685 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
686 | module.register_parameter("bias" , torch::ones({20}), false); |
687 | module.define(R"( |
688 | def fn(self, x:float=1.0): |
689 | return x |
690 | |
691 | def forward(self, input): |
692 | x1 = torch.zeros(2, 2) |
693 | x2 = torch.empty_like(torch.empty(2, 2)) |
694 | x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
695 | # Add torch.add operator to cover bytecode version bump from 6 to 7 |
696 | # for bytecode version 7, the main change is to support defaults arguments with out arguments |
697 | x = 2 * torch.ones(1) |
698 | h = torch.ones(1) |
699 | torch.add(x, h, out=x) |
700 | device = torch.ones(1, 1).cpu().device.type |
701 | is_cuda = x1.is_cuda |
702 | bool_val = True |
703 | check_is = [] is None |
704 | check_is_not = [1] is not None |
705 | check_not = not bool_val |
706 | num_to_tensor = torch.tensor([self.fn()]) |
707 | d = {"a": "abc"} |
708 | check_dict_index = d["a"] |
709 | check_dim = x1.dim() |
710 | return ( |
711 | x1, x2, x3, x, device, is_cuda, check_is, |
712 | check_is_not, num_to_tensor, check_dict_index, |
713 | check_dim, check_not |
714 | ) |
715 | )" ); |
716 | |
717 | torch::jit::Module module_freeze = freeze(module); |
718 | |
719 | std::stringstream input_model_stream; |
720 | module_freeze._save_for_mobile( |
721 | input_model_stream, |
722 | /*extra_files=*/{}, |
723 | /*save_mobile_debug_info=*/false, |
724 | /*use_flatbuffer=*/true); |
725 | std::vector<IValue> input_data = |
726 | std::vector<IValue>({torch::ones({1, 1, 28, 28})}); |
727 | std::vector<IValue> expect_result_list; |
728 | expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0); |
729 | expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float)); |
730 | expect_result_list.emplace_back( |
731 | at::ones({1, 20, 24, 24}, ScalarType::Float) * 26); |
732 | expect_result_list.emplace_back(3 * at::ones({1})); |
733 | // "cpu" False, False, True, tensor(1), "abc", 2, False) |
734 | expect_result_list.emplace_back(c10::IValue("cpu" )); |
735 | expect_result_list.emplace_back(c10::IValue(false)); |
736 | expect_result_list.emplace_back(c10::IValue(false)); |
737 | expect_result_list.emplace_back(c10::IValue(true)); |
738 | expect_result_list.emplace_back(c10::IValue(at::ones({1}))); |
739 | expect_result_list.emplace_back(c10::IValue("abc" )); |
740 | expect_result_list.emplace_back(c10::IValue(2)); |
741 | expect_result_list.emplace_back(c10::IValue(false)); |
742 | |
743 | backportAllVersionCheck( |
744 | input_model_stream, |
745 | input_data, |
746 | expect_result_list, |
747 | 9); // flatbuffer starts at 9 |
748 | } |
749 | #endif // !defined(FB_XPLAT_BUILD) |
750 | |
751 | TEST(LiteInterpreterTest, GetRuntimeOpsAndInfo) { |
752 | auto runtime_ops = _get_runtime_ops_and_info(); |
753 | // Ballpark estimate of the minimal number of ops; just used to |
754 | // verify API returns a reasonably large number. |
755 | AT_ASSERT(runtime_ops.size() > 2900); |
756 | } |
757 | |
758 | TEST(LiteInterpreterTest, isCompatibleSuccess) { |
759 | // test trivial success case |
760 | auto runtime_info = RuntimeCompatibilityInfo::get(); |
761 | std::unordered_map<std::string, OperatorInfo> model_ops; |
762 | model_ops["aten::add.Scalar" ] = OperatorInfo{2}; |
763 | |
764 | std::unordered_set<std::string> types = {"List" , "int" , "NamedTuple" }; |
765 | auto model_info = ModelCompatibilityInfo{ |
766 | caffe2::serialize::kMaxSupportedBytecodeVersion, |
767 | model_ops, |
768 | types, |
769 | _get_runtime_bytecode_min_max_versions().first}; |
770 | |
771 | AT_ASSERT( |
772 | is_compatible(runtime_info, model_info).status == |
773 | ModelCompatibilityStatus::OK); |
774 | } |
775 | |
776 | TEST(LiteInterpreterTest, isCompatibleFail) { |
777 | // test trivial failure due to ops |
778 | std::unordered_map<std::string, OperatorInfo> model_ops; |
779 | model_ops["aten::add.Scalar" ] = OperatorInfo{2}; |
780 | auto model_info = ModelCompatibilityInfo{ |
781 | caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops}; |
782 | std::unordered_map<std::string, OperatorInfo> runtime_ops; |
783 | runtime_ops["aten::add.Int" ] = OperatorInfo{2}; |
784 | auto runtime_info = RuntimeCompatibilityInfo{ |
785 | std::pair<uint64_t, uint64_t>( |
786 | caffe2::serialize::kMinSupportedBytecodeVersion, |
787 | caffe2::serialize::kMaxSupportedBytecodeVersion), |
788 | runtime_ops, |
789 | _get_mobile_supported_types()}; |
790 | |
791 | auto result = is_compatible(runtime_info, model_info); |
792 | AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR); |
793 | AT_ASSERT( |
794 | result.errors[0] == |
795 | "Operator 'aten::add.Scalar' missing from runtime (not found)" ); |
796 | |
797 | // test trivial failure due to bytecode greater than max supported bytecode |
798 | // version |
799 | runtime_ops["aten::add.Scalar" ] = OperatorInfo{2}; |
800 | runtime_info = RuntimeCompatibilityInfo{ |
801 | std::pair<uint64_t, uint64_t>( |
802 | caffe2::serialize::kMinSupportedBytecodeVersion, |
803 | caffe2::serialize::kMaxSupportedBytecodeVersion), |
804 | runtime_ops, |
805 | _get_mobile_supported_types()}; |
806 | model_info.bytecode_version = |
807 | caffe2::serialize::kMaxSupportedBytecodeVersion + 1; |
808 | |
809 | result = is_compatible(runtime_info, model_info); |
810 | AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR); |
811 | |
812 | // test trivial failure due to bytecode less than min supported bytecode |
813 | // version |
814 | runtime_ops["aten::add.Scalar" ] = OperatorInfo{2}; |
815 | runtime_info = RuntimeCompatibilityInfo{ |
816 | std::pair<uint64_t, uint64_t>( |
817 | caffe2::serialize::kMinSupportedBytecodeVersion, |
818 | caffe2::serialize::kMaxSupportedBytecodeVersion), |
819 | runtime_ops, |
820 | _get_mobile_supported_types()}; |
821 | model_info.bytecode_version = |
822 | caffe2::serialize::kMinSupportedBytecodeVersion - 1; |
823 | |
824 | result = is_compatible(runtime_info, model_info); |
825 | AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR); |
826 | |
827 | // test trivial failure due to type |
828 | runtime_info = RuntimeCompatibilityInfo::get(); |
829 | std::unordered_set<std::string> types = {"List" , "int" , "Sequence" }; |
830 | |
831 | model_info = ModelCompatibilityInfo{ |
832 | caffe2::serialize::kMaxSupportedBytecodeVersion, |
833 | model_ops, |
834 | types, |
835 | _get_runtime_bytecode_min_max_versions().first}; |
836 | |
837 | AT_ASSERT( |
838 | is_compatible(runtime_info, model_info).status == |
839 | ModelCompatibilityStatus::ERROR); |
840 | |
841 | // test trivial failure due to operator version |
842 | runtime_info = RuntimeCompatibilityInfo::get(); |
843 | |
844 | model_info = ModelCompatibilityInfo{ |
845 | caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, {}, 0}; |
846 | |
847 | AT_ASSERT( |
848 | is_compatible(runtime_info, model_info).status == |
849 | ModelCompatibilityStatus::ERROR); |
850 | } |
851 | |
852 | TEST(LiteInterpreterTest, Eval) { |
853 | std::vector<torch::jit::IValue> inputs; |
854 | |
855 | Module m("m" ); |
856 | m.define(R"( |
857 | def __init__(self, x): |
858 | self.training = True |
859 | |
860 | def forward(self, input): |
861 | return torch.dropout(input, 1.0, self.training) |
862 | )" ); |
863 | |
864 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
865 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
866 | m.eval(); |
867 | auto outputref = m.forward(inputs).toTensor(); |
868 | |
869 | // save m in training mode to make sure that mobile eval() will correctly |
870 | // change back to eval mode |
871 | m.train(); |
872 | std::stringstream ss; |
873 | m._save_for_mobile(ss); |
874 | mobile::Module bc = _load_for_mobile(ss); |
875 | bc.eval(); |
876 | IValue res; |
877 | for (int i = 0; i < 3; ++i) { |
878 | res = bc.get_method("forward" )(inputs); |
879 | } |
880 | auto output = res.toTensor(); |
881 | AT_ASSERT(outputref.dim() == output.dim()); |
882 | AT_ASSERT( |
883 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
884 | } |
885 | |
886 | TEST(LiteInterpreterTest, FindWrongMethodName) { |
887 | Module m("m" ); |
888 | m.register_parameter("foo" , torch::ones({}), false); |
889 | m.define(R"( |
890 | def add(self, x): |
891 | b = 4 |
892 | return self.foo + x + b |
893 | )" ); |
894 | std::stringstream ss; |
895 | m._save_for_mobile(ss); |
896 | mobile::Module bc = _load_for_mobile(ss); |
897 | ASSERT_TRUE(bc.find_method("forward" ) == c10::nullopt); |
898 | } |
899 | |
900 | TEST(LiteInterpreterTest, FindAndRunMethod) { |
901 | Module m("m" ); |
902 | m.register_parameter("foo" , torch::ones({}), false); |
903 | m.define(R"( |
904 | def add_it(self, x): |
905 | b = 4 |
906 | return self.foo + x + b |
907 | )" ); |
908 | |
909 | std::vector<IValue> inputs; |
910 | auto minput = 5 * torch::ones({}); |
911 | inputs.emplace_back(minput); |
912 | auto ref = m.get_method("add_it" )(inputs); |
913 | |
914 | std::stringstream ss; |
915 | m._save_for_mobile(ss); |
916 | mobile::Module bc = _load_for_mobile(ss); |
917 | IValue res; |
918 | for (int i = 0; i < 3; ++i) { |
919 | auto bcinputs = inputs; |
920 | auto method = bc.find_method("add_it" ); |
921 | AT_ASSERT(method != c10::nullopt); |
922 | res = (*method)(std::move(bcinputs)); |
923 | } |
924 | |
925 | auto resd = res.toTensor().item<float>(); |
926 | auto refd = ref.toTensor().item<float>(); |
927 | AT_ASSERT(resd == refd); |
928 | } |
929 | |
930 | TEST(LiteInterpreterTest, RunMethodVariadic) { |
931 | Module m("m" ); |
932 | m.register_parameter("foo" , torch::ones({}), false); |
933 | m.define(R"( |
934 | def add_three(self, x, y): |
935 | return self.foo + x + y |
936 | )" ); |
937 | |
938 | std::vector<IValue> inputs; |
939 | auto inputx = 5 * torch::ones({}); |
940 | auto inputy = 4 * torch::ones({}); |
941 | auto ref = m.run_method("add_three" , inputx, inputy); |
942 | |
943 | std::stringstream ss; |
944 | m._save_for_mobile(ss); |
945 | mobile::Module bc = _load_for_mobile(ss); |
946 | IValue res = bc.run_method("add_three" , inputx, inputy); |
947 | |
948 | auto resd = res.toTensor().item<float>(); |
949 | auto refd = ref.toTensor().item<float>(); |
950 | AT_ASSERT(resd == refd); |
951 | } |
952 | |
953 | TEST(LiteInterpreterTest, DuplicateSetState) { |
954 | Module m("M" ); |
955 | m.register_parameter("foo" , torch::ones({}), false); |
956 | m.define(R"( |
957 | def __getstate__(self): |
958 | return self.foo + self.foo |
959 | def __setstate__(self, a): |
960 | self.foo = a |
961 | def forward(self, x): |
962 | b = 4 |
963 | return self.foo + x + b |
964 | )" ); |
965 | |
966 | Module b("B" ); |
967 | b.register_module("M0" , m); |
968 | b.register_module("M1" , m); |
969 | b.define(R"( |
970 | def forward(self, x): |
971 | return self.M0.forward(x) + self.M1.forward(x) |
972 | )" ); |
973 | |
974 | std::stringstream ss; |
975 | m._save_for_mobile(ss); |
976 | mobile::Module bc = _load_for_mobile(ss); |
977 | const auto methods = bc.get_methods(); |
978 | const size_t expected_n = 3; |
979 | ASSERT_EQ(methods.size(), expected_n); |
980 | } |
981 | |
982 | TEST(LiteInterpreterTest, ExtraFiles) { |
983 | const auto script = R"JIT( |
984 | def forward(self): |
985 | x = torch.rand(5, 5) |
986 | x = x.mm(x) |
987 | return x |
988 | )JIT" ; |
989 | |
990 | auto module = |
991 | std::make_shared<Module>("Module" , std::make_shared<CompilationUnit>()); |
992 | module->define(script); |
993 | std::ostringstream oss; |
994 | std::unordered_map<std::string, std::string> ; |
995 | extra_files["metadata.json" ] = "abc" ; |
996 | extra_files["mobile_info.json" ] = "{\"key\": 23}" ; |
997 | module->_save_for_mobile(oss, extra_files); |
998 | |
999 | std::istringstream iss(oss.str()); |
1000 | std::unordered_map<std::string, std::string> ; |
1001 | loaded_extra_files["metadata.json" ] = "" ; |
1002 | torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files); |
1003 | ASSERT_EQ(loaded_extra_files["metadata.json" ], "abc" ); |
1004 | |
1005 | loaded_extra_files.clear(); |
1006 | std::vector<std::string> all_files = |
1007 | caffe2::serialize::PyTorchStreamReader(&iss).getAllRecords(); |
1008 | |
1009 | for (auto& file_name : all_files) { |
1010 | if (file_name.find("extra/" ) == 0) { |
1011 | loaded_extra_files[file_name.substr(6)] = "" ; |
1012 | } |
1013 | } |
1014 | iss.seekg(0, iss.beg); |
1015 | torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files); |
1016 | ASSERT_EQ(loaded_extra_files["metadata.json" ], "abc" ); |
1017 | ASSERT_EQ(loaded_extra_files["mobile_info.json" ], "{\"key\": 23}" ); |
1018 | } |
1019 | |
1020 | TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) { |
1021 | torch::jit::Module m("m" ); |
1022 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
1023 | m.register_parameter("bias" , torch::ones({20}), false); |
1024 | m.define(R"( |
1025 | def forward(self, input): |
1026 | x1 = torch.zeros(2, 2) |
1027 | x2 = torch.empty_like(torch.empty(2, 2)) |
1028 | x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
1029 | return (x1, x2, x3) |
1030 | )" ); |
1031 | m.eval(); |
1032 | |
1033 | std::stringstream ss; |
1034 | m._save_for_mobile(ss); |
1035 | |
1036 | torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss); |
1037 | std::set<std::string> operator_names = |
1038 | torch::jit::mobile::_export_operator_list(ptl_model); |
1039 | std::set<std::string> expected_operator_names = { |
1040 | "aten::_convolution" , |
1041 | "aten::empty.memory_format" , |
1042 | "aten::empty_like" , |
1043 | "aten::zeros" , |
1044 | }; |
1045 | EXPECT_EQ(operator_names, expected_operator_names) |
1046 | << "Expected the root operator lists to be the same" ; |
1047 | } |
1048 | |
1049 | TEST(LiteInterpreterTest, DefaultArgsConv) { |
1050 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
1051 | if (s && strcmp(s, "1" ) == 0) |
1052 | return; |
1053 | |
1054 | std::vector<torch::jit::IValue> inputs; |
1055 | |
1056 | Module m("m" ); |
1057 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
1058 | m.register_parameter("bias" , torch::ones({20}), false); |
1059 | m.define(R"( |
1060 | def forward(self, input): |
1061 | return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1) |
1062 | )" ); |
1063 | |
1064 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
1065 | |
1066 | auto outputref = m.forward(inputs).toTensor(); |
1067 | |
1068 | std::stringstream ss; |
1069 | m._save_for_mobile(ss); |
1070 | mobile::Module bc = _load_for_mobile(ss); |
1071 | IValue res; |
1072 | for (int i = 0; i < 1; ++i) { |
1073 | res = bc.get_method("forward" )(inputs); |
1074 | } |
1075 | auto output = res.toTensor(); |
1076 | AT_ASSERT(outputref.dim() == output.dim()); |
1077 | AT_ASSERT(output.equal(outputref)); |
1078 | } |
1079 | |
1080 | TEST(RunTimeTest, ParseBytecode) { |
1081 | // A simple example to show a simple bytecode that can be used independent of |
1082 | // PyTorch TorchScript serialization (unpickler, etc) and operator library. |
1083 | // It has basic control flow (if, else) and basic data orchestration (list |
1084 | // construction). The original PyTorch program: |
1085 | |
1086 | // class Module(torch.nn.Module): |
1087 | // |
1088 | // def __init__(self): |
1089 | // super().__init__() |
1090 | // |
1091 | // def forward(self, x: int, h: int, xfirst: bool): |
1092 | // if xfirst: |
1093 | // return [x, h] |
1094 | // else: |
1095 | // return [h, x] |
1096 | |
1097 | // 1. Prepare for the bytecode. In reality it can be from a customized |
1098 | // deserializer. |
1099 | std::vector<IValue> instructions{ |
1100 | to_tuple({"STOREN" , 1, 4}), |
1101 | to_tuple({"DROPR" , 1, 0}), |
1102 | to_tuple({"MOVE" , 4, 0}), |
1103 | to_tuple({"JF" , 5, 0}), |
1104 | to_tuple({"LOAD" , 2, 0}), |
1105 | to_tuple({"LOAD" , 3, 0}), |
1106 | to_tuple({"LIST_CONSTRUCT" , 0, 2}), |
1107 | to_tuple({"JMP" , 4, 0}), |
1108 | to_tuple({"LOAD" , 3, 0}), |
1109 | to_tuple({"LOAD" , 2, 0}), |
1110 | to_tuple({"LIST_CONSTRUCT" , 1, 2}), |
1111 | to_tuple({"STORE" , 5, 0}), |
1112 | to_tuple({"DROPR" , 3, 0}), |
1113 | to_tuple({"DROPR" , 2, 0}), |
1114 | to_tuple({"MOVE" , 5, 0}), |
1115 | to_tuple({"RET" , 0, 0}), |
1116 | }; |
1117 | std::vector<IValue> operators; // empty for this example |
1118 | std::vector<IValue> constants; // empty for this example |
1119 | |
1120 | std::vector<IValue> types{"List[int]" , "List[int]" }; |
1121 | // 2. Parse the function |
1122 | std::string function_name("test_function" ); |
1123 | auto function = std::unique_ptr<mobile::Function>( |
1124 | new mobile::Function(c10::QualifiedName(function_name))); |
1125 | c10::ivalue::TupleElements debug_handles_m_tuple; |
1126 | parseInstructions( |
1127 | function_name, |
1128 | std::move(*c10::ivalue::Tuple::create(instructions)).elements(), |
1129 | debug_handles_m_tuple, |
1130 | function.get()); |
1131 | parseTypes(c10::ivalue::Tuple::create(types)->elements(), function.get()); |
1132 | const size_t rsize = 5; |
1133 | parseRegisterSize(rsize, function.get()); |
1134 | |
1135 | // 3. Prepare for inputs and run the function |
1136 | // Note that the first input is reserved for Module object. |
1137 | // Since this is a function test and Module object is not required, |
1138 | // a dummy IValue (0) is added here. |
1139 | std::vector<IValue> inputs{0, 1, 2, true}; |
1140 | function->run(inputs); |
1141 | auto output = inputs[0].toList(); |
1142 | ASSERT_EQ(output[0], 1); |
1143 | ASSERT_EQ(output[1], 2); |
1144 | |
1145 | std::vector<IValue> inputs1{0, 1, 2, false}; |
1146 | function->run(inputs1); |
1147 | auto output1 = inputs1[0].toList(); |
1148 | ASSERT_EQ(output1[0], 2); |
1149 | ASSERT_EQ(output1[1], 1); |
1150 | } |
1151 | |
1152 | TEST(RunTimeTest, ParseOperator) { |
1153 | // A simple example to show a simple bytecode that can be used independent of |
1154 | // PyTorch TorchScript serialization (unpickler, etc) and operator library. |
1155 | // It has one operator and we should be able to register it. The original |
1156 | // PyTorch program: |
1157 | |
1158 | // class Add(torch.nn.Module): |
1159 | // def __init__(self): |
1160 | // super(Add, self).__init__() |
1161 | |
1162 | // def forward(self, a, b): |
1163 | // return a + b |
1164 | |
1165 | // 1. Prepare for the bytecode. In reality it can be from a customized |
1166 | // deserializer. |
1167 | std::vector<IValue> instructions{ |
1168 | to_tuple({"STOREN" , 1, 3}), |
1169 | to_tuple({"DROPR" , 1, 0}), |
1170 | to_tuple({"MOVE" , 2, 0}), |
1171 | to_tuple({"MOVE" , 3, 0}), |
1172 | to_tuple({"OP" , 0, 0}), |
1173 | to_tuple({"RET" , 0, 0}), |
1174 | }; |
1175 | std::vector<IValue> operators{ |
1176 | to_tuple({"aten::add" , "Tensor" , 2}), |
1177 | }; |
1178 | std::vector<IValue> constants{ |
1179 | to_tuple({1}), |
1180 | }; |
1181 | // 2. Parse the function |
1182 | std::string function_name("test_function" ); |
1183 | auto function = std::unique_ptr<mobile::Function>( |
1184 | new mobile::Function(c10::QualifiedName(function_name))); |
1185 | c10::ivalue::TupleElements debug_handles_m_tuple; |
1186 | parseInstructions( |
1187 | function_name, |
1188 | std::move(*c10::ivalue::Tuple::create(instructions)).elements(), |
1189 | debug_handles_m_tuple, |
1190 | function.get()); |
1191 | parseOperators( |
1192 | std::move(*c10::ivalue::Tuple::create(operators)).elements(), |
1193 | 1, |
1194 | function.get()); |
1195 | const size_t rsize = 5; |
1196 | parseRegisterSize(rsize, function.get()); |
1197 | |
1198 | // 3. Prepare for inputs and run the function |
1199 | // Note that the first input is reserved for Module object. |
1200 | // Since this is a function test and Module object is not required, |
1201 | // a dummy IValue (0) is added here. |
1202 | std::vector<IValue> inputs{0, at::tensor(1), at::tensor(2)}; |
1203 | function->run(inputs); |
1204 | auto output = inputs[0]; |
1205 | ASSERT_EQ(output, at::tensor(3)); |
1206 | } |
1207 | |
1208 | namespace { |
1209 | void testLiteModuleCompareResultTensors( |
1210 | Module& m, |
1211 | const std::vector<torch::jit::IValue>& inputs, |
1212 | const std::string& method_name = "forward" ) { |
1213 | auto outputref = m.get_method(method_name)(inputs).toTensor(); |
1214 | |
1215 | std::stringstream ss; |
1216 | m._save_for_mobile(ss); |
1217 | mobile::Module bc = _load_for_mobile(ss); |
1218 | IValue res; |
1219 | for (int i = 0; i < 3; ++i) { |
1220 | res = bc.get_method(method_name)(inputs); |
1221 | } |
1222 | auto output = res.toTensor(); |
1223 | AT_ASSERT(outputref.dim() == output.dim()); |
1224 | AT_ASSERT(output.equal(outputref)); |
1225 | } |
1226 | |
1227 | void testDefaultArgsPinv(int num_args) { |
1228 | Module m("m" ); |
1229 | if (num_args == 1) { |
1230 | m.define(R"( |
1231 | def forward(self, input): |
1232 | return torch.linalg_pinv(input) |
1233 | )" ); |
1234 | } else if (num_args == 2) { |
1235 | m.define(R"( |
1236 | def forward(self, input): |
1237 | return torch.linalg_pinv(input, 1e-5) |
1238 | )" ); |
1239 | } else if (num_args == 3) { |
1240 | m.define(R"( |
1241 | def forward(self, input): |
1242 | return torch.linalg_pinv(input, 1e-5, True) |
1243 | )" ); |
1244 | } |
1245 | |
1246 | std::vector<torch::jit::IValue> inputs; |
1247 | const int N = 28; |
1248 | auto input = torch::range(1, N * N, 1); |
1249 | input[0] = 1; // a more stable matrix |
1250 | input = input.view({N, N}); |
1251 | inputs.push_back(input); |
1252 | testLiteModuleCompareResultTensors(m, inputs); |
1253 | } |
1254 | } // namespace |
1255 | |
1256 | #if !defined FB_XPLAT_BUILD |
1257 | TEST(LiteInterpreterTest, DefaultArgsPinv) { |
1258 | // Test with different number of specified arguments. |
1259 | // Arguments not specified take default value. |
1260 | for (int num_args = 1; num_args <= 3; ++num_args) { |
1261 | testDefaultArgsPinv(num_args); |
1262 | } |
1263 | |
1264 | // bytecode with one specified argument: |
1265 | // (6, |
1266 | // ('__torch__.m.forward', |
1267 | // (('instructions', |
1268 | // (('STOREN', 1, 2), |
1269 | // ('DROPR', 1, 0), |
1270 | // ('MOVE', 2, 0), |
1271 | // ('OP', 0, 0), |
1272 | // ('RET', 0, 0))), |
1273 | // ('operators', (('aten::linalg_pinv', '', 1),)), |
1274 | // ('constants', (False, 1e-15)), # default constants are not |
1275 | // used |
1276 | // ('types', ()), |
1277 | // ('register_size', 2)), |
1278 | // (('arguments', |
1279 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
1280 | // None)), |
1281 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
1282 | // None)))), |
1283 | // ('returns', |
1284 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
1285 | // None)),))))) |
1286 | |
1287 | // bytecode with 2 specified argument: |
1288 | // (6, |
1289 | // ('__torch__.m.forward', |
1290 | // (('instructions', |
1291 | // (('STOREN', 1, 2), |
1292 | // ('DROPR', 1, 0), |
1293 | // ('MOVE', 2, 0), |
1294 | // ('LOADC', 1, 0), # added LOADC for specified argument |
1295 | // ('OP', 0, 0), |
1296 | // ('RET', 0, 0))), |
1297 | // ('operators', (('aten::linalg_pinv', '', 2),)), |
1298 | // ('constants', (False, 1e-05)), # updated constant table |
1299 | // ('types', ()), |
1300 | // ('register_size', 2)), |
1301 | // (('arguments', |
1302 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
1303 | // None)), |
1304 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
1305 | // None)))), |
1306 | // ('returns', |
1307 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
1308 | // None)),))))) |
1309 | |
1310 | // bytecode with 3 specified arguments: |
1311 | // (6, |
1312 | // ('__torch__.m.forward', |
1313 | // (('instructions', |
1314 | // (('STOREN', 1, 2), |
1315 | // ('DROPR', 1, 0), |
1316 | // ('MOVE', 2, 0), |
1317 | // ('LOADC', 1, 0), |
1318 | // ('LOADC', 0, 0), |
1319 | // ('OP', 0, 0), |
1320 | // ('RET', 0, 0))), |
1321 | // ('operators', (('aten::linalg_pinv', '', 3),)), |
1322 | // ('constants', (True, 1e-05)), |
1323 | // ('types', ()), |
1324 | // ('register_size', 2)), |
1325 | // (('arguments', |
1326 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
1327 | // None)), |
1328 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
1329 | // None)))), |
1330 | // ('returns', |
1331 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
1332 | // None)),))))) |
1333 | } |
1334 | |
1335 | TEST(LiteInterpreterTest, DefaultArgsTensorinvSpecifyDefault) { |
1336 | // The second argument is specified, but the value is the same as the default |
1337 | // value. It's treated as "not specified" since the value can be fetched from |
1338 | // schema. |
1339 | Module m("m" ); |
1340 | m.define(R"( |
1341 | def forward(self, input): |
1342 | return torch.linalg_tensorinv(input, 2) |
1343 | )" ); |
1344 | torch::jit::MobileCode code(m.get_method("forward" ).graph(), "forward" ); |
1345 | auto arg_nums = code.op_to_num_specified_args(); |
1346 | ASSERT_EQ(arg_nums.size(), 1); |
1347 | ASSERT_EQ(arg_nums["aten::linalg_tensorinv" ], 1); |
1348 | std::vector<torch::jit::IValue> inputs; |
1349 | const int N = 4; |
1350 | auto input = torch::rand({N, N, N, N}); |
1351 | inputs.push_back(input); |
1352 | testLiteModuleCompareResultTensors(m, inputs); |
1353 | } |
1354 | |
1355 | void testDefaultArgsPinvWithOutArg(int num_args) { |
1356 | Module m("m" ); |
1357 | if (num_args == 1) { |
1358 | m.define(R"( |
1359 | def forward(self, input): |
1360 | return torch.linalg_pinv(input, out=input) |
1361 | )" ); |
1362 | } else if (num_args == 2) { |
1363 | m.define(R"( |
1364 | def forward(self, input): |
1365 | return torch.linalg_pinv(input, 1e-5, out=input) |
1366 | )" ); |
1367 | } else if (num_args == 3) { |
1368 | m.define(R"( |
1369 | def forward(self, input): |
1370 | return torch.linalg_pinv(input, 1e-5, True, out=input) |
1371 | )" ); |
1372 | } |
1373 | |
1374 | const int N = 28; |
1375 | auto input = torch::range(1, N * N, 1); |
1376 | input[0] = 10000; // a more stable matrix |
1377 | input = input.view({N, N}); |
1378 | auto ref = m.run_method("forward" , input); |
1379 | TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); |
1380 | TORCH_CHECK(input.equal(ref.toTensor())); |
1381 | } |
1382 | |
1383 | TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) { |
1384 | // Test with different number of specified arguments + out arg. |
1385 | // Arguments not specified take default value. |
1386 | for (int num_args = 1; num_args <= 3; ++num_args) { |
1387 | testDefaultArgsPinvWithOutArg(num_args); |
1388 | } |
1389 | } |
1390 | |
1391 | TEST(LiteInterpreterTest, DefaultArgsWithOutArg) { |
1392 | Module m("m" ); |
1393 | m.define(R"( |
1394 | def forward(self, x, h): |
1395 | torch.add(x, h, out=x) |
1396 | )" ); |
1397 | |
1398 | std::vector<IValue> inputs; |
1399 | auto input_x = 2 * torch::ones({}); |
1400 | auto input_h = torch::ones({}); |
1401 | auto ref = m.run_method("forward" , input_x, input_h); |
1402 | |
1403 | std::stringstream ss; |
1404 | |
1405 | m._save_for_mobile(ss, {}, true); |
1406 | mobile::Module bc = _load_for_mobile(ss); |
1407 | bc.run_method("forward" , input_x, input_h); |
1408 | AT_ASSERT(input_x.equal(4 * torch::ones({}))); |
1409 | |
1410 | auto ops = _get_model_ops_and_info(ss); |
1411 | auto op = ops.find("aten::add.out" ); |
1412 | TORCH_CHECK( |
1413 | op != ops.end() && op->second.num_schema_args.has_value() && |
1414 | op->second.num_schema_args.value() == 3); |
1415 | } |
1416 | |
1417 | TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { |
1418 | Module a("A" ); |
1419 | a.define(R"( |
1420 | def bar(self, x, y): |
1421 | return x + y |
1422 | )" ); |
1423 | Module b("B" ); |
1424 | b.register_module("A0" , a); |
1425 | b.define(R"( |
1426 | def foo(self, x, y): |
1427 | return self.A0.bar(x, y) + 2 |
1428 | )" ); |
1429 | Module c("C" ); |
1430 | c.register_module("B0" , b); |
1431 | c.define(R"( |
1432 | def forward(self, x, y): |
1433 | return self.B0.foo(x, y) + 3 |
1434 | )" ); |
1435 | |
1436 | std::vector<IValue> inputs; |
1437 | inputs.emplace_back(torch::rand({2, 4})); |
1438 | inputs.emplace_back(torch::rand({13, 9})); |
1439 | |
1440 | std::stringstream ss; |
1441 | c._save_for_mobile(ss, ExtraFilesMap(), true); |
1442 | auto lite_m = _load_for_mobile(ss); |
1443 | std::string error_pattern = R"( |
1444 | Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add |
1445 | Traceback of TorchScript (most recent call last): |
1446 | File "<string>", line 3, in <unknown> |
1447 | |
1448 | def forward(self, x, y): |
1449 | return self.B0.foo(x, y) + 3 |
1450 | ~~~~~~~~~~~ <--- HERE |
1451 | |
1452 | File "<string>", line 3, in foo |
1453 | |
1454 | def foo(self, x, y): |
1455 | return self.A0.bar(x, y) + 2 |
1456 | ~~~~~~~~~~~ <--- HERE |
1457 | |
1458 | File "<string>", line 3, in bar |
1459 | |
1460 | def bar(self, x, y): |
1461 | return x + y |
1462 | ~~~~~ <--- HERE |
1463 | )" ; |
1464 | ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern); |
1465 | } |
1466 | #endif // !defined(FB_XPLAT_BUILD) |
1467 | |
1468 | namespace { |
1469 | static auto reg = |
1470 | torch::class_<TorchBindLiteInterpreterTestStruct>( |
1471 | "_TorchScriptTesting" , |
1472 | "_LiteInterpreterTest" ) |
1473 | .def(torch::init<>()) |
1474 | .def("get" , &TorchBindLiteInterpreterTestStruct::get) |
1475 | .def_pickle( |
1476 | // __getattr__ |
1477 | [](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>& |
1478 | self) -> int64_t { return 0; }, |
1479 | // __setattr__ |
1480 | [](int64_t state) { |
1481 | return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>(); |
1482 | }); |
1483 | |
1484 | } // namespace |
1485 | |
1486 | TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) { |
1487 | // Create 3 methods: |
1488 | // |
1489 | // 1. forward() returns a tensor with dtype=torch.int64 (4) |
1490 | // 2. forward2() returns a tensor with dtype=torch.float32 (6) |
1491 | // 3. forward3() returns a tensor with dtype=torch.float32 but |
1492 | // the dtype is inferred by the input tensor's dtype |
1493 | // |
1494 | // If caching works correctly, then the result from the full-jit |
1495 | // module and the lite module will be the same. Otherwise, it |
1496 | // will be different if we don't correctly ignore the cache |
1497 | // entry for an operator that has a different number of |
1498 | // arguments. |
1499 | Module m("m" ); |
1500 | m.define(R"( |
1501 | def forward(self): |
1502 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4) |
1503 | return ret1.fill_(25) |
1504 | )" ); |
1505 | m.define(R"( |
1506 | def forward2(self): |
1507 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6) |
1508 | return ret1.fill_(32.0) |
1509 | )" ); |
1510 | m.define(R"( |
1511 | def forward3(self): |
1512 | ret1 = torch.new_empty(torch.zeros(10), [10]) |
1513 | return ret1.fill_(12.0) |
1514 | )" ); |
1515 | |
1516 | std::vector<torch::jit::IValue> inputs; |
1517 | testLiteModuleCompareResultTensors(m, inputs, "forward" ); |
1518 | testLiteModuleCompareResultTensors(m, inputs, "forward2" ); |
1519 | testLiteModuleCompareResultTensors(m, inputs, "forward3" ); |
1520 | } |
1521 | |
1522 | TEST(RunTimeTest, RuntimeCall) { |
1523 | // def call(x): |
1524 | // return x + x |
1525 | // |
1526 | // def forward(a): |
1527 | // x = a + call(a) |
1528 | // y = a + call(x) |
1529 | // return y |
1530 | |
1531 | std::vector<IValue> instructionsCall{ |
1532 | to_tuple({"STORE" , 1, 0}), |
1533 | to_tuple({"LOAD" , 1, 0}), |
1534 | to_tuple({"MOVE" , 1, 0}), |
1535 | to_tuple({"LOADC" , 0, 0}), |
1536 | to_tuple({"OP" , 0, 0}), |
1537 | to_tuple({"RET" , 0, 0}), |
1538 | }; |
1539 | std::vector<IValue> instructionsFoo{ |
1540 | to_tuple({"STORE" , 1, 0}), |
1541 | to_tuple({"LOAD" , 1, 0}), |
1542 | to_tuple({"LOAD" , 1, 0}), |
1543 | to_tuple({"MOVE" , 1, 0}), |
1544 | to_tuple({"CALL" , 0, 0}), |
1545 | to_tuple({"LOADC" , 0, 0}), |
1546 | to_tuple({"OP" , 0, 0}), |
1547 | to_tuple({"CALL" , 0, 0}), |
1548 | to_tuple({"LOADC" , 0, 0}), |
1549 | to_tuple({"OP" , 0, 0}), |
1550 | to_tuple({"RET" , 0, 0}), |
1551 | }; |
1552 | std::vector<IValue> operatorsFoo{ |
1553 | to_tuple({"aten::add" , "Tensor" , 3}), |
1554 | }; |
1555 | std::vector<IValue> constantsFoo{ |
1556 | 1, |
1557 | }; |
1558 | std::vector<IValue> operatorsCall{ |
1559 | to_tuple({"aten::add" , "Tensor" , 3}), |
1560 | }; |
1561 | std::vector<IValue> constantsCall{ |
1562 | 1, |
1563 | }; |
1564 | |
1565 | auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo" )); |
1566 | c10::ivalue::TupleElements debug_handles_m_tuple; |
1567 | parseInstructions( |
1568 | "foo" , |
1569 | std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(), |
1570 | debug_handles_m_tuple, |
1571 | foo.get()); |
1572 | parseOperators( |
1573 | std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(), |
1574 | 1, |
1575 | foo.get()); |
1576 | parseConstants( |
1577 | std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(), |
1578 | foo.get()); |
1579 | const size_t rsize = 5; |
1580 | parseRegisterSize(rsize, foo.get()); |
1581 | |
1582 | auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call" )); |
1583 | parseInstructions( |
1584 | "call" , |
1585 | std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(), |
1586 | debug_handles_m_tuple, |
1587 | call.get()); |
1588 | parseOperators( |
1589 | std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(), |
1590 | 1, |
1591 | call.get()); |
1592 | parseConstants( |
1593 | std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(), |
1594 | call.get()); |
1595 | parseRegisterSize(rsize, call.get()); |
1596 | |
1597 | foo->append_function(*call); |
1598 | |
1599 | std::vector<IValue> inputs{at::tensor(1)}; |
1600 | foo->run(inputs); |
1601 | auto output = inputs[0]; |
1602 | ASSERT_EQ(output, at::tensor(7)); |
1603 | } |
1604 | |
1605 | TEST(LiteInterpreterTest, OperatorSize1) { |
1606 | Module m("m" ); |
1607 | m.define(R"( |
1608 | def forward(self, input: Tensor, scale:float): |
1609 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
1610 | )" ); |
1611 | |
1612 | std::stringstream ss; |
1613 | m._save_for_mobile(ss); |
1614 | mobile::Module bc = _load_for_mobile(ss); |
1615 | const auto& func = bc.get_method("forward" ).function(); |
1616 | ASSERT_EQ( |
1617 | func.get_code().operator_input_sizes_.size(), |
1618 | func.get_code().operators_.size()); |
1619 | } |
1620 | |
1621 | TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest) |
1622 | const std::vector<std::string> test_programs{ |
1623 | // test invoking a method with default parameter |
1624 | R"( |
1625 | def test_func(self, x, b : int = 4): |
1626 | return self.foo + x + b |
1627 | )" , |
1628 | // inner method call with default parameter (gets inlined) |
1629 | R"( |
1630 | def add_with_default_arg(self, x, b : int = 4): |
1631 | return self.foo + x + b |
1632 | def test_func(self, x): |
1633 | return self.add_with_default_arg(x) # invoke method w/ default arg |
1634 | )" , |
1635 | // simple method call |
1636 | R"( |
1637 | def test_func(self, x): |
1638 | b = 4 |
1639 | return self.foo + x + b |
1640 | )" , |
1641 | }; |
1642 | for (const auto& test_program : test_programs) { |
1643 | Module m("m" ); |
1644 | m.register_parameter("foo" , torch::ones({}), false); |
1645 | m.define(test_program); |
1646 | |
1647 | std::stringstream ss; |
1648 | m._save_for_mobile(ss); |
1649 | mobile::Module bc = _load_for_mobile(ss); |
1650 | const auto& func = bc.get_method("test_func" ).function(); |
1651 | ASSERT_EQ( |
1652 | func.get_code().operator_input_sizes_.size(), |
1653 | func.get_code().operators_.size()); |
1654 | } |
1655 | } |
1656 | |
1657 | #if !defined FB_XPLAT_BUILD |
1658 | // The following test run in fbcode only |
1659 | TEST(LiteInterpreterUpgraderTest, DivTensorV2) { |
1660 | std::string filePath(__FILE__); |
1661 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1662 | test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl" ); |
1663 | /* |
1664 | (('__torch__.MyModule.forward', |
1665 | (('instructions', |
1666 | (('STOREN', 1, 3), |
1667 | ('DROPR', 1, 0), |
1668 | ('LOAD', 2, 0), |
1669 | ('LOAD', 3, 0), |
1670 | ('OP', 0, 0), |
1671 | ('LOAD', 2, 0), |
1672 | ('LOAD', 3, 0), |
1673 | ('OP', 1, 0), |
1674 | ('MOVE', 2, 0), |
1675 | ('MOVE', 3, 0), |
1676 | ('OP', 2, 0), |
1677 | ('TUPLE_CONSTRUCT', 3, 0), |
1678 | ('RET', 0, 0))), |
1679 | ('operators', |
1680 | (('aten::div', 'Tensor'), |
1681 | ('aten::div', 'Tensor'), |
1682 | ('aten::div', 'Tensor'))), |
1683 | ('constants', ()), |
1684 | ('types', ()), |
1685 | ('register_size', 3))),) |
1686 | |
1687 | */ |
1688 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1689 | auto intrsuction_list = |
1690 | m_module.get_method("forward" ).function().get_code().instructions_; |
1691 | uint64_t number_of_call_instruction = 0; |
1692 | for (auto& instruction : intrsuction_list) { |
1693 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1694 | } |
1695 | // 3 operators will use upgrader |
1696 | ASSERT_EQ(number_of_call_instruction, 3); |
1697 | |
1698 | std::vector<IValue> inputs = { |
1699 | IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))}; |
1700 | auto actual_output = m_module.forward(inputs); |
1701 | auto expect_output = 2.0 * torch::ones({1}); |
1702 | auto actual_output_list = actual_output.toTuple()->elements(); |
1703 | ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output)); |
1704 | } |
1705 | |
1706 | TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) { |
1707 | std::string filePath(__FILE__); |
1708 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1709 | test_model_file.append( |
1710 | "upgrader_models/test_versioned_div_tensor_out_v2.ptl" ); |
1711 | /* |
1712 | (('__torch__.MyModule.forward', |
1713 | (('instructions', |
1714 | (('STOREN', 1, 4), |
1715 | ('DROPR', 1, 0), |
1716 | ('MOVE', 2, 0), |
1717 | ('MOVE', 3, 0), |
1718 | ('MOVE', 4, 0), |
1719 | ('OP', 0, 0), |
1720 | ('RET', 0, 0))), |
1721 | ('operators', (('aten::div', 'out'),)), |
1722 | ('constants', ()), |
1723 | ('types', ()), |
1724 | ('register_size', 4))),) |
1725 | */ |
1726 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1727 | |
1728 | auto intrsuction_list = |
1729 | m_module.get_method("forward" ).function().get_code().instructions_; |
1730 | uint64_t number_of_call_instruction = 0; |
1731 | for (auto& instruction : intrsuction_list) { |
1732 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1733 | } |
1734 | // One operator will use upgrader |
1735 | ASSERT_EQ(number_of_call_instruction, 1); |
1736 | |
1737 | std::vector<IValue> inputs{ |
1738 | IValue(6 * torch::ones({1})), |
1739 | IValue(3 * torch::ones({1})), |
1740 | IValue(torch::empty({1}))}; |
1741 | m_module.forward(inputs); |
1742 | auto expect_output = 2.0 * torch::ones({1}); |
1743 | auto actual_output = inputs[2].toTensor(); |
1744 | // The out argument will be overwritten with the output |
1745 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1746 | } |
1747 | |
1748 | TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) { |
1749 | std::string filePath(__FILE__); |
1750 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1751 | test_model_file.append( |
1752 | "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl" ); |
1753 | /* |
1754 | (('__torch__.MyModule.forward', |
1755 | (('instructions', |
1756 | (('STOREN', 1, 3), |
1757 | ('DROPR', 1, 0), |
1758 | ('MOVE', 2, 0), |
1759 | ('MOVE', 3, 0), |
1760 | ('OP', 0, 0), |
1761 | ('RET', 0, 0))), |
1762 | ('operators', (('aten::div_', 'Tensor'),)), |
1763 | ('constants', ()), |
1764 | ('types', ()), |
1765 | ('register_size', 3))),) |
1766 | */ |
1767 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1768 | |
1769 | auto intrsuction_list = |
1770 | m_module.get_method("forward" ).function().get_code().instructions_; |
1771 | uint64_t number_of_call_instruction = 0; |
1772 | for (auto& instruction : intrsuction_list) { |
1773 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1774 | } |
1775 | // One operator will use upgrader |
1776 | ASSERT_EQ(number_of_call_instruction, 1); |
1777 | |
1778 | std::vector<IValue> inputs{ |
1779 | IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))}; |
1780 | m_module.forward(inputs); |
1781 | auto expect_output = 2.0 * torch::ones({1}); |
1782 | auto actual_output = inputs[0].toTensor(); |
1783 | // The out argument will be overwritten with the output |
1784 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1785 | } |
1786 | |
1787 | TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) { |
1788 | std::string filePath(__FILE__); |
1789 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1790 | test_model_file.append( |
1791 | "upgrader_models/test_versioned_div_scalar_float_v2.ptl" ); |
1792 | /* |
1793 | (('__torch__.MyModuleFloat.forward', |
1794 | (('instructions', |
1795 | (('STOREN', 1, 3), |
1796 | ('DROPR', 1, 0), |
1797 | ('MOVE', 2, 0), |
1798 | ('MOVE', 3, 0), |
1799 | ('OP', 0, 0), |
1800 | ('RET', 0, 0))), |
1801 | ('operators', (('aten::div', 'Scalar'),)), |
1802 | ('constants', ()), |
1803 | ('types', ()), |
1804 | ('register_size', 3))),) |
1805 | */ |
1806 | |
1807 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1808 | |
1809 | auto intrsuction_list = |
1810 | m_module.get_method("forward" ).function().get_code().instructions_; |
1811 | uint64_t number_of_call_instruction = 0; |
1812 | for (auto& instruction : intrsuction_list) { |
1813 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1814 | } |
1815 | // One operator will use upgrader |
1816 | ASSERT_EQ(number_of_call_instruction, 1); |
1817 | |
1818 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1819 | auto output = m_module.forward(inputs); |
1820 | auto expect_output = 2.0 * torch::ones({1}); |
1821 | auto actual_output = output.toTensor(); |
1822 | |
1823 | // The out argument will be overwritten with the output |
1824 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1825 | } |
1826 | |
1827 | TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) { |
1828 | std::string filePath(__FILE__); |
1829 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1830 | test_model_file.append( |
1831 | "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl" ); |
1832 | /* |
1833 | (('__torch__.MyModuleFloat.forward', |
1834 | (('instructions', |
1835 | (('STOREN', 1, 3), |
1836 | ('DROPR', 1, 0), |
1837 | ('MOVE', 2, 0), |
1838 | ('OP', 0, 0), |
1839 | ('MOVE', 3, 0), |
1840 | ('OP', 1, 0), |
1841 | ('RET', 0, 0))), |
1842 | ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))), |
1843 | ('constants', ()), |
1844 | ('types', ()), |
1845 | ('register_size', 3))),) |
1846 | */ |
1847 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1848 | |
1849 | auto intrsuction_list = |
1850 | m_module.get_method("forward" ).function().get_code().instructions_; |
1851 | uint64_t number_of_call_instruction = 0; |
1852 | for (auto& instruction : intrsuction_list) { |
1853 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1854 | } |
1855 | // No operator will use upgrader |
1856 | ASSERT_EQ(number_of_call_instruction, 0); |
1857 | |
1858 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1859 | auto output = m_module.forward(inputs); |
1860 | auto expect_output = 0.5 * torch::ones({1}); |
1861 | auto actual_output = output.toTensor(); |
1862 | std::cout << "expect output: " << expect_output; |
1863 | std::cout << "actual output: " << actual_output; |
1864 | // The out argument will be overwritten with the output |
1865 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1866 | } |
1867 | |
1868 | TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) { |
1869 | std::string filePath(__FILE__); |
1870 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1871 | test_model_file.append( |
1872 | "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl" ); |
1873 | /* |
1874 | (('__torch__.MyModuleInt.forward', |
1875 | (('instructions', |
1876 | (('STOREN', 1, 3), |
1877 | ('DROPR', 1, 0), |
1878 | ('MOVE', 2, 0), |
1879 | ('OP', 0, 0), |
1880 | ('MOVE', 3, 0), |
1881 | ('OP', 1, 0), |
1882 | ('RET', 0, 0))), |
1883 | ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))), |
1884 | ('constants', ()), |
1885 | ('types', ()), |
1886 | ('register_size', 3))),) |
1887 | */ |
1888 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1889 | |
1890 | auto intrsuction_list = |
1891 | m_module.get_method("forward" ).function().get_code().instructions_; |
1892 | uint64_t number_of_call_instruction = 0; |
1893 | for (auto& instruction : intrsuction_list) { |
1894 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1895 | } |
1896 | // No operator will use upgrader |
1897 | ASSERT_EQ(number_of_call_instruction, 0); |
1898 | |
1899 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1900 | auto output = m_module.forward(inputs); |
1901 | auto expect_output = 0.5 * torch::ones({1}); |
1902 | auto actual_output = output.toTensor(); |
1903 | |
1904 | // The out argument will be overwritten with the output |
1905 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1906 | } |
1907 | |
1908 | TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) { |
1909 | std::string filePath(__FILE__); |
1910 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1911 | test_model_file.append( |
1912 | "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl" ); |
1913 | /* |
1914 | (('__torch__.MyModule.forward', |
1915 | (('instructions', |
1916 | (('STOREN', 1, 5), |
1917 | ('DROPR', 1, 0), |
1918 | ('LOAD', 2, 0), |
1919 | ('LOAD', 3, 0), |
1920 | ('OP', 0, 0), |
1921 | ('MOVE', 2, 0), |
1922 | ('LOAD', 4, 0), |
1923 | ('OP', 1, 0), |
1924 | ('LOAD', 3, 0), |
1925 | ('MOVE', 4, 0), |
1926 | ('OP', 2, 0), |
1927 | ('MOVE', 3, 0), |
1928 | ('MOVE', 5, 0), |
1929 | ('OP', 3, 0), |
1930 | ('TUPLE_CONSTRUCT', 4, 0), |
1931 | ('RET', 0, 0))), |
1932 | ('operators', |
1933 | (('aten::div', ''), |
1934 | ('aten::div', 'float'), |
1935 | ('aten::div', ''), |
1936 | ('aten::div', 'int'))), |
1937 | ('constants', ()), |
1938 | ('types', ()), |
1939 | ('register_size', 5))),) |
1940 | */ |
1941 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1942 | auto intrsuction_list = |
1943 | m_module.get_method("forward" ).function().get_code().instructions_; |
1944 | uint64_t number_of_call_instruction = 0; |
1945 | for (auto& instruction : intrsuction_list) { |
1946 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1947 | } |
1948 | // No operator will use upgrader |
1949 | ASSERT_EQ(number_of_call_instruction, 0); |
1950 | |
1951 | std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)}; |
1952 | auto output = m_module.forward(inputs); |
1953 | auto output_list = output.toTupleRef().elements(); |
1954 | auto expect_output = std::vector<IValue>( |
1955 | {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)}); |
1956 | // auto actual_output = output.toTensor(); |
1957 | for (size_t i = 0; i < expect_output.size(); i++) { |
1958 | ASSERT_EQ(output_list[i], expect_output[i]); |
1959 | } |
1960 | } |
1961 | |
1962 | TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) { |
1963 | std::string filePath(__FILE__); |
1964 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1965 | test_model_file.append( |
1966 | "upgrader_models/test_versioned_div_scalar_int_v2.ptl" ); |
1967 | /* |
1968 | (('__torch__.MyModuleInt.forward', |
1969 | (('instructions', |
1970 | (('STOREN', 1, 3), |
1971 | ('DROPR', 1, 0), |
1972 | ('MOVE', 2, 0), |
1973 | ('MOVE', 3, 0), |
1974 | ('OP', 0, 0), |
1975 | ('RET', 0, 0))), |
1976 | ('operators', (('aten::div', 'Scalar'),)), |
1977 | ('constants', ()), |
1978 | ('types', ()), |
1979 | ('register_size', 3))),) |
1980 | */ |
1981 | mobile::Module m_module = _load_for_mobile(test_model_file); |
1982 | |
1983 | auto intrsuction_list = |
1984 | m_module.get_method("forward" ).function().get_code().instructions_; |
1985 | uint64_t number_of_call_instruction = 0; |
1986 | for (auto& instruction : intrsuction_list) { |
1987 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1988 | } |
1989 | // One operator will use upgrader |
1990 | ASSERT_EQ(number_of_call_instruction, 1); |
1991 | |
1992 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)}; |
1993 | auto output = m_module.forward(inputs); |
1994 | auto expect_output = 2.0 * torch::ones({1}); |
1995 | auto actual_output = output.toTensor(); |
1996 | |
1997 | // The out argument will be overwritten with the output |
1998 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1999 | } |
2000 | |
2001 | TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) { |
2002 | std::string filePath(__FILE__); |
2003 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
2004 | test_model_file.append( |
2005 | "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl" ); |
2006 | /* |
2007 | (('__torch__.MyModuleFloat.forward', |
2008 | (('instructions', |
2009 | (('STOREN', 1, 3), |
2010 | ('DROPR', 1, 0), |
2011 | ('MOVE', 2, 0), |
2012 | ('MOVE', 3, 0), |
2013 | ('OP', 0, 0), |
2014 | ('RET', 0, 0))), |
2015 | ('operators', (('aten::div_', 'Scalar'),)), |
2016 | ('constants', ()), |
2017 | ('types', ()), |
2018 | ('register_size', 3))),) |
2019 | */ |
2020 | |
2021 | mobile::Module m_module = _load_for_mobile(test_model_file); |
2022 | |
2023 | auto intrsuction_list = |
2024 | m_module.get_method("forward" ).function().get_code().instructions_; |
2025 | uint64_t number_of_call_instruction = 0; |
2026 | for (auto& instruction : intrsuction_list) { |
2027 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
2028 | } |
2029 | // One operator will use upgrader |
2030 | ASSERT_EQ(number_of_call_instruction, 1); |
2031 | |
2032 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
2033 | auto output = m_module.forward(inputs); |
2034 | auto expect_output = 2.0 * torch::ones({1}); |
2035 | auto actual_output = output.toTensor(); |
2036 | |
2037 | // The out argument will be overwritten with the output |
2038 | ASSERT_TRUE(actual_output.equal(expect_output)); |
2039 | } |
2040 | |
2041 | TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) { |
2042 | std::string filePath(__FILE__); |
2043 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
2044 | test_model_file.append( |
2045 | "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl" ); |
2046 | /* |
2047 | (('__torch__.MyModuleInt.forward', |
2048 | (('instructions', |
2049 | (('STOREN', 1, 3), |
2050 | ('DROPR', 1, 0), |
2051 | ('MOVE', 2, 0), |
2052 | ('MOVE', 3, 0), |
2053 | ('OP', 0, 0), |
2054 | ('RET', 0, 0))), |
2055 | ('operators', (('aten::div_', 'Scalar'),)), |
2056 | ('constants', ()), |
2057 | ('types', ()), |
2058 | ('register_size', 3))),) |
2059 | */ |
2060 | |
2061 | mobile::Module m_module = _load_for_mobile(test_model_file); |
2062 | |
2063 | auto intrsuction_list = |
2064 | m_module.get_method("forward" ).function().get_code().instructions_; |
2065 | uint64_t number_of_call_instruction = 0; |
2066 | for (auto& instruction : intrsuction_list) { |
2067 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
2068 | } |
2069 | // One operator will use upgrader |
2070 | ASSERT_EQ(number_of_call_instruction, 1); |
2071 | |
2072 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)}; |
2073 | auto output = m_module.forward(inputs); |
2074 | auto expect_output = 2.0 * torch::ones({1}); |
2075 | auto actual_output = output.toTensor(); |
2076 | |
2077 | // The out argument will be overwritten with the output |
2078 | ASSERT_TRUE(actual_output.equal(expect_output)); |
2079 | } |
2080 | |
2081 | #endif // !defined(FB_XPLAT_BUILD) |
2082 | |
2083 | TEST(LiteInterpreterUpgraderTest, Upgrader) { |
2084 | std::vector<mobile::Function> upgrader_functions; |
2085 | |
2086 | for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) { |
2087 | byteCodeFunctionWithOperator.function.initialize_operators(true); |
2088 | ASSERT_EQ( |
2089 | byteCodeFunctionWithOperator.function.get_code().operators_.size(), |
2090 | byteCodeFunctionWithOperator.function.get_code().op_names_.size()); |
2091 | if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) { |
2092 | for (const auto& op : byteCodeFunctionWithOperator.operators) { |
2093 | byteCodeFunctionWithOperator.function.append_operator( |
2094 | op.name, op.overload_name, op.num_specified_args); |
2095 | } |
2096 | } |
2097 | upgrader_functions.push_back(byteCodeFunctionWithOperator.function); |
2098 | } |
2099 | |
2100 | ASSERT_EQ(getUpgraderBytecodeList().size(), upgrader_functions.size()); |
2101 | } |
2102 | |
2103 | void enumerateTupleType( |
2104 | size_t depth, |
2105 | std::vector<TypePtr>& current, |
2106 | const std::vector<TypePtr>& candidates, |
2107 | std::vector<TypePtr>& out) { |
2108 | static std::vector<std::string> fieldNames; |
2109 | if (depth > fieldNames.size()) { |
2110 | fieldNames.reserve(depth); |
2111 | for (size_t i = fieldNames.size(); i < depth; i++) { |
2112 | fieldNames.push_back("field" + std::to_string(i)); |
2113 | } |
2114 | } |
2115 | if (depth == 0) { |
2116 | out.push_back(TupleType::create(current)); |
2117 | while (fieldNames.size() > current.size()) { |
2118 | fieldNames.pop_back(); |
2119 | } |
2120 | out.push_back(TupleType::createNamed("NamedTuple" , fieldNames, current)); |
2121 | return; |
2122 | } |
2123 | for (const auto& type : candidates) { |
2124 | if (containsAnyType(type)) { |
2125 | continue; |
2126 | } |
2127 | current.push_back(type); |
2128 | enumerateTupleType(depth - 1, current, candidates, out); |
2129 | current.pop_back(); |
2130 | } |
2131 | } |
2132 | |
2133 | class LiteInterpreterDynamicTypeTestFixture |
2134 | : public ::testing::TestWithParam<size_t> { |
2135 | protected: |
2136 | void SetUp() override { |
2137 | cu = std::make_shared<CompilationUnit>(); |
2138 | std::vector<TypePtr> keyTypes = { |
2139 | AnyType::get(), |
2140 | IntType::get(), |
2141 | BoolType::get(), |
2142 | FloatType::get(), |
2143 | ComplexType::get(), |
2144 | StringType::get(), |
2145 | TensorType::get(), |
2146 | DeviceObjType::get(), |
2147 | }; |
2148 | types = { |
2149 | NoneType::get(), |
2150 | NumberType::get(), |
2151 | ClassType::create("__torch__.TestClass1" , cu), |
2152 | ClassType::create("__torch__.TestClass2" , cu), |
2153 | AnyListType::get(), |
2154 | AnyTupleType::get(), |
2155 | StreamObjType::get(), |
2156 | CapsuleType::get(), |
2157 | GeneratorType::get(), |
2158 | StorageType::get(), |
2159 | VarType::create("t" ), |
2160 | VarType::create("v" ), |
2161 | AnyClassType::get()}; |
2162 | std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types)); |
2163 | auto expandTypes = [&](size_t tupleSize) { |
2164 | std::vector<TypePtr> nested; |
2165 | for (const auto& type : types) { |
2166 | if (!(type == AnyType::get())) { |
2167 | nested.emplace_back(ListType::create(type)); |
2168 | if (!(type == NoneType::get() || |
2169 | type->kind() == OptionalType::Kind)) { |
2170 | nested.emplace_back(OptionalType::create(type)); |
2171 | } |
2172 | } |
2173 | for (const auto& keyType : keyTypes) { |
2174 | nested.emplace_back(DictType::create(keyType, type)); |
2175 | } |
2176 | } |
2177 | std::vector<TypePtr> tmp; |
2178 | enumerateTupleType(tupleSize, tmp, types, nested); |
2179 | std::move( |
2180 | std::begin(nested), std::end(nested), std::back_inserter(types)); |
2181 | }; |
2182 | expandTypes(1); |
2183 | expandTypes(1); |
2184 | } |
2185 | std::shared_ptr<CompilationUnit> cu; |
2186 | std::vector<TypePtr> types; |
2187 | |
2188 | public: |
2189 | static constexpr size_t kNumSplits = 10; |
2190 | }; |
2191 | |
2192 | constexpr size_t LiteInterpreterDynamicTypeTestFixture::kNumSplits; |
2193 | |
2194 | /** |
2195 | * Enumerate all possible JIT types appearing in mobile runtime, and test |
2196 | * whether subtyping relation is preserved after one of the JIT types is |
2197 | * converted to DynamicType. |
2198 | * |
2199 | * We firstly enumerate all "base" types in a vector, and implement |
2200 | * expandTypes() to enumerate container types one "level" up for a given set |
2201 | * of types. We call expandTypes() twice to test types nested less or equal |
2202 | * to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc. |
2203 | */ |
2204 | TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) { |
2205 | size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits; |
2206 | size_t begin = num * GetParam(); |
2207 | size_t end = std::min(types.size(), begin + num); |
2208 | for (const auto& a : types) { |
2209 | auto da = DynamicType::create(*a); |
2210 | for (size_t i = begin; i < end; i++) { |
2211 | const auto& b = types[i]; |
2212 | bool result = a->isSubtypeOf(*b); |
2213 | EXPECT_EQ(result, da->isSubtypeOf(*b)); |
2214 | result = b->isSubtypeOf(*a); |
2215 | EXPECT_EQ(result, b->isSubtypeOf(*da)); |
2216 | } |
2217 | } |
2218 | } |
2219 | |
2220 | INSTANTIATE_TEST_CASE_P( |
2221 | PyTorch, |
2222 | LiteInterpreterDynamicTypeTestFixture, |
2223 | ::testing::Range( |
2224 | static_cast<size_t>(0), |
2225 | LiteInterpreterDynamicTypeTestFixture::kNumSplits)); |
2226 | |
2227 | } // namespace jit |
2228 | } // namespace torch |
2229 | |