1#ifdef FLATBUFFERS_VERSION_MAJOR
2#error "flatbuffer_loader.h must not include any flatbuffers headers"
3#endif // FLATBUFFERS_VERSION_MAJOR
4
5#include <array>
6#include <istream>
7#include <memory>
8#include <string>
9#include <tuple>
10#include <unordered_map>
11#include <unordered_set>
12#include <utility>
13#include <vector>
14
15#include <ATen/ATen.h>
16#include <ATen/core/dynamic_type.h>
17#include <ATen/core/ivalue.h>
18#include <ATen/core/qualified_name.h>
19#include <c10/core/CPUAllocator.h>
20#include <c10/core/impl/alloc_cpu.h>
21#include <c10/util/Exception.h>
22#include <c10/util/Optional.h>
23#include <c10/util/ScopeExit.h>
24#include <caffe2/serialize/inline_container.h>
25#include <torch/csrc/jit/mobile/file_format.h>
26#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
27#include <torch/csrc/jit/mobile/function.h>
28#include <torch/csrc/jit/mobile/import.h>
29#include <torch/csrc/jit/mobile/interpreter.h>
30#include <torch/csrc/jit/mobile/module.h>
31#include <torch/csrc/jit/mobile/observer.h>
32#include <torch/csrc/jit/mobile/type_parser.h>
33#include <torch/csrc/jit/runtime/instruction.h>
34#include <torch/csrc/jit/serialization/export_bytecode.h>
35#include <torch/csrc/jit/serialization/import_export_constants.h>
36#include <torch/csrc/jit/serialization/import_read.h>
37#include <torch/custom_class.h>
38
39#ifndef DISABLE_UPGRADER
40#include <torch/csrc/jit/mobile/parse_bytecode.h>
41#include <torch/csrc/jit/mobile/upgrader_mobile.h>
42#endif
43
44#ifdef _WIN32
45#include <malloc.h>
46#else
47#include <cstdlib>
48#endif
49
50#if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
51#include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
52namespace flatbuffers = flatbuffers_fbsource;
53#define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
54#else
55#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
56#endif
57
58namespace torch {
59namespace jit {
60
61// Our own alignment requirement does not need to be exactly the same as what
62// flatbuffers supports, but what flatbuffers supports needs to satisfy our
63// requirement.
64static_assert(
65 kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
66 "Sizes must be compatible");
67static_assert(
68 (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
69 kFlatbufferDataAlignmentBytes,
70 "Must be a power of 2");
71
72namespace {
73
74static constexpr c10::string_view kCustomClassPrefix =
75 "__torch__.torch.classes";
76static constexpr c10::string_view kTorchPrefix = "__torch__";
77static constexpr c10::string_view kJitPrefix = "torch.jit";
78
79class FlatbufferLoader final {
80 public:
81 FlatbufferLoader();
82
83 typedef IValue (
84 *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
85 void registerIValueParser(
86 mobile::serialization::IValueUnion ivalue_type,
87 IValueParser parser);
88 mobile::Module parseModule(mobile::serialization::Module* module);
89
90 void extractJitSourceAndConstants(
91 ExtraFilesMap* jit_sources,
92 std::vector<IValue>* constants);
93
94 typedef TypePtr (*TypeResolver)(
95 const std::string& type_str,
96 std::shared_ptr<CompilationUnit> cu);
97
98 void internal_registerTypeResolver(TypeResolver type_resolver);
99
100 IValue& getIValue(uint32_t pos) {
101 TORCH_CHECK(pos < all_ivalues_.size());
102 return all_ivalues_[pos];
103 }
104
105 mobile::Function* getFunction(uint32_t pos) {
106 return all_functions_[pos];
107 }
108
109 ClassTypePtr getType(uint32_t pos) {
110 TORCH_CHECK(pos < all_types_.size());
111 return all_types_[pos];
112 }
113
114 c10::Storage getStorage(uint32_t index);
115 TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
116 ClassTypePtr getOrCreateClassTypeForObject(
117 const mobile::serialization::Object* object);
118
119 const mobile::serialization::Module* getCurrentFlatbufferInput() {
120 return module_;
121 }
122
123 void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
124 should_copy_tensor_memory_ = should_copy_tensor_memory;
125 }
126
127 std::shared_ptr<mobile::CompilationUnit> mcu_;
128 std::shared_ptr<CompilationUnit> cu_;
129
130 private:
131 IValue parseIValue(const mobile::serialization::IValue* ivalue);
132 std::unique_ptr<mobile::Function> parseFunction(
133 const mobile::serialization::Function* method);
134 void parseAndPopulate(
135 uint32_t i,
136 const mobile::serialization::IValue* ivalue);
137
138 std::unordered_map<uint32_t, mobile::Function*> all_functions_;
139 std::vector<ClassTypePtr> all_types_;
140 std::unordered_set<uint32_t> initialized_types_;
141 std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
142 std::vector<bool> storage_loaded_;
143 std::vector<c10::Storage> storages_;
144 std::vector<IValue> all_ivalues_;
145 std::array<
146 IValueParser,
147 static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
148 ivalue_parsers_;
149 TypeResolver type_resolver_ = nullptr;
150 mobile::serialization::Module* module_ = nullptr;
151 bool module_parsed_ = false;
152 bool should_copy_tensor_memory_ = false;
153 // 0 -> mobile_ivalue_size_ elements are from the mobile module.
154 uint32_t mobile_ivalue_size_ = 0;
155};
156
157IValue parseList(
158 FlatbufferLoader&,
159 const mobile::serialization::IValue& ivalue);
160IValue parseTensor(
161 FlatbufferLoader&,
162 const mobile::serialization::IValue& ivalue);
163IValue parseTuple(
164 FlatbufferLoader&,
165 const mobile::serialization::IValue& ivalue);
166IValue parseDict(
167 FlatbufferLoader&,
168 const mobile::serialization::IValue& ivalue);
169IValue parseObject(
170 FlatbufferLoader&,
171 const mobile::serialization::IValue& ivalue);
172IValue parseIntList(
173 FlatbufferLoader&,
174 const mobile::serialization::IValue& ivalue);
175IValue parseDoubleList(
176 FlatbufferLoader&,
177 const mobile::serialization::IValue& ivalue);
178IValue parseBoolList(
179 FlatbufferLoader&,
180 const mobile::serialization::IValue& ivalue);
181IValue parseBasic(
182 FlatbufferLoader&,
183 const mobile::serialization::IValue& ivalue);
184IValue parseEnum(
185 FlatbufferLoader&,
186 const mobile::serialization::IValue& ivalue);
187
188TypePtr resolveType(
189 const std::string& type_string,
190 std::shared_ptr<CompilationUnit> cu) {
191 TypePtr type;
192 c10::string_view type_str(type_string);
193 if (type_str.starts_with(kCustomClassPrefix)) {
194 type = getCustomClass(type_string);
195 TORCH_CHECK(
196 type, "The implementation of class ", type_string, " cannot be found.");
197 } else if (
198 type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
199 c10::QualifiedName qn(type_string);
200 if (cu->get_class(qn) == nullptr) {
201 auto classtype = ClassType::create(qn, cu, true);
202 cu->register_type(classtype);
203 type = classtype;
204 } else {
205 type = cu->get_class(qn);
206 }
207 } else {
208 type = c10::parseType(type_string);
209 }
210 return type;
211}
212
213FlatbufferLoader::FlatbufferLoader()
214 : mcu_(std::make_shared<mobile::CompilationUnit>()),
215 cu_(std::make_shared<CompilationUnit>()),
216 ivalue_parsers_{nullptr} {
217 registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
218 registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
219 registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
220 registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
221 registerIValueParser(
222 mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
223 registerIValueParser(
224 mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
225 registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
226 registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
227 registerIValueParser(
228 mobile::serialization::IValueUnion::IntList, &parseIntList);
229 registerIValueParser(
230 mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
231 registerIValueParser(
232 mobile::serialization::IValueUnion::BoolList, &parseBoolList);
233 registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
234 registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
235 registerIValueParser(
236 mobile::serialization::IValueUnion::Object, &parseObject);
237 registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
238 registerIValueParser(
239 mobile::serialization::IValueUnion::EnumValue, &parseEnum);
240 internal_registerTypeResolver(&resolveType);
241}
242
243void FlatbufferLoader::registerIValueParser(
244 mobile::serialization::IValueUnion ivalue_type,
245 IValueParser parser) {
246 ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
247}
248
249void FlatbufferLoader::internal_registerTypeResolver(
250 TypeResolver type_resolver) {
251 type_resolver_ = type_resolver;
252}
253
254void parseExtraFilesFromVector(
255 const flatbuffers::Vector<flatbuffers::Offset<
256 torch::jit::mobile::serialization::ExtraFile>>* files,
257 ExtraFilesMap* extra_files) {
258 for (uint32_t i = 0; i < files->size(); ++i) {
259 const auto* extra_file = files->Get(i);
260 (*extra_files)[extra_file->name()->str()] = extra_file->content()->str();
261 }
262}
263
264void parseExtraFiles(
265 mobile::serialization::Module* module,
266 ExtraFilesMap& extra_files) {
267 auto extra_files_offsets = module->extra_files();
268 parseExtraFilesFromVector(extra_files_offsets, &extra_files);
269}
270
271void FlatbufferLoader::parseAndPopulate(
272 uint32_t i,
273 const mobile::serialization::IValue* ivalue) {
274 if (const auto* func = ivalue->val_as_Function()) {
275 auto func_ptr = parseFunction(func);
276 all_functions_[i] = func_ptr.get();
277 mcu_->register_function(std::move(func_ptr));
278 } else {
279 all_ivalues_[i] = parseIValue(ivalue);
280 }
281}
282
283mobile::Module FlatbufferLoader::parseModule(
284 mobile::serialization::Module* module) {
285 module_ = module;
286 all_ivalues_.clear();
287 all_types_.clear();
288 storages_.clear();
289 storage_loaded_.clear();
290 module_parsed_ = false;
291
292 const auto* ivalues = module->ivalues();
293 all_ivalues_.resize(ivalues->size());
294 all_types_.resize(module->object_types()->size());
295 storages_.resize(module->storage_data_size());
296 storage_loaded_.resize(module->storage_data_size(), false);
297
298 mobile_ivalue_size_ = module_->mobile_ivalue_size();
299 if (mobile_ivalue_size_ == 0) {
300 mobile_ivalue_size_ = ivalues->size();
301 }
302
303 for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
304 const auto* ival = ivalues->Get(i);
305 parseAndPopulate(i, ival);
306 }
307 IValue& module_ivalue = getIValue(module->state_obj());
308
309 // register functions
310 for (const auto& f : all_functions_) {
311 uint32_t class_index =
312 ivalues->Get(f.first)->val_as_Function()->class_type();
313 ClassTypePtr class_type = all_types_[class_index];
314 class_type->addMethod(f.second);
315 }
316
317 module_parsed_ = true;
318 auto m = mobile::Module(module_ivalue.toObject(), mcu_);
319 m.set_min_operator_version(module->operator_version());
320 m.set_bytecode_version(module->bytecode_version());
321 return m;
322}
323
324void appendUpgraderFunctions(mobile::Function* function) {
325#ifndef DISABLE_UPGRADER
326 for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
327 function->append_function(byteCodeFunctionWithOperator.function);
328 }
329#endif
330}
331
332std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
333 const mobile::serialization::Function* method) {
334 auto function = std::make_unique<mobile::Function>(
335 c10::QualifiedName(method->qn()->str()));
336 // TODO(qihan) add debug handle
337 // const auto* debug_handle = method->debug_info()->debug_handle();
338 for (const auto* inst : *method->instructions()) {
339 function->append_instruction(
340 static_cast<OpCode>(inst->op()), inst->x(), inst->n());
341 }
342
343 for (uint32_t i : *method->constants()) {
344 function->append_constant(getIValue(i));
345 }
346
347 appendUpgraderFunctions(function.get());
348 // 2. Decides if upgrader is needed
349 const uint32_t operator_version = module_->operator_version();
350 bool use_upgrader =
351 (operator_version < caffe2::serialize::kProducedFileFormatVersion);
352
353 for (const auto* op : *method->operators()) {
354 c10::optional<int> num_args = c10::nullopt;
355 if (op->num_args_serialized() > -1) {
356 num_args = op->num_args_serialized();
357 }
358
359 function->append_operator(
360 op->name()->str(), op->overload_name()->str(), num_args);
361 }
362
363 function->initialize_operators(true);
364
365 for (const auto i : *method->type_annotations()) {
366 function->append_type(getOrCreateTypeAnnotations(i));
367 }
368
369 // 3. If upgrader is needed, change change the OP instrunction to CALL
370 // instruction (In next PR, use_upgrader will be parsed to parseInstruction
371 // function and do the actual change)
372 if (use_upgrader) {
373#ifndef DISABLE_UPGRADER
374 applyUpgrader(function.get(), operator_version);
375#endif
376 }
377
378 function->set_register_size(method->register_size());
379 if (method->schema()) {
380 try {
381 auto parseArgList = [this](const auto* args_fb) {
382 std::vector<c10::Argument> args;
383 for (const auto* arg_tb : *args_fb) {
384 IValue default_value = getIValue(arg_tb->default_value());
385 TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
386 auto arg = c10::Argument(
387 arg_tb->name()->str(),
388 std::move(type_ptr),
389 c10::nullopt /*N*/,
390 std::move(default_value));
391 args.emplace_back(std::move(arg));
392 }
393 return args;
394 };
395 c10::FunctionSchema schema(
396 method->qn()->str(),
397 "" /*overload_name*/,
398 parseArgList(method->schema()->arguments()),
399 parseArgList(method->schema()->returns()),
400 false /*is_varargs*/,
401 false /*is_varret*/);
402
403 function->setSchema(std::move(schema));
404 } catch (const c10::Error& e) {
405 }
406 }
407 return function;
408}
409
410IValue parseEnum(
411 FlatbufferLoader& loader,
412 const mobile::serialization::IValue& ivalue) {
413 const auto* enum_val = ivalue.val_as_EnumValue();
414 auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
415 ->cast<c10::EnumType>();
416 AT_ASSERT(
417 enum_type,
418 "Enum with type: " + enum_val->type_name()->str() + " not found.");
419 IValue val = loader.getIValue(enum_val->value());
420 for (const auto& p : enum_type->enumNamesValues()) {
421 if (p.second == val) {
422 auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
423 enum_type, p.first, p.second);
424 return IValue(std::move(enum_holder));
425 }
426 }
427 AT_ASSERT(
428 false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
429}
430
431IValue parseBasic(
432 FlatbufferLoader&,
433 const mobile::serialization::IValue& ivalue) {
434 switch (ivalue.val_type()) {
435 case mobile::serialization::IValueUnion::NONE:
436 return {};
437 case mobile::serialization::IValueUnion::Int:
438 return ivalue.val_as_Int()->int_val();
439 case mobile::serialization::IValueUnion::Bool:
440 return ivalue.val_as_Bool()->bool_val();
441 case mobile::serialization::IValueUnion::Double:
442 return ivalue.val_as_Double()->double_val();
443 case mobile::serialization::IValueUnion::ComplexDouble: {
444 const auto* comp = ivalue.val_as_ComplexDouble();
445 return c10::complex<double>(comp->real(), comp->imag());
446 }
447 case mobile::serialization::IValueUnion::String:
448 return ivalue.val_as_String()->data()->str();
449 case mobile::serialization::IValueUnion::Device: {
450 return c10::Device(ivalue.val_as_Device()->str()->str());
451 }
452 default:
453 return {};
454 }
455}
456
457at::Tensor parseTensorFromMetadata(
458 FlatbufferLoader* loader,
459 const mobile::serialization::TensorMetadata* tensor_md) {
460 at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
461 auto options = at::CPU(type).options();
462 at::Tensor tensor;
463 if (tensor_md->quantized_schema() != nullptr) {
464 // is quantized
465 const auto* schema = tensor_md->quantized_schema();
466 auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
467 switch (qscheme_type) {
468 case at::kPerTensorAffine: {
469 tensor = at::_empty_affine_quantized(
470 {0}, options, schema->scale(), schema->zero_point());
471 } break;
472 case at::kPerChannelAffineFloatQParams:
473 case at::kPerChannelAffine: {
474 at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
475 at::Tensor zero_points =
476 parseTensorFromMetadata(loader, schema->zero_points());
477 tensor = at::_empty_per_channel_affine_quantized(
478 {0}, scales, zero_points, schema->axis(), options);
479 } break;
480 default:
481 TORCH_CHECK(
482 false,
483 "Unsupported tensor quantization type in serialization ",
484 toString(qscheme_type));
485 break;
486 }
487 } else {
488 tensor = at::empty({0}, options);
489 }
490 at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
491
492 c10::Storage storage;
493 storage = loader->getStorage(tensor_md->storage_location_index());
494 impl->set_storage_keep_dtype(storage);
495 impl->set_storage_offset(tensor_md->storage_offset());
496
497 std::vector<int64_t> size{
498 tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
499 std::vector<int64_t> stride{
500 tensor_md->strides()->begin(), tensor_md->strides()->end()};
501 impl->set_sizes_and_strides(size, stride);
502#ifndef MIN_EDGE_RUNTIME
503 tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
504#endif
505 return tensor;
506}
507
508IValue parseTensor(
509 FlatbufferLoader& loader,
510 const mobile::serialization::IValue& ivalue) {
511 const mobile::serialization::TensorMetadata* tensor_md =
512 ivalue.val_as_TensorMetadata();
513 return parseTensorFromMetadata(&loader, tensor_md);
514}
515
516IValue parseList(
517 FlatbufferLoader& loader,
518 const mobile::serialization::IValue& ivalue) {
519 const mobile::serialization::List* list = ivalue.val_as_List();
520 auto res = c10::impl::GenericList(AnyType::get());
521 for (int i : *list->items()) {
522 res.emplace_back(loader.getIValue(i));
523 }
524 auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
525 res.unsafeSetElementType(type->containedType(0));
526 return res;
527}
528
529template <typename T, typename U>
530std::vector<T> parseListNative(const U* list) {
531 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
532 return {list->items()->begin(), list->items()->end()};
533}
534
535IValue parseIntList(
536 FlatbufferLoader&,
537 const mobile::serialization::IValue& ivalue) {
538 const auto& list = ivalue.val_as_IntList();
539 return parseListNative<int64_t>(list);
540}
541
542IValue parseDoubleList(
543 FlatbufferLoader&,
544 const mobile::serialization::IValue& ivalue) {
545 const auto& list = ivalue.val_as_DoubleList();
546 return parseListNative<double>(list);
547}
548
549IValue parseBoolList(
550 FlatbufferLoader&,
551 const mobile::serialization::IValue& ivalue) {
552 const auto& list = ivalue.val_as_BoolList();
553 std::vector<uint8_t> res = parseListNative<uint8_t>(list);
554 c10::List<bool> boollist;
555 for (auto x : res) {
556 boollist.push_back(x);
557 }
558 return boollist;
559}
560
561IValue parseTuple(
562 FlatbufferLoader& loader,
563 const mobile::serialization::IValue& ivalue) {
564 const auto& tuple = ivalue.val_as_Tuple();
565 std::vector<IValue> res;
566 for (int i : *tuple->items()) {
567 res.emplace_back(loader.getIValue(i));
568 }
569 return c10::ivalue::Tuple::create(res);
570}
571
572IValue parseDict(
573 FlatbufferLoader& loader,
574 const mobile::serialization::IValue& ivalue) {
575 const auto* dict = ivalue.val_as_Dict();
576 auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
577 const auto* keys = dict->keys();
578 const auto* values = dict->values();
579 for (size_t i = 0; i < keys->size(); ++i) {
580 uint32_t key = keys->Get(i);
581 uint32_t val = values->Get(i);
582 result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
583 }
584 auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
585 result.unsafeSetKeyType(type->containedType(0));
586 result.unsafeSetValueType(type->containedType(1));
587 return result;
588}
589
590ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
591 const mobile::serialization::Object* object) {
592 auto cls = getType(object->type_index());
593 const mobile::serialization::ObjectType* obj_type =
594 module_->object_types()->Get(object->type_index());
595 if (cls == nullptr) {
596 c10::string_view qn_str(
597 obj_type->type_name()->c_str(), obj_type->type_name()->size());
598 if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
599 c10::QualifiedName qn(obj_type->type_name()->str());
600 cls = cu_->get_class(qn);
601 if (cls == nullptr) {
602 cls = ClassType::create(qn, cu_, true);
603 cu_->register_type(cls);
604 }
605 } else {
606 cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
607 }
608 TORCH_CHECK(object->type_index() < all_ivalues_.size());
609 all_types_[object->type_index()] = cls;
610
611 if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
612 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
613 IValue val = getIValue(object->attrs()->Get(i));
614 // Need to use concrete object's field's type to set type of field.
615 cls->addAttribute(
616 obj_type->attr_names()->Get(i)->str(),
617 val.type<c10::DynamicType>());
618 }
619 }
620 initialized_types_.insert(object->type_index());
621 }
622 return cls;
623}
624
625IValue parseObject(
626 FlatbufferLoader& loader,
627 const mobile::serialization::IValue& ivalue) {
628 const mobile::serialization::Object* object = ivalue.val_as_Object();
629 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
630 const auto* cur_input = loader.getCurrentFlatbufferInput();
631 const mobile::serialization::ObjectType* obj_type =
632 cur_input->object_types()->Get(object->type_index());
633 auto cls = loader.getOrCreateClassTypeForObject(object);
634 Stack stack;
635 switch (obj_type->type()) {
636 case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
637 auto obj = c10::ivalue::Object::create(
638 at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
639 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
640 IValue val = loader.getIValue(object->attrs()->Get(i));
641 obj->setSlot(i, std::move(val));
642 }
643 return obj;
644 }
645 case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
646 IValue input = loader.getIValue(object->state());
647 mobile::Function* setstate = loader.getFunction(object->setstate_func());
648 auto obj =
649 c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
650 stack.emplace_back(obj);
651 stack.emplace_back(std::move(input));
652 setstate->run(stack);
653 return obj;
654 }
655 case mobile::serialization::TypeType::CUSTOM_CLASS: {
656 auto custom_class_type =
657 torch::jit::getCustomClass(cls->name()->qualifiedName());
658 IValue input = loader.getIValue(object->state());
659 auto obj = c10::ivalue::Object::create(
660 c10::StrongTypePtr(nullptr, custom_class_type), 1);
661 stack.emplace_back(obj);
662 stack.emplace_back(std::move(input));
663 custom_class_type->getMethod("__setstate__").run(stack);
664 return obj;
665 }
666 default:
667 AT_ASSERT(false, "need to be object");
668 }
669}
670
671IValue FlatbufferLoader::parseIValue(
672 const mobile::serialization::IValue* ivalue) {
673 return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
674 *this, *ivalue);
675}
676
677void deleteNothing2(void*);
678void deleteNothing2(void*) {}
679
680c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
681 TORCH_CHECK(index < storage_loaded_.size());
682 TORCH_CHECK(index < storages_.size());
683 if (!storage_loaded_[index]) {
684 auto* storage = module_->storage_data()->GetMutableObject(index);
685 size_t size = storage->data()->size();
686
687 at::DataPtr data;
688 if (should_copy_tensor_memory_) {
689 auto* allocator = at::GetCPUAllocator();
690 data = allocator->allocate(size);
691 memcpy(data.get(), storage->data()->data(), size);
692 } else {
693 void* ptr = static_cast<void*>(storage->mutable_data()->data());
694 data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
695 }
696 storages_[index] =
697 c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
698 storage_loaded_[index] = true;
699 }
700 return storages_[index];
701}
702
703TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
704 const flatbuffers::String* offset) {
705 auto iter = type_annotations_.find(offset);
706 if (iter != type_annotations_.end()) {
707 return iter->second;
708 }
709 TypePtr type = type_resolver_(offset->str(), cu_);
710 type_annotations_[offset] = type;
711 return type;
712}
713
714void FlatbufferLoader::extractJitSourceAndConstants(
715 ExtraFilesMap* jit_sources,
716 std::vector<IValue>* constants) {
717 AT_ASSERT(
718 module_parsed_,
719 "Need to first parse a flatbuffer file before extracting jit_sources");
720
721 const auto* ivalues = module_->ivalues();
722 for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
723 const auto* ival = ivalues->Get(i);
724 parseAndPopulate(i, ival);
725 }
726 // register functions
727 for (const auto& f : all_functions_) {
728 if (f.first >= mobile_ivalue_size_) {
729 uint32_t class_index =
730 ivalues->Get(f.first)->val_as_Function()->class_type();
731 ClassTypePtr class_type = all_types_[class_index];
732 class_type->addMethod(f.second);
733 }
734 }
735 const auto* jit_constants = module_->jit_constants();
736 for (auto i = 0; i < jit_constants->size(); ++i) {
737 constants->emplace_back(getIValue(jit_constants->Get(i)));
738 }
739 parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
740}
741
742} // namespace
743
744mobile::Module parse_and_initialize_mobile_module(
745 void* data,
746 size_t,
747 c10::optional<at::Device>,
748 ExtraFilesMap* extra_files,
749 bool should_copy_tensor_memory) {
750 TORCH_CHECK(
751 mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
752 // TODO(T128189662): If not copying, enforce that data is aligned to
753 // kFlatbufferDataAlignmentBytes, and add unit tests.
754
755 FlatbufferLoader loader;
756 loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
757
758 // Flatbuffer doesn't seem to have a way to provide the buffer size when
759 // interacting with the buffer.
760 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
761 mobile::Module m = loader.parseModule(flatbuffer_module);
762 if (extra_files != nullptr) {
763 parseExtraFiles(flatbuffer_module, *extra_files);
764 }
765 return m;
766}
767
768mobile::Module parse_and_initialize_mobile_module(
769 std::shared_ptr<char> data,
770 size_t size,
771 c10::optional<at::Device> device,
772 ExtraFilesMap* extra_files) {
773 mobile::Module m = parse_and_initialize_mobile_module(
774 data.get(),
775 size,
776 device,
777 extra_files,
778 /*should_copy_tensor_memory=*/false);
779 m.set_delete_memory(std::move(data));
780 return m;
781}
782
783mobile::Module parse_and_initialize_mobile_module_for_jit(
784 void* data,
785 size_t,
786 ExtraFilesMap& jit_sources,
787 std::vector<IValue>& jit_constants,
788 c10::optional<at::Device>,
789 ExtraFilesMap* extra_files) {
790 TORCH_CHECK(
791 mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
792 // TODO(T128189662): Enforce that data is aligned to
793 // kFlatbufferDataAlignmentBytes, and add unit tests.
794
795 FlatbufferLoader loader;
796 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
797 mobile::Module m = loader.parseModule(flatbuffer_module);
798 if (extra_files != nullptr) {
799 parseExtraFiles(flatbuffer_module, *extra_files);
800 }
801
802 loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
803 return m;
804}
805
806mobile::Module load_mobile_module_from_file(
807 const std::string& filename,
808 c10::optional<c10::Device> device,
809 ExtraFilesMap* extra_files) {
810 std::shared_ptr<char> data;
811 size_t size = 0;
812 std::tie(data, size) = get_file_content(filename.c_str());
813 return parse_and_initialize_mobile_module(
814 std::move(data), size, device, extra_files);
815}
816
817uint64_t get_bytecode_version(std::istream& in) {
818 std::shared_ptr<char> data;
819 size_t size = 0;
820 std::tie(data, size) = get_stream_content(in);
821 return get_bytecode_version_from_bytes(data.get());
822}
823
824uint64_t get_bytecode_version(const std::string& filename) {
825 std::shared_ptr<char> data;
826 size_t size = 0;
827 std::tie(data, size) = get_file_content(filename.c_str());
828 return get_bytecode_version_from_bytes(data.get());
829}
830
831uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
832 TORCH_CHECK(
833 mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
834 "Format error");
835 auto* flatbuffer_module =
836 mobile::serialization::GetMutableModule(flatbuffer_content);
837 return flatbuffer_module->bytecode_version();
838}
839
840mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
841 auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
842 mobile::ModuleInfo minfo;
843 minfo.operator_version = ff_module->operator_version();
844 minfo.bytecode_version = ff_module->bytecode_version();
845
846 uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size();
847 if (mobile_ivalue_size == 0) {
848 mobile_ivalue_size = ff_module->ivalues()->size();
849 }
850
851 std::vector<std::string> type_name_list;
852 for (uint32_t i = 0; i < mobile_ivalue_size; i++) {
853 const auto* ival = ff_module->ivalues()->Get(i);
854 if (const auto* func = ival->val_as_Function()) {
855 minfo.function_names.insert(func->qn()->str());
856 for (const auto* op : *func->operators()) {
857 at::OperatorName opname(op->name()->str(), op->overload_name()->str());
858 minfo.opname_to_num_args[mobile::operator_str(opname)] =
859 op->num_args_serialized();
860 }
861 for (const auto* type_ann : *func->type_annotations()) {
862 type_name_list.push_back(type_ann->str());
863 }
864 }
865 }
866 c10::TypeParser parser(type_name_list);
867 parser.parseList();
868 minfo.type_names = parser.getContainedTypes();
869 return minfo;
870}
871
872mobile::Module load_mobile_module_from_stream_with_copy(
873 std::istream& in,
874 c10::optional<at::Device> device,
875 ExtraFilesMap* extra_files) {
876 std::shared_ptr<char> data;
877 size_t size = 0;
878 std::tie(data, size) = get_stream_content(in);
879 return parse_and_initialize_mobile_module(
880 std::move(data), size, device, extra_files);
881}
882
883mobile::Module parse_flatbuffer_no_object(
884 std::shared_ptr<char> data,
885 size_t size,
886 c10::optional<at::Device> device) {
887 (void)device;
888 (void)size;
889 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
890 FlatbufferLoader loader;
891 // replace parserObject with to handle only class with field case
892 // function.
893 loader.registerIValueParser(
894 mobile::serialization::IValueUnion::Object,
895 +[](FlatbufferLoader& loader,
896 const mobile::serialization::IValue& ivalue) {
897 const mobile::serialization::Object* object = ivalue.val_as_Object();
898 auto cls = loader.getOrCreateClassTypeForObject(object);
899 auto obj = c10::ivalue::Object::create(
900 at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
901 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
902 IValue val = loader.getIValue(object->attrs()->Get(i));
903 obj->setSlot(i, std::move(val));
904 }
905 return static_cast<c10::IValue>(obj);
906 });
907
908 mobile::Module m = loader.parseModule(flatbuffer_module);
909 m.set_delete_memory(std::move(data));
910 return m;
911}
912
913bool register_flatbuffer_loader() {
914 return true;
915}
916
917} // namespace jit
918} // namespace torch
919