1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef GLOW_IMPORTER_ONNXMODELLOADER_H
18#define GLOW_IMPORTER_ONNXMODELLOADER_H
19
20#include "glow/Graph/Graph.h"
21#include "glow/Importer/CommonOperatorLoader.h"
22
23#include "onnx/onnx_pb.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/ADT/StringRef.h"
27
28#include "google/protobuf/io/coded_stream.h"
29#include "google/protobuf/io/zero_copy_stream_impl.h"
30#include <fstream>
31#include <string>
32#include <unordered_set>
33
34namespace ONNX_NAMESPACE {
35class AttributeProto;
36class NodeProto;
37class GraphProto;
38class ModelProto;
39class TensorProto;
40} // namespace ONNX_NAMESPACE
41
42namespace glow {
43
44/// Loads tensor \p T from the input \p in. \p useGlowCustomOps changes the
45/// format for doc_string format for adding meta information.
46Error loadTensor(const ONNX_NAMESPACE::TensorProto &in, Tensor *T,
47 bool useGlowCustomOps = false, const std::string &data = "");
48
49/// Parses as input file name \p fileName which is an ONNX file
50/// and \returns a parsed GraphProto.
51ONNX_NAMESPACE::GraphProto parseOnnxFile(const std::string &fileName);
52
53/// Takes an ONNX file in \p fileName reads it and loads the tensors
54/// in \p bindings. If the tensors loaded from the underlying file
55/// Are smaller than what the placeholder for that tensor expects it gets
56/// Padded with 0 if \p partialTensorPayloads is nullptr other wise
57/// \p partialTensorPayloads holds the data for full tensors.
58/// If \p usingGlowCustomOps then the custom Glow ONNX format will be
59/// expected/used to load from the ONNX file.
60void fillPlaceholders(const std::string &fileName,
61 PlaceholderBindings *bindings,
62 std::vector<Tensor> *partialTensorPayloads = nullptr,
63 bool usingGlowCustomOps = false);
64
65/// Override that takes \p parsedFile as a parsed file instead of file name.
66void fillPlaceholders(const ONNX_NAMESPACE::GraphProto &parsedFile,
67 PlaceholderBindings *bindings,
68 std::vector<Tensor> *partialTensorPayloads = nullptr,
69 bool usingGlowCustomOps = false);
70
71/// Define undefined symbols to \p str loaded from an ONNX proto. See
72/// onnxDefineSymbolOpt in ONNXModelLoader.cpp.
73void setOnnxDefineSymbol(const std::vector<std::string> &lst);
74
75/// Loads ONNX models.
76class ONNXModelLoader
77 : public CommonOperatorLoader<ONNX_NAMESPACE::NodeProto,
78 ONNX_NAMESPACE::AttributeProto> {
79 /// \returns True if the operator has broadcasting activated.
80 Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) override;
81
82 /// \returns True if the operator with the name \p typeName has support for
83 /// multidirectional broadcasting.
84 bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) override;
85
86 /// Converts a ONNX TensorProto DataType enum to the Glow element type.
87 /// Supports only non quantized and signed types.
88 Expected<ElemKind>
89 convertTensorProtoDataType(ONNX_NAMESPACE::TensorProto_DataType t);
90
91 /// Load the operator \p op into the network. This creates one or more nodes
92 /// in the network. \returns Error if operator \p op cannot be loaded.
93 Error loadOperator(const ONNX_NAMESPACE::NodeProto &op);
94
95 /// \returns a TypeRef found in \p dict which is loaded and uniqued into the
96 /// Module. The TypeRef is represented in the ONNX proto by concatenating the
97 /// relevant members of a type, (ElemKind, Shape, and Scale and Offset if
98 /// ElemKind is quantized) with \p resNo.
99 Expected<TypeRef> loadTypeFromAttributes(unsigned resNo,
100 ArgumentDictionaryTy &dict);
101
102 /// If this is a custom Glow op that was exported via NodeGen automatic export
103 /// logic, try to load the op. \returns Expected<true> if the op is
104 /// successfully loaded. \returns Expected<false> if op type is not supported.
105 /// \returns an Error if an error occurred while trying to load, or otherwise
106 /// the single Node that was created.
107 Expected<Node *> tryLoadGlowCustomOp(llvm::StringRef typeName,
108 const ONNX_NAMESPACE::NodeProto &op,
109 ArgumentDictionaryTy &dict);
110
111 /// \returns True if the operator\ op is successfully folded.
112 Expected<bool> foldOperator(const ONNX_NAMESPACE::NodeProto &op);
113
114 /// Helper function to print better log information for operator failure cases
115 const std::string opErrMsg(const ONNX_NAMESPACE::NodeProto &op,
116 const std::string &errMsg);
117
118 /// ONNX model ir_version;
119 size_t irVersion_;
120
121 /// ONNX model op_version;
122 size_t opsetVersion_;
123
124 /// Whether we're loading an ONNX file exported using Glow custom ops.
125 bool useGlowCustomOps_{false};
126
127 /// A set of inputs which will be static placeholders.
128 std::unordered_set<std::string> staticInputs_;
129
130 /// A set of Functions used for ConstantFolding to be deleted after loading.
131 std::unordered_set<Function *> constFoldFuns_;
132
133 /// Load ONNX NonZero Operator.
134 /// Glow's requirement for static shapes results in required Constant
135 /// input. Thus, the operator will be folded in the Importer.
136 Error loadNonZero(const ONNX_NAMESPACE::NodeProto &op,
137 const ArgumentDictionaryTy &dict);
138
139 /// Load Trigonometric Ops
140 Error loadAsin(const ONNX_NAMESPACE::NodeProto &op,
141 const ArgumentDictionaryTy &dict);
142
143 Error loadAcos(const ONNX_NAMESPACE::NodeProto &op,
144 const ArgumentDictionaryTy &dict);
145
146 /// Load Erf ONNX operator
147 Error loadErf(const ONNX_NAMESPACE::NodeProto &op,
148 const ArgumentDictionaryTy &dict);
149
150 Error loadAtan(const ONNX_NAMESPACE::NodeProto &op,
151 const ArgumentDictionaryTy &dict);
152
153 /// Load Constant ONNX operator.
154 Error loadConstant(const ONNX_NAMESPACE::NodeProto &op,
155 ArgumentDictionaryTy &dict);
156
157 /// Helper function for ONNX range operator
158 template <typename T>
159 Error getRange(const ONNX_NAMESPACE::NodeProto &op, Constant *constT);
160
161 /// Load Range ONNX operator
162 Error loadRange(const ONNX_NAMESPACE::NodeProto &op,
163 ArgumentDictionaryTy &dict);
164
165 /// Load PRelu ONNX operator.
166 Error loadPRelu(const ONNX_NAMESPACE::NodeProto &op,
167 ArgumentDictionaryTy &dict);
168
169 /// Load Slice ONNX operator.
170 Error loadSlice(const ONNX_NAMESPACE::NodeProto &op,
171 ArgumentDictionaryTy &dict);
172
173 /// Load Trignometric ONNX operators.
174 Error loadTrigonometricOps(const std::string &typeName,
175 const ONNX_NAMESPACE::NodeProto &op,
176 ArgumentDictionaryTy &dict);
177
178 /// Load Sign ONNX operator
179 Error loadSign(const ONNX_NAMESPACE::NodeProto &op,
180 const ArgumentDictionaryTy &dict);
181
182 /// Load Softmax ONNX operator
183 Error loadSoftmax(const ONNX_NAMESPACE::NodeProto &op,
184 const ArgumentDictionaryTy &dict);
185
186 /// Load LogSoftmax ONNX operator
187 Error loadLogSoftmax(const ONNX_NAMESPACE::NodeProto &op,
188 const ArgumentDictionaryTy &dict);
189
190 /// Load ScatterData ONNX operator
191 Error loadScatterData(const ONNX_NAMESPACE::NodeProto &op,
192 const ArgumentDictionaryTy &dict);
193
194 /// Load TopK ONNX operator
195 Error loadTopK(const ONNX_NAMESPACE::NodeProto &op,
196 ArgumentDictionaryTy &dict);
197
198 /// Load Conv ONNX operator.
199 Error loadConv(const ONNX_NAMESPACE::NodeProto &op,
200 ArgumentDictionaryTy &dict);
201
202 /// Load ChannelwiseQuantizedConvolution Glow operator.
203 Error loadChannelwiseQuantizedConvolution(const ONNX_NAMESPACE::NodeProto &op,
204 ArgumentDictionaryTy &dict);
205
206 /// Load Glow conv operator with quantized inputs. Since this isn't a normal
207 /// part of the ops supported by onnx, the assumption is that this op was
208 /// produced by Glow's on ONNXModelWriter and thus has NHWC layout for inputs.
209 Error loadTensorwiseQuantizedConvolution(const ONNX_NAMESPACE::NodeProto &op,
210 ArgumentDictionaryTy &dict);
211 /// Load ConvTranspose ONNX operator.
212 Error loadConvTranspose(const ONNX_NAMESPACE::NodeProto &op,
213 ArgumentDictionaryTy &dict);
214
215 /// Load Conv1D operator.
216 /// As per conv operation definition at
217 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv ,
218 /// input is in format (NxCxD1x...xDn). If the input tensor dimension size is
219 /// 3 , we have kernel size of only 1 dimension and we call such a conv
220 /// operation as conv1d.
221 /// Conv1d is implemented using Conv2d as follows:
222 /// a) Expand the input and kernel dimension to 4 using expand operator
223 /// b) Do the necessary tensor format conversion as required for Conv2d
224 /// c) Then use Conv2d for execution
225 /// d) To reduce the output tensor dimension from 4 to 3, Squeeze is used
226 Error loadConv1D(const ONNX_NAMESPACE::NodeProto &op,
227 ArgumentDictionaryTy &dict);
228
229 /// Load Conv operator with 2D input
230 Error loadConv2D(const ONNX_NAMESPACE::NodeProto &op,
231 ArgumentDictionaryTy &dict);
232
233 /// Load Conv operator with 3D input
234 Error loadConv3D(const ONNX_NAMESPACE::NodeProto &op,
235 ArgumentDictionaryTy &dict);
236
237 /// Load MaxPool or AveragePool ONNX operator. \p typeName is the name of the
238 /// ONNX operator being loaded, either MaxPool or AveragePool.
239 Error loadPool(const ONNX_NAMESPACE::NodeProto &op,
240 ArgumentDictionaryTy &dict, llvm::StringRef typeName);
241
242 /// Load Glow pooling operator with quantized inputs. Since this isn't a
243 /// normal part of the ops supported by onnx, the assumption is that this op
244 /// was produced by Glow's on ONNXModelWriter and thus has NHWC layout for
245 /// inputs.
246 Error loadTensorwiseQuantizedPool(const ONNX_NAMESPACE::NodeProto &op,
247 ArgumentDictionaryTy &dict,
248 llvm::StringRef typeName);
249
250 /// Load GlobalAveragePool ONNX operator.
251 Error loadGlobalAveragePool(const ONNX_NAMESPACE::NodeProto &op,
252 ArgumentDictionaryTy &dict);
253
254 /// Load Squeeze ONNX operator.
255 Error loadSqueeze(const ONNX_NAMESPACE::NodeProto &op,
256 ArgumentDictionaryTy &dict);
257
258 /// Load Unsqueeze ONNX operator.
259 Error loadUnsqueeze(const ONNX_NAMESPACE::NodeProto &op,
260 ArgumentDictionaryTy &dict);
261
262 /// Load ArgMax and ArgMin ONNX operators.
263 Error loadArgMinMax(const ONNX_NAMESPACE::NodeProto &op,
264 ArgumentDictionaryTy &dict, bool isMin);
265
266 /// Load Upsample ONNX operator.
267 Error loadUpsample(const ONNX_NAMESPACE::NodeProto &op,
268 ArgumentDictionaryTy &dict);
269
270 /// Load Resize ONNX Operator.
271 Error loadResize(const ONNX_NAMESPACE::NodeProto &op,
272 const ArgumentDictionaryTy &dict);
273
274 /// Load BatchNormalization ONNX operator.
275 Error loadBatchNormalization(const ONNX_NAMESPACE::NodeProto &op,
276 ArgumentDictionaryTy &dict);
277
278 /// Load InstanceNormalization ONNX operator.
279 Error loadInstanceNormalization(const ONNX_NAMESPACE::NodeProto &op,
280 ArgumentDictionaryTy &dict);
281
282 /// Load Concat ONNX operator.
283 Error loadConcat(const ONNX_NAMESPACE::NodeProto &op,
284 ArgumentDictionaryTy &dict);
285
286 /// Load FCTransposed ONNX operator.
287 Error loadFCTransposed(const ONNX_NAMESPACE::NodeProto &op,
288 ArgumentDictionaryTy &dict);
289
290 /// Load Gemm ONNX operator.
291 Error loadGemm(const ONNX_NAMESPACE::NodeProto &op,
292 ArgumentDictionaryTy &dict);
293
294 /// Load MatMul ONNX operator.
295 Error loadMatMul(const ONNX_NAMESPACE::NodeProto &op,
296 ArgumentDictionaryTy &dict);
297
298 /// Load Pad ONNX operator.
299 Error loadPad(const ONNX_NAMESPACE::NodeProto &op,
300 ArgumentDictionaryTy &dict);
301
302 /// Load Cast ONNX operator.
303 Error loadCast(const ONNX_NAMESPACE::NodeProto &op,
304 ArgumentDictionaryTy &dict);
305
306 /// Load HardSigmoid ONNX operator.
307 Error loadHardSigmoid(const ONNX_NAMESPACE::NodeProto &op,
308 ArgumentDictionaryTy &dict);
309
310 /// Load LeakyRelu ONNX operator.
311 Error loadLeakyRelu(const ONNX_NAMESPACE::NodeProto &op,
312 ArgumentDictionaryTy &dict);
313
314 /// Load SpaceToDepth ONNX operator.
315 Error loadSpaceToDepth(const ONNX_NAMESPACE::NodeProto &op,
316 ArgumentDictionaryTy &dict);
317
318 /// Load ReduceL2 ONNX operator
319 Error loadReduceL2(const ONNX_NAMESPACE::NodeProto &op,
320 const ArgumentDictionaryTy &dict);
321
322 /// Load DepthToSpace ONNX operator.
323 Error loadDepthToSpace(const ONNX_NAMESPACE::NodeProto &op,
324 const ArgumentDictionaryTy &dict);
325
326 /// Load ConstantOfShape ONNX operator.
327 Error loadConstantOfShape(const ONNX_NAMESPACE::NodeProto &op,
328 ArgumentDictionaryTy &dict, bool isSplat);
329
330 /// Load Tile ONNX operator.
331 Error loadTile(const ONNX_NAMESPACE::NodeProto &op,
332 ArgumentDictionaryTy &dict);
333
334 /// Load Expand ONNX operator.
335 Error loadExpand(const ONNX_NAMESPACE::NodeProto &op,
336 const ArgumentDictionaryTy &dict);
337
338 /// Load Where ONNX operator.
339 Error loadWhere(const ONNX_NAMESPACE::NodeProto &op,
340 ArgumentDictionaryTy &dict);
341
342 /// Load RNN ONNX operator.
343 Error loadRNN(const ONNX_NAMESPACE::NodeProto &op,
344 ArgumentDictionaryTy &dict);
345
346 /// Load GRU ONNX operator.
347 Error loadGRU(const ONNX_NAMESPACE::NodeProto &op,
348 ArgumentDictionaryTy &dict);
349
350 /// Load LSTM ONNX operator.
351 Error loadLSTM(const ONNX_NAMESPACE::NodeProto &op,
352 ArgumentDictionaryTy &dict);
353
354 /// Load Clip ONNX operator.
355 Error loadClip(const ONNX_NAMESPACE::NodeProto &op,
356 const ArgumentDictionaryTy &dict);
357
358 /// Load Glow specific operators, not defined in ONNX format
359 /// Load Glow CmpEQ operator.
360 Error loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op,
361 ArgumentDictionaryTy &dict);
362
363 /// Load Glow CmpLTE operator.
364 Error loadCmpLTE(const ONNX_NAMESPACE::NodeProto &op,
365 ArgumentDictionaryTy &dict);
366
367 /// Load Mean ONNX operator
368 Error loadMean(const ONNX_NAMESPACE::NodeProto &op,
369 ArgumentDictionaryTy &dict);
370
371 /// Load Glow Select operator.
372 Error loadSelect(const ONNX_NAMESPACE::NodeProto &op,
373 ArgumentDictionaryTy &dict);
374
375 /// Load Glow Quantize operator.
376 Error loadQuantize(const ONNX_NAMESPACE::NodeProto &op,
377 ArgumentDictionaryTy &dict);
378
379 /// Load Onnx QuantizeLinear operator.
380 Error loadQuantizeLinear(const ONNX_NAMESPACE::NodeProto &op,
381 ArgumentDictionaryTy &dict);
382
383 /// Load Glow ConvertTo operator.
384 Error loadConvertTo(const ONNX_NAMESPACE::NodeProto &op,
385 ArgumentDictionaryTy &dict);
386
387 /// Load Glow Dequantize operator.
388 Error loadDequantize(const ONNX_NAMESPACE::NodeProto &op,
389 ArgumentDictionaryTy &dict);
390
391 /// Load Glow Regression operator.
392 Error loadRegression(const ONNX_NAMESPACE::NodeProto &op,
393 ArgumentDictionaryTy &dict);
394
395 /// Load Glow BatchedAdd operator.
396 Error loadBatchedAdd(const ONNX_NAMESPACE::NodeProto &op,
397 ArgumentDictionaryTy &dict);
398
399 /// Load Glow CumSum operator.
400 Error loadCumSum(const ONNX_NAMESPACE::NodeProto &op,
401 ArgumentDictionaryTy &dict);
402
403 /// Load Glow ScatterAssign operator.
404 Error loadScatterAssign(const ONNX_NAMESPACE::NodeProto &op,
405 ArgumentDictionaryTy &dict);
406
407 /// Load Glow IntLookupTable operator.
408 Error loadIntLookupTable(const ONNX_NAMESPACE::NodeProto &op,
409 ArgumentDictionaryTy &dict);
410
411 /// Load Glow LengthsRangeFill operator.
412 Error loadLengthsRangeFill(const ONNX_NAMESPACE::NodeProto &op,
413 ArgumentDictionaryTy &dict);
414
415 /// Load Glow RescaleQuantized operator.
416 Error loadRescaleQuantized(const ONNX_NAMESPACE::NodeProto &op,
417 ArgumentDictionaryTy &dict);
418
419 /// Load Glow RowwiseQuantizedSparseLengthsWeightedSum operator.
420 Error loadRowwiseQuantizedSparseLengthsWeightedSum(
421 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict);
422
423 /// Load Glow FusedRowwiseQuantizedSparseLengthsWeightedSum operator.
424 Error loadFusedRowwiseQuantizedSparseLengthsWeightedSum(
425 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict);
426
427 /// Load Glow FusedRowwiseQuantizedSparseLengthsSum operator.
428 Error
429 loadFusedRowwiseQuantizedSparseLengthsSum(const ONNX_NAMESPACE::NodeProto &op,
430 ArgumentDictionaryTy &dict);
431
432 /// Load Glow RowwiseQuantizedFullyConnected operator.
433 Error loadRowwiseQuantizedFullyConnected(const ONNX_NAMESPACE::NodeProto &op,
434 ArgumentDictionaryTy &dict);
435
436 /// Load Glow FullyConnected operator.
437 Error loadFullyConnected(const ONNX_NAMESPACE::NodeProto &op,
438 ArgumentDictionaryTy &dict);
439
440 /// Load ONNX Identity operator.
441 Error loadIdentity(const ONNX_NAMESPACE::NodeProto &op,
442 ArgumentDictionaryTy &dict);
443
444 /// Load Glow Splat operator.
445 Error loadSplat(const ONNX_NAMESPACE::NodeProto &op,
446 ArgumentDictionaryTy &dict);
447
448 /// Load NonMaxSuppression ONNX and TF NMSv4 operator.
449 /// The \p isV4 indicates whether this is ONNX or custom NMSv4 operator.
450 Error loadNonMaxSuppression(const ONNX_NAMESPACE::NodeProto &op,
451 ArgumentDictionaryTy &dict, bool isV4);
452
453 /// Load Glow InsertTensor operator.
454 Error loadInsertTensor(const ONNX_NAMESPACE::NodeProto &op,
455 ArgumentDictionaryTy &dict);
456
457 /// Load If ONNX operator.
458 Error loadIf(const ONNX_NAMESPACE::NodeProto &op,
459 const ArgumentDictionaryTy &dict);
460
461 /// Load AdaptiveAvgPool Glow operator.
462 /// NOTE: since this operator is not a standard onnx op, assume this is from
463 /// OnnxModelWriter and is therefore in NHWC format.
464 Error loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
465 ArgumentDictionaryTy &dict);
466
467 /// Load Flip Glow operator.
468 Error loadFlip(const ONNX_NAMESPACE::NodeProto &op,
469 ArgumentDictionaryTy &dict);
470
471 /// Load AudioSpectrogram Glow operator.
472 Error loadAudioSpectrogram(const ONNX_NAMESPACE::NodeProto &op,
473 ArgumentDictionaryTy &dict);
474
475 /// Load Loop operator.
476 Error loadLoop(const ONNX_NAMESPACE::NodeProto &op,
477 const ArgumentDictionaryTy &dict);
478
479 /// Load MFCC Glow operator.
480 Error loadMFCC(const ONNX_NAMESPACE::NodeProto &op,
481 ArgumentDictionaryTy &dict);
482
483 // Load ROIAlign ONNX operator
484 Error loadROIAlign(const ONNX_NAMESPACE::NodeProto &op,
485 ArgumentDictionaryTy &dict);
486
487protected:
488 /// Loads operators from \p net. If \p loadingConstFoldSubgraph then the
489 /// current Function \ref G_ is assumed to be the one to load into.
490 /// \returns Error if network cannot be loaded.
491 Error loadNetwork(ONNX_NAMESPACE::GraphProto &net,
492 bool loadingConstFoldSubgraph);
493
494 /// Set the output nodes of the network \p net. Initializes the map from the
495 /// names of the outputs to the save nodes that save each output.
496 /// \returns Error if network cannot be loaded.
497 Error setOutputNodes(ONNX_NAMESPACE::GraphProto &net);
498
499 /// Set ir verion and op version.
500 Error setVersion(ONNX_NAMESPACE::ModelProto MP);
501
502 /// \returns Expected<ModelProto> if a ModelProto can be loaded from the
503 /// stream \p iStream.
504 static Expected<ONNX_NAMESPACE::ModelProto>
505 loadProto(google::protobuf::io::ZeroCopyInputStream &iStream);
506
507 /// Load the network initializers from the GraphProto.
508 Error loadInitializers(ONNX_NAMESPACE::GraphProto &net);
509
510 /// Given some initializer \p in, check if it has some constant folding node
511 /// associated with it in \p net. If so, deserializes the Function if not
512 /// already done, performs the constant folding, and \returns the Constant
513 /// created as a result to be used for this initializer.
514 Expected<Constant *>
515 replaySerializedConstFold(const ONNX_NAMESPACE::TensorProto &in,
516 ONNX_NAMESPACE::GraphProto &net);
517
518 /// Given some \p outputName that maps to a NodeValue that we want to constant
519 /// fold, run it and assign the resulting Constant \p initializerName.
520 Expected<Constant *> runDeserializedConstFold(llvm::StringRef initializerName,
521 llvm::StringRef outputName);
522
523 /// Load the inputs from the GraphProto. If \p loadInputsAsPlaceholdersForOnnx
524 /// is true then this will load each graph input as a placeholder otherwise it
525 /// will create an empty tensor for each input.
526 Error loadInputs(ONNX_NAMESPACE::GraphProto &net,
527 bool loadInputsAsPlaceholdersForOnnx);
528
529 /// \returns whether there's an issue with pre-existing \p S with name \p
530 /// name, \p ty, \p layout, and \p trainable (for Placeholders).
531 Error verifyPreexistingStorage(const Storage *S, const std::string &name,
532 const Type &ty, const std::string &layout,
533 const bool trainable = false);
534
535 /// \returns Expected<ModelProto> if a ModelProto can be constructed from the
536 /// contents of the file \p filename and Error otherwise.
537 /// Loads ModelProto from the file containing serialized protobuf.
538 /// If \p zipMode then zip format will be expected/loaded.
539 static Expected<ONNX_NAMESPACE::ModelProto>
540 loadProto(const std::string &filename, bool zipMode,
541 const std::string *inputStringPtr);
542
543 /// \returns Expected<ModelProto> if a ModelProto can be constructed from the
544 /// in-memory serialized protobuf.
545 /// Loads ModelProto from the in-memory serialized protobuf \p
546 /// onnxModel with the model size \p onnxModelSize.
547 static Expected<ONNX_NAMESPACE::ModelProto> loadProto(const void *onnxModel,
548 size_t onnxModelSize);
549
550 /// Checks that the inputs tensors are compatible with the inputs declared in
551 /// the ONNX model. The input types in \p types match the list of names
552 /// \p tensorNames.
553 Error checkInputs(ONNX_NAMESPACE::GraphProto &net,
554 llvm::ArrayRef<const char *> tensorNames,
555 llvm::ArrayRef<TypeRef> types);
556
557 /// Go through the ValueInfoProto of the inputs of the \p net and collect
558 /// static placeholders if it's marked in the ValueInfoProto.
559 Error collectStaticInputs(ONNX_NAMESPACE::GraphProto &net);
560
561 /// Looks through all ops in \p net for any dummy static PH nodes carrying the
562 /// type that was used for loading deferred weights initially. If found then
563 /// they're added to \ref staticPlaceholderTypes_. If \ref
564 /// staticPlaceholderTypes_ is a nullptr then this method is a no-op.
565 Error setupOrigStaticTypeMap(ONNX_NAMESPACE::GraphProto &net);
566
567 /// Associate all inputs of \p net with nodes in \p NVs. Number of inputs of
568 /// \p net should match the number of elements of \p NVs.
569 /// \returns error code in case of error.
570 Error assignGraphInputs(const ONNX_NAMESPACE::GraphProto &net,
571 llvm::ArrayRef<NodeValue> NVs);
572
573 /// Creates a ONNX model loader to build \p F.
574 /// Loads the ONNIXFI \p model from memory of \p modelSize size,
575 /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent
576 /// descriptors. Converts inputs into placeholder if requested \p
577 /// loadInputsAsPlaceholdersForOnnx. Reports success/failure through optional
578 /// parameter \p errPtr. This constructor always overrides the default
579 /// constant folding in loader flag with \p constFoldInLoader.
580 ONNXModelLoader(const void *model, uint32_t modelSize, uint32_t weightsCount,
581 const onnxTensorDescriptorV1 *weightDescriptors, Function &F,
582 bool loadInputsAsPlaceholdersForOnnx, Error *errPtr = nullptr,
583 bool constFoldInLoader = true,
584 BackendSpecificNodeInfo *perNodeOpts = nullptr);
585
586 /// Creates a ONNX model loader to build \p mod.
587 /// Loads the ONNIXFI \p model from memory of \p modelSize size,
588 /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent
589 /// descriptors. Converts inputs into placeholder if requested \p
590 /// loadInputsAsPlaceholdersForOnnx. Reports success/failure through optional
591 /// parameter \p errPtr. This constructor always overrides the default
592 /// constant folding in loader flag with \p constFoldInLoader.
593 /// Supports loading a DAG which was serialized, loading in DAGNode meta info
594 /// into \p PPC which can be later used to recreated the DAG. \p funName is
595 /// used to setup the DAG root node's name, or if the input model is not
596 /// partitioned then is used as the name of the single Function loaded. Loads
597 /// backend-specific node info annotations into \p perNodeOpts.
598 /// \p staticPlaceholderTypes will be filled with types to use for static
599 /// Placeholders if the proto being parsed contains such information;
600 /// otherwise it's left unchanged. If \p replaceDummyTQPs then any dummy TQPs
601 /// (represented by scale=0.f) will be replaced by updated TQPs found in
602 /// metadata_props, allowing for changing of TQPs of serialized models.
603 ONNXModelLoader(const void *model, uint32_t modelSize, uint32_t weightsCount,
604 const onnxTensorDescriptorV1 *weightDescriptors, Module &mod,
605 llvm::StringRef funName, runtime::PrePartitionedConfig *PPC,
606 bool loadInputsAsPlaceholdersForOnnx, Error *errPtr = nullptr,
607 bool constFoldInLoader = true,
608 BackendSpecificNodeInfo *perNodeOpts = nullptr,
609 std::map<std::string, Type> *staticPlaceholderTypes = nullptr,
610 bool replaceDummyTQPs = false,
611 bool clipQuantRangeToFP16 = false);
612
613 friend class ONNXIFIModelLoader;
614
615 /// \returns success if the folding of operator \p op in the loader
616 /// \p loader is successful. The folding utility uses temporary
617 /// loader \p tmpLoader, and associated temporary function \p F.
618 template <class LoaderType, class OpType>
619 friend Error constantFoldInLoader(Function *F, LoaderType &tmpLoader,
620 LoaderType *loader, const OpType &op);
621
622 /// \returns a Type with the proper shape and element type given \p in.
623 Expected<Type> getTensorType(const ONNX_NAMESPACE::ValueInfoProto &in);
624
625 /// \returns a Type with the proper shape and element type given \p in.
626 Expected<Type> getTensorType(const ONNX_NAMESPACE::TensorProto &in);
627
628 /// Load a model \p modelDef given \p tensorNames, \p types, \p B, and
629 /// \p loadInputsAsPlaceholdersForOnnx.
630 Error loadModel(ONNX_NAMESPACE::ModelProto &modelDef,
631 llvm::ArrayRef<const char *> tensorNames,
632 llvm::ArrayRef<TypeRef> types, const Backend *B,
633 bool loadInputsAsPlaceholdersForOnnx);
634
635 /// Setup partitions by creating Functions and loading metadata into \p PPC
636 /// from the metadata props found in \p modelDef given \p rootName and
637 /// \p numPartitions.
638 Error setupPartitions(ONNX_NAMESPACE::ModelProto &modelDef,
639 runtime::PrePartitionedConfig &PPC,
640 llvm::StringRef rootName, int numPartitions);
641
642 /// Deletes the Functions in \ref constFoldFuns_ from \ref mod_.
643 void deleteConstFoldFunctions();
644
645 /// Sets up positional IO into \ref positionalInputNames_ and
646 /// \ref positionalOutputNames_ from \p graph.
647 void setupPositionalIO(const ONNX_NAMESPACE::GraphProto &graph);
648
649 /// Sets up \ref updatedTQPs_ based on metadata props found in \p modelDef as
650 /// well as \p weightsCount weights in \p weightDescriptors.
651 Error setupUpdatedTQPMap(ONNX_NAMESPACE::ModelProto &modelDef,
652 uint32_t weightsCount,
653 const onnxTensorDescriptorV1 *weightDescriptors);
654
655public:
656 /// \returns ONNX model ir_version;
657 size_t getIrVersion() const { return irVersion_; };
658
659 /// \returns ONNX model op_version;
660 size_t getOpSetVersion() const { return opsetVersion_; };
661
662 /// \returns if the loader is loading a proto using custom Glow ops.
663 bool usingGlowCustomOps() const { return useGlowCustomOps_; };
664
665 /// Creates a ONNX model loader to build \p F.
666 /// If \p errPtr is not null then if an error occurs it will get assigned
667 /// there otherwise if an error occurs it will abort.
668 ONNXModelLoader(Function &F, Error *errPtr = nullptr);
669
670 /// Update \p inTensorNames and \p inTypes from inputs of onnx model from
671 /// filename
672 static Error getInputsNamesAndTypes(std::vector<std::string> &inTensorNames,
673 std::vector<Type> &inTypes,
674 const std::string &filename);
675
676 /// Loads the ONNX model that's represented by a model description file,
677 /// serialized in \p modelDescFilename and populates the network into \p F.
678 /// The types in \p types match the list of names \p tensorNames and used as
679 /// inputs to the network.
680 /// If \p names and \p types are empty loader fills inputs automatically.
681 /// If \p errPtr is not null then if an error occurs it will get assigned
682 /// there otherwise if an error occurs it will abort.
683 /// If \p disableConstFoldInLoader then constant folding will be disabled
684 /// during loading. \p B will be used during function verification after
685 /// loading. If \p loadIntoExistingModule then all Functions and Storage is
686 /// expected to already exist, so they will be searched for according to the
687 /// proto being loaded instead of created as usual.
688 ONNXModelLoader(const std::string &modelDescFilename,
689 llvm::ArrayRef<const char *> tensorNames,
690 llvm::ArrayRef<TypeRef> types, Function &F,
691 Error *errPtr = nullptr, bool zipMode = false,
692 BackendSpecificNodeInfo *perNodeOpts = nullptr,
693 bool disableConstFoldInLoader = false,
694 bool loadIntoExistingModule = false,
695 const Backend *B = nullptr,
696 const std::string *inputStringPtr = nullptr);
697
698 /// Loads the ONNX model that's represented by a model description file,
699 /// serialized in \p modelDescFilename and populates the network into \p mod.
700 /// Supports loading a DAG which was serialized, loading in DAGNode meta info
701 /// into \p PPC which can be later used to recreated the DAG. \p funName is
702 /// used to setup the DAG root node's name, or if the input model is not
703 /// partitioned then is used as the name of the single Function loaded.
704 /// The types in \p types match the list of names \p tensorNames and used as
705 /// inputs to the network.
706 /// If \p names and \p types are empty loader fills inputs automatically.
707 /// If \p errPtr is not null then if an error occurs it will get assigned
708 /// there otherwise if an error occurs it will abort.
709 /// If \p disableConstFoldInLoader then constant folding will be disabled
710 /// during loading. \p B will be used during function verification after
711 /// loading. If \p loadIntoExistingModule then all Functions and Storage is
712 /// expected to already exist, so they will be searched for according to the
713 /// proto being loaded instead of created as usual.
714 ONNXModelLoader(const std::string &modelDescFilename,
715 llvm::ArrayRef<const char *> tensorNames,
716 llvm::ArrayRef<TypeRef> types, Module &mod,
717 llvm::StringRef funName,
718 runtime::PrePartitionedConfig *PPC = nullptr,
719 Error *errPtr = nullptr, bool zipMode = false,
720 BackendSpecificNodeInfo *perNodeOpts = nullptr,
721 bool loadIntoExistingModule = false,
722 bool disableConstFoldInLoader = false,
723 const Backend *B = nullptr,
724 const std::string *inputStringPtr = nullptr);
725
726private:
727 /// Per-node options that may be specified in a proto.
728 BackendSpecificNodeInfo *perNodeOpts_{nullptr};
729 /// Map from static PH names to the type it was originally loaded with.
730 std::map<std::string, Type> *staticPlaceholderTypes_;
731};
732
733} // namespace glow
734
735#endif // GLOW_IMPORTER_ONNXMODELLOADER_H
736