1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_IMPORTER_ONNXMODELLOADER_H |
18 | #define GLOW_IMPORTER_ONNXMODELLOADER_H |
19 | |
20 | #include "glow/Graph/Graph.h" |
21 | #include "glow/Importer/CommonOperatorLoader.h" |
22 | |
23 | #include "onnx/onnx_pb.h" |
24 | |
25 | #include "llvm/ADT/ArrayRef.h" |
26 | #include "llvm/ADT/StringRef.h" |
27 | |
28 | #include "google/protobuf/io/coded_stream.h" |
29 | #include "google/protobuf/io/zero_copy_stream_impl.h" |
30 | #include <fstream> |
31 | #include <string> |
32 | #include <unordered_set> |
33 | |
34 | namespace ONNX_NAMESPACE { |
35 | class AttributeProto; |
36 | class NodeProto; |
37 | class GraphProto; |
38 | class ModelProto; |
39 | class TensorProto; |
40 | } // namespace ONNX_NAMESPACE |
41 | |
42 | namespace glow { |
43 | |
44 | /// Loads tensor \p T from the input \p in. \p useGlowCustomOps changes the |
45 | /// format for doc_string format for adding meta information. |
46 | Error loadTensor(const ONNX_NAMESPACE::TensorProto &in, Tensor *T, |
47 | bool useGlowCustomOps = false, const std::string &data = "" ); |
48 | |
49 | /// Parses as input file name \p fileName which is an ONNX file |
50 | /// and \returns a parsed GraphProto. |
51 | ONNX_NAMESPACE::GraphProto parseOnnxFile(const std::string &fileName); |
52 | |
53 | /// Takes an ONNX file in \p fileName reads it and loads the tensors |
54 | /// in \p bindings. If the tensors loaded from the underlying file |
55 | /// Are smaller than what the placeholder for that tensor expects it gets |
56 | /// Padded with 0 if \p partialTensorPayloads is nullptr other wise |
57 | /// \p partialTensorPayloads holds the data for full tensors. |
58 | /// If \p usingGlowCustomOps then the custom Glow ONNX format will be |
59 | /// expected/used to load from the ONNX file. |
60 | void fillPlaceholders(const std::string &fileName, |
61 | PlaceholderBindings *bindings, |
62 | std::vector<Tensor> *partialTensorPayloads = nullptr, |
63 | bool usingGlowCustomOps = false); |
64 | |
65 | /// Override that takes \p parsedFile as a parsed file instead of file name. |
66 | void fillPlaceholders(const ONNX_NAMESPACE::GraphProto &parsedFile, |
67 | PlaceholderBindings *bindings, |
68 | std::vector<Tensor> *partialTensorPayloads = nullptr, |
69 | bool usingGlowCustomOps = false); |
70 | |
71 | /// Define undefined symbols to \p str loaded from an ONNX proto. See |
72 | /// onnxDefineSymbolOpt in ONNXModelLoader.cpp. |
73 | void setOnnxDefineSymbol(const std::vector<std::string> &lst); |
74 | |
75 | /// Loads ONNX models. |
76 | class ONNXModelLoader |
77 | : public CommonOperatorLoader<ONNX_NAMESPACE::NodeProto, |
78 | ONNX_NAMESPACE::AttributeProto> { |
79 | /// \returns True if the operator has broadcasting activated. |
80 | Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) override; |
81 | |
82 | /// \returns True if the operator with the name \p typeName has support for |
83 | /// multidirectional broadcasting. |
84 | bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) override; |
85 | |
86 | /// Converts a ONNX TensorProto DataType enum to the Glow element type. |
87 | /// Supports only non quantized and signed types. |
88 | Expected<ElemKind> |
89 | convertTensorProtoDataType(ONNX_NAMESPACE::TensorProto_DataType t); |
90 | |
91 | /// Load the operator \p op into the network. This creates one or more nodes |
92 | /// in the network. \returns Error if operator \p op cannot be loaded. |
93 | Error loadOperator(const ONNX_NAMESPACE::NodeProto &op); |
94 | |
95 | /// \returns a TypeRef found in \p dict which is loaded and uniqued into the |
96 | /// Module. The TypeRef is represented in the ONNX proto by concatenating the |
97 | /// relevant members of a type, (ElemKind, Shape, and Scale and Offset if |
98 | /// ElemKind is quantized) with \p resNo. |
99 | Expected<TypeRef> loadTypeFromAttributes(unsigned resNo, |
100 | ArgumentDictionaryTy &dict); |
101 | |
102 | /// If this is a custom Glow op that was exported via NodeGen automatic export |
103 | /// logic, try to load the op. \returns Expected<true> if the op is |
104 | /// successfully loaded. \returns Expected<false> if op type is not supported. |
105 | /// \returns an Error if an error occurred while trying to load, or otherwise |
106 | /// the single Node that was created. |
107 | Expected<Node *> tryLoadGlowCustomOp(llvm::StringRef typeName, |
108 | const ONNX_NAMESPACE::NodeProto &op, |
109 | ArgumentDictionaryTy &dict); |
110 | |
111 | /// \returns True if the operator\ op is successfully folded. |
112 | Expected<bool> foldOperator(const ONNX_NAMESPACE::NodeProto &op); |
113 | |
114 | /// Helper function to print better log information for operator failure cases |
115 | const std::string opErrMsg(const ONNX_NAMESPACE::NodeProto &op, |
116 | const std::string &errMsg); |
117 | |
118 | /// ONNX model ir_version; |
119 | size_t irVersion_; |
120 | |
121 | /// ONNX model op_version; |
122 | size_t opsetVersion_; |
123 | |
124 | /// Whether we're loading an ONNX file exported using Glow custom ops. |
125 | bool useGlowCustomOps_{false}; |
126 | |
127 | /// A set of inputs which will be static placeholders. |
128 | std::unordered_set<std::string> staticInputs_; |
129 | |
130 | /// A set of Functions used for ConstantFolding to be deleted after loading. |
131 | std::unordered_set<Function *> constFoldFuns_; |
132 | |
133 | /// Load ONNX NonZero Operator. |
134 | /// Glow's requirement for static shapes results in required Constant |
135 | /// input. Thus, the operator will be folded in the Importer. |
136 | Error loadNonZero(const ONNX_NAMESPACE::NodeProto &op, |
137 | const ArgumentDictionaryTy &dict); |
138 | |
139 | /// Load Trigonometric Ops |
140 | Error loadAsin(const ONNX_NAMESPACE::NodeProto &op, |
141 | const ArgumentDictionaryTy &dict); |
142 | |
143 | Error loadAcos(const ONNX_NAMESPACE::NodeProto &op, |
144 | const ArgumentDictionaryTy &dict); |
145 | |
146 | /// Load Erf ONNX operator |
147 | Error loadErf(const ONNX_NAMESPACE::NodeProto &op, |
148 | const ArgumentDictionaryTy &dict); |
149 | |
150 | Error loadAtan(const ONNX_NAMESPACE::NodeProto &op, |
151 | const ArgumentDictionaryTy &dict); |
152 | |
153 | /// Load Constant ONNX operator. |
154 | Error loadConstant(const ONNX_NAMESPACE::NodeProto &op, |
155 | ArgumentDictionaryTy &dict); |
156 | |
157 | /// Helper function for ONNX range operator |
158 | template <typename T> |
159 | Error getRange(const ONNX_NAMESPACE::NodeProto &op, Constant *constT); |
160 | |
161 | /// Load Range ONNX operator |
162 | Error loadRange(const ONNX_NAMESPACE::NodeProto &op, |
163 | ArgumentDictionaryTy &dict); |
164 | |
165 | /// Load PRelu ONNX operator. |
166 | Error loadPRelu(const ONNX_NAMESPACE::NodeProto &op, |
167 | ArgumentDictionaryTy &dict); |
168 | |
169 | /// Load Slice ONNX operator. |
170 | Error loadSlice(const ONNX_NAMESPACE::NodeProto &op, |
171 | ArgumentDictionaryTy &dict); |
172 | |
173 | /// Load Trignometric ONNX operators. |
174 | Error loadTrigonometricOps(const std::string &typeName, |
175 | const ONNX_NAMESPACE::NodeProto &op, |
176 | ArgumentDictionaryTy &dict); |
177 | |
178 | /// Load Sign ONNX operator |
179 | Error loadSign(const ONNX_NAMESPACE::NodeProto &op, |
180 | const ArgumentDictionaryTy &dict); |
181 | |
182 | /// Load Softmax ONNX operator |
183 | Error loadSoftmax(const ONNX_NAMESPACE::NodeProto &op, |
184 | const ArgumentDictionaryTy &dict); |
185 | |
186 | /// Load LogSoftmax ONNX operator |
187 | Error loadLogSoftmax(const ONNX_NAMESPACE::NodeProto &op, |
188 | const ArgumentDictionaryTy &dict); |
189 | |
190 | /// Load ScatterData ONNX operator |
191 | Error loadScatterData(const ONNX_NAMESPACE::NodeProto &op, |
192 | const ArgumentDictionaryTy &dict); |
193 | |
194 | /// Load TopK ONNX operator |
195 | Error loadTopK(const ONNX_NAMESPACE::NodeProto &op, |
196 | ArgumentDictionaryTy &dict); |
197 | |
198 | /// Load Conv ONNX operator. |
199 | Error loadConv(const ONNX_NAMESPACE::NodeProto &op, |
200 | ArgumentDictionaryTy &dict); |
201 | |
202 | /// Load ChannelwiseQuantizedConvolution Glow operator. |
203 | Error loadChannelwiseQuantizedConvolution(const ONNX_NAMESPACE::NodeProto &op, |
204 | ArgumentDictionaryTy &dict); |
205 | |
206 | /// Load Glow conv operator with quantized inputs. Since this isn't a normal |
207 | /// part of the ops supported by onnx, the assumption is that this op was |
208 | /// produced by Glow's on ONNXModelWriter and thus has NHWC layout for inputs. |
209 | Error loadTensorwiseQuantizedConvolution(const ONNX_NAMESPACE::NodeProto &op, |
210 | ArgumentDictionaryTy &dict); |
211 | /// Load ConvTranspose ONNX operator. |
212 | Error loadConvTranspose(const ONNX_NAMESPACE::NodeProto &op, |
213 | ArgumentDictionaryTy &dict); |
214 | |
215 | /// Load Conv1D operator. |
216 | /// As per conv operation definition at |
217 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv , |
218 | /// input is in format (NxCxD1x...xDn). If the input tensor dimension size is |
219 | /// 3 , we have kernel size of only 1 dimension and we call such a conv |
220 | /// operation as conv1d. |
221 | /// Conv1d is implemented using Conv2d as follows: |
222 | /// a) Expand the input and kernel dimension to 4 using expand operator |
223 | /// b) Do the necessary tensor format conversion as required for Conv2d |
224 | /// c) Then use Conv2d for execution |
225 | /// d) To reduce the output tensor dimension from 4 to 3, Squeeze is used |
226 | Error loadConv1D(const ONNX_NAMESPACE::NodeProto &op, |
227 | ArgumentDictionaryTy &dict); |
228 | |
229 | /// Load Conv operator with 2D input |
230 | Error loadConv2D(const ONNX_NAMESPACE::NodeProto &op, |
231 | ArgumentDictionaryTy &dict); |
232 | |
233 | /// Load Conv operator with 3D input |
234 | Error loadConv3D(const ONNX_NAMESPACE::NodeProto &op, |
235 | ArgumentDictionaryTy &dict); |
236 | |
237 | /// Load MaxPool or AveragePool ONNX operator. \p typeName is the name of the |
238 | /// ONNX operator being loaded, either MaxPool or AveragePool. |
239 | Error loadPool(const ONNX_NAMESPACE::NodeProto &op, |
240 | ArgumentDictionaryTy &dict, llvm::StringRef typeName); |
241 | |
242 | /// Load Glow pooling operator with quantized inputs. Since this isn't a |
243 | /// normal part of the ops supported by onnx, the assumption is that this op |
244 | /// was produced by Glow's on ONNXModelWriter and thus has NHWC layout for |
245 | /// inputs. |
246 | Error loadTensorwiseQuantizedPool(const ONNX_NAMESPACE::NodeProto &op, |
247 | ArgumentDictionaryTy &dict, |
248 | llvm::StringRef typeName); |
249 | |
250 | /// Load GlobalAveragePool ONNX operator. |
251 | Error loadGlobalAveragePool(const ONNX_NAMESPACE::NodeProto &op, |
252 | ArgumentDictionaryTy &dict); |
253 | |
254 | /// Load Squeeze ONNX operator. |
255 | Error loadSqueeze(const ONNX_NAMESPACE::NodeProto &op, |
256 | ArgumentDictionaryTy &dict); |
257 | |
258 | /// Load Unsqueeze ONNX operator. |
259 | Error loadUnsqueeze(const ONNX_NAMESPACE::NodeProto &op, |
260 | ArgumentDictionaryTy &dict); |
261 | |
262 | /// Load ArgMax and ArgMin ONNX operators. |
263 | Error loadArgMinMax(const ONNX_NAMESPACE::NodeProto &op, |
264 | ArgumentDictionaryTy &dict, bool isMin); |
265 | |
266 | /// Load Upsample ONNX operator. |
267 | Error loadUpsample(const ONNX_NAMESPACE::NodeProto &op, |
268 | ArgumentDictionaryTy &dict); |
269 | |
270 | /// Load Resize ONNX Operator. |
271 | Error loadResize(const ONNX_NAMESPACE::NodeProto &op, |
272 | const ArgumentDictionaryTy &dict); |
273 | |
274 | /// Load BatchNormalization ONNX operator. |
275 | Error loadBatchNormalization(const ONNX_NAMESPACE::NodeProto &op, |
276 | ArgumentDictionaryTy &dict); |
277 | |
278 | /// Load InstanceNormalization ONNX operator. |
279 | Error loadInstanceNormalization(const ONNX_NAMESPACE::NodeProto &op, |
280 | ArgumentDictionaryTy &dict); |
281 | |
282 | /// Load Concat ONNX operator. |
283 | Error loadConcat(const ONNX_NAMESPACE::NodeProto &op, |
284 | ArgumentDictionaryTy &dict); |
285 | |
286 | /// Load FCTransposed ONNX operator. |
287 | Error loadFCTransposed(const ONNX_NAMESPACE::NodeProto &op, |
288 | ArgumentDictionaryTy &dict); |
289 | |
290 | /// Load Gemm ONNX operator. |
291 | Error loadGemm(const ONNX_NAMESPACE::NodeProto &op, |
292 | ArgumentDictionaryTy &dict); |
293 | |
294 | /// Load MatMul ONNX operator. |
295 | Error loadMatMul(const ONNX_NAMESPACE::NodeProto &op, |
296 | ArgumentDictionaryTy &dict); |
297 | |
298 | /// Load Pad ONNX operator. |
299 | Error loadPad(const ONNX_NAMESPACE::NodeProto &op, |
300 | ArgumentDictionaryTy &dict); |
301 | |
302 | /// Load Cast ONNX operator. |
303 | Error loadCast(const ONNX_NAMESPACE::NodeProto &op, |
304 | ArgumentDictionaryTy &dict); |
305 | |
306 | /// Load HardSigmoid ONNX operator. |
307 | Error loadHardSigmoid(const ONNX_NAMESPACE::NodeProto &op, |
308 | ArgumentDictionaryTy &dict); |
309 | |
310 | /// Load LeakyRelu ONNX operator. |
311 | Error loadLeakyRelu(const ONNX_NAMESPACE::NodeProto &op, |
312 | ArgumentDictionaryTy &dict); |
313 | |
314 | /// Load SpaceToDepth ONNX operator. |
315 | Error loadSpaceToDepth(const ONNX_NAMESPACE::NodeProto &op, |
316 | ArgumentDictionaryTy &dict); |
317 | |
318 | /// Load ReduceL2 ONNX operator |
319 | Error loadReduceL2(const ONNX_NAMESPACE::NodeProto &op, |
320 | const ArgumentDictionaryTy &dict); |
321 | |
322 | /// Load DepthToSpace ONNX operator. |
323 | Error loadDepthToSpace(const ONNX_NAMESPACE::NodeProto &op, |
324 | const ArgumentDictionaryTy &dict); |
325 | |
326 | /// Load ConstantOfShape ONNX operator. |
327 | Error loadConstantOfShape(const ONNX_NAMESPACE::NodeProto &op, |
328 | ArgumentDictionaryTy &dict, bool isSplat); |
329 | |
330 | /// Load Tile ONNX operator. |
331 | Error loadTile(const ONNX_NAMESPACE::NodeProto &op, |
332 | ArgumentDictionaryTy &dict); |
333 | |
334 | /// Load Expand ONNX operator. |
335 | Error loadExpand(const ONNX_NAMESPACE::NodeProto &op, |
336 | const ArgumentDictionaryTy &dict); |
337 | |
338 | /// Load Where ONNX operator. |
339 | Error loadWhere(const ONNX_NAMESPACE::NodeProto &op, |
340 | ArgumentDictionaryTy &dict); |
341 | |
342 | /// Load RNN ONNX operator. |
343 | Error loadRNN(const ONNX_NAMESPACE::NodeProto &op, |
344 | ArgumentDictionaryTy &dict); |
345 | |
346 | /// Load GRU ONNX operator. |
347 | Error loadGRU(const ONNX_NAMESPACE::NodeProto &op, |
348 | ArgumentDictionaryTy &dict); |
349 | |
350 | /// Load LSTM ONNX operator. |
351 | Error loadLSTM(const ONNX_NAMESPACE::NodeProto &op, |
352 | ArgumentDictionaryTy &dict); |
353 | |
354 | /// Load Clip ONNX operator. |
355 | Error loadClip(const ONNX_NAMESPACE::NodeProto &op, |
356 | const ArgumentDictionaryTy &dict); |
357 | |
358 | /// Load Glow specific operators, not defined in ONNX format |
359 | /// Load Glow CmpEQ operator. |
360 | Error loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op, |
361 | ArgumentDictionaryTy &dict); |
362 | |
363 | /// Load Glow CmpLTE operator. |
364 | Error loadCmpLTE(const ONNX_NAMESPACE::NodeProto &op, |
365 | ArgumentDictionaryTy &dict); |
366 | |
367 | /// Load Mean ONNX operator |
368 | Error loadMean(const ONNX_NAMESPACE::NodeProto &op, |
369 | ArgumentDictionaryTy &dict); |
370 | |
371 | /// Load Glow Select operator. |
372 | Error loadSelect(const ONNX_NAMESPACE::NodeProto &op, |
373 | ArgumentDictionaryTy &dict); |
374 | |
375 | /// Load Glow Quantize operator. |
376 | Error loadQuantize(const ONNX_NAMESPACE::NodeProto &op, |
377 | ArgumentDictionaryTy &dict); |
378 | |
379 | /// Load Onnx QuantizeLinear operator. |
380 | Error loadQuantizeLinear(const ONNX_NAMESPACE::NodeProto &op, |
381 | ArgumentDictionaryTy &dict); |
382 | |
383 | /// Load Glow ConvertTo operator. |
384 | Error loadConvertTo(const ONNX_NAMESPACE::NodeProto &op, |
385 | ArgumentDictionaryTy &dict); |
386 | |
387 | /// Load Glow Dequantize operator. |
388 | Error loadDequantize(const ONNX_NAMESPACE::NodeProto &op, |
389 | ArgumentDictionaryTy &dict); |
390 | |
391 | /// Load Glow Regression operator. |
392 | Error loadRegression(const ONNX_NAMESPACE::NodeProto &op, |
393 | ArgumentDictionaryTy &dict); |
394 | |
395 | /// Load Glow BatchedAdd operator. |
396 | Error loadBatchedAdd(const ONNX_NAMESPACE::NodeProto &op, |
397 | ArgumentDictionaryTy &dict); |
398 | |
399 | /// Load Glow CumSum operator. |
400 | Error loadCumSum(const ONNX_NAMESPACE::NodeProto &op, |
401 | ArgumentDictionaryTy &dict); |
402 | |
403 | /// Load Glow ScatterAssign operator. |
404 | Error loadScatterAssign(const ONNX_NAMESPACE::NodeProto &op, |
405 | ArgumentDictionaryTy &dict); |
406 | |
407 | /// Load Glow IntLookupTable operator. |
408 | Error loadIntLookupTable(const ONNX_NAMESPACE::NodeProto &op, |
409 | ArgumentDictionaryTy &dict); |
410 | |
411 | /// Load Glow LengthsRangeFill operator. |
412 | Error loadLengthsRangeFill(const ONNX_NAMESPACE::NodeProto &op, |
413 | ArgumentDictionaryTy &dict); |
414 | |
415 | /// Load Glow RescaleQuantized operator. |
416 | Error loadRescaleQuantized(const ONNX_NAMESPACE::NodeProto &op, |
417 | ArgumentDictionaryTy &dict); |
418 | |
419 | /// Load Glow RowwiseQuantizedSparseLengthsWeightedSum operator. |
420 | Error loadRowwiseQuantizedSparseLengthsWeightedSum( |
421 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict); |
422 | |
423 | /// Load Glow FusedRowwiseQuantizedSparseLengthsWeightedSum operator. |
424 | Error loadFusedRowwiseQuantizedSparseLengthsWeightedSum( |
425 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict); |
426 | |
427 | /// Load Glow FusedRowwiseQuantizedSparseLengthsSum operator. |
428 | Error |
429 | loadFusedRowwiseQuantizedSparseLengthsSum(const ONNX_NAMESPACE::NodeProto &op, |
430 | ArgumentDictionaryTy &dict); |
431 | |
432 | /// Load Glow RowwiseQuantizedFullyConnected operator. |
433 | Error loadRowwiseQuantizedFullyConnected(const ONNX_NAMESPACE::NodeProto &op, |
434 | ArgumentDictionaryTy &dict); |
435 | |
436 | /// Load Glow FullyConnected operator. |
437 | Error loadFullyConnected(const ONNX_NAMESPACE::NodeProto &op, |
438 | ArgumentDictionaryTy &dict); |
439 | |
440 | /// Load ONNX Identity operator. |
441 | Error loadIdentity(const ONNX_NAMESPACE::NodeProto &op, |
442 | ArgumentDictionaryTy &dict); |
443 | |
444 | /// Load Glow Splat operator. |
445 | Error loadSplat(const ONNX_NAMESPACE::NodeProto &op, |
446 | ArgumentDictionaryTy &dict); |
447 | |
448 | /// Load NonMaxSuppression ONNX and TF NMSv4 operator. |
449 | /// The \p isV4 indicates whether this is ONNX or custom NMSv4 operator. |
450 | Error loadNonMaxSuppression(const ONNX_NAMESPACE::NodeProto &op, |
451 | ArgumentDictionaryTy &dict, bool isV4); |
452 | |
453 | /// Load Glow InsertTensor operator. |
454 | Error loadInsertTensor(const ONNX_NAMESPACE::NodeProto &op, |
455 | ArgumentDictionaryTy &dict); |
456 | |
457 | /// Load If ONNX operator. |
458 | Error loadIf(const ONNX_NAMESPACE::NodeProto &op, |
459 | const ArgumentDictionaryTy &dict); |
460 | |
461 | /// Load AdaptiveAvgPool Glow operator. |
462 | /// NOTE: since this operator is not a standard onnx op, assume this is from |
463 | /// OnnxModelWriter and is therefore in NHWC format. |
464 | Error loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op, |
465 | ArgumentDictionaryTy &dict); |
466 | |
467 | /// Load Flip Glow operator. |
468 | Error loadFlip(const ONNX_NAMESPACE::NodeProto &op, |
469 | ArgumentDictionaryTy &dict); |
470 | |
471 | /// Load AudioSpectrogram Glow operator. |
472 | Error loadAudioSpectrogram(const ONNX_NAMESPACE::NodeProto &op, |
473 | ArgumentDictionaryTy &dict); |
474 | |
475 | /// Load Loop operator. |
476 | Error loadLoop(const ONNX_NAMESPACE::NodeProto &op, |
477 | const ArgumentDictionaryTy &dict); |
478 | |
479 | /// Load MFCC Glow operator. |
480 | Error loadMFCC(const ONNX_NAMESPACE::NodeProto &op, |
481 | ArgumentDictionaryTy &dict); |
482 | |
483 | // Load ROIAlign ONNX operator |
484 | Error loadROIAlign(const ONNX_NAMESPACE::NodeProto &op, |
485 | ArgumentDictionaryTy &dict); |
486 | |
487 | protected: |
488 | /// Loads operators from \p net. If \p loadingConstFoldSubgraph then the |
489 | /// current Function \ref G_ is assumed to be the one to load into. |
490 | /// \returns Error if network cannot be loaded. |
491 | Error loadNetwork(ONNX_NAMESPACE::GraphProto &net, |
492 | bool loadingConstFoldSubgraph); |
493 | |
494 | /// Set the output nodes of the network \p net. Initializes the map from the |
495 | /// names of the outputs to the save nodes that save each output. |
496 | /// \returns Error if network cannot be loaded. |
497 | Error setOutputNodes(ONNX_NAMESPACE::GraphProto &net); |
498 | |
499 | /// Set ir verion and op version. |
500 | Error setVersion(ONNX_NAMESPACE::ModelProto MP); |
501 | |
502 | /// \returns Expected<ModelProto> if a ModelProto can be loaded from the |
503 | /// stream \p iStream. |
504 | static Expected<ONNX_NAMESPACE::ModelProto> |
505 | loadProto(google::protobuf::io::ZeroCopyInputStream &iStream); |
506 | |
507 | /// Load the network initializers from the GraphProto. |
508 | Error loadInitializers(ONNX_NAMESPACE::GraphProto &net); |
509 | |
510 | /// Given some initializer \p in, check if it has some constant folding node |
511 | /// associated with it in \p net. If so, deserializes the Function if not |
512 | /// already done, performs the constant folding, and \returns the Constant |
513 | /// created as a result to be used for this initializer. |
514 | Expected<Constant *> |
515 | replaySerializedConstFold(const ONNX_NAMESPACE::TensorProto &in, |
516 | ONNX_NAMESPACE::GraphProto &net); |
517 | |
518 | /// Given some \p outputName that maps to a NodeValue that we want to constant |
519 | /// fold, run it and assign the resulting Constant \p initializerName. |
520 | Expected<Constant *> runDeserializedConstFold(llvm::StringRef initializerName, |
521 | llvm::StringRef outputName); |
522 | |
523 | /// Load the inputs from the GraphProto. If \p loadInputsAsPlaceholdersForOnnx |
524 | /// is true then this will load each graph input as a placeholder otherwise it |
525 | /// will create an empty tensor for each input. |
526 | Error loadInputs(ONNX_NAMESPACE::GraphProto &net, |
527 | bool loadInputsAsPlaceholdersForOnnx); |
528 | |
529 | /// \returns whether there's an issue with pre-existing \p S with name \p |
530 | /// name, \p ty, \p layout, and \p trainable (for Placeholders). |
531 | Error verifyPreexistingStorage(const Storage *S, const std::string &name, |
532 | const Type &ty, const std::string &layout, |
533 | const bool trainable = false); |
534 | |
535 | /// \returns Expected<ModelProto> if a ModelProto can be constructed from the |
536 | /// contents of the file \p filename and Error otherwise. |
537 | /// Loads ModelProto from the file containing serialized protobuf. |
538 | /// If \p zipMode then zip format will be expected/loaded. |
539 | static Expected<ONNX_NAMESPACE::ModelProto> |
540 | loadProto(const std::string &filename, bool zipMode, |
541 | const std::string *inputStringPtr); |
542 | |
543 | /// \returns Expected<ModelProto> if a ModelProto can be constructed from the |
544 | /// in-memory serialized protobuf. |
545 | /// Loads ModelProto from the in-memory serialized protobuf \p |
546 | /// onnxModel with the model size \p onnxModelSize. |
547 | static Expected<ONNX_NAMESPACE::ModelProto> loadProto(const void *onnxModel, |
548 | size_t onnxModelSize); |
549 | |
550 | /// Checks that the inputs tensors are compatible with the inputs declared in |
551 | /// the ONNX model. The input types in \p types match the list of names |
552 | /// \p tensorNames. |
553 | Error checkInputs(ONNX_NAMESPACE::GraphProto &net, |
554 | llvm::ArrayRef<const char *> tensorNames, |
555 | llvm::ArrayRef<TypeRef> types); |
556 | |
557 | /// Go through the ValueInfoProto of the inputs of the \p net and collect |
558 | /// static placeholders if it's marked in the ValueInfoProto. |
559 | Error collectStaticInputs(ONNX_NAMESPACE::GraphProto &net); |
560 | |
561 | /// Looks through all ops in \p net for any dummy static PH nodes carrying the |
562 | /// type that was used for loading deferred weights initially. If found then |
563 | /// they're added to \ref staticPlaceholderTypes_. If \ref |
564 | /// staticPlaceholderTypes_ is a nullptr then this method is a no-op. |
565 | Error setupOrigStaticTypeMap(ONNX_NAMESPACE::GraphProto &net); |
566 | |
567 | /// Associate all inputs of \p net with nodes in \p NVs. Number of inputs of |
568 | /// \p net should match the number of elements of \p NVs. |
569 | /// \returns error code in case of error. |
570 | Error assignGraphInputs(const ONNX_NAMESPACE::GraphProto &net, |
571 | llvm::ArrayRef<NodeValue> NVs); |
572 | |
573 | /// Creates a ONNX model loader to build \p F. |
574 | /// Loads the ONNIXFI \p model from memory of \p modelSize size, |
575 | /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent |
576 | /// descriptors. Converts inputs into placeholder if requested \p |
577 | /// loadInputsAsPlaceholdersForOnnx. Reports success/failure through optional |
578 | /// parameter \p errPtr. This constructor always overrides the default |
579 | /// constant folding in loader flag with \p constFoldInLoader. |
580 | ONNXModelLoader(const void *model, uint32_t modelSize, uint32_t weightsCount, |
581 | const onnxTensorDescriptorV1 *weightDescriptors, Function &F, |
582 | bool loadInputsAsPlaceholdersForOnnx, Error *errPtr = nullptr, |
583 | bool constFoldInLoader = true, |
584 | BackendSpecificNodeInfo *perNodeOpts = nullptr); |
585 | |
586 | /// Creates a ONNX model loader to build \p mod. |
587 | /// Loads the ONNIXFI \p model from memory of \p modelSize size, |
588 | /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent |
589 | /// descriptors. Converts inputs into placeholder if requested \p |
590 | /// loadInputsAsPlaceholdersForOnnx. Reports success/failure through optional |
591 | /// parameter \p errPtr. This constructor always overrides the default |
592 | /// constant folding in loader flag with \p constFoldInLoader. |
593 | /// Supports loading a DAG which was serialized, loading in DAGNode meta info |
594 | /// into \p PPC which can be later used to recreated the DAG. \p funName is |
595 | /// used to setup the DAG root node's name, or if the input model is not |
596 | /// partitioned then is used as the name of the single Function loaded. Loads |
597 | /// backend-specific node info annotations into \p perNodeOpts. |
598 | /// \p staticPlaceholderTypes will be filled with types to use for static |
599 | /// Placeholders if the proto being parsed contains such information; |
600 | /// otherwise it's left unchanged. If \p replaceDummyTQPs then any dummy TQPs |
601 | /// (represented by scale=0.f) will be replaced by updated TQPs found in |
602 | /// metadata_props, allowing for changing of TQPs of serialized models. |
603 | ONNXModelLoader(const void *model, uint32_t modelSize, uint32_t weightsCount, |
604 | const onnxTensorDescriptorV1 *weightDescriptors, Module &mod, |
605 | llvm::StringRef funName, runtime::PrePartitionedConfig *PPC, |
606 | bool loadInputsAsPlaceholdersForOnnx, Error *errPtr = nullptr, |
607 | bool constFoldInLoader = true, |
608 | BackendSpecificNodeInfo *perNodeOpts = nullptr, |
609 | std::map<std::string, Type> *staticPlaceholderTypes = nullptr, |
610 | bool replaceDummyTQPs = false, |
611 | bool clipQuantRangeToFP16 = false); |
612 | |
613 | friend class ONNXIFIModelLoader; |
614 | |
615 | /// \returns success if the folding of operator \p op in the loader |
616 | /// \p loader is successful. The folding utility uses temporary |
617 | /// loader \p tmpLoader, and associated temporary function \p F. |
618 | template <class LoaderType, class OpType> |
619 | friend Error constantFoldInLoader(Function *F, LoaderType &tmpLoader, |
620 | LoaderType *loader, const OpType &op); |
621 | |
622 | /// \returns a Type with the proper shape and element type given \p in. |
623 | Expected<Type> getTensorType(const ONNX_NAMESPACE::ValueInfoProto &in); |
624 | |
625 | /// \returns a Type with the proper shape and element type given \p in. |
626 | Expected<Type> getTensorType(const ONNX_NAMESPACE::TensorProto &in); |
627 | |
628 | /// Load a model \p modelDef given \p tensorNames, \p types, \p B, and |
629 | /// \p loadInputsAsPlaceholdersForOnnx. |
630 | Error loadModel(ONNX_NAMESPACE::ModelProto &modelDef, |
631 | llvm::ArrayRef<const char *> tensorNames, |
632 | llvm::ArrayRef<TypeRef> types, const Backend *B, |
633 | bool loadInputsAsPlaceholdersForOnnx); |
634 | |
635 | /// Setup partitions by creating Functions and loading metadata into \p PPC |
636 | /// from the metadata props found in \p modelDef given \p rootName and |
637 | /// \p numPartitions. |
638 | Error setupPartitions(ONNX_NAMESPACE::ModelProto &modelDef, |
639 | runtime::PrePartitionedConfig &PPC, |
640 | llvm::StringRef rootName, int numPartitions); |
641 | |
642 | /// Deletes the Functions in \ref constFoldFuns_ from \ref mod_. |
643 | void deleteConstFoldFunctions(); |
644 | |
645 | /// Sets up positional IO into \ref positionalInputNames_ and |
646 | /// \ref positionalOutputNames_ from \p graph. |
647 | void setupPositionalIO(const ONNX_NAMESPACE::GraphProto &graph); |
648 | |
649 | /// Sets up \ref updatedTQPs_ based on metadata props found in \p modelDef as |
650 | /// well as \p weightsCount weights in \p weightDescriptors. |
651 | Error setupUpdatedTQPMap(ONNX_NAMESPACE::ModelProto &modelDef, |
652 | uint32_t weightsCount, |
653 | const onnxTensorDescriptorV1 *weightDescriptors); |
654 | |
655 | public: |
656 | /// \returns ONNX model ir_version; |
657 | size_t getIrVersion() const { return irVersion_; }; |
658 | |
659 | /// \returns ONNX model op_version; |
660 | size_t getOpSetVersion() const { return opsetVersion_; }; |
661 | |
662 | /// \returns if the loader is loading a proto using custom Glow ops. |
663 | bool usingGlowCustomOps() const { return useGlowCustomOps_; }; |
664 | |
665 | /// Creates a ONNX model loader to build \p F. |
666 | /// If \p errPtr is not null then if an error occurs it will get assigned |
667 | /// there otherwise if an error occurs it will abort. |
668 | ONNXModelLoader(Function &F, Error *errPtr = nullptr); |
669 | |
670 | /// Update \p inTensorNames and \p inTypes from inputs of onnx model from |
671 | /// filename |
672 | static Error getInputsNamesAndTypes(std::vector<std::string> &inTensorNames, |
673 | std::vector<Type> &inTypes, |
674 | const std::string &filename); |
675 | |
676 | /// Loads the ONNX model that's represented by a model description file, |
677 | /// serialized in \p modelDescFilename and populates the network into \p F. |
678 | /// The types in \p types match the list of names \p tensorNames and used as |
679 | /// inputs to the network. |
680 | /// If \p names and \p types are empty loader fills inputs automatically. |
681 | /// If \p errPtr is not null then if an error occurs it will get assigned |
682 | /// there otherwise if an error occurs it will abort. |
683 | /// If \p disableConstFoldInLoader then constant folding will be disabled |
684 | /// during loading. \p B will be used during function verification after |
685 | /// loading. If \p loadIntoExistingModule then all Functions and Storage is |
686 | /// expected to already exist, so they will be searched for according to the |
687 | /// proto being loaded instead of created as usual. |
688 | ONNXModelLoader(const std::string &modelDescFilename, |
689 | llvm::ArrayRef<const char *> tensorNames, |
690 | llvm::ArrayRef<TypeRef> types, Function &F, |
691 | Error *errPtr = nullptr, bool zipMode = false, |
692 | BackendSpecificNodeInfo *perNodeOpts = nullptr, |
693 | bool disableConstFoldInLoader = false, |
694 | bool loadIntoExistingModule = false, |
695 | const Backend *B = nullptr, |
696 | const std::string *inputStringPtr = nullptr); |
697 | |
698 | /// Loads the ONNX model that's represented by a model description file, |
699 | /// serialized in \p modelDescFilename and populates the network into \p mod. |
700 | /// Supports loading a DAG which was serialized, loading in DAGNode meta info |
701 | /// into \p PPC which can be later used to recreated the DAG. \p funName is |
702 | /// used to setup the DAG root node's name, or if the input model is not |
703 | /// partitioned then is used as the name of the single Function loaded. |
704 | /// The types in \p types match the list of names \p tensorNames and used as |
705 | /// inputs to the network. |
706 | /// If \p names and \p types are empty loader fills inputs automatically. |
707 | /// If \p errPtr is not null then if an error occurs it will get assigned |
708 | /// there otherwise if an error occurs it will abort. |
709 | /// If \p disableConstFoldInLoader then constant folding will be disabled |
710 | /// during loading. \p B will be used during function verification after |
711 | /// loading. If \p loadIntoExistingModule then all Functions and Storage is |
712 | /// expected to already exist, so they will be searched for according to the |
713 | /// proto being loaded instead of created as usual. |
714 | ONNXModelLoader(const std::string &modelDescFilename, |
715 | llvm::ArrayRef<const char *> tensorNames, |
716 | llvm::ArrayRef<TypeRef> types, Module &mod, |
717 | llvm::StringRef funName, |
718 | runtime::PrePartitionedConfig *PPC = nullptr, |
719 | Error *errPtr = nullptr, bool zipMode = false, |
720 | BackendSpecificNodeInfo *perNodeOpts = nullptr, |
721 | bool loadIntoExistingModule = false, |
722 | bool disableConstFoldInLoader = false, |
723 | const Backend *B = nullptr, |
724 | const std::string *inputStringPtr = nullptr); |
725 | |
726 | private: |
727 | /// Per-node options that may be specified in a proto. |
728 | BackendSpecificNodeInfo *perNodeOpts_{nullptr}; |
729 | /// Map from static PH names to the type it was originally loaded with. |
730 | std::map<std::string, Type> *staticPlaceholderTypes_; |
731 | }; |
732 | |
733 | } // namespace glow |
734 | |
735 | #endif // GLOW_IMPORTER_ONNXMODELLOADER_H |
736 | |