1 | #pragma once |
2 | #include <torch/csrc/jit/mobile/module.h> |
3 | #include <torch/csrc/jit/mobile/parse_operators.h> |
4 | |
5 | #include <istream> |
6 | #include <memory> |
7 | |
8 | #include <caffe2/serialize/file_adapter.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | using caffe2::serialize::FileAdapter; |
13 | using caffe2::serialize::IStreamAdapter; |
14 | using caffe2::serialize::ReadAdapterInterface; |
15 | using = std::unordered_map<std::string, std::string>; |
16 | |
17 | constexpr const char* kArchiveNameBytecode = "bytecode" ; |
18 | constexpr const char* kArchiveNameConstants = "constants" ; |
19 | constexpr const char* kArchiveNameVersion = "version" ; |
20 | |
21 | // The family of methods below load a serialized Mobile Module |
22 | // into a mobile::Module object. |
23 | TORCH_API mobile::Module _load_for_mobile( |
24 | std::istream& in, |
25 | c10::optional<at::Device> device, |
26 | ExtraFilesMap& ); |
27 | |
28 | TORCH_API mobile::Module _load_for_mobile( |
29 | const std::string& filename, |
30 | c10::optional<at::Device> device, |
31 | ExtraFilesMap& ); |
32 | |
33 | TORCH_API mobile::Module _load_for_mobile( |
34 | std::unique_ptr<ReadAdapterInterface> rai, |
35 | c10::optional<c10::Device> device, |
36 | ExtraFilesMap& , |
37 | uint64_t module_load_options = kDefaultMobileLoadOptions); |
38 | |
39 | TORCH_API mobile::Module _load_for_mobile( |
40 | const std::string& filename, |
41 | c10::optional<at::Device> device, |
42 | ExtraFilesMap& , |
43 | uint64_t module_load_options); |
44 | |
45 | TORCH_API mobile::Module _load_for_mobile( |
46 | std::istream& in, |
47 | c10::optional<at::Device> device = c10::nullopt); |
48 | |
49 | TORCH_API mobile::Module _load_for_mobile( |
50 | const std::string& filename, |
51 | c10::optional<at::Device> device = c10::nullopt); |
52 | |
53 | TORCH_API mobile::Module _load_for_mobile( |
54 | std::unique_ptr<ReadAdapterInterface> rai, |
55 | c10::optional<c10::Device> device = c10::nullopt); |
56 | |
57 | /** |
58 | * Load only the contents of the "extra/" files whose names are |
59 | * passed in the map (extra_files). Populate the corresponding values |
60 | * with the contents of those files. Do not attempt to load the entire |
61 | * model, and stop once the extra files have been extracted. |
62 | * |
63 | * This API is needed to be able to load GPU models on linux CPU |
64 | * machines and extract only the extra files so that we can inspect |
65 | * the metadata that was added to the .ptl archive when it was |
66 | * generated. |
67 | * |
68 | */ |
69 | void ( |
70 | const std::string& filename, |
71 | c10::optional<at::Device> device, |
72 | ExtraFilesMap& ); |
73 | |
74 | // Currently used by both mobile/import.cpp and model_compatibility.cpp. |
75 | // Should be removed after model_compatibility.cpp start using simplified |
76 | // version type_resolver and obj_loader. |
77 | at::TypePtr resolveTypeNameMobile( |
78 | const c10::QualifiedName& qn, |
79 | std::shared_ptr<CompilationUnit> compilation_unit); |
80 | c10::StrongTypePtr typeResolverMobile( |
81 | const c10::QualifiedName& qn, |
82 | std::shared_ptr<CompilationUnit> compilation_unit); |
83 | c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile( |
84 | const at::StrongTypePtr& type, |
85 | const at::IValue& input, |
86 | mobile::CompilationUnit& mobile_compilation_unit); |
87 | |
88 | // Given a reader, which has access to a model file, |
89 | // return true if there exists tensors in `bytecode` archive |
90 | bool isTensorInBytecodeArchive( |
91 | caffe2::serialize::PyTorchStreamReader& stream_reader); |
92 | |
93 | namespace mobile { |
94 | |
95 | /** |
96 | * Given a torch::jit::mobile::Module, return a set of operator names |
97 | * (with overload name) that are used by any method in this mobile |
98 | * Mobile. This method runs through the bytecode for all methods |
99 | * in the specified model (module), and extracts all the root |
100 | * operator names. Root operators are operators that are called |
101 | * directly by the model (as opposed to non-root operators, which |
102 | * may be called transitively by the root operators). |
103 | * |
104 | */ |
105 | TORCH_API std::set<std::string> _export_operator_list( |
106 | torch::jit::mobile::Module& module); |
107 | |
108 | } // namespace mobile |
109 | |
110 | } // namespace jit |
111 | } // namespace torch |
112 | |