1 | #include <test/cpp/jit/test_utils.h> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <torch/csrc/autograd/generated/variable_factories.h> |
7 | #include <torch/csrc/jit/api/module.h> |
8 | #include <torch/csrc/jit/frontend/resolver.h> |
9 | #include <torch/csrc/jit/mobile/compatibility/backport.h> |
10 | #include <torch/csrc/jit/mobile/compatibility/backport_manager.h> |
11 | #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h> |
12 | #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h> |
13 | #include <torch/csrc/jit/mobile/flatbuffer_loader.h> |
14 | #include <torch/csrc/jit/mobile/import.h> |
15 | #include <torch/csrc/jit/mobile/interpreter.h> |
16 | #include <torch/csrc/jit/mobile/module.h> |
17 | #include <torch/csrc/jit/mobile/parse_bytecode.h> |
18 | #include <torch/csrc/jit/mobile/parse_operators.h> |
19 | #include <torch/csrc/jit/serialization/export.h> |
20 | #include <torch/csrc/jit/serialization/export_bytecode.h> |
21 | #include <torch/csrc/jit/serialization/flatbuffer_serializer.h> |
22 | #include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h> |
23 | #include <torch/csrc/jit/serialization/import.h> |
24 | #include <torch/custom_class.h> |
25 | #include <torch/torch.h> |
26 | |
27 | #include <caffe2/serialize/versions.h> |
28 | #include <torch/csrc/jit/serialization/import_export_functions.h> |
29 | #include <unordered_set> |
30 | |
31 | #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2) |
32 | #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT |
33 | namespace flatbuffers = flatbuffers_fbsource; |
34 | #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT |
35 | #else |
36 | #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT |
37 | #endif |
38 | // Tests go in torch::jit |
39 | namespace torch { |
40 | namespace jit { |
41 | |
42 | namespace { |
43 | mobile::Module parse_mobile_module( |
44 | void* data, |
45 | size_t size, |
46 | bool should_copy_tensor_memory = false) { |
47 | return parse_and_initialize_mobile_module( |
48 | static_cast<char*>(data), |
49 | size, |
50 | /*device=*/c10::nullopt, |
51 | /*extra_files=*/nullptr, |
52 | should_copy_tensor_memory); |
53 | } |
54 | } // namespace |
55 | |
56 | TEST(FlatbufferTest, UpsampleNearest2d) { |
57 | Module m("m" ); |
58 | m.define(R"( |
59 | def forward(self, input: Tensor, scale:float): |
60 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
61 | )" ); |
62 | |
63 | std::vector<IValue> inputs; |
64 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
65 | inputs.emplace_back(at::Scalar(2.0)); |
66 | auto ref = m.forward(inputs); |
67 | |
68 | CompilationOptions options; |
69 | mobile::Module bc = jitModuleToMobile(m, options); |
70 | IValue res; |
71 | res = bc.forward(inputs); |
72 | |
73 | auto resd = res.toTensor(); |
74 | auto refd = ref.toTensor(); |
75 | ASSERT_TRUE(resd.equal(refd)); |
76 | |
77 | auto buff = save_mobile_module_to_bytes(bc); |
78 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
79 | auto res2 = bc2.forward(inputs); |
80 | auto resd2 = res2.toTensor(); |
81 | ASSERT_TRUE(resd2.equal(refd)); |
82 | } |
83 | |
84 | TEST(FlatbufferTest, UpsampleNearest2dWithCopyTensorMemory) { |
85 | Module m("m" ); |
86 | m.define(R"( |
87 | def forward(self, input: Tensor, scale:float): |
88 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
89 | )" ); |
90 | |
91 | std::vector<IValue> inputs; |
92 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
93 | inputs.emplace_back(at::Scalar(2.0)); |
94 | auto ref = m.forward(inputs); |
95 | |
96 | CompilationOptions options; |
97 | mobile::Module bc = jitModuleToMobile(m, options); |
98 | IValue res; |
99 | res = bc.forward(inputs); |
100 | |
101 | auto resd = res.toTensor(); |
102 | auto refd = ref.toTensor(); |
103 | ASSERT_TRUE(resd.equal(refd)); |
104 | |
105 | auto buff = save_mobile_module_to_bytes(bc); |
106 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true); |
107 | |
108 | auto res2 = bc2.forward(inputs); |
109 | auto resd2 = res2.toTensor(); |
110 | ASSERT_TRUE(resd2.equal(refd)); |
111 | } |
112 | |
113 | TEST(FlatbufferTest, CheckAttrAccess) { |
114 | Module m("m" ); |
115 | m.register_attribute("mobile_optimized" , BoolType::get(), true); |
116 | |
117 | CompilationOptions options; |
118 | mobile::Module bc = jitModuleToMobile(m, options); |
119 | bool mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
120 | |
121 | AT_ASSERT(mobile_optimized); |
122 | m.setattr("mobile_optimized" , false); |
123 | bc = jitModuleToMobile(m, options); |
124 | mobile_optimized = bc.attr("mobile_optimized" , false).toBool(); |
125 | |
126 | AT_ASSERT(!mobile_optimized); |
127 | |
128 | auto buff = save_mobile_module_to_bytes(bc); |
129 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
130 | auto mobile_optimized2 = bc2.attr("mobile_optimized" , false).toBool(); |
131 | AT_ASSERT(!mobile_optimized2); |
132 | } |
133 | |
134 | TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest) |
135 | const std::vector<std::string> test_programs{ |
136 | // test invoking a method with default parameter |
137 | R"( |
138 | def test_func(self, x, b : int = 4): |
139 | return self.foo + x + b |
140 | )" , |
141 | // inner method call with default parameter (gets inlined) |
142 | R"( |
143 | def add_with_default_arg(self, x, b : int = 4): |
144 | return self.foo + x + b |
145 | def test_func(self, x): |
146 | return self.add_with_default_arg(x) # invoke method w/ default arg |
147 | )" , |
148 | // simple method call |
149 | R"( |
150 | def test_func(self, x): |
151 | b = 4 |
152 | return self.foo + x + b |
153 | )" , |
154 | }; |
155 | for (const auto& test_program : test_programs) { |
156 | Module m("m" ); |
157 | m.register_parameter("foo" , torch::ones({}), false); |
158 | m.define(test_program); |
159 | |
160 | const int fortyTwo = 42; // (keep linter happy) |
161 | auto minput = fortyTwo * torch::ones({}); |
162 | auto ref = m.run_method("test_func" , minput); |
163 | |
164 | CompilationOptions options; |
165 | mobile::Module bc = jitModuleToMobile(m, options); |
166 | const auto& test_func = bc.get_method("test_func" ); |
167 | IValue res; |
168 | for (int i = 0; i < 3; ++i) { |
169 | res = test_func({minput}); |
170 | } |
171 | |
172 | auto resd = res.toTensor().item<float>(); |
173 | auto refd = ref.toTensor().item<float>(); |
174 | AT_ASSERT(resd == refd); |
175 | |
176 | auto buff = save_mobile_module_to_bytes(bc); |
177 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
178 | const auto& test_func2 = bc2.get_method("test_func" ); |
179 | IValue res2; |
180 | for (int i = 0; i < 3; ++i) { |
181 | res2 = test_func2({minput}); |
182 | } |
183 | auto resd2 = res2.toTensor().item<float>(); |
184 | AT_ASSERT(resd2 == refd); |
185 | } |
186 | } |
187 | |
188 | #if !defined(FB_XPLAT_BUILD) |
189 | TEST(FlatbufferTest, FlatbufferBackPortTest) { |
190 | Module m("m" ); |
191 | m.define(R"( |
192 | def forward(self, input: Tensor, scale:float): |
193 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
194 | )" ); |
195 | std::stringstream ss; |
196 | m._save_for_mobile(ss, {}, false, true); |
197 | |
198 | std::stringstream oss; |
199 | bool backPortSuccess = _backport_for_mobile(ss, oss, 5); |
200 | ASSERT_TRUE(backPortSuccess); |
201 | } |
202 | #endif // !defined(FB_XPLAT_BUILD) |
203 | |
204 | TEST(FlatbufferTest, ExtraFiles) { |
205 | const auto script = R"JIT( |
206 | def forward(self): |
207 | x = torch.rand(5, 5) |
208 | x = x.mm(x) |
209 | return x |
210 | )JIT" ; |
211 | |
212 | auto module = |
213 | std::make_shared<Module>("Module" , std::make_shared<CompilationUnit>()); |
214 | module->define(script); |
215 | std::ostringstream oss; |
216 | std::unordered_map<std::string, std::string> ; |
217 | extra_files["metadata.json" ] = "abc" ; |
218 | extra_files["mobile_info.json" ] = "{\"key\": 23}" ; |
219 | |
220 | std::unordered_map<std::string, std::string> ; |
221 | std::stringstream ss; |
222 | module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true); |
223 | |
224 | loaded_extra_files["metadata.json" ] = "" ; |
225 | auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files); |
226 | |
227 | ASSERT_EQ(loaded_extra_files["metadata.json" ], "abc" ); |
228 | ASSERT_EQ(loaded_extra_files["mobile_info.json" ], "{\"key\": 23}" ); |
229 | |
230 | // load it twice using the same stream |
231 | auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files); |
232 | |
233 | ASSERT_EQ(loaded_extra_files["metadata.json" ], "abc" ); |
234 | ASSERT_EQ(loaded_extra_files["mobile_info.json" ], "{\"key\": 23}" ); |
235 | } |
236 | |
237 | TEST(FlatbufferTest, Conv) { |
238 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
239 | if (s && strcmp(s, "1" ) == 0) |
240 | return; |
241 | |
242 | std::vector<torch::jit::IValue> inputs; |
243 | |
244 | Module m("m" ); |
245 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
246 | m.register_parameter("bias" , torch::ones({20}), false); |
247 | m.define(R"( |
248 | def forward(self, input): |
249 | return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
250 | )" ); |
251 | |
252 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
253 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
254 | |
255 | auto outputref = m.forward(inputs).toTensor(); |
256 | |
257 | CompilationOptions options; |
258 | mobile::Module bc = jitModuleToMobile(m, options); |
259 | IValue res; |
260 | for (int i = 0; i < 3; ++i) { |
261 | res = bc.get_method("forward" )(inputs); |
262 | } |
263 | auto output = res.toTensor(); |
264 | AT_ASSERT(outputref.dim() == output.dim()); |
265 | AT_ASSERT( |
266 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
267 | |
268 | auto buff = save_mobile_module_to_bytes(bc); |
269 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
270 | for (int i = 0; i < 3; ++i) { |
271 | res = bc2.get_method("forward" )(inputs); |
272 | } |
273 | output = res.toTensor(); |
274 | AT_ASSERT(outputref.dim() == output.dim()); |
275 | AT_ASSERT( |
276 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
277 | } |
278 | |
279 | TEST(FlatbufferTest, ConvWithCopyTensorMemory) { |
280 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
281 | if (s && strcmp(s, "1" ) == 0) |
282 | return; |
283 | |
284 | std::vector<torch::jit::IValue> inputs; |
285 | |
286 | Module m("m" ); |
287 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
288 | m.register_parameter("bias" , torch::ones({20}), false); |
289 | m.define(R"( |
290 | def forward(self, input): |
291 | return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
292 | )" ); |
293 | |
294 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
295 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
296 | |
297 | auto outputref = m.forward(inputs).toTensor(); |
298 | |
299 | CompilationOptions options; |
300 | mobile::Module bc = jitModuleToMobile(m, options); |
301 | IValue res; |
302 | for (int i = 0; i < 3; ++i) { |
303 | res = bc.get_method("forward" )(inputs); |
304 | } |
305 | auto output = res.toTensor(); |
306 | AT_ASSERT(outputref.dim() == output.dim()); |
307 | AT_ASSERT( |
308 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
309 | |
310 | auto buff = save_mobile_module_to_bytes(bc); |
311 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true); |
312 | |
313 | for (int i = 0; i < 3; ++i) { |
314 | res = bc2.get_method("forward" )(inputs); |
315 | } |
316 | output = res.toTensor(); |
317 | AT_ASSERT(outputref.dim() == output.dim()); |
318 | AT_ASSERT( |
319 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
320 | } |
321 | |
322 | TEST(FlatbufferTest, Inline) { |
323 | Module m("m" ); |
324 | m.define(R"JIT( |
325 | def foo1(self, x): |
326 | return x + 1 |
327 | |
328 | def foo2(self, x): |
329 | return self.foo1(x) + 2 |
330 | |
331 | def foo3(self, x): |
332 | return self.foo2(x) + 3 |
333 | )JIT" ); |
334 | CompilationOptions options; |
335 | mobile::Module bc = jitModuleToMobile(m, options); |
336 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
337 | auto output = bc.get_method("foo3" )(inputs); |
338 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
339 | |
340 | auto buff = save_mobile_module_to_bytes(bc); |
341 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
342 | std::vector<torch::jit::IValue> inputs2({torch::ones({})}); |
343 | output = bc2.get_method("foo3" )(inputs2); |
344 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
345 | } |
346 | |
347 | TEST(FlatbufferTest, InlineWithCopyTensorMemory) { |
348 | Module m("m" ); |
349 | m.define(R"JIT( |
350 | def foo1(self, x): |
351 | return x + 1 |
352 | |
353 | def foo2(self, x): |
354 | return self.foo1(x) + 2 |
355 | |
356 | def foo3(self, x): |
357 | return self.foo2(x) + 3 |
358 | )JIT" ); |
359 | CompilationOptions options; |
360 | mobile::Module bc = jitModuleToMobile(m, options); |
361 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
362 | auto output = bc.get_method("foo3" )(inputs); |
363 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
364 | |
365 | auto buff = save_mobile_module_to_bytes(bc); |
366 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true); |
367 | std::vector<torch::jit::IValue> inputs2({torch::ones({})}); |
368 | output = bc2.get_method("foo3" )(inputs2); |
369 | AT_ASSERT(output.toTensor().item<float>() == 7.0); |
370 | } |
371 | |
372 | TEST(FlatbufferTest, Tuple) { |
373 | Module m("m" ); |
374 | m.define(R"JIT( |
375 | def foo(self, x): |
376 | return (1, 2, x + 3) |
377 | |
378 | def forward(self, x): |
379 | tuple = self.foo(x) |
380 | return tuple |
381 | )JIT" ); |
382 | CompilationOptions options; |
383 | mobile::Module bc = jitModuleToMobile(m, options); |
384 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
385 | auto output = bc.get_method("forward" )(inputs); |
386 | AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2); |
387 | |
388 | auto buff = save_mobile_module_to_bytes(bc); |
389 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
390 | output = bc2.get_method("forward" )(inputs); |
391 | AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2); |
392 | } |
393 | |
394 | TEST(FlatbufferTest, Dict) { |
395 | Module m("m" ); |
396 | m.define(R"JIT( |
397 | def foo(self, x): |
398 | return {"result": x + 1} |
399 | |
400 | def forward(self, x): |
401 | d = self.foo(x) |
402 | return d |
403 | )JIT" ); |
404 | CompilationOptions options; |
405 | mobile::Module bc = jitModuleToMobile(m, options); |
406 | std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
407 | auto output = bc.get_method("forward" )(inputs); |
408 | AT_ASSERT(output.toGenericDict().at("result" ).toTensor().item().toInt() == 2); |
409 | |
410 | auto buff = save_mobile_module_to_bytes(bc); |
411 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
412 | output = bc2.get_method("forward" )(inputs); |
413 | AT_ASSERT(output.toGenericDict().at("result" ).toTensor().item().toInt() == 2); |
414 | } |
415 | |
416 | TEST(FlatbufferTest, Prim) { |
417 | Module m("m" ); |
418 | m.define(R"JIT( |
419 | def forward(self, x): |
420 | return int(x) |
421 | )JIT" ); |
422 | |
423 | std::vector<IValue> inputs; |
424 | auto minput = 3.5 * torch::ones({}); |
425 | inputs.emplace_back(minput); |
426 | auto ref = m.run_method("forward" , minput); |
427 | |
428 | CompilationOptions options; |
429 | mobile::Module bc = jitModuleToMobile(m, options); |
430 | IValue res; |
431 | for (int i = 0; i < 3; ++i) { |
432 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
433 | auto bcinputs = inputs; |
434 | res = bc.get_method("forward" )(bcinputs); |
435 | } |
436 | |
437 | auto resi = res.toInt(); |
438 | auto refi = ref.toInt(); |
439 | AT_ASSERT(resi == refi); |
440 | |
441 | auto buff = save_mobile_module_to_bytes(bc); |
442 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
443 | for (int i = 0; i < 3; ++i) { |
444 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
445 | auto bcinputs = inputs; |
446 | res = bc2.get_method("forward" )(bcinputs); |
447 | } |
448 | auto resi2 = res.toInt(); |
449 | AT_ASSERT(resi2 == refi); |
450 | } |
451 | |
452 | TEST(FlatbufferTest, PrimScalar) { |
453 | Module m("m" ); |
454 | m.define(R"JIT( |
455 | def forward(self, x): |
456 | return int(x.item()) |
457 | )JIT" ); |
458 | |
459 | std::vector<IValue> inputs; |
460 | auto minput = 3.5 * torch::ones({}); |
461 | inputs.emplace_back(minput); |
462 | auto ref = m.run_method("forward" , minput); |
463 | |
464 | CompilationOptions options; |
465 | mobile::Module bc = jitModuleToMobile(m, options); |
466 | IValue res; |
467 | for (int i = 0; i < 3; ++i) { |
468 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
469 | auto bcinputs = inputs; |
470 | res = bc.get_method("forward" )(bcinputs); |
471 | } |
472 | |
473 | auto resi = res.toInt(); |
474 | auto refi = ref.toInt(); |
475 | AT_ASSERT(resi == refi); |
476 | |
477 | auto buff = save_mobile_module_to_bytes(bc); |
478 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
479 | for (int i = 0; i < 3; ++i) { |
480 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
481 | auto bcinputs = inputs; |
482 | res = bc2.get_method("forward" )(bcinputs); |
483 | } |
484 | auto resi2 = res.toInt(); |
485 | AT_ASSERT(resi2 == refi); |
486 | } |
487 | |
488 | TEST(FlatbufferTest, WrongMethodName) { |
489 | Module m("m" ); |
490 | m.register_parameter("foo" , torch::ones({}), false); |
491 | m.define(R"( |
492 | def add(self, x): |
493 | b = 4 |
494 | return self.foo + x + b |
495 | )" ); |
496 | CompilationOptions options; |
497 | mobile::Module bc = jitModuleToMobile(m, options); |
498 | std::vector<IValue> inputs; |
499 | auto minput = 5 * torch::ones({}); |
500 | inputs.emplace_back(minput); |
501 | ASSERT_THROWS_WITH_MESSAGE( |
502 | bc.get_method("forward" )(inputs), "is not defined" ); |
503 | |
504 | auto buff = save_mobile_module_to_bytes(bc); |
505 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
506 | ASSERT_THROWS_WITH_MESSAGE( |
507 | bc2.get_method("forward" )(inputs), "is not defined" ); |
508 | } |
509 | |
510 | TEST(FlatbufferTest, SetState) { |
511 | Module m("m" ); |
512 | m.register_parameter("foo" , torch::ones({}), false); |
513 | m.define(R"( |
514 | def __getstate__(self): |
515 | return self.foo |
516 | def __setstate__(self, a): |
517 | self.foo = a |
518 | def forward(self, x): |
519 | b = 4 |
520 | return self.foo + x + b |
521 | )" ); |
522 | |
523 | std::vector<IValue> inputs; |
524 | auto minput = 5 * torch::ones({}); |
525 | inputs.emplace_back(minput); |
526 | |
527 | std::stringstream ms; |
528 | m.save(ms); |
529 | auto loaded_m = load(ms); |
530 | auto ref = loaded_m.run_method("forward" , minput); |
531 | |
532 | CompilationOptions options; |
533 | mobile::Module bc = jitModuleToMobile(m, options); |
534 | IValue res; |
535 | for (int i = 0; i < 3; ++i) { |
536 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
537 | auto bcinputs = inputs; |
538 | res = bc.get_method("forward" )(bcinputs); |
539 | } |
540 | |
541 | auto resd = res.toTensor().item<float>(); |
542 | auto refd = ref.toTensor().item<float>(); |
543 | AT_ASSERT(resd == refd); |
544 | |
545 | auto buff = save_mobile_module_to_bytes(bc); |
546 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
547 | for (int i = 0; i < 3; ++i) { |
548 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
549 | auto bcinputs = inputs; |
550 | res = bc2.get_method("forward" )(bcinputs); |
551 | } |
552 | |
553 | auto resd2 = res.toTensor().item<float>(); |
554 | AT_ASSERT(resd2 == refd); |
555 | } |
556 | |
557 | class TorchBindFlatbufferTestStruct : public torch::jit::CustomClassHolder { |
558 | public: |
559 | std::string get(at::Tensor t) { |
560 | std::stringstream ss; |
561 | ss << "Hello! Your tensor has " ; |
562 | ss << t.numel(); |
563 | ss << " elements!" ; |
564 | return ss.str(); |
565 | } |
566 | }; |
567 | |
568 | namespace { |
569 | struct ClassNamespaceValue : public SugaredValue { |
570 | explicit ClassNamespaceValue(c10::QualifiedName name) |
571 | : basename_(std::move(name)) {} |
572 | |
573 | std::shared_ptr<SugaredValue> attr( |
574 | const SourceRange& loc, |
575 | GraphFunction& m, |
576 | const std::string& name) override { |
577 | const auto fullName = c10::QualifiedName(basename_, name); |
578 | |
579 | // Check to see if it is a custom class. |
580 | if (auto custom_class = getCustomClass(fullName.qualifiedName())) { |
581 | return std::make_shared<ClassValue>(custom_class); |
582 | } |
583 | |
584 | // If it's not a custom class, assume it's another namespace |
585 | // NOLINTNEXTLINE(performance-move-const-arg) |
586 | return std::make_shared<ClassNamespaceValue>(std::move(fullName)); |
587 | } |
588 | |
589 | std::string kind() const override { |
590 | return "Class Namespace" ; |
591 | } |
592 | |
593 | private: |
594 | c10::QualifiedName basename_; |
595 | }; |
596 | |
597 | struct TestModuleResolver : public Resolver { |
598 | std::shared_ptr<SugaredValue> resolveValue( |
599 | const std::string& name, |
600 | GraphFunction& m, |
601 | const SourceRange& loc) override { |
602 | if (name == "torch" ) { |
603 | return std::make_shared<BuiltinModule>("aten" ); |
604 | } else if (name == "__torch__" ) { |
605 | return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name)); |
606 | } |
607 | |
608 | return nullptr; |
609 | } |
610 | |
611 | TypePtr resolveType(const std::string& name, const SourceRange& loc) |
612 | override { |
613 | return nullptr; |
614 | } |
615 | }; |
616 | } // namespace |
617 | |
618 | TEST(FlatbufferTest, BuiltinClass) { |
619 | script::Module m("m" ); |
620 | |
621 | auto cls = getCustomClass( |
622 | "__torch__.torch.classes._TorchScriptTesting._FlatbufferTest" ); |
623 | TORCH_INTERNAL_ASSERT(cls); |
624 | c10::intrusive_ptr<torch::CustomClassHolder> obj_holder; |
625 | m.register_attribute("my_obj" , cls, IValue::make_capsule(obj_holder)); |
626 | |
627 | m.register_parameter("foo" , torch::ones({}), false); |
628 | m.define( |
629 | R"( |
630 | def __getstate__(self): |
631 | return 1 |
632 | def __setstate__(self, a): |
633 | self.my_obj = __torch__.torch.classes._TorchScriptTesting._FlatbufferTest() |
634 | |
635 | def forward(self, x) -> str: |
636 | return self.my_obj.get(x) |
637 | )" , |
638 | std::make_shared<TestModuleResolver>()); |
639 | |
640 | CompilationOptions options; |
641 | mobile::Module bc = jitModuleToMobile(m, options); |
642 | auto buff = save_mobile_module_to_bytes(bc); |
643 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
644 | std::string expected = "Hello! Your tensor has 12 elements!" ; |
645 | auto res = |
646 | bc2.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
647 | const auto& str2 = res.toStringRef(); |
648 | AT_ASSERT(str2 == expected); |
649 | } |
650 | |
651 | TEST(FlatbufferTest, BuiltinFunction) { |
652 | script::Module m("m" ); |
653 | auto custom_class_obj = make_custom_class<TorchBindFlatbufferTestStruct>(); |
654 | m.register_attribute("my_obj" , custom_class_obj.type(), custom_class_obj); |
655 | m.define(R"( |
656 | def forward(self, x) -> str: |
657 | return self.my_obj.get(x) |
658 | )" ); |
659 | |
660 | CompilationOptions options; |
661 | mobile::Module bc = jitModuleToMobile(m, options); |
662 | auto res = |
663 | bc.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
664 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
665 | auto str = res.toStringRef(); |
666 | std::string expected = "Hello! Your tensor has 12 elements!" ; |
667 | AT_ASSERT(str == expected); |
668 | |
669 | auto buff = save_mobile_module_to_bytes(bc); |
670 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
671 | res = bc2.get_method("forward" )(std::vector<IValue>{torch::zeros({3, 4})}); |
672 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
673 | str = res.toStringRef(); |
674 | AT_ASSERT(str == expected); |
675 | } |
676 | |
677 | TEST(FlatbufferTest, Eval) { |
678 | std::vector<torch::jit::IValue> inputs; |
679 | |
680 | Module m("m" ); |
681 | m.define(R"( |
682 | def __init__(self, x): |
683 | self.training = True |
684 | |
685 | def forward(self, input): |
686 | return torch.dropout(input, 1.0, self.training) |
687 | )" ); |
688 | |
689 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) |
690 | inputs.push_back(torch::ones({1, 1, 28, 28})); |
691 | m.eval(); |
692 | auto outputref = m.forward(inputs).toTensor(); |
693 | |
694 | // save m in training mode to make sure that mobile eval() will correctly |
695 | // change back to eval mode |
696 | m.train(); |
697 | CompilationOptions options; |
698 | mobile::Module bc = jitModuleToMobile(m, options); |
699 | bc.eval(); |
700 | IValue res; |
701 | for (int i = 0; i < 3; ++i) { |
702 | res = bc.get_method("forward" )(inputs); |
703 | } |
704 | auto output = res.toTensor(); |
705 | AT_ASSERT(outputref.dim() == output.dim()); |
706 | AT_ASSERT( |
707 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
708 | |
709 | auto buff = save_mobile_module_to_bytes(bc); |
710 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
711 | bc2.eval(); |
712 | for (int i = 0; i < 3; ++i) { |
713 | res = bc2.get_method("forward" )(inputs); |
714 | } |
715 | output = res.toTensor(); |
716 | AT_ASSERT(outputref.dim() == output.dim()); |
717 | AT_ASSERT( |
718 | outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
719 | } |
720 | |
721 | TEST(FlatbufferTest, FindWrongMethodName) { |
722 | Module m("m" ); |
723 | m.register_parameter("foo" , torch::ones({}), false); |
724 | m.define(R"( |
725 | def add(self, x): |
726 | b = 4 |
727 | return self.foo + x + b |
728 | )" ); |
729 | CompilationOptions options; |
730 | mobile::Module bc = jitModuleToMobile(m, options); |
731 | ASSERT_TRUE(bc.find_method("forward" ) == c10::nullopt); |
732 | |
733 | auto buff = save_mobile_module_to_bytes(bc); |
734 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
735 | ASSERT_TRUE(bc2.find_method("forward" ) == c10::nullopt); |
736 | } |
737 | |
738 | TEST(FlatbufferTest, FindAndRunMethod) { |
739 | Module m("m" ); |
740 | m.register_parameter("foo" , torch::ones({}), false); |
741 | m.define(R"( |
742 | def add_it(self, x): |
743 | b = 4 |
744 | return self.foo + x + b |
745 | )" ); |
746 | |
747 | std::vector<IValue> inputs; |
748 | auto minput = 5 * torch::ones({}); |
749 | inputs.emplace_back(minput); |
750 | auto ref = m.get_method("add_it" )(inputs); |
751 | |
752 | CompilationOptions options; |
753 | mobile::Module bc = jitModuleToMobile(m, options); |
754 | IValue res; |
755 | for (int i = 0; i < 3; ++i) { |
756 | auto bcinputs = inputs; |
757 | auto method = bc.find_method("add_it" ); |
758 | AT_ASSERT(method != c10::nullopt); |
759 | res = (*method)(std::move(bcinputs)); |
760 | } |
761 | |
762 | auto resd = res.toTensor().item<float>(); |
763 | auto refd = ref.toTensor().item<float>(); |
764 | AT_ASSERT(resd == refd); |
765 | |
766 | auto buff = save_mobile_module_to_bytes(bc); |
767 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
768 | |
769 | for (int i = 0; i < 3; ++i) { |
770 | auto bcinputs = inputs; |
771 | auto method = bc2.find_method("add_it" ); |
772 | AT_ASSERT(method != c10::nullopt); |
773 | res = (*method)(std::move(bcinputs)); |
774 | } |
775 | |
776 | resd = res.toTensor().item<float>(); |
777 | AT_ASSERT(resd == refd); |
778 | } |
779 | |
780 | TEST(FlatbufferTest, RunMethodVariadic) { |
781 | Module m("m" ); |
782 | m.register_parameter("foo" , torch::ones({}), false); |
783 | m.define(R"( |
784 | def add_three(self, x, y): |
785 | return self.foo + x + y |
786 | )" ); |
787 | |
788 | std::vector<IValue> inputs; |
789 | auto inputx = 5 * torch::ones({}); |
790 | auto inputy = 4 * torch::ones({}); |
791 | auto ref = m.run_method("add_three" , inputx, inputy); |
792 | |
793 | CompilationOptions options; |
794 | mobile::Module bc = jitModuleToMobile(m, options); |
795 | IValue res = bc.run_method("add_three" , inputx, inputy); |
796 | |
797 | auto resd = res.toTensor().item<float>(); |
798 | auto refd = ref.toTensor().item<float>(); |
799 | AT_ASSERT(resd == refd); |
800 | |
801 | auto buff = save_mobile_module_to_bytes(bc); |
802 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
803 | res = bc.run_method("add_three" , inputx, inputy); |
804 | resd = res.toTensor().item<float>(); |
805 | AT_ASSERT(resd == refd); |
806 | } |
807 | |
808 | TEST(FlatbufferTest, DuplicateSetState) { |
809 | Module m("M" ); |
810 | m.register_parameter("foo" , torch::ones({}), false); |
811 | m.define(R"( |
812 | def __getstate__(self): |
813 | return self.foo + self.foo |
814 | def __setstate__(self, a): |
815 | self.foo = a |
816 | def forward(self, x): |
817 | b = 4 |
818 | return self.foo + x + b |
819 | )" ); |
820 | |
821 | Module b("B" ); |
822 | b.register_module("M0" , m); |
823 | b.register_module("M1" , m); |
824 | b.define(R"( |
825 | def forward(self, x): |
826 | return self.M0.forward(x) + self.M1.forward(x) |
827 | )" ); |
828 | |
829 | CompilationOptions options; |
830 | mobile::Module bc = jitModuleToMobile(m, options); |
831 | const auto methods = bc.get_methods(); |
832 | const size_t expected_n = 3; |
833 | ASSERT_EQ(methods.size(), expected_n); |
834 | |
835 | auto buff = save_mobile_module_to_bytes(bc); |
836 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
837 | const auto methods2 = bc.get_methods(); |
838 | ASSERT_EQ(methods2.size(), expected_n); |
839 | } |
840 | |
841 | TEST(FlatbufferTest, OpNameExportFetchRootOperators) { |
842 | torch::jit::Module m("m" ); |
843 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
844 | m.register_parameter("bias" , torch::ones({20}), false); |
845 | m.define(R"( |
846 | def forward(self, input): |
847 | x1 = torch.zeros(2, 2) |
848 | x2 = torch.empty_like(torch.empty(2, 2)) |
849 | x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) |
850 | return (x1, x2, x3) |
851 | )" ); |
852 | m.eval(); |
853 | |
854 | CompilationOptions options; |
855 | mobile::Module ptl_model = jitModuleToMobile(m, options); |
856 | std::set<std::string> operator_names = |
857 | torch::jit::mobile::_export_operator_list(ptl_model); |
858 | std::set<std::string> expected_operator_names = { |
859 | "aten::_convolution" , |
860 | "aten::empty.memory_format" , |
861 | "aten::empty_like" , |
862 | "aten::zeros" , |
863 | }; |
864 | EXPECT_EQ(operator_names, expected_operator_names) |
865 | << "Expected the root operator lists to be the same" ; |
866 | |
867 | auto buff = save_mobile_module_to_bytes(ptl_model); |
868 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
869 | operator_names = torch::jit::mobile::_export_operator_list(bc2); |
870 | EXPECT_EQ(operator_names, expected_operator_names) |
871 | << "Expected the root operator lists to be the same" ; |
872 | } |
873 | |
874 | TEST(FlatbufferTest, DefaultArgsConv) { |
875 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
876 | if (s && strcmp(s, "1" ) == 0) |
877 | return; |
878 | |
879 | std::vector<torch::jit::IValue> inputs; |
880 | |
881 | Module m("m" ); |
882 | m.register_parameter("weight" , torch::ones({20, 1, 5, 5}), false); |
883 | m.register_parameter("bias" , torch::ones({20}), false); |
884 | m.define(R"( |
885 | def forward(self, input): |
886 | return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1) |
887 | )" ); |
888 | |
889 | inputs.emplace_back(torch::ones({1, 1, 28, 28})); |
890 | |
891 | auto outputref = m.forward(inputs).toTensor(); |
892 | |
893 | CompilationOptions options; |
894 | mobile::Module bc = jitModuleToMobile(m, options); |
895 | IValue res; |
896 | for (int i = 0; i < 1; ++i) { |
897 | res = bc.get_method("forward" )(inputs); |
898 | } |
899 | auto output = res.toTensor(); |
900 | AT_ASSERT(outputref.dim() == output.dim()); |
901 | AT_ASSERT(output.equal(outputref)); |
902 | |
903 | auto buff = save_mobile_module_to_bytes(bc); |
904 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
905 | for (int i = 0; i < 1; ++i) { |
906 | res = bc2.get_method("forward" )(inputs); |
907 | } |
908 | output = res.toTensor(); |
909 | AT_ASSERT(outputref.dim() == output.dim()); |
910 | AT_ASSERT(output.equal(outputref)); |
911 | } |
912 | |
913 | namespace { |
914 | void testLiteModuleCompareResultTensors( |
915 | Module& m, |
916 | const std::vector<torch::jit::IValue>& inputs, |
917 | const std::string& method_name = "forward" ) { |
918 | auto outputref = m.get_method(method_name)(inputs).toTensor(); |
919 | |
920 | CompilationOptions options; |
921 | mobile::Module bc = jitModuleToMobile(m, options); |
922 | IValue res; |
923 | for (int i = 0; i < 3; ++i) { |
924 | res = bc.get_method(method_name)(inputs); |
925 | } |
926 | auto output = res.toTensor(); |
927 | AT_ASSERT(outputref.dim() == output.dim()); |
928 | AT_ASSERT(output.equal(outputref)); |
929 | |
930 | auto buff = save_mobile_module_to_bytes(bc); |
931 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
932 | for (int i = 0; i < 3; ++i) { |
933 | res = bc2.get_method(method_name)(inputs); |
934 | } |
935 | output = res.toTensor(); |
936 | AT_ASSERT(outputref.dim() == output.dim()); |
937 | AT_ASSERT(output.equal(outputref)); |
938 | } |
939 | |
940 | static void testDefaultArgsPinv(int num_args) { |
941 | Module m("m" ); |
942 | if (num_args == 1) { |
943 | m.define(R"( |
944 | def forward(self, input): |
945 | return torch.linalg_pinv(input) |
946 | )" ); |
947 | } else if (num_args == 2) { |
948 | m.define(R"( |
949 | def forward(self, input): |
950 | return torch.linalg_pinv(input, 1e-5) |
951 | )" ); |
952 | } else if (num_args == 3) { |
953 | m.define(R"( |
954 | def forward(self, input): |
955 | return torch.linalg_pinv(input, 1e-5, True) |
956 | )" ); |
957 | } |
958 | |
959 | std::vector<torch::jit::IValue> inputs; |
960 | const int N = 28; |
961 | auto input = torch::range(1, N * N, 1); |
962 | input[0] = 1; // a more stable matrix |
963 | input = input.view({N, N}); |
964 | inputs.emplace_back(input); |
965 | testLiteModuleCompareResultTensors(m, inputs); |
966 | } |
967 | } // namespace |
968 | |
969 | #if !defined FB_XPLAT_BUILD |
970 | TEST(FlatbufferTest, DefaultArgsPinv) { |
971 | // Test with different number of specified arguments. |
972 | // Arguments not specified take default value. |
973 | for (int num_args = 1; num_args <= 3; ++num_args) { |
974 | testDefaultArgsPinv(num_args); |
975 | } |
976 | |
977 | // bytecode with one specified argument: |
978 | // (6, |
979 | // ('__torch__.m.forward', |
980 | // (('instructions', |
981 | // (('STOREN', 1, 2), |
982 | // ('DROPR', 1, 0), |
983 | // ('MOVE', 2, 0), |
984 | // ('OP', 0, 0), |
985 | // ('RET', 0, 0))), |
986 | // ('operators', (('aten::linalg_pinv', '', 1),)), |
987 | // ('constants', (False, 1e-15)), # default constants are not |
988 | // used |
989 | // ('types', ()), |
990 | // ('register_size', 2)), |
991 | // (('arguments', |
992 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
993 | // None)), |
994 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
995 | // None)))), |
996 | // ('returns', |
997 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
998 | // None)),))))) |
999 | |
1000 | // bytecode with 2 specified argument: |
1001 | // (6, |
1002 | // ('__torch__.m.forward', |
1003 | // (('instructions', |
1004 | // (('STOREN', 1, 2), |
1005 | // ('DROPR', 1, 0), |
1006 | // ('MOVE', 2, 0), |
1007 | // ('LOADC', 1, 0), # added LOADC for specified argument |
1008 | // ('OP', 0, 0), |
1009 | // ('RET', 0, 0))), |
1010 | // ('operators', (('aten::linalg_pinv', '', 2),)), |
1011 | // ('constants', (False, 1e-05)), # updated constant table |
1012 | // ('types', ()), |
1013 | // ('register_size', 2)), |
1014 | // (('arguments', |
1015 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
1016 | // None)), |
1017 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
1018 | // None)))), |
1019 | // ('returns', |
1020 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
1021 | // None)),))))) |
1022 | |
1023 | // bytecode with 3 specified arguments: |
1024 | // (6, |
1025 | // ('__torch__.m.forward', |
1026 | // (('instructions', |
1027 | // (('STOREN', 1, 2), |
1028 | // ('DROPR', 1, 0), |
1029 | // ('MOVE', 2, 0), |
1030 | // ('LOADC', 1, 0), |
1031 | // ('LOADC', 0, 0), |
1032 | // ('OP', 0, 0), |
1033 | // ('RET', 0, 0))), |
1034 | // ('operators', (('aten::linalg_pinv', '', 3),)), |
1035 | // ('constants', (True, 1e-05)), |
1036 | // ('types', ()), |
1037 | // ('register_size', 2)), |
1038 | // (('arguments', |
1039 | // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', |
1040 | // None)), |
1041 | // (('name', 'input'), ('type', 'Tensor'), ('default_value', |
1042 | // None)))), |
1043 | // ('returns', |
1044 | // ((('name', ''), ('type', 'Tensor'), ('default_value', |
1045 | // None)),))))) |
1046 | } |
1047 | |
1048 | TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) { |
1049 | // The second argument is specified, but the value is the same as the default |
1050 | // value. It's treated as "not specified" since the value can be fetched from |
1051 | // schema. |
1052 | Module m("m" ); |
1053 | m.define(R"( |
1054 | def forward(self, input): |
1055 | return torch.linalg_tensorinv(input, 2) |
1056 | )" ); |
1057 | torch::jit::MobileCode code(m.get_method("forward" ).graph(), "forward" ); |
1058 | auto arg_nums = code.op_to_num_specified_args(); |
1059 | ASSERT_EQ(arg_nums.size(), 1); |
1060 | ASSERT_EQ(arg_nums["aten::linalg_tensorinv" ], 1); |
1061 | std::vector<torch::jit::IValue> inputs; |
1062 | const int N = 4; |
1063 | auto input = torch::rand({N, N, N, N}); |
1064 | inputs.emplace_back(input); |
1065 | testLiteModuleCompareResultTensors(m, inputs); |
1066 | } |
1067 | |
1068 | static void testDefaultArgsPinvWithOutArg(int num_args) { |
1069 | Module m("m" ); |
1070 | if (num_args == 1) { |
1071 | m.define(R"( |
1072 | def forward(self, input): |
1073 | return torch.linalg_pinv(input, out=input) |
1074 | )" ); |
1075 | } else if (num_args == 2) { |
1076 | m.define(R"( |
1077 | def forward(self, input): |
1078 | return torch.linalg_pinv(input, 1e-5, out=input) |
1079 | )" ); |
1080 | } else if (num_args == 3) { |
1081 | m.define(R"( |
1082 | def forward(self, input): |
1083 | return torch.linalg_pinv(input, 1e-5, True, out=input) |
1084 | )" ); |
1085 | } |
1086 | |
1087 | const int N = 28; |
1088 | auto input = torch::range(1, N * N, 1); |
1089 | input[0] = 10000; // a more stable matrix |
1090 | input = input.view({N, N}); |
1091 | auto ref = m.run_method("forward" , input); |
1092 | TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); |
1093 | TORCH_CHECK(input.equal(ref.toTensor())); |
1094 | } |
1095 | |
1096 | TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) { |
1097 | // Test with different number of specified arguments + out arg. |
1098 | // Arguments not specified take default value. |
1099 | for (int num_args = 1; num_args <= 3; ++num_args) { |
1100 | testDefaultArgsPinvWithOutArg(num_args); |
1101 | } |
1102 | } |
1103 | |
1104 | TEST(FlatbufferTest, DefaultArgsWithOutArg) { |
1105 | Module m("m" ); |
1106 | m.define(R"( |
1107 | def forward(self, x, h): |
1108 | torch.add(x, h, out=x) |
1109 | )" ); |
1110 | |
1111 | std::vector<IValue> inputs; |
1112 | auto input_x = 2 * torch::ones({}); |
1113 | auto input_h = torch::ones({}); |
1114 | auto ref = m.run_method("forward" , input_x, input_h); |
1115 | |
1116 | CompilationOptions options; |
1117 | mobile::Module bc = jitModuleToMobile(m, options); |
1118 | bc.run_method("forward" , input_x, input_h); |
1119 | AT_ASSERT(input_x.equal(4 * torch::ones({}))); |
1120 | |
1121 | auto buff = save_mobile_module_to_bytes(bc); |
1122 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
1123 | auto input_x2 = 2 * torch::ones({}); |
1124 | auto input_h2 = torch::ones({}); |
1125 | m.run_method("forward" , input_x2, input_h2); |
1126 | bc2.run_method("forward" , input_x2, input_h2); |
1127 | AT_ASSERT(input_x2.equal(4 * torch::ones({}))); |
1128 | } |
1129 | |
1130 | #endif // !defined(FB_XPLAT_BUILD) |
1131 | |
1132 | namespace { |
1133 | static auto reg = |
1134 | torch::class_<TorchBindFlatbufferTestStruct>( |
1135 | "_TorchScriptTesting" , |
1136 | "_FlatbufferTest" ) |
1137 | .def(torch::init<>()) |
1138 | .def("get" , &TorchBindFlatbufferTestStruct::get) |
1139 | .def_pickle( |
1140 | // __getattr__ |
1141 | [](const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self) |
1142 | -> int64_t { return 0; }, |
1143 | // __setattr__ |
1144 | [](int64_t state) { |
1145 | return c10::make_intrusive<TorchBindFlatbufferTestStruct>(); |
1146 | }); |
1147 | |
1148 | } // namespace |
1149 | |
1150 | TEST(FlatbufferTest, OperatorCacheDifferentiatesDefaultArgs) { |
1151 | // Create 3 methods: |
1152 | // |
1153 | // 1. forward() returns a tensor with dtype=torch.int64 (4) |
1154 | // 2. forward2() returns a tensor with dtype=torch.float32 (6) |
1155 | // 3. forward3() returns a tensor with dtype=torch.float32 but |
1156 | // the dtype is inferred by the input tensor's dtype |
1157 | // |
1158 | // If caching works correctly, then the result from the full-jit |
1159 | // module and the lite module will be the same. Otherwise, it |
1160 | // will be different if we don't correctly ignore the cache |
1161 | // entry for an operator that has a different number of |
1162 | // arguments. |
1163 | Module m("m" ); |
1164 | m.define(R"( |
1165 | def forward(self): |
1166 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4) |
1167 | return ret1.fill_(25) |
1168 | )" ); |
1169 | m.define(R"( |
1170 | def forward2(self): |
1171 | ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6) |
1172 | return ret1.fill_(32.0) |
1173 | )" ); |
1174 | m.define(R"( |
1175 | def forward3(self): |
1176 | ret1 = torch.new_empty(torch.zeros(10), [10]) |
1177 | return ret1.fill_(12.0) |
1178 | )" ); |
1179 | |
1180 | std::vector<torch::jit::IValue> inputs; |
1181 | testLiteModuleCompareResultTensors(m, inputs, "forward" ); |
1182 | testLiteModuleCompareResultTensors(m, inputs, "forward2" ); |
1183 | testLiteModuleCompareResultTensors(m, inputs, "forward3" ); |
1184 | } |
1185 | |
1186 | TEST(FlatbufferTest, OperatorSize1) { |
1187 | Module m("m" ); |
1188 | m.define(R"( |
1189 | def forward(self, input: Tensor, scale:float): |
1190 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
1191 | )" ); |
1192 | |
1193 | CompilationOptions options; |
1194 | mobile::Module bc = jitModuleToMobile(m, options); |
1195 | const auto& func = bc.get_method("forward" ).function(); |
1196 | ASSERT_EQ( |
1197 | func.get_code().operator_input_sizes_.size(), |
1198 | func.get_code().operators_.size()); |
1199 | |
1200 | auto buff = save_mobile_module_to_bytes(bc); |
1201 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
1202 | const auto& func2 = bc.get_method("forward" ).function(); |
1203 | ASSERT_EQ( |
1204 | func2.get_code().operator_input_sizes_.size(), |
1205 | func2.get_code().operators_.size()); |
1206 | } |
1207 | |
1208 | TEST(FlatbufferTest, BoolAndDoubleList) { |
1209 | Module m("m" ); |
1210 | c10::List<bool> boollist; |
1211 | boollist.push_back(false); |
1212 | IValue boollist_ival = boollist; |
1213 | IValue doublelist = std::vector<double>{2.0}; |
1214 | m.register_attribute("bool_list" , boollist_ival.type(), boollist_ival); |
1215 | m.register_attribute("double_list" , doublelist.type(), doublelist); |
1216 | |
1217 | CompilationOptions options; |
1218 | mobile::Module bc = jitModuleToMobile(m, options); |
1219 | auto buff = save_mobile_module_to_bytes(bc); |
1220 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
1221 | |
1222 | // if the variables read are wrong type the conversion will raise exception |
1223 | auto boolval = bc2.attr("bool_list" , {}).toBoolList().get(0); |
1224 | auto doubleval = bc2.attr("double_list" , {}).toDoubleList().get(0); |
1225 | |
1226 | ASSERT_EQ(boolval, false); |
1227 | ASSERT_EQ(doubleval, 2.0); |
1228 | } |
1229 | |
1230 | TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest) |
1231 | const std::vector<std::string> test_programs{ |
1232 | // test invoking a method with default parameter |
1233 | R"( |
1234 | def test_func(self, x, b : int = 4): |
1235 | return self.foo + x + b |
1236 | )" , |
1237 | // inner method call with default parameter (gets inlined) |
1238 | R"( |
1239 | def add_with_default_arg(self, x, b : int = 4): |
1240 | return self.foo + x + b |
1241 | def test_func(self, x): |
1242 | return self.add_with_default_arg(x) # invoke method w/ default arg |
1243 | )" , |
1244 | // simple method call |
1245 | R"( |
1246 | def test_func(self, x): |
1247 | b = 4 |
1248 | return self.foo + x + b |
1249 | )" , |
1250 | }; |
1251 | for (const auto& test_program : test_programs) { |
1252 | Module m("m" ); |
1253 | m.register_parameter("foo" , torch::ones({}), false); |
1254 | m.define(test_program); |
1255 | |
1256 | CompilationOptions options; |
1257 | mobile::Module bc = jitModuleToMobile(m, options); |
1258 | const auto& func = bc.get_method("test_func" ).function(); |
1259 | ASSERT_EQ( |
1260 | func.get_code().operator_input_sizes_.size(), |
1261 | func.get_code().operators_.size()); |
1262 | |
1263 | auto buff = save_mobile_module_to_bytes(bc); |
1264 | mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size()); |
1265 | const auto& func2 = bc.get_method("test_func" ).function(); |
1266 | ASSERT_EQ( |
1267 | func2.get_code().operator_input_sizes_.size(), |
1268 | func2.get_code().operators_.size()); |
1269 | } |
1270 | } |
1271 | |
1272 | Module jitModuleFromBuffer(void* data, size_t size) { |
1273 | // Make a copy of the data so we can use the existing API, which takes |
1274 | // ownership. The `data` param might point into the middle of a buffer, so we |
1275 | // can't safely take ownership of it directly. |
1276 | // @nolint CLANGTIDY cppcoreguidelines-no-malloc |
1277 | std::shared_ptr<char> copy(static_cast<char*>(malloc(size)), free); |
1278 | memcpy(copy.get(), data, size); |
1279 | |
1280 | ExtraFilesMap ; |
1281 | return parse_and_initialize_jit_module(std::move(copy), size, extra_files); |
1282 | } |
1283 | |
1284 | TEST(TestSourceFlatbuffer, UpsampleNearest2d) { |
1285 | Module m("m" ); |
1286 | m.define(R"( |
1287 | def forward(self, input: Tensor, scale:float): |
1288 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
1289 | )" ); |
1290 | |
1291 | std::vector<IValue> inputs; |
1292 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
1293 | inputs.emplace_back(at::Scalar(2.0)); |
1294 | auto ref = m.forward(inputs); |
1295 | |
1296 | std::stringstream ss; |
1297 | m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true); |
1298 | auto mm = _load_for_mobile(ss); |
1299 | auto m2 = load(ss); |
1300 | |
1301 | auto res = m2.forward(inputs); |
1302 | auto resm = mm.forward(inputs); |
1303 | |
1304 | auto resd = res.toTensor(); |
1305 | auto refd = ref.toTensor(); |
1306 | auto resmd = resm.toTensor(); |
1307 | ASSERT_TRUE(resd.equal(refd)); |
1308 | ASSERT_TRUE(resmd.equal(refd)); |
1309 | } |
1310 | |
1311 | TEST(TestSourceFlatbuffer, CheckAttrAccess) { |
1312 | Module m("m" ); |
1313 | m.register_attribute("mobile_optimized" , BoolType::get(), true); |
1314 | auto data = save_jit_module_to_bytes(m); |
1315 | Module m2 = jitModuleFromBuffer(data->data(), data->size()); |
1316 | bool mobile_optimized = m2.attr("mobile_optimized" , false).toBool(); |
1317 | AT_ASSERT(mobile_optimized); |
1318 | mobile::Module m3 = parse_mobile_module(data->data(), data->size()); |
1319 | mobile_optimized = m3.attr("mobile_optimized" , false).toBool(); |
1320 | AT_ASSERT(mobile_optimized); |
1321 | } |
1322 | |
1323 | TEST(TestSourceFlatbuffer, |
1324 | MethodInvocation) { // NOLINT (use =delete in gtest) |
1325 | const std::vector<std::string> test_programs{ |
1326 | // test invoking a method with default parameter |
1327 | R"( |
1328 | def test_func(self, x, b : int = 4): |
1329 | return self.foo + x + b |
1330 | )" , |
1331 | // inner method call with default parameter (gets inlined) |
1332 | R"( |
1333 | def add_with_default_arg(self, x, b : int = 4): |
1334 | return self.foo + x + b |
1335 | def test_func(self, x): |
1336 | return self.add_with_default_arg(x) # invoke method w/ default arg |
1337 | )" , |
1338 | // simple method call |
1339 | R"( |
1340 | def test_func(self, x): |
1341 | b = 4 |
1342 | return self.foo + x + b |
1343 | )" , |
1344 | }; |
1345 | for (const auto& test_program : test_programs) { |
1346 | Module m("m" ); |
1347 | m.register_parameter("foo" , torch::ones({}), false); |
1348 | m.define(test_program); |
1349 | |
1350 | const int fortyTwo = 42; // (keep linter happy) |
1351 | auto minput = fortyTwo * torch::ones({}); |
1352 | auto ref = m.run_method("test_func" , minput); |
1353 | |
1354 | auto data = save_jit_module_to_bytes(m); |
1355 | Module m2 = jitModuleFromBuffer(data->data(), data->size()); |
1356 | const auto& test_func = m2.get_method("test_func" ); |
1357 | IValue res; |
1358 | for (int i = 0; i < 3; ++i) { |
1359 | res = test_func({minput}); |
1360 | } |
1361 | auto resd = res.toTensor().item<float>(); |
1362 | auto refd = ref.toTensor().item<float>(); |
1363 | AT_ASSERT(resd == refd); |
1364 | |
1365 | mobile::Module m3 = parse_mobile_module(data->data(), data->size()); |
1366 | const auto& test_func3 = m3.get_method("test_func" ); |
1367 | for (int i = 0; i < 3; ++i) { |
1368 | res = test_func3({minput}); |
1369 | } |
1370 | resd = res.toTensor().item<float>(); |
1371 | refd = ref.toTensor().item<float>(); |
1372 | AT_ASSERT(resd == refd); |
1373 | } |
1374 | } |
1375 | |
1376 | #if !defined FB_XPLAT_BUILD |
1377 | // The following test run in fbcode only |
1378 | TEST(FlatbufferUpgraderTest, DivTensorV2) { |
1379 | std::string filePath(__FILE__); |
1380 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1381 | test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl.ff" ); |
1382 | /* |
1383 | (('__torch__.MyModule.forward', |
1384 | (('instructions', |
1385 | (('STOREN', 1, 3), |
1386 | ('DROPR', 1, 0), |
1387 | ('LOAD', 2, 0), |
1388 | ('LOAD', 3, 0), |
1389 | ('OP', 0, 0), |
1390 | ('LOAD', 2, 0), |
1391 | ('LOAD', 3, 0), |
1392 | ('OP', 1, 0), |
1393 | ('MOVE', 2, 0), |
1394 | ('MOVE', 3, 0), |
1395 | ('OP', 2, 0), |
1396 | ('TUPLE_CONSTRUCT', 3, 0), |
1397 | ('RET', 0, 0))), |
1398 | ('operators', |
1399 | (('aten::div', 'Tensor'), |
1400 | ('aten::div', 'Tensor'), |
1401 | ('aten::div', 'Tensor'))), |
1402 | ('constants', ()), |
1403 | ('types', ()), |
1404 | ('register_size', 3))),) |
1405 | |
1406 | */ |
1407 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1408 | auto intrsuction_list = |
1409 | m_module.get_method("forward" ).function().get_code().instructions_; |
1410 | uint64_t number_of_call_instruction = 0; |
1411 | for (auto& instruction : intrsuction_list) { |
1412 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1413 | } |
1414 | // 3 operators will use upgrader |
1415 | ASSERT_EQ(number_of_call_instruction, 3); |
1416 | |
1417 | std::vector<IValue> inputs = { |
1418 | IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))}; |
1419 | auto actual_output = m_module.forward(inputs); |
1420 | auto expect_output = 2.0 * torch::ones({1}); |
1421 | auto actual_output_list = actual_output.toTuple()->elements(); |
1422 | ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output)); |
1423 | } |
1424 | |
1425 | TEST(FlatbufferUpgraderTest, DivTensorOutV2) { |
1426 | std::string filePath(__FILE__); |
1427 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1428 | test_model_file.append( |
1429 | "upgrader_models/test_versioned_div_tensor_out_v2.ptl.ff" ); |
1430 | /* |
1431 | (('__torch__.MyModule.forward', |
1432 | (('instructions', |
1433 | (('STOREN', 1, 4), |
1434 | ('DROPR', 1, 0), |
1435 | ('MOVE', 2, 0), |
1436 | ('MOVE', 3, 0), |
1437 | ('MOVE', 4, 0), |
1438 | ('OP', 0, 0), |
1439 | ('RET', 0, 0))), |
1440 | ('operators', (('aten::div', 'out'),)), |
1441 | ('constants', ()), |
1442 | ('types', ()), |
1443 | ('register_size', 4))),) |
1444 | */ |
1445 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1446 | |
1447 | auto intrsuction_list = |
1448 | m_module.get_method("forward" ).function().get_code().instructions_; |
1449 | uint64_t number_of_call_instruction = 0; |
1450 | for (auto& instruction : intrsuction_list) { |
1451 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1452 | } |
1453 | // One operator will use upgrader |
1454 | ASSERT_EQ(number_of_call_instruction, 1); |
1455 | |
1456 | std::vector<IValue> inputs{ |
1457 | IValue(6 * torch::ones({1})), |
1458 | IValue(3 * torch::ones({1})), |
1459 | IValue(torch::empty({1}))}; |
1460 | m_module.forward(inputs); |
1461 | auto expect_output = 2.0 * torch::ones({1}); |
1462 | auto actual_output = inputs[2].toTensor(); |
1463 | // The out argument will be overwritten with the output |
1464 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1465 | } |
1466 | |
1467 | TEST(FlatbufferUpgraderTest, DivTensorInplaceV2) { |
1468 | std::string filePath(__FILE__); |
1469 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1470 | test_model_file.append( |
1471 | "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl.ff" ); |
1472 | /* |
1473 | (('__torch__.MyModule.forward', |
1474 | (('instructions', |
1475 | (('STOREN', 1, 3), |
1476 | ('DROPR', 1, 0), |
1477 | ('MOVE', 2, 0), |
1478 | ('MOVE', 3, 0), |
1479 | ('OP', 0, 0), |
1480 | ('RET', 0, 0))), |
1481 | ('operators', (('aten::div_', 'Tensor'),)), |
1482 | ('constants', ()), |
1483 | ('types', ()), |
1484 | ('register_size', 3))),) |
1485 | */ |
1486 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1487 | |
1488 | auto intrsuction_list = |
1489 | m_module.get_method("forward" ).function().get_code().instructions_; |
1490 | uint64_t number_of_call_instruction = 0; |
1491 | for (auto& instruction : intrsuction_list) { |
1492 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1493 | } |
1494 | // One operator will use upgrader |
1495 | ASSERT_EQ(number_of_call_instruction, 1); |
1496 | |
1497 | std::vector<IValue> inputs{ |
1498 | IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))}; |
1499 | m_module.forward(inputs); |
1500 | auto expect_output = 2.0 * torch::ones({1}); |
1501 | auto actual_output = inputs[0].toTensor(); |
1502 | // The out argument will be overwritten with the output |
1503 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1504 | } |
1505 | |
1506 | TEST(FlatbufferUpgraderTest, DivScalarFloatV2) { |
1507 | std::string filePath(__FILE__); |
1508 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1509 | test_model_file.append( |
1510 | "upgrader_models/test_versioned_div_scalar_float_v2.ptl.ff" ); |
1511 | /* |
1512 | (('__torch__.MyModuleFloat.forward', |
1513 | (('instructions', |
1514 | (('STOREN', 1, 3), |
1515 | ('DROPR', 1, 0), |
1516 | ('MOVE', 2, 0), |
1517 | ('MOVE', 3, 0), |
1518 | ('OP', 0, 0), |
1519 | ('RET', 0, 0))), |
1520 | ('operators', (('aten::div', 'Scalar'),)), |
1521 | ('constants', ()), |
1522 | ('types', ()), |
1523 | ('register_size', 3))),) |
1524 | */ |
1525 | |
1526 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1527 | |
1528 | auto intrsuction_list = |
1529 | m_module.get_method("forward" ).function().get_code().instructions_; |
1530 | uint64_t number_of_call_instruction = 0; |
1531 | for (auto& instruction : intrsuction_list) { |
1532 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1533 | } |
1534 | // One operator will use upgrader |
1535 | ASSERT_EQ(number_of_call_instruction, 1); |
1536 | |
1537 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1538 | auto output = m_module.forward(inputs); |
1539 | auto expect_output = 2.0 * torch::ones({1}); |
1540 | auto actual_output = output.toTensor(); |
1541 | |
1542 | // The out argument will be overwritten with the output |
1543 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1544 | } |
1545 | |
1546 | TEST(FlatbufferUpgraderTest, DivScalarReciprocalFloatV2) { |
1547 | std::string filePath(__FILE__); |
1548 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1549 | test_model_file.append( |
1550 | "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl.ff" ); |
1551 | /* |
1552 | (('__torch__.MyModuleFloat.forward', |
1553 | (('instructions', |
1554 | (('STOREN', 1, 3), |
1555 | ('DROPR', 1, 0), |
1556 | ('MOVE', 2, 0), |
1557 | ('OP', 0, 0), |
1558 | ('MOVE', 3, 0), |
1559 | ('OP', 1, 0), |
1560 | ('RET', 0, 0))), |
1561 | ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))), |
1562 | ('constants', ()), |
1563 | ('types', ()), |
1564 | ('register_size', 3))),) |
1565 | */ |
1566 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1567 | |
1568 | auto intrsuction_list = |
1569 | m_module.get_method("forward" ).function().get_code().instructions_; |
1570 | uint64_t number_of_call_instruction = 0; |
1571 | for (auto& instruction : intrsuction_list) { |
1572 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1573 | } |
1574 | // No operator will use upgrader |
1575 | ASSERT_EQ(number_of_call_instruction, 0); |
1576 | |
1577 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1578 | auto output = m_module.forward(inputs); |
1579 | auto expect_output = 0.5 * torch::ones({1}); |
1580 | auto actual_output = output.toTensor(); |
1581 | // The out argument will be overwritten with the output |
1582 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1583 | } |
1584 | |
1585 | TEST(FlatbufferUpgraderTest, DivScalarReciprocalIntV2) { |
1586 | std::string filePath(__FILE__); |
1587 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1588 | test_model_file.append( |
1589 | "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl.ff" ); |
1590 | /* |
1591 | (('__torch__.MyModuleInt.forward', |
1592 | (('instructions', |
1593 | (('STOREN', 1, 3), |
1594 | ('DROPR', 1, 0), |
1595 | ('MOVE', 2, 0), |
1596 | ('OP', 0, 0), |
1597 | ('MOVE', 3, 0), |
1598 | ('OP', 1, 0), |
1599 | ('RET', 0, 0))), |
1600 | ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))), |
1601 | ('constants', ()), |
1602 | ('types', ()), |
1603 | ('register_size', 3))),) |
1604 | */ |
1605 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1606 | |
1607 | auto intrsuction_list = |
1608 | m_module.get_method("forward" ).function().get_code().instructions_; |
1609 | uint64_t number_of_call_instruction = 0; |
1610 | for (auto& instruction : intrsuction_list) { |
1611 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1612 | } |
1613 | // No operator will use upgrader |
1614 | ASSERT_EQ(number_of_call_instruction, 0); |
1615 | |
1616 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1617 | auto output = m_module.forward(inputs); |
1618 | auto expect_output = 0.5 * torch::ones({1}); |
1619 | auto actual_output = output.toTensor(); |
1620 | |
1621 | // The out argument will be overwritten with the output |
1622 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1623 | } |
1624 | |
1625 | TEST(FlatbufferUpgraderTest, DivScalarScalarV2) { |
1626 | std::string filePath(__FILE__); |
1627 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1628 | test_model_file.append( |
1629 | "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl.ff" ); |
1630 | /* |
1631 | (('__torch__.MyModule.forward', |
1632 | (('instructions', |
1633 | (('STOREN', 1, 5), |
1634 | ('DROPR', 1, 0), |
1635 | ('LOAD', 2, 0), |
1636 | ('LOAD', 3, 0), |
1637 | ('OP', 0, 0), |
1638 | ('MOVE', 2, 0), |
1639 | ('LOAD', 4, 0), |
1640 | ('OP', 1, 0), |
1641 | ('LOAD', 3, 0), |
1642 | ('MOVE', 4, 0), |
1643 | ('OP', 2, 0), |
1644 | ('MOVE', 3, 0), |
1645 | ('MOVE', 5, 0), |
1646 | ('OP', 3, 0), |
1647 | ('TUPLE_CONSTRUCT', 4, 0), |
1648 | ('RET', 0, 0))), |
1649 | ('operators', |
1650 | (('aten::div', ''), |
1651 | ('aten::div', 'float'), |
1652 | ('aten::div', ''), |
1653 | ('aten::div', 'int'))), |
1654 | ('constants', ()), |
1655 | ('types', ()), |
1656 | ('register_size', 5))),) |
1657 | */ |
1658 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1659 | auto intrsuction_list = |
1660 | m_module.get_method("forward" ).function().get_code().instructions_; |
1661 | uint64_t number_of_call_instruction = 0; |
1662 | for (auto& instruction : intrsuction_list) { |
1663 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1664 | } |
1665 | // No operator will use upgrader |
1666 | ASSERT_EQ(number_of_call_instruction, 0); |
1667 | |
1668 | std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)}; |
1669 | auto output = m_module.forward(inputs); |
1670 | auto output_list = output.toTupleRef().elements(); |
1671 | auto expect_output = std::vector<IValue>( |
1672 | {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)}); |
1673 | // auto actual_output = output.toTensor(); |
1674 | for (size_t i = 0; i < expect_output.size(); i++) { |
1675 | ASSERT_EQ(output_list[i], expect_output[i]); |
1676 | } |
1677 | } |
1678 | |
1679 | TEST(FlatbufferUpgraderTest, DivScalarIntV2) { |
1680 | std::string filePath(__FILE__); |
1681 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1682 | test_model_file.append( |
1683 | "upgrader_models/test_versioned_div_scalar_int_v2.ptl.ff" ); |
1684 | /* |
1685 | (('__torch__.MyModuleInt.forward', |
1686 | (('instructions', |
1687 | (('STOREN', 1, 3), |
1688 | ('DROPR', 1, 0), |
1689 | ('MOVE', 2, 0), |
1690 | ('MOVE', 3, 0), |
1691 | ('OP', 0, 0), |
1692 | ('RET', 0, 0))), |
1693 | ('operators', (('aten::div', 'Scalar'),)), |
1694 | ('constants', ()), |
1695 | ('types', ()), |
1696 | ('register_size', 3))),) |
1697 | */ |
1698 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1699 | |
1700 | auto intrsuction_list = |
1701 | m_module.get_method("forward" ).function().get_code().instructions_; |
1702 | uint64_t number_of_call_instruction = 0; |
1703 | for (auto& instruction : intrsuction_list) { |
1704 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1705 | } |
1706 | // One operator will use upgrader |
1707 | ASSERT_EQ(number_of_call_instruction, 1); |
1708 | |
1709 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)}; |
1710 | auto output = m_module.forward(inputs); |
1711 | auto expect_output = 2.0 * torch::ones({1}); |
1712 | auto actual_output = output.toTensor(); |
1713 | |
1714 | // The out argument will be overwritten with the output |
1715 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1716 | } |
1717 | |
1718 | TEST(FlatbufferUpgraderTest, DivScalarInplaceFloatV2) { |
1719 | std::string filePath(__FILE__); |
1720 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1721 | test_model_file.append( |
1722 | "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl.ff" ); |
1723 | /* |
1724 | (('__torch__.MyModuleFloat.forward', |
1725 | (('instructions', |
1726 | (('STOREN', 1, 3), |
1727 | ('DROPR', 1, 0), |
1728 | ('MOVE', 2, 0), |
1729 | ('MOVE', 3, 0), |
1730 | ('OP', 0, 0), |
1731 | ('RET', 0, 0))), |
1732 | ('operators', (('aten::div_', 'Scalar'),)), |
1733 | ('constants', ()), |
1734 | ('types', ()), |
1735 | ('register_size', 3))),) |
1736 | */ |
1737 | |
1738 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1739 | |
1740 | auto intrsuction_list = |
1741 | m_module.get_method("forward" ).function().get_code().instructions_; |
1742 | uint64_t number_of_call_instruction = 0; |
1743 | for (auto& instruction : intrsuction_list) { |
1744 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1745 | } |
1746 | // One operator will use upgrader |
1747 | ASSERT_EQ(number_of_call_instruction, 1); |
1748 | |
1749 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)}; |
1750 | auto output = m_module.forward(inputs); |
1751 | auto expect_output = 2.0 * torch::ones({1}); |
1752 | auto actual_output = output.toTensor(); |
1753 | |
1754 | // The out argument will be overwritten with the output |
1755 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1756 | } |
1757 | |
1758 | TEST(FlatbufferUpgraderTest, DivScalarInplaceIntV2) { |
1759 | std::string filePath(__FILE__); |
1760 | auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
1761 | test_model_file.append( |
1762 | "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl.ff" ); |
1763 | /* |
1764 | (('__torch__.MyModuleInt.forward', |
1765 | (('instructions', |
1766 | (('STOREN', 1, 3), |
1767 | ('DROPR', 1, 0), |
1768 | ('MOVE', 2, 0), |
1769 | ('MOVE', 3, 0), |
1770 | ('OP', 0, 0), |
1771 | ('RET', 0, 0))), |
1772 | ('operators', (('aten::div_', 'Scalar'),)), |
1773 | ('constants', ()), |
1774 | ('types', ()), |
1775 | ('register_size', 3))),) |
1776 | */ |
1777 | |
1778 | mobile::Module m_module = load_mobile_module_from_file(test_model_file); |
1779 | |
1780 | auto intrsuction_list = |
1781 | m_module.get_method("forward" ).function().get_code().instructions_; |
1782 | uint64_t number_of_call_instruction = 0; |
1783 | for (auto& instruction : intrsuction_list) { |
1784 | number_of_call_instruction += (instruction.op == OpCode::CALL); |
1785 | } |
1786 | // One operator will use upgrader |
1787 | ASSERT_EQ(number_of_call_instruction, 1); |
1788 | |
1789 | std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)}; |
1790 | auto output = m_module.forward(inputs); |
1791 | auto expect_output = 2.0 * torch::ones({1}); |
1792 | auto actual_output = output.toTensor(); |
1793 | |
1794 | // The out argument will be overwritten with the output |
1795 | ASSERT_TRUE(actual_output.equal(expect_output)); |
1796 | } |
1797 | |
1798 | #endif // !defined(FB_XPLAT_BUILD) |
1799 | |
1800 | // |
1801 | // Tests that need access to internal flatbuffers types/functions. |
1802 | // Do not add any other tests after this section. |
1803 | // |
1804 | |
1805 | } // namespace jit |
1806 | } // namespace torch |
1807 | namespace torch { |
1808 | namespace jit { |
1809 | |
1810 | /** |
1811 | * An Allocator that can only deallocate (using delete []), counting |
1812 | * the number of times that it has been asked to deallocate. |
1813 | */ |
1814 | class TestAllocator : public flatbuffers::Allocator { |
1815 | public: |
1816 | /** |
1817 | * *deallocate_call_count will be incremented whenever deallocate() is called. |
1818 | */ |
1819 | explicit TestAllocator(int* deallocate_call_count) |
1820 | : deallocate_call_count_(deallocate_call_count) {} |
1821 | |
1822 | void deallocate(uint8_t* p, size_t /*size*/) override { |
1823 | *deallocate_call_count_ += 1; |
1824 | delete[] p; |
1825 | } |
1826 | |
1827 | uint8_t* allocate(size_t) override { |
1828 | TORCH_CHECK(false, "allocate() should not be called" ); |
1829 | } |
1830 | uint8_t* reallocate_downward(uint8_t*, size_t, size_t, size_t, size_t) |
1831 | override { |
1832 | TORCH_CHECK(false, "reallocate_downward() should not be called" ); |
1833 | } |
1834 | |
1835 | private: |
1836 | int* deallocate_call_count_; |
1837 | }; |
1838 | |
1839 | /// Provides access to DetachedBuffer::destroy(). |
1840 | struct DetachedBufferTestingFriend { |
1841 | /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer. |
1842 | /// A copy of similar code in flatbuffer_serializer.cpp. |
1843 | static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer( |
1844 | DetachedBuffer* buf) { |
1845 | return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy); |
1846 | } |
1847 | }; |
1848 | |
1849 | TEST(FlatbufferTest, DetachedBufferSmoke) { |
1850 | // Use a custom Allocator to watch the lifecycle of a |
1851 | // flatbuffers::DetachedBuffer. |
1852 | int deallocate_call_count = 0; |
1853 | TestAllocator alloc(&deallocate_call_count); |
1854 | |
1855 | // Data for the buffer. TestAllocator will free it with `delete []`. |
1856 | constexpr size_t data_size = 4; |
1857 | uint8_t* data = new uint8_t[data_size]; |
1858 | |
1859 | // An internal buffer on the stack that owns the data. |
1860 | flatbuffers::DetachedBuffer fb_buf_local( |
1861 | &alloc, /*own_allocator=*/false, data, data_size, data, data_size); |
1862 | EXPECT_EQ(fb_buf_local.data(), data); |
1863 | EXPECT_EQ(fb_buf_local.size(), data_size); |
1864 | |
1865 | // Mimic the code inside save_mobile_module_to_bytes by transferring ownership |
1866 | // to a heap object. |
1867 | auto fb_buf_ptr = new flatbuffers::DetachedBuffer(std::move(fb_buf_local)); |
1868 | // The data should not have been deleted yet. |
1869 | EXPECT_EQ(deallocate_call_count, 0); |
1870 | // The new object points to the data. |
1871 | EXPECT_EQ(fb_buf_ptr->data(), data); |
1872 | EXPECT_EQ(fb_buf_ptr->size(), data_size); |
1873 | // The old object points to nothing. |
1874 | // @lint-ignore CLANGTIDY bugprone-use-after-move |
1875 | EXPECT_EQ(fb_buf_local.data(), nullptr); |
1876 | // @lint-ignore CLANGTIDY bugprone-use-after-move |
1877 | EXPECT_EQ(fb_buf_local.size(), 0); |
1878 | |
1879 | // The top-level torch::jit::DetachedBuffer. |
1880 | auto wrapped_buf = |
1881 | new DetachedBuffer(fb_buf_ptr->data(), fb_buf_ptr->size(), fb_buf_ptr); |
1882 | EXPECT_EQ(wrapped_buf->data(), data); |
1883 | EXPECT_EQ(wrapped_buf->size(), data_size); |
1884 | |
1885 | // The unique_ptr that owns the torch::jit::DetachedBuffer and its contents. |
1886 | { |
1887 | DetachedBuffer::UniqueDetachedBuffer unique_buf = |
1888 | DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf); |
1889 | EXPECT_EQ(unique_buf->data(), data); |
1890 | EXPECT_EQ(unique_buf->size(), data_size); |
1891 | |
1892 | // The data should not have been deleted yet. |
1893 | EXPECT_EQ(deallocate_call_count, 0); |
1894 | } |
1895 | |
1896 | // Now that the unique_ptr is out of scope, the data should have been deleted. |
1897 | EXPECT_EQ(deallocate_call_count, 1); |
1898 | } |
1899 | |
1900 | TEST(FlatbufferTest, DetachedBufferNullOwner) { |
1901 | // a torch::jit::DetachedBuffer with a null internal owner. |
1902 | std::vector<uint8_t> data(4); |
1903 | auto wrapped_buf = new DetachedBuffer(data.data(), data.size()); |
1904 | |
1905 | // A unique_ptr that owns the torch::jit::DetachedBuffer and its contents. |
1906 | { |
1907 | DetachedBuffer::UniqueDetachedBuffer unique_buf = |
1908 | DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf); |
1909 | EXPECT_EQ(unique_buf->data(), data.data()); |
1910 | EXPECT_EQ(unique_buf->size(), data.size()); |
1911 | } |
1912 | |
1913 | // The DetachedBuffer should have been destroyed when the UniqueDetachedBuffer |
1914 | // went out of scope. If we didn't crash or get any ASAN warnings, we should |
1915 | // be good. |
1916 | } |
1917 | |
1918 | // |
1919 | // Do not add tests here unless they require flatbuffers types. See comment at |
1920 | // the beginning of this section. |
1921 | // |
1922 | |
1923 | } // namespace jit |
1924 | } // namespace torch |
1925 | |