1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_EXPORTER_ONNXMODELWRITER_H |
18 | #define GLOW_EXPORTER_ONNXMODELWRITER_H |
19 | |
20 | #include "glow/Exporter/CommonOperatorWriter.h" |
21 | #include "glow/Graph/Graph.h" |
22 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
23 | #include "glow/Runtime/RuntimeTypes.h" |
24 | |
25 | #include "onnx/onnx_pb.h" |
26 | |
27 | #include "llvm/ADT/ArrayRef.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | |
30 | #include <list> |
31 | #include <string> |
32 | |
33 | /// ONNX traits for protobuf types. |
34 | struct ONNX_TRAITS { |
35 | using GraphProto = ONNX_NAMESPACE::GraphProto; |
36 | }; |
37 | |
38 | namespace glow { |
39 | |
40 | /// Unique set of visited nodes. |
41 | using ReportedNodes = std::unordered_set<const Node *>; |
42 | |
43 | /// Writes ONNX models. |
44 | class ONNXModelWriter : public CommonOperatorWriter<ONNX_TRAITS> { |
45 | // Declare shorter aliases. |
46 | using GraphType = typename ONNX_TRAITS::GraphProto; |
47 | using NodeType = ONNX_NAMESPACE::NodeProto; |
48 | using TensorType = ONNX_NAMESPACE::TensorProto; |
49 | using AttrType = ONNX_NAMESPACE::AttributeProto; |
50 | using ValueInfoType = ONNX_NAMESPACE::ValueInfoProto; |
51 | |
52 | // ModelProto that we are writing to. |
53 | ONNX_NAMESPACE::ModelProto modelProto_; |
54 | // GraphProto that we are writing to. |
55 | ONNX_TRAITS::GraphProto *graphProto_; |
56 | // Root GraphProto that we are writing to. Equal to \ref graphProto_ unless |
57 | // when writing a constant folding subgraph, when graphProto_ is temporarily |
58 | // changed. |
59 | ONNX_TRAITS::GraphProto *graphProtoRoot_; |
60 | /// Current IR version of ONNX. |
61 | const size_t irVersion_; |
62 | /// Current version of ONNX standard. |
63 | const size_t opsetVersion_; |
64 | /// Keeps the track of already visited or processed nodes. |
65 | ReportedNodes reportedNodes_; |
66 | /// Whether we use zip mode or not |
67 | const bool zipMode_; |
68 | /// Whether we use text mode or not |
69 | const bool textMode_; |
70 | /// Whether to include Constant (initializer) data in the exported proto. |
71 | const bool includeConstantData_; |
72 | /// Extra metadata properties to add to the ONNX file |
73 | const llvm::StringMap<std::string> &; |
74 | /// Whether to use custom ONNX ops. |
75 | const bool useGlowCustomOps_; |
76 | /// Whether we are writing a DAG. |
77 | const bool dagMode_; |
78 | /// A map containing a record of what constant folding took place, to record |
79 | /// in serialized DAGs. |
80 | const ConstantFoldingRecordMap &constFoldRecord_; |
81 | /// Backend-specific node info to include in the exported model. |
82 | const BackendSpecificNodeInfo &backendSpecificNodeInfo_; |
83 | /// Map from Placeholders in the Module to the symbolic name they were loaded |
84 | /// with from the input model. If not null, included in IO doc_string info. |
85 | const LoadedPlaceholderNameMap *loadedPHNames_; |
86 | /// Map from static PH names to the type it was originally loaded with. |
87 | const std::map<std::string, Type> *staticPlaceholderTypes_; |
88 | /// A dedicated list of initializers in case the tensors get too big and don't |
89 | /// fit into the model. |
90 | std::list<TensorType> initializers_; |
91 | /// Holds all Functions from a DAG that are being written when in dagMode_. |
92 | llvm::SmallSet<Function *, 6> functionsFromDAG_; |
93 | /// Holds all constant folding Functions that have been processed. |
94 | llvm::SmallSet<Function *, 6> processedConstFoldFunctions_; |
95 | /// Maps from all non-static input PHs to the generated proto. It's used to |
96 | /// buffer protos; later on written out in order based on \ref loadedPHNames_. |
97 | std::unordered_map<const Placeholder *, ValueInfoType> inputValueInfos_; |
98 | /// Maps from all output PHs to the generated proto. It's used to buffer |
99 | /// protos; later on written out in order based on \ref loadedPHNames_. |
100 | std::unordered_map<const Placeholder *, ValueInfoType> outputValueInfos_; |
101 | /// Output string. Null value indicates that the output is to be written to a |
102 | /// file. |
103 | std::string *outputStringPtr_; |
104 | |
105 | /// Creates and \returns a new ValueInfoType for \p PH based on \p isInput. |
106 | /// It's added either directy to \ref graphProto_, or to \ref inputValueInfos_ |
107 | /// / \ref outputValueInfos_, depending on whether there's an order we need to |
108 | /// serialize the IO in (order comes from \ref loadedPHNames_ if non-null). |
109 | Expected<ValueInfoType *> createProtoForIO(const Placeholder *PH, |
110 | bool isInput); |
111 | /// Writes all inputs and outputs with operator name \p opName from give Node |
112 | /// \p node into protobuf \p proto. |
113 | static Error writeAllWithNode(const std::string &opName, const Node *node, |
114 | GraphType &graph, NodeType *proto); |
115 | /// Writes all inputs and outputs with operator name \p opName from give Node |
116 | /// \p node into created node protobuf using \p graph. |
117 | static Error writeAll(const std::string &opName, const Node *node, |
118 | GraphType &graph); |
119 | |
120 | /// Add an initializer. Depending on \ref zipMode_, it will add directly to |
121 | /// the \p graph or to a separate list. |
122 | TensorType *addInitializer(GraphType &graph); |
123 | |
124 | /// Special case node writer for Glow convolutions with quantized inputs and |
125 | /// outputs. |
126 | Error writeTensorwiseQuantizedConvolution(const ConvolutionNode *node, |
127 | GraphType &graph); |
128 | |
129 | /// Write \p node to \p graph using custom Glow Nodes, exported via |
130 | /// auto-generated export logic in NodeGen. |
131 | Error writeGlowCustomOperator(const Node *node, GraphType &graph); |
132 | |
133 | /// Setup a new proto \ref modelProto_ and \ref graphProto_. |
134 | void setupNewProto(); |
135 | |
136 | /// Write the current Function \ref F_ to \ref graphProto_. \returns if there |
137 | /// was an issue during iteration or writing. |
138 | Error writeFunction(); |
139 | |
140 | /// Given a Constant \p C that was previously created during Constant folding, |
141 | /// Serializes the constant folding Function saved by \p SN, where the |
142 | /// Function is the parent of \p SN. The function is written to an attribute |
143 | /// in a Glow__ConstFoldSubgraph NodeProto. \returns if an Error occurs. |
144 | Error writeConstantFoldingSubgraph(const Constant *C, SaveNode *SN); |
145 | |
146 | /// \returns whether currently writing a constant folding subgraph. |
147 | bool isWritingConstFoldSubgraph(); |
148 | |
149 | /// Finalize the written function and write it out to \p filename. \returns if |
150 | /// there is an error encountered. |
151 | Error finalizeAndWriteProto(llvm::StringRef filename); |
152 | |
153 | /// Adds a metadata prop with \p key and \p val to \ref modelProto_. |
154 | void addMetadataProp(const std::string &key, const std::string &val); |
155 | |
156 | /// Write out the Functions and metadata for all DAGNodes in \p postOrder |
157 | /// given parent \p mod. |
158 | Error writePartitionAndMetadataProps( |
159 | Module &mod, llvm::ArrayRef<const runtime::DAGNode *> postOrder); |
160 | |
161 | /// \returns whether \p PH is an intermediate PH for the DAG being written |
162 | /// (i.e. both input and an output for Functions in \ref functionsFromDAG_). |
163 | bool isIntermediatePHForDAG(const Placeholder *PH); |
164 | |
165 | /// \returns True if the operator with the name \p typeName has support for |
166 | /// multidirectional broadcasting. |
167 | bool hasMultidirectionalBroadcast(const llvm::StringRef typeName); |
168 | |
169 | public: |
170 | /// Inserts the mapping in \p map into \p extraMetadataProps. \returns an |
171 | /// error if the key already exists for the map in \p extraMetadataProps. |
172 | static Error insertLoaderNameUniqueOffsetMetadata( |
173 | llvm::StringMap<std::string> &, |
174 | const OriginNameToTQPMap &map); |
175 | |
176 | /// Converts \p glowType to \p protoType. |
177 | static typename TensorType::DataType convertType(const Type &glowType); |
178 | /// Writes Glow tensor \p T to proto output \p out. Depending on |
179 | /// \p useGlowCustomOps meta info will be annotated differently. |
180 | /// If \p includeData then the data from \p T will be included; otherwise only |
181 | /// the type info and name will be. |
182 | static void writeTensor(const Tensor &T, TensorType *out, |
183 | bool useGlowCustomOps = false, |
184 | bool includeData = true); |
185 | |
186 | /// Creates an ONNX model writer to serialize \p F graph into file |
187 | /// \p modelFilename, writing \p irVersion and \p opsetVersion. |
188 | /// If \p errPtr is not null then if an error occurs it will get assigned |
189 | /// there otherwise if an error occurs it will abort. It also supports |
190 | /// serialization with text format or binary format depending on \p textMode. |
191 | /// If \p zipMode is true, it will save weights into individual TensorProto |
192 | /// file along with the model file and package them into a zip file. If |
193 | /// \p useGlowCustomOps then it will use auto-generated export logic via |
194 | /// NodeGen to export all Glow Nodes as is via custom ops, instead of trying |
195 | /// to abide by the official ONNX ops. If \p includeConstantData then data for |
196 | /// Constants will be serialized in the written model, otherwise it will be |
197 | /// skipped (but initializers will still exist, they will just have no data). |
198 | /// \p extraMetadataProps is a mapping of key value pairs which are added to |
199 | /// the metadata props portion of the ONNX. |
200 | /// \p constFoldRecord contains any records of constant folding that should be |
201 | /// included in the serialized model. |
202 | /// \p backendSpecificNodeInfo contains attributes to add onto Nodes when |
203 | /// exporting if found. |
204 | ONNXModelWriter(const std::string &modelFilename, Function &F, |
205 | size_t irVersion, size_t opsetVersion, |
206 | Error *errPtr = nullptr, bool textMode = false, |
207 | bool zipMode = false, bool useGlowCustomOps = false, |
208 | bool includeConstantData = true, |
209 | const llvm::StringMap<std::string> & = |
210 | llvm::StringMap<std::string>(), |
211 | const ConstantFoldingRecordMap &constFoldRecord = |
212 | ConstantFoldingRecordMap(), |
213 | const BackendSpecificNodeInfo &backendSpecificNodeInfo = {}, |
214 | std::string *outputStringPtr = nullptr); |
215 | |
216 | /// Creates an ONNX model writer to serialize \p dagList into file |
217 | |
218 | /// \p modelFilename, writing \p irVersion and \p opsetVersion. Each partition |
219 | /// from \p dagList will be annotated with the name of the partition to the |
220 | /// op. This exporter requires using \ref useGlowCustomOps_ and sets it true |
221 | /// as such. If \p errPtr is not null then if an error occurs it will get |
222 | /// assigned there otherwise if an error occurs it will abort. It also |
223 | /// supports serialization with text format or binary format depending on |
224 | /// \p textMode. If \p zipMode is true, it will save weights into individual |
225 | /// TensorProto file along with the model file and package them into a zip |
226 | /// file. If \p includeConstantData then data for Constants will be serialized |
227 | /// in the written model, otherwise it will be skipped (but initializers will |
228 | /// still exist, they will just have no data). \p extraMetadataProps is |
229 | /// a mapping of key value pairs which are added to the metadata props portion |
230 | /// of the ONNX. \p constFoldRecord contains any records of constant folding |
231 | /// that should be included in the serialized model. |
232 | /// \p backendSpecificNodeInfo contains attributes to add onto Nodes when |
233 | /// exporting if found. |
234 | |
235 | ONNXModelWriter( |
236 | const std::string &modelFilename, runtime::DAGListTy &dagList, |
237 | size_t irVersion, size_t opsetVersion, Error *errPtr = nullptr, |
238 | bool textMode = false, bool zipMode = false, |
239 | bool includeConstantData = true, |
240 | const llvm::StringMap<std::string> & = |
241 | llvm::StringMap<std::string>(), |
242 | const ConstantFoldingRecordMap &constFoldRecord = |
243 | ConstantFoldingRecordMap(), |
244 | const BackendSpecificNodeInfo &backendSpecificNodeInfo = {}, |
245 | const LoadedPlaceholderNameMap *loadedPHNames = nullptr, |
246 | const std::map<std::string, Type> *staticPlaceholderTypes = nullptr, |
247 | std::string *outputStringPtr = nullptr); |
248 | |
249 | private: |
250 | /// \returns error for the unexpected node kind. |
251 | static Error writeUnexpectedKind(const Node *node) { |
252 | return MAKE_ERR( |
253 | strFormat("Glow can not export node %s, unsupported kind: %s." , |
254 | node->getName().str().c_str(), node->getKindName())); |
255 | } |
256 | |
257 | /// Declares the overriden all pure virtual methods, declared in base class. |
258 | #define DEF_NODE(CLASS, NAME) \ |
259 | Error write##NAME(const CLASS *, GraphType &) override; |
260 | #include "glow/AutoGenNodes.def" |
261 | }; |
262 | |
263 | } // namespace glow |
264 | |
265 | #endif // GLOW_EXPORTER_ONNXMODELWRITER_H |
266 | |