1#include <test/cpp/jit/test_utils.h>
2
3#include <gtest/gtest.h>
4
5#include <c10/core/TensorOptions.h>
6#include <torch/csrc/autograd/generated/variable_factories.h>
7#include <torch/csrc/jit/api/module.h>
8#include <torch/csrc/jit/frontend/resolver.h>
9#include <torch/csrc/jit/mobile/compatibility/backport.h>
10#include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
14#include <torch/csrc/jit/mobile/import.h>
15#include <torch/csrc/jit/mobile/interpreter.h>
16#include <torch/csrc/jit/mobile/module.h>
17#include <torch/csrc/jit/mobile/parse_bytecode.h>
18#include <torch/csrc/jit/mobile/parse_operators.h>
19#include <torch/csrc/jit/serialization/export.h>
20#include <torch/csrc/jit/serialization/export_bytecode.h>
21#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
22#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
23#include <torch/csrc/jit/serialization/import.h>
24#include <torch/custom_class.h>
25#include <torch/torch.h>
26
27#include <caffe2/serialize/versions.h>
28#include <torch/csrc/jit/serialization/import_export_functions.h>
29#include <unordered_set>
30
31#if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
32#include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
33namespace flatbuffers = flatbuffers_fbsource;
34#define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
35#else
36#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
37#endif
38// Tests go in torch::jit
39namespace torch {
40namespace jit {
41
42namespace {
43mobile::Module parse_mobile_module(
44 void* data,
45 size_t size,
46 bool should_copy_tensor_memory = false) {
47 return parse_and_initialize_mobile_module(
48 static_cast<char*>(data),
49 size,
50 /*device=*/c10::nullopt,
51 /*extra_files=*/nullptr,
52 should_copy_tensor_memory);
53}
54} // namespace
55
56TEST(FlatbufferTest, UpsampleNearest2d) {
57 Module m("m");
58 m.define(R"(
59 def forward(self, input: Tensor, scale:float):
60 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
61 )");
62
63 std::vector<IValue> inputs;
64 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
65 inputs.emplace_back(at::Scalar(2.0));
66 auto ref = m.forward(inputs);
67
68 CompilationOptions options;
69 mobile::Module bc = jitModuleToMobile(m, options);
70 IValue res;
71 res = bc.forward(inputs);
72
73 auto resd = res.toTensor();
74 auto refd = ref.toTensor();
75 ASSERT_TRUE(resd.equal(refd));
76
77 auto buff = save_mobile_module_to_bytes(bc);
78 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
79 auto res2 = bc2.forward(inputs);
80 auto resd2 = res2.toTensor();
81 ASSERT_TRUE(resd2.equal(refd));
82}
83
84TEST(FlatbufferTest, UpsampleNearest2dWithCopyTensorMemory) {
85 Module m("m");
86 m.define(R"(
87 def forward(self, input: Tensor, scale:float):
88 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
89 )");
90
91 std::vector<IValue> inputs;
92 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
93 inputs.emplace_back(at::Scalar(2.0));
94 auto ref = m.forward(inputs);
95
96 CompilationOptions options;
97 mobile::Module bc = jitModuleToMobile(m, options);
98 IValue res;
99 res = bc.forward(inputs);
100
101 auto resd = res.toTensor();
102 auto refd = ref.toTensor();
103 ASSERT_TRUE(resd.equal(refd));
104
105 auto buff = save_mobile_module_to_bytes(bc);
106 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
107
108 auto res2 = bc2.forward(inputs);
109 auto resd2 = res2.toTensor();
110 ASSERT_TRUE(resd2.equal(refd));
111}
112
113TEST(FlatbufferTest, CheckAttrAccess) {
114 Module m("m");
115 m.register_attribute("mobile_optimized", BoolType::get(), true);
116
117 CompilationOptions options;
118 mobile::Module bc = jitModuleToMobile(m, options);
119 bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
120
121 AT_ASSERT(mobile_optimized);
122 m.setattr("mobile_optimized", false);
123 bc = jitModuleToMobile(m, options);
124 mobile_optimized = bc.attr("mobile_optimized", false).toBool();
125
126 AT_ASSERT(!mobile_optimized);
127
128 auto buff = save_mobile_module_to_bytes(bc);
129 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
130 auto mobile_optimized2 = bc2.attr("mobile_optimized", false).toBool();
131 AT_ASSERT(!mobile_optimized2);
132}
133
134TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
135 const std::vector<std::string> test_programs{
136 // test invoking a method with default parameter
137 R"(
138 def test_func(self, x, b : int = 4):
139 return self.foo + x + b
140 )",
141 // inner method call with default parameter (gets inlined)
142 R"(
143 def add_with_default_arg(self, x, b : int = 4):
144 return self.foo + x + b
145 def test_func(self, x):
146 return self.add_with_default_arg(x) # invoke method w/ default arg
147 )",
148 // simple method call
149 R"(
150 def test_func(self, x):
151 b = 4
152 return self.foo + x + b
153 )",
154 };
155 for (const auto& test_program : test_programs) {
156 Module m("m");
157 m.register_parameter("foo", torch::ones({}), false);
158 m.define(test_program);
159
160 const int fortyTwo = 42; // (keep linter happy)
161 auto minput = fortyTwo * torch::ones({});
162 auto ref = m.run_method("test_func", minput);
163
164 CompilationOptions options;
165 mobile::Module bc = jitModuleToMobile(m, options);
166 const auto& test_func = bc.get_method("test_func");
167 IValue res;
168 for (int i = 0; i < 3; ++i) {
169 res = test_func({minput});
170 }
171
172 auto resd = res.toTensor().item<float>();
173 auto refd = ref.toTensor().item<float>();
174 AT_ASSERT(resd == refd);
175
176 auto buff = save_mobile_module_to_bytes(bc);
177 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
178 const auto& test_func2 = bc2.get_method("test_func");
179 IValue res2;
180 for (int i = 0; i < 3; ++i) {
181 res2 = test_func2({minput});
182 }
183 auto resd2 = res2.toTensor().item<float>();
184 AT_ASSERT(resd2 == refd);
185 }
186}
187
188#if !defined(FB_XPLAT_BUILD)
189TEST(FlatbufferTest, FlatbufferBackPortTest) {
190 Module m("m");
191 m.define(R"(
192 def forward(self, input: Tensor, scale:float):
193 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
194 )");
195 std::stringstream ss;
196 m._save_for_mobile(ss, {}, false, true);
197
198 std::stringstream oss;
199 bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
200 ASSERT_TRUE(backPortSuccess);
201}
202#endif // !defined(FB_XPLAT_BUILD)
203
204TEST(FlatbufferTest, ExtraFiles) {
205 const auto script = R"JIT(
206 def forward(self):
207 x = torch.rand(5, 5)
208 x = x.mm(x)
209 return x
210 )JIT";
211
212 auto module =
213 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
214 module->define(script);
215 std::ostringstream oss;
216 std::unordered_map<std::string, std::string> extra_files;
217 extra_files["metadata.json"] = "abc";
218 extra_files["mobile_info.json"] = "{\"key\": 23}";
219
220 std::unordered_map<std::string, std::string> loaded_extra_files;
221 std::stringstream ss;
222 module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
223
224 loaded_extra_files["metadata.json"] = "";
225 auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
226
227 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
228 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
229
230 // load it twice using the same stream
231 auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
232
233 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
234 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
235}
236
237TEST(FlatbufferTest, Conv) {
238 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
239 if (s && strcmp(s, "1") == 0)
240 return;
241
242 std::vector<torch::jit::IValue> inputs;
243
244 Module m("m");
245 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
246 m.register_parameter("bias", torch::ones({20}), false);
247 m.define(R"(
248 def forward(self, input):
249 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
250 )");
251
252 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
253 inputs.push_back(torch::ones({1, 1, 28, 28}));
254
255 auto outputref = m.forward(inputs).toTensor();
256
257 CompilationOptions options;
258 mobile::Module bc = jitModuleToMobile(m, options);
259 IValue res;
260 for (int i = 0; i < 3; ++i) {
261 res = bc.get_method("forward")(inputs);
262 }
263 auto output = res.toTensor();
264 AT_ASSERT(outputref.dim() == output.dim());
265 AT_ASSERT(
266 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
267
268 auto buff = save_mobile_module_to_bytes(bc);
269 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
270 for (int i = 0; i < 3; ++i) {
271 res = bc2.get_method("forward")(inputs);
272 }
273 output = res.toTensor();
274 AT_ASSERT(outputref.dim() == output.dim());
275 AT_ASSERT(
276 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
277}
278
279TEST(FlatbufferTest, ConvWithCopyTensorMemory) {
280 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
281 if (s && strcmp(s, "1") == 0)
282 return;
283
284 std::vector<torch::jit::IValue> inputs;
285
286 Module m("m");
287 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
288 m.register_parameter("bias", torch::ones({20}), false);
289 m.define(R"(
290 def forward(self, input):
291 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
292 )");
293
294 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
295 inputs.push_back(torch::ones({1, 1, 28, 28}));
296
297 auto outputref = m.forward(inputs).toTensor();
298
299 CompilationOptions options;
300 mobile::Module bc = jitModuleToMobile(m, options);
301 IValue res;
302 for (int i = 0; i < 3; ++i) {
303 res = bc.get_method("forward")(inputs);
304 }
305 auto output = res.toTensor();
306 AT_ASSERT(outputref.dim() == output.dim());
307 AT_ASSERT(
308 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
309
310 auto buff = save_mobile_module_to_bytes(bc);
311 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
312
313 for (int i = 0; i < 3; ++i) {
314 res = bc2.get_method("forward")(inputs);
315 }
316 output = res.toTensor();
317 AT_ASSERT(outputref.dim() == output.dim());
318 AT_ASSERT(
319 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
320}
321
322TEST(FlatbufferTest, Inline) {
323 Module m("m");
324 m.define(R"JIT(
325 def foo1(self, x):
326 return x + 1
327
328 def foo2(self, x):
329 return self.foo1(x) + 2
330
331 def foo3(self, x):
332 return self.foo2(x) + 3
333 )JIT");
334 CompilationOptions options;
335 mobile::Module bc = jitModuleToMobile(m, options);
336 std::vector<torch::jit::IValue> inputs({torch::ones({})});
337 auto output = bc.get_method("foo3")(inputs);
338 AT_ASSERT(output.toTensor().item<float>() == 7.0);
339
340 auto buff = save_mobile_module_to_bytes(bc);
341 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
342 std::vector<torch::jit::IValue> inputs2({torch::ones({})});
343 output = bc2.get_method("foo3")(inputs2);
344 AT_ASSERT(output.toTensor().item<float>() == 7.0);
345}
346
347TEST(FlatbufferTest, InlineWithCopyTensorMemory) {
348 Module m("m");
349 m.define(R"JIT(
350 def foo1(self, x):
351 return x + 1
352
353 def foo2(self, x):
354 return self.foo1(x) + 2
355
356 def foo3(self, x):
357 return self.foo2(x) + 3
358 )JIT");
359 CompilationOptions options;
360 mobile::Module bc = jitModuleToMobile(m, options);
361 std::vector<torch::jit::IValue> inputs({torch::ones({})});
362 auto output = bc.get_method("foo3")(inputs);
363 AT_ASSERT(output.toTensor().item<float>() == 7.0);
364
365 auto buff = save_mobile_module_to_bytes(bc);
366 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
367 std::vector<torch::jit::IValue> inputs2({torch::ones({})});
368 output = bc2.get_method("foo3")(inputs2);
369 AT_ASSERT(output.toTensor().item<float>() == 7.0);
370}
371
372TEST(FlatbufferTest, Tuple) {
373 Module m("m");
374 m.define(R"JIT(
375 def foo(self, x):
376 return (1, 2, x + 3)
377
378 def forward(self, x):
379 tuple = self.foo(x)
380 return tuple
381 )JIT");
382 CompilationOptions options;
383 mobile::Module bc = jitModuleToMobile(m, options);
384 std::vector<torch::jit::IValue> inputs({torch::ones({})});
385 auto output = bc.get_method("forward")(inputs);
386 AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
387
388 auto buff = save_mobile_module_to_bytes(bc);
389 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
390 output = bc2.get_method("forward")(inputs);
391 AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
392}
393
394TEST(FlatbufferTest, Dict) {
395 Module m("m");
396 m.define(R"JIT(
397 def foo(self, x):
398 return {"result": x + 1}
399
400 def forward(self, x):
401 d = self.foo(x)
402 return d
403 )JIT");
404 CompilationOptions options;
405 mobile::Module bc = jitModuleToMobile(m, options);
406 std::vector<torch::jit::IValue> inputs({torch::ones({})});
407 auto output = bc.get_method("forward")(inputs);
408 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
409
410 auto buff = save_mobile_module_to_bytes(bc);
411 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
412 output = bc2.get_method("forward")(inputs);
413 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
414}
415
416TEST(FlatbufferTest, Prim) {
417 Module m("m");
418 m.define(R"JIT(
419 def forward(self, x):
420 return int(x)
421 )JIT");
422
423 std::vector<IValue> inputs;
424 auto minput = 3.5 * torch::ones({});
425 inputs.emplace_back(minput);
426 auto ref = m.run_method("forward", minput);
427
428 CompilationOptions options;
429 mobile::Module bc = jitModuleToMobile(m, options);
430 IValue res;
431 for (int i = 0; i < 3; ++i) {
432 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
433 auto bcinputs = inputs;
434 res = bc.get_method("forward")(bcinputs);
435 }
436
437 auto resi = res.toInt();
438 auto refi = ref.toInt();
439 AT_ASSERT(resi == refi);
440
441 auto buff = save_mobile_module_to_bytes(bc);
442 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
443 for (int i = 0; i < 3; ++i) {
444 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
445 auto bcinputs = inputs;
446 res = bc2.get_method("forward")(bcinputs);
447 }
448 auto resi2 = res.toInt();
449 AT_ASSERT(resi2 == refi);
450}
451
452TEST(FlatbufferTest, PrimScalar) {
453 Module m("m");
454 m.define(R"JIT(
455 def forward(self, x):
456 return int(x.item())
457 )JIT");
458
459 std::vector<IValue> inputs;
460 auto minput = 3.5 * torch::ones({});
461 inputs.emplace_back(minput);
462 auto ref = m.run_method("forward", minput);
463
464 CompilationOptions options;
465 mobile::Module bc = jitModuleToMobile(m, options);
466 IValue res;
467 for (int i = 0; i < 3; ++i) {
468 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
469 auto bcinputs = inputs;
470 res = bc.get_method("forward")(bcinputs);
471 }
472
473 auto resi = res.toInt();
474 auto refi = ref.toInt();
475 AT_ASSERT(resi == refi);
476
477 auto buff = save_mobile_module_to_bytes(bc);
478 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
479 for (int i = 0; i < 3; ++i) {
480 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
481 auto bcinputs = inputs;
482 res = bc2.get_method("forward")(bcinputs);
483 }
484 auto resi2 = res.toInt();
485 AT_ASSERT(resi2 == refi);
486}
487
488TEST(FlatbufferTest, WrongMethodName) {
489 Module m("m");
490 m.register_parameter("foo", torch::ones({}), false);
491 m.define(R"(
492 def add(self, x):
493 b = 4
494 return self.foo + x + b
495 )");
496 CompilationOptions options;
497 mobile::Module bc = jitModuleToMobile(m, options);
498 std::vector<IValue> inputs;
499 auto minput = 5 * torch::ones({});
500 inputs.emplace_back(minput);
501 ASSERT_THROWS_WITH_MESSAGE(
502 bc.get_method("forward")(inputs), "is not defined");
503
504 auto buff = save_mobile_module_to_bytes(bc);
505 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
506 ASSERT_THROWS_WITH_MESSAGE(
507 bc2.get_method("forward")(inputs), "is not defined");
508}
509
510TEST(FlatbufferTest, SetState) {
511 Module m("m");
512 m.register_parameter("foo", torch::ones({}), false);
513 m.define(R"(
514 def __getstate__(self):
515 return self.foo
516 def __setstate__(self, a):
517 self.foo = a
518 def forward(self, x):
519 b = 4
520 return self.foo + x + b
521 )");
522
523 std::vector<IValue> inputs;
524 auto minput = 5 * torch::ones({});
525 inputs.emplace_back(minput);
526
527 std::stringstream ms;
528 m.save(ms);
529 auto loaded_m = load(ms);
530 auto ref = loaded_m.run_method("forward", minput);
531
532 CompilationOptions options;
533 mobile::Module bc = jitModuleToMobile(m, options);
534 IValue res;
535 for (int i = 0; i < 3; ++i) {
536 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
537 auto bcinputs = inputs;
538 res = bc.get_method("forward")(bcinputs);
539 }
540
541 auto resd = res.toTensor().item<float>();
542 auto refd = ref.toTensor().item<float>();
543 AT_ASSERT(resd == refd);
544
545 auto buff = save_mobile_module_to_bytes(bc);
546 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
547 for (int i = 0; i < 3; ++i) {
548 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
549 auto bcinputs = inputs;
550 res = bc2.get_method("forward")(bcinputs);
551 }
552
553 auto resd2 = res.toTensor().item<float>();
554 AT_ASSERT(resd2 == refd);
555}
556
557class TorchBindFlatbufferTestStruct : public torch::jit::CustomClassHolder {
558 public:
559 std::string get(at::Tensor t) {
560 std::stringstream ss;
561 ss << "Hello! Your tensor has ";
562 ss << t.numel();
563 ss << " elements!";
564 return ss.str();
565 }
566};
567
568namespace {
569struct ClassNamespaceValue : public SugaredValue {
570 explicit ClassNamespaceValue(c10::QualifiedName name)
571 : basename_(std::move(name)) {}
572
573 std::shared_ptr<SugaredValue> attr(
574 const SourceRange& loc,
575 GraphFunction& m,
576 const std::string& name) override {
577 const auto fullName = c10::QualifiedName(basename_, name);
578
579 // Check to see if it is a custom class.
580 if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
581 return std::make_shared<ClassValue>(custom_class);
582 }
583
584 // If it's not a custom class, assume it's another namespace
585 // NOLINTNEXTLINE(performance-move-const-arg)
586 return std::make_shared<ClassNamespaceValue>(std::move(fullName));
587 }
588
589 std::string kind() const override {
590 return "Class Namespace";
591 }
592
593 private:
594 c10::QualifiedName basename_;
595};
596
597struct TestModuleResolver : public Resolver {
598 std::shared_ptr<SugaredValue> resolveValue(
599 const std::string& name,
600 GraphFunction& m,
601 const SourceRange& loc) override {
602 if (name == "torch") {
603 return std::make_shared<BuiltinModule>("aten");
604 } else if (name == "__torch__") {
605 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
606 }
607
608 return nullptr;
609 }
610
611 TypePtr resolveType(const std::string& name, const SourceRange& loc)
612 override {
613 return nullptr;
614 }
615};
616} // namespace
617
618TEST(FlatbufferTest, BuiltinClass) {
619 script::Module m("m");
620
621 auto cls = getCustomClass(
622 "__torch__.torch.classes._TorchScriptTesting._FlatbufferTest");
623 TORCH_INTERNAL_ASSERT(cls);
624 c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
625 m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
626
627 m.register_parameter("foo", torch::ones({}), false);
628 m.define(
629 R"(
630 def __getstate__(self):
631 return 1
632 def __setstate__(self, a):
633 self.my_obj = __torch__.torch.classes._TorchScriptTesting._FlatbufferTest()
634
635 def forward(self, x) -> str:
636 return self.my_obj.get(x)
637 )",
638 std::make_shared<TestModuleResolver>());
639
640 CompilationOptions options;
641 mobile::Module bc = jitModuleToMobile(m, options);
642 auto buff = save_mobile_module_to_bytes(bc);
643 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
644 std::string expected = "Hello! Your tensor has 12 elements!";
645 auto res =
646 bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
647 const auto& str2 = res.toStringRef();
648 AT_ASSERT(str2 == expected);
649}
650
651TEST(FlatbufferTest, BuiltinFunction) {
652 script::Module m("m");
653 auto custom_class_obj = make_custom_class<TorchBindFlatbufferTestStruct>();
654 m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
655 m.define(R"(
656 def forward(self, x) -> str:
657 return self.my_obj.get(x)
658 )");
659
660 CompilationOptions options;
661 mobile::Module bc = jitModuleToMobile(m, options);
662 auto res =
663 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
664 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
665 auto str = res.toStringRef();
666 std::string expected = "Hello! Your tensor has 12 elements!";
667 AT_ASSERT(str == expected);
668
669 auto buff = save_mobile_module_to_bytes(bc);
670 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
671 res = bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
672 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
673 str = res.toStringRef();
674 AT_ASSERT(str == expected);
675}
676
677TEST(FlatbufferTest, Eval) {
678 std::vector<torch::jit::IValue> inputs;
679
680 Module m("m");
681 m.define(R"(
682 def __init__(self, x):
683 self.training = True
684
685 def forward(self, input):
686 return torch.dropout(input, 1.0, self.training)
687 )");
688
689 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
690 inputs.push_back(torch::ones({1, 1, 28, 28}));
691 m.eval();
692 auto outputref = m.forward(inputs).toTensor();
693
694 // save m in training mode to make sure that mobile eval() will correctly
695 // change back to eval mode
696 m.train();
697 CompilationOptions options;
698 mobile::Module bc = jitModuleToMobile(m, options);
699 bc.eval();
700 IValue res;
701 for (int i = 0; i < 3; ++i) {
702 res = bc.get_method("forward")(inputs);
703 }
704 auto output = res.toTensor();
705 AT_ASSERT(outputref.dim() == output.dim());
706 AT_ASSERT(
707 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
708
709 auto buff = save_mobile_module_to_bytes(bc);
710 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
711 bc2.eval();
712 for (int i = 0; i < 3; ++i) {
713 res = bc2.get_method("forward")(inputs);
714 }
715 output = res.toTensor();
716 AT_ASSERT(outputref.dim() == output.dim());
717 AT_ASSERT(
718 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
719}
720
721TEST(FlatbufferTest, FindWrongMethodName) {
722 Module m("m");
723 m.register_parameter("foo", torch::ones({}), false);
724 m.define(R"(
725 def add(self, x):
726 b = 4
727 return self.foo + x + b
728 )");
729 CompilationOptions options;
730 mobile::Module bc = jitModuleToMobile(m, options);
731 ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
732
733 auto buff = save_mobile_module_to_bytes(bc);
734 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
735 ASSERT_TRUE(bc2.find_method("forward") == c10::nullopt);
736}
737
738TEST(FlatbufferTest, FindAndRunMethod) {
739 Module m("m");
740 m.register_parameter("foo", torch::ones({}), false);
741 m.define(R"(
742 def add_it(self, x):
743 b = 4
744 return self.foo + x + b
745 )");
746
747 std::vector<IValue> inputs;
748 auto minput = 5 * torch::ones({});
749 inputs.emplace_back(minput);
750 auto ref = m.get_method("add_it")(inputs);
751
752 CompilationOptions options;
753 mobile::Module bc = jitModuleToMobile(m, options);
754 IValue res;
755 for (int i = 0; i < 3; ++i) {
756 auto bcinputs = inputs;
757 auto method = bc.find_method("add_it");
758 AT_ASSERT(method != c10::nullopt);
759 res = (*method)(std::move(bcinputs));
760 }
761
762 auto resd = res.toTensor().item<float>();
763 auto refd = ref.toTensor().item<float>();
764 AT_ASSERT(resd == refd);
765
766 auto buff = save_mobile_module_to_bytes(bc);
767 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
768
769 for (int i = 0; i < 3; ++i) {
770 auto bcinputs = inputs;
771 auto method = bc2.find_method("add_it");
772 AT_ASSERT(method != c10::nullopt);
773 res = (*method)(std::move(bcinputs));
774 }
775
776 resd = res.toTensor().item<float>();
777 AT_ASSERT(resd == refd);
778}
779
780TEST(FlatbufferTest, RunMethodVariadic) {
781 Module m("m");
782 m.register_parameter("foo", torch::ones({}), false);
783 m.define(R"(
784 def add_three(self, x, y):
785 return self.foo + x + y
786 )");
787
788 std::vector<IValue> inputs;
789 auto inputx = 5 * torch::ones({});
790 auto inputy = 4 * torch::ones({});
791 auto ref = m.run_method("add_three", inputx, inputy);
792
793 CompilationOptions options;
794 mobile::Module bc = jitModuleToMobile(m, options);
795 IValue res = bc.run_method("add_three", inputx, inputy);
796
797 auto resd = res.toTensor().item<float>();
798 auto refd = ref.toTensor().item<float>();
799 AT_ASSERT(resd == refd);
800
801 auto buff = save_mobile_module_to_bytes(bc);
802 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
803 res = bc.run_method("add_three", inputx, inputy);
804 resd = res.toTensor().item<float>();
805 AT_ASSERT(resd == refd);
806}
807
808TEST(FlatbufferTest, DuplicateSetState) {
809 Module m("M");
810 m.register_parameter("foo", torch::ones({}), false);
811 m.define(R"(
812 def __getstate__(self):
813 return self.foo + self.foo
814 def __setstate__(self, a):
815 self.foo = a
816 def forward(self, x):
817 b = 4
818 return self.foo + x + b
819 )");
820
821 Module b("B");
822 b.register_module("M0", m);
823 b.register_module("M1", m);
824 b.define(R"(
825 def forward(self, x):
826 return self.M0.forward(x) + self.M1.forward(x)
827 )");
828
829 CompilationOptions options;
830 mobile::Module bc = jitModuleToMobile(m, options);
831 const auto methods = bc.get_methods();
832 const size_t expected_n = 3;
833 ASSERT_EQ(methods.size(), expected_n);
834
835 auto buff = save_mobile_module_to_bytes(bc);
836 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
837 const auto methods2 = bc.get_methods();
838 ASSERT_EQ(methods2.size(), expected_n);
839}
840
841TEST(FlatbufferTest, OpNameExportFetchRootOperators) {
842 torch::jit::Module m("m");
843 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
844 m.register_parameter("bias", torch::ones({20}), false);
845 m.define(R"(
846 def forward(self, input):
847 x1 = torch.zeros(2, 2)
848 x2 = torch.empty_like(torch.empty(2, 2))
849 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
850 return (x1, x2, x3)
851 )");
852 m.eval();
853
854 CompilationOptions options;
855 mobile::Module ptl_model = jitModuleToMobile(m, options);
856 std::set<std::string> operator_names =
857 torch::jit::mobile::_export_operator_list(ptl_model);
858 std::set<std::string> expected_operator_names = {
859 "aten::_convolution",
860 "aten::empty.memory_format",
861 "aten::empty_like",
862 "aten::zeros",
863 };
864 EXPECT_EQ(operator_names, expected_operator_names)
865 << "Expected the root operator lists to be the same";
866
867 auto buff = save_mobile_module_to_bytes(ptl_model);
868 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
869 operator_names = torch::jit::mobile::_export_operator_list(bc2);
870 EXPECT_EQ(operator_names, expected_operator_names)
871 << "Expected the root operator lists to be the same";
872}
873
874TEST(FlatbufferTest, DefaultArgsConv) {
875 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
876 if (s && strcmp(s, "1") == 0)
877 return;
878
879 std::vector<torch::jit::IValue> inputs;
880
881 Module m("m");
882 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
883 m.register_parameter("bias", torch::ones({20}), false);
884 m.define(R"(
885 def forward(self, input):
886 return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
887 )");
888
889 inputs.emplace_back(torch::ones({1, 1, 28, 28}));
890
891 auto outputref = m.forward(inputs).toTensor();
892
893 CompilationOptions options;
894 mobile::Module bc = jitModuleToMobile(m, options);
895 IValue res;
896 for (int i = 0; i < 1; ++i) {
897 res = bc.get_method("forward")(inputs);
898 }
899 auto output = res.toTensor();
900 AT_ASSERT(outputref.dim() == output.dim());
901 AT_ASSERT(output.equal(outputref));
902
903 auto buff = save_mobile_module_to_bytes(bc);
904 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
905 for (int i = 0; i < 1; ++i) {
906 res = bc2.get_method("forward")(inputs);
907 }
908 output = res.toTensor();
909 AT_ASSERT(outputref.dim() == output.dim());
910 AT_ASSERT(output.equal(outputref));
911}
912
913namespace {
914void testLiteModuleCompareResultTensors(
915 Module& m,
916 const std::vector<torch::jit::IValue>& inputs,
917 const std::string& method_name = "forward") {
918 auto outputref = m.get_method(method_name)(inputs).toTensor();
919
920 CompilationOptions options;
921 mobile::Module bc = jitModuleToMobile(m, options);
922 IValue res;
923 for (int i = 0; i < 3; ++i) {
924 res = bc.get_method(method_name)(inputs);
925 }
926 auto output = res.toTensor();
927 AT_ASSERT(outputref.dim() == output.dim());
928 AT_ASSERT(output.equal(outputref));
929
930 auto buff = save_mobile_module_to_bytes(bc);
931 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
932 for (int i = 0; i < 3; ++i) {
933 res = bc2.get_method(method_name)(inputs);
934 }
935 output = res.toTensor();
936 AT_ASSERT(outputref.dim() == output.dim());
937 AT_ASSERT(output.equal(outputref));
938}
939
940static void testDefaultArgsPinv(int num_args) {
941 Module m("m");
942 if (num_args == 1) {
943 m.define(R"(
944 def forward(self, input):
945 return torch.linalg_pinv(input)
946 )");
947 } else if (num_args == 2) {
948 m.define(R"(
949 def forward(self, input):
950 return torch.linalg_pinv(input, 1e-5)
951 )");
952 } else if (num_args == 3) {
953 m.define(R"(
954 def forward(self, input):
955 return torch.linalg_pinv(input, 1e-5, True)
956 )");
957 }
958
959 std::vector<torch::jit::IValue> inputs;
960 const int N = 28;
961 auto input = torch::range(1, N * N, 1);
962 input[0] = 1; // a more stable matrix
963 input = input.view({N, N});
964 inputs.emplace_back(input);
965 testLiteModuleCompareResultTensors(m, inputs);
966}
967} // namespace
968
969#if !defined FB_XPLAT_BUILD
970TEST(FlatbufferTest, DefaultArgsPinv) {
971 // Test with different number of specified arguments.
972 // Arguments not specified take default value.
973 for (int num_args = 1; num_args <= 3; ++num_args) {
974 testDefaultArgsPinv(num_args);
975 }
976
977 // bytecode with one specified argument:
978 // (6,
979 // ('__torch__.m.forward',
980 // (('instructions',
981 // (('STOREN', 1, 2),
982 // ('DROPR', 1, 0),
983 // ('MOVE', 2, 0),
984 // ('OP', 0, 0),
985 // ('RET', 0, 0))),
986 // ('operators', (('aten::linalg_pinv', '', 1),)),
987 // ('constants', (False, 1e-15)), # default constants are not
988 // used
989 // ('types', ()),
990 // ('register_size', 2)),
991 // (('arguments',
992 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
993 // None)),
994 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
995 // None)))),
996 // ('returns',
997 // ((('name', ''), ('type', 'Tensor'), ('default_value',
998 // None)),)))))
999
1000 // bytecode with 2 specified argument:
1001 // (6,
1002 // ('__torch__.m.forward',
1003 // (('instructions',
1004 // (('STOREN', 1, 2),
1005 // ('DROPR', 1, 0),
1006 // ('MOVE', 2, 0),
1007 // ('LOADC', 1, 0), # added LOADC for specified argument
1008 // ('OP', 0, 0),
1009 // ('RET', 0, 0))),
1010 // ('operators', (('aten::linalg_pinv', '', 2),)),
1011 // ('constants', (False, 1e-05)), # updated constant table
1012 // ('types', ()),
1013 // ('register_size', 2)),
1014 // (('arguments',
1015 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1016 // None)),
1017 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1018 // None)))),
1019 // ('returns',
1020 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1021 // None)),)))))
1022
1023 // bytecode with 3 specified arguments:
1024 // (6,
1025 // ('__torch__.m.forward',
1026 // (('instructions',
1027 // (('STOREN', 1, 2),
1028 // ('DROPR', 1, 0),
1029 // ('MOVE', 2, 0),
1030 // ('LOADC', 1, 0),
1031 // ('LOADC', 0, 0),
1032 // ('OP', 0, 0),
1033 // ('RET', 0, 0))),
1034 // ('operators', (('aten::linalg_pinv', '', 3),)),
1035 // ('constants', (True, 1e-05)),
1036 // ('types', ()),
1037 // ('register_size', 2)),
1038 // (('arguments',
1039 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1040 // None)),
1041 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1042 // None)))),
1043 // ('returns',
1044 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1045 // None)),)))))
1046}
1047
1048TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) {
1049 // The second argument is specified, but the value is the same as the default
1050 // value. It's treated as "not specified" since the value can be fetched from
1051 // schema.
1052 Module m("m");
1053 m.define(R"(
1054 def forward(self, input):
1055 return torch.linalg_tensorinv(input, 2)
1056 )");
1057 torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1058 auto arg_nums = code.op_to_num_specified_args();
1059 ASSERT_EQ(arg_nums.size(), 1);
1060 ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1061 std::vector<torch::jit::IValue> inputs;
1062 const int N = 4;
1063 auto input = torch::rand({N, N, N, N});
1064 inputs.emplace_back(input);
1065 testLiteModuleCompareResultTensors(m, inputs);
1066}
1067
1068static void testDefaultArgsPinvWithOutArg(int num_args) {
1069 Module m("m");
1070 if (num_args == 1) {
1071 m.define(R"(
1072 def forward(self, input):
1073 return torch.linalg_pinv(input, out=input)
1074 )");
1075 } else if (num_args == 2) {
1076 m.define(R"(
1077 def forward(self, input):
1078 return torch.linalg_pinv(input, 1e-5, out=input)
1079 )");
1080 } else if (num_args == 3) {
1081 m.define(R"(
1082 def forward(self, input):
1083 return torch.linalg_pinv(input, 1e-5, True, out=input)
1084 )");
1085 }
1086
1087 const int N = 28;
1088 auto input = torch::range(1, N * N, 1);
1089 input[0] = 10000; // a more stable matrix
1090 input = input.view({N, N});
1091 auto ref = m.run_method("forward", input);
1092 TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1093 TORCH_CHECK(input.equal(ref.toTensor()));
1094}
1095
1096TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
1097 // Test with different number of specified arguments + out arg.
1098 // Arguments not specified take default value.
1099 for (int num_args = 1; num_args <= 3; ++num_args) {
1100 testDefaultArgsPinvWithOutArg(num_args);
1101 }
1102}
1103
1104TEST(FlatbufferTest, DefaultArgsWithOutArg) {
1105 Module m("m");
1106 m.define(R"(
1107 def forward(self, x, h):
1108 torch.add(x, h, out=x)
1109 )");
1110
1111 std::vector<IValue> inputs;
1112 auto input_x = 2 * torch::ones({});
1113 auto input_h = torch::ones({});
1114 auto ref = m.run_method("forward", input_x, input_h);
1115
1116 CompilationOptions options;
1117 mobile::Module bc = jitModuleToMobile(m, options);
1118 bc.run_method("forward", input_x, input_h);
1119 AT_ASSERT(input_x.equal(4 * torch::ones({})));
1120
1121 auto buff = save_mobile_module_to_bytes(bc);
1122 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1123 auto input_x2 = 2 * torch::ones({});
1124 auto input_h2 = torch::ones({});
1125 m.run_method("forward", input_x2, input_h2);
1126 bc2.run_method("forward", input_x2, input_h2);
1127 AT_ASSERT(input_x2.equal(4 * torch::ones({})));
1128}
1129
1130#endif // !defined(FB_XPLAT_BUILD)
1131
1132namespace {
1133static auto reg =
1134 torch::class_<TorchBindFlatbufferTestStruct>(
1135 "_TorchScriptTesting",
1136 "_FlatbufferTest")
1137 .def(torch::init<>())
1138 .def("get", &TorchBindFlatbufferTestStruct::get)
1139 .def_pickle(
1140 // __getattr__
1141 [](const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self)
1142 -> int64_t { return 0; },
1143 // __setattr__
1144 [](int64_t state) {
1145 return c10::make_intrusive<TorchBindFlatbufferTestStruct>();
1146 });
1147
1148} // namespace
1149
1150TEST(FlatbufferTest, OperatorCacheDifferentiatesDefaultArgs) {
1151 // Create 3 methods:
1152 //
1153 // 1. forward() returns a tensor with dtype=torch.int64 (4)
1154 // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1155 // 3. forward3() returns a tensor with dtype=torch.float32 but
1156 // the dtype is inferred by the input tensor's dtype
1157 //
1158 // If caching works correctly, then the result from the full-jit
1159 // module and the lite module will be the same. Otherwise, it
1160 // will be different if we don't correctly ignore the cache
1161 // entry for an operator that has a different number of
1162 // arguments.
1163 Module m("m");
1164 m.define(R"(
1165 def forward(self):
1166 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1167 return ret1.fill_(25)
1168 )");
1169 m.define(R"(
1170 def forward2(self):
1171 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1172 return ret1.fill_(32.0)
1173 )");
1174 m.define(R"(
1175 def forward3(self):
1176 ret1 = torch.new_empty(torch.zeros(10), [10])
1177 return ret1.fill_(12.0)
1178 )");
1179
1180 std::vector<torch::jit::IValue> inputs;
1181 testLiteModuleCompareResultTensors(m, inputs, "forward");
1182 testLiteModuleCompareResultTensors(m, inputs, "forward2");
1183 testLiteModuleCompareResultTensors(m, inputs, "forward3");
1184}
1185
1186TEST(FlatbufferTest, OperatorSize1) {
1187 Module m("m");
1188 m.define(R"(
1189 def forward(self, input: Tensor, scale:float):
1190 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1191 )");
1192
1193 CompilationOptions options;
1194 mobile::Module bc = jitModuleToMobile(m, options);
1195 const auto& func = bc.get_method("forward").function();
1196 ASSERT_EQ(
1197 func.get_code().operator_input_sizes_.size(),
1198 func.get_code().operators_.size());
1199
1200 auto buff = save_mobile_module_to_bytes(bc);
1201 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1202 const auto& func2 = bc.get_method("forward").function();
1203 ASSERT_EQ(
1204 func2.get_code().operator_input_sizes_.size(),
1205 func2.get_code().operators_.size());
1206}
1207
1208TEST(FlatbufferTest, BoolAndDoubleList) {
1209 Module m("m");
1210 c10::List<bool> boollist;
1211 boollist.push_back(false);
1212 IValue boollist_ival = boollist;
1213 IValue doublelist = std::vector<double>{2.0};
1214 m.register_attribute("bool_list", boollist_ival.type(), boollist_ival);
1215 m.register_attribute("double_list", doublelist.type(), doublelist);
1216
1217 CompilationOptions options;
1218 mobile::Module bc = jitModuleToMobile(m, options);
1219 auto buff = save_mobile_module_to_bytes(bc);
1220 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1221
1222 // if the variables read are wrong type the conversion will raise exception
1223 auto boolval = bc2.attr("bool_list", {}).toBoolList().get(0);
1224 auto doubleval = bc2.attr("double_list", {}).toDoubleList().get(0);
1225
1226 ASSERT_EQ(boolval, false);
1227 ASSERT_EQ(doubleval, 2.0);
1228}
1229
1230TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1231 const std::vector<std::string> test_programs{
1232 // test invoking a method with default parameter
1233 R"(
1234 def test_func(self, x, b : int = 4):
1235 return self.foo + x + b
1236 )",
1237 // inner method call with default parameter (gets inlined)
1238 R"(
1239 def add_with_default_arg(self, x, b : int = 4):
1240 return self.foo + x + b
1241 def test_func(self, x):
1242 return self.add_with_default_arg(x) # invoke method w/ default arg
1243 )",
1244 // simple method call
1245 R"(
1246 def test_func(self, x):
1247 b = 4
1248 return self.foo + x + b
1249 )",
1250 };
1251 for (const auto& test_program : test_programs) {
1252 Module m("m");
1253 m.register_parameter("foo", torch::ones({}), false);
1254 m.define(test_program);
1255
1256 CompilationOptions options;
1257 mobile::Module bc = jitModuleToMobile(m, options);
1258 const auto& func = bc.get_method("test_func").function();
1259 ASSERT_EQ(
1260 func.get_code().operator_input_sizes_.size(),
1261 func.get_code().operators_.size());
1262
1263 auto buff = save_mobile_module_to_bytes(bc);
1264 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1265 const auto& func2 = bc.get_method("test_func").function();
1266 ASSERT_EQ(
1267 func2.get_code().operator_input_sizes_.size(),
1268 func2.get_code().operators_.size());
1269 }
1270}
1271
1272Module jitModuleFromBuffer(void* data, size_t size) {
1273 // Make a copy of the data so we can use the existing API, which takes
1274 // ownership. The `data` param might point into the middle of a buffer, so we
1275 // can't safely take ownership of it directly.
1276 // @nolint CLANGTIDY cppcoreguidelines-no-malloc
1277 std::shared_ptr<char> copy(static_cast<char*>(malloc(size)), free);
1278 memcpy(copy.get(), data, size);
1279
1280 ExtraFilesMap extra_files;
1281 return parse_and_initialize_jit_module(std::move(copy), size, extra_files);
1282}
1283
1284TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
1285 Module m("m");
1286 m.define(R"(
1287 def forward(self, input: Tensor, scale:float):
1288 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1289 )");
1290
1291 std::vector<IValue> inputs;
1292 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
1293 inputs.emplace_back(at::Scalar(2.0));
1294 auto ref = m.forward(inputs);
1295
1296 std::stringstream ss;
1297 m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true);
1298 auto mm = _load_for_mobile(ss);
1299 auto m2 = load(ss);
1300
1301 auto res = m2.forward(inputs);
1302 auto resm = mm.forward(inputs);
1303
1304 auto resd = res.toTensor();
1305 auto refd = ref.toTensor();
1306 auto resmd = resm.toTensor();
1307 ASSERT_TRUE(resd.equal(refd));
1308 ASSERT_TRUE(resmd.equal(refd));
1309}
1310
1311TEST(TestSourceFlatbuffer, CheckAttrAccess) {
1312 Module m("m");
1313 m.register_attribute("mobile_optimized", BoolType::get(), true);
1314 auto data = save_jit_module_to_bytes(m);
1315 Module m2 = jitModuleFromBuffer(data->data(), data->size());
1316 bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
1317 AT_ASSERT(mobile_optimized);
1318 mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1319 mobile_optimized = m3.attr("mobile_optimized", false).toBool();
1320 AT_ASSERT(mobile_optimized);
1321}
1322
1323TEST(TestSourceFlatbuffer,
1324 MethodInvocation) { // NOLINT (use =delete in gtest)
1325 const std::vector<std::string> test_programs{
1326 // test invoking a method with default parameter
1327 R"(
1328 def test_func(self, x, b : int = 4):
1329 return self.foo + x + b
1330 )",
1331 // inner method call with default parameter (gets inlined)
1332 R"(
1333 def add_with_default_arg(self, x, b : int = 4):
1334 return self.foo + x + b
1335 def test_func(self, x):
1336 return self.add_with_default_arg(x) # invoke method w/ default arg
1337 )",
1338 // simple method call
1339 R"(
1340 def test_func(self, x):
1341 b = 4
1342 return self.foo + x + b
1343 )",
1344 };
1345 for (const auto& test_program : test_programs) {
1346 Module m("m");
1347 m.register_parameter("foo", torch::ones({}), false);
1348 m.define(test_program);
1349
1350 const int fortyTwo = 42; // (keep linter happy)
1351 auto minput = fortyTwo * torch::ones({});
1352 auto ref = m.run_method("test_func", minput);
1353
1354 auto data = save_jit_module_to_bytes(m);
1355 Module m2 = jitModuleFromBuffer(data->data(), data->size());
1356 const auto& test_func = m2.get_method("test_func");
1357 IValue res;
1358 for (int i = 0; i < 3; ++i) {
1359 res = test_func({minput});
1360 }
1361 auto resd = res.toTensor().item<float>();
1362 auto refd = ref.toTensor().item<float>();
1363 AT_ASSERT(resd == refd);
1364
1365 mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1366 const auto& test_func3 = m3.get_method("test_func");
1367 for (int i = 0; i < 3; ++i) {
1368 res = test_func3({minput});
1369 }
1370 resd = res.toTensor().item<float>();
1371 refd = ref.toTensor().item<float>();
1372 AT_ASSERT(resd == refd);
1373 }
1374}
1375
1376#if !defined FB_XPLAT_BUILD
1377// The following test run in fbcode only
1378TEST(FlatbufferUpgraderTest, DivTensorV2) {
1379 std::string filePath(__FILE__);
1380 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1381 test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl.ff");
1382 /*
1383 (('__torch__.MyModule.forward',
1384 (('instructions',
1385 (('STOREN', 1, 3),
1386 ('DROPR', 1, 0),
1387 ('LOAD', 2, 0),
1388 ('LOAD', 3, 0),
1389 ('OP', 0, 0),
1390 ('LOAD', 2, 0),
1391 ('LOAD', 3, 0),
1392 ('OP', 1, 0),
1393 ('MOVE', 2, 0),
1394 ('MOVE', 3, 0),
1395 ('OP', 2, 0),
1396 ('TUPLE_CONSTRUCT', 3, 0),
1397 ('RET', 0, 0))),
1398 ('operators',
1399 (('aten::div', 'Tensor'),
1400 ('aten::div', 'Tensor'),
1401 ('aten::div', 'Tensor'))),
1402 ('constants', ()),
1403 ('types', ()),
1404 ('register_size', 3))),)
1405
1406 */
1407 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1408 auto intrsuction_list =
1409 m_module.get_method("forward").function().get_code().instructions_;
1410 uint64_t number_of_call_instruction = 0;
1411 for (auto& instruction : intrsuction_list) {
1412 number_of_call_instruction += (instruction.op == OpCode::CALL);
1413 }
1414 // 3 operators will use upgrader
1415 ASSERT_EQ(number_of_call_instruction, 3);
1416
1417 std::vector<IValue> inputs = {
1418 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1419 auto actual_output = m_module.forward(inputs);
1420 auto expect_output = 2.0 * torch::ones({1});
1421 auto actual_output_list = actual_output.toTuple()->elements();
1422 ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1423}
1424
1425TEST(FlatbufferUpgraderTest, DivTensorOutV2) {
1426 std::string filePath(__FILE__);
1427 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1428 test_model_file.append(
1429 "upgrader_models/test_versioned_div_tensor_out_v2.ptl.ff");
1430 /*
1431 (('__torch__.MyModule.forward',
1432 (('instructions',
1433 (('STOREN', 1, 4),
1434 ('DROPR', 1, 0),
1435 ('MOVE', 2, 0),
1436 ('MOVE', 3, 0),
1437 ('MOVE', 4, 0),
1438 ('OP', 0, 0),
1439 ('RET', 0, 0))),
1440 ('operators', (('aten::div', 'out'),)),
1441 ('constants', ()),
1442 ('types', ()),
1443 ('register_size', 4))),)
1444 */
1445 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1446
1447 auto intrsuction_list =
1448 m_module.get_method("forward").function().get_code().instructions_;
1449 uint64_t number_of_call_instruction = 0;
1450 for (auto& instruction : intrsuction_list) {
1451 number_of_call_instruction += (instruction.op == OpCode::CALL);
1452 }
1453 // One operator will use upgrader
1454 ASSERT_EQ(number_of_call_instruction, 1);
1455
1456 std::vector<IValue> inputs{
1457 IValue(6 * torch::ones({1})),
1458 IValue(3 * torch::ones({1})),
1459 IValue(torch::empty({1}))};
1460 m_module.forward(inputs);
1461 auto expect_output = 2.0 * torch::ones({1});
1462 auto actual_output = inputs[2].toTensor();
1463 // The out argument will be overwritten with the output
1464 ASSERT_TRUE(actual_output.equal(expect_output));
1465}
1466
1467TEST(FlatbufferUpgraderTest, DivTensorInplaceV2) {
1468 std::string filePath(__FILE__);
1469 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1470 test_model_file.append(
1471 "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl.ff");
1472 /*
1473 (('__torch__.MyModule.forward',
1474 (('instructions',
1475 (('STOREN', 1, 3),
1476 ('DROPR', 1, 0),
1477 ('MOVE', 2, 0),
1478 ('MOVE', 3, 0),
1479 ('OP', 0, 0),
1480 ('RET', 0, 0))),
1481 ('operators', (('aten::div_', 'Tensor'),)),
1482 ('constants', ()),
1483 ('types', ()),
1484 ('register_size', 3))),)
1485 */
1486 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1487
1488 auto intrsuction_list =
1489 m_module.get_method("forward").function().get_code().instructions_;
1490 uint64_t number_of_call_instruction = 0;
1491 for (auto& instruction : intrsuction_list) {
1492 number_of_call_instruction += (instruction.op == OpCode::CALL);
1493 }
1494 // One operator will use upgrader
1495 ASSERT_EQ(number_of_call_instruction, 1);
1496
1497 std::vector<IValue> inputs{
1498 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1499 m_module.forward(inputs);
1500 auto expect_output = 2.0 * torch::ones({1});
1501 auto actual_output = inputs[0].toTensor();
1502 // The out argument will be overwritten with the output
1503 ASSERT_TRUE(actual_output.equal(expect_output));
1504}
1505
1506TEST(FlatbufferUpgraderTest, DivScalarFloatV2) {
1507 std::string filePath(__FILE__);
1508 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1509 test_model_file.append(
1510 "upgrader_models/test_versioned_div_scalar_float_v2.ptl.ff");
1511 /*
1512 (('__torch__.MyModuleFloat.forward',
1513 (('instructions',
1514 (('STOREN', 1, 3),
1515 ('DROPR', 1, 0),
1516 ('MOVE', 2, 0),
1517 ('MOVE', 3, 0),
1518 ('OP', 0, 0),
1519 ('RET', 0, 0))),
1520 ('operators', (('aten::div', 'Scalar'),)),
1521 ('constants', ()),
1522 ('types', ()),
1523 ('register_size', 3))),)
1524 */
1525
1526 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1527
1528 auto intrsuction_list =
1529 m_module.get_method("forward").function().get_code().instructions_;
1530 uint64_t number_of_call_instruction = 0;
1531 for (auto& instruction : intrsuction_list) {
1532 number_of_call_instruction += (instruction.op == OpCode::CALL);
1533 }
1534 // One operator will use upgrader
1535 ASSERT_EQ(number_of_call_instruction, 1);
1536
1537 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1538 auto output = m_module.forward(inputs);
1539 auto expect_output = 2.0 * torch::ones({1});
1540 auto actual_output = output.toTensor();
1541
1542 // The out argument will be overwritten with the output
1543 ASSERT_TRUE(actual_output.equal(expect_output));
1544}
1545
1546TEST(FlatbufferUpgraderTest, DivScalarReciprocalFloatV2) {
1547 std::string filePath(__FILE__);
1548 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1549 test_model_file.append(
1550 "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl.ff");
1551 /*
1552 (('__torch__.MyModuleFloat.forward',
1553 (('instructions',
1554 (('STOREN', 1, 3),
1555 ('DROPR', 1, 0),
1556 ('MOVE', 2, 0),
1557 ('OP', 0, 0),
1558 ('MOVE', 3, 0),
1559 ('OP', 1, 0),
1560 ('RET', 0, 0))),
1561 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1562 ('constants', ()),
1563 ('types', ()),
1564 ('register_size', 3))),)
1565 */
1566 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1567
1568 auto intrsuction_list =
1569 m_module.get_method("forward").function().get_code().instructions_;
1570 uint64_t number_of_call_instruction = 0;
1571 for (auto& instruction : intrsuction_list) {
1572 number_of_call_instruction += (instruction.op == OpCode::CALL);
1573 }
1574 // No operator will use upgrader
1575 ASSERT_EQ(number_of_call_instruction, 0);
1576
1577 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1578 auto output = m_module.forward(inputs);
1579 auto expect_output = 0.5 * torch::ones({1});
1580 auto actual_output = output.toTensor();
1581 // The out argument will be overwritten with the output
1582 ASSERT_TRUE(actual_output.equal(expect_output));
1583}
1584
1585TEST(FlatbufferUpgraderTest, DivScalarReciprocalIntV2) {
1586 std::string filePath(__FILE__);
1587 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1588 test_model_file.append(
1589 "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl.ff");
1590 /*
1591 (('__torch__.MyModuleInt.forward',
1592 (('instructions',
1593 (('STOREN', 1, 3),
1594 ('DROPR', 1, 0),
1595 ('MOVE', 2, 0),
1596 ('OP', 0, 0),
1597 ('MOVE', 3, 0),
1598 ('OP', 1, 0),
1599 ('RET', 0, 0))),
1600 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1601 ('constants', ()),
1602 ('types', ()),
1603 ('register_size', 3))),)
1604 */
1605 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1606
1607 auto intrsuction_list =
1608 m_module.get_method("forward").function().get_code().instructions_;
1609 uint64_t number_of_call_instruction = 0;
1610 for (auto& instruction : intrsuction_list) {
1611 number_of_call_instruction += (instruction.op == OpCode::CALL);
1612 }
1613 // No operator will use upgrader
1614 ASSERT_EQ(number_of_call_instruction, 0);
1615
1616 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1617 auto output = m_module.forward(inputs);
1618 auto expect_output = 0.5 * torch::ones({1});
1619 auto actual_output = output.toTensor();
1620
1621 // The out argument will be overwritten with the output
1622 ASSERT_TRUE(actual_output.equal(expect_output));
1623}
1624
1625TEST(FlatbufferUpgraderTest, DivScalarScalarV2) {
1626 std::string filePath(__FILE__);
1627 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1628 test_model_file.append(
1629 "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl.ff");
1630 /*
1631 (('__torch__.MyModule.forward',
1632 (('instructions',
1633 (('STOREN', 1, 5),
1634 ('DROPR', 1, 0),
1635 ('LOAD', 2, 0),
1636 ('LOAD', 3, 0),
1637 ('OP', 0, 0),
1638 ('MOVE', 2, 0),
1639 ('LOAD', 4, 0),
1640 ('OP', 1, 0),
1641 ('LOAD', 3, 0),
1642 ('MOVE', 4, 0),
1643 ('OP', 2, 0),
1644 ('MOVE', 3, 0),
1645 ('MOVE', 5, 0),
1646 ('OP', 3, 0),
1647 ('TUPLE_CONSTRUCT', 4, 0),
1648 ('RET', 0, 0))),
1649 ('operators',
1650 (('aten::div', ''),
1651 ('aten::div', 'float'),
1652 ('aten::div', ''),
1653 ('aten::div', 'int'))),
1654 ('constants', ()),
1655 ('types', ()),
1656 ('register_size', 5))),)
1657 */
1658 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1659 auto intrsuction_list =
1660 m_module.get_method("forward").function().get_code().instructions_;
1661 uint64_t number_of_call_instruction = 0;
1662 for (auto& instruction : intrsuction_list) {
1663 number_of_call_instruction += (instruction.op == OpCode::CALL);
1664 }
1665 // No operator will use upgrader
1666 ASSERT_EQ(number_of_call_instruction, 0);
1667
1668 std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1669 auto output = m_module.forward(inputs);
1670 auto output_list = output.toTupleRef().elements();
1671 auto expect_output = std::vector<IValue>(
1672 {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1673 // auto actual_output = output.toTensor();
1674 for (size_t i = 0; i < expect_output.size(); i++) {
1675 ASSERT_EQ(output_list[i], expect_output[i]);
1676 }
1677}
1678
1679TEST(FlatbufferUpgraderTest, DivScalarIntV2) {
1680 std::string filePath(__FILE__);
1681 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1682 test_model_file.append(
1683 "upgrader_models/test_versioned_div_scalar_int_v2.ptl.ff");
1684 /*
1685 (('__torch__.MyModuleInt.forward',
1686 (('instructions',
1687 (('STOREN', 1, 3),
1688 ('DROPR', 1, 0),
1689 ('MOVE', 2, 0),
1690 ('MOVE', 3, 0),
1691 ('OP', 0, 0),
1692 ('RET', 0, 0))),
1693 ('operators', (('aten::div', 'Scalar'),)),
1694 ('constants', ()),
1695 ('types', ()),
1696 ('register_size', 3))),)
1697 */
1698 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1699
1700 auto intrsuction_list =
1701 m_module.get_method("forward").function().get_code().instructions_;
1702 uint64_t number_of_call_instruction = 0;
1703 for (auto& instruction : intrsuction_list) {
1704 number_of_call_instruction += (instruction.op == OpCode::CALL);
1705 }
1706 // One operator will use upgrader
1707 ASSERT_EQ(number_of_call_instruction, 1);
1708
1709 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1710 auto output = m_module.forward(inputs);
1711 auto expect_output = 2.0 * torch::ones({1});
1712 auto actual_output = output.toTensor();
1713
1714 // The out argument will be overwritten with the output
1715 ASSERT_TRUE(actual_output.equal(expect_output));
1716}
1717
1718TEST(FlatbufferUpgraderTest, DivScalarInplaceFloatV2) {
1719 std::string filePath(__FILE__);
1720 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1721 test_model_file.append(
1722 "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl.ff");
1723 /*
1724 (('__torch__.MyModuleFloat.forward',
1725 (('instructions',
1726 (('STOREN', 1, 3),
1727 ('DROPR', 1, 0),
1728 ('MOVE', 2, 0),
1729 ('MOVE', 3, 0),
1730 ('OP', 0, 0),
1731 ('RET', 0, 0))),
1732 ('operators', (('aten::div_', 'Scalar'),)),
1733 ('constants', ()),
1734 ('types', ()),
1735 ('register_size', 3))),)
1736 */
1737
1738 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1739
1740 auto intrsuction_list =
1741 m_module.get_method("forward").function().get_code().instructions_;
1742 uint64_t number_of_call_instruction = 0;
1743 for (auto& instruction : intrsuction_list) {
1744 number_of_call_instruction += (instruction.op == OpCode::CALL);
1745 }
1746 // One operator will use upgrader
1747 ASSERT_EQ(number_of_call_instruction, 1);
1748
1749 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1750 auto output = m_module.forward(inputs);
1751 auto expect_output = 2.0 * torch::ones({1});
1752 auto actual_output = output.toTensor();
1753
1754 // The out argument will be overwritten with the output
1755 ASSERT_TRUE(actual_output.equal(expect_output));
1756}
1757
1758TEST(FlatbufferUpgraderTest, DivScalarInplaceIntV2) {
1759 std::string filePath(__FILE__);
1760 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1761 test_model_file.append(
1762 "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl.ff");
1763 /*
1764 (('__torch__.MyModuleInt.forward',
1765 (('instructions',
1766 (('STOREN', 1, 3),
1767 ('DROPR', 1, 0),
1768 ('MOVE', 2, 0),
1769 ('MOVE', 3, 0),
1770 ('OP', 0, 0),
1771 ('RET', 0, 0))),
1772 ('operators', (('aten::div_', 'Scalar'),)),
1773 ('constants', ()),
1774 ('types', ()),
1775 ('register_size', 3))),)
1776 */
1777
1778 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1779
1780 auto intrsuction_list =
1781 m_module.get_method("forward").function().get_code().instructions_;
1782 uint64_t number_of_call_instruction = 0;
1783 for (auto& instruction : intrsuction_list) {
1784 number_of_call_instruction += (instruction.op == OpCode::CALL);
1785 }
1786 // One operator will use upgrader
1787 ASSERT_EQ(number_of_call_instruction, 1);
1788
1789 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1790 auto output = m_module.forward(inputs);
1791 auto expect_output = 2.0 * torch::ones({1});
1792 auto actual_output = output.toTensor();
1793
1794 // The out argument will be overwritten with the output
1795 ASSERT_TRUE(actual_output.equal(expect_output));
1796}
1797
1798#endif // !defined(FB_XPLAT_BUILD)
1799
1800//
1801// Tests that need access to internal flatbuffers types/functions.
1802// Do not add any other tests after this section.
1803//
1804
1805} // namespace jit
1806} // namespace torch
1807namespace torch {
1808namespace jit {
1809
1810/**
1811 * An Allocator that can only deallocate (using delete []), counting
1812 * the number of times that it has been asked to deallocate.
1813 */
1814class TestAllocator : public flatbuffers::Allocator {
1815 public:
1816 /**
1817 * *deallocate_call_count will be incremented whenever deallocate() is called.
1818 */
1819 explicit TestAllocator(int* deallocate_call_count)
1820 : deallocate_call_count_(deallocate_call_count) {}
1821
1822 void deallocate(uint8_t* p, size_t /*size*/) override {
1823 *deallocate_call_count_ += 1;
1824 delete[] p;
1825 }
1826
1827 uint8_t* allocate(size_t) override {
1828 TORCH_CHECK(false, "allocate() should not be called");
1829 }
1830 uint8_t* reallocate_downward(uint8_t*, size_t, size_t, size_t, size_t)
1831 override {
1832 TORCH_CHECK(false, "reallocate_downward() should not be called");
1833 }
1834
1835 private:
1836 int* deallocate_call_count_;
1837};
1838
1839/// Provides access to DetachedBuffer::destroy().
1840struct DetachedBufferTestingFriend {
1841 /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
1842 /// A copy of similar code in flatbuffer_serializer.cpp.
1843 static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
1844 DetachedBuffer* buf) {
1845 return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
1846 }
1847};
1848
1849TEST(FlatbufferTest, DetachedBufferSmoke) {
1850 // Use a custom Allocator to watch the lifecycle of a
1851 // flatbuffers::DetachedBuffer.
1852 int deallocate_call_count = 0;
1853 TestAllocator alloc(&deallocate_call_count);
1854
1855 // Data for the buffer. TestAllocator will free it with `delete []`.
1856 constexpr size_t data_size = 4;
1857 uint8_t* data = new uint8_t[data_size];
1858
1859 // An internal buffer on the stack that owns the data.
1860 flatbuffers::DetachedBuffer fb_buf_local(
1861 &alloc, /*own_allocator=*/false, data, data_size, data, data_size);
1862 EXPECT_EQ(fb_buf_local.data(), data);
1863 EXPECT_EQ(fb_buf_local.size(), data_size);
1864
1865 // Mimic the code inside save_mobile_module_to_bytes by transferring ownership
1866 // to a heap object.
1867 auto fb_buf_ptr = new flatbuffers::DetachedBuffer(std::move(fb_buf_local));
1868 // The data should not have been deleted yet.
1869 EXPECT_EQ(deallocate_call_count, 0);
1870 // The new object points to the data.
1871 EXPECT_EQ(fb_buf_ptr->data(), data);
1872 EXPECT_EQ(fb_buf_ptr->size(), data_size);
1873 // The old object points to nothing.
1874 // @lint-ignore CLANGTIDY bugprone-use-after-move
1875 EXPECT_EQ(fb_buf_local.data(), nullptr);
1876 // @lint-ignore CLANGTIDY bugprone-use-after-move
1877 EXPECT_EQ(fb_buf_local.size(), 0);
1878
1879 // The top-level torch::jit::DetachedBuffer.
1880 auto wrapped_buf =
1881 new DetachedBuffer(fb_buf_ptr->data(), fb_buf_ptr->size(), fb_buf_ptr);
1882 EXPECT_EQ(wrapped_buf->data(), data);
1883 EXPECT_EQ(wrapped_buf->size(), data_size);
1884
1885 // The unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1886 {
1887 DetachedBuffer::UniqueDetachedBuffer unique_buf =
1888 DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1889 EXPECT_EQ(unique_buf->data(), data);
1890 EXPECT_EQ(unique_buf->size(), data_size);
1891
1892 // The data should not have been deleted yet.
1893 EXPECT_EQ(deallocate_call_count, 0);
1894 }
1895
1896 // Now that the unique_ptr is out of scope, the data should have been deleted.
1897 EXPECT_EQ(deallocate_call_count, 1);
1898}
1899
1900TEST(FlatbufferTest, DetachedBufferNullOwner) {
1901 // a torch::jit::DetachedBuffer with a null internal owner.
1902 std::vector<uint8_t> data(4);
1903 auto wrapped_buf = new DetachedBuffer(data.data(), data.size());
1904
1905 // A unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1906 {
1907 DetachedBuffer::UniqueDetachedBuffer unique_buf =
1908 DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1909 EXPECT_EQ(unique_buf->data(), data.data());
1910 EXPECT_EQ(unique_buf->size(), data.size());
1911 }
1912
1913 // The DetachedBuffer should have been destroyed when the UniqueDetachedBuffer
1914 // went out of scope. If we didn't crash or get any ASAN warnings, we should
1915 // be good.
1916}
1917
1918//
1919// Do not add tests here unless they require flatbuffers types. See comment at
1920// the beginning of this section.
1921//
1922
1923} // namespace jit
1924} // namespace torch
1925