1 | #pragma once |
2 | |
3 | #include <caffe2/serialize/inline_container.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/serialization/export_bytecode.h> |
7 | #include <torch/csrc/jit/serialization/flatbuffer_serializer.h> |
8 | #include <torch/csrc/jit/serialization/pickler.h> |
9 | #include <torch/csrc/jit/serialization/python_print.h> |
10 | #include <torch/csrc/jit/serialization/storage_context.h> |
11 | #include <torch/csrc/jit/serialization/type_name_uniquer.h> |
12 | #include <torch/csrc/onnx/onnx.h> |
13 | #include <ostream> |
14 | |
15 | namespace ONNX_NAMESPACE { |
16 | class ModelProto; |
17 | } |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | // This map is used to keep track of parameters that should be exported |
23 | // externally. When `defer_weight_export` is true, the returned map contains |
24 | // kv pairs that map {external reference name} -> {at::Tensor to be exported}. |
25 | // It is the responsibility of the caller to export these appropriately. |
26 | // |
27 | // For example, when exporting to a zip archive, the caller may write out files |
28 | // for each entry in the export map, with the filename being the key and the |
29 | // file contents being the raw tensor data. |
30 | using RawDataExportMap = std::unordered_map<std::string, at::Tensor>; |
31 | |
32 | using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>; |
33 | |
34 | using NodeNameMap = std::unordered_map<const Node*, std::string>; |
35 | |
36 | // Used for modularized export settling function and node attributes. |
37 | using NodeAttrNameMap = std:: |
38 | unordered_map<const Node*, std::unordered_map<std::string, std::string>>; |
39 | |
40 | TORCH_API std::tuple< |
41 | std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, |
42 | RawDataExportMap, |
43 | SymbolDimMap, |
44 | bool, |
45 | NodeNameMap> |
46 | export_onnx( |
47 | const std::shared_ptr<Graph>& graph, |
48 | const std::map<std::string, at::Tensor>& initializers, |
49 | int64_t onnx_opset_version, |
50 | const std::unordered_map< |
51 | std::string, |
52 | std::unordered_map<int64_t, std::string>>& dynamic_axes, |
53 | bool defer_weight_export = false, |
54 | ::torch::onnx::OperatorExportTypes operator_export_type = |
55 | ::torch::onnx::OperatorExportTypes::ONNX, |
56 | bool strip_doc_string = true, |
57 | bool keep_initializers_as_inputs = true, |
58 | const std::map<std::string, int>& custom_opsets = {}, |
59 | bool add_node_names = true, |
60 | bool use_external_data_format = false, |
61 | const std::string& onnx_file_path = std::string(), |
62 | const NodeAttrNameMap& node_attr_to_name = {}); |
63 | |
64 | TORCH_API std::string serialize_model_proto_to_string( |
65 | const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto); |
66 | |
67 | TORCH_API void check_onnx_proto(const std::string& proto_string); |
68 | |
69 | // Serializer for both oldsyle and unified format TorchScript serialization |
70 | class TORCH_API ScriptModuleSerializer { |
71 | public: |
72 | explicit ScriptModuleSerializer( |
73 | caffe2::serialize::PyTorchStreamWriter& export_writer) |
74 | : writer_(export_writer), current_source_range_tag_(0) {} |
75 | |
76 | void writeFiles(const std::string& code_dir); |
77 | void serialize( |
78 | const Module& module, |
79 | const ExtraFilesMap& , |
80 | bool bytecode_format, |
81 | bool save_mobile_debug_info); |
82 | void serialize_unified_format(Module& module, uint64_t script_module_id); |
83 | SerializationStorageContext& storage_context(); |
84 | |
85 | ~ScriptModuleSerializer() = default; |
86 | |
87 | private: |
88 | void convertNamedType(const c10::NamedTypePtr& class_type); |
89 | void convertTypes(const at::NamedTypePtr& root_type); |
90 | void (const Module& module, const ExtraFilesMap& ); |
91 | void writeByteCode(const Module& module, bool save_mobile_debug_info); |
92 | void writeArchive( |
93 | const IValue& value, |
94 | const std::string& archive_name, |
95 | const std::string& archive_dir, |
96 | const std::string& tensor_dir, |
97 | bool use_storage_context = false, |
98 | bool skip_tensor_data = false); |
99 | void updateSourceRangeTags(const SourceRangeRecords& ranges); |
100 | |
101 | caffe2::serialize::PyTorchStreamWriter& writer_; |
102 | std::vector<at::IValue> constant_table_; |
103 | |
104 | std::unordered_set<c10::NamedTypePtr> converted_types_; |
105 | PrintDepsTable class_deps_; |
106 | TypeNameUniquer type_name_uniquer_; |
107 | // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be |
108 | // created |
109 | OrderedDict<std::string, PythonPrint> file_streams_; |
110 | // Used to keep references of storages around during serialization to solve |
111 | // for ABA memory reuse problem hit when storages are created/destroyed |
112 | // during serialization process. Also used to coordinate sharing of storages |
113 | // between Script and eager modules in torch.package. |
114 | SerializationStorageContext storage_context_; |
115 | |
116 | // Uniquely identifies a SourceRange in a model. |
117 | // SourceRanges are associated with Nodes of Graphs. |
118 | // However for mobile deployment we dont intend to ship |
119 | // full JIT with capabilities of reading code and constructing |
120 | // graphs. |
121 | // Instead we serialize the Code generated from graph of the methods. |
122 | // Code is serialized in bytecode format that contains instructions |
123 | // corresponding to the nodes of the graph. Since original graph is gone, the |
124 | // question is how do we identify where the ops, in serialized bytecode, come |
125 | // from in original model code. We do this in two parts. |
126 | // 1. Associate a unique tag to SourceRange. |
127 | // 2. Serialize this unique_tag. |
128 | // 2.1 Meaning save <byte_offset, source_range_tag, source range> instead of |
129 | // <byte_offset, source range> |
130 | // 3. During serializing model for mobile, i.e. bytecode generation, |
131 | // save unique tag of SourceRange corresponding to the Node. |
132 | // 4. During deserialization, read all the debug_pkl, to construct a map |
133 | // of <unique_tag, SourceRange> and use tag saved with OPs in bytecode |
134 | // to lookup the source range. |
135 | // Strictly speaking we will serialize InlinedCallStack directly, which |
136 | // contains SourceRange. This way we have access to entire callstack and not |
137 | // just source information about where the node is, since bytecode inlines the |
138 | // graph before saving it. |
139 | SourceRangeTagMap source_range_tags_; |
140 | int64_t current_source_range_tag_; |
141 | }; |
142 | |
143 | // For testing purposes |
144 | TORCH_API std::string pretty_print_onnx( |
145 | const std::shared_ptr<Graph>& graph, |
146 | const std::map<std::string, at::Tensor>& initializers, |
147 | int64_t onnx_opset_version, |
148 | bool defer_weight_export, |
149 | ::torch::onnx::OperatorExportTypes operator_export_type = |
150 | ::torch::onnx::OperatorExportTypes::ONNX, |
151 | bool google_printer = false, |
152 | bool keep_initializers_as_inputs = true, |
153 | const std::map<std::string, int>& custom_opsets = {}, |
154 | bool add_node_names = true); |
155 | |
156 | TORCH_API void ExportModule( |
157 | const Module& module, |
158 | std::ostream& out, |
159 | const ExtraFilesMap& metadata = ExtraFilesMap(), |
160 | bool bytecode_format = false, |
161 | bool save_mobile_debug_info = false, |
162 | bool use_flatbuffer = false); |
163 | |
164 | TORCH_API void ExportModule( |
165 | const Module& module, |
166 | const std::string& filename, |
167 | const ExtraFilesMap& metadata = ExtraFilesMap(), |
168 | bool bytecode_format = false, |
169 | bool save_mobile_debug_info = false, |
170 | bool use_flatbuffer = false); |
171 | |
172 | TORCH_API void ExportModule( |
173 | const Module& module, |
174 | const std::function<size_t(const void*, size_t)>& writer_func, |
175 | const ExtraFilesMap& metadata = ExtraFilesMap(), |
176 | bool bytecode_format = false, |
177 | bool save_mobile_debug_info = false, |
178 | bool use_flatbuffer = false); |
179 | |
180 | // Write the bytes of a pickle archive and the tensors referenced inside that |
181 | // archive |
182 | TORCH_API void writeArchiveAndTensors( |
183 | const std::string& archive_name, |
184 | const char* pickle_bytes, |
185 | size_t size, |
186 | const std::vector<at::Tensor>& tensors, |
187 | caffe2::serialize::PyTorchStreamWriter& out); |
188 | |
189 | // Surrounding system can install an additional hook to produce extra files |
190 | // with metadata based on environment every time a module is serialized. |
191 | using = std::function<ExtraFilesMap(const Module&)>; |
192 | TORCH_API void (ExportModuleExtraFilesHook hook); |
193 | |
194 | /** |
195 | * Generates new bytecode for a Script module and returns what the op list |
196 | * would be for a LiteScriptModule based off the current code base. If you |
197 | * have a LiteScriptModule and want to get the currently present |
198 | * list of ops call _export_operator_list instead. |
199 | */ |
200 | TORCH_API std::vector<std::string> export_opnames(const Module& m); |
201 | |
202 | struct TORCH_API BytecodeEmitMode { |
203 | static bool is_default_value_for_unspecified_arg_enabled(); |
204 | static void set_default_value_for_unspecified_arg_enabled(bool enabled); |
205 | |
206 | static bool is_default_args_before_out_args_enabled(); |
207 | static void set_default_args_before_out_args_enabled(bool enabled); |
208 | |
209 | static bool is_emit_promoted_ops_enabled(); |
210 | static void set_default_emit_promoted_ops_enabled(bool enabled); |
211 | }; |
212 | |
213 | // RAII guard to switch the way JIT emits the bytecode for inputs. |
214 | // default_value_for_unspecified_arg: |
215 | // true: instruction of default argument values (like LOADC) is emitted. |
216 | // false: instruction of default argument values are not emitted. Instead |
217 | // they are fetched from operator schema. |
218 | // default_args_before_out_args (to forward compatibile support |
219 | // operators allowing out arguments and default arguments): |
220 | // true: the number of specified arguments will deserialized to (#all_args - |
221 | // #default_args). false: the number of specified arguments will deserialized to |
222 | // (#all_args). |
223 | struct TORCH_API BytecodeEmitModeGuard { |
224 | BytecodeEmitModeGuard( |
225 | bool enable_default_value_for_unspecified_arg, |
226 | bool enable_default_args_before_out_args, |
227 | bool enable_emit_promoted_ops) |
228 | : prev_default_value_for_unspecified_arg_mode( |
229 | BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()), |
230 | prev_default_args_before_out_args( |
231 | BytecodeEmitMode::is_default_args_before_out_args_enabled()), |
232 | prev_default_emit_promoted_ops( |
233 | BytecodeEmitMode::is_emit_promoted_ops_enabled()) { |
234 | BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( |
235 | enable_default_value_for_unspecified_arg); |
236 | BytecodeEmitMode::set_default_args_before_out_args_enabled( |
237 | enable_default_args_before_out_args); |
238 | BytecodeEmitMode::set_default_emit_promoted_ops_enabled( |
239 | enable_emit_promoted_ops); |
240 | } |
241 | ~BytecodeEmitModeGuard() { |
242 | BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( |
243 | prev_default_value_for_unspecified_arg_mode); |
244 | BytecodeEmitMode::set_default_args_before_out_args_enabled( |
245 | prev_default_args_before_out_args); |
246 | BytecodeEmitMode::set_default_emit_promoted_ops_enabled( |
247 | prev_default_emit_promoted_ops); |
248 | } |
249 | bool prev_default_value_for_unspecified_arg_mode; |
250 | bool prev_default_args_before_out_args; |
251 | bool prev_default_emit_promoted_ops; |
252 | }; |
253 | |
254 | TORCH_API IValue to_tuple(std::vector<IValue> ivalues); |
255 | TORCH_API IValue |
256 | Table(const std::vector<std::pair<std::string, IValue>>& entries); |
257 | |
258 | // TODO remove these switches once interface call is rolled out. |
259 | TORCH_API void enableMobileInterfaceCallExport(); |
260 | bool getMobileInterfaceCallExport(); |
261 | |
262 | TORCH_API CompilationOptions getOptionsFromGlobal(); |
263 | |
264 | TORCH_API void save_jit_module( |
265 | const Module& module, |
266 | const std::string& filename, |
267 | const ExtraFilesMap& = ExtraFilesMap()); |
268 | |
269 | TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( |
270 | const Module& module, |
271 | const ExtraFilesMap& = ExtraFilesMap()); |
272 | |
273 | TORCH_API void save_jit_module_to_write_func( |
274 | const Module& module, |
275 | const ExtraFilesMap& , |
276 | bool save_mobile_debug_info, |
277 | const std::function<size_t(const void*, size_t)>& writer_func); |
278 | |
279 | } // namespace jit |
280 | } // namespace torch |
281 | |