1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef GLOW_EXPORTER_ONNXMODELWRITER_H
18#define GLOW_EXPORTER_ONNXMODELWRITER_H
19
20#include "glow/Exporter/CommonOperatorWriter.h"
21#include "glow/Graph/Graph.h"
22#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23#include "glow/Runtime/RuntimeTypes.h"
24
25#include "onnx/onnx_pb.h"
26
27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/StringRef.h"
29
30#include <list>
31#include <string>
32
33/// ONNX traits for protobuf types.
34struct ONNX_TRAITS {
35 using GraphProto = ONNX_NAMESPACE::GraphProto;
36};
37
38namespace glow {
39
40/// Unique set of visited nodes.
41using ReportedNodes = std::unordered_set<const Node *>;
42
43/// Writes ONNX models.
44class ONNXModelWriter : public CommonOperatorWriter<ONNX_TRAITS> {
45 // Declare shorter aliases.
46 using GraphType = typename ONNX_TRAITS::GraphProto;
47 using NodeType = ONNX_NAMESPACE::NodeProto;
48 using TensorType = ONNX_NAMESPACE::TensorProto;
49 using AttrType = ONNX_NAMESPACE::AttributeProto;
50 using ValueInfoType = ONNX_NAMESPACE::ValueInfoProto;
51
52 // ModelProto that we are writing to.
53 ONNX_NAMESPACE::ModelProto modelProto_;
54 // GraphProto that we are writing to.
55 ONNX_TRAITS::GraphProto *graphProto_;
56 // Root GraphProto that we are writing to. Equal to \ref graphProto_ unless
57 // when writing a constant folding subgraph, when graphProto_ is temporarily
58 // changed.
59 ONNX_TRAITS::GraphProto *graphProtoRoot_;
60 /// Current IR version of ONNX.
61 const size_t irVersion_;
62 /// Current version of ONNX standard.
63 const size_t opsetVersion_;
64 /// Keeps the track of already visited or processed nodes.
65 ReportedNodes reportedNodes_;
66 /// Whether we use zip mode or not
67 const bool zipMode_;
68 /// Whether we use text mode or not
69 const bool textMode_;
70 /// Whether to include Constant (initializer) data in the exported proto.
71 const bool includeConstantData_;
72 /// Extra metadata properties to add to the ONNX file
73 const llvm::StringMap<std::string> &extraMetadataProps_;
74 /// Whether to use custom ONNX ops.
75 const bool useGlowCustomOps_;
76 /// Whether we are writing a DAG.
77 const bool dagMode_;
78 /// A map containing a record of what constant folding took place, to record
79 /// in serialized DAGs.
80 const ConstantFoldingRecordMap &constFoldRecord_;
81 /// Backend-specific node info to include in the exported model.
82 const BackendSpecificNodeInfo &backendSpecificNodeInfo_;
83 /// Map from Placeholders in the Module to the symbolic name they were loaded
84 /// with from the input model. If not null, included in IO doc_string info.
85 const LoadedPlaceholderNameMap *loadedPHNames_;
86 /// Map from static PH names to the type it was originally loaded with.
87 const std::map<std::string, Type> *staticPlaceholderTypes_;
88 /// A dedicated list of initializers in case the tensors get too big and don't
89 /// fit into the model.
90 std::list<TensorType> initializers_;
91 /// Holds all Functions from a DAG that are being written when in dagMode_.
92 llvm::SmallSet<Function *, 6> functionsFromDAG_;
93 /// Holds all constant folding Functions that have been processed.
94 llvm::SmallSet<Function *, 6> processedConstFoldFunctions_;
95 /// Maps from all non-static input PHs to the generated proto. It's used to
96 /// buffer protos; later on written out in order based on \ref loadedPHNames_.
97 std::unordered_map<const Placeholder *, ValueInfoType> inputValueInfos_;
98 /// Maps from all output PHs to the generated proto. It's used to buffer
99 /// protos; later on written out in order based on \ref loadedPHNames_.
100 std::unordered_map<const Placeholder *, ValueInfoType> outputValueInfos_;
101 /// Output string. Null value indicates that the output is to be written to a
102 /// file.
103 std::string *outputStringPtr_;
104
105 /// Creates and \returns a new ValueInfoType for \p PH based on \p isInput.
106 /// It's added either directy to \ref graphProto_, or to \ref inputValueInfos_
107 /// / \ref outputValueInfos_, depending on whether there's an order we need to
108 /// serialize the IO in (order comes from \ref loadedPHNames_ if non-null).
109 Expected<ValueInfoType *> createProtoForIO(const Placeholder *PH,
110 bool isInput);
111 /// Writes all inputs and outputs with operator name \p opName from give Node
112 /// \p node into protobuf \p proto.
113 static Error writeAllWithNode(const std::string &opName, const Node *node,
114 GraphType &graph, NodeType *proto);
115 /// Writes all inputs and outputs with operator name \p opName from give Node
116 /// \p node into created node protobuf using \p graph.
117 static Error writeAll(const std::string &opName, const Node *node,
118 GraphType &graph);
119
120 /// Add an initializer. Depending on \ref zipMode_, it will add directly to
121 /// the \p graph or to a separate list.
122 TensorType *addInitializer(GraphType &graph);
123
124 /// Special case node writer for Glow convolutions with quantized inputs and
125 /// outputs.
126 Error writeTensorwiseQuantizedConvolution(const ConvolutionNode *node,
127 GraphType &graph);
128
129 /// Write \p node to \p graph using custom Glow Nodes, exported via
130 /// auto-generated export logic in NodeGen.
131 Error writeGlowCustomOperator(const Node *node, GraphType &graph);
132
133 /// Setup a new proto \ref modelProto_ and \ref graphProto_.
134 void setupNewProto();
135
136 /// Write the current Function \ref F_ to \ref graphProto_. \returns if there
137 /// was an issue during iteration or writing.
138 Error writeFunction();
139
140 /// Given a Constant \p C that was previously created during Constant folding,
141 /// Serializes the constant folding Function saved by \p SN, where the
142 /// Function is the parent of \p SN. The function is written to an attribute
143 /// in a Glow__ConstFoldSubgraph NodeProto. \returns if an Error occurs.
144 Error writeConstantFoldingSubgraph(const Constant *C, SaveNode *SN);
145
146 /// \returns whether currently writing a constant folding subgraph.
147 bool isWritingConstFoldSubgraph();
148
149 /// Finalize the written function and write it out to \p filename. \returns if
150 /// there is an error encountered.
151 Error finalizeAndWriteProto(llvm::StringRef filename);
152
153 /// Adds a metadata prop with \p key and \p val to \ref modelProto_.
154 void addMetadataProp(const std::string &key, const std::string &val);
155
156 /// Write out the Functions and metadata for all DAGNodes in \p postOrder
157 /// given parent \p mod.
158 Error writePartitionAndMetadataProps(
159 Module &mod, llvm::ArrayRef<const runtime::DAGNode *> postOrder);
160
161 /// \returns whether \p PH is an intermediate PH for the DAG being written
162 /// (i.e. both input and an output for Functions in \ref functionsFromDAG_).
163 bool isIntermediatePHForDAG(const Placeholder *PH);
164
165 /// \returns True if the operator with the name \p typeName has support for
166 /// multidirectional broadcasting.
167 bool hasMultidirectionalBroadcast(const llvm::StringRef typeName);
168
169public:
170 /// Inserts the mapping in \p map into \p extraMetadataProps. \returns an
171 /// error if the key already exists for the map in \p extraMetadataProps.
172 static Error insertLoaderNameUniqueOffsetMetadata(
173 llvm::StringMap<std::string> &extraMetadataProps,
174 const OriginNameToTQPMap &map);
175
176 /// Converts \p glowType to \p protoType.
177 static typename TensorType::DataType convertType(const Type &glowType);
178 /// Writes Glow tensor \p T to proto output \p out. Depending on
179 /// \p useGlowCustomOps meta info will be annotated differently.
180 /// If \p includeData then the data from \p T will be included; otherwise only
181 /// the type info and name will be.
182 static void writeTensor(const Tensor &T, TensorType *out,
183 bool useGlowCustomOps = false,
184 bool includeData = true);
185
186 /// Creates an ONNX model writer to serialize \p F graph into file
187 /// \p modelFilename, writing \p irVersion and \p opsetVersion.
188 /// If \p errPtr is not null then if an error occurs it will get assigned
189 /// there otherwise if an error occurs it will abort. It also supports
190 /// serialization with text format or binary format depending on \p textMode.
191 /// If \p zipMode is true, it will save weights into individual TensorProto
192 /// file along with the model file and package them into a zip file. If
193 /// \p useGlowCustomOps then it will use auto-generated export logic via
194 /// NodeGen to export all Glow Nodes as is via custom ops, instead of trying
195 /// to abide by the official ONNX ops. If \p includeConstantData then data for
196 /// Constants will be serialized in the written model, otherwise it will be
197 /// skipped (but initializers will still exist, they will just have no data).
198 /// \p extraMetadataProps is a mapping of key value pairs which are added to
199 /// the metadata props portion of the ONNX.
200 /// \p constFoldRecord contains any records of constant folding that should be
201 /// included in the serialized model.
202 /// \p backendSpecificNodeInfo contains attributes to add onto Nodes when
203 /// exporting if found.
204 ONNXModelWriter(const std::string &modelFilename, Function &F,
205 size_t irVersion, size_t opsetVersion,
206 Error *errPtr = nullptr, bool textMode = false,
207 bool zipMode = false, bool useGlowCustomOps = false,
208 bool includeConstantData = true,
209 const llvm::StringMap<std::string> &extraMetadataProps =
210 llvm::StringMap<std::string>(),
211 const ConstantFoldingRecordMap &constFoldRecord =
212 ConstantFoldingRecordMap(),
213 const BackendSpecificNodeInfo &backendSpecificNodeInfo = {},
214 std::string *outputStringPtr = nullptr);
215
216 /// Creates an ONNX model writer to serialize \p dagList into file
217
218 /// \p modelFilename, writing \p irVersion and \p opsetVersion. Each partition
219 /// from \p dagList will be annotated with the name of the partition to the
220 /// op. This exporter requires using \ref useGlowCustomOps_ and sets it true
221 /// as such. If \p errPtr is not null then if an error occurs it will get
222 /// assigned there otherwise if an error occurs it will abort. It also
223 /// supports serialization with text format or binary format depending on
224 /// \p textMode. If \p zipMode is true, it will save weights into individual
225 /// TensorProto file along with the model file and package them into a zip
226 /// file. If \p includeConstantData then data for Constants will be serialized
227 /// in the written model, otherwise it will be skipped (but initializers will
228 /// still exist, they will just have no data). \p extraMetadataProps is
229 /// a mapping of key value pairs which are added to the metadata props portion
230 /// of the ONNX. \p constFoldRecord contains any records of constant folding
231 /// that should be included in the serialized model.
232 /// \p backendSpecificNodeInfo contains attributes to add onto Nodes when
233 /// exporting if found.
234
235 ONNXModelWriter(
236 const std::string &modelFilename, runtime::DAGListTy &dagList,
237 size_t irVersion, size_t opsetVersion, Error *errPtr = nullptr,
238 bool textMode = false, bool zipMode = false,
239 bool includeConstantData = true,
240 const llvm::StringMap<std::string> &extraMetadataProps =
241 llvm::StringMap<std::string>(),
242 const ConstantFoldingRecordMap &constFoldRecord =
243 ConstantFoldingRecordMap(),
244 const BackendSpecificNodeInfo &backendSpecificNodeInfo = {},
245 const LoadedPlaceholderNameMap *loadedPHNames = nullptr,
246 const std::map<std::string, Type> *staticPlaceholderTypes = nullptr,
247 std::string *outputStringPtr = nullptr);
248
249private:
250 /// \returns error for the unexpected node kind.
251 static Error writeUnexpectedKind(const Node *node) {
252 return MAKE_ERR(
253 strFormat("Glow can not export node %s, unsupported kind: %s.",
254 node->getName().str().c_str(), node->getKindName()));
255 }
256
257 /// Declares the overriden all pure virtual methods, declared in base class.
258#define DEF_NODE(CLASS, NAME) \
259 Error write##NAME(const CLASS *, GraphType &) override;
260#include "glow/AutoGenNodes.def"
261};
262
263} // namespace glow
264
265#endif // GLOW_EXPORTER_ONNXMODELWRITER_H
266