1 | #ifdef FLATBUFFERS_VERSION_MAJOR |
2 | #error "flatbuffer_loader.h must not include any flatbuffers headers" |
3 | #endif // FLATBUFFERS_VERSION_MAJOR |
4 | |
5 | #include <array> |
6 | #include <istream> |
7 | #include <memory> |
8 | #include <string> |
9 | #include <tuple> |
10 | #include <unordered_map> |
11 | #include <unordered_set> |
12 | #include <utility> |
13 | #include <vector> |
14 | |
15 | #include <ATen/ATen.h> |
16 | #include <ATen/core/dynamic_type.h> |
17 | #include <ATen/core/ivalue.h> |
18 | #include <ATen/core/qualified_name.h> |
19 | #include <c10/core/CPUAllocator.h> |
20 | #include <c10/core/impl/alloc_cpu.h> |
21 | #include <c10/util/Exception.h> |
22 | #include <c10/util/Optional.h> |
23 | #include <c10/util/ScopeExit.h> |
24 | #include <caffe2/serialize/inline_container.h> |
25 | #include <torch/csrc/jit/mobile/file_format.h> |
26 | #include <torch/csrc/jit/mobile/flatbuffer_loader.h> |
27 | #include <torch/csrc/jit/mobile/function.h> |
28 | #include <torch/csrc/jit/mobile/import.h> |
29 | #include <torch/csrc/jit/mobile/interpreter.h> |
30 | #include <torch/csrc/jit/mobile/module.h> |
31 | #include <torch/csrc/jit/mobile/observer.h> |
32 | #include <torch/csrc/jit/mobile/type_parser.h> |
33 | #include <torch/csrc/jit/runtime/instruction.h> |
34 | #include <torch/csrc/jit/serialization/export_bytecode.h> |
35 | #include <torch/csrc/jit/serialization/import_export_constants.h> |
36 | #include <torch/csrc/jit/serialization/import_read.h> |
37 | #include <torch/custom_class.h> |
38 | |
39 | #ifndef DISABLE_UPGRADER |
40 | #include <torch/csrc/jit/mobile/parse_bytecode.h> |
41 | #include <torch/csrc/jit/mobile/upgrader_mobile.h> |
42 | #endif |
43 | |
44 | #ifdef _WIN32 |
45 | #include <malloc.h> |
46 | #else |
47 | #include <cstdlib> |
48 | #endif |
49 | |
50 | #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2) |
51 | #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT |
52 | namespace flatbuffers = flatbuffers_fbsource; |
53 | #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT |
54 | #else |
55 | #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT |
56 | #endif |
57 | |
58 | namespace torch { |
59 | namespace jit { |
60 | |
61 | // Our own alignment requirement does not need to be exactly the same as what |
62 | // flatbuffers supports, but what flatbuffers supports needs to satisfy our |
63 | // requirement. |
64 | static_assert( |
65 | kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT, |
66 | "Sizes must be compatible" ); |
67 | static_assert( |
68 | (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) == |
69 | kFlatbufferDataAlignmentBytes, |
70 | "Must be a power of 2" ); |
71 | |
72 | namespace { |
73 | |
74 | static constexpr c10::string_view kCustomClassPrefix = |
75 | "__torch__.torch.classes" ; |
76 | static constexpr c10::string_view kTorchPrefix = "__torch__" ; |
77 | static constexpr c10::string_view kJitPrefix = "torch.jit" ; |
78 | |
79 | class FlatbufferLoader final { |
80 | public: |
81 | FlatbufferLoader(); |
82 | |
83 | typedef IValue ( |
84 | *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&); |
85 | void registerIValueParser( |
86 | mobile::serialization::IValueUnion ivalue_type, |
87 | IValueParser parser); |
88 | mobile::Module parseModule(mobile::serialization::Module* module); |
89 | |
90 | void extractJitSourceAndConstants( |
91 | ExtraFilesMap* jit_sources, |
92 | std::vector<IValue>* constants); |
93 | |
94 | typedef TypePtr (*TypeResolver)( |
95 | const std::string& type_str, |
96 | std::shared_ptr<CompilationUnit> cu); |
97 | |
98 | void internal_registerTypeResolver(TypeResolver type_resolver); |
99 | |
100 | IValue& getIValue(uint32_t pos) { |
101 | TORCH_CHECK(pos < all_ivalues_.size()); |
102 | return all_ivalues_[pos]; |
103 | } |
104 | |
105 | mobile::Function* getFunction(uint32_t pos) { |
106 | return all_functions_[pos]; |
107 | } |
108 | |
109 | ClassTypePtr getType(uint32_t pos) { |
110 | TORCH_CHECK(pos < all_types_.size()); |
111 | return all_types_[pos]; |
112 | } |
113 | |
114 | c10::Storage getStorage(uint32_t index); |
115 | TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset); |
116 | ClassTypePtr getOrCreateClassTypeForObject( |
117 | const mobile::serialization::Object* object); |
118 | |
119 | const mobile::serialization::Module* getCurrentFlatbufferInput() { |
120 | return module_; |
121 | } |
122 | |
123 | void setShouldCopyTensorMemory(bool should_copy_tensor_memory) { |
124 | should_copy_tensor_memory_ = should_copy_tensor_memory; |
125 | } |
126 | |
127 | std::shared_ptr<mobile::CompilationUnit> mcu_; |
128 | std::shared_ptr<CompilationUnit> cu_; |
129 | |
130 | private: |
131 | IValue parseIValue(const mobile::serialization::IValue* ivalue); |
132 | std::unique_ptr<mobile::Function> parseFunction( |
133 | const mobile::serialization::Function* method); |
134 | void parseAndPopulate( |
135 | uint32_t i, |
136 | const mobile::serialization::IValue* ivalue); |
137 | |
138 | std::unordered_map<uint32_t, mobile::Function*> all_functions_; |
139 | std::vector<ClassTypePtr> all_types_; |
140 | std::unordered_set<uint32_t> initialized_types_; |
141 | std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_; |
142 | std::vector<bool> storage_loaded_; |
143 | std::vector<c10::Storage> storages_; |
144 | std::vector<IValue> all_ivalues_; |
145 | std::array< |
146 | IValueParser, |
147 | static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1> |
148 | ivalue_parsers_; |
149 | TypeResolver type_resolver_ = nullptr; |
150 | mobile::serialization::Module* module_ = nullptr; |
151 | bool module_parsed_ = false; |
152 | bool should_copy_tensor_memory_ = false; |
153 | // 0 -> mobile_ivalue_size_ elements are from the mobile module. |
154 | uint32_t mobile_ivalue_size_ = 0; |
155 | }; |
156 | |
157 | IValue parseList( |
158 | FlatbufferLoader&, |
159 | const mobile::serialization::IValue& ivalue); |
160 | IValue parseTensor( |
161 | FlatbufferLoader&, |
162 | const mobile::serialization::IValue& ivalue); |
163 | IValue parseTuple( |
164 | FlatbufferLoader&, |
165 | const mobile::serialization::IValue& ivalue); |
166 | IValue parseDict( |
167 | FlatbufferLoader&, |
168 | const mobile::serialization::IValue& ivalue); |
169 | IValue parseObject( |
170 | FlatbufferLoader&, |
171 | const mobile::serialization::IValue& ivalue); |
172 | IValue parseIntList( |
173 | FlatbufferLoader&, |
174 | const mobile::serialization::IValue& ivalue); |
175 | IValue parseDoubleList( |
176 | FlatbufferLoader&, |
177 | const mobile::serialization::IValue& ivalue); |
178 | IValue parseBoolList( |
179 | FlatbufferLoader&, |
180 | const mobile::serialization::IValue& ivalue); |
181 | IValue parseBasic( |
182 | FlatbufferLoader&, |
183 | const mobile::serialization::IValue& ivalue); |
184 | IValue parseEnum( |
185 | FlatbufferLoader&, |
186 | const mobile::serialization::IValue& ivalue); |
187 | |
188 | TypePtr resolveType( |
189 | const std::string& type_string, |
190 | std::shared_ptr<CompilationUnit> cu) { |
191 | TypePtr type; |
192 | c10::string_view type_str(type_string); |
193 | if (type_str.starts_with(kCustomClassPrefix)) { |
194 | type = getCustomClass(type_string); |
195 | TORCH_CHECK( |
196 | type, "The implementation of class " , type_string, " cannot be found." ); |
197 | } else if ( |
198 | type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) { |
199 | c10::QualifiedName qn(type_string); |
200 | if (cu->get_class(qn) == nullptr) { |
201 | auto classtype = ClassType::create(qn, cu, true); |
202 | cu->register_type(classtype); |
203 | type = classtype; |
204 | } else { |
205 | type = cu->get_class(qn); |
206 | } |
207 | } else { |
208 | type = c10::parseType(type_string); |
209 | } |
210 | return type; |
211 | } |
212 | |
213 | FlatbufferLoader::FlatbufferLoader() |
214 | : mcu_(std::make_shared<mobile::CompilationUnit>()), |
215 | cu_(std::make_shared<CompilationUnit>()), |
216 | ivalue_parsers_{nullptr} { |
217 | registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic); |
218 | registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic); |
219 | registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic); |
220 | registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic); |
221 | registerIValueParser( |
222 | mobile::serialization::IValueUnion::ComplexDouble, &parseBasic); |
223 | registerIValueParser( |
224 | mobile::serialization::IValueUnion::TensorMetadata, &parseTensor); |
225 | registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic); |
226 | registerIValueParser(mobile::serialization::IValueUnion::List, &parseList); |
227 | registerIValueParser( |
228 | mobile::serialization::IValueUnion::IntList, &parseIntList); |
229 | registerIValueParser( |
230 | mobile::serialization::IValueUnion::DoubleList, &parseDoubleList); |
231 | registerIValueParser( |
232 | mobile::serialization::IValueUnion::BoolList, &parseBoolList); |
233 | registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple); |
234 | registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict); |
235 | registerIValueParser( |
236 | mobile::serialization::IValueUnion::Object, &parseObject); |
237 | registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic); |
238 | registerIValueParser( |
239 | mobile::serialization::IValueUnion::EnumValue, &parseEnum); |
240 | internal_registerTypeResolver(&resolveType); |
241 | } |
242 | |
243 | void FlatbufferLoader::registerIValueParser( |
244 | mobile::serialization::IValueUnion ivalue_type, |
245 | IValueParser parser) { |
246 | ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser; |
247 | } |
248 | |
249 | void FlatbufferLoader::internal_registerTypeResolver( |
250 | TypeResolver type_resolver) { |
251 | type_resolver_ = type_resolver; |
252 | } |
253 | |
254 | void ( |
255 | const flatbuffers::Vector<flatbuffers::Offset< |
256 | torch::jit::mobile::serialization::ExtraFile>>* files, |
257 | ExtraFilesMap* ) { |
258 | for (uint32_t i = 0; i < files->size(); ++i) { |
259 | const auto* = files->Get(i); |
260 | (*extra_files)[extra_file->name()->str()] = extra_file->content()->str(); |
261 | } |
262 | } |
263 | |
264 | void ( |
265 | mobile::serialization::Module* module, |
266 | ExtraFilesMap& ) { |
267 | auto = module->extra_files(); |
268 | parseExtraFilesFromVector(extra_files_offsets, &extra_files); |
269 | } |
270 | |
271 | void FlatbufferLoader::parseAndPopulate( |
272 | uint32_t i, |
273 | const mobile::serialization::IValue* ivalue) { |
274 | if (const auto* func = ivalue->val_as_Function()) { |
275 | auto func_ptr = parseFunction(func); |
276 | all_functions_[i] = func_ptr.get(); |
277 | mcu_->register_function(std::move(func_ptr)); |
278 | } else { |
279 | all_ivalues_[i] = parseIValue(ivalue); |
280 | } |
281 | } |
282 | |
283 | mobile::Module FlatbufferLoader::parseModule( |
284 | mobile::serialization::Module* module) { |
285 | module_ = module; |
286 | all_ivalues_.clear(); |
287 | all_types_.clear(); |
288 | storages_.clear(); |
289 | storage_loaded_.clear(); |
290 | module_parsed_ = false; |
291 | |
292 | const auto* ivalues = module->ivalues(); |
293 | all_ivalues_.resize(ivalues->size()); |
294 | all_types_.resize(module->object_types()->size()); |
295 | storages_.resize(module->storage_data_size()); |
296 | storage_loaded_.resize(module->storage_data_size(), false); |
297 | |
298 | mobile_ivalue_size_ = module_->mobile_ivalue_size(); |
299 | if (mobile_ivalue_size_ == 0) { |
300 | mobile_ivalue_size_ = ivalues->size(); |
301 | } |
302 | |
303 | for (uint32_t i = 0; i < mobile_ivalue_size_; i++) { |
304 | const auto* ival = ivalues->Get(i); |
305 | parseAndPopulate(i, ival); |
306 | } |
307 | IValue& module_ivalue = getIValue(module->state_obj()); |
308 | |
309 | // register functions |
310 | for (const auto& f : all_functions_) { |
311 | uint32_t class_index = |
312 | ivalues->Get(f.first)->val_as_Function()->class_type(); |
313 | ClassTypePtr class_type = all_types_[class_index]; |
314 | class_type->addMethod(f.second); |
315 | } |
316 | |
317 | module_parsed_ = true; |
318 | auto m = mobile::Module(module_ivalue.toObject(), mcu_); |
319 | m.set_min_operator_version(module->operator_version()); |
320 | m.set_bytecode_version(module->bytecode_version()); |
321 | return m; |
322 | } |
323 | |
324 | void appendUpgraderFunctions(mobile::Function* function) { |
325 | #ifndef DISABLE_UPGRADER |
326 | for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) { |
327 | function->append_function(byteCodeFunctionWithOperator.function); |
328 | } |
329 | #endif |
330 | } |
331 | |
332 | std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction( |
333 | const mobile::serialization::Function* method) { |
334 | auto function = std::make_unique<mobile::Function>( |
335 | c10::QualifiedName(method->qn()->str())); |
336 | // TODO(qihan) add debug handle |
337 | // const auto* debug_handle = method->debug_info()->debug_handle(); |
338 | for (const auto* inst : *method->instructions()) { |
339 | function->append_instruction( |
340 | static_cast<OpCode>(inst->op()), inst->x(), inst->n()); |
341 | } |
342 | |
343 | for (uint32_t i : *method->constants()) { |
344 | function->append_constant(getIValue(i)); |
345 | } |
346 | |
347 | appendUpgraderFunctions(function.get()); |
348 | // 2. Decides if upgrader is needed |
349 | const uint32_t operator_version = module_->operator_version(); |
350 | bool use_upgrader = |
351 | (operator_version < caffe2::serialize::kProducedFileFormatVersion); |
352 | |
353 | for (const auto* op : *method->operators()) { |
354 | c10::optional<int> num_args = c10::nullopt; |
355 | if (op->num_args_serialized() > -1) { |
356 | num_args = op->num_args_serialized(); |
357 | } |
358 | |
359 | function->append_operator( |
360 | op->name()->str(), op->overload_name()->str(), num_args); |
361 | } |
362 | |
363 | function->initialize_operators(true); |
364 | |
365 | for (const auto i : *method->type_annotations()) { |
366 | function->append_type(getOrCreateTypeAnnotations(i)); |
367 | } |
368 | |
369 | // 3. If upgrader is needed, change change the OP instrunction to CALL |
370 | // instruction (In next PR, use_upgrader will be parsed to parseInstruction |
371 | // function and do the actual change) |
372 | if (use_upgrader) { |
373 | #ifndef DISABLE_UPGRADER |
374 | applyUpgrader(function.get(), operator_version); |
375 | #endif |
376 | } |
377 | |
378 | function->set_register_size(method->register_size()); |
379 | if (method->schema()) { |
380 | try { |
381 | auto parseArgList = [this](const auto* args_fb) { |
382 | std::vector<c10::Argument> args; |
383 | for (const auto* arg_tb : *args_fb) { |
384 | IValue default_value = getIValue(arg_tb->default_value()); |
385 | TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type()); |
386 | auto arg = c10::Argument( |
387 | arg_tb->name()->str(), |
388 | std::move(type_ptr), |
389 | c10::nullopt /*N*/, |
390 | std::move(default_value)); |
391 | args.emplace_back(std::move(arg)); |
392 | } |
393 | return args; |
394 | }; |
395 | c10::FunctionSchema schema( |
396 | method->qn()->str(), |
397 | "" /*overload_name*/, |
398 | parseArgList(method->schema()->arguments()), |
399 | parseArgList(method->schema()->returns()), |
400 | false /*is_varargs*/, |
401 | false /*is_varret*/); |
402 | |
403 | function->setSchema(std::move(schema)); |
404 | } catch (const c10::Error& e) { |
405 | } |
406 | } |
407 | return function; |
408 | } |
409 | |
410 | IValue parseEnum( |
411 | FlatbufferLoader& loader, |
412 | const mobile::serialization::IValue& ivalue) { |
413 | const auto* enum_val = ivalue.val_as_EnumValue(); |
414 | auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name()) |
415 | ->cast<c10::EnumType>(); |
416 | AT_ASSERT( |
417 | enum_type, |
418 | "Enum with type: " + enum_val->type_name()->str() + " not found." ); |
419 | IValue val = loader.getIValue(enum_val->value()); |
420 | for (const auto& p : enum_type->enumNamesValues()) { |
421 | if (p.second == val) { |
422 | auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>( |
423 | enum_type, p.first, p.second); |
424 | return IValue(std::move(enum_holder)); |
425 | } |
426 | } |
427 | AT_ASSERT( |
428 | false, "Enum with type: " + enum_val->type_name()->str() + " not found." ); |
429 | } |
430 | |
431 | IValue parseBasic( |
432 | FlatbufferLoader&, |
433 | const mobile::serialization::IValue& ivalue) { |
434 | switch (ivalue.val_type()) { |
435 | case mobile::serialization::IValueUnion::NONE: |
436 | return {}; |
437 | case mobile::serialization::IValueUnion::Int: |
438 | return ivalue.val_as_Int()->int_val(); |
439 | case mobile::serialization::IValueUnion::Bool: |
440 | return ivalue.val_as_Bool()->bool_val(); |
441 | case mobile::serialization::IValueUnion::Double: |
442 | return ivalue.val_as_Double()->double_val(); |
443 | case mobile::serialization::IValueUnion::ComplexDouble: { |
444 | const auto* comp = ivalue.val_as_ComplexDouble(); |
445 | return c10::complex<double>(comp->real(), comp->imag()); |
446 | } |
447 | case mobile::serialization::IValueUnion::String: |
448 | return ivalue.val_as_String()->data()->str(); |
449 | case mobile::serialization::IValueUnion::Device: { |
450 | return c10::Device(ivalue.val_as_Device()->str()->str()); |
451 | } |
452 | default: |
453 | return {}; |
454 | } |
455 | } |
456 | |
457 | at::Tensor parseTensorFromMetadata( |
458 | FlatbufferLoader* loader, |
459 | const mobile::serialization::TensorMetadata* tensor_md) { |
460 | at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type()); |
461 | auto options = at::CPU(type).options(); |
462 | at::Tensor tensor; |
463 | if (tensor_md->quantized_schema() != nullptr) { |
464 | // is quantized |
465 | const auto* schema = tensor_md->quantized_schema(); |
466 | auto qscheme_type = static_cast<at::QScheme>(schema->qscheme()); |
467 | switch (qscheme_type) { |
468 | case at::kPerTensorAffine: { |
469 | tensor = at::_empty_affine_quantized( |
470 | {0}, options, schema->scale(), schema->zero_point()); |
471 | } break; |
472 | case at::kPerChannelAffineFloatQParams: |
473 | case at::kPerChannelAffine: { |
474 | at::Tensor scales = parseTensorFromMetadata(loader, schema->scales()); |
475 | at::Tensor zero_points = |
476 | parseTensorFromMetadata(loader, schema->zero_points()); |
477 | tensor = at::_empty_per_channel_affine_quantized( |
478 | {0}, scales, zero_points, schema->axis(), options); |
479 | } break; |
480 | default: |
481 | TORCH_CHECK( |
482 | false, |
483 | "Unsupported tensor quantization type in serialization " , |
484 | toString(qscheme_type)); |
485 | break; |
486 | } |
487 | } else { |
488 | tensor = at::empty({0}, options); |
489 | } |
490 | at::TensorImpl* impl = tensor.unsafeGetTensorImpl(); |
491 | |
492 | c10::Storage storage; |
493 | storage = loader->getStorage(tensor_md->storage_location_index()); |
494 | impl->set_storage_keep_dtype(storage); |
495 | impl->set_storage_offset(tensor_md->storage_offset()); |
496 | |
497 | std::vector<int64_t> size{ |
498 | tensor_md->sizes()->begin(), tensor_md->sizes()->end()}; |
499 | std::vector<int64_t> stride{ |
500 | tensor_md->strides()->begin(), tensor_md->strides()->end()}; |
501 | impl->set_sizes_and_strides(size, stride); |
502 | #ifndef MIN_EDGE_RUNTIME |
503 | tensor = autograd::make_variable(tensor, tensor_md->requires_grad()); |
504 | #endif |
505 | return tensor; |
506 | } |
507 | |
508 | IValue parseTensor( |
509 | FlatbufferLoader& loader, |
510 | const mobile::serialization::IValue& ivalue) { |
511 | const mobile::serialization::TensorMetadata* tensor_md = |
512 | ivalue.val_as_TensorMetadata(); |
513 | return parseTensorFromMetadata(&loader, tensor_md); |
514 | } |
515 | |
516 | IValue parseList( |
517 | FlatbufferLoader& loader, |
518 | const mobile::serialization::IValue& ivalue) { |
519 | const mobile::serialization::List* list = ivalue.val_as_List(); |
520 | auto res = c10::impl::GenericList(AnyType::get()); |
521 | for (int i : *list->items()) { |
522 | res.emplace_back(loader.getIValue(i)); |
523 | } |
524 | auto type = loader.getOrCreateTypeAnnotations(list->annotation_str()); |
525 | res.unsafeSetElementType(type->containedType(0)); |
526 | return res; |
527 | } |
528 | |
529 | template <typename T, typename U> |
530 | std::vector<T> parseListNative(const U* list) { |
531 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr); |
532 | return {list->items()->begin(), list->items()->end()}; |
533 | } |
534 | |
535 | IValue parseIntList( |
536 | FlatbufferLoader&, |
537 | const mobile::serialization::IValue& ivalue) { |
538 | const auto& list = ivalue.val_as_IntList(); |
539 | return parseListNative<int64_t>(list); |
540 | } |
541 | |
542 | IValue parseDoubleList( |
543 | FlatbufferLoader&, |
544 | const mobile::serialization::IValue& ivalue) { |
545 | const auto& list = ivalue.val_as_DoubleList(); |
546 | return parseListNative<double>(list); |
547 | } |
548 | |
549 | IValue parseBoolList( |
550 | FlatbufferLoader&, |
551 | const mobile::serialization::IValue& ivalue) { |
552 | const auto& list = ivalue.val_as_BoolList(); |
553 | std::vector<uint8_t> res = parseListNative<uint8_t>(list); |
554 | c10::List<bool> boollist; |
555 | for (auto x : res) { |
556 | boollist.push_back(x); |
557 | } |
558 | return boollist; |
559 | } |
560 | |
561 | IValue parseTuple( |
562 | FlatbufferLoader& loader, |
563 | const mobile::serialization::IValue& ivalue) { |
564 | const auto& tuple = ivalue.val_as_Tuple(); |
565 | std::vector<IValue> res; |
566 | for (int i : *tuple->items()) { |
567 | res.emplace_back(loader.getIValue(i)); |
568 | } |
569 | return c10::ivalue::Tuple::create(res); |
570 | } |
571 | |
572 | IValue parseDict( |
573 | FlatbufferLoader& loader, |
574 | const mobile::serialization::IValue& ivalue) { |
575 | const auto* dict = ivalue.val_as_Dict(); |
576 | auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get()); |
577 | const auto* keys = dict->keys(); |
578 | const auto* values = dict->values(); |
579 | for (size_t i = 0; i < keys->size(); ++i) { |
580 | uint32_t key = keys->Get(i); |
581 | uint32_t val = values->Get(i); |
582 | result.insert_or_assign(loader.getIValue(key), loader.getIValue(val)); |
583 | } |
584 | auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str()); |
585 | result.unsafeSetKeyType(type->containedType(0)); |
586 | result.unsafeSetValueType(type->containedType(1)); |
587 | return result; |
588 | } |
589 | |
590 | ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject( |
591 | const mobile::serialization::Object* object) { |
592 | auto cls = getType(object->type_index()); |
593 | const mobile::serialization::ObjectType* obj_type = |
594 | module_->object_types()->Get(object->type_index()); |
595 | if (cls == nullptr) { |
596 | c10::string_view qn_str( |
597 | obj_type->type_name()->c_str(), obj_type->type_name()->size()); |
598 | if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) { |
599 | c10::QualifiedName qn(obj_type->type_name()->str()); |
600 | cls = cu_->get_class(qn); |
601 | if (cls == nullptr) { |
602 | cls = ClassType::create(qn, cu_, true); |
603 | cu_->register_type(cls); |
604 | } |
605 | } else { |
606 | cls = c10::parseType(std::string(qn_str))->cast<ClassType>(); |
607 | } |
608 | TORCH_CHECK(object->type_index() < all_ivalues_.size()); |
609 | all_types_[object->type_index()] = cls; |
610 | |
611 | if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) { |
612 | for (uint32_t i = 0; i < object->attrs()->size(); i++) { |
613 | IValue val = getIValue(object->attrs()->Get(i)); |
614 | // Need to use concrete object's field's type to set type of field. |
615 | cls->addAttribute( |
616 | obj_type->attr_names()->Get(i)->str(), |
617 | val.type<c10::DynamicType>()); |
618 | } |
619 | } |
620 | initialized_types_.insert(object->type_index()); |
621 | } |
622 | return cls; |
623 | } |
624 | |
625 | IValue parseObject( |
626 | FlatbufferLoader& loader, |
627 | const mobile::serialization::IValue& ivalue) { |
628 | const mobile::serialization::Object* object = ivalue.val_as_Object(); |
629 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr); |
630 | const auto* cur_input = loader.getCurrentFlatbufferInput(); |
631 | const mobile::serialization::ObjectType* obj_type = |
632 | cur_input->object_types()->Get(object->type_index()); |
633 | auto cls = loader.getOrCreateClassTypeForObject(object); |
634 | Stack stack; |
635 | switch (obj_type->type()) { |
636 | case mobile::serialization::TypeType::CLASS_WITH_FIELD: { |
637 | auto obj = c10::ivalue::Object::create( |
638 | at::StrongTypePtr(loader.cu_, cls), object->attrs()->size()); |
639 | for (uint32_t i = 0; i < object->attrs()->size(); i++) { |
640 | IValue val = loader.getIValue(object->attrs()->Get(i)); |
641 | obj->setSlot(i, std::move(val)); |
642 | } |
643 | return obj; |
644 | } |
645 | case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: { |
646 | IValue input = loader.getIValue(object->state()); |
647 | mobile::Function* setstate = loader.getFunction(object->setstate_func()); |
648 | auto obj = |
649 | c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0); |
650 | stack.emplace_back(obj); |
651 | stack.emplace_back(std::move(input)); |
652 | setstate->run(stack); |
653 | return obj; |
654 | } |
655 | case mobile::serialization::TypeType::CUSTOM_CLASS: { |
656 | auto custom_class_type = |
657 | torch::jit::getCustomClass(cls->name()->qualifiedName()); |
658 | IValue input = loader.getIValue(object->state()); |
659 | auto obj = c10::ivalue::Object::create( |
660 | c10::StrongTypePtr(nullptr, custom_class_type), 1); |
661 | stack.emplace_back(obj); |
662 | stack.emplace_back(std::move(input)); |
663 | custom_class_type->getMethod("__setstate__" ).run(stack); |
664 | return obj; |
665 | } |
666 | default: |
667 | AT_ASSERT(false, "need to be object" ); |
668 | } |
669 | } |
670 | |
671 | IValue FlatbufferLoader::parseIValue( |
672 | const mobile::serialization::IValue* ivalue) { |
673 | return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())]( |
674 | *this, *ivalue); |
675 | } |
676 | |
677 | void deleteNothing2(void*); |
678 | void deleteNothing2(void*) {} |
679 | |
680 | c10::Storage FlatbufferLoader::getStorage(uint32_t index) { |
681 | TORCH_CHECK(index < storage_loaded_.size()); |
682 | TORCH_CHECK(index < storages_.size()); |
683 | if (!storage_loaded_[index]) { |
684 | auto* storage = module_->storage_data()->GetMutableObject(index); |
685 | size_t size = storage->data()->size(); |
686 | |
687 | at::DataPtr data; |
688 | if (should_copy_tensor_memory_) { |
689 | auto* allocator = at::GetCPUAllocator(); |
690 | data = allocator->allocate(size); |
691 | memcpy(data.get(), storage->data()->data(), size); |
692 | } else { |
693 | void* ptr = static_cast<void*>(storage->mutable_data()->data()); |
694 | data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU); |
695 | } |
696 | storages_[index] = |
697 | c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data)); |
698 | storage_loaded_[index] = true; |
699 | } |
700 | return storages_[index]; |
701 | } |
702 | |
703 | TypePtr FlatbufferLoader::getOrCreateTypeAnnotations( |
704 | const flatbuffers::String* offset) { |
705 | auto iter = type_annotations_.find(offset); |
706 | if (iter != type_annotations_.end()) { |
707 | return iter->second; |
708 | } |
709 | TypePtr type = type_resolver_(offset->str(), cu_); |
710 | type_annotations_[offset] = type; |
711 | return type; |
712 | } |
713 | |
714 | void FlatbufferLoader::extractJitSourceAndConstants( |
715 | ExtraFilesMap* jit_sources, |
716 | std::vector<IValue>* constants) { |
717 | AT_ASSERT( |
718 | module_parsed_, |
719 | "Need to first parse a flatbuffer file before extracting jit_sources" ); |
720 | |
721 | const auto* ivalues = module_->ivalues(); |
722 | for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) { |
723 | const auto* ival = ivalues->Get(i); |
724 | parseAndPopulate(i, ival); |
725 | } |
726 | // register functions |
727 | for (const auto& f : all_functions_) { |
728 | if (f.first >= mobile_ivalue_size_) { |
729 | uint32_t class_index = |
730 | ivalues->Get(f.first)->val_as_Function()->class_type(); |
731 | ClassTypePtr class_type = all_types_[class_index]; |
732 | class_type->addMethod(f.second); |
733 | } |
734 | } |
735 | const auto* jit_constants = module_->jit_constants(); |
736 | for (auto i = 0; i < jit_constants->size(); ++i) { |
737 | constants->emplace_back(getIValue(jit_constants->Get(i))); |
738 | } |
739 | parseExtraFilesFromVector(module_->jit_sources(), jit_sources); |
740 | } |
741 | |
742 | } // namespace |
743 | |
744 | mobile::Module parse_and_initialize_mobile_module( |
745 | void* data, |
746 | size_t, |
747 | c10::optional<at::Device>, |
748 | ExtraFilesMap* , |
749 | bool should_copy_tensor_memory) { |
750 | TORCH_CHECK( |
751 | mobile::serialization::ModuleBufferHasIdentifier(data), "Format error" ); |
752 | // TODO(T128189662): If not copying, enforce that data is aligned to |
753 | // kFlatbufferDataAlignmentBytes, and add unit tests. |
754 | |
755 | FlatbufferLoader loader; |
756 | loader.setShouldCopyTensorMemory(should_copy_tensor_memory); |
757 | |
758 | // Flatbuffer doesn't seem to have a way to provide the buffer size when |
759 | // interacting with the buffer. |
760 | auto* flatbuffer_module = mobile::serialization::GetMutableModule(data); |
761 | mobile::Module m = loader.parseModule(flatbuffer_module); |
762 | if (extra_files != nullptr) { |
763 | parseExtraFiles(flatbuffer_module, *extra_files); |
764 | } |
765 | return m; |
766 | } |
767 | |
768 | mobile::Module parse_and_initialize_mobile_module( |
769 | std::shared_ptr<char> data, |
770 | size_t size, |
771 | c10::optional<at::Device> device, |
772 | ExtraFilesMap* ) { |
773 | mobile::Module m = parse_and_initialize_mobile_module( |
774 | data.get(), |
775 | size, |
776 | device, |
777 | extra_files, |
778 | /*should_copy_tensor_memory=*/false); |
779 | m.set_delete_memory(std::move(data)); |
780 | return m; |
781 | } |
782 | |
783 | mobile::Module parse_and_initialize_mobile_module_for_jit( |
784 | void* data, |
785 | size_t, |
786 | ExtraFilesMap& jit_sources, |
787 | std::vector<IValue>& jit_constants, |
788 | c10::optional<at::Device>, |
789 | ExtraFilesMap* ) { |
790 | TORCH_CHECK( |
791 | mobile::serialization::ModuleBufferHasIdentifier(data), "Format error" ); |
792 | // TODO(T128189662): Enforce that data is aligned to |
793 | // kFlatbufferDataAlignmentBytes, and add unit tests. |
794 | |
795 | FlatbufferLoader loader; |
796 | auto* flatbuffer_module = mobile::serialization::GetMutableModule(data); |
797 | mobile::Module m = loader.parseModule(flatbuffer_module); |
798 | if (extra_files != nullptr) { |
799 | parseExtraFiles(flatbuffer_module, *extra_files); |
800 | } |
801 | |
802 | loader.extractJitSourceAndConstants(&jit_sources, &jit_constants); |
803 | return m; |
804 | } |
805 | |
806 | mobile::Module load_mobile_module_from_file( |
807 | const std::string& filename, |
808 | c10::optional<c10::Device> device, |
809 | ExtraFilesMap* ) { |
810 | std::shared_ptr<char> data; |
811 | size_t size = 0; |
812 | std::tie(data, size) = get_file_content(filename.c_str()); |
813 | return parse_and_initialize_mobile_module( |
814 | std::move(data), size, device, extra_files); |
815 | } |
816 | |
817 | uint64_t get_bytecode_version(std::istream& in) { |
818 | std::shared_ptr<char> data; |
819 | size_t size = 0; |
820 | std::tie(data, size) = get_stream_content(in); |
821 | return get_bytecode_version_from_bytes(data.get()); |
822 | } |
823 | |
824 | uint64_t get_bytecode_version(const std::string& filename) { |
825 | std::shared_ptr<char> data; |
826 | size_t size = 0; |
827 | std::tie(data, size) = get_file_content(filename.c_str()); |
828 | return get_bytecode_version_from_bytes(data.get()); |
829 | } |
830 | |
831 | uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) { |
832 | TORCH_CHECK( |
833 | mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content), |
834 | "Format error" ); |
835 | auto* flatbuffer_module = |
836 | mobile::serialization::GetMutableModule(flatbuffer_content); |
837 | return flatbuffer_module->bytecode_version(); |
838 | } |
839 | |
840 | mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) { |
841 | auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content); |
842 | mobile::ModuleInfo minfo; |
843 | minfo.operator_version = ff_module->operator_version(); |
844 | minfo.bytecode_version = ff_module->bytecode_version(); |
845 | |
846 | uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size(); |
847 | if (mobile_ivalue_size == 0) { |
848 | mobile_ivalue_size = ff_module->ivalues()->size(); |
849 | } |
850 | |
851 | std::vector<std::string> type_name_list; |
852 | for (uint32_t i = 0; i < mobile_ivalue_size; i++) { |
853 | const auto* ival = ff_module->ivalues()->Get(i); |
854 | if (const auto* func = ival->val_as_Function()) { |
855 | minfo.function_names.insert(func->qn()->str()); |
856 | for (const auto* op : *func->operators()) { |
857 | at::OperatorName opname(op->name()->str(), op->overload_name()->str()); |
858 | minfo.opname_to_num_args[mobile::operator_str(opname)] = |
859 | op->num_args_serialized(); |
860 | } |
861 | for (const auto* type_ann : *func->type_annotations()) { |
862 | type_name_list.push_back(type_ann->str()); |
863 | } |
864 | } |
865 | } |
866 | c10::TypeParser parser(type_name_list); |
867 | parser.parseList(); |
868 | minfo.type_names = parser.getContainedTypes(); |
869 | return minfo; |
870 | } |
871 | |
872 | mobile::Module load_mobile_module_from_stream_with_copy( |
873 | std::istream& in, |
874 | c10::optional<at::Device> device, |
875 | ExtraFilesMap* ) { |
876 | std::shared_ptr<char> data; |
877 | size_t size = 0; |
878 | std::tie(data, size) = get_stream_content(in); |
879 | return parse_and_initialize_mobile_module( |
880 | std::move(data), size, device, extra_files); |
881 | } |
882 | |
883 | mobile::Module parse_flatbuffer_no_object( |
884 | std::shared_ptr<char> data, |
885 | size_t size, |
886 | c10::optional<at::Device> device) { |
887 | (void)device; |
888 | (void)size; |
889 | auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get()); |
890 | FlatbufferLoader loader; |
891 | // replace parserObject with to handle only class with field case |
892 | // function. |
893 | loader.registerIValueParser( |
894 | mobile::serialization::IValueUnion::Object, |
895 | +[](FlatbufferLoader& loader, |
896 | const mobile::serialization::IValue& ivalue) { |
897 | const mobile::serialization::Object* object = ivalue.val_as_Object(); |
898 | auto cls = loader.getOrCreateClassTypeForObject(object); |
899 | auto obj = c10::ivalue::Object::create( |
900 | at::StrongTypePtr(loader.cu_, cls), object->attrs()->size()); |
901 | for (uint32_t i = 0; i < object->attrs()->size(); i++) { |
902 | IValue val = loader.getIValue(object->attrs()->Get(i)); |
903 | obj->setSlot(i, std::move(val)); |
904 | } |
905 | return static_cast<c10::IValue>(obj); |
906 | }); |
907 | |
908 | mobile::Module m = loader.parseModule(flatbuffer_module); |
909 | m.set_delete_memory(std::move(data)); |
910 | return m; |
911 | } |
912 | |
913 | bool register_flatbuffer_loader() { |
914 | return true; |
915 | } |
916 | |
917 | } // namespace jit |
918 | } // namespace torch |
919 | |