1#pragma once
2
3#include <caffe2/serialize/inline_container.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/serialization/export_bytecode.h>
7#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
8#include <torch/csrc/jit/serialization/pickler.h>
9#include <torch/csrc/jit/serialization/python_print.h>
10#include <torch/csrc/jit/serialization/storage_context.h>
11#include <torch/csrc/jit/serialization/type_name_uniquer.h>
12#include <torch/csrc/onnx/onnx.h>
13#include <ostream>
14
15namespace ONNX_NAMESPACE {
16class ModelProto;
17}
18
19namespace torch {
20namespace jit {
21
22// This map is used to keep track of parameters that should be exported
23// externally. When `defer_weight_export` is true, the returned map contains
24// kv pairs that map {external reference name} -> {at::Tensor to be exported}.
25// It is the responsibility of the caller to export these appropriately.
26//
27// For example, when exporting to a zip archive, the caller may write out files
28// for each entry in the export map, with the filename being the key and the
29// file contents being the raw tensor data.
30using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
31
32using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
33
34using NodeNameMap = std::unordered_map<const Node*, std::string>;
35
36// Used for modularized export settling function and node attributes.
37using NodeAttrNameMap = std::
38 unordered_map<const Node*, std::unordered_map<std::string, std::string>>;
39
40TORCH_API std::tuple<
41 std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
42 RawDataExportMap,
43 SymbolDimMap,
44 bool,
45 NodeNameMap>
46export_onnx(
47 const std::shared_ptr<Graph>& graph,
48 const std::map<std::string, at::Tensor>& initializers,
49 int64_t onnx_opset_version,
50 const std::unordered_map<
51 std::string,
52 std::unordered_map<int64_t, std::string>>& dynamic_axes,
53 bool defer_weight_export = false,
54 ::torch::onnx::OperatorExportTypes operator_export_type =
55 ::torch::onnx::OperatorExportTypes::ONNX,
56 bool strip_doc_string = true,
57 bool keep_initializers_as_inputs = true,
58 const std::map<std::string, int>& custom_opsets = {},
59 bool add_node_names = true,
60 bool use_external_data_format = false,
61 const std::string& onnx_file_path = std::string(),
62 const NodeAttrNameMap& node_attr_to_name = {});
63
64TORCH_API std::string serialize_model_proto_to_string(
65 const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);
66
67TORCH_API void check_onnx_proto(const std::string& proto_string);
68
69// Serializer for both oldsyle and unified format TorchScript serialization
70class TORCH_API ScriptModuleSerializer {
71 public:
72 explicit ScriptModuleSerializer(
73 caffe2::serialize::PyTorchStreamWriter& export_writer)
74 : writer_(export_writer), current_source_range_tag_(0) {}
75
76 void writeFiles(const std::string& code_dir);
77 void serialize(
78 const Module& module,
79 const ExtraFilesMap& extra_files,
80 bool bytecode_format,
81 bool save_mobile_debug_info);
82 void serialize_unified_format(Module& module, uint64_t script_module_id);
83 SerializationStorageContext& storage_context();
84
85 ~ScriptModuleSerializer() = default;
86
87 private:
88 void convertNamedType(const c10::NamedTypePtr& class_type);
89 void convertTypes(const at::NamedTypePtr& root_type);
90 void writeExtraFiles(const Module& module, const ExtraFilesMap& extra_files);
91 void writeByteCode(const Module& module, bool save_mobile_debug_info);
92 void writeArchive(
93 const IValue& value,
94 const std::string& archive_name,
95 const std::string& archive_dir,
96 const std::string& tensor_dir,
97 bool use_storage_context = false,
98 bool skip_tensor_data = false);
99 void updateSourceRangeTags(const SourceRangeRecords& ranges);
100
101 caffe2::serialize::PyTorchStreamWriter& writer_;
102 std::vector<at::IValue> constant_table_;
103
104 std::unordered_set<c10::NamedTypePtr> converted_types_;
105 PrintDepsTable class_deps_;
106 TypeNameUniquer type_name_uniquer_;
107 // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
108 // created
109 OrderedDict<std::string, PythonPrint> file_streams_;
110 // Used to keep references of storages around during serialization to solve
111 // for ABA memory reuse problem hit when storages are created/destroyed
112 // during serialization process. Also used to coordinate sharing of storages
113 // between Script and eager modules in torch.package.
114 SerializationStorageContext storage_context_;
115
116 // Uniquely identifies a SourceRange in a model.
117 // SourceRanges are associated with Nodes of Graphs.
118 // However for mobile deployment we dont intend to ship
119 // full JIT with capabilities of reading code and constructing
120 // graphs.
121 // Instead we serialize the Code generated from graph of the methods.
122 // Code is serialized in bytecode format that contains instructions
123 // corresponding to the nodes of the graph. Since original graph is gone, the
124 // question is how do we identify where the ops, in serialized bytecode, come
125 // from in original model code. We do this in two parts.
126 // 1. Associate a unique tag to SourceRange.
127 // 2. Serialize this unique_tag.
128 // 2.1 Meaning save <byte_offset, source_range_tag, source range> instead of
129 // <byte_offset, source range>
130 // 3. During serializing model for mobile, i.e. bytecode generation,
131 // save unique tag of SourceRange corresponding to the Node.
132 // 4. During deserialization, read all the debug_pkl, to construct a map
133 // of <unique_tag, SourceRange> and use tag saved with OPs in bytecode
134 // to lookup the source range.
135 // Strictly speaking we will serialize InlinedCallStack directly, which
136 // contains SourceRange. This way we have access to entire callstack and not
137 // just source information about where the node is, since bytecode inlines the
138 // graph before saving it.
139 SourceRangeTagMap source_range_tags_;
140 int64_t current_source_range_tag_;
141};
142
143// For testing purposes
144TORCH_API std::string pretty_print_onnx(
145 const std::shared_ptr<Graph>& graph,
146 const std::map<std::string, at::Tensor>& initializers,
147 int64_t onnx_opset_version,
148 bool defer_weight_export,
149 ::torch::onnx::OperatorExportTypes operator_export_type =
150 ::torch::onnx::OperatorExportTypes::ONNX,
151 bool google_printer = false,
152 bool keep_initializers_as_inputs = true,
153 const std::map<std::string, int>& custom_opsets = {},
154 bool add_node_names = true);
155
156TORCH_API void ExportModule(
157 const Module& module,
158 std::ostream& out,
159 const ExtraFilesMap& metadata = ExtraFilesMap(),
160 bool bytecode_format = false,
161 bool save_mobile_debug_info = false,
162 bool use_flatbuffer = false);
163
164TORCH_API void ExportModule(
165 const Module& module,
166 const std::string& filename,
167 const ExtraFilesMap& metadata = ExtraFilesMap(),
168 bool bytecode_format = false,
169 bool save_mobile_debug_info = false,
170 bool use_flatbuffer = false);
171
172TORCH_API void ExportModule(
173 const Module& module,
174 const std::function<size_t(const void*, size_t)>& writer_func,
175 const ExtraFilesMap& metadata = ExtraFilesMap(),
176 bool bytecode_format = false,
177 bool save_mobile_debug_info = false,
178 bool use_flatbuffer = false);
179
180// Write the bytes of a pickle archive and the tensors referenced inside that
181// archive
182TORCH_API void writeArchiveAndTensors(
183 const std::string& archive_name,
184 const char* pickle_bytes,
185 size_t size,
186 const std::vector<at::Tensor>& tensors,
187 caffe2::serialize::PyTorchStreamWriter& out);
188
189// Surrounding system can install an additional hook to produce extra files
190// with metadata based on environment every time a module is serialized.
191using ExportModuleExtraFilesHook = std::function<ExtraFilesMap(const Module&)>;
192TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook);
193
194/**
195 * Generates new bytecode for a Script module and returns what the op list
196 * would be for a LiteScriptModule based off the current code base. If you
197 * have a LiteScriptModule and want to get the currently present
198 * list of ops call _export_operator_list instead.
199 */
200TORCH_API std::vector<std::string> export_opnames(const Module& m);
201
202struct TORCH_API BytecodeEmitMode {
203 static bool is_default_value_for_unspecified_arg_enabled();
204 static void set_default_value_for_unspecified_arg_enabled(bool enabled);
205
206 static bool is_default_args_before_out_args_enabled();
207 static void set_default_args_before_out_args_enabled(bool enabled);
208
209 static bool is_emit_promoted_ops_enabled();
210 static void set_default_emit_promoted_ops_enabled(bool enabled);
211};
212
213// RAII guard to switch the way JIT emits the bytecode for inputs.
214// default_value_for_unspecified_arg:
215// true: instruction of default argument values (like LOADC) is emitted.
216// false: instruction of default argument values are not emitted. Instead
217// they are fetched from operator schema.
218// default_args_before_out_args (to forward compatibile support
219// operators allowing out arguments and default arguments):
220// true: the number of specified arguments will deserialized to (#all_args -
221// #default_args). false: the number of specified arguments will deserialized to
222// (#all_args).
223struct TORCH_API BytecodeEmitModeGuard {
224 BytecodeEmitModeGuard(
225 bool enable_default_value_for_unspecified_arg,
226 bool enable_default_args_before_out_args,
227 bool enable_emit_promoted_ops)
228 : prev_default_value_for_unspecified_arg_mode(
229 BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
230 prev_default_args_before_out_args(
231 BytecodeEmitMode::is_default_args_before_out_args_enabled()),
232 prev_default_emit_promoted_ops(
233 BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
234 BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
235 enable_default_value_for_unspecified_arg);
236 BytecodeEmitMode::set_default_args_before_out_args_enabled(
237 enable_default_args_before_out_args);
238 BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
239 enable_emit_promoted_ops);
240 }
241 ~BytecodeEmitModeGuard() {
242 BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
243 prev_default_value_for_unspecified_arg_mode);
244 BytecodeEmitMode::set_default_args_before_out_args_enabled(
245 prev_default_args_before_out_args);
246 BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
247 prev_default_emit_promoted_ops);
248 }
249 bool prev_default_value_for_unspecified_arg_mode;
250 bool prev_default_args_before_out_args;
251 bool prev_default_emit_promoted_ops;
252};
253
254TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
255TORCH_API IValue
256Table(const std::vector<std::pair<std::string, IValue>>& entries);
257
258// TODO remove these switches once interface call is rolled out.
259TORCH_API void enableMobileInterfaceCallExport();
260bool getMobileInterfaceCallExport();
261
262TORCH_API CompilationOptions getOptionsFromGlobal();
263
264TORCH_API void save_jit_module(
265 const Module& module,
266 const std::string& filename,
267 const ExtraFilesMap& extra_files = ExtraFilesMap());
268
269TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes(
270 const Module& module,
271 const ExtraFilesMap& extra_files = ExtraFilesMap());
272
273TORCH_API void save_jit_module_to_write_func(
274 const Module& module,
275 const ExtraFilesMap& extra_files,
276 bool save_mobile_debug_info,
277 const std::function<size_t(const void*, size_t)>& writer_func);
278
279} // namespace jit
280} // namespace torch
281