1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Importer/ONNXModelLoader.h"
18
19#include "glow/Base/Tensor.h"
20#include "glow/Flags/Flags.h"
21#include "glow/Graph/Graph.h"
22#include "glow/Graph/Nodes.h"
23#include "glow/Importer/Caffe2ModelLoader.h"
24#include "glow/Support/Support.h"
25#include "glow/Support/ZipUtils.h"
26
27#include "llvm/Support/Casting.h"
28#include "llvm/Support/CommandLine.h"
29
30#include "google/protobuf/io/coded_stream.h"
31#include "google/protobuf/io/tokenizer.h"
32#include "google/protobuf/io/zero_copy_stream_impl.h"
33#include "google/protobuf/text_format.h"
34
35#include <cstddef>
36#include <cstdint>
37#include <fstream>
38#include <sstream>
39#include <string>
40#include <vector>
41
42using namespace glow;
43using namespace glow::runtime;
44using llvm::cast;
45
46namespace {
47
48llvm::cl::OptionCategory onnxModelLoaderCat("ONNX Model Loader Options");
49
50std::vector<std::string> onnxDefineSymbol;
51llvm::cl::list<std::string, std::vector<std::string>> onnxDefineSymbolOpt(
52 "onnx-define-symbol", llvm::cl::ZeroOrMore,
53 llvm::cl::location(onnxDefineSymbol),
54 llvm::cl::desc(
55 "Define (replace) the undefined symbols from the tensor descriptions\n"
56 "in the ONNX model with actual integer sizes. The undefined symbols \n"
57 "are marked in the proto description with the 'dim_param' field. For\n"
58 "example, if the model contains a tensor with the size described as \n"
59 "'None' x 3 x 224 x 224, the symbol 'None' can be replaced with an \n"
60 "actual integer size (for example 1) by using the following command \n"
61 "line option: \n"
62 " -onnx-define-symbol=None,1 \n"
63 "Multiple symbols can be defined using this option, for example: \n"
64 " -onnx-define-symbol=<symbol_name1>,<symbol_value1> \n"
65 " -onnx-define-symbol=<symbol_name2>,<symbol_value2> \n"
66 " ..................................................\n"),
67 llvm::cl::value_desc("name,value"), llvm::cl::cat(onnxModelLoaderCat));
68
69llvm::cl::opt<bool> onnxExportRnnStatesOpt(
70 "onnx-export-rnn-states", llvm::cl::init(false), llvm::cl::Optional,
71 llvm::cl::desc(
72 "Option to export the states of the ONNX RNN operators (for example \n"
73 "RNN, GRU, LSTM) as graph placeholders regardless of whether the \n"
74 "states are explicitly set or not in the graph. The placeholders are\n"
75 "also providing an automatic way for tracking the RNN states since \n"
76 "the states are updated automatically with the new RNN states after \n"
77 "each inference. Default is false."),
78 llvm::cl::cat(onnxModelLoaderCat));
79
80llvm::cl::opt<unsigned> loopUnrollLimit(
81 "loop-unroll-limit",
82 llvm::cl::desc("Maximum unrollable iterations for the Loop operator"),
83 llvm::cl::Optional, llvm::cl::init(20), llvm::cl::cat(onnxModelLoaderCat));
84
85/// Parse the command line option and get the user defined map of symbols.
86/// The command line option has the format <symbol_name>,<symbol_value>.
87Expected<std::unordered_map<std::string, dim_t>> getSymbolMap() {
88 std::unordered_map<std::string, dim_t> symbolMap;
89 for (const auto &str : onnxDefineSymbol) {
90 auto strPair = llvm::StringRef(str).split(',');
91 llvm::StringRef name = strPair.first;
92 RETURN_ERR_IF_NOT(name.size() > 0, "ONNX defined symbol name is empty.");
93 dim_t value;
94 RETURN_ERR_IF_NOT(!strPair.second.getAsInteger(0, value),
95 strFormat("ONNX defined symbol value '%s' is invalid.",
96 strPair.second.data()));
97 symbolMap[name.str()] = value;
98 }
99 return symbolMap;
100}
101
102/// Get the shape of a TensorShapeProto given by \p shapeProto and return the
103/// dimensions in the vector \p dim passed by reference.
104Expected<std::vector<dim_t>>
105getProtoShape(const ONNX_NAMESPACE::TensorShapeProto &shapeProto) {
106 std::vector<dim_t> dim;
107 for (auto d : shapeProto.dim()) {
108 if (d.has_dim_value()) {
109 // Proto shape has an explicit size given by the "dim_value" field.
110 dim.push_back(d.dim_value());
111 } else if (d.has_dim_param()) {
112 // Proto shape has a symbolic size given by the "dim_param" field. Search
113 // the symbol in the user defined map of symbols. If the symbol is not
114 // found then raise an error.
115 auto symbolName = d.dim_param();
116 std::unordered_map<std::string, dim_t> symbolMap;
117 ASSIGN_VALUE_OR_RETURN_ERR(symbolMap, getSymbolMap());
118 if (symbolMap.count(symbolName) && symbolMap[symbolName] > 0) {
119 dim.push_back(symbolMap[symbolName]);
120 } else {
121 return MAKE_ERR(strFormat(
122 "ONNX model symbol '%s' is undefined. Define the symbol with the "
123 "following command line option: -onnx-define-symbol=%s,<value> and "
124 "each 'dim_value' of tensor shape proto must be greater than 0.",
125 symbolName.c_str(), symbolName.c_str()));
126 }
127 } else {
128 // Proto shape has no "dim_value" and no "dim_param" field.
129 return MAKE_ERR(
130 "Tensor shape proto has no 'dim_value' or 'dim_param' field!");
131 }
132 }
133 return dim;
134}
135
136/// Given some \p onnxType, sets \p elemTy to a corresponding Glow
137/// ElemKind. \returns whether an ElemKind was successfully selected.
138Error onnxTensorDataTypeToElemKind(int32_t onnxType, ElemKind *elemTy) {
139 if (onnxType == ONNX_NAMESPACE::TensorProto::FLOAT) {
140 *elemTy = ElemKind::FloatTy;
141 return Error::success();
142 } else if (onnxType == ONNX_NAMESPACE::TensorProto::FLOAT16) {
143 *elemTy = ElemKind::Float16Ty;
144 return Error::success();
145 } else if (onnxType == ONNX_NAMESPACE::TensorProto::BFLOAT16) {
146 *elemTy = ElemKind::BFloat16Ty;
147 return Error::success();
148 } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT64) {
149 *elemTy = ElemKind::Int64ITy;
150 return Error::success();
151 } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT32) {
152 *elemTy = ElemKind::Int32ITy;
153 return Error::success();
154 } else if (onnxType == ONNX_NAMESPACE::TensorProto::UINT8) {
155 *elemTy = ElemKind::UInt8FusedQTy;
156 return Error::success();
157 } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT8) {
158 *elemTy = ElemKind::Int8QTy;
159 return Error::success();
160 } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT16) {
161 *elemTy = ElemKind::Int16QTy;
162 return Error::success();
163 } else if (onnxType == ONNX_NAMESPACE::TensorProto::BOOL) {
164 *elemTy = ElemKind::BoolTy;
165 return Error::success();
166 } else {
167 return MAKE_ERR(strFormat(
168 "Don't know how to convert ONNX tensor data type %d to ElemKind",
169 onnxType));
170 }
171}
172
173/// Finds an attribute from the doc_string and \returns it. If it does not exist
174/// then \returns Error. The expected structure here is that each attribute
175/// starts with startChar and is separated from its value by a sepChar.
176Expected<std::string> getAttrFromDocString(const std::string &attr,
177 const std::string &docStr) {
178 const std::string attrAndSep = attr + sepChar;
179 size_t begin = 0;
180 while (true) {
181 begin = docStr.find(startChar, begin);
182 if (begin == std::string::npos) {
183 return MAKE_ERR(strFormat("Didn't find PH attribute '%s'", attr.c_str()));
184 }
185
186 // Note: +1 here and following line to account for the leading startChar.
187 if (!docStr.compare(begin + 1, attrAndSep.size(), attrAndSep)) {
188 // If we found the attribute then set begin to just after attrAndSep.
189 begin += attrAndSep.size() + 1;
190 break;
191 }
192 // Move past the current non-matching attribute to try the next attribute.
193 begin = begin + attrAndSep.size();
194 }
195
196 return docStr.substr(begin, docStr.find(startChar, begin) - begin);
197}
198
199Expected<std::pair<bool, std::string>>
200getTrainableLayoutPairFromDocString(const std::string &docString,
201 bool useGlowCustomOps) {
202 std::string layout = ANY_LAYOUT;
203 std::string isTrainableStr = "0";
204 if (useGlowCustomOps) {
205 ASSIGN_VALUE_OR_RETURN_ERR(
206 isTrainableStr, getAttrFromDocString(trainableSignifier, docString));
207 ASSIGN_VALUE_OR_RETURN_ERR(
208 layout, getAttrFromDocString(layoutSignifier, docString));
209 }
210 return std::make_pair(isTrainableStr != "0", layout);
211}
212
213Expected<std::pair<float, int32_t>>
214getQuantParamsFromDocString(const std::string &docStr) {
215 std::string scaleStr;
216 ASSIGN_VALUE_OR_RETURN_ERR(scaleStr,
217 getAttrFromDocString(qScaleSignifier, docStr));
218 float scale = std::strtof(scaleStr.c_str(), NULL);
219
220 std::string offsetStr;
221 ASSIGN_VALUE_OR_RETURN_ERR(offsetStr,
222 getAttrFromDocString(qOffsetSignifier, docStr));
223 int32_t offset;
224 ASSIGN_VALUE_OR_RETURN_ERR(offset, getIntFromStr(offsetStr));
225 return std::make_pair(scale, offset);
226}
227
228ShapeVector getStridesFromDocString(const std::string &docStr) {
229 ShapeVector strides;
230 std::string stridesStr;
231 auto stridesOrError = getAttrFromDocString(stridesSignifier, docStr);
232 if (ERR_TO_BOOL(stridesOrError.takeError(), /* log */ false)) {
233 return strides;
234 }
235 stridesStr = std::move(stridesOrError.get());
236 if (stridesStr.empty()) {
237 return strides;
238 }
239 // Parse comma-delimited stride values.
240 llvm::SmallVector<llvm::StringRef, max_tensor_dimensions> stridesStrSplit;
241 llvm::StringRef stridesStrRef = llvm::StringRef(stridesStr);
242 stridesStrRef.split(stridesStrSplit, ',');
243 for (const auto &stride : stridesStrSplit) {
244 strides.emplace_back(std::stoi(stride.str()));
245 }
246 return strides;
247}
248
249/// Used for retrieving an attribute of type \p T from \p attr. Some
250/// specializations used \p loader if necessary.
251template <bool IsInteger, typename T> struct AttributeRetriever {
252 static Expected<T> get(const ONNX_NAMESPACE::AttributeProto *attr,
253 const ProtobufLoader &loader);
254};
255
256/// Specialization for std::vector<float>.
257template <> struct AttributeRetriever<false, std::vector<float>> {
258 static Expected<std::vector<float>>
259 get(const ONNX_NAMESPACE::AttributeProto *attr,
260 const ProtobufLoader & /* unused */) {
261 return getFloats(attr);
262 }
263};
264
265/// Specialization for std::vector<NodeValue>.
266template <> struct AttributeRetriever<false, std::vector<NodeValue>> {
267 static Expected<std::vector<NodeValue>>
268 get(const ONNX_NAMESPACE::AttributeProto *attr, ProtobufLoader &loader) {
269 // Retrieve the names from the proto which map to NodeValues.
270 std::vector<std::string> strs;
271 ASSIGN_VALUE_OR_RETURN_ERR(strs, getStrings(attr));
272
273 // Get NodeValues corresponding to these names from the loader.
274 std::vector<NodeValue> NVs;
275 for (const auto &str : strs) {
276 NodeValue NV;
277 ASSIGN_VALUE_OR_RETURN_ERR(NV, loader.getNodeValueByName(str));
278 NVs.push_back(NV);
279 }
280 return NVs;
281 }
282};
283
284/// Specialization for NodeValue.
285template <> struct AttributeRetriever<false, NodeValue> {
286 static Expected<NodeValue> get(const ONNX_NAMESPACE::AttributeProto *attr,
287 ProtobufLoader &loader) {
288 // Retrieve the name from the proto, which is mapped to a NodeValue.
289 std::string str;
290 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
291
292 // Get/return the corresponding NodeValue for this name from the loader.
293 NodeValue NV;
294 ASSIGN_VALUE_OR_RETURN_ERR(NV, loader.getNodeValueByName(str));
295 return NV;
296 }
297};
298
299/// Specialization for std::vector<T>. Fall back for integer types.
300template <typename T> struct AttributeRetriever<false, std::vector<T>> {
301 static Expected<std::vector<T>>
302 get(const ONNX_NAMESPACE::AttributeProto *attr,
303 const ProtobufLoader & /* unused */) {
304 return getShape<T>(attr, /* allowEmptyShape */ true);
305 }
306};
307
308/// Specialization for integer types.
309template <typename T> struct AttributeRetriever<true, T> {
310 static Expected<T> get(const ONNX_NAMESPACE::AttributeProto *attr,
311 const ProtobufLoader & /* unused */) {
312 return loadInt(attr);
313 }
314};
315
316/// Specialization for LengthsMode.
317template <> struct AttributeRetriever<false, LengthsMode> {
318 static Expected<LengthsMode> get(const ONNX_NAMESPACE::AttributeProto *attr,
319 const ProtobufLoader & /* unused */) {
320 std::string str;
321 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
322 if (str == "AllOne") {
323 return LengthsMode::AllOne;
324 } else if (str == "Variable") {
325 return LengthsMode::Variable;
326 } else {
327 return MAKE_ERR("Invalid LengthsMode");
328 }
329 }
330};
331
332/// Specialization for FusedActivation.
333template <> struct AttributeRetriever<false, FusedActivation> {
334 static Expected<FusedActivation>
335 get(const ONNX_NAMESPACE::AttributeProto *attr,
336 const ProtobufLoader & /* unused */) {
337 std::string str;
338 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
339 if (str == "NONE") {
340 return FusedActivation::NONE;
341 } else if (str == "RELU") {
342 return FusedActivation::RELU;
343 } else if (str == "CLIP") {
344 return FusedActivation::CLIP;
345 } else if (str == "TANH") {
346 return FusedActivation::TANH;
347 } else if (str == "SIGMOID") {
348 return FusedActivation::SIGMOID;
349 } else if (str == "LEAKY_RELU") {
350 return FusedActivation::LEAKY_RELU;
351 } else {
352 return MAKE_ERR("Invalid FusedActivation");
353 }
354 }
355};
356
357/// Specialization for FusedActivation.
358template <> struct AttributeRetriever<false, LUTOperator> {
359 static Expected<LUTOperator> get(const ONNX_NAMESPACE::AttributeProto *attr,
360 const ProtobufLoader & /* unused */) {
361 std::string str;
362 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
363 if (str == "NONE") {
364 return LUTOperator::NONE;
365 } else if (str == "RELU") {
366 return LUTOperator::RELU;
367 } else if (str == "CLIP") {
368 return LUTOperator::CLIP;
369 } else if (str == "TANH") {
370 return LUTOperator::TANH;
371 } else if (str == "SIGMOID") {
372 return LUTOperator::SIGMOID;
373 } else if (str == "LEAKY_RELU") {
374 return LUTOperator::LEAKY_RELU;
375 } else {
376 return MAKE_ERR("Invalid LUTOperator");
377 }
378 }
379};
380
381/// Specialization for ConvolutionLayout.
382template <> struct AttributeRetriever<false, ConvolutionLayout> {
383 static Expected<ConvolutionLayout>
384 get(const ONNX_NAMESPACE::AttributeProto *attr,
385 const ProtobufLoader & /* unused */) {
386 std::string str;
387 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
388 if (str == "NHWC") {
389 return ConvolutionLayout::NHWC;
390 } else if (str == "NCHW") {
391 return ConvolutionLayout::NCHW;
392 } else if (str == "NTHWC") {
393 return ConvolutionLayout::NTHWC;
394 } else if (str == "NCTHW") {
395 return ConvolutionLayout::NCTHW;
396 } else {
397 return MAKE_ERR("Invalid ConvolutionLayout");
398 }
399 }
400};
401
402/// Specialization for PaddingMode.
403template <> struct AttributeRetriever<false, PaddingMode> {
404 static Expected<PaddingMode> get(const ONNX_NAMESPACE::AttributeProto *attr,
405 const ProtobufLoader & /* unused */) {
406 std::string str;
407 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
408 if (str == "CONSTANT") {
409 return PaddingMode::CONSTANT;
410 } else if (str == "REFLECT") {
411 return PaddingMode::REFLECT;
412 } else if (str == "EDGE") {
413 return PaddingMode::EDGE;
414 } else {
415 return MAKE_ERR("Invalid PaddingMode");
416 }
417 }
418};
419
420/// Specialization for SplitEmbeddingPoolingMode.
421template <> struct AttributeRetriever<false, SplitEmbeddingPoolingMode> {
422 static Expected<SplitEmbeddingPoolingMode>
423 get(const ONNX_NAMESPACE::AttributeProto *attr,
424 const ProtobufLoader & /* unused */) {
425 std::string poolingMode;
426 ASSIGN_VALUE_OR_RETURN_ERR(poolingMode, loadStr(attr));
427 if (poolingMode == "0") {
428 return SplitEmbeddingPoolingMode::EP_SUM;
429 } else if (poolingMode == "1") {
430 return SplitEmbeddingPoolingMode::EP_MEAN;
431 } else if (poolingMode == "2") {
432 return SplitEmbeddingPoolingMode::EP_NONE;
433 } else {
434 return MAKE_ERR("Invalid SplitEmbeddingPoolingMode");
435 }
436 }
437};
438
439/// Specialization for SplitEmbeddingSparseType.
440template <> struct AttributeRetriever<false, SplitEmbeddingSparseType> {
441 static Expected<SplitEmbeddingSparseType>
442 get(const ONNX_NAMESPACE::AttributeProto *attr,
443 const ProtobufLoader & /* unused */) {
444 std::string str;
445 ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr));
446 if (str == "0") {
447 return SplitEmbeddingSparseType::EST_FLOAT;
448 } else if (str == "1") {
449 return SplitEmbeddingSparseType::EST_FLOAT16;
450 } else if (str == "2") {
451 return SplitEmbeddingSparseType::EST_INT8;
452 } else if (str == "3") {
453 return SplitEmbeddingSparseType::EST_INT4;
454 } else if (str == "4") {
455 return SplitEmbeddingSparseType::EST_INT2;
456 } else {
457 return MAKE_ERR("Invalid SplitEmbeddingSparseType");
458 }
459 }
460};
461
462/// Specialization for float.
463template <> struct AttributeRetriever<false, float> {
464 static Expected<float> get(const ONNX_NAMESPACE::AttributeProto *attr,
465 const ProtobufLoader & /* unused */) {
466 return loadFloat(attr);
467 }
468};
469
470/// Specialization for std::string.
471template <> struct AttributeRetriever<false, std::string> {
472 static Expected<std::string> get(const ONNX_NAMESPACE::AttributeProto *attr,
473 const ProtobufLoader & /* unused */) {
474 return loadStr(attr);
475 }
476};
477
478/// Forwards to the correct AttributeRetriever specialization.
479template <typename T>
480Expected<T> loadAttribute(const ONNX_NAMESPACE::AttributeProto *attr,
481 ProtobufLoader &loader) {
482 RETURN_ERR_IF_NOT(attr, "No such attribute");
483 return AttributeRetriever<std::numeric_limits<T>::is_integer, T>::get(attr,
484 loader);
485}
486
487} // namespace
488
489using ArgumentDictionaryTy =
490 std::unordered_map<std::string, const ONNX_NAMESPACE::AttributeProto *>;
491
492/// \returns a type based on \p ty, but using the provided \p strides.
493static Type getTypeWithCustomStrides(const Type &ty,
494 llvm::ArrayRef<dim_t> strides) {
495 if (strides.empty()) {
496 return ty;
497 }
498 return Type::newStrides(ty, strides);
499}
500
501/// \returns a type in module \p mod, based on \p ty, but using the provided \p
502/// strides.
503static TypeRef getTypeWithCustomStrides(glow::Module &mod, const TypeRef ty,
504 llvm::ArrayRef<dim_t> strides) {
505 if (strides.empty()) {
506 return ty;
507 }
508 return mod.uniqueTypeWithNewStrides(ty, ty->dims(), strides);
509}
510
511/// Given a docstring encoding \p str of a type and its dimension \p
512/// dims, parses the string and \returns a Glow Type from it or Error if
513/// parsing failed. Expected format of str is either elemKindSignifier or
514/// "ElemKind:scale:offset".
515Expected<Type> parseTypeFromDocString(const std::string &str,
516 llvm::ArrayRef<dim_t> dims,
517 bool useGlowCustomOps) {
518 float scale = 1.0;
519 int32_t offset = 0;
520 ElemKind elemKind = ElemKind::FloatTy;
521 ShapeVector strides;
522
523 if (useGlowCustomOps) {
524 std::string elemKindStr;
525 ASSIGN_VALUE_OR_RETURN_ERR(elemKindStr,
526 getAttrFromDocString(elemKindSignifier, str));
527 elemKind = Type::getElementKindFromName(elemKindStr);
528
529 if (isQuantizedElemKind(elemKind)) {
530 std::pair<float, int32_t> scaleOffsetPair;
531 ASSIGN_VALUE_OR_RETURN_ERR(scaleOffsetPair,
532 getQuantParamsFromDocString(str));
533 std::tie(scale, offset) = scaleOffsetPair;
534 }
535 strides = getStridesFromDocString(str);
536 } else {
537 size_t begin = 0;
538
539 // Find Elemkind string
540 size_t end = str.find(':', begin);
541
542 // If a ':' isn't found then assume the whole string is ElemKind (for
543 // backwards compatibility reasons) otherwise look for scale and offset
544 // strings.
545 std::string elemKindStr;
546 if (end == std::string::npos) {
547 elemKindStr = str.substr(0, str.size());
548 } else {
549 elemKindStr = str.substr(begin, end - begin);
550
551 // Get scale string.
552 begin = end + 1;
553 end = str.find(':', begin);
554 if (end == std::string::npos) {
555 return MAKE_ERR("scale not found");
556 }
557 std::string scaleStr = str.substr(begin, end - begin);
558
559 // Get offset string.
560 begin = end + 1;
561 end = str.size();
562 if (end - begin == 0) {
563 return MAKE_ERR("offset not found");
564 }
565
566 std::string offsetStr = str.substr(begin, end - begin);
567
568 scale = std::stof(scaleStr);
569 offset = std::stoi(offsetStr);
570 }
571
572 elemKind = Type::getElementKindFromName(elemKindStr);
573 }
574
575 Type ty;
576 if (isQuantizedElemKind(elemKind)) {
577 ty = Type(elemKind, dims, scale, offset);
578 } else {
579 ty = Type(elemKind, dims);
580 }
581 return getTypeWithCustomStrides(ty, strides);
582}
583
584/// Translates the protocol buffer node \p op into a random access map.
585static ArgumentDictionaryTy
586loadArgumentMap(const ONNX_NAMESPACE::NodeProto &op) {
587 ArgumentDictionaryTy dict;
588 for (auto &arg : op.attribute()) {
589 dict[arg.name()] = &arg;
590 }
591 return dict;
592}
593
594void glow::setOnnxDefineSymbol(const std::vector<std::string> &strs) {
595 onnxDefineSymbol = strs;
596}
597
598ONNX_NAMESPACE::GraphProto glow::parseOnnxFile(const std::string &fileName) {
599 ::ONNX_NAMESPACE::GraphProto graphProto;
600 std::ifstream inputFileStream(fileName, std::ios::in | std::ios::binary);
601 CHECK(inputFileStream) << "Can't find the input file for " << fileName;
602 google::protobuf::io::IstreamInputStream protobufFileStream(&inputFileStream);
603 google::protobuf::io::CodedInputStream codedStream(&protobufFileStream);
604#if GOOGLE_PROTOBUF_VERSION >= 3002000
605 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE);
606#else
607 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE);
608#endif
609 bool parsedSuccessfully = graphProto.ParseFromCodedStream(&codedStream);
610 CHECK(parsedSuccessfully) << "Failed to parse GraphProto";
611 return graphProto;
612}
613
614void glow::fillPlaceholders(const ONNX_NAMESPACE::GraphProto &inputGroup,
615 PlaceholderBindings *bindings,
616 std::vector<Tensor> *partialTensorPayloads,
617 bool usingGlowCustomOps) {
618 for (const auto &tensorProto : inputGroup.initializer()) {
619 const std::string glowLegalizedName =
620 glow::legalizeName(tensorProto.name());
621 auto *tensor =
622 bindings->get(bindings->getPlaceholderByNameSlow(glowLegalizedName));
623 CHECK(tensor) << "Missing " << tensorProto.name()
624 << ", Glow legalized name " << glowLegalizedName;
625 size_t fullSize = tensor->getSizeInBytes();
626 const auto fullType = tensor->getType();
627 auto error = loadTensor(tensorProto, tensor, usingGlowCustomOps);
628 bool hasError = ERR_TO_BOOL(std::move(error));
629 CHECK(!hasError) << "Cannot load input tensor";
630 size_t loadedSize = tensor->getSizeInBytes();
631 if (loadedSize != fullSize) {
632 if (partialTensorPayloads) {
633 VLOG(1) << "Loading " << tensorProto.name() << ", Glow legalized name "
634 << glowLegalizedName << " as a partial tensor: partial size="
635 << tensor->getType().toString()
636 << " full size=" << fullType.toString();
637 Tensor fullTensor(tensor->getUnsafePtr(), &fullType,
638 tensor->getSizeInBytes());
639 // 'fullTensor' doesn't own the underlying data. 'tensor' does. So
640 // we want to keep the original tensor object around until inference
641 // is finished.
642 partialTensorPayloads->emplace_back(std::move(*tensor));
643 *tensor = std::move(fullTensor);
644 } else {
645 // pad with 0
646 VLOG(1) << "Loading and padding " << tensorProto.name()
647 << ", Glow legalized name " << glowLegalizedName
648 << " as a partial tensor: partial size="
649 << tensor->getType().toString()
650 << " full size=" << fullType.toString();
651 Tensor fullTensor(&fullType);
652 std::memcpy(fullTensor.getUnsafePtr(), tensor->getUnsafePtr(),
653 tensor->getSizeInBytes());
654 std::memset(fullTensor.getUnsafePtr() + tensor->getSizeInBytes(), 0,
655 fullTensor.getSizeInBytes() - tensor->getSizeInBytes());
656 *tensor = std::move(fullTensor);
657 }
658 }
659 }
660}
661
662void glow::fillPlaceholders(const std::string &fileName,
663 PlaceholderBindings *bindings,
664 std::vector<Tensor> *partialTensorPayloads,
665 bool usingGlowCustomOps) {
666 const ONNX_NAMESPACE::GraphProto &inputGroup = parseOnnxFile(fileName);
667 fillPlaceholders(inputGroup, bindings, partialTensorPayloads,
668 usingGlowCustomOps);
669}
670
671/// Loads tensor \p T from the input \p in.
672Error glow::loadTensor(const ONNX_NAMESPACE::TensorProto &in, Tensor *T,
673 bool useGlowCustomOps, const std::string &data) {
674 std::vector<dim_t> dim;
675 for (auto d : in.dims()) {
676 dim.push_back(d);
677 }
678
679 ShapeVector strides;
680 if (in.has_doc_string()) {
681 strides = getStridesFromDocString(in.doc_string());
682 }
683
684 if (in.data_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
685 Type ty(ElemKind::FloatTy, dim);
686 T->reset(getTypeWithCustomStrides(ty, strides));
687
688 if (in.float_data_size() > 0) {
689 auto TH = T->getHandle<>();
690 size_t i = 0;
691 for (auto f : in.float_data()) {
692 TH.raw(i++) = f;
693 }
694 } else if (in.has_raw_data() || !data.empty()) {
695 std::istringstream inStream(data.empty() ? in.raw_data() : data,
696 std::stringstream::binary);
697 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(float));
698 } else {
699 return MAKE_ERR("Unsupported Tensor format for FLOAT, name: " + in.name(),
700 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
701 }
702 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::FLOAT16) {
703 Type ty(ElemKind::Float16Ty, dim);
704 T->reset(getTypeWithCustomStrides(ty, strides));
705 if (in.has_raw_data() || !data.empty()) {
706 std::istringstream inStream(data.empty() ? in.raw_data() : data,
707 std::stringstream::binary);
708 inStream.read(T->getUnsafePtr(), T->actualSize() * (sizeof(float) / 2));
709 } else {
710 return MAKE_ERR("Unsupported Tensor format for FLOAT16, name: " +
711 in.name(),
712 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
713 }
714 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::BFLOAT16) {
715 Type ty(ElemKind::BFloat16Ty, dim);
716 T->reset(getTypeWithCustomStrides(ty, strides));
717 if (in.has_raw_data() || !data.empty()) {
718 std::istringstream inStream(data.empty() ? in.raw_data() : data,
719 std::stringstream::binary);
720 inStream.read(T->getUnsafePtr(), T->actualSize() * (sizeof(float) / 2));
721 } else {
722 return MAKE_ERR("Unsupported Tensor format for BFLOAT16, name: " +
723 in.name(),
724 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
725 }
726 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT64) {
727 Type ty(ElemKind::Int64ITy, dim);
728 T->reset(getTypeWithCustomStrides(ty, strides));
729
730 if (in.int64_data_size() > 0) {
731 auto TH = T->getHandle<int64_t>();
732 size_t i = 0;
733 for (auto f : in.int64_data()) {
734 TH.raw(i++) = f;
735 }
736 } else if (in.has_raw_data() || !data.empty()) {
737 std::istringstream inStream(data.empty() ? in.raw_data() : data,
738 std::stringstream::binary);
739 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int64_t));
740 } else {
741 return MAKE_ERR("Unsupported Tensor format for INT64, name: " + in.name(),
742 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
743 }
744 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT8) {
745 if (in.has_doc_string()) {
746 Type ty;
747 ASSIGN_VALUE_OR_RETURN_ERR(
748 ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps));
749 T->reset(ty);
750 } else {
751 // Onnx uses Int8 data type through operators like QuantizeLinear.
752 // Also data is passed through raw_data since TensorProto
753 // does not have data field for int8_t or uint8_t.
754 // scale is set to 1 and offset is set to 0 since both scale and offset
755 // themselves are operators inputs.
756 Type ty(ElemKind::Int8QTy, dim, 1 /* scale*/, 0 /* offset*/);
757 T->reset(getTypeWithCustomStrides(ty, strides));
758 }
759 if (in.has_raw_data() || !data.empty()) {
760 std::istringstream inStream(data.empty() ? in.raw_data() : data,
761 std::stringstream::binary);
762 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int8_t));
763 } else {
764 return MAKE_ERR("Unsupported Tensor format for INT8, name: " + in.name(),
765 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
766 }
767 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT16) {
768 Type ty;
769 ASSIGN_VALUE_OR_RETURN_ERR(
770 ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps));
771 T->reset(ty);
772
773 if (in.has_raw_data() || !data.empty()) {
774 std::istringstream inStream(data.empty() ? in.raw_data() : data,
775 std::stringstream::binary);
776 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int16_t));
777 } else {
778 return MAKE_ERR("Unsupported Tensor format for INT16, name: " + in.name(),
779 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
780 }
781 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT32) {
782 if (in.has_doc_string()) {
783 Type ty;
784 ASSIGN_VALUE_OR_RETURN_ERR(
785 ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps));
786 T->reset(ty);
787 } else {
788 // There are few cases when we will have int32 tensors. For example, the
789 // second output of Concat from Caffe2 concat op is int32
790 Type ty(ElemKind::Int32ITy, dim);
791 T->reset(getTypeWithCustomStrides(ty, strides));
792 }
793
794 if (in.int32_data_size() > 0) {
795 auto TH = T->getHandle<int32_t>();
796 size_t i = 0;
797 for (auto f : in.int32_data()) {
798 TH.raw(i++) = f;
799 }
800 } else if (in.has_raw_data() || !data.empty()) {
801 std::istringstream inStream(data.empty() ? in.raw_data() : data,
802 std::stringstream::binary);
803 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int32_t));
804 } else {
805 return MAKE_ERR("Unsupported Tensor format for INT32, name: " + in.name(),
806 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
807 }
808 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::UINT8) {
809 if (in.has_doc_string()) {
810 Type ty;
811 ASSIGN_VALUE_OR_RETURN_ERR(
812 ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps));
813 T->reset(ty);
814 } else {
815 // Onnx uses Int8 data type through operators like QuantizeLinear.
816 // Also data is passed through raw_data since TensorProto
817 // does not have data field for int8_t or uint8_t.
818 // scale is set to 1 and offset is set to 0 since both scale and offset
819 // themselves are operators inputs.
820 Type ty(ElemKind::UInt8QTy, dim, 1 /* scale*/, 0 /* offset*/);
821 T->reset(getTypeWithCustomStrides(ty, strides));
822 }
823
824 if (in.has_raw_data() || !data.empty()) {
825 std::istringstream inStream(data.empty() ? in.raw_data() : data,
826 std::stringstream::binary);
827 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(uint8_t));
828 } else {
829 return MAKE_ERR("Unsupported Tensor format for UINT8, name: " + in.name(),
830 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
831 }
832 } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::BOOL) {
833 Type ty(ElemKind::BoolTy, dim);
834 T->reset(getTypeWithCustomStrides(ty, strides));
835 if (in.has_raw_data() || !data.empty()) {
836 std::istringstream inStream(data.empty() ? in.raw_data() : data,
837 std::stringstream::binary);
838 inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(bool));
839 } else if (in.int32_data_size() > 0) {
840 // Some ONNX models use int32_data to initialize bool type (e.g., when
841 // converted from Keras).
842 auto TH = T->getHandle<bool>();
843 size_t i = 0;
844 for (auto f : in.int32_data()) {
845 TH.raw(i++) = (bool)f;
846 }
847 } else {
848 return MAKE_ERR("Unsupported Tensor format for BOOL, name: " + in.name(),
849 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
850 }
851 } else {
852 return MAKE_ERR(strFormat("Unsupported tensor data type: %u",
853 static_cast<unsigned>(in.data_type())),
854 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
855 }
856 return Error::success();
857}
858
859Expected<Type>
860ONNXModelLoader::getTensorType(const ONNX_NAMESPACE::TensorProto &in) {
861 std::vector<dim_t> dim;
862 for (auto d : in.dims()) {
863 dim.push_back(d);
864 }
865
866 switch (in.data_type()) {
867 case ONNX_NAMESPACE::TensorProto::FLOAT:
868 return Type(ElemKind::FloatTy, dim);
869
870 case ONNX_NAMESPACE::TensorProto::FLOAT16:
871 return Type(ElemKind::Float16Ty, dim);
872
873 case ONNX_NAMESPACE::TensorProto::BFLOAT16:
874 return Type(ElemKind::BFloat16Ty, dim);
875
876 case ONNX_NAMESPACE::TensorProto::INT64:
877 return Type(ElemKind::Int64ITy, dim);
878
879 case ONNX_NAMESPACE::TensorProto::UINT8:
880 case ONNX_NAMESPACE::TensorProto::INT8:
881 case ONNX_NAMESPACE::TensorProto::INT16:
882 return parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps_);
883
884 case ONNX_NAMESPACE::TensorProto::INT32:
885 if (in.has_doc_string()) {
886 return parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps_);
887 }
888 return Type(ElemKind::Int32ITy, dim);
889
890 case ONNX_NAMESPACE::TensorProto::BOOL:
891 return Type(ElemKind::BoolTy, dim);
892
893 default:
894 return MAKE_ERR(strFormat("Unsupported tensor data type: %u",
895 static_cast<unsigned>(in.data_type())),
896 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
897 }
898 llvm_unreachable("Unsupported tensor data type");
899}
900
901Expected<Type>
902ONNXModelLoader::getTensorType(const ONNX_NAMESPACE::ValueInfoProto &in) {
903 auto type = in.type();
904
905 std::vector<dim_t> dim;
906 ASSIGN_VALUE_OR_RETURN_ERR(dim, getProtoShape(type.tensor_type().shape()));
907
908 ElemKind kind = ElemKind::FloatTy;
909 float scale = 1.0;
910 int32_t offset = 0;
911 ShapeVector strides;
912 if (useGlowCustomOps_) {
913 std::string elemKindStr;
914 ASSIGN_VALUE_OR_RETURN_ERR(
915 elemKindStr, getAttrFromDocString(elemKindSignifier, in.doc_string()));
916 kind = Type::getElementKindFromName(elemKindStr);
917 if (isQuantizedElemKind(kind)) {
918 std::pair<float, int32_t> scaleOffsetPair;
919 ASSIGN_VALUE_OR_RETURN_ERR(scaleOffsetPair,
920 getQuantParamsFromDocString(in.doc_string()));
921 std::tie(scale, offset) = scaleOffsetPair;
922 }
923 strides = getStridesFromDocString(in.doc_string());
924 } else {
925 // Retrieve the ElemKind from the ONNX type, including considerations for
926 // whether the datatype is quantized.
927 RETURN_IF_ERR(
928 onnxTensorDataTypeToElemKind(type.tensor_type().elem_type(), &kind));
929 }
930
931 // If quantized then retrieve the scale and offset if provided (may not be for
932 // fused quantized types since they're ignored anyway).
933 if (isQuantizedElemKind(kind)) {
934 assert(useGlowCustomOps_ &&
935 "Quantized loading not fully supported without custom Glow ops.");
936 return getTypeWithCustomStrides(Type(kind, dim, scale, offset), strides);
937 }
938 return getTypeWithCustomStrides(Type(kind, dim), strides);
939}
940
941Error ONNXModelLoader::getInputsNamesAndTypes(
942 std::vector<std::string> &inTensorNames, std::vector<Type> &inTypes,
943 const std::string &filename) {
944 // Creating GraphProto from modelfile
945 ONNX_NAMESPACE::ModelProto modelDef;
946
947 ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(filename, false, nullptr));
948
949 ONNX_NAMESPACE::GraphProto graphDef = modelDef.graph();
950
951 // GraphDef.input can have both inputs and intermediate tensors whereas
952 // initializers have info about intermediate tensors. Thus the difference
953 // betweem two is taken
954 std::vector<std::string> inputs;
955 for (auto &in : graphDef.input()) {
956 inputs.push_back(in.name());
957 }
958
959 for (const auto &in : graphDef.initializer()) {
960 auto position = std::find(inputs.begin(), inputs.end(), in.name());
961 if (position != inputs.end()) {
962 inputs.erase(position);
963 }
964 }
965
966 for (const auto &finalIn : inputs) {
967 for (const auto &in : graphDef.input()) {
968 if (finalIn.compare(in.name()) != 0) {
969 continue;
970 }
971 inTensorNames.push_back(in.name());
972 const ONNX_NAMESPACE::ValueInfoProto &valueInfo = in;
973 auto type = valueInfo.type();
974
975 std::vector<dim_t> dim;
976 ASSIGN_VALUE_OR_RETURN_ERR(dim,
977 getProtoShape(type.tensor_type().shape()));
978
979 ElemKind kind = ElemKind::FloatTy;
980 RETURN_IF_ERR(
981 onnxTensorDataTypeToElemKind(type.tensor_type().elem_type(), &kind));
982
983 inTypes.emplace_back(kind, dim);
984 }
985 }
986
987 return Error::success();
988}
989
990Error ONNXModelLoader::verifyPreexistingStorage(const Storage *S,
991 const std::string &name,
992 const Type &ty,
993 const std::string &layout,
994 const bool trainable) {
995 RETURN_ERR_IF_NOT(S, "Storage did not exist in Module: " + name);
996 if (replaceDummyTQPs_ && ty.isQuantizedType() &&
997 ty.getScale() == dummyScale) {
998 TensorQuantizationParams TQP;
999 ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(ty.getOffset()));
1000 // If we are replacing dummy TQPs with updated ones, then do verification
1001 // based on the updated type and not the base dummy type we found.
1002 Type updatedTy(ty.getElementType(), ty.dims(), TQP.scale, TQP.offset);
1003 RETURN_ERR_IF_NOT(S->getType()->isEqual(updatedTy),
1004 "Incorrect type for quant param updated existing " +
1005 S->getDebugDesc() + " " + "Expected type " +
1006 updatedTy.toString());
1007 } else {
1008 RETURN_ERR_IF_NOT(S->getType()->isEqual(ty),
1009 "Incorrect type for existing " + S->getDebugDesc() +
1010 " " + "Expected type " + ty.toString());
1011 }
1012 if (const Placeholder *PH = llvm::dyn_cast<Placeholder>(S)) {
1013 RETURN_ERR_IF_NOT(trainable == PH->isTraining(),
1014 "Incorrect trainability for existing Storage " + name);
1015 }
1016 RETURN_ERR_IF_NOT(layout == S->getLayout(),
1017 "Incorrect layout for existing Storage " + name);
1018 return Error::success();
1019}
1020
1021Error ONNXModelLoader::loadInputs(ONNX_NAMESPACE::GraphProto &net,
1022 bool loadInputsAsPlaceholdersForOnnx) {
1023 for (const auto &in : net.input()) {
1024 // Skip static weights.
1025 if (getConstantByNameOrNull(in.name())) {
1026 continue;
1027 }
1028
1029 const std::string &docString = in.doc_string();
1030
1031 Type ty;
1032 ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(in));
1033
1034 if (replaceDummyTQPs_ && ty.isQuantizedType() &&
1035 ty.getScale() == dummyScale) {
1036 TensorQuantizationParams TQP;
1037 ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(ty.getOffset()));
1038 ty = Type(ty.getElementType(), ty.dims(), TQP.scale, TQP.offset);
1039 }
1040
1041 std::pair<bool, std::string> trainableLayoutPair;
1042 ASSIGN_VALUE_OR_RETURN_ERR(
1043 trainableLayoutPair,
1044 getTrainableLayoutPairFromDocString(docString, useGlowCustomOps_));
1045
1046 // If we already have the existing module then we may already have the input
1047 // Placeholder. If so, verify it has the correct type.
1048 if (loadIntoExistingModule_) {
1049 RETURN_ERR_IF_NOT(
1050 loadInputsAsPlaceholdersForOnnx,
1051 "Must load inputs as Placeholders when using existing Module.");
1052 if (Placeholder *PH = mod_.getPlaceholderByNameSlow(in.name())) {
1053 // Set Fused types of Placeholders if they were expected to be
1054 // fused. Necessary because Caffe2/ONNX protos do not have fused types
1055 // explicitly, so will be loaded initially as int8.
1056 if (ty.isFusedQuantizedType()) {
1057 RETURN_IF_ERR(setFusedTy(PH, mod_.uniqueType(ty)));
1058 }
1059 RETURN_IF_ERR(verifyPreexistingStorage(PH, in.name(), ty,
1060 trainableLayoutPair.second,
1061 trainableLayoutPair.first));
1062 nodeValueByName_[in.name()] = PH->getOutput();
1063 continue;
1064 }
1065 }
1066
1067 // We must not have the input created yet, so do so.
1068 if (loadInputsAsPlaceholdersForOnnx) {
1069 RETURN_ERR_IF_NOT(!clipQuantRangeToFP16_ || !ty.isQuantizedType() ||
1070 ty.isFusedQuantizedType(),
1071 "Do not support clipQuantRangeToFP16 with unfused "
1072 "quantized input Placeholders: " +
1073 in.name());
1074 Placeholder *inPH;
1075 ASSIGN_VALUE_OR_RETURN_ERR(
1076 inPH, createAndRegisterPlaceholder(in.name(), mod_.uniqueType(ty),
1077 staticInputs_.count(in.name()),
1078 trainableLayoutPair.first,
1079 trainableLayoutPair.second));
1080 auto loaderNameOrErr =
1081 getAttrFromDocString(loaderNameSignifier, docString);
1082 const std::string &loaderName =
1083 !ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)
1084 ? loaderNameOrErr.get()
1085 : in.name();
1086 RETURN_ERR_IF_NOT(inputVarsByName_.try_emplace(loaderName, inPH).second,
1087 "Already had input placeholder by name " + loaderName);
1088 } else {
1089 Tensor T(ty);
1090 RETURN_IF_ERR(createAndRegisterConstant(in.name(), std::move(T)));
1091 }
1092 }
1093 return Error::success();
1094}
1095
1096Expected<bool> ONNXModelLoader::getBroadcast(ArgumentDictionaryTy &dict) {
1097 // Starting with opset 7, broadcasting is implicit and doesn't require any
1098 // attribute.
1099 if (opsetVersion_ > 6) {
1100 return true;
1101 }
1102 if (!dict.count("broadcast")) {
1103 return false;
1104 }
1105
1106 int broadcast;
1107 ASSIGN_VALUE_OR_RETURN_ERR(broadcast, loadInt(dict.at("broadcast")));
1108 return broadcast == 1;
1109}
1110
1111bool ONNXModelLoader::hasMultidirectionalBroadcast(
1112 const llvm::StringRef typeName) {
1113 // Before opset 7, broadcasting was unidirectional.
1114 if (opsetVersion_ > 6) {
1115 // List of ops that support multidirectional broadcast can be found at
1116 // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
1117 if ((typeName == "Add") || (typeName == "Sub") || (typeName == "Mul") ||
1118 (typeName == "Div") || (typeName == "Equal") ||
1119 (typeName == "Greater") || (typeName == "Less") ||
1120 (typeName == "Max") || (typeName == "Mean") || (typeName == "Min") ||
1121 (typeName == "Or") || (typeName == "Pow") || (typeName == "Sum") ||
1122 (typeName == "Xor")) {
1123 return true;
1124 }
1125 }
1126 return false;
1127}
1128
1129Expected<ElemKind> ONNXModelLoader::convertTensorProtoDataType(
1130 ONNX_NAMESPACE::TensorProto_DataType t) {
1131 switch (t) {
1132 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
1133 return ElemKind::FloatTy;
1134 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
1135 return ElemKind::Float16Ty;
1136 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
1137 return ElemKind::BFloat16Ty;
1138 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
1139 return ElemKind::Float64Ty;
1140 case ONNX_NAMESPACE::TensorProto_DataType_INT32:
1141 return ElemKind::Int32ITy;
1142 case ONNX_NAMESPACE::TensorProto_DataType_INT64:
1143 return ElemKind::Int64ITy;
1144 case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
1145 return ElemKind::BoolTy;
1146 default:;
1147 }
1148 return MAKE_ERR("Non supported ONNX type");
1149}
1150
1151Error ONNXModelLoader::setVersion(ONNX_NAMESPACE::ModelProto MP) {
1152 irVersion_ = MP.ir_version();
1153 opsetVersion_ = 0;
1154 RETURN_ERR_IF_NOT(
1155 irVersion_ >= 3,
1156 "This ONNX model with ir_version < 3 is too old to be supported.",
1157 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION);
1158 for (const auto &imp : MP.opset_import()) {
1159 if (!imp.has_domain() || imp.domain() == "") {
1160 opsetVersion_ = imp.version();
1161 break;
1162 }
1163 }
1164 RETURN_ERR_IF_NOT(opsetVersion_ > 0,
1165 "The opset of this ONNX model is not supported.");
1166 return Error::success();
1167}
1168
1169Expected<ONNX_NAMESPACE::ModelProto>
1170ONNXModelLoader::loadProto(google::protobuf::io::ZeroCopyInputStream &iStream) {
1171 // Construct and configure a Coded Input Stream
1172 google::protobuf::io::CodedInputStream codedStream(&iStream);
1173
1174 // Don't warn about large file sizes.
1175#if GOOGLE_PROTOBUF_VERSION >= 3002000
1176 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE);
1177#else
1178 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE);
1179#endif
1180 ONNX_NAMESPACE::ModelProto MP;
1181 bool parseNet = MP.ParseFromCodedStream(&codedStream);
1182 RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto",
1183 ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
1184
1185 return MP;
1186}
1187
1188Expected<ONNX_NAMESPACE::ModelProto>
1189ONNXModelLoader::loadProto(const void *onnxModel, size_t onnxModelSize) {
1190 google::protobuf::io::ArrayInputStream arrayStream(onnxModel, onnxModelSize);
1191 return loadProto(arrayStream);
1192}
1193
1194Expected<ONNX_NAMESPACE::ModelProto>
1195ONNXModelLoader::loadProto(const std::string &filename, bool zipMode,
1196 const std::string *inputStringPtr) {
1197 /// A helper class to report parsing errors when loading model protos.
1198 class OnnxTxtErrCollector : public google::protobuf::io::ErrorCollector {
1199 public:
1200 OnnxTxtErrCollector() = default;
1201 ~OnnxTxtErrCollector() override = default;
1202 void AddError(int line, int column, const std::string &message) override {
1203 llvm::errs() << strFormat("ONNX parsing error at [%d, %d]: ", line,
1204 column)
1205 << message << "\n";
1206 }
1207 void AddWarning(int line, int column, const std::string &message) override {
1208 llvm::errs() << strFormat("ONNX parsing warning at [%d, %d]: ", line,
1209 column)
1210 << message << "\n";
1211 }
1212 };
1213
1214 // Create a parser object and attach an error collector to it.
1215 OnnxTxtErrCollector errorCollector;
1216 google::protobuf::TextFormat::Parser parser;
1217 parser.RecordErrorsTo(&errorCollector);
1218 bool parseNet;
1219 ONNX_NAMESPACE::ModelProto MP;
1220
1221 if (zipMode) {
1222 RETURN_ERR_IF_NOT(
1223 inputStringPtr == nullptr,
1224 "OnnxModelLoader load from string for zip mode not supported");
1225 ZipReader zip(filename);
1226 std::string buffer = zip.getRecord("model");
1227 // Try to parse as a protocol buffer first.
1228 parseNet = MP.ParseFromString(buffer);
1229 // Try to parse a textual representation first.
1230 if (!parseNet) {
1231 // If it is not a protocol buffer, try to parse as a text format.
1232 parseNet = parser.ParseFromString(buffer, &MP);
1233 }
1234 if (!parseNet) {
1235 RETURN_ERR_IF_NOT(false, "Failed to parse ModelProto",
1236 ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
1237 }
1238 size_t numWeights = 0;
1239 auto numWeightsStr = zip.getRecord("weights");
1240 numWeights = atoi(numWeightsStr.c_str());
1241 for (size_t i = 0; i < numWeights; ++i) {
1242 std::stringstream ss;
1243 ss << "weight_" << i;
1244 buffer = zip.getRecord(ss.str());
1245 auto *t = MP.mutable_graph()->add_initializer();
1246 t->ParseFromString(buffer);
1247 }
1248 return MP;
1249 }
1250
1251 std::ifstream ff(filename, std::ios::in | std::ios::binary);
1252 if (inputStringPtr == nullptr) {
1253 RETURN_ERR_IF_NOT(ff,
1254 strFormat("Can't find the model or network files for %s.",
1255 filename.c_str()),
1256 ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
1257 }
1258
1259 // TODO: intend to find a way to reuse the following function later
1260 // for the text format onnx model:
1261 // bool ONNXModelLoader::loadProto(ONNX_NAMESPACE::GraphProto &net,
1262 // google::protobuf::io::ZeroCopyInputStream &iStream)
1263 if (filename.find(".onnxtxt") != std::string::npos) {
1264 if (inputStringPtr == nullptr) {
1265 // Reserve a reasonably big buffer to make sure that reading of models
1266 // with serialized constants is fast enough.
1267 const std::streamsize bufferSize = 1024 * 1024 * 16;
1268 std::vector<char> buffer(bufferSize);
1269 ff.rdbuf()->pubsetbuf(buffer.data(), bufferSize);
1270 std::stringstream ss;
1271 ss << ff.rdbuf();
1272 std::string str = ss.str();
1273 parseNet = parser.ParseFromString(str, &MP);
1274 } else {
1275 parseNet = parser.ParseFromString(*inputStringPtr, &MP);
1276 }
1277
1278 RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto",
1279 ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
1280 return MP;
1281 }
1282
1283 if (inputStringPtr != nullptr) {
1284 std::istringstream iss(*inputStringPtr);
1285 google::protobuf::io::IstreamInputStream stringStream(&iss);
1286 return loadProto(stringStream);
1287 }
1288 google::protobuf::io::IstreamInputStream fileStream(&ff);
1289 return loadProto(fileStream);
1290}
1291
1292/// Given an input \p val , ceil value is computed for a given datatype T
1293template <typename T> T ceil(float val) {
1294 return (val - (T)val) > 0 ? (T)(val + 1) : (T)val;
1295}
1296
1297namespace {
1298/// Helper type for pads.
1299using Pads = std::vector<unsigned_t>;
1300} // namespace
1301
1302/// Get the Pads value based on setting for auto_pad.
1303/// \p kdim : kernel sizes (HW)
1304/// \p sdim: stride sizes (HW)
1305/// \p idim: input sizes (HW)
1306Expected<Pads> getPads(ArgumentDictionaryTy &dict,
1307 llvm::ArrayRef<unsigned_t> kdim,
1308 llvm::ArrayRef<unsigned_t> sdim,
1309 llvm::ArrayRef<unsigned_t> idim) {
1310 // TODO: ONNX spec disallows using "pads" and "auto_pad" together. However,
1311 // the implementation allows mixing them and onnxruntime gives pads priority.
1312 if (dict.count("pads")) {
1313 if (dict.at("pads")->ints_size() == 2) { // For maxPool1D
1314 return Pads({0, (unsigned_t)dict.at("pads")->ints(0), 0,
1315 (unsigned_t)dict.at("pads")->ints(1)});
1316 }
1317 return getShape<unsigned_t>(dict["pads"]);
1318 }
1319
1320 // Set the default zero pads. Number of pads is dependent on input dims
1321 auto zeroPads = [idim]() {
1322 if (idim.size() == 3) {
1323 return Pads({0, 0, 0, 0, 0, 0});
1324 } else {
1325 return Pads({0, 0, 0, 0});
1326 }
1327 };
1328
1329 if (dict.count("auto_pad")) {
1330 std::string padStr;
1331 ASSIGN_VALUE_OR_RETURN_ERR(padStr, loadStr(dict.at("auto_pad")));
1332 if (padStr == "VALID") {
1333 // Return default value 0 for pads.
1334 return zeroPads();
1335 } else if (padStr == "SAME_UPPER" || padStr == "SAME_LOWER") {
1336 unsigned_t near, far, top, left, bottom, right;
1337 // From https://arxiv.org/pdf/1603.07285.pdf 2.4,
1338 // o = floor((i + 2*p - k)/s) + 1
1339 // Also, from https://github.com/onnx/onnx/blob/master/docs/Operators.md
1340 // output_spatial_shape[i] =
1341 // ceil(input_spatial_shape[i] / strides_spatial_shape[i])
1342 // pad_shape[i] =
1343 // (output_spatial_shape[i] - 1) * strides_spatial_shape[i]
1344 // + kernel_spatial_shape[i] - input_spatial_shape[i]
1345 // Use the smallest padding possible out of the possible options.
1346 unsigned_t odim;
1347 if (idim.size() == 2) {
1348 llvm::SmallVector<unsigned_t, 2> pdim(2); // Total Paddding, HW.
1349 for (size_t i = 0, e = pdim.size(); i < e; i++) {
1350 odim = ceil<unsigned_t>((float)idim[i] / (float)sdim[i]);
1351 pdim[i] = sdim[i] * (odim - 1) + kdim[i] - idim[i];
1352 }
1353 if (padStr == "SAME_UPPER") {
1354 // SAME_UPPPER: if odd number for pdim[i], use extra padding at the
1355 // end.
1356 top = pdim[0] / 2;
1357 bottom = top + (pdim[0] & 0x1);
1358 left = pdim[1] / 2;
1359 right = left + (pdim[1] & 0x1);
1360 } else {
1361 // SAME_LOWER: if odd number for pdim[i], use extra padding at the
1362 // beginning.
1363 bottom = pdim[0] / 2;
1364 top = bottom + (pdim[0] & 0x1);
1365 right = pdim[1] / 2;
1366 left = right + (pdim[1] & 0x1);
1367 }
1368 return Pads({top, left, bottom, right});
1369 } else if (idim.size() == 3) {
1370 llvm::SmallVector<unsigned_t, 3> pdim(3); // Total Paddding, HW.
1371 for (size_t i = 0, e = pdim.size(); i < e; i++) {
1372 odim = ceil<unsigned_t>((float)idim[i] / (float)sdim[i]);
1373 pdim[i] = sdim[i] * (odim - 1) + kdim[i] - idim[i];
1374 }
1375 if (padStr == "SAME_UPPER") {
1376 // SAME_UPPPER: if odd number for pdim[i], use extra padding at the
1377 // end.
1378 near = pdim[0] / 2;
1379 far = near + (pdim[0] & 0x1);
1380 top = pdim[1] / 2;
1381 bottom = top + (pdim[1] & 0x1);
1382 left = pdim[2] / 2;
1383 right = left + (pdim[2] & 0x1);
1384 } else {
1385 // SAME_LOWER: if odd number for pdim[i], use extra padding at the
1386 // beginning.
1387 far = pdim[0] / 2;
1388 near = far + (pdim[0] & 0x1);
1389 bottom = pdim[1] / 2;
1390 top = bottom + (pdim[1] & 0x1);
1391 right = pdim[2] / 2;
1392 left = right + (pdim[2] & 0x1);
1393 }
1394 return Pads({near, top, left, far, bottom, right});
1395 } else {
1396 return MAKE_ERR("getPads only works for 2D or 3D");
1397 }
1398 } else if (padStr == "NOTSET") {
1399 // We use explicit pads (if not given we assume its all zeros).
1400 if (dict.count("pads")) {
1401 if (dict.at("pads")->ints_size() == 2) { // For maxPool1D
1402 return Pads({0, (unsigned_t)dict.at("pads")->ints(0), 0,
1403 (unsigned_t)dict.at("pads")->ints(1)});
1404 }
1405 return getShape<unsigned_t>(dict["pads"]);
1406 } else {
1407 return zeroPads();
1408 }
1409 }
1410 return MAKE_ERR("Only auto_pad==VALID, SAME_UPPER, SAME_LOWER and NOTSET "
1411 "are supported");
1412 }
1413 // Return default value 0 for pads.
1414 return zeroPads();
1415}
1416
1417/// Get the Pads value based on setting for auto_pad.
1418/// \p kdim : kernel sizes (HW)
1419/// \p sdim: stride sizes (HW)
1420/// \p idim: input sizes (HW)
1421static Expected<Pads> getConvTransposePadsfromOutput(
1422 ArgumentDictionaryTy &dict, llvm::ArrayRef<unsigned_t> kdim,
1423 llvm::ArrayRef<unsigned_t> sdim, llvm::ArrayRef<unsigned_t> dilation,
1424 llvm::ArrayRef<unsigned_t> idim, llvm::ArrayRef<unsigned_t> odim) {
1425
1426 llvm::SmallVector<unsigned_t, 2> pdim(2); // Total Paddding, HW.
1427 for (size_t i = 0, e = pdim.size(); i < e; i++) {
1428 pdim[i] = sdim[i] * (idim[i] - 1) /* + output_padding[0]*/ +
1429 ((kdim[i] - 1) * dilation[i] + 1) - odim[i];
1430 }
1431
1432 unsigned_t top, left, bottom, right;
1433
1434 if (dict.count("auto_pad")) {
1435 std::string padStr;
1436 ASSIGN_VALUE_OR_RETURN_ERR(padStr, loadStr(dict.at("auto_pad")));
1437 if (padStr == "SAME_UPPER") {
1438 // SAME_UPPER ONNX formula:
1439 // if odd number for pdim[i], use extra padding at the end.
1440 // pads[start_i] = total_padding[i] - (total_padding[i]/2);
1441 // pads[end_i] = (total_padding[i]/2).
1442 top = pdim[0] - pdim[0] / 2;
1443 bottom = pdim[0] / 2;
1444 left = pdim[1] - pdim[1] / 2;
1445 right = pdim[1] / 2;
1446 return Pads({top, left, bottom, right});
1447 }
1448 }
1449 // !SAME_UPPER ONNX formula:
1450 // pads[start_i] = total_padding[i]/2;
1451 // pads[end_i] = total_padding[i] - (total_padding[i]/2)
1452 top = pdim[0] / 2;
1453 bottom = pdim[0] - pdim[0] / 2;
1454 left = pdim[1] / 2;
1455 right = pdim[1] - pdim[1] / 2;
1456 return Pads({top, left, bottom, right});
1457}
1458
1459const std::string ONNXModelLoader::opErrMsg(const ONNX_NAMESPACE::NodeProto &op,
1460 const std::string &errMsg) {
1461 const std::string &opName = loadOperatorName(op);
1462 return strFormat(" [Operator-'%s', opset_version-%d, ir_version-%d] : %s ",
1463 opName.c_str(), int(opsetVersion_), int(irVersion_),
1464 errMsg.c_str());
1465}
1466
1467Error ONNXModelLoader::loadConstant(const ONNX_NAMESPACE::NodeProto &op,
1468 ArgumentDictionaryTy &dict) {
1469 /*
1470 output: "Parameter6"
1471 name: "Parameter6"
1472 op_type: "Constant"
1473 attribute {
1474 name: "value"
1475 t {
1476 dims: 8
1477 data_type: FLOAT
1478 float_data: -0.161539719
1479 float_data: -0.433835655
1480 float_data: 0.091641359
1481 float_data: -0.0168522168
1482 float_data: -0.0650264397
1483 float_data: -0.131737873
1484 float_data: 0.0204175506
1485 float_data: -0.121110231
1486 }
1487 type: TENSOR
1488 }
1489 doc_string: ""
1490 domain: ""
1491 */
1492
1493 const auto &name = op.output(0);
1494 // If the tensor is pre-populated by the user of this class then we don't
1495 // need to allocate a new tensor.
1496 if (getConstantByNameOrNull(name)) {
1497 return Error::success();
1498 }
1499
1500 const auto &type = dict.at("value")->type();
1501 RETURN_ERR_IF_NOT((type == ONNX_NAMESPACE::AttributeProto::TENSOR ||
1502 type == ONNX_NAMESPACE::AttributeProto::INTS),
1503 "Only Tensor type constants are supported.",
1504 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
1505
1506 Tensor T;
1507 if (type == ONNX_NAMESPACE::AttributeProto::TENSOR) {
1508 RETURN_IF_ERR(loadTensor(dict.at("value")->t(), &T, useGlowCustomOps_));
1509 } else {
1510 std::vector<int64_t> ints;
1511 ASSIGN_VALUE_OR_RETURN_ERR(ints, getShape<int64_t>(dict["value"]));
1512 T = Tensor(ElemKind::Int64ITy, {(dim_t)ints.size()});
1513 auto TH = T.getHandle<int64_t>();
1514 for (dim_t i = 0, e = ints.size(); i < e; ++i) {
1515 TH.at({i}) = ints[i];
1516 }
1517 }
1518 RETURN_IF_ERR(createAndRegisterConstant(name, std::move(T)));
1519
1520 return Error::success();
1521}
1522
1523/// Retrieves data from a constant Tensor and stores it in a vector.
1524template <typename T, typename datatype = ssize_t>
1525static void helperSetter(Constant *constT, std::vector<datatype> &vec) {
1526 auto constH = constT->getPayload().getHandle<T>();
1527 for (dim_t i = 0; i < constH.size(); ++i) {
1528 vec.push_back(constH.at({i}));
1529 }
1530}
1531
1532template <typename T>
1533Error ONNXModelLoader::getRange(const ONNX_NAMESPACE::NodeProto &op,
1534 Constant *constT) {
1535 T start = constT->getPayload().getHandle<T>().raw(0);
1536
1537 ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(1)));
1538 T limit = constT->getPayload().getHandle<T>().raw(0);
1539
1540 ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(2)));
1541 T delta = constT->getPayload().getHandle<T>().raw(0);
1542
1543 std::vector<T> rangeValues;
1544 if (limit > start) {
1545 RETURN_ERR_IF_NOT(delta > 0, "delta should be positive");
1546 auto i = start;
1547 while (i < limit) {
1548 rangeValues.push_back(i);
1549 i += delta;
1550 }
1551 } else if (limit < start) {
1552 RETURN_ERR_IF_NOT(delta < 0, "delta should be negative");
1553 auto i = start;
1554 while (i > limit) {
1555 rangeValues.push_back(i);
1556 i += delta;
1557 }
1558 } else {
1559 return MAKE_ERR("limit and start value should be different");
1560 }
1561
1562 Tensor rangeTensor(constT->getElementType(),
1563 {static_cast<unsigned int>(rangeValues.size())});
1564 rangeTensor.getHandle<T>() = rangeValues;
1565 RETURN_IF_ERR(
1566 createAndRegisterConstant(op.output(0), std::move(rangeTensor)));
1567 return Error::success();
1568}
1569
1570Error ONNXModelLoader::loadRange(const ONNX_NAMESPACE::NodeProto &op,
1571 ArgumentDictionaryTy &dict) {
1572 Constant *constT;
1573 ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(0)));
1574 auto glowType = constT->getElementType();
1575 if (glowType == ElemKind::Int64ITy) {
1576 return getRange<int64_t>(op, constT);
1577 } else if (glowType == ElemKind::Int32ITy) {
1578 return getRange<int32_t>(op, constT);
1579 } else if (glowType == ElemKind::FloatTy) {
1580 return getRange<float>(op, constT);
1581 } else {
1582 return MAKE_ERR("Data type not supported");
1583 }
1584}
1585
1586Error ONNXModelLoader::loadPRelu(const ONNX_NAMESPACE::NodeProto &op,
1587 ArgumentDictionaryTy &dict) {
1588 const std::string &opName = loadOperatorName(op);
1589
1590 NodeValue in;
1591 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1592
1593 NodeValue slope;
1594 ASSIGN_VALUE_OR_RETURN_ERR(slope, getNodeValueByName(op.input(1)));
1595
1596 // Do broadcasting.
1597 auto targetDim = in.dims();
1598 // Sets the axis of each inputs so that the trailing-most dimensions of
1599 // input tensors and the target shape are aligned.
1600 int axis = targetDim.size() - slope.dims().size();
1601 auto *finalSlope = G_->createBroadcast(opName, slope, targetDim, axis);
1602 auto *R = G_->createPRELU(opName, in, finalSlope);
1603 RETURN_IF_ERR(addNodeAsOutput(op, R));
1604 return Error::success();
1605}
1606
1607Error ONNXModelLoader::loadSlice(const ONNX_NAMESPACE::NodeProto &op,
1608 ArgumentDictionaryTy &dict) {
1609 const std::string &opName = loadOperatorName(op);
1610 NodeValue data;
1611 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1612 auto dims = data.dims();
1613 auto numDims = dims.size();
1614
1615 std::vector<ssize_t> starts;
1616 std::vector<ssize_t> ends;
1617 // Attribute 'axes' is optional.
1618 std::vector<ssize_t> axes;
1619 if (this->opsetVersion_ >= 10) {
1620 Constant *startsC = getConstantByNameOrNull(op.input(1));
1621 Constant *endsC = getConstantByNameOrNull(op.input(2));
1622
1623 RETURN_ERR_IF_NOT(startsC, opErrMsg(op, "Starts Tensor is not Constant."));
1624 RETURN_ERR_IF_NOT(endsC, opErrMsg(op, "Ends Tensor is not Constant."));
1625
1626 if (startsC->getElementType() == ElemKind::Int64ITy) {
1627 helperSetter<int64_t>(startsC, starts);
1628 } else if (startsC->getElementType() == ElemKind::Int32ITy) {
1629 helperSetter<int32_t>(startsC, starts);
1630 } else {
1631 RETURN_ERR_IF_NOT(
1632 false,
1633 opErrMsg(
1634 op,
1635 strFormat("Slice Starts Tensor has unsupported type '%s' ",
1636 startsC->getType()->getElementName().str().c_str())));
1637 }
1638
1639 if (endsC->getElementType() == ElemKind::Int64ITy) {
1640 helperSetter<int64_t>(endsC, ends);
1641 } else if (endsC->getElementType() == ElemKind::Int32ITy) {
1642 helperSetter<int32_t>(endsC, ends);
1643 } else {
1644 RETURN_ERR_IF_NOT(
1645 false,
1646 opErrMsg(
1647 op, strFormat("Slice Ends Tensor has unsupported type '%s' ",
1648 endsC->getType()->getElementName().str().c_str())));
1649 }
1650
1651 if (op.input_size() > 3) {
1652 Constant *axesC = getConstantByNameOrNull(op.input(3));
1653
1654 RETURN_ERR_IF_NOT(axesC, opErrMsg(op, "Axes Tensor is not Constant."));
1655
1656 if (axesC->getElementType() == ElemKind::Int64ITy) {
1657 helperSetter<int64_t>(axesC, axes);
1658 } else if (axesC->getElementType() == ElemKind::Int32ITy) {
1659 helperSetter<int32_t>(axesC, axes);
1660 } else {
1661 RETURN_ERR_IF_NOT(
1662 false,
1663 opErrMsg(
1664 op,
1665 strFormat("Slice Axes Tensor has unsupported type '%s' ",
1666 axesC->getType()->getElementName().str().c_str())));
1667 }
1668
1669 if (op.input_size() > 4) {
1670 std::vector<ssize_t> step;
1671 Constant *stepC = getConstantByNameOrNull(op.input(4));
1672
1673 RETURN_ERR_IF_NOT(stepC, opErrMsg(op, "Step tensor is not Constant."));
1674
1675 if (stepC->getElementType() == ElemKind::Int64ITy) {
1676 helperSetter<int64_t>(stepC, step);
1677 } else if (stepC->getElementType() == ElemKind::Int32ITy) {
1678 helperSetter<int32_t>(stepC, step);
1679 } else {
1680 RETURN_ERR_IF_NOT(
1681 false,
1682 opErrMsg(
1683 op,
1684 strFormat("Step Tensor has unsupported type '%s'",
1685 stepC->getType()->getElementName().str().c_str())));
1686 }
1687
1688 // Step is interpreted 1 as default.
1689 for (size_t i = 0; i < step.size(); i++) {
1690 RETURN_ERR_IF_NOT(step[i] == 1,
1691 opErrMsg(op, "step!=1 is currently not supported"));
1692 }
1693 }
1694 }
1695 } else {
1696 // Attributes 'starts' and 'ends' are mandatory and must be consistent.
1697 ASSIGN_VALUE_OR_RETURN_ERR(starts, getShape<ssize_t>(dict["starts"]));
1698 ASSIGN_VALUE_OR_RETURN_ERR(ends, getShape<ssize_t>(dict["ends"]));
1699
1700 if (dict.count("axes")) {
1701 // The ONNX spec is unclear so we consider that the 'axes' array may have
1702 // any size. The constraints are:
1703 // - the element value must be in range [0, numDims),
1704 // - 'starts' & 'ends' arrays must have the same size as the 'axes' array.
1705 // In case an axis is specified multiple times in 'axes', the later
1706 // parameters will simply overwrite the previous ones.
1707 ASSIGN_VALUE_OR_RETURN_ERR(
1708 axes, loadAxes<ssize_t>(dict["axes"], data.dims().size()));
1709 }
1710 }
1711 RETURN_ERR_IF_NOT(
1712 (starts.size() == ends.size()),
1713 opErrMsg(
1714 op,
1715 strFormat("Slice: 'starts' and 'ends' arrays must have the same size."
1716 " but found starts %zu and ends %zu sizes ",
1717 starts.size(), ends.size())));
1718
1719 if (axes.empty()) {
1720 for (size_t i = 0; i < numDims; i++) {
1721 axes.push_back(ssize_t(i));
1722 }
1723 }
1724
1725 // The ONNX description is unclear and doesn't describe what to do when a
1726 // an axis index is not given in the axes array. An interpretation is that
1727 // for such an axis, the entire range is taken. Then, we initialize
1728 // newStarts and newEnds with the full range for all axes.
1729 std::vector<dim_t> newStarts(numDims);
1730 std::vector<dim_t> newEnds(numDims);
1731 for (size_t i = 0; i < numDims; i++) {
1732 newStarts[i] = 0;
1733 newEnds[i] = dims[i];
1734 }
1735
1736 // Determine the coordinates of the sub-tensor to extract.
1737 RETURN_ERR_IF_NOT(axes.size() == starts.size(),
1738 opErrMsg(op, strFormat("'axes' %zu and 'starts' %zu must"
1739 "be the same size.",
1740 axes.size(), starts.size())));
1741 for (size_t i = 0; i < axes.size(); i++) {
1742 ssize_t newStart = starts[i];
1743 ssize_t newEnd = ends[i];
1744 ssize_t axisId = axes[i];
1745 RETURN_ERR_IF_NOT(
1746 (axisId >= 0) && (axisId < ssize_t(numDims)),
1747 opErrMsg(op, "Axes indexes must be within the input tensor range."));
1748
1749 // ONNX: "If the value passed to start or end is larger than the n (the
1750 // number of elements in this dimension), it represents n".
1751 if (newStart > ssize_t(dims[axisId])) {
1752 newStart = ssize_t(dims[axisId]);
1753 }
1754 if (newEnd > ssize_t(dims[axisId])) {
1755 newEnd = ssize_t(dims[axisId]);
1756 }
1757
1758 // The ONNX description is unclear and the numpy definition is more
1759 // accurate.
1760 // - ONNX: "Similar to numpy. [...]. If a negative value is passed for any
1761 // of the start or end indices, it represent number of elements before the
1762 // end of that dimension."
1763 // - Numpy: "Negative indices are interpreted as counting from the end of
1764 // the array (i.e., if n_i < 0, it means n_i + d_i)."
1765 if (newStart < 0) {
1766 newStart = ssize_t(dims[axisId]) + newStart;
1767 RETURN_ERR_IF_NOT(
1768 newStart >= 0,
1769 opErrMsg(op, strFormat("Slice: final start index %zu should "
1770 " never be negative.",
1771 newStart)));
1772 }
1773 if (newEnd < 0) {
1774 newEnd = ssize_t(dims[axisId]) + newEnd;
1775 RETURN_ERR_IF_NOT(
1776 newEnd >= 0,
1777 opErrMsg(op, strFormat("Slice: final end index %zu should "
1778 " never be negative.",
1779 newEnd)));
1780 }
1781
1782 newStarts[axisId] = size_t(newStart);
1783 newEnds[axisId] = size_t(newEnd);
1784 }
1785
1786 // Create the IR node.
1787 Node *SN = G_->createSlice(opName, data, newStarts, newEnds);
1788 RETURN_IF_ERR(addNodeAsOutput(op, SN));
1789
1790 return Error::success();
1791}
1792
1793Error ONNXModelLoader::loadTrigonometricOps(const std::string &typeName,
1794 const ONNX_NAMESPACE::NodeProto &op,
1795 ArgumentDictionaryTy &dict) {
1796 const std::string &opName = loadOperatorName(op);
1797 NodeValue in;
1798 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1799 Node *N;
1800 if (typeName == "Sin") {
1801 N = G_->createSin(opName, in);
1802 } else {
1803 N = G_->createCos(opName, in);
1804 }
1805 RETURN_IF_ERR(addNodeAsOutput(op, N));
1806 return Error::success();
1807}
1808
1809Error ONNXModelLoader::loadErf(const ONNX_NAMESPACE::NodeProto &op,
1810 const ArgumentDictionaryTy &dict) {
1811 const std::string &opName = loadOperatorName(op);
1812 NodeValue in;
1813 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1814 Node *N = G_->createErf(opName, in);
1815 RETURN_IF_ERR(addNodeAsOutput(op, N));
1816 return Error::success();
1817}
1818
1819Error ONNXModelLoader::loadConv(const ONNX_NAMESPACE::NodeProto &op,
1820 ArgumentDictionaryTy &dict) {
1821 // Load the inputs
1822 NodeValue in;
1823 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1824
1825 if (in.dims().size() == 3) {
1826 return loadConv1D(op, dict);
1827 } else if (in.dims().size() == 4) {
1828 return loadConv2D(op, dict);
1829 } else if (in.dims().size() == 5) {
1830 return loadConv3D(op, dict);
1831 } else {
1832 return MAKE_ERR(strFormat(
1833 "Only 1D (3 dims), 2D (4 dims) and 3D (5 dims) convolution are "
1834 "supported by the ONNX loader but there are %zu input dims.",
1835 in.dims().size()));
1836 }
1837}
1838
1839Error ONNXModelLoader::loadConv1D(const ONNX_NAMESPACE::NodeProto &op,
1840 ArgumentDictionaryTy &dict) {
1841 const std::string &opName = loadOperatorName(op);
1842 // Load the attributes
1843 std::vector<glow::unsigned_t> strides(2, 1);
1844
1845 strides[1] = dict.count("strides") ? dict.at("strides")->ints(0) : 1;
1846 strides[0] = 1;
1847
1848 unsigned_t group = 1;
1849 if (dict.count("group")) {
1850 ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group")));
1851 }
1852
1853 std::vector<unsigned_t> dilations;
1854 ASSIGN_VALUE_OR_RETURN_ERR(dilations,
1855 getDilations(dict, std::vector<unsigned_t>{1, 1}));
1856 // Expand dilations to length 2 since Glow treat conv1D as conv2D
1857 if (dilations.size() == 1) {
1858 dilations.push_back(dilations[0]);
1859 }
1860
1861 // Load the inputs
1862 NodeValue in;
1863 // input == NCW ---> NCHW
1864 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1865 in = G_->createExpandDims(opName, in, 2);
1866 // filtervalue == CKS ---> CKRS
1867 NodeValue filterValue;
1868 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
1869 filterValue = G_->createExpandDims(opName, filterValue, 2);
1870 // Transpose the filter to the right format. Glow expects to read the
1871 // weights in the format CRSK. ONNX stores the operators as CKRS.
1872 // C - output_depth, R - filter_height, S - filter_width, K - input_depth.
1873 // filtervalue == CKRS ---> CRSK
1874 TransposeNode *filterTransposeNode =
1875 G_->createTranspose(opName, filterValue, NCHW2NHWC);
1876 // The structure of the conv weights is: CRSK. We take the C, which is the
1877 // number of filters. We use this value to calculate the size of the bias
1878 // if it is not specified.
1879 const NodeValue filterTransposedValue = filterTransposeNode->getResult();
1880 dim_t depth = filterTransposedValue.dims()[0];
1881
1882 // Construct the Bias field.
1883 NodeValue B;
1884 // Check if we have a serialized bias vector.
1885 if (op.input_size() > 2) {
1886 auto &biasTensorName = op.input(2);
1887 // Load the serialized bias vector as NodeValue.
1888 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName));
1889 }
1890
1891 // If a serialized bias wasn't found then create a zero bias.
1892 if (op.input_size() == 2) {
1893 auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth});
1894 Tensor biasTensor(biasTy);
1895 biasTensor.zero();
1896 B = mod_.createConstant("conv.bias", std::move(biasTensor));
1897 }
1898
1899 // ONNX passes the input as NCHW, and we expect the input to be NHWC.
1900 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
1901 // Calculate the size and allocate the output buffer.
1902 ShapeNHWC idim = ShapeNHWC(tr->getResult().dims());
1903 llvm::SmallVector<unsigned_t, 2> idimHW(2);
1904 idimHW[0] = in.dims()[2];
1905 idimHW[1] = in.dims()[3];
1906
1907 // Pads : {pad_top, pad_left, pad_bottom, pad_right}
1908 Pads pads;
1909 // Get the kernel shape.
1910 llvm::SmallVector<unsigned_t, 2> kernelShape(2);
1911 kernelShape[0] = filterTransposedValue.dims()[1];
1912 kernelShape[1] = filterTransposedValue.dims()[2];
1913
1914 ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernelShape, strides, idimHW));
1915 auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernelShape, strides,
1916 pads, dilations);
1917 std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}};
1918 auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims);
1919 auto *node = G_->createConv(opName, tr, filterTransposeNode, B, outTy,
1920 kernelShape, strides, pads, group, dilations);
1921
1922 auto *N = G_->createSqueeze(opName, node, 1 /*axes*/);
1923 // Transpose the output back
1924 auto *RR = G_->createTranspose(opName, N, {0, 2, 1});
1925 RETURN_IF_ERR(addNodeAsOutput(op, RR));
1926 return Error::success();
1927}
1928
1929Error ONNXModelLoader::loadConv2D(const ONNX_NAMESPACE::NodeProto &op,
1930 ArgumentDictionaryTy &dict) {
1931 const std::string &opName = loadOperatorName(op);
1932
1933 // Load the inputs
1934 NodeValue in;
1935 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1936
1937 NodeValue filterValue;
1938 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
1939
1940 // Load the attributes
1941 std::vector<unsigned_t> strides(2, 1);
1942 if (dict.count("strides")) {
1943 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
1944 }
1945 unsigned_t group = 1;
1946 if (dict.count("group")) {
1947 ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group")));
1948 }
1949
1950 std::vector<unsigned_t> dilations;
1951 ASSIGN_VALUE_OR_RETURN_ERR(dilations,
1952 getDilations(dict, std::vector<unsigned_t>{1, 1}));
1953 RETURN_ERR_IF_NOT(
1954 dilations.size() == 2,
1955 opErrMsg(op, strFormat("2D Conv dilations must be specified for 2 axes "
1956 " found axes %zu",
1957 dilations.size())));
1958
1959 // Transpose the filter to the right format. Glow expects to read the
1960 // weights in the format CRSK. ONNX stores the operators as KCRS.
1961 // C - output_depth, R - filter_height, S - filter_width, K - input_depth.
1962 TransposeNode *filterTransposeNode =
1963 G_->createTranspose(opName, filterValue, NCHW2NHWC);
1964
1965 // The structure of the conv weights is: CRSK. We take the C, which is the
1966 // number of filters. We use this value to calculate the size of the bias
1967 // if it is not specified.
1968 const NodeValue filterTransposedValue = filterTransposeNode->getResult();
1969 dim_t depth = filterTransposedValue.dims()[0];
1970
1971 // Get the kernel shape from the input.
1972 llvm::SmallVector<unsigned_t, 2> kernelShape(2);
1973 kernelShape[0] = filterTransposedValue.dims()[1];
1974 kernelShape[1] = filterTransposedValue.dims()[2];
1975
1976 // Extra check when the 'kernel_shape' attribute exists.
1977 // The 'kernel_shape' attribute is redundant not mandatory.
1978 if (dict.count("kernel_shape")) {
1979 std::vector<unsigned_t> kernelShapeAttribute;
1980 ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute,
1981 getShape<unsigned_t>(dict["kernel_shape"]));
1982 RETURN_ERR_IF_NOT((kernelShape[0] == kernelShapeAttribute[0] &&
1983 kernelShape[1] == kernelShapeAttribute[1]),
1984 opErrMsg(op, "Conv The 'kernel_shape' attribute is not "
1985 "consistent with the actual "
1986 "convolution kernel shape."));
1987 (void)kernelShapeAttribute; // Avoids compilation warning in release mode.
1988 }
1989
1990 // Construct the Bias field.
1991 NodeValue B;
1992 // Check if we have a serialized bias vector.
1993 if (op.input_size() > 2) {
1994 auto &biasTensorName = op.input(2);
1995 // Load the serialized bias vector as NodeValue.
1996 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName));
1997 }
1998
1999 // If a serialized bias wasn't found then create a zero bias.
2000 if (op.input_size() == 2) {
2001 auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth});
2002 Tensor biasTensor(biasTy);
2003 biasTensor.zero();
2004 B = mod_.createConstant("conv.bias", std::move(biasTensor));
2005 }
2006
2007 // ONNX passes the input as NCHW, and we expect the input to be NHWC.
2008 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
2009
2010 // Calculate the size and allocate the output buffer.
2011 ShapeNHWC idim = ShapeNHWC(tr->getResult().dims());
2012
2013 llvm::SmallVector<unsigned_t, 2> idimHW(2);
2014 idimHW[0] = in.dims()[2];
2015 idimHW[1] = in.dims()[3];
2016
2017 // Pads : {pad_top, pad_left, pad_bottom, pad_right}
2018 Pads pads;
2019 ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernelShape, strides, idimHW));
2020
2021 auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernelShape, strides,
2022 pads, dilations);
2023 std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}};
2024 auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims);
2025
2026 auto *node = G_->createConv(opName, tr, filterTransposeNode, B, outTy,
2027 kernelShape, strides, pads, group, dilations);
2028
2029 // Transpose the output back.
2030 auto *N = G_->createTranspose(opName, node, NHWC2NCHW);
2031
2032 RETURN_IF_ERR(addNodeAsOutput(op, N));
2033
2034 return Error::success();
2035}
2036
2037Error ONNXModelLoader::loadConv3D(const ONNX_NAMESPACE::NodeProto &op,
2038 ArgumentDictionaryTy &dict) {
2039 const std::string &opName = loadOperatorName(op);
2040
2041 // Load the inputs
2042 NodeValue in;
2043 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2044
2045 NodeValue filterValue;
2046 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
2047
2048 // Load the attributes
2049 std::vector<unsigned_t> strides(3, 1);
2050 if (dict.count("strides")) {
2051 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2052 }
2053 unsigned_t group = 1;
2054 if (dict.count("group")) {
2055 ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group")));
2056 }
2057
2058 std::vector<unsigned_t> dilations(3, 1);
2059 if (dict.count("dilations")) {
2060 ASSIGN_VALUE_OR_RETURN_ERR(
2061 dilations, getDilations(dict, std::vector<unsigned_t>{1, 1, 1}));
2062 RETURN_ERR_IF_NOT(
2063 dilations.size() == 3,
2064 opErrMsg(op, strFormat("3D Conv dilations must be specified for 3 axes "
2065 "found %zu axes",
2066 dilations.size())));
2067 RETURN_ERR_IF_NOT(
2068 dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
2069 opErrMsg(op, strFormat("3D Conv dilations currently only support "
2070 "the default value of [1, 1, 1] but the model "
2071 "contains [%u, %u, %u]",
2072 dilations[0], dilations[1], dilations[2])));
2073 }
2074
2075 // Transpose the filter to the right format. Glow expects to read the
2076 // weights in the format CRSK. ONNX stores the operators as KCRS.
2077 // C - output_depth, R - filter_height, S - filter_width, K - input_depth.
2078 TransposeNode *filterTransposeNode =
2079 G_->createTranspose(opName, filterValue, NCTHW2NTHWC);
2080
2081 // The structure of the conv weights is: CRSK. We take the C, which is the
2082 // number of filters. We use this value to calculate the size of the bias
2083 // if it is not specified.
2084 const NodeValue filterTransposedValue = filterTransposeNode->getResult();
2085 dim_t depth = filterTransposedValue.dims()[0];
2086
2087 // Get the kernel shape from the input.
2088 llvm::SmallVector<unsigned_t, 3> kernelShape(3);
2089 kernelShape[0] = filterTransposedValue.dims()[1];
2090 kernelShape[1] = filterTransposedValue.dims()[2];
2091 kernelShape[2] = filterTransposedValue.dims()[3];
2092
2093 // Extra check when the 'kernel_shape' attribute exists.
2094 // The 'kernel_shape' attribute is redundant not mandatory.
2095 if (dict.count("kernel_shape")) {
2096 std::vector<unsigned_t> kernelShapeAttribute;
2097 ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute,
2098 getShape<unsigned_t>(dict["kernel_shape"]));
2099 RETURN_ERR_IF_NOT(
2100 (kernelShape[0] == kernelShapeAttribute[0] &&
2101 kernelShape[1] == kernelShapeAttribute[1] &&
2102 kernelShape[2] == kernelShapeAttribute[2]),
2103 opErrMsg(
2104 op,
2105 strFormat(
2106 "The 'kernel_shape' attribute [%d, %d, %d] is not consistent "
2107 "with the actual convolution kernel shape [%d, %d, %d].",
2108 kernelShapeAttribute[0], kernelShapeAttribute[1],
2109 kernelShapeAttribute[2], kernelShape[0], kernelShape[1],
2110 kernelShape[2])));
2111 (void)kernelShapeAttribute; // Avoids compilation warning in release mode.
2112 }
2113
2114 // Construct the Bias field.
2115 NodeValue B;
2116 // Check if we have a serialized bias vector.
2117 if (op.input_size() > 2) {
2118 auto &biasTensorName = op.input(2);
2119 // Load the serialized bias vector as NodeValue.
2120 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName));
2121 }
2122
2123 // If a serialized bias wasn't found then create a zero bias.
2124 if (op.input_size() == 2) {
2125 auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth});
2126 Tensor biasTensor(biasTy);
2127 biasTensor.zero();
2128 B = mod_.createConstant("conv.bias", std::move(biasTensor));
2129 }
2130
2131 // ONNX passes the input as NCTHW, and we expect the input to be NTHWC.
2132 auto *tr = G_->createTranspose(opName, in, NCTHW2NTHWC);
2133
2134 // Calculate the size and allocate the output buffer.
2135 ShapeNTHWC idim(tr->getResult().dims());
2136
2137 llvm::SmallVector<unsigned_t, 3> idimTHW(3);
2138 idimTHW = {static_cast<glow::unsigned_t>(idim.t),
2139 static_cast<glow::unsigned_t>(idim.h),
2140 static_cast<glow::unsigned_t>(idim.w)};
2141
2142 // Pads : {pad_near, pad_top, pad_left, pad_far, pad_bottom, pad_right}
2143 Pads tmpPads;
2144 ASSIGN_VALUE_OR_RETURN_ERR(tmpPads,
2145 getPads(dict, kernelShape, strides, idimTHW));
2146
2147 // Transpose padding from NTLFBR, which is the ONNX specified layout, to
2148 // NFTBLR, which is the glow internal layout per the interpreter
2149 // implementation.
2150
2151 Pads pads = {tmpPads[0], tmpPads[3], tmpPads[1],
2152 tmpPads[4], tmpPads[2], tmpPads[5]};
2153
2154 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w,
2155 kernelShape, strides, pads);
2156 std::array<dim_t, 5> outDims = {
2157 {idim.n, outSz.temporal_frames, outSz.height, outSz.width, depth}};
2158 auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims);
2159
2160 auto *node = G_->createConv3D(opName, tr, filterTransposeNode, B, outTy,
2161 kernelShape, strides, pads, group);
2162
2163 // Transpose the output back.
2164 auto *N = G_->createTranspose(opName, node, NTHWC2NCTHW);
2165
2166 RETURN_IF_ERR(addNodeAsOutput(op, N));
2167
2168 return Error::success();
2169}
2170
2171Error ONNXModelLoader::loadTensorwiseQuantizedConvolution(
2172 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
2173 const std::string &opName = loadOperatorName(op);
2174
2175 NodeValue input;
2176 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
2177 NodeValue filterValue;
2178 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
2179 NodeValue biasValue;
2180 ASSIGN_VALUE_OR_RETURN_ERR(biasValue, getNodeValueByName(op.input(2)));
2181
2182 std::vector<unsigned_t> kernels;
2183 ASSIGN_VALUE_OR_RETURN_ERR(kernels,
2184 getShape<unsigned_t>(dict["kernel_shape"]));
2185 std::vector<unsigned_t> strides;
2186 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2187 std::vector<unsigned_t> pads;
2188 ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<unsigned_t>(dict["pads"]));
2189
2190 unsigned_t groups;
2191 ASSIGN_VALUE_OR_RETURN_ERR(groups, loadInt(dict.at("group")));
2192
2193 float outScale;
2194 ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale")));
2195 int32_t outOffset;
2196 ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset")));
2197
2198 ShapeNHWC idim(input.dims());
2199 auto outSz =
2200 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
2201 std::array<dim_t, 4> outDims = {
2202 {idim.n, outSz.first, outSz.second, biasValue.dims()[0]}};
2203 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, outDims, outScale, outOffset);
2204
2205 auto *node = G_->createConv(opName, input, filterValue, biasValue, outTy,
2206 kernels, strides, pads, groups);
2207
2208 return addNodeAsOutput(op, node);
2209}
2210
2211Error ONNXModelLoader::loadChannelwiseQuantizedConvolution(
2212 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
2213 const std::string &opName = loadOperatorName(op);
2214
2215 NodeValue input;
2216 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
2217 NodeValue filterValue;
2218 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
2219 NodeValue biasValue;
2220 ASSIGN_VALUE_OR_RETURN_ERR(biasValue, getNodeValueByName(op.input(2)));
2221 NodeValue scalesValue;
2222 ASSIGN_VALUE_OR_RETURN_ERR(scalesValue, getNodeValueByName(op.input(3)));
2223 NodeValue offsetsValue;
2224 ASSIGN_VALUE_OR_RETURN_ERR(offsetsValue, getNodeValueByName(op.input(4)));
2225
2226 std::vector<unsigned_t> kernels;
2227 ASSIGN_VALUE_OR_RETURN_ERR(kernels,
2228 getShape<unsigned_t>(dict["kernel_shape"]));
2229 std::vector<unsigned_t> strides;
2230 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2231 std::vector<unsigned_t> pads;
2232 ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<unsigned_t>(dict["pads"]));
2233
2234 unsigned_t group;
2235 ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group")));
2236
2237 std::vector<unsigned_t> dilations;
2238 ASSIGN_VALUE_OR_RETURN_ERR(dilations,
2239 getDilations(dict, std::vector<unsigned_t>{1, 1}));
2240
2241 float outScale;
2242 ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale")));
2243 int32_t outOffset;
2244 ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset")));
2245
2246 ShapeNHWC idim(input.dims());
2247 auto outSz =
2248 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
2249 std::array<dim_t, 4> outDims = {
2250 {idim.n, outSz.first, outSz.second, biasValue.dims()[0]}};
2251 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, outDims, outScale, outOffset);
2252
2253 // Quantize the filter automatically (only if it is float). The bias is NOT
2254 // quantized automatically and is left at the disposal of each Backend to
2255 // quantize it later using custom logic.
2256 auto *node = G_->createChannelwiseQuantizedConv(
2257 opName, input, filterValue, biasValue, scalesValue, offsetsValue,
2258 /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy, kernels,
2259 strides, pads, group, dilations, /* quantizeFilter */ true,
2260 /* quantizeBias */ false);
2261
2262 return addNodeAsOutput(op, node);
2263}
2264
2265Error ONNXModelLoader::loadConvTranspose(const ONNX_NAMESPACE::NodeProto &op,
2266 ArgumentDictionaryTy &dict) {
2267 const std::string &opName = loadOperatorName(op);
2268 // Load the attributes
2269 std::vector<unsigned_t> strides(2, 1);
2270 if (dict.count("strides")) {
2271 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2272 }
2273 unsigned_t group = 1;
2274 if (dict.count("group")) {
2275 ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group")));
2276 }
2277
2278 std::vector<unsigned_t> dilations;
2279 ASSIGN_VALUE_OR_RETURN_ERR(dilations,
2280 getDilations(dict, std::vector<unsigned_t>{1, 1}));
2281 RETURN_ERR_IF_NOT(dilations.size() == 2,
2282 opErrMsg(op, strFormat("ConvTranspose dilations must be "
2283 "specified for 2 axes, found %zu ",
2284 dilations.size())));
2285
2286 // Load the inputs
2287 NodeValue in;
2288 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2289 NodeValue filterValue;
2290 ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1)));
2291
2292 // Transpose the filter to the right format. Glow expects to read the
2293 // weights in the format CRSK. ONNX stores the operators as KCRS.
2294 // C - output_depth, R - filter_height, S - filter_width, K - input_depth.
2295 TransposeNode *filterTransposeNode =
2296 G_->createTranspose(opName, filterValue, CNHW2NHWC /* flip matrix */);
2297
2298 // The structure of the conv weigts is: NHWC. We take the C, which is the
2299 // number of filters. We use this value to calculate the size of the bias
2300 // if it is not specified.
2301 const NodeValue filterTransposedValue = filterTransposeNode->getResult();
2302 dim_t depth = filterTransposedValue.dims()[0] * group;
2303
2304 // Get the kernel shape from the input.
2305 llvm::SmallVector<unsigned_t, 2> kernels(2);
2306 kernels[0] = filterTransposedValue.dims()[1];
2307 kernels[1] = filterTransposedValue.dims()[2];
2308
2309 // Extra check when the 'kernel_shape' attribute exists.
2310 // The 'kernel_shape' attribute is redundant not mandatory.
2311 if (dict.count("kernel_shape")) {
2312 std::vector<unsigned_t> kernelShapeAttribute;
2313 ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute,
2314 getShape<unsigned_t>(dict["kernel_shape"]));
2315 RETURN_ERR_IF_NOT(
2316 (kernels[0] == kernelShapeAttribute[0] &&
2317 kernels[1] == kernelShapeAttribute[1]),
2318 opErrMsg(
2319 op,
2320 "The 'kernel_shape' attribute is not consistent with the actual "
2321 "convolution kernel shape."));
2322 (void)kernelShapeAttribute; // Avoids compilation warning in release mode.
2323 }
2324
2325 // Construct the Bias field.
2326 NodeValue B;
2327 // Check if we have a serialized bias vector.
2328 if (op.input_size() > 2) {
2329 auto &biasTensorName = op.input(2);
2330 // Load the serialized bias vector as constant or NodeValue.
2331 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName));
2332 }
2333
2334 // If a serialized bias wasn't found then create a zero bias.
2335 if (op.input_size() == 2) {
2336 auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth});
2337 Tensor biasTensor(biasTy);
2338 biasTensor.zero();
2339 B = mod_.createConstant("conv.bias", std::move(biasTensor));
2340 }
2341
2342 // ONNX passes the input as NCHW, and we expect the input to be NHWC.
2343 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
2344
2345 // Calculate the size and allocate the output buffer.
2346 ShapeNHWC idim = ShapeNHWC(tr->getResult().dims());
2347
2348 llvm::SmallVector<unsigned_t, 2> idimHW(2);
2349 idimHW[0] = in.dims()[2];
2350 idimHW[1] = in.dims()[3];
2351
2352 // Pads : {pad_top, pad_left, pad_bottom, pad_right}
2353 Pads pads;
2354
2355 // Conv transpose output size (HxW) is either specified or calculated.
2356 std::pair<dim_t, dim_t> outSz;
2357
2358 // Per spec, if output_shape is specified, pads are ignored.
2359 if (dict.count("output_shape")) {
2360 std::vector<unsigned_t> outShape;
2361 ASSIGN_VALUE_OR_RETURN_ERR(outShape,
2362 getShape<unsigned_t>(dict["output_shape"]));
2363 ASSIGN_VALUE_OR_RETURN_ERR(
2364 pads, getConvTransposePadsfromOutput(dict, kernels, strides, dilations,
2365 idimHW, outShape));
2366 outSz = {outShape[0], outShape[1]};
2367
2368 std::pair<dim_t, dim_t> outSzTest = calculateConvTransposeOutputDims(
2369 idim.h, idim.w, kernels, strides, pads, dilations);
2370 RETURN_ERR_IF_NOT(
2371 (outShape[0] == outSzTest.first),
2372 opErrMsg(op, strFormat("ConvTranspose Expected %d /calculated %d "
2373 "pads don't match ",
2374 int(outShape[0]), int(outSzTest.first))));
2375 RETURN_ERR_IF_NOT(
2376 (outShape[1] == outSzTest.second),
2377 opErrMsg(op, strFormat("ConvTranspose Expected %d /calculated %d "
2378 "pads don't match ",
2379 int(outShape[1]), int(outSzTest.second))));
2380 } else {
2381 if (dict.count("output_padding")) {
2382 std::vector<dim_t> outPad;
2383 ASSIGN_VALUE_OR_RETURN_ERR(outPad,
2384 getShape<dim_t>(dict["output_padding"]));
2385 if (std::equal(outPad.begin() + 1, outPad.end(), outPad.begin()) &&
2386 outPad[0] != 0) {
2387 LOG(FATAL)
2388 << "ConvTranspose argument 'output_padding' is not supported.";
2389 }
2390 }
2391 ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW));
2392 outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels, strides,
2393 pads, dilations);
2394 }
2395 std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}};
2396 auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims);
2397
2398 auto *node =
2399 G_->createConvTranspose(opName, tr, filterTransposeNode, B, outTy,
2400 kernels, strides, pads, group, dilations);
2401
2402 // Transpose the output back.
2403 auto *N = G_->createTranspose(opName, node, NHWC2NCHW);
2404 RETURN_IF_ERR(addNodeAsOutput(op, N));
2405
2406 return Error::success();
2407}
2408
2409Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op,
2410 ArgumentDictionaryTy &dict,
2411 llvm::StringRef typeName) {
2412 const std::string &opName = loadOperatorName(op);
2413
2414 // Load the inputs:
2415 NodeValue in;
2416 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2417
2418 std::vector<unsigned_t> strides(2, 1);
2419
2420 size_t inDim = in.dims().size();
2421
2422 std::vector<unsigned_t> kernelsShape;
2423 ASSIGN_VALUE_OR_RETURN_ERR(kernelsShape,
2424 getShape<unsigned_t>(dict["kernel_shape"]));
2425
2426 size_t kerDim = kernelsShape.size();
2427
2428 std::vector<unsigned_t> kernels = {1, kernelsShape[kerDim - 1]};
2429
2430 bool countIncludePads;
2431 ASSIGN_VALUE_OR_RETURN_ERR(
2432 countIncludePads, getCountIncludePads(dict, /* defaultValue */ false));
2433
2434 // For maxPool1D inDim = 3
2435 if (inDim == 3) {
2436 in = G_->createExpandDims(opName, in, 2);
2437 if (kerDim != 1) {
2438 return MAKE_ERR(
2439 opErrMsg(op, strFormat("Glow handles 1D pooling with kernel "
2440 "dimenstion size 1, but found %d ",
2441 int(kerDim))),
2442 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
2443 } else {
2444 if (dict.count("strides")) {
2445 strides[1] = dict.at("strides")->ints(0);
2446 strides[0] = 1;
2447 }
2448 }
2449 }
2450
2451 if (kerDim == 2) { // For maxPool2D
2452 kernels[0] = kernelsShape[0];
2453 if (dict.count("strides")) {
2454 ASSIGN_VALUE_OR_RETURN_ERR(strides,
2455 getShape<unsigned_t>(dict["strides"]));
2456 }
2457 }
2458
2459 if (in.dims().size() != 4 || kernels.size() != 2) {
2460 // Glow only handles 2D pooling currently.
2461 return MAKE_ERR(opErrMsg(op, "Glow only handles 2D pooling currently."),
2462 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
2463 }
2464
2465 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
2466
2467 // If 'global_pooling' is set then the operation will pool over the size of
2468 // the input by doing: kernel = height/width.
2469 if (dict.count("global_pooling")) {
2470 auto Ty = in.getType();
2471 kernels[0] = Ty->dims()[2];
2472 kernels[1] = Ty->dims()[3];
2473 }
2474
2475 // NHWC
2476 llvm::SmallVector<unsigned_t, 2> idimHW(2);
2477 idimHW[0] = in.dims()[2]; // As per NCHW format
2478 idimHW[1] = in.dims()[3];
2479
2480 Pads pads;
2481 ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW));
2482
2483 Node *node = nullptr;
2484 if (op.output_size() > 1) {
2485 if (typeName != "MaxPool") {
2486 return MAKE_ERR(
2487 opErrMsg(op, "Pool Argmax output is only supported for MaxPool!"),
2488 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
2489 }
2490
2491 node = G_->createMaxPool(opName, tr, kernels, strides, pads);
2492 auto *res = G_->createTranspose(opName, NodeValue(node, 0), NHWC2NCHW);
2493 auto *argmax = G_->createTranspose(opName, NodeValue(node, 1), NHWC2NCHW);
2494 RETURN_IF_ERR(assignNodeOutputs(op, {res, argmax}));
2495 } else {
2496 size_t idx = 0;
2497 if (typeName == "MaxPool") {
2498 node = G_->createMaxPool(opName, tr, kernels, strides, pads);
2499 idx = MaxPoolNode::ResultIdx;
2500 } else {
2501 node = G_->createAvgPool(opName, tr, kernels, strides, pads, NHWC,
2502 countIncludePads);
2503 idx = AvgPoolNode::ResultIdx;
2504 }
2505
2506 Node *N = nullptr;
2507 if (inDim == 3) { // For maxPool1D
2508 auto *R = G_->createSqueeze(opName, NodeValue(node, idx), 1);
2509 N = G_->createTranspose(opName, R, {0, 2, 1});
2510 } else {
2511 N = G_->createTranspose(opName, NodeValue(node, idx), NHWC2NCHW);
2512 }
2513
2514 RETURN_IF_ERR(addNodeAsOutput(op, N));
2515 }
2516 return Error::success();
2517}
2518
2519Error ONNXModelLoader::loadTensorwiseQuantizedPool(
2520 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict,
2521 llvm::StringRef typeName) {
2522 const std::string &opName = loadOperatorName(op);
2523
2524 // Load the inputs:
2525 NodeValue in;
2526 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2527
2528 std::vector<unsigned_t> kernels;
2529 ASSIGN_VALUE_OR_RETURN_ERR(kernels,
2530 getShape<unsigned_t>(dict["kernel_shape"]));
2531 std::vector<unsigned_t> strides;
2532 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2533
2534 if (in.dims().size() != 4 || kernels.size() != 2) {
2535 // Glow only handles 2D pooling currently.
2536 return MAKE_ERR(
2537 opErrMsg(op, strFormat("TensorwiseQuantizedPool Glow only handles 2D "
2538 "pooling currently, but found kernel %zu ",
2539 kernels.size())),
2540 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
2541 }
2542
2543 bool countIncludePads;
2544 ASSIGN_VALUE_OR_RETURN_ERR(
2545 countIncludePads, getCountIncludePads(dict, /* defaultValue */ false));
2546
2547 // NHWC
2548 llvm::SmallVector<unsigned_t, 2> idimHW(2);
2549 idimHW[0] = in.dims()[1];
2550 idimHW[1] = in.dims()[2];
2551
2552 Pads pads;
2553 ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW));
2554
2555 if (op.output_size() > 1) {
2556 if (typeName != "MaxPool") {
2557 return MAKE_ERR(opErrMsg(op,
2558 "TensorwiseQuantizedPool Argmax output is only "
2559 "supported for MaxPool!"),
2560 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
2561 }
2562
2563 Node *maxpool = G_->createMaxPool(opName, in, kernels, strides, pads);
2564 auto res = maxpool->getNthResult(MaxPoolNode::ResultIdx);
2565 auto argmax = maxpool->getNthResult(MaxPoolNode::ArgmaxIdx);
2566 RETURN_IF_ERR(assignNodeOutputs(op, {res, argmax}));
2567 } else {
2568 Node *poolNode;
2569 if (typeName == "MaxPool") {
2570 poolNode = G_->createMaxPool(opName, in, kernels, strides, pads);
2571 } else {
2572 poolNode = G_->createAvgPool(opName, in, kernels, strides, pads, NHWC,
2573 countIncludePads);
2574 }
2575 RETURN_IF_ERR(addNodeAsOutput(op, poolNode));
2576 }
2577 return Error::success();
2578}
2579
2580Error ONNXModelLoader::loadArgMinMax(const ONNX_NAMESPACE::NodeProto &op,
2581 ArgumentDictionaryTy &dict, bool isMin) {
2582 const std::string &opName = loadOperatorName(op);
2583
2584 NodeValue in;
2585 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2586 size_t axis = 0;
2587 if (dict.count("axis")) {
2588 ASSIGN_VALUE_OR_RETURN_ERR(
2589 axis, loadAxis<size_t>(dict.at("axis"), in.dims().size()));
2590 }
2591 bool keepDims = true;
2592 if (dict.count("keepdims")) {
2593 ASSIGN_VALUE_OR_RETURN_ERR(keepDims, loadInt(dict.at("keepdims")));
2594 }
2595 Node *node;
2596 if (isMin) {
2597 node = G_->createArgMin(opName, in, axis, keepDims);
2598 } else {
2599 node = G_->createArgMax(opName, in, axis, keepDims);
2600 }
2601 RETURN_IF_ERR(addNodeAsOutput(op, node));
2602 return Error::success();
2603}
2604
2605Error ONNXModelLoader::loadUpsample(const ONNX_NAMESPACE::NodeProto &op,
2606 ArgumentDictionaryTy &dict) {
2607
2608 RETURN_ERR_IF_NOT(
2609 (opsetVersion_ < 10) && (opsetVersion_ > 6),
2610 opErrMsg(op, "Version mismatch issue found, Upsample operator is "
2611 "supported for opset_version between 7 and 9"
2612 "use resize operator if opset_version > 9"));
2613
2614 const std::string &opName = loadOperatorName(op);
2615 NodeValue in;
2616 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2617
2618 // Default mode of upsample operator is "nearest"
2619 std::string mode("nearest");
2620 if (dict.count("mode")) {
2621 ASSIGN_VALUE_OR_RETURN_ERR(mode, loadStr(dict.at("mode")));
2622 }
2623
2624 /// Only Nearest Mode is supported
2625 RETURN_ERR_IF_NOT(mode.compare("nearest") == 0,
2626 opErrMsg(op, strFormat("UpSample Operator has nearest mode "
2627 "support only, found mode '%s' ",
2628 mode.c_str())));
2629
2630 /// Scale is always float as per onnx documentation
2631 std::vector<float> scales;
2632
2633 if (opsetVersion_ == 7) {
2634 if (dict.count("scales")) {
2635 /// As per onnx documentation this is a required field
2636 /// and if not present then onnx.checker.check_model file check to fail
2637 ASSIGN_VALUE_OR_RETURN_ERR(scales, getFloats(dict["scales"]));
2638 } else {
2639 return MAKE_ERR(opErrMsg(op,
2640 "UpSample Scales field is not present, expected "
2641 "for Upsample opset_version==7"));
2642 }
2643 }
2644
2645 if (opsetVersion_ > 7) {
2646 Constant *scale;
2647 ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1)));
2648 if (scale->getElementType() != ElemKind::FloatTy) {
2649 return MAKE_ERR(opErrMsg(op,
2650 "UpSample Scales Tensor should have float type "
2651 "for opset_version > 7"));
2652 }
2653 auto constH = scale->getPayload().getHandle<float>();
2654 for (dim_t i = 0; i < constH.size(); ++i) {
2655 scales.push_back(constH.at({i}));
2656 }
2657 }
2658
2659 /// Scales tensor format is NHWC for supported modes other than nearest.
2660 /// For nearest mode scale can be 4D or 5D.
2661 RETURN_ERR_IF_NOT(
2662 (scales.size() >= 4 && scales.size() <= 5 && mode == "nearest") ||
2663 scales.size() == 4,
2664 opErrMsg(
2665 op, strFormat(
2666 "UpSample Scales dimension invalid. Mode: %s Scale Size: %zu",
2667 mode.c_str(), scales.size())));
2668
2669 for (auto &val : scales) {
2670 RETURN_ERR_IF_NOT(
2671 val >= 1,
2672 opErrMsg(op, strFormat("UpSample Scales value can only be "
2673 " greater than or equal to 1, but found %d",
2674 int(val))));
2675 }
2676
2677 switch (scales.size()) {
2678 case 4: {
2679 vectorReorder(scales, {NHWC2NCHW});
2680 auto *intr = G_->createTranspose(opName, in, NCHW2NHWC);
2681 auto *node = G_->createResizeNearest(opName, intr, scales);
2682 auto *N = G_->createTranspose(opName, node, NHWC2NCHW);
2683 RETURN_IF_ERR(addNodeAsOutput(op, N));
2684 return Error::success();
2685 }
2686 case 5: {
2687 vectorReorder(scales, {NTHWC2NCTHW});
2688
2689 auto *intr = G_->createTranspose(opName, in, NCTHW2NTHWC);
2690 auto *node = G_->createResizeNearest(opName, intr, scales);
2691 auto *N = G_->createTranspose(opName, node, NTHWC2NCTHW);
2692 RETURN_IF_ERR(addNodeAsOutput(op, N));
2693 return Error::success();
2694 }
2695 default:
2696 RETURN_ERR_IF_NOT(
2697 false, opErrMsg(op, strFormat("UpSample Scales dimension invalid")));
2698 }
2699}
2700
2701Error ONNXModelLoader::loadResize(const ONNX_NAMESPACE::NodeProto &op,
2702 const ArgumentDictionaryTy &dict) {
2703 const std::string &opName = loadOperatorName(op);
2704
2705 NodeValue in;
2706 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2707
2708 std::string modeStr;
2709 ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode")));
2710
2711 Constant *scalesC = nullptr;
2712
2713 // Either scales or outDims will be populated (V11 can do either, V10 scales
2714 // only)
2715 std::vector<float> scales;
2716 std::vector<dim_t> outDims;
2717
2718 int32_t scalesIdx = (this->opsetVersion_ >= 11) ? 2 : 1;
2719 scalesC = getConstantByNameOrNull(op.input(scalesIdx));
2720 RETURN_ERR_IF_NOT(
2721 scalesC,
2722 opErrMsg(op, strFormat("Resize Scales Tensor '%s' is not Constant.",
2723 op.input(scalesIdx).c_str())));
2724 if (scalesC->getElementType() != ElemKind::FloatTy) {
2725 return MAKE_ERR(opErrMsg(
2726 op, strFormat(
2727 "Resize Scales Tensor should have float type, but found '%s' ",
2728 scalesC->getType()->getElementName().str().c_str())));
2729 }
2730
2731 // For ONNX Resize v11, support attributes that are compatible with v10:
2732 // exclude_outside = 0
2733 // extrapolation_value = 0.0
2734 // nearest_mode = floor
2735 // coordinate_transformation_mode = asymmetric
2736 // mode = nearest, (bi)linear
2737 if (this->opsetVersion_ >= 11) {
2738 int32_t excludeOutside = 0;
2739 // attribute: exclude_outside.
2740 if (dict.count("exclude_outside")) {
2741 ASSIGN_VALUE_OR_RETURN_ERR(excludeOutside,
2742 loadInt(dict.at("exclude_outside")));
2743 }
2744 RETURN_ERR_IF_NOT(excludeOutside == 0,
2745 opErrMsg(op, strFormat("ONNX Resize exclude outside "
2746 " not supported.")));
2747 // attribute: extrapolation_value.
2748 float extrapolationValue = 0.0;
2749 if (dict.count("extrapolation_value")) {
2750 ASSIGN_VALUE_OR_RETURN_ERR(extrapolationValue,
2751 loadFloat(dict.at("extrapolation_value")));
2752 }
2753 RETURN_ERR_IF_NOT(
2754 extrapolationValue == 0.0,
2755 opErrMsg(op, strFormat("Resize extrapolation value 0 supported only, "
2756 "but found value %f",
2757 extrapolationValue)));
2758 // attribute: nearest_mode.
2759 std::string nearestMode = "round_prefer_floor";
2760 if (dict.count("nearest_mode")) {
2761 ASSIGN_VALUE_OR_RETURN_ERR(nearestMode, loadStr(dict.at("nearest_mode")));
2762 }
2763 if (modeStr == "nearest" && nearestMode != "floor") {
2764 return MAKE_ERR(
2765 opErrMsg(op, strFormat("Resize 'floor' and 'nearest' mode "
2766 "supported only, but found mode '%s' ",
2767 modeStr.c_str())));
2768 }
2769 // attribute: coordinate_transformation_mode.
2770 std::string coordTransformMode = "half_pixel";
2771 if (dict.count("coordinate_transformation_mode")) {
2772 ASSIGN_VALUE_OR_RETURN_ERR(
2773 coordTransformMode,
2774 loadStr(dict.at("coordinate_transformation_mode")));
2775 }
2776 RETURN_ERR_IF_NOT(
2777 coordTransformMode == "asymmetric",
2778 opErrMsg(op, strFormat("Resize 'asymmetric' coordinate transformation "
2779 "mode supported only, but found %s",
2780 coordTransformMode.c_str())));
2781
2782 // If no scales tensor, sizes tensor should be valid.
2783 if (scalesC->getPayload().getHandle().size() == 0) {
2784 Constant *sizesC;
2785 ASSIGN_VALUE_OR_RETURN_ERR(sizesC, getConstantByName(op.input(3)));
2786 RETURN_ERR_IF_NOT(sizesC,
2787 opErrMsg(op, strFormat("Resize Sizes Tensor '%s'"
2788 " is not Constant.",
2789 op.input(3).c_str())));
2790
2791 // Must be 1D tensor of int64_t.
2792 RETURN_ERR_IF_NOT(
2793 sizesC->dims().size() == 1,
2794 opErrMsg(op, strFormat("Resize Input must be a 1D vector."
2795 " but found vector size %zu ",
2796 sizesC->dims().size())));
2797 RETURN_ERR_IF_NOT(
2798 sizesC->getType()->getElementType() == ElemKind::Int64ITy,
2799 opErrMsg(op, strFormat(
2800 "Resize Input element type must be Int64ITy, but "
2801 "found type '%s' ",
2802 sizesC->getType()->getElementName().str().c_str())));
2803
2804 auto sizesH = sizesC->getPayload().getHandle<int64_t>();
2805 RETURN_ERR_IF_NOT(
2806 in.dims().size() == sizesH.size(),
2807 opErrMsg(
2808 op,
2809 strFormat("Data input %s and sizes input %s must match in size.",
2810 std::to_string(in.dims().size()).c_str(),
2811 std::to_string(sizesH.size()).c_str())));
2812 // Now fill the output tensor
2813 for (dim_t i = 0; i < sizesH.size(); ++i) {
2814 outDims.push_back(sizesH.at({i}));
2815 }
2816 } else {
2817 RETURN_ERR_IF_NOT(
2818 op.input_size() == 3,
2819 opErrMsg(op, "Resize 'sizes' not valid with 'scales' input"));
2820 }
2821 } // v11 processing.
2822
2823 NodeValue outtr = nullptr;
2824 auto scalesH = scalesC->getPayload().getHandle();
2825
2826 // Check is scales is not empty - if yes, use it.
2827 if (scalesH.size()) {
2828 for (dim_t i = 0; i < scalesH.size(); ++i) {
2829 scales.push_back(scalesH.at({i}));
2830 }
2831
2832 // Scales tensor format is NHWC for supported modes other than nearest.
2833 // For nearest mode scale can be 3D, 4D, 5D, or 6D.
2834 RETURN_ERR_IF_NOT(
2835 (scales.size() >= 3 && scales.size() <= 6 && modeStr == "nearest") ||
2836 scales.size() == 4,
2837 opErrMsg(
2838 op, strFormat(
2839 "Resize Scales dimension invalid. Mode: %s Scale Size: %zu",
2840 modeStr.c_str(), scales.size())));
2841
2842 for (auto &val : scales) {
2843 RETURN_ERR_IF_NOT(
2844 val > 0,
2845 opErrMsg(
2846 op,
2847 strFormat(
2848 "Resize Scale value must be greater than zero, but found %d",
2849 int(val))));
2850 }
2851
2852 if (modeStr == "nearest") {
2853 outtr = G_->createResizeNearest(opName, in, scales);
2854 } else if (modeStr == "bilinear" || modeStr == "linear") {
2855 vectorReorder(scales, {NHWC2NCHW});
2856 auto *intr = G_->createTranspose(opName, in, NCHW2NHWC);
2857 auto RN = G_->createResizeBilinear(opName, intr, scales);
2858 outtr = G_->createTranspose(opName, RN, NHWC2NCHW);
2859 } else {
2860 return MAKE_ERR(
2861 opErrMsg(op, strFormat("Resize Supports nearest or bilinear "
2862 "interpolation only, but found mode as '%s' ",
2863 modeStr.c_str())),
2864 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
2865 }
2866 } else if (outDims.size()) {
2867 if (modeStr == "nearest") {
2868 auto outTy = G_->getParent()->uniqueTypeWithNewShape(
2869 in.getType(), llvm::ArrayRef<dim_t>(outDims));
2870 outtr = G_->createResizeNearest(opName, in, outTy);
2871 } else if (modeStr == "bilinear" || modeStr == "linear") {
2872 vectorReorder(outDims, {NHWC2NCHW});
2873 auto outTy = G_->getParent()->uniqueTypeWithNewShape(
2874 in.getType(), llvm::ArrayRef<dim_t>(outDims));
2875 auto *intr = G_->createTranspose(opName, in, NCHW2NHWC);
2876 auto RN = G_->createResizeBilinear(opName, intr, outTy);
2877 outtr = G_->createTranspose(opName, RN, NHWC2NCHW);
2878 } else {
2879 return MAKE_ERR(
2880 opErrMsg(op, strFormat(
2881 "Supporting nearest or (bi)linear interpolation only"
2882 " but found mode '%s' ",
2883 modeStr.c_str())),
2884 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
2885 }
2886 } else {
2887 return MAKE_ERR(opErrMsg(op, "Resize Neither scales or sizes are set."),
2888 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
2889 }
2890
2891 RETURN_IF_ERR(addNodeAsOutput(op, outtr));
2892 return Error::success();
2893}
2894
2895Error ONNXModelLoader::loadGlobalAveragePool(
2896 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
2897 const std::string &opName = loadOperatorName(op);
2898
2899 // Load the inputs:
2900 NodeValue in;
2901 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2902 std::vector<unsigned_t> strides(2, 1);
2903 if (dict.count("strides")) {
2904 ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides"]));
2905 }
2906
2907 llvm::SmallVector<unsigned_t, 2> kernels(2);
2908 kernels[0] = in.dims()[2];
2909 kernels[1] = in.dims()[3];
2910
2911 Pads pads;
2912 ASSIGN_VALUE_OR_RETURN_ERR(
2913 pads, getPads(dict, kernels, strides, kernels /* input sizes*/));
2914
2915 bool countIncludePads;
2916 ASSIGN_VALUE_OR_RETURN_ERR(
2917 countIncludePads, getCountIncludePads(dict, /* defaultValue */ false));
2918
2919 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
2920 Node *node = G_->createAvgPool(opName, tr, kernels, strides, pads, NHWC,
2921 countIncludePads);
2922 auto *N = G_->createTranspose(opName, node, NHWC2NCHW);
2923 RETURN_IF_ERR(addNodeAsOutput(op, N));
2924 return Error::success();
2925}
2926
2927Error ONNXModelLoader::loadSqueeze(const ONNX_NAMESPACE::NodeProto &op,
2928 ArgumentDictionaryTy &dict) {
2929 const std::string &opName = loadOperatorName(op);
2930
2931 NodeValue in;
2932 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2933 std::vector<dim_t> axes;
2934 ASSIGN_VALUE_OR_RETURN_ERR(axes,
2935 loadAxes<dim_t>(dict["axes"], in.dims().size()));
2936 Node *node = G_->createSqueeze(opName, in, axes);
2937 RETURN_IF_ERR(addNodeAsOutput(op, node));
2938 return Error::success();
2939}
2940
2941Error ONNXModelLoader::loadUnsqueeze(const ONNX_NAMESPACE::NodeProto &op,
2942 ArgumentDictionaryTy &dict) {
2943 const std::string &opName = loadOperatorName(op);
2944
2945 NodeValue in;
2946 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2947
2948 // Compute output rank.
2949 std::vector<int> axesTemp;
2950 ASSIGN_VALUE_OR_RETURN_ERR(axesTemp, getShape<int>(dict["axes"]));
2951 int outputRank = in.dims().size() + axesTemp.size();
2952
2953 // Read again the axes and use the output rank to wrap negative axes.
2954 std::vector<dim_t> axes;
2955 ASSIGN_VALUE_OR_RETURN_ERR(axes, loadAxes<dim_t>(dict["axes"], outputRank));
2956
2957 Node *node = G_->createExpandDims(opName, in, axes);
2958 RETURN_IF_ERR(addNodeAsOutput(op, node));
2959 return Error::success();
2960}
2961
2962Error ONNXModelLoader::loadBatchNormalization(
2963 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
2964 const std::string &opName = loadOperatorName(op);
2965
2966 NodeValue in;
2967 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
2968 NodeValue scale;
2969 ASSIGN_VALUE_OR_RETURN_ERR(scale, getNodeValueByName(op.input(1)));
2970 NodeValue bias;
2971 ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(2)));
2972 NodeValue mean;
2973 ASSIGN_VALUE_OR_RETURN_ERR(mean, getNodeValueByName(op.input(3)));
2974 NodeValue var;
2975 ASSIGN_VALUE_OR_RETURN_ERR(var, getNodeValueByName(op.input(4)));
2976 float epsilon = 1e-5f; // default
2977 auto epsilonIt = dict.find("epsilon");
2978 if (epsilonIt != dict.end()) {
2979 ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second));
2980 }
2981
2982 auto *node = G_->createBatchNormalization(opName, in.getType(), in, bias,
2983 scale, mean, var, 1, epsilon);
2984
2985 // BatchNormalization has 4 optional outputs that are not supported by glow.
2986 // Then: 1/ In case the optional outputs are present and used by other
2987 // operations of the model, then the import should fail. 2/ In case the
2988 // optional outputs are declared but not used, the import should succeed. By
2989 // registering only the mandatory output, we make sure the import will fail if
2990 // the non supported features are actually requested by the ONNX model.
2991 RETURN_IF_ERR(addNodeAsOutput(op, node, 1));
2992
2993 return Error::success();
2994}
2995
2996Error ONNXModelLoader::loadInstanceNormalization(
2997 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
2998 const std::string &opName = loadOperatorName(op);
2999
3000 NodeValue in;
3001 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
3002 NodeValue scale;
3003 ASSIGN_VALUE_OR_RETURN_ERR(scale, getNodeValueByName(op.input(1)));
3004 NodeValue bias;
3005 ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(2)));
3006
3007 float epsilon = 1e-5f; // default
3008 auto epsilonIt = dict.find("epsilon");
3009 if (epsilonIt != dict.end()) {
3010 ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second));
3011 }
3012
3013 auto *node =
3014 G_->createInstanceNormalization(opName, in, bias, scale, 1, epsilon);
3015 RETURN_IF_ERR(addNodeAsOutput(op, node, 1));
3016
3017 return Error::success();
3018}
3019
3020Error ONNXModelLoader::loadConcat(const ONNX_NAMESPACE::NodeProto &op,
3021 ArgumentDictionaryTy &dict) {
3022 const std::string &opName = loadOperatorName(op);
3023
3024 const unsigned numInputs = op.input_size();
3025 llvm::SmallVector<NodeValue, 4> inputs;
3026 inputs.reserve(numInputs);
3027 for (unsigned i = 0; i < numInputs; i++) {
3028 NodeValue in;
3029 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
3030 inputs.push_back(in);
3031 }
3032
3033 int axis;
3034 ASSIGN_VALUE_OR_RETURN_ERR(
3035 axis, loadAxis<int>(dict.at("axis"), inputs.back().dims().size()));
3036
3037 Node *node = G_->createConcat(opName, inputs, axis);
3038
3039 RETURN_IF_ERR(addNodeAsOutput(op, node));
3040 return Error::success();
3041}
3042
3043Error ONNXModelLoader::loadFCTransposed(const ONNX_NAMESPACE::NodeProto &op,
3044 ArgumentDictionaryTy &dict) {
3045 const std::string &opName = loadOperatorName(op);
3046
3047 NodeValue in;
3048 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
3049 if (in.getType()->dims().size() > 2) {
3050 size_t axis = 1;
3051 if (dict.count("axis")) {
3052 ASSIGN_VALUE_OR_RETURN_ERR(
3053 axis, loadAxis<size_t>(dict.at("axis"), in.dims().size()));
3054 }
3055 in = G_->createFlatten(opName + ".fc.in", in, axis);
3056 }
3057
3058 unsigned_t axis_w = 1;
3059 if (dict.count("axis_w")) {
3060 ASSIGN_VALUE_OR_RETURN_ERR(axis_w, loadInt(dict.at("axis_w")));
3061 }
3062
3063 Constant *W;
3064 ASSIGN_VALUE_OR_RETURN_ERR(W, getConstantByName(op.input(1)));
3065
3066 // w is stored already transposed. No need to additionally transpose it.
3067 if (W->dims().size() > 2) {
3068 Tensor tmp;
3069 auto wDims = flattenCdr(W->dims(), axis_w);
3070 tmp.reset(ElemKind::FloatTy, {wDims.first, wDims.second});
3071 tmp.copyRawFrom(&W->getPayload());
3072 W = mod_.createConstant(W->getName(), tmp);
3073 }
3074
3075 Constant *B;
3076 ASSIGN_VALUE_OR_RETURN_ERR(B, getConstantByName(op.input(2)));
3077
3078 auto *node = G_->createFullyConnected(opName, in, W, B);
3079
3080 RETURN_IF_ERR(addNodeAsOutput(op, node));
3081 return Error::success();
3082}
3083
3084Error ONNXModelLoader::loadGemm(const ONNX_NAMESPACE::NodeProto &op,
3085 ArgumentDictionaryTy &dict) {
3086 const std::string &opName = loadOperatorName(op);
3087 NodeValue A;
3088 ASSIGN_VALUE_OR_RETURN_ERR(A, getNodeValueByName(op.input(0)));
3089 NodeValue B;
3090 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(1)));
3091 NodeValue C = nullptr;
3092 if (op.input_size() > 2 && !op.input(2).empty()) {
3093 ASSIGN_VALUE_OR_RETURN_ERR(C, getNodeValueByName(op.input(2)));
3094 }
3095
3096 float alpha = 1.0;
3097 if (dict.count("alpha")) {
3098 ASSIGN_VALUE_OR_RETURN_ERR(alpha, loadFloat(dict.at("alpha")));
3099 }
3100
3101 float beta = 1.0;
3102 if (dict.count("beta")) {
3103 ASSIGN_VALUE_OR_RETURN_ERR(beta, loadFloat(dict.at("beta")));
3104 }
3105
3106 bool transA = false;
3107 if (dict.count("transA")) {
3108 ASSIGN_VALUE_OR_RETURN_ERR(transA, loadInt(dict.at("transA")));
3109 }
3110
3111 bool transB = false;
3112 if (dict.count("transB")) {
3113 ASSIGN_VALUE_OR_RETURN_ERR(transB, loadInt(dict.at("transB")));
3114 }
3115
3116 Node *node = G_->createGemm(opName, A, B, C, alpha, beta, transA, transB);
3117
3118 RETURN_IF_ERR(addNodeAsOutput(op, node));
3119 return Error::success();
3120}
3121
3122Error ONNXModelLoader::loadMatMul(const ONNX_NAMESPACE::NodeProto &op,
3123 ArgumentDictionaryTy &dict) {
3124 const std::string &opName = loadOperatorName(op);
3125
3126 NodeValue LHS;
3127 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0)));
3128 NodeValue RHS;
3129 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1)));
3130
3131 /// For dimension greater than 2 use batchedMatMul
3132 if (LHS.dims().size() > 2) {
3133 Node *node = G_->createBatchMatMul(opName, LHS, RHS);
3134 const size_t numDimsLHS = LHS.dims().size();
3135 if (numDimsLHS > 3) {
3136 const size_t numDimsRHS = RHS.dims().size();
3137 std::vector<dim_t> finalShape;
3138 for (auto d : LHS.dims()) {
3139 finalShape.push_back(d);
3140 }
3141 finalShape[numDimsLHS - 1] = RHS.dims()[numDimsRHS - 1];
3142 node = G_->createReshape(opName, node, finalShape);
3143 }
3144 RETURN_IF_ERR(addNodeAsOutput(op, node));
3145 } else {
3146 Node *node = G_->createMatMul(opName, LHS, RHS);
3147 RETURN_IF_ERR(addNodeAsOutput(op, node));
3148 }
3149 return Error::success();
3150}
3151
3152Error ONNXModelLoader::loadHardSigmoid(const ONNX_NAMESPACE::NodeProto &op,
3153 ArgumentDictionaryTy &dict) {
3154 const std::string &opName = loadOperatorName(op);
3155
3156 // Input Type.
3157 NodeValue input;
3158 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3159
3160 float alphaVal = 0.2f;
3161 if (dict.count("alpha")) {
3162 ASSIGN_VALUE_OR_RETURN_ERR(alphaVal, loadFloat(dict.at("alpha")));
3163 }
3164 float betaVal = 0.5f;
3165 if (dict.count("beta")) {
3166 ASSIGN_VALUE_OR_RETURN_ERR(betaVal, loadFloat(dict.at("beta")));
3167 }
3168
3169 // Create the node.
3170 Node *N = G_->createHardSigmoid(opName, input, alphaVal, betaVal);
3171 RETURN_IF_ERR(addNodeAsOutput(op, N));
3172
3173 return Error::success();
3174}
3175
3176Error ONNXModelLoader::loadLeakyRelu(const ONNX_NAMESPACE::NodeProto &op,
3177 ArgumentDictionaryTy &dict) {
3178 const std::string &opName = loadOperatorName(op);
3179
3180 // Input Type.
3181 NodeValue input;
3182 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3183
3184 // ONNX spec says default is 0.01, but doesn't explicitly say it's optional.
3185 // like for others. The default example just omits alpha.
3186 float alphaVal = 0.01f;
3187 if (dict.count("alpha")) {
3188 ASSIGN_VALUE_OR_RETURN_ERR(alphaVal, loadFloat(dict.at("alpha")));
3189 }
3190
3191 // Create the node.
3192 Node *N = G_->createLeakyRELU(opName, input, alphaVal);
3193 RETURN_IF_ERR(addNodeAsOutput(op, N));
3194
3195 return Error::success();
3196}
3197
3198Error ONNXModelLoader::loadPad(const ONNX_NAMESPACE::NodeProto &op,
3199 ArgumentDictionaryTy &dict) {
3200 const std::string &opName = loadOperatorName(op);
3201
3202 // Input
3203 NodeValue input;
3204 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3205 auto inputDims = input.dims();
3206 auto numDims = inputDims.size();
3207
3208 // Padding properties.
3209 unsigned_t mode = PaddingMode::CONSTANT; // default is constant.
3210 if (dict.count("mode")) {
3211 std::string modeStr;
3212 ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode")));
3213 if (modeStr == "constant") {
3214 mode = PaddingMode::CONSTANT;
3215 } else if (modeStr == "reflect") {
3216 mode = PaddingMode::REFLECT;
3217 } else if (modeStr == "edge") {
3218 mode = PaddingMode::EDGE;
3219 } else {
3220 return MAKE_ERR(
3221 "Pad: Invalid mode",
3222 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
3223 }
3224 }
3225
3226 std::vector<int> pads;
3227
3228 if (this->opsetVersion_ > 10) {
3229 // Get pads through input(1) from opset v11.
3230 RETURN_ERR_IF_NOT(
3231 op.input_size() > 1,
3232 "Pad: The 'pads' is mandatory as input(1) from opsetv11.");
3233 Constant *padC;
3234 ASSIGN_VALUE_OR_RETURN_ERR(padC, getConstantByName(op.input(1)));
3235 RETURN_ERR_IF_NOT(padC, "Support only constant pad");
3236 helperSetter<int64_t, int>(padC, pads);
3237 } else {
3238 RETURN_ERR_IF_NOT(dict.count("pads"),
3239 "Pad: The 'pads' property is mandatory");
3240 ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<int>(dict["pads"]));
3241 }
3242
3243 RETURN_ERR_IF_NOT(
3244 (pads.size() == 2 * numDims),
3245 opErrMsg(op, " The 'pads' array must contain 2 values per dimensions"));
3246
3247 float value = 0.f;
3248 if (this->opsetVersion_ > 10) {
3249 if (op.input_size() > 2) {
3250 Constant *valueC;
3251 ASSIGN_VALUE_OR_RETURN_ERR(valueC, getConstantByName(op.input(2)));
3252 RETURN_ERR_IF_NOT(valueC, "Support only constant value in Pad");
3253 RETURN_ERR_IF_NOT(valueC->getElementType() == ElemKind::FloatTy,
3254 "Value in Pad should be float type.");
3255 value = valueC->getPayload().getHandle().raw(0);
3256 }
3257 } else {
3258 if (dict.count("value")) {
3259 ASSIGN_VALUE_OR_RETURN_ERR(value, loadFloat(dict.at("value")));
3260 }
3261 }
3262
3263 // Compute the output type.
3264 std::vector<dim_t> outDims(numDims);
3265 for (unsigned_t i = 0; i < numDims; i++) {
3266 auto new_dim = inputDims[i] + pads[i] + pads[i + numDims];
3267 RETURN_ERR_IF_NOT(
3268 new_dim > 0,
3269 opErrMsg(op, "The padding can't remove all elements of a dimension"));
3270 outDims[i] = new_dim;
3271 }
3272 auto outTy = mod_.uniqueType(ElemKind::FloatTy, outDims);
3273
3274 // Create the IR node.
3275 Node *N = G_->createPad(opName, input, outTy, mode, pads, value);
3276 RETURN_IF_ERR(addNodeAsOutput(op, N));
3277
3278 return Error::success();
3279}
3280
3281Error ONNXModelLoader::loadCast(const ONNX_NAMESPACE::NodeProto &op,
3282 ArgumentDictionaryTy &dict) {
3283 const std::string &opName = loadOperatorName(op);
3284
3285 // Input type
3286 NodeValue input;
3287 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3288 ElemKind inputKind = input.getType()->getElementType();
3289
3290 // Target type.
3291 ElemKind targetKind;
3292 RETURN_ERR_IF_NOT(dict.count("to"),
3293 opErrMsg(op, "Cast missing 'to' attribute"));
3294 int toONNXTypeValue;
3295 ASSIGN_VALUE_OR_RETURN_ERR(toONNXTypeValue, loadInt(dict.at("to")));
3296 RETURN_ERR_IF_NOT(
3297 ONNX_NAMESPACE::TensorProto_DataType_IsValid(toONNXTypeValue),
3298 opErrMsg(op, "Cast invalid target type"),
3299 ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
3300 ASSIGN_VALUE_OR_RETURN_ERR(
3301 targetKind, convertTensorProtoDataType(
3302 ONNX_NAMESPACE::TensorProto_DataType(toONNXTypeValue)));
3303
3304 // Only support non quantized types.
3305 RETURN_ERR_IF_NOT(
3306 (!isQuantizedElemKind(inputKind)) && (!isQuantizedElemKind(targetKind)),
3307 opErrMsg(op,
3308 "Cast Unsupported types (Supports only non quantized types)"));
3309
3310 // Create the IR node.
3311 Node *N = G_->createConvertTo(opName, input, targetKind);
3312 RETURN_IF_ERR(addNodeAsOutput(op, N));
3313
3314 return Error::success();
3315}
3316
3317Error ONNXModelLoader::loadDepthToSpace(const ONNX_NAMESPACE::NodeProto &op,
3318 const ArgumentDictionaryTy &dict) {
3319 NodeValue input;
3320 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3321
3322 dim_t blockSize = 0;
3323 if (dict.count("blocksize")) {
3324 ASSIGN_VALUE_OR_RETURN_ERR(blockSize, loadInt(dict.at("blocksize")));
3325 } else {
3326 return MAKE_ERR("DepthToSpace: missing 'blocksize' attribute");
3327 }
3328
3329 std::string mode = "DCR";
3330 if (dict.count("mode")) {
3331 ASSIGN_VALUE_OR_RETURN_ERR(mode, loadStr(dict.at("mode")));
3332 }
3333
3334 auto inputDim = input.dims();
3335 RETURN_ERR_IF_NOT(inputDim.size() == 4,
3336 "DepthToSpace: dimension size of 4 is expected.");
3337 RETURN_ERR_IF_NOT(inputDim[1] % blockSize == 0,
3338 "DepthToSpace: depth should be divisible by block size.");
3339
3340 std::string opName = loadOperatorName(op);
3341 auto *TR1 = G_->createTranspose(opName, input, NCHW2NHWC);
3342 auto *D2S = G_->createDepthToSpace(opName, TR1, blockSize, mode == "CRD");
3343 auto *TR2 = G_->createTranspose(opName, D2S, NHWC2NCHW);
3344
3345 RETURN_IF_ERR(addNodeAsOutput(op, TR2));
3346 return Error::success();
3347}
3348
3349Error ONNXModelLoader::loadSpaceToDepth(const ONNX_NAMESPACE::NodeProto &op,
3350 ArgumentDictionaryTy &dict) {
3351
3352 // Input Type
3353 NodeValue input;
3354 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
3355
3356 int blockSize = 0;
3357 if (dict.count("blocksize")) {
3358 ASSIGN_VALUE_OR_RETURN_ERR(blockSize, loadInt(dict.at("blocksize")));
3359 } else {
3360 return MAKE_ERR(
3361 opErrMsg(op, "SpaceToDepth: missing 'blocksize' attribute"));
3362 }
3363
3364 // Create the node.
3365 std::string opName = loadOperatorName(op);
3366 auto *tr = G_->createTranspose(opName, input, NCHW2NHWC);
3367 Node *nd = G_->createSpaceToDepth(opName, tr, blockSize);
3368 auto *N = G_->createTranspose(opName, nd, NHWC2NCHW);
3369
3370 RETURN_IF_ERR(addNodeAsOutput(op, N));
3371
3372 return Error::success();
3373}
3374
3375Error ONNXModelLoader::loadReduceL2(const ONNX_NAMESPACE::NodeProto &op,
3376 const ArgumentDictionaryTy &dict) {
3377 const std::string &opName = loadOperatorName(op);
3378 NodeValue in;
3379 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
3380 in = G_->createMul(opName, in, in);
3381
3382 // ReduceAdd.
3383 std::vector<unsigned_t> shapeAxes = {};
3384 if (dict.count("axes")) {
3385 ASSIGN_VALUE_OR_RETURN_ERR(
3386 shapeAxes, loadAxes<unsigned_t>(dict.at("axes"), in.dims().size()));
3387 std::sort(shapeAxes.begin(), shapeAxes.end());
3388 if (shapeAxes.size() > 1) {
3389 auto it = std::unique(shapeAxes.begin(), shapeAxes.end());
3390 if (it != shapeAxes.end())
3391 return MAKE_ERR(opErrMsg(op, "ReduceL2 Axes values are not unique."),
3392 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
3393 }
3394 } else {
3395 shapeAxes.resize(in.dims().size());
3396 std::iota(shapeAxes.begin(), shapeAxes.end(), 0);
3397 }
3398
3399 bool keepDims = true;
3400 if (dict.count("keepdims")) {
3401 int keepdims;
3402 ASSIGN_VALUE_OR_RETURN_ERR(keepdims, loadInt(dict.at("keepdims")));
3403 keepDims = (bool)keepdims;
3404 }
3405
3406 // Reduceadd works only for single axis as of now.
3407 for (auto it = shapeAxes.rbegin(), e = shapeAxes.rend(); it != e; ++it) {
3408 in = G_->createBatchedReduceAdd(opName, in, llvm::makeArrayRef(*it));
3409 if (keepDims) {
3410 in = G_->createExpandDims(opName, in, *it);
3411 }
3412 }
3413
3414 in = G_->createPow(opName, in, 0.5f);
3415 RETURN_IF_ERR(addNodeAsOutput(op, in));
3416 return Error::success();
3417}
3418
3419Error ONNXModelLoader::loadConstantOfShape(const ONNX_NAMESPACE::NodeProto &op,
3420 ArgumentDictionaryTy &dict,
3421 bool isSplat) {
3422 Tensor T(ElemKind::FloatTy, {1});
3423 T.getHandle().raw(0) = 0.0;
3424
3425 if (dict.count("value")) {
3426 RETURN_IF_ERR(loadTensor(dict.at("value")->t(), &T, useGlowCustomOps_));
3427 if (!isSplat) {
3428 // Validate tensor only for ConstantOfShape operator.
3429 RETURN_ERR_IF_NOT(
3430 T.dims().size() == 1,
3431 opErrMsg(op, strFormat("ConstantOfShape Value must be "
3432 "a 1D vector, but found size %zu ",
3433 T.dims().size())));
3434 RETURN_ERR_IF_NOT(
3435 T.getType().getElementType() == ElemKind::FloatTy ||
3436 T.getType().getElementType() == ElemKind::Int64ITy ||
3437 T.getType().getElementType() == ElemKind::Int32ITy,
3438 T.getType().getElementName().str() + " type Value is not supported.");
3439 }
3440 }
3441
3442 TypeRef ty;
3443 Node *SN = nullptr;
3444 if (op.input_size() > 0) {
3445 Constant *in;
3446 ASSIGN_VALUE_OR_RETURN_ERR(in, getConstantByName(op.input(0)));
3447 // Must be 1D tensor of int64_t.
3448 RETURN_ERR_IF_NOT(
3449 in->dims().size() == 1,
3450 opErrMsg(
3451 op,
3452 strFormat(
3453 "ConstantOfShape Input must be a 1D vector, but found size %d ",
3454 int(in->dims().size()))));
3455 RETURN_ERR_IF_NOT(
3456 in->getType()->getElementType() == ElemKind::Int64ITy,
3457 opErrMsg(op, "ConstantOfShape Input element type must be Int64ITy."));
3458 // Convert 1D tensor of int64_t into llvm::ArrayRef<dim_t>.
3459 auto TH = in->getPayload().getHandle<int64_t>();
3460 auto begin = &TH.raw(0);
3461 auto end = begin + TH.actualSize();
3462 ShapeVector outputDims(begin, end);
3463
3464 ty = mod_.uniqueType(T.getType().getElementType(), outputDims);
3465 switch (T.getType().getElementType()) {
3466 case ElemKind::Int64ITy: {
3467 int64_t v = T.getHandle<int64_t>().raw(0);
3468 RETURN_ERR_IF_NOT(
3469 v == static_cast<int64_t>(static_cast<float>(v)),
3470 opErrMsg(
3471 op, "ConstantOfShape implementation may cause losses for value " +
3472 std::to_string(v) + " ."));
3473 SN = G_->createSplat(loadOperatorName(op), ty, v);
3474 break;
3475 }
3476 case ElemKind::Int32ITy: {
3477 int32_t v = T.getHandle<int32_t>().raw(0);
3478 RETURN_ERR_IF_NOT(
3479 v == static_cast<int32_t>(static_cast<float>(v)),
3480 opErrMsg(
3481 op, "ConstantOfShape implementation may cause losses for value " +
3482 std::to_string(v) + " ."));
3483 SN = G_->createSplat(loadOperatorName(op), ty, v);
3484 break;
3485 }
3486 default:
3487 SN = G_->createSplat(loadOperatorName(op), ty, T.getHandle().raw(0));
3488 }
3489 } else {
3490 ty = mod_.uniqueType(T.getType().getElementType(), T.dims());
3491 SN = G_->createSplat(loadOperatorName(op), ty, T.getHandle().raw(0));
3492 }
3493 RETURN_IF_ERR(addNodeAsOutput(op, SN));
3494 return Error::success();
3495}
3496
3497Error ONNXModelLoader::loadTile(const ONNX_NAMESPACE::NodeProto &op,
3498 ArgumentDictionaryTy &dict) {
3499 const std::string &opName = loadOperatorName(op);
3500 NodeValue in, repeats;
3501 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
3502 ASSIGN_VALUE_OR_RETURN_ERR(repeats, getNodeValueByName(op.input(1)));
3503 if (!llvm::isa<Constant>(repeats)) {
3504 return MAKE_ERR(opErrMsg(op, "Tile Only constant Repeats is supported!"));
3505 }
3506
3507 if (repeats.dims().size() != 1) {
3508 return MAKE_ERR(
3509 opErrMsg(op, "Tile Repeats must be a single-dimensional tensor!"));
3510 }
3511
3512 if (repeats.dims()[0] != in.dims().size()) {
3513 return MAKE_ERR(opErrMsg(
3514 op, "Tile Repeats should have one value for each dimension of input!"));
3515 }
3516 auto rh = llvm::cast<Constant>(repeats)->getPayload().getHandle<int64_t>();
3517 Node *N = in;
3518 for (size_t i = 0; i < in.dims().size(); i++) {
3519 auto tiles = rh.raw(i);
3520 if (tiles != 1) {
3521 std::string name = opName + "." + std::to_string(i);
3522 N = G_->createTile(name, N, tiles, /*axis*/ i);
3523 }
3524 }
3525
3526 RETURN_IF_ERR(addNodeAsOutput(op, N));
3527 return Error::success();
3528}
3529
3530Error ONNXModelLoader::loadExpand(const ONNX_NAMESPACE::NodeProto &op,
3531 const ArgumentDictionaryTy &dict) {
3532 const std::string &opName = loadOperatorName(op);
3533 NodeValue in;
3534 Constant *repeats;
3535 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
3536 ASSIGN_VALUE_OR_RETURN_ERR(repeats, getConstantByName(op.input(1)));
3537
3538 std::vector<dim_t> tiles;
3539 helperSetter<int64_t, dim_t>(repeats, tiles);
3540 auto inputDimSize = (size_t)in.dims().size();
3541 auto repeatSize = (size_t)tiles.size();
3542 if (repeatSize > inputDimSize) {
3543 for (size_t i = 0, e = repeatSize - inputDimSize; i < e; i++) {
3544 in = G_->createExpandDims(opName + "_" + std::to_string(i), in, i);
3545 }
3546 }
3547
3548 Node *N = in;
3549 for (size_t i = 0, e = tiles.size(); i < e; i++) {
3550 // Two corresponding dimension must have the same value,
3551 // or one of them is equal to 1.
3552 if (in.dims()[i] != 1 && tiles[i] != in.dims()[i] && tiles[i] != 1) {
3553 return MAKE_ERR(opErrMsg(op, "Expand Invalid repeat value"));
3554 }
3555 if (tiles[i] != in.dims()[i] && tiles[i] != 1) {
3556 std::string name = opName + "_" + std::to_string(i);
3557 N = G_->createTile(name, N, tiles[i], /*axis*/ i);
3558 }
3559 }
3560
3561 RETURN_IF_ERR(addNodeAsOutput(op, N));
3562 return Error::success();
3563}
3564
3565Expected<bool>
3566ONNXModelLoader::foldOperator(const ONNX_NAMESPACE::NodeProto &op) {
3567 const unsigned numInputs = op.input_size();
3568 const std::string &typeName = op.op_type();
3569 llvm::SmallVector<NodeValue, 4> inputs;
3570 inputs.reserve(numInputs);
3571 for (unsigned i = 0; i < numInputs; i++) {
3572 // If the name of the input is empty then consider it to be unspecified,
3573 // which is valid for optional inputs, so simply skip. If it is necessary
3574 // for loading the op, then when we later try to load the proper error will
3575 // be propagated upward.
3576 if (op.input(i).empty()) {
3577 continue;
3578 }
3579 NodeValue in;
3580 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
3581 inputs.push_back(in);
3582 }
3583
3584 if (!isConstantFoldable(inputs, typeName)) {
3585 return false;
3586 }
3587
3588 // Create a temporary lightweight loader to construct function representing
3589 // current Op, and then constant fold the function using Interp backend.
3590 Function *tmpF = mod_.createFunction("eval_const_fold__");
3591 ONNXModelLoader tmpLoader(*tmpF);
3592 tmpLoader.opsetVersion_ = opsetVersion_;
3593 bool foldStatus = !ERR_TO_BOOL(
3594 constantFoldInLoader<ONNXModelLoader, ONNX_NAMESPACE::NodeProto>(
3595 tmpF, tmpLoader, this, op),
3596 /* log */ false);
3597 mod_.eraseFunction(tmpF);
3598 return foldStatus;
3599}
3600
3601Error ONNXModelLoader::loadWhere(const ONNX_NAMESPACE::NodeProto &op,
3602 ArgumentDictionaryTy &dict) {
3603 NodeValue cNV;
3604 ASSIGN_VALUE_OR_RETURN_ERR(cNV, getNodeValueByName(op.input(0)));
3605 NodeValue xNV;
3606 ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(1)));
3607 NodeValue yNV;
3608 ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(2)));
3609
3610 std::string opName = loadOperatorName(op);
3611
3612 // Passing -1 for multi directional broadcast, axis will be computed
3613 // automatically.
3614 Node *N = G_->createNodeWithBroadcast<SelectNode>(opName, -1, cNV, xNV, yNV);
3615
3616 RETURN_IF_ERR(addNodeAsOutput(op, N));
3617 return Error::success();
3618}
3619
3620/// Utility function to get the RNN, GRU or LSTM direction from the proto
3621/// description. If not provided, the default direction is 'forward'.
3622static Expected<Function::RnnDirection>
3623getRnnDirection(const ONNX_NAMESPACE::NodeProto &op,
3624 ArgumentDictionaryTy &dict) {
3625 Function::RnnDirection direction = Function::RnnDirection::Forward;
3626 if (dict.count("direction")) {
3627 std::string directionStr;
3628 ASSIGN_VALUE_OR_RETURN_ERR(directionStr, loadStr(dict.at("direction")));
3629 if (directionStr == "forward") {
3630 direction = Function::RnnDirection::Forward;
3631 } else if (directionStr == "reverse") {
3632 direction = Function::RnnDirection::Reverse;
3633 } else if (directionStr == "bidirectional") {
3634 direction = Function::RnnDirection::Bidirectional;
3635 } else {
3636 return MAKE_ERR(
3637 "ONNX " + op.op_type() + " 'direction' attribute is invalid!",
3638 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
3639 }
3640 }
3641 return direction;
3642}
3643
3644/// Relu activation function definition.
3645static Function::RnnActivation RnnActivationRelu(Function &F) {
3646 return [&F](llvm::StringRef name, Node *input) {
3647 return F.createRELU(name, input);
3648 };
3649}
3650
3651/// Tanh activation function definition.
3652static Function::RnnActivation RnnActivationTanh(Function &F) {
3653 return [&F](llvm::StringRef name, Node *input) {
3654 return F.createTanh(name, input);
3655 };
3656}
3657
3658/// Sigmoid activation function definition.
3659static Function::RnnActivation RnnActivationSigmoid(Function &F) {
3660 return [&F](llvm::StringRef name, Node *input) {
3661 return F.createSigmoid(name, input);
3662 };
3663}
3664
3665/// Utility function to get the RNN, GRU or LSTM activation functions from the
3666/// proto description. The activation function array is assumed to be already
3667/// initialized with the default values upon entering this function so that the
3668/// purpose of this function is to overwrite the specific default values.
3669/// Currenlty only Sigmoid, Tahn and ReLU activations are supported.
3670static Error
3671getRnnActivations(const ONNX_NAMESPACE::NodeProto &op,
3672 ArgumentDictionaryTy &dict, Function *F,
3673 std::vector<Function::RnnActivation> &activations) {
3674
3675 // Activation alpha not supported (Optional)(Default:activation dependent).
3676 RETURN_ERR_IF_NOT(!dict.count("activation_alpha"),
3677 "ONNX " + op.op_type() +
3678 " 'activation_alpha' attribute not supported!");
3679
3680 // Activation beta not supported (Optional)(Default:activation dependent).
3681 RETURN_ERR_IF_NOT(!dict.count("activation_beta"),
3682 "ONNX " + op.op_type() +
3683 " 'activation_beta' attribute not supported!");
3684
3685 // Get activations.
3686 if (dict.count("activations") && dict.at("activations")->strings_size()) {
3687 size_t actNum = dict.at("activations")->strings_size();
3688 size_t actNumExpected = activations.size();
3689 RETURN_ERR_IF_NOT(actNum == actNumExpected,
3690 strFormat("ONNX %s 'activations' attribute has invalid "
3691 "number of functions! Expected number is %d!",
3692 op.op_type().c_str(), (int)actNumExpected));
3693 for (size_t actIdx = 0; actIdx < actNum; actIdx++) {
3694 std::string actStr = dict.at("activations")->strings().Get(actIdx);
3695 if (actStr == "Relu") {
3696 activations[actIdx] = RnnActivationRelu(*F);
3697 } else if (actStr == "Tanh") {
3698 activations[actIdx] = RnnActivationTanh(*F);
3699 } else if (actStr == "Sigmoid") {
3700 activations[actIdx] = RnnActivationSigmoid(*F);
3701 } else {
3702 return MAKE_ERR(
3703 "ONNX " + op.op_type() + " activation '" + actStr +
3704 "' not supported!",
3705 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
3706 }
3707 }
3708 }
3709 return Error::success();
3710}
3711
3712// Limitations:
3713// - Activation clipping not supported.
3714// - Variable sequence length not supported.
3715Error ONNXModelLoader::loadRNN(const ONNX_NAMESPACE::NodeProto &op,
3716 ArgumentDictionaryTy &dict) {
3717
3718 const std::string &opName = loadOperatorName(op);
3719
3720 // ------------------------- Attributes -------------------------------------
3721 // Get direction (Optional)(Default:forward).
3722 Function::RnnDirection direction;
3723 ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict));
3724 dim_t numDirections =
3725 (direction == Function::RnnDirection::Bidirectional) ? 2 : 1;
3726
3727 // Get activations as lambdas (Optional)(Default:f=Tanh).
3728 std::vector<Function::RnnActivation> activations;
3729 if (direction == Function::RnnDirection::Bidirectional) {
3730 activations = {RnnActivationTanh(*G_), RnnActivationTanh(*G_)};
3731 } else {
3732 activations = {RnnActivationTanh(*G_)};
3733 }
3734 RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations));
3735
3736 // Activation clipping not supported (Optional)(Default: 0 for no clipping).
3737 RETURN_ERR_IF_NOT(!dict.count("clip"),
3738 opErrMsg(op, "ONNX RNN 'clip' attribute not supported!"));
3739
3740 // Get hidden size (Required).
3741 dim_t hiddenSize;
3742 RETURN_ERR_IF_NOT(
3743 dict.count("hidden_size"),
3744 opErrMsg(op, "ONNX RNN 'hidden_size' attribute is required!"));
3745 ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size")));
3746
3747 // --------------------------- Inputs ---------------------------------------
3748 const int numInputs = op.input_size();
3749 RETURN_ERR_IF_NOT((3 <= numInputs) && (numInputs <= 6),
3750 opErrMsg(op, strFormat("ONNX RNN should have minimum 3 and "
3751 "maximum 6 inputs, but found %d ",
3752 numInputs)));
3753
3754 // Input0: X (Required).
3755 NodeValue X;
3756 ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
3757
3758 // Input1: W (Required).
3759 NodeValue W;
3760 ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1)));
3761
3762 // Input2: R (Required).
3763 NodeValue R;
3764 ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2)));
3765
3766 // Input3: B (Optional).
3767 NodeValue B = nullptr;
3768 if (numInputs > 3 && !op.input(3).empty()) {
3769 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3)));
3770 }
3771
3772 // Input4: sequence_lens (Optional).
3773 if (numInputs > 4 && !op.input(4).empty()) {
3774 LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of "
3775 "ONNX RNN input.";
3776 }
3777
3778 // Input5: initial_h (Optional).
3779 NodeValue initial_h = nullptr;
3780 if (numInputs > 5 && !op.input(5).empty()) {
3781 ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5)));
3782 }
3783
3784 // -------------------------- Outputs ---------------------------------------
3785 // We allow creating placeholders for the RNN state variable Y_h for the
3786 // following reasons:
3787 // - expose the RNN state in the graph interface for accessibility (set
3788 // desired state, reset state, watch the state being updated automatically).
3789 // - since the RNN cells are unrolled (no graph loop primitive available
3790 // at this point), the optimal way to use the RNN within a model would be
3791 // to have it defined with only 1 time step and have the loop in the top
3792 // of the application while the RNN state will be automatically updated
3793 // from one iteration (time step) to the next through the placeholders.
3794
3795 // Derived parameters.
3796 RETURN_ERR_IF_NOT(
3797 X.dims().size() == 3,
3798 opErrMsg(op, "ONNX RNN input 'X' should have 3 dimensions!"));
3799 dim_t batchSize = X.dims()[1];
3800
3801 // Create Y_h (hidden state) output placeholder.
3802 Placeholder *Y_h_ph = nullptr;
3803 if (onnxExportRnnStatesOpt) {
3804 TypeRef Htype = mod_.uniqueTypeWithNewShape(
3805 X.getType(), {numDirections, batchSize, hiddenSize});
3806 std::string Hname = opName + ".Y_h";
3807 ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph,
3808 createAndRegisterPlaceholder(Hname, Htype));
3809 inputVarsByName_.try_emplace(Hname, Y_h_ph);
3810 }
3811
3812 // Set RNN input state.
3813 NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h;
3814
3815 // Create ONNX RNN.
3816 NodeValue Y, Y_h;
3817 G_->createOnnxRNN(opName, X, W, R, B, Y_h_init, Y, Y_h, hiddenSize, direction,
3818 activations);
3819
3820 // Save RNN output state.
3821 if (onnxExportRnnStatesOpt) {
3822 G_->createSave(opName + ".Y_h.save", Y_h, Y_h_ph);
3823 }
3824
3825 // Add node.
3826 const int numOutputs = op.output_size();
3827 if (numOutputs == 1) {
3828 RETURN_IF_ERR(addNodeAsOutput(op, Y));
3829 } else if (numOutputs == 2) {
3830 RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h}));
3831 } else {
3832 return MAKE_ERR(opErrMsg(op, strFormat("ONNX RNN should have minimum 1 and "
3833 "maximum 2 outputs, but found %d ",
3834 numOutputs)));
3835 }
3836 return Error::success();
3837}
3838
3839// Limitations:
3840// - Activation clipping not supported.
3841// - Variable sequence length not supported.
3842Error ONNXModelLoader::loadGRU(const ONNX_NAMESPACE::NodeProto &op,
3843 ArgumentDictionaryTy &dict) {
3844
3845 const std::string &opName = loadOperatorName(op);
3846
3847 // ------------------------- Attributes -------------------------------------
3848 // Get direction (Optional)(Default:forward).
3849 Function::RnnDirection direction;
3850 ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict));
3851 dim_t numDirections =
3852 (direction == Function::RnnDirection::Bidirectional) ? 2 : 1;
3853
3854 // Get activations as lambdas (Optional)(Default:f=Sigmoid, g=Tanh).
3855 std::vector<Function::RnnActivation> activations;
3856 if (direction == Function::RnnDirection::Bidirectional) {
3857 activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_),
3858 RnnActivationSigmoid(*G_), RnnActivationTanh(*G_)};
3859 } else {
3860 activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_)};
3861 }
3862 RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations));
3863
3864 // Activation clipping not supported (Optional)(Default: 0 for no clipping).
3865 RETURN_ERR_IF_NOT(!dict.count("clip"),
3866 opErrMsg(op, "ONNX GRU 'clip' attribute not supported!"));
3867
3868 // Get hidden size (Required).
3869 dim_t hiddenSize;
3870 RETURN_ERR_IF_NOT(
3871 dict.count("hidden_size"),
3872 opErrMsg(op, "ONNX GRU 'hidden_size' attribute is required!"));
3873 ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size")));
3874 ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size")));
3875
3876 // Get linear_before_reset (Optional)(Default:0).
3877 int linearBeforeReset = 0;
3878 if (dict.count("linear_before_reset") &&
3879 dict.at("linear_before_reset")->has_i()) {
3880 linearBeforeReset = dict.at("linear_before_reset")->i();
3881 }
3882
3883 // --------------------------- Inputs ---------------------------------------
3884 const int numInputs = op.input_size();
3885 RETURN_ERR_IF_NOT((3 <= numInputs) && (numInputs <= 6),
3886 opErrMsg(op, strFormat("ONNX GRU should have minimum 3 and "
3887 "maximum 6 inputs, but found %d ",
3888 numInputs)));
3889
3890 // Input0: X (Required).
3891 NodeValue X;
3892 ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
3893
3894 // Input1: W (Required).
3895 NodeValue W;
3896 ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1)));
3897
3898 // Input2: R (Required).
3899 NodeValue R;
3900 ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2)));
3901
3902 // Input3: B (Optional).
3903 NodeValue B = nullptr;
3904 if (numInputs > 3 && !op.input(3).empty()) {
3905 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3)));
3906 }
3907
3908 // Input4: sequence_lens (Optional).
3909 if (numInputs > 4 && !op.input(4).empty()) {
3910 LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of "
3911 "ONNX GRU input.";
3912 }
3913
3914 // Input5: initial_h (Optional).
3915 NodeValue initial_h = nullptr;
3916 if (numInputs > 5 && !op.input(5).empty()) {
3917 ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5)));
3918 }
3919
3920 // -------------------------- Outputs ---------------------------------------
3921 // We allow creating placeholders for the GRU state variable Y_h for the
3922 // following reasons:
3923 // - expose the GRU state in the graph interface for accessibility (set
3924 // desired state, reset state, watch the state being updated automatically).
3925 // - since the GRU cells are unrolled (no graph loop primitive available
3926 // at this point), the optimal way to use the GRU within a model would be
3927 // to have it defined with only 1 time step and have the loop in the top
3928 // of the application while the GRU state will be automatically updated
3929 // from one iteration (time step) to the next through the placeholders.
3930
3931 // Derived parameters.
3932 RETURN_ERR_IF_NOT(
3933 X.dims().size() == 3,
3934 opErrMsg(op, "ONNX GRU input 'X' should have 3 dimensions!"));
3935 dim_t batchSize = X.dims()[1];
3936
3937 // Create Y_h (hidden state) output placeholder.
3938 Placeholder *Y_h_ph = nullptr;
3939 if (onnxExportRnnStatesOpt) {
3940 TypeRef Htype = mod_.uniqueTypeWithNewShape(
3941 X.getType(), {numDirections, batchSize, hiddenSize});
3942 std::string Hname = opName + ".Y_h";
3943 ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph,
3944 createAndRegisterPlaceholder(Hname, Htype));
3945 inputVarsByName_.try_emplace(Hname, Y_h_ph);
3946 }
3947
3948 // Set GRU input state.
3949 NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h;
3950
3951 // Create ONNX GRU.
3952 NodeValue Y, Y_h;
3953 G_->createOnnxGRU(opName, X, W, R, B, Y_h_init, Y, Y_h, hiddenSize, direction,
3954 activations, (bool)linearBeforeReset);
3955
3956 // Save GRU output state.
3957 if (onnxExportRnnStatesOpt) {
3958 G_->createSave(opName + ".Y_h.save", Y_h, Y_h_ph);
3959 }
3960
3961 // Add node.
3962 const int numOutputs = op.output_size();
3963 if (numOutputs == 1) {
3964 RETURN_IF_ERR(addNodeAsOutput(op, Y));
3965 } else if (numOutputs == 2) {
3966 RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h}));
3967 } else {
3968 return MAKE_ERR(opErrMsg(op, strFormat("ONNX GRU should have minimum 1 and "
3969 "maximum 2 outputs, but found %d ",
3970 numOutputs)));
3971 }
3972 return Error::success();
3973}
3974
3975// Limitations:
3976// - Activation clipping not supported.
3977// - Variable sequence length not supported.
3978Error ONNXModelLoader::loadLSTM(const ONNX_NAMESPACE::NodeProto &op,
3979 ArgumentDictionaryTy &dict) {
3980
3981 const std::string &opName = loadOperatorName(op);
3982
3983 // ------------------------- Attributes -------------------------------------
3984 // Get direction (Optional)(Default:forward).
3985 Function::RnnDirection direction;
3986 ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict));
3987 dim_t numDirections =
3988 (direction == Function::RnnDirection::Bidirectional) ? 2 : 1;
3989
3990 // Get activations as lambdas (Optional)(Default:f=Sigmoid, g=Tanh, h=Tanh).
3991 std::vector<Function::RnnActivation> activations;
3992 if (direction == Function::RnnDirection::Bidirectional) {
3993 activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_),
3994 RnnActivationTanh(*G_), RnnActivationSigmoid(*G_),
3995 RnnActivationTanh(*G_), RnnActivationTanh(*G_)};
3996 } else {
3997 activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_),
3998 RnnActivationTanh(*G_)};
3999 }
4000 RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations));
4001
4002 // Activation clipping not supported (Optional)(Default: 0 for no clipping).
4003 RETURN_ERR_IF_NOT(!dict.count("clip"),
4004 opErrMsg(op, "ONNX LSTM 'clip' attribute not supported!"));
4005
4006 // Get hidden size (Required).
4007 dim_t hiddenSize;
4008 RETURN_ERR_IF_NOT(
4009 dict.count("hidden_size"),
4010 opErrMsg(op, "ONNX LSTM 'hidden_size' attribute is required!"));
4011 ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size")));
4012
4013 // Get input forget (Optional)(Default:0).
4014 int inputForget = 0;
4015 if (dict.count("input_forget") && dict.at("input_forget")->has_i()) {
4016 inputForget = dict.at("input_forget")->i();
4017 }
4018
4019 // --------------------------- Inputs ---------------------------------------
4020 const int numInputs = op.input_size();
4021 RETURN_ERR_IF_NOT(
4022 (3 <= numInputs) && (numInputs <= 8),
4023 opErrMsg(op, strFormat("ONNX LSTM should have minimum 3 and maximum 8 "
4024 "inputs, but found %d ",
4025 numInputs)));
4026
4027 // Input0: X (Required).
4028 NodeValue X;
4029 ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
4030
4031 // Input1: W (Required).
4032 NodeValue W;
4033 ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1)));
4034
4035 // Input2: R (Required).
4036 NodeValue R;
4037 ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2)));
4038
4039 // Input3: B (Optional).
4040 NodeValue B = nullptr;
4041 if (numInputs > 3 && !op.input(3).empty()) {
4042 ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3)));
4043 }
4044
4045 // Input4: sequence_lens (Optional).
4046 if (numInputs > 4 && !op.input(4).empty()) {
4047 LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of "
4048 "ONNX LSTM input.";
4049 }
4050
4051 // Input5: initial_h (Optional).
4052 NodeValue initial_h = nullptr;
4053 if (numInputs > 5 && !op.input(5).empty()) {
4054 ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5)));
4055 }
4056
4057 // Input6: initial_c (Optional).
4058 NodeValue initial_c = nullptr;
4059 if (numInputs > 6 && !op.input(6).empty()) {
4060 ASSIGN_VALUE_OR_RETURN_ERR(initial_c, getNodeValueByName(op.input(6)));
4061 }
4062
4063 // Input7: P (Optional).
4064 NodeValue P = nullptr;
4065 if (numInputs > 7 && !op.input(7).empty()) {
4066 ASSIGN_VALUE_OR_RETURN_ERR(P, getNodeValueByName(op.input(7)));
4067 }
4068
4069 // -------------------------- Outputs ---------------------------------------
4070 // We allow creating placeholders for the LSTM state variables (Y_h and Y_c)
4071 // for the following reasons:
4072 // - expose the LSTM state in the graph interface for accessibility (set
4073 // desired state, reset state, watch the state being updated automatically).
4074 // - since the LSTM cells are unrolled (no graph loop primitive available
4075 // at this point), the optimal way to use the LSTM within a model would be
4076 // to have it defined with only 1 time step and have the loop in the top
4077 // of the application while the LSTM state will be automatically updated
4078 // from one iteration (time step) to the next through the placeholders.
4079
4080 // Derived parameters.
4081 RETURN_ERR_IF_NOT(
4082 X.dims().size() == 3,
4083 opErrMsg(op, "ONNX LSTM input 'X' should have 3 dimensions!"));
4084 dim_t batchSize = X.dims()[1];
4085
4086 // Create Y_h (hidden state) output placeholder.
4087 Placeholder *Y_h_ph = nullptr;
4088 if (onnxExportRnnStatesOpt) {
4089 TypeRef Htype = mod_.uniqueTypeWithNewShape(
4090 X.getType(), {numDirections, batchSize, hiddenSize});
4091 std::string Hname = opName + ".Y_h";
4092 ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph,
4093 createAndRegisterPlaceholder(Hname, Htype));
4094 inputVarsByName_.try_emplace(Hname, Y_h_ph);
4095 }
4096
4097 // Create Y_c (cell state) output placeholder.
4098 Placeholder *Y_c_ph = nullptr;
4099 if (onnxExportRnnStatesOpt) {
4100 TypeRef Ctype = mod_.uniqueTypeWithNewShape(
4101 X.getType(), {numDirections, batchSize, hiddenSize});
4102 std::string Cname = opName + ".Y_c";
4103 ASSIGN_VALUE_OR_RETURN_ERR(Y_c_ph,
4104 createAndRegisterPlaceholder(Cname, Ctype));
4105 inputVarsByName_.try_emplace(Cname, Y_c_ph);
4106 }
4107
4108 // Set LSTM input states.
4109 NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h;
4110 NodeValue Y_c_init = onnxExportRnnStatesOpt ? Y_c_ph : initial_c;
4111
4112 // Create ONNX LSTM.
4113 NodeValue Y, Y_h, Y_c;
4114 G_->createOnnxLSTM(opName, X, W, R, B, Y_h_init, Y_c_init, P, Y, Y_h, Y_c,
4115 hiddenSize, direction, activations, (bool)inputForget);
4116
4117 // Save LSTM output states.
4118 if (onnxExportRnnStatesOpt) {
4119 G_->createSave(opName + ".Y_h.save", Y_h, Y_h_ph);
4120 G_->createSave(opName + ".Y_c.save", Y_c, Y_c_ph);
4121 }
4122
4123 // Add node.
4124 const int numOutputs = op.output_size();
4125 if (numOutputs == 1) {
4126 RETURN_IF_ERR(addNodeAsOutput(op, Y));
4127 } else if (numOutputs == 2) {
4128 RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h}));
4129 } else if (numOutputs == 3) {
4130 RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h, Y_c}));
4131 } else {
4132 return MAKE_ERR(
4133 opErrMsg(op, strFormat("ONNX LSTM should have minimum 1 and "
4134 "maximum 3 outputs, but found %d ",
4135 numOutputs)));
4136 }
4137 return Error::success();
4138}
4139
4140Error ONNXModelLoader::loadClip(const ONNX_NAMESPACE::NodeProto &op,
4141 const ArgumentDictionaryTy &dict) {
4142 NodeValue in;
4143 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4144
4145 float cmin = std::numeric_limits<float>::lowest();
4146 if (opsetVersion_ > 10 && op.input_size() > 1 && !op.input(1).empty()) {
4147 // min value is optional and might not be supplied.
4148 Constant *minC = getConstantByNameOrNull(op.input(1));
4149 RETURN_ERR_IF_NOT(minC, "Expect constant for min value in Clip operator.");
4150 cmin = minC->getPayload().getHandle().raw(0);
4151 } else if (dict.count("min")) {
4152 ASSIGN_VALUE_OR_RETURN_ERR(cmin, loadFloat(dict.find("min")->second));
4153 }
4154
4155 // Windows headers define `max` macro, so have to wrap the function name in
4156 // parenthesis to avoid compilation error.
4157 float cmax = (std::numeric_limits<float>::max)();
4158 if (opsetVersion_ > 10 && op.input_size() > 2 && !op.input(2).empty()) {
4159 // max value is optional and might not be supplied.
4160 Constant *maxC = getConstantByNameOrNull(op.input(2));
4161 RETURN_ERR_IF_NOT(maxC, "Expect constant for max value in Clip operator.");
4162 cmax = maxC->getPayload().getHandle().raw(0);
4163 } else if (dict.count("max")) {
4164 ASSIGN_VALUE_OR_RETURN_ERR(cmax, loadFloat(dict.find("max")->second));
4165 }
4166
4167 auto *node = G_->createClip(loadOperatorName(op), in, cmin, cmax);
4168 RETURN_IF_ERR(addNodeAsOutput(op, node));
4169 return Error::success();
4170}
4171
4172Error ONNXModelLoader::loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op,
4173 ArgumentDictionaryTy &dict) {
4174 NodeValue LHS;
4175 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0)));
4176 NodeValue RHS;
4177 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1)));
4178
4179 Node *N = G_->createNodeWithBroadcast<CmpEQNode>(loadOperatorName(op),
4180 /* axis */ -1, LHS, RHS);
4181 RETURN_IF_ERR(addNodeAsOutput(op, N));
4182 return Error::success();
4183}
4184
4185Error ONNXModelLoader::loadCmpLTE(const ONNX_NAMESPACE::NodeProto &op,
4186 ArgumentDictionaryTy &dict) {
4187 NodeValue LHS;
4188 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0)));
4189 NodeValue RHS;
4190 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1)));
4191
4192 Node *N = G_->createNodeWithBroadcast<CmpLTENode>(loadOperatorName(op),
4193 /* axis */ -1, LHS, RHS);
4194 RETURN_IF_ERR(addNodeAsOutput(op, N));
4195 return Error::success();
4196}
4197
4198/// Takes a list of NodeValues \p inputs and broadcasts them to a common shape
4199/// \p broadcastShape based on the maximum value along each dimension.
4200static Error getShapeForBroadcast(llvm::ArrayRef<NodeValue> inputs,
4201 std::vector<dim_t> &broadcastShape) {
4202 std::vector<uint32_t> numDims;
4203 for (auto &N : inputs) {
4204 numDims.push_back(N.dims().size());
4205 }
4206 const uint32_t outputNumDims =
4207 *std::max_element(numDims.begin(), numDims.end());
4208 for (uint32_t i = 0; i < outputNumDims; i++) {
4209 std::vector<dim_t> dims;
4210 for (uint32_t j = 0; j < inputs.size(); j++) {
4211 auto vals = inputs[j].dims();
4212 if (vals.size() > i) {
4213 dims.push_back(vals[vals.size() - 1 - i]);
4214 }
4215 }
4216 broadcastShape.insert(broadcastShape.begin(),
4217 *std::max_element(dims.begin(), dims.end()));
4218 }
4219 return Error::success();
4220}
4221
4222Error ONNXModelLoader::loadMean(const ONNX_NAMESPACE::NodeProto &op,
4223 ArgumentDictionaryTy &dict) {
4224 size_t numInputTensors = op.input_size();
4225 if (numInputTensors == 1) {
4226 NodeValue in;
4227 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4228 RETURN_IF_ERR(addNodeAsOutput(op, in));
4229 } else {
4230 const std::string &opName = loadOperatorName(op);
4231 llvm::SmallVector<NodeValue, 4> inputTensors;
4232 inputTensors.reserve(numInputTensors);
4233 for (unsigned i = 0; i < numInputTensors; i++) {
4234 NodeValue in;
4235 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
4236 inputTensors.push_back(in);
4237 }
4238 std::vector<dim_t> broadcastShape;
4239 RETURN_IF_ERR(getShapeForBroadcast(inputTensors, broadcastShape));
4240 for (unsigned i = 0; i < numInputTensors; i++) {
4241 auto &in = inputTensors[i];
4242 int axis = broadcastShape.size() - in.dims().size();
4243 in = G_->createBroadcast(opName, in, broadcastShape, axis);
4244 in = G_->createExpandDims(opName, in, {0});
4245 }
4246 ConcatNode *concat = G_->createConcat(opName, inputTensors, /* axis */ 0);
4247 Node *N = G_->createBatchedReduceMean(opName, concat, /* axis */ {0});
4248 RETURN_IF_ERR(addNodeAsOutput(op, N));
4249 }
4250 return Error::success();
4251}
4252
4253Error ONNXModelLoader::loadSelect(const ONNX_NAMESPACE::NodeProto &op,
4254 ArgumentDictionaryTy &dict) {
4255 NodeValue Cond;
4256 ASSIGN_VALUE_OR_RETURN_ERR(Cond, getNodeValueByName(op.input(0)));
4257 NodeValue LHS;
4258 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(1)));
4259 NodeValue RHS;
4260 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(2)));
4261
4262 std::vector<dim_t> shape;
4263 ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape"]));
4264
4265 auto outTy = mod_.uniqueType(LHS.getElementType(), shape);
4266 Node *N = G_->createSelect(loadOperatorName(op), outTy, Cond, LHS, RHS);
4267
4268 RETURN_IF_ERR(addNodeAsOutput(op, N));
4269 return Error::success();
4270}
4271
4272Error ONNXModelLoader::loadNonZero(const ONNX_NAMESPACE::NodeProto &op,
4273 const ArgumentDictionaryTy &dict) {
4274 NodeValue input;
4275 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4276
4277 Constant *C = getConstantByNameOrNull(op.input(0));
4278 RETURN_ERR_IF_NOT(C,
4279 opErrMsg(op, "NonZero Only constant shape is supported!"));
4280
4281 // output tensor.
4282 Tensor outT;
4283
4284 // Fold NonZero operator.
4285 auto foldNonZero = [&C, &outT](auto dummy) -> Error {
4286 auto inH = C->getPayload().getHandle<decltype(dummy)>();
4287 auto dims = C->dims();
4288
4289 // First pass over the input is used to find the number of non-zero elements
4290 // so we can create output tensor that will be filled in the 2nd pass.
4291 dim_t nonZeroCnt = 0;
4292 for (dim_t idx = 0, e = inH.size(); idx < e; idx++) {
4293 nonZeroCnt += (inH.raw(idx) != 0) ? 1 : 0;
4294 }
4295
4296 // No need to support zero Tensor (empty output); we support constant input
4297 // only and such input likely means it's an invalid model or the part of
4298 // graph can be removed.
4299 RETURN_ERR_IF_NOT(nonZeroCnt > 0,
4300 "Non-Zero input with all zeroes is not supported.");
4301
4302 // Create output tensor. First dimension is the rank of input tensor, second
4303 // dimension is the number of non-zero elements.
4304 outT.reset(ElemKind::Int64ITy, {(dim_t)dims.size(), nonZeroCnt});
4305
4306 // Strides for each dimensions, needed to calculate NonZero output.
4307 std::vector<dim_t> strides;
4308 strides.resize(dims.size());
4309 strides[dims.size() - 1] = 1;
4310 if (dims.size() > 1) {
4311 for (int i = dims.size() - 2; i >= 0; i--) {
4312 strides[i] = dims[i + 1] * strides[i + 1];
4313 }
4314 }
4315
4316 // Second pass over the input is used to fill the output tensor. For each
4317 // non-zero element we fill all the dimensions, at position determined
4318 // by the non-zero element's index when zero elements are ignored.
4319 auto outH = outT.getHandle<int64_t>();
4320 for (dim_t idx = 0, pos = 0, e = inH.size(); idx < e; idx++) {
4321 if (inH.raw(idx) != 0) {
4322 for (dim_t dim = 0; dim < dims.size(); dim++) {
4323 outH.at({dim, pos}) = (idx / strides[dim]) % dims[dim];
4324 }
4325 pos++;
4326 }
4327 }
4328 return Error::success();
4329 };
4330
4331 std::string err;
4332 if (C->getElementType() == ElemKind::FloatTy) {
4333 RETURN_IF_ERR(foldNonZero((float)0));
4334 } else if (C->getElementType() == ElemKind::Int64ITy) {
4335 RETURN_IF_ERR(foldNonZero((int64_t)0));
4336 } else if (C->getElementType() == ElemKind::Int32ITy) {
4337 RETURN_IF_ERR(foldNonZero((int32_t)0));
4338 } else {
4339 return MAKE_ERR(
4340 opErrMsg(op, "NonZero: Unsupported input type for NonZero operator."
4341 "(Supports Float, Int32 and Int64)"));
4342 }
4343
4344 Constant *outC = G_->getParent()->createConstant("nonZero", std::move(outT));
4345 RETURN_IF_ERR(addNodeAsOutput(op, outC));
4346 return Error::success();
4347}
4348
4349Error ONNXModelLoader::loadQuantize(const ONNX_NAMESPACE::NodeProto &op,
4350 ArgumentDictionaryTy &dict) {
4351 NodeValue in;
4352 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4353
4354 float scale;
4355 ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict.at("scale")));
4356 unsigned_t offset;
4357 ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict.at("offset")));
4358 std::string elemKindStr;
4359 ASSIGN_VALUE_OR_RETURN_ERR(elemKindStr, loadStr(dict.at("elem_kind")));
4360
4361 ElemKind elemKind = Type::getElementKindFromName(elemKindStr);
4362
4363 auto outDims = in.getType()->dims();
4364 auto outTy = mod_.uniqueType(elemKind, outDims, scale, offset);
4365 Node *N = G_->createQuantize(loadOperatorName(op), in, outTy);
4366
4367 RETURN_IF_ERR(addNodeAsOutput(op, N));
4368 return Error::success();
4369}
4370
4371Error ONNXModelLoader::loadQuantizeLinear(const ONNX_NAMESPACE::NodeProto &op,
4372 ArgumentDictionaryTy &dict) {
4373 NodeValue in;
4374 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4375
4376 // Onnx documentation expects input to be Float or Int32.
4377 if (!(in.getElementType() == ElemKind::FloatTy ||
4378 in.getElementType() == ElemKind::Int32ITy)) {
4379 return MAKE_ERR(
4380 opErrMsg(op, "QuantizeLinear supports input to be Float or Int32."));
4381 }
4382
4383 // Only scale with constant is supported.
4384 Constant *scale;
4385 ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1)));
4386
4387 // Glow supports only per layer scale.
4388 if (!((scale->getType()->dims().size() == 1 &&
4389 scale->getType()->dims()[0] == 1) ||
4390 scale->getType()->dims().size() == 0)) {
4391 return MAKE_ERR(opErrMsg(op, "QuantizeLinear: y_scale scalar value is only"
4392 " supported."));
4393 }
4394
4395 float scaleValue = scale->getPayload().getHandle<float>().raw(0);
4396
4397 // Default values as per Onnx documentation.
4398 int32_t offsetValue = 0;
4399 auto type = ElemKind::UInt8QTy;
4400
4401 // Check if we have a offset vector.
4402 if (op.input_size() > 2) {
4403 auto &offsetTensorName = op.input(2);
4404 // Only offset with constant is supported.
4405 Constant *offset = nullptr;
4406 // Load the serialized offset vector.
4407 ASSIGN_VALUE_OR_RETURN_ERR(offset, getConstantByName(offsetTensorName));
4408 if (!((offset->getType()->dims().size() == 1 &&
4409 offset->getType()->dims()[0] == 1) ||
4410 offset->getType()->dims().size() == 0)) {
4411 return MAKE_ERR(
4412 opErrMsg(op, "QuantizeLinear: y_zero_point scalar value is only"
4413 " supported."));
4414 }
4415
4416 type = offset->getElementType();
4417 // Only uint8 and int8 values are supported as per onnx.
4418 if (type == ElemKind::UInt8QTy) {
4419 offsetValue = static_cast<int32_t>(
4420 offset->getPayload().getHandle<uint8_t>().raw(0));
4421 } else if (type == ElemKind::Int8QTy) {
4422 offsetValue =
4423 static_cast<int32_t>(offset->getPayload().getHandle<int8_t>().raw(0));
4424 } else {
4425 // This condition is hit when there is onnx graph creation issue or
4426 // constant is not created correctly.
4427 return MAKE_ERR(
4428 opErrMsg(op, "QuantizeLinear: Supports only uint8 or int8 data in"
4429 " y_zero_point"));
4430 }
4431 }
4432
4433 auto outDims = in.getType()->dims();
4434 auto outTy = mod_.uniqueType(type, outDims, scaleValue, offsetValue);
4435 Node *N = G_->createQuantize(loadOperatorName(op), in, outTy);
4436
4437 RETURN_IF_ERR(addNodeAsOutput(op, N));
4438 return Error::success();
4439}
4440
4441Error ONNXModelLoader::loadConvertTo(const ONNX_NAMESPACE::NodeProto &op,
4442 ArgumentDictionaryTy &dict) {
4443 NodeValue in;
4444 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4445
4446 const auto *attr = dict.at("shape");
4447 RETURN_ERR_IF_NOT(
4448 attr->has_t(),
4449 opErrMsg(op, "ConvertTo should have t() field as \"shape\""));
4450 const auto &t = attr->t();
4451 std::vector<dim_t> shape;
4452 for (const auto d : t.dims()) {
4453 shape.push_back(d);
4454 }
4455
4456 auto type = ElemKind::FloatTy;
4457 RETURN_IF_ERR(onnxTensorDataTypeToElemKind(t.data_type(), &type));
4458 auto outTy = mod_.uniqueType(type, shape);
4459 Node *N = G_->createConvertTo(loadOperatorName(op), in, outTy);
4460
4461 RETURN_IF_ERR(addNodeAsOutput(op, N));
4462 return Error::success();
4463}
4464
4465Error ONNXModelLoader::loadDequantize(const ONNX_NAMESPACE::NodeProto &op,
4466 ArgumentDictionaryTy &dict) {
4467 NodeValue in;
4468 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4469
4470 Node *N = G_->createDequantize(loadOperatorName(op), in, ElemKind::FloatTy);
4471
4472 RETURN_IF_ERR(addNodeAsOutput(op, N));
4473 return Error::success();
4474}
4475
4476Error ONNXModelLoader::loadRegression(const ONNX_NAMESPACE::NodeProto &op,
4477 ArgumentDictionaryTy &dict) {
4478 NodeValue in;
4479 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4480 NodeValue expected;
4481 ASSIGN_VALUE_OR_RETURN_ERR(expected, getNodeValueByName(op.input(1)));
4482
4483 Node *N = G_->createRegression(loadOperatorName(op), in, expected);
4484
4485 RETURN_IF_ERR(addNodeAsOutput(op, N));
4486 return Error::success();
4487}
4488
4489Error ONNXModelLoader::loadBatchedAdd(const ONNX_NAMESPACE::NodeProto &op,
4490 ArgumentDictionaryTy &dict) {
4491 NodeValue batch;
4492 ASSIGN_VALUE_OR_RETURN_ERR(batch, getNodeValueByName(op.input(0)));
4493 NodeValue sample;
4494 ASSIGN_VALUE_OR_RETURN_ERR(sample, getNodeValueByName(op.input(1)));
4495
4496 Node *N = G_->createBatchedAdd(loadOperatorName(op), batch, sample);
4497
4498 RETURN_IF_ERR(addNodeAsOutput(op, N));
4499 return Error::success();
4500}
4501
4502Error ONNXModelLoader::loadCumSum(const ONNX_NAMESPACE::NodeProto &op,
4503 ArgumentDictionaryTy &dict) {
4504 if (op.input_size() > 1) {
4505 Expected<NodeValue> axis = getNodeValueByName(op.input(1));
4506 if (axis) {
4507 if (auto *AC = llvm::dyn_cast<Constant>(axis->getNode())) {
4508 RETURN_ERR_IF_NOT(AC->getPayload().dims().size() == 1,
4509 opErrMsg(op, "CumSum axis must be 0-D"));
4510 RETURN_ERR_IF_NOT(AC->getPayload().dims()[0] == 1,
4511 opErrMsg(op, "CumSum axis must be 0-D"));
4512 RETURN_ERR_IF_NOT(AC->getHandle<int32_t>().at(0) == 0,
4513 opErrMsg(op, "CumSum only supports axis == 0"));
4514 } else {
4515 return MAKE_ERR(opErrMsg(op, "Axis must be Constant"));
4516 }
4517
4518 // Axis default is 0, which is fine.
4519 }
4520 }
4521
4522 NodeValue input;
4523 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4524 bool exclusive = false;
4525 if (dict.count("exclusive")) {
4526 ASSIGN_VALUE_OR_RETURN_ERR(exclusive, loadInt(dict.at("exclusive")));
4527 }
4528
4529 bool reverse = false;
4530 if (dict.count("reverse")) {
4531 ASSIGN_VALUE_OR_RETURN_ERR(reverse, loadInt(dict.at("reverse")));
4532 }
4533
4534 // TODO: add axis/dim support
4535 Node *N =
4536 G_->createCumSum(loadOperatorName(op), input, 0, exclusive, reverse);
4537 RETURN_IF_ERR(addNodeAsOutput(op, N));
4538 return Error::success();
4539}
4540
4541Error ONNXModelLoader::loadScatterAssign(const ONNX_NAMESPACE::NodeProto &op,
4542 ArgumentDictionaryTy &dict) {
4543 NodeValue data;
4544 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
4545 NodeValue indices;
4546 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
4547 NodeValue slices;
4548 ASSIGN_VALUE_OR_RETURN_ERR(slices, getNodeValueByName(op.input(2)));
4549
4550 Node *N = G_->createScatterData(loadOperatorName(op), data, indices, slices);
4551
4552 RETURN_IF_ERR(addNodeAsOutput(op, N));
4553 return Error::success();
4554}
4555
4556Error ONNXModelLoader::loadIntLookupTable(const ONNX_NAMESPACE::NodeProto &op,
4557 ArgumentDictionaryTy &dict) {
4558 NodeValue in;
4559 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4560
4561 if (in.getType()->getElementType() == ElemKind::Int8QTy) {
4562 std::vector<int8_t> values;
4563 ASSIGN_VALUE_OR_RETURN_ERR(values, getShape<int8_t>(dict["values"]));
4564 std::vector<dim_t> shape;
4565 ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape"]));
4566
4567 auto outTy = mod_.uniqueType(in.getElementType(), shape);
4568 Node *N = G_->createIntLookupTable<int8_t>(loadOperatorName(op), in, values,
4569 outTy);
4570
4571 RETURN_IF_ERR(addNodeAsOutput(op, N));
4572 return Error::success();
4573 } else if (in.getType()->getElementType() == ElemKind::Int16QTy) {
4574 std::vector<int16_t> values;
4575 ASSIGN_VALUE_OR_RETURN_ERR(values, getShape<int16_t>(dict["values"]));
4576 std::vector<dim_t> shape;
4577 ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape"]));
4578
4579 auto outTy = mod_.uniqueType(in.getElementType(), shape);
4580 Node *N = G_->createIntLookupTable<int16_t>(loadOperatorName(op), in,
4581 values, outTy);
4582
4583 RETURN_IF_ERR(addNodeAsOutput(op, N));
4584 return Error::success();
4585 } else {
4586 return MAKE_ERR(strFormat("Lookup table type '%s' not supported!",
4587 in.getType()->getElementName().str().c_str()));
4588 }
4589}
4590
4591Error ONNXModelLoader::loadLengthsRangeFill(const ONNX_NAMESPACE::NodeProto &op,
4592 ArgumentDictionaryTy &dict) {
4593 NodeValue lengths;
4594 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
4595 unsigned_t size;
4596 ASSIGN_VALUE_OR_RETURN_ERR(size, loadInt(dict.at("size")));
4597
4598 Node *N = G_->createLengthsRangeFill(loadOperatorName(op), lengths, size);
4599
4600 RETURN_IF_ERR(addNodeAsOutput(op, N));
4601 return Error::success();
4602}
4603
4604Error ONNXModelLoader::loadRescaleQuantized(const ONNX_NAMESPACE::NodeProto &op,
4605 ArgumentDictionaryTy &dict) {
4606 NodeValue in;
4607 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4608 float scale;
4609 ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict.at("scale")));
4610 unsigned_t offset;
4611 ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict.at("offset")));
4612
4613 auto inTy = in.getType();
4614 auto outTy =
4615 mod_.uniqueType(inTy->getElementType(), inTy->dims(), scale, offset);
4616
4617 Node *N = G_->createRescaleQuantized(loadOperatorName(op), in, outTy);
4618
4619 RETURN_IF_ERR(addNodeAsOutput(op, N));
4620 return Error::success();
4621}
4622
4623Error ONNXModelLoader::loadRowwiseQuantizedSparseLengthsWeightedSum(
4624 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
4625 Constant *data;
4626 ASSIGN_VALUE_OR_RETURN_ERR(data, getConstantByName(op.input(0)));
4627 Constant *scales;
4628 ASSIGN_VALUE_OR_RETURN_ERR(scales, getConstantByName(op.input(1)));
4629 Constant *offsets;
4630 ASSIGN_VALUE_OR_RETURN_ERR(offsets, getConstantByName(op.input(2)));
4631 NodeValue weights;
4632 ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(3)));
4633 NodeValue indices;
4634 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(4)));
4635 NodeValue lengths;
4636 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(5)));
4637 LengthsMode lengthsMode;
4638 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
4639
4640 Node *N = G_->createRowwiseQuantizedSparseLengthsWeightedSum(
4641 loadOperatorName(op), data, scales, offsets, weights, indices, lengths,
4642 /* precision */ ElemKind::FloatTy, /* useFP16Accumulation */ false,
4643 lengthsMode);
4644
4645 RETURN_IF_ERR(addNodeAsOutput(op, N));
4646 return Error::success();
4647}
4648
4649Error ONNXModelLoader::loadFusedRowwiseQuantizedSparseLengthsWeightedSum(
4650 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
4651 NodeValue data;
4652 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
4653 NodeValue weights;
4654 ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1)));
4655 NodeValue indices;
4656 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(2)));
4657 NodeValue lengths;
4658 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(3)));
4659 LengthsMode lengthsMode;
4660 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
4661
4662 Node *N = G_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
4663 loadOperatorName(op), data, weights, indices, lengths,
4664 /* useFP16Accumulation */ false, lengthsMode);
4665
4666 RETURN_IF_ERR(addNodeAsOutput(op, N));
4667 return Error::success();
4668}
4669
4670Error ONNXModelLoader::loadFusedRowwiseQuantizedSparseLengthsSum(
4671 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
4672 NodeValue data;
4673 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
4674 NodeValue indices;
4675 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
4676 NodeValue lengths;
4677 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(2)));
4678 LengthsMode lengthsMode;
4679 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
4680
4681 Storage *dataS = llvm::dyn_cast<Storage>(data);
4682 Node *N = G_->createFusedRowwiseQuantizedSparseLengthsSum(
4683 loadOperatorName(op), dataS, indices, lengths,
4684 /* useFP16Accumulation */ false, lengthsMode);
4685
4686 RETURN_IF_ERR(addNodeAsOutput(op, N));
4687 return Error::success();
4688}
4689
4690Error ONNXModelLoader::loadFullyConnected(const ONNX_NAMESPACE::NodeProto &op,
4691 ArgumentDictionaryTy &dict) {
4692 NodeValue in;
4693 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4694 NodeValue w;
4695 ASSIGN_VALUE_OR_RETURN_ERR(w, getNodeValueByName(op.input(1)));
4696 NodeValue b;
4697 ASSIGN_VALUE_OR_RETURN_ERR(b, getNodeValueByName(op.input(2)));
4698
4699 unsigned_t axis = 1;
4700 if (dict.count("axis")) {
4701 ASSIGN_VALUE_OR_RETURN_ERR(
4702 axis, loadAxis<unsigned_t>(dict.at("axis"), in.dims().size()));
4703 }
4704
4705 Node *N = G_->createFullyConnected(loadOperatorName(op), in, w, b, axis);
4706
4707 RETURN_IF_ERR(addNodeAsOutput(op, N));
4708 return Error::success();
4709}
4710
4711Error ONNXModelLoader::loadRowwiseQuantizedFullyConnected(
4712 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) {
4713 NodeValue input;
4714 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4715
4716 NodeValue weights;
4717 ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1)));
4718 auto *weightsC = llvm::dyn_cast<Constant>(weights.getNode());
4719
4720 NodeValue scales;
4721 ASSIGN_VALUE_OR_RETURN_ERR(scales, getNodeValueByName(op.input(2)));
4722 auto *scalesC = llvm::dyn_cast<Constant>(scales.getNode());
4723
4724 NodeValue offsets;
4725 ASSIGN_VALUE_OR_RETURN_ERR(offsets, getNodeValueByName(op.input(3)));
4726 auto *offsetsC = llvm::dyn_cast<Constant>(offsets.getNode());
4727
4728 NodeValue bias;
4729 ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(4)));
4730 auto *biasC = llvm::dyn_cast<Constant>(bias.getNode());
4731
4732 float outScale;
4733 ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale")));
4734 int32_t outOffset;
4735 ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset")));
4736
4737 auto outTy =
4738 mod_.uniqueType(ElemKind::Int8QTy, {input.dims()[0], weights.dims()[0]},
4739 outScale, outOffset);
4740
4741 Node *N = G_->createRowwiseQuantizedFullyConnected(
4742 loadOperatorName(op) + ".rowwise_quantized_fc", input, weightsC, scalesC,
4743 offsetsC, biasC, outTy);
4744
4745 return addNodeAsOutput(op, N);
4746}
4747
4748Error ONNXModelLoader::loadNonMaxSuppression(
4749 const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict,
4750 bool isV4) {
4751 NodeValue boxesNV;
4752 ASSIGN_VALUE_OR_RETURN_ERR(boxesNV, getNodeValueByName(op.input(0)));
4753 NodeValue scoresNV;
4754 ASSIGN_VALUE_OR_RETURN_ERR(scoresNV, getNodeValueByName(op.input(1)));
4755 unsigned maxOutputBoxesPerClass = 0;
4756 Constant *maxOutputBoxesPerClassC = nullptr;
4757 if (op.input_size() > 2 && !op.input(2).empty()) {
4758 maxOutputBoxesPerClassC = getConstantByNameOrNull(op.input(2));
4759 RETURN_ERR_IF_NOT(maxOutputBoxesPerClassC,
4760 "NMS: maxOutputBoxesPerClass is not a constant tensor.");
4761 if (maxOutputBoxesPerClassC->getPayload().getElementType() ==
4762 ElemKind::Int64ITy) {
4763 maxOutputBoxesPerClass =
4764 maxOutputBoxesPerClassC->getPayload().getHandle<int64_t>().raw(0);
4765 } else if (maxOutputBoxesPerClassC->getPayload().getElementType() ==
4766 ElemKind::Int32ITy) {
4767 maxOutputBoxesPerClass =
4768 maxOutputBoxesPerClassC->getPayload().getHandle<int32_t>().raw(0);
4769 } else {
4770 return MAKE_ERR("NMS: Unsupported type for maxoutputboxesperclass.");
4771 }
4772 }
4773 float iouThreshold = 0.0f;
4774 Constant *iouThresholdC = nullptr;
4775 if (op.input_size() > 3 && !op.input(3).empty()) {
4776 iouThresholdC = getConstantByNameOrNull(op.input(3));
4777 RETURN_ERR_IF_NOT(iouThresholdC,
4778 "NMS: iouThreshold is not a constant tensor.");
4779 iouThreshold = iouThresholdC->getPayload().getHandle<float>().raw(0);
4780 }
4781 float scoreThreshold = 0.0f;
4782 Constant *scoreThresholdC = nullptr;
4783 if (op.input_size() > 4 && !op.input(4).empty()) {
4784 scoreThresholdC = getConstantByNameOrNull(op.input(4));
4785 RETURN_ERR_IF_NOT(scoreThresholdC,
4786 "NMS: scoreThreshold is not a constant tensor.");
4787 scoreThreshold = scoreThresholdC->getPayload().getHandle<float>().raw(0);
4788 }
4789
4790 // Defaults to 0 which is the same representation as TF.
4791 unsigned centerPointBox = 0;
4792 if (dict.count("center_point_box")) {
4793 ASSIGN_VALUE_OR_RETURN_ERR(centerPointBox,
4794 loadInt(dict.at("center_point_box")));
4795 }
4796
4797 int32_t padToMaxOutputSize = 0;
4798 if (isV4) {
4799 if (dict.count("pad_to_max_output_size")) {
4800 ASSIGN_VALUE_OR_RETURN_ERR(padToMaxOutputSize,
4801 loadInt(dict.at("pad_to_max_output_size")));
4802 }
4803
4804 // Does it make sense within GLOW context to have no padding? Since Size has
4805 // to be compile time constant.
4806 RETURN_ERR_IF_NOT(
4807 padToMaxOutputSize == 1,
4808 opErrMsg(op, "NonMaxSuppressionV4 does not support non-padding mode."));
4809 }
4810
4811 // Create Node.
4812 std::string opName = loadOperatorName(op);
4813 Node *N = nullptr;
4814
4815 if (isV4) {
4816 N = G_->createNonMaxSuppressionV4(opName, boxesNV, scoresNV, centerPointBox,
4817 maxOutputBoxesPerClass, iouThreshold,
4818 scoreThreshold);
4819 } else {
4820 N = G_->createNonMaxSuppressionONNX(opName, boxesNV, scoresNV,
4821 centerPointBox, maxOutputBoxesPerClass,
4822 iouThreshold, scoreThreshold);
4823 }
4824 RETURN_IF_ERR(addNodeAsOutput(op, N));
4825 return Error::success();
4826}
4827
4828Error ONNXModelLoader::loadSplat(const ONNX_NAMESPACE::NodeProto &op,
4829 ArgumentDictionaryTy &dict) {
4830 return loadConstantOfShape(op, dict, true /* isSplat */);
4831}
4832
4833Error ONNXModelLoader::loadInsertTensor(const ONNX_NAMESPACE::NodeProto &op,
4834 ArgumentDictionaryTy &dict) {
4835 NodeValue big;
4836 ASSIGN_VALUE_OR_RETURN_ERR(big, getNodeValueByName(op.input(0)));
4837 NodeValue small;
4838 ASSIGN_VALUE_OR_RETURN_ERR(small, getNodeValueByName(op.input(1)));
4839
4840 std::vector<dim_t> start;
4841 ASSIGN_VALUE_OR_RETURN_ERR(start, getShape<dim_t>(dict["start"]));
4842
4843 unsigned_t count = 1;
4844 if (dict.count("count")) {
4845 ASSIGN_VALUE_OR_RETURN_ERR(count, loadInt(dict.at("count")));
4846 }
4847
4848 unsigned_t axis = 0;
4849 if (dict.count("axis")) {
4850 ASSIGN_VALUE_OR_RETURN_ERR(
4851 axis, loadAxis<unsigned_t>(dict.at("axis"), big.dims().size()));
4852 }
4853
4854 Node *N = G_->createInsertTensor(loadOperatorName(op), big, small, start,
4855 count, axis);
4856
4857 RETURN_IF_ERR(addNodeAsOutput(op, N));
4858 return Error::success();
4859}
4860
4861Error ONNXModelLoader::loadIdentity(const ONNX_NAMESPACE::NodeProto &op,
4862 ArgumentDictionaryTy &dict) {
4863 NodeValue in;
4864 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
4865 RETURN_IF_ERR(addNodeAsOutput(op, in));
4866 return Error::success();
4867}
4868
4869Error ONNXModelLoader::loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
4870 ArgumentDictionaryTy &dict) {
4871 const std::string &opName = loadOperatorName(op);
4872
4873 NodeValue input;
4874 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4875
4876 std::vector<unsigned_t> outputShape;
4877 ASSIGN_VALUE_OR_RETURN_ERR(outputShape,
4878 getShape<unsigned_t>(dict["output_size"]));
4879
4880 ShapeNHWC idim(input.dims());
4881
4882 auto outTy = mod_.uniqueTypeWithNewShape(
4883 input.getType(), {idim.n, outputShape[0], outputShape[1], idim.c});
4884
4885 Node *N = G_->createAdaptiveAvgPool(opName, input, outTy);
4886
4887 return addNodeAsOutput(op, N);
4888}
4889
4890Error ONNXModelLoader::loadFlip(const ONNX_NAMESPACE::NodeProto &op,
4891 ArgumentDictionaryTy &dict) {
4892 NodeValue input;
4893 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4894
4895 unsigned_t axis = 0;
4896 if (dict.count("axis")) {
4897 ASSIGN_VALUE_OR_RETURN_ERR(
4898 axis, loadAxis<unsigned_t>(dict.at("axis"), input.dims().size()));
4899 }
4900
4901 Node *N = G_->createFlip(loadOperatorName(op) + ".flip", input, axis);
4902
4903 RETURN_IF_ERR(addNodeAsOutput(op, N));
4904 return Error::success();
4905}
4906
4907Error ONNXModelLoader::loadAudioSpectrogram(const ONNX_NAMESPACE::NodeProto &op,
4908 ArgumentDictionaryTy &dict) {
4909 NodeValue input;
4910 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
4911
4912 // Get window size (Required).
4913 int64_t windowSize;
4914 RETURN_ERR_IF_NOT(
4915 dict.count("window_size"),
4916 "ONNX AudioSpectrogram 'window_size' attribute is required!");
4917 ASSIGN_VALUE_OR_RETURN_ERR(windowSize, loadInt(dict.at("window_size")));
4918
4919 // Get window stride (Required).
4920 int64_t windowStride;
4921 RETURN_ERR_IF_NOT(dict.count("stride"),
4922 "ONNX AudioSpectrogram 'stride' attribute is required!");
4923 ASSIGN_VALUE_OR_RETURN_ERR(windowStride, loadInt(dict.at("stride")));
4924
4925 // Get magnitude squared flag (Optional)(Default: 1).
4926 int magnitudeSquared = 1;
4927 if (dict.count("magnitude_squared") &&
4928 dict.at("magnitude_squared")->has_i()) {
4929 magnitudeSquared = dict.at("magnitude_squared")->i();
4930 }
4931
4932 Node *N = G_->createAudioSpectrogram(loadOperatorName(op), input, windowSize,
4933 windowStride, (bool)magnitudeSquared);
4934
4935 RETURN_IF_ERR(addNodeAsOutput(op, N));
4936 return Error::success();
4937}
4938
4939Error ONNXModelLoader::loadROIAlign(const ONNX_NAMESPACE::NodeProto &op,
4940 ArgumentDictionaryTy &dict) {
4941 NodeValue featureMap;
4942 ASSIGN_VALUE_OR_RETURN_ERR(featureMap, getNodeValueByName(op.input(0)));
4943 NodeValue boxes;
4944 ASSIGN_VALUE_OR_RETURN_ERR(boxes, getNodeValueByName(op.input(1)));
4945 NodeValue batchIndices;
4946 ASSIGN_VALUE_OR_RETURN_ERR(batchIndices, getNodeValueByName(op.input(2)));
4947
4948 PoolingMode mode = PoolingMode::AVG;
4949 if (dict.count("mode")) {
4950 std::string modeStr;
4951 ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode")));
4952 if (modeStr == "avg") {
4953 mode = PoolingMode::AVG;
4954 } else if (modeStr == "max") {
4955 mode = PoolingMode::MAX;
4956 } else {
4957 return MAKE_ERR(strFormat("Invalid PoolingMode: %s", modeStr.c_str()));
4958 }
4959 }
4960
4961 bool rotated = false;
4962 if (dict.count("rotated")) {
4963 ASSIGN_VALUE_OR_RETURN_ERR(rotated, loadInt(dict.at("rotated")));
4964 }
4965
4966 bool aligned = false;
4967 if (dict.count("aligned")) {
4968 ASSIGN_VALUE_OR_RETURN_ERR(aligned, loadInt(dict.at("aligned")));
4969 }
4970
4971 uint32_t outputHeight = 1;
4972 if (dict.count("output_height")) {
4973 ASSIGN_VALUE_OR_RETURN_ERR(outputHeight, loadInt(dict.at("output_height")));
4974 }
4975
4976 uint32_t outputWidth = 1;
4977 if (dict.count("output_width")) {
4978 ASSIGN_VALUE_OR_RETURN_ERR(outputWidth, loadInt(dict.at("output_width")));
4979 }
4980
4981 uint32_t samplingRatio = 0;
4982 if (dict.count("sampling_ratio")) {
4983 ASSIGN_VALUE_OR_RETURN_ERR(samplingRatio,
4984 loadInt(dict.at("sampling_ratio")));
4985 }
4986
4987 float spatialScale = 1.0;
4988 if (dict.count("spatial_scale")) {
4989 ASSIGN_VALUE_OR_RETURN_ERR(spatialScale,
4990 loadFloat(dict.at("spatial_scale")));
4991 }
4992
4993 const std::string &opName = loadOperatorName(op);
4994 featureMap = G_->createTranspose(opName, featureMap, NCHW2NHWC);
4995 Node *N = G_->createROIAlign(opName, featureMap, boxes, batchIndices,
4996 outputHeight, outputWidth, samplingRatio,
4997 spatialScale, aligned, rotated, mode);
4998 N = G_->createTranspose(opName, N, NHWC2NCHW);
4999 RETURN_IF_ERR(addNodeAsOutput(op, N));
5000 return Error::success();
5001}
5002
5003Error ONNXModelLoader::loadMFCC(const ONNX_NAMESPACE::NodeProto &op,
5004 ArgumentDictionaryTy &dict) {
5005 NodeValue spectrogram;
5006 ASSIGN_VALUE_OR_RETURN_ERR(spectrogram, getNodeValueByName(op.input(0)));
5007
5008 // Get sample rate [Hz] (Required).
5009 float sampleRate;
5010 RETURN_ERR_IF_NOT(dict.count("sample_rate"),
5011 "ONNX MFCC 'sample_rate' attribute is required!");
5012 ASSIGN_VALUE_OR_RETURN_ERR(sampleRate, loadFloat(dict.at("sample_rate")));
5013
5014 // Get lower frequency [Hz] (Required).
5015 float lowerFrequency;
5016 RETURN_ERR_IF_NOT(dict.count("lower_frequency_limit"),
5017 "ONNX MFCC 'lower_frequency_limit' attribute is required!");
5018 ASSIGN_VALUE_OR_RETURN_ERR(lowerFrequency,
5019 loadFloat(dict.at("lower_frequency_limit")));
5020
5021 // Get upper frequency [Hz] (Required).
5022 float upperFrequency;
5023 RETURN_ERR_IF_NOT(dict.count("upper_frequency_limit"),
5024 "ONNX MFCC 'upper_frequency_limit' attribute is required!");
5025 ASSIGN_VALUE_OR_RETURN_ERR(upperFrequency,
5026 loadFloat(dict.at("upper_frequency_limit")));
5027
5028 // Get filter bank count (Required).
5029 int64_t filterBankCount;
5030 RETURN_ERR_IF_NOT(
5031 dict.count("filterbank_channel_count"),
5032 "ONNX MFCC 'filterbank_channel_count' attribute is required!");
5033 ASSIGN_VALUE_OR_RETURN_ERR(filterBankCount,
5034 loadInt(dict.at("filterbank_channel_count")));
5035
5036 // Get number of coefficients (Required).
5037 int64_t numCoefficients;
5038 RETURN_ERR_IF_NOT(dict.count("dct_coefficient_count"),
5039 "ONNX MFCC 'dct_coefficient_count' attribute is required!");
5040 ASSIGN_VALUE_OR_RETURN_ERR(numCoefficients,
5041 loadInt(dict.at("dct_coefficient_count")));
5042
5043 Node *N = G_->createMFCC(loadOperatorName(op), spectrogram, sampleRate,
5044 lowerFrequency, upperFrequency, filterBankCount,
5045 numCoefficients);
5046
5047 RETURN_IF_ERR(addNodeAsOutput(op, N));
5048 return Error::success();
5049}
5050
5051Error ONNXModelLoader::loadAsin(const ONNX_NAMESPACE::NodeProto &op,
5052 const ArgumentDictionaryTy &dict) {
5053 const std::string &opName = loadOperatorName(op);
5054 NodeValue in;
5055 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5056 auto outTy = mod_.uniqueType(*(in.getType()));
5057 Node *node = G_->createAsin(opName, outTy, in);
5058 RETURN_IF_ERR(addNodeAsOutput(op, node));
5059 return Error::success();
5060}
5061
5062Error ONNXModelLoader::loadAcos(const ONNX_NAMESPACE::NodeProto &op,
5063 const ArgumentDictionaryTy &dict) {
5064 const std::string &opName = loadOperatorName(op);
5065 NodeValue in;
5066 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5067 auto outTy = mod_.uniqueType(*(in.getType()));
5068 Node *node = G_->createAcos(opName, outTy, in);
5069 RETURN_IF_ERR(addNodeAsOutput(op, node));
5070 return Error::success();
5071}
5072
5073Error ONNXModelLoader::loadAtan(const ONNX_NAMESPACE::NodeProto &op,
5074 const ArgumentDictionaryTy &dict) {
5075 const std::string &opName = loadOperatorName(op);
5076 NodeValue in;
5077 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5078 auto outTy = mod_.uniqueType(*(in.getType()));
5079 Node *node = G_->createAtan(opName, outTy, in);
5080 RETURN_IF_ERR(addNodeAsOutput(op, node));
5081 return Error::success();
5082}
5083
5084Error ONNXModelLoader::loadSign(const ONNX_NAMESPACE::NodeProto &op,
5085 const ArgumentDictionaryTy &dict) {
5086 const std::string &opName = loadOperatorName(op);
5087 NodeValue in;
5088 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5089 Node *node = G_->createSign(opName, in);
5090 RETURN_IF_ERR(addNodeAsOutput(op, node));
5091 return Error::success();
5092}
5093
5094Error ONNXModelLoader::loadSoftmax(const ONNX_NAMESPACE::NodeProto &op,
5095 const ArgumentDictionaryTy &dict) {
5096
5097 const std::string &opName = loadOperatorName(op);
5098 NodeValue in;
5099 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5100
5101 RETURN_ERR_IF_NOT(in.dims().size() >= 2, "SoftMax input dims must be >= 2");
5102
5103 // Create a constant to store labels to be used in SoftMaxGradNode.
5104 auto selected =
5105 mod_.createConstant(ElemKind::Int64ITy, {in.dims()[0], 1}, "selected");
5106
5107 if (opsetVersion_ == 13) {
5108 int axis = in.dims().size() - 1;
5109 if (dict.count("axis")) {
5110 ASSIGN_VALUE_OR_RETURN_ERR(
5111 axis, loadAxis<int>(dict.at("axis"), in.dims().size()));
5112 }
5113 RETURN_ERR_IF_NOT(in.dims().size() == 4, "SoftMax 13 input dims must be 4");
5114 // Compute the shuffle layout based on axis input.
5115 std::vector<unsigned_t> shuffle;
5116 std::vector<unsigned_t> shuffleBack;
5117 switch (axis) {
5118 case 0:
5119 shuffle = {1u, 2u, 3u, 0u};
5120 shuffleBack = {3u, 0u, 1u, 2u};
5121 break;
5122
5123 case 1:
5124 shuffle = {0u, 2u, 3u, 1u};
5125 shuffleBack = {0u, 3u, 1u, 2u};
5126 break;
5127
5128 case 2:
5129 shuffle = {0u, 1u, 3u, 2u};
5130 shuffleBack = {0u, 1u, 3u, 2u};
5131 break;
5132
5133 case 3:
5134 shuffle = {0u, 1u, 2u, 3u};
5135 shuffleBack = {0u, 1u, 2u, 3u};
5136 break;
5137
5138 default:
5139 return MAKE_ERR("SoftMax Axis must be <=3");
5140 break;
5141 }
5142 auto *NH = G_->createTranspose(opName, in, shuffle);
5143 auto *FN = G_->createFlattenV1("reshapeInput", NH, axis);
5144 auto *SM = G_->createSoftMax(opName, FN, selected);
5145
5146 // The output should have the same shape as the original input.
5147 auto origInDims = NH->getResult().dims();
5148 auto *RN = G_->createReshape("reshapeOutput", SM, origInDims);
5149 auto *NC = G_->createTranspose(opName, RN, shuffleBack);
5150 RETURN_IF_ERR(addNodeAsOutput(op, NC));
5151 } else {
5152 // ONNX allows shapes like <N x 10 x 1 x 1 >. Flatten the inputs to the
5153 // softmax function. This is basimilar to a bitcast operation.
5154 int axis = 1;
5155 if (dict.count("axis")) {
5156 ASSIGN_VALUE_OR_RETURN_ERR(
5157 axis, loadAxis<int>(dict.at("axis"), in.dims().size()));
5158 }
5159 auto *FN = G_->createFlatten("reshapeInput", in, axis);
5160 auto *SM = G_->createSoftMax(opName, FN, selected);
5161
5162 // The output should have the same shape as the original input.
5163 auto origInDims = in.getType()->dims();
5164 auto *RN = G_->createReshape("reshapeOutput", SM, origInDims);
5165 RETURN_IF_ERR(addNodeAsOutput(op, RN));
5166 }
5167 return Error::success();
5168}
5169
5170Error ONNXModelLoader::loadLogSoftmax(const ONNX_NAMESPACE::NodeProto &op,
5171 const ArgumentDictionaryTy &dict) {
5172
5173 const std::string &opName = loadOperatorName(op);
5174 NodeValue in;
5175 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5176
5177 RETURN_ERR_IF_NOT(in.dims().size() >= 2,
5178 "LogSoftMax input dims must be >= 2");
5179
5180 // Create a constant to store labels to be used in SoftMaxGradNode.
5181 auto selected =
5182 mod_.createConstant(ElemKind::Int64ITy, {in.dims()[0], 1}, "selected");
5183
5184 if (opsetVersion_ == 13) {
5185 int axis = in.dims().size() - 1;
5186 if (dict.count("axis")) {
5187 ASSIGN_VALUE_OR_RETURN_ERR(
5188 axis, loadAxis<int>(dict.at("axis"), in.dims().size()));
5189 }
5190 RETURN_ERR_IF_NOT(in.dims().size() == 4,
5191 "LogSoftMax 13 input dims must be 4");
5192 // Compute the shuffle layout based on axis input.
5193 std::vector<unsigned_t> shuffle;
5194 std::vector<unsigned_t> shuffleBack;
5195 switch (axis) {
5196 case 0:
5197 shuffle = {1u, 2u, 3u, 0u};
5198 shuffleBack = {3u, 0u, 1u, 2u};
5199 break;
5200
5201 case 1:
5202 shuffle = {0u, 2u, 3u, 1u};
5203 shuffleBack = {0u, 3u, 1u, 2u};
5204 break;
5205
5206 case 2:
5207 shuffle = {0u, 1u, 3u, 2u};
5208 shuffleBack = {0u, 1u, 3u, 2u};
5209 break;
5210
5211 case 3:
5212 shuffle = {0u, 1u, 2u, 3u};
5213 shuffleBack = {0u, 1u, 2u, 3u};
5214 break;
5215
5216 default:
5217 return MAKE_ERR("LogSoftMax Axis must be <=3");
5218 break;
5219 }
5220 auto *NH = G_->createTranspose(opName, in, shuffle);
5221 auto *FN = G_->createFlattenV1("reshapeInput", NH, axis);
5222 auto *SM = G_->createLogSoftMax(opName, FN, selected);
5223
5224 // The output should have the same shape as the original input.
5225 auto origInDims = NH->getResult().dims();
5226 auto *RN = G_->createReshape("reshapeOutput", SM, origInDims);
5227 auto *NC = G_->createTranspose(opName, RN, shuffleBack);
5228 RETURN_IF_ERR(addNodeAsOutput(op, NC));
5229 } else {
5230 // ONNX allows shapes like <N x 10 x 1 x 1 >. Flatten the inputs to the
5231 // logsoftmax function. This is basimilar to a bitcast operation.
5232 int axis = 1;
5233 if (dict.count("axis")) {
5234 ASSIGN_VALUE_OR_RETURN_ERR(
5235 axis, loadAxis<int>(dict.at("axis"), in.dims().size()));
5236 }
5237 auto *FN = G_->createFlatten("reshapeInput", in, axis);
5238 auto *SM = G_->createLogSoftMax(opName, FN, selected);
5239
5240 // The output should have the same shape as the original input.
5241 auto origInDims = in.getType()->dims();
5242 auto *RN = G_->createReshape("reshapeOutput", SM, origInDims);
5243 RETURN_IF_ERR(addNodeAsOutput(op, RN));
5244 }
5245 return Error::success();
5246}
5247
5248Error ONNXModelLoader::loadScatterData(const ONNX_NAMESPACE::NodeProto &op,
5249 const ArgumentDictionaryTy &dict) {
5250
5251 const std::string &opName = loadOperatorName(op);
5252 NodeValue data;
5253 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
5254 NodeValue indices;
5255 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
5256 NodeValue values;
5257 ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));
5258
5259 RETURN_ERR_IF_NOT(indices.dims().size() == 2,
5260 opErrMsg(op, "Indices must be a 2D tensor!"));
5261 RETURN_ERR_IF_NOT(indices.dims()[0] == values.dims()[0],
5262 opErrMsg(op, "Indices and values must have same lengths!"));
5263
5264 bool cumulative = false;
5265 if (dict.count("cumulative")) {
5266 ASSIGN_VALUE_OR_RETURN_ERR(cumulative, loadInt(dict.at("cumulative")));
5267 }
5268
5269 Node *node = G_->createScatterData(opName, data, indices, values, cumulative);
5270 RETURN_IF_ERR(addNodeAsOutput(op, node));
5271 return Error::success();
5272}
5273
5274Error ONNXModelLoader::loadTopK(const ONNX_NAMESPACE::NodeProto &op,
5275 ArgumentDictionaryTy &dict) {
5276 const std::string &opName = loadOperatorName(op);
5277 NodeValue in;
5278 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5279 RETURN_ERR_IF_NOT(
5280 op.input_size() <= 2,
5281 opErrMsg(
5282 op,
5283 strFormat(
5284 "TopK: Maximum number of inputs is 2, but found input size %d ",
5285 op.input_size())));
5286 unsigned_t k = 0;
5287 if (op.input_size() > 1) {
5288 Constant *kConst = getConstantByNameOrNull(op.input(1));
5289 RETURN_ERR_IF_NOT(
5290 kConst, opErrMsg(op, "TopK: Non-constant k is not supported by Glow."));
5291 RETURN_ERR_IF_NOT(
5292 kConst->getElementType() == ElemKind::Int64ITy,
5293 opErrMsg(op,
5294 strFormat("TopK: k input must be of type Int64, but found "
5295 "input type '%s' ",
5296 kConst->getType()->getElementName().str().c_str())));
5297 auto constH = kConst->getPayload().getHandle<int64_t>();
5298 k = constH.at({0});
5299 } else {
5300 ASSIGN_VALUE_OR_RETURN_ERR(k, loadInt(dict["k"]));
5301 }
5302
5303 int lastDim = in.dims().size() - 1;
5304 int axis = lastDim;
5305 if (dict.count("axis")) {
5306 ASSIGN_VALUE_OR_RETURN_ERR(axis,
5307 loadAxis<int>(dict["axis"], in.dims().size()));
5308 }
5309
5310 RETURN_ERR_IF_NOT(
5311 axis == lastDim,
5312 opErrMsg(
5313 op,
5314 strFormat(
5315 "TopK: Currently only support axis %d being last dimension %d ",
5316 axis, lastDim)));
5317
5318 auto *R = G_->createTopK(opName, in, k);
5319 RETURN_IF_ERR(addNodeAsOutput(op, R));
5320 return Error::success();
5321}
5322
5323Error ONNXModelLoader::loadLoop(const ONNX_NAMESPACE::NodeProto &op,
5324 const ArgumentDictionaryTy &dict) {
5325 int64_t maxTripCount;
5326 bool ignoreMaxTripCount = op.input(0).empty();
5327 if (!ignoreMaxTripCount) {
5328 Constant *M = getConstantByNameOrNull(op.input(0));
5329 RETURN_ERR_IF_NOT(M, "Loop operator M input must be a constant.");
5330 RETURN_ERR_IF_NOT(M->getElementType() == ElemKind::Int64ITy,
5331 "Loop operator M input must be int64.");
5332 maxTripCount = (M->getPayload().getHandle<int64_t>()).raw(0);
5333
5334 RETURN_ERR_IF_NOT(
5335 maxTripCount >= 0,
5336 strFormat("Loop operator trip count (%ld) must be positive",
5337 maxTripCount));
5338 }
5339
5340 bool condOrig = false;
5341 bool ignoreCond = op.input(1).empty();
5342 if (!ignoreCond) {
5343 Constant *cond = getConstantByNameOrNull(op.input(1));
5344 RETURN_ERR_IF_NOT(cond, "Loop operator cond input must be a constant.");
5345 RETURN_ERR_IF_NOT(cond->getElementType() == ElemKind::BoolTy,
5346 "Loop operator cond input must be bool.");
5347 condOrig = (cond->getPayload().getHandle<bool>()).raw(0);
5348 }
5349
5350 RETURN_ERR_IF_NOT(dict.count("body"), "Loop body not found.");
5351 auto body = dict.at("body")->g();
5352
5353 // 2 + N (i.e., maximum trip-count, cond, and N)
5354 const int numLoopInputs = op.input_size();
5355 // N + K (final N loop carried dependency values then K scan_outputs)
5356 const int numLoopOutputs = op.output_size();
5357 // 2 + N (i.e., iteration_num, condition, and N loop carried dependencies)
5358 const int numBodyInputs = body.input_size();
5359 // 1 + N + K (i.e., condition, N loop carried dependencies, and K
5360 // scan_outputs)
5361 const int numBodyOutputs = body.output_size();
5362
5363 RETURN_ERR_IF_NOT(numLoopInputs >= 2 && numLoopInputs == numBodyInputs &&
5364 numLoopOutputs == numBodyOutputs - 1 &&
5365 numLoopOutputs >= 1,
5366 "Mismatched inputs/outputs of Loop and subgraph.");
5367
5368 const int numK = numBodyOutputs - numBodyInputs + 1;
5369 const int numN = numLoopInputs - 2;
5370
5371 CHECK_GE(numN, 0) << "Invalid number of v_initial in Loop operator : "
5372 << numN;
5373 CHECK_GE(numK, 0) << "Invalid number of scan_outputs in Loop operator : "
5374 << numK;
5375
5376 // Handle a loop with no iterations.
5377 if ((!ignoreMaxTripCount && maxTripCount == 0) ||
5378 (!ignoreCond && !condOrig)) {
5379 // No need to load the subgraph, just connect loop's N v_initial to
5380 // N v_final and empty Tensor for K scan_outputs.
5381 llvm::SmallVector<NodeValue, 4> outputs;
5382 outputs.reserve(numLoopOutputs);
5383 // Connect N v_initial to N v_final.
5384 for (int i = 0; i < numN; ++i) {
5385 NodeValue in;
5386 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i + 2)));
5387 outputs.push_back(in);
5388 }
5389 // Connect empty Tensors for K scan_outputs.
5390 for (int k = 0; k < numK; ++k) {
5391 int ki = (body.input_size() - 1) + k;
5392 auto scan_output = body.output(ki);
5393 Type ty;
5394 ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(scan_output));
5395 Tensor T(ty);
5396 T.zero();
5397 Constant *c = G_->getParent()->createConstant("empty", std::move(T));
5398 outputs.push_back(G_->createExpandDims("unsqueeze_K", c, {0}));
5399 }
5400 RETURN_IF_ERR(assignNodeOutputs(op, outputs));
5401 return Error::success();
5402 }
5403
5404 // Now, there is at least one iteration.
5405 llvm::SmallVector<NodeValue, 4> inputs;
5406 inputs.reserve(numBodyInputs);
5407
5408 // Collect names of N loop carried dependencies from Loop op. It will be used
5409 // to connect Loop's inputs with subgraph's inputs for the first iteration.
5410 llvm::SmallVector<llvm::StringRef, 4> namesOfLoopNs;
5411 namesOfLoopNs.reserve(numN);
5412 for (int i = 0; i < numN; ++i) {
5413 namesOfLoopNs.push_back(op.input(i + 2));
5414 }
5415
5416 // Collect output node names for the next iteration. It will be used when
5417 // connecting subgraph's outputs with subgraph's input between iteration. Add
5418 // names just for N loop carried dependencies. No need to add names for K
5419 // scan_outputs because they are not input of subgraph.
5420 llvm::SmallVector<llvm::StringRef, 4> namesOfBodyOutputNs;
5421 namesOfBodyOutputNs.reserve(numN);
5422 for (int i = 0; i < numN; ++i) {
5423 namesOfBodyOutputNs.push_back(body.output(i + 1).name());
5424 }
5425
5426 // Collect names of K scan_outputs.
5427 llvm::SmallVector<llvm::SmallVector<NodeValue, 2>, 2> scanOutputKs;
5428 scanOutputKs.reserve(numK);
5429 for (int k = 0; k < numK; ++k) {
5430 llvm::SmallVector<NodeValue, 2> scanOutputIter;
5431 scanOutputIter.reserve(loopUnrollLimit);
5432 scanOutputKs.push_back(scanOutputIter);
5433 }
5434
5435 auto getDummyCond = [&]() -> Expected<NodeValue> {
5436 RETURN_ERR_IF_NOT(ignoreCond, "Unexpected empty name in cond in Loop");
5437 Tensor dummyCondT(ElemKind::BoolTy, {1});
5438 dummyCondT.zero();
5439 Constant *dummyCondNode =
5440 G_->getParent()->createConstant("dumpCond", std::move(dummyCondT));
5441 return dummyCondNode->getOutput();
5442 };
5443
5444 auto getIterationNumConst = [&](int64_t val) -> Constant * {
5445 Tensor T(ElemKind::Int64ITy, {1});
5446 T.getHandle<int64_t>() = {val};
5447 Constant *C = G_->getParent()->createConstant("const", std::move(T));
5448 return C;
5449 };
5450
5451 auto prepareNextIteration =
5452 [&](int64_t iterationNum, NodeValue condNode,
5453 llvm::ArrayRef<llvm::StringRef> outputNames) -> Error {
5454 inputs.clear();
5455 // Set iteration_num.
5456 inputs.push_back(getIterationNumConst(iterationNum));
5457 // Set condition.
5458 inputs.push_back(condNode);
5459 // Set N loop carried dependencies.
5460 for (auto oName : outputNames) {
5461 NodeValue in;
5462 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(oName));
5463 inputs.push_back(in);
5464 }
5465 RETURN_IF_ERR(assignGraphInputs(body, inputs));
5466 return Error::success();
5467 };
5468
5469 auto accumulateScanOutputs = [&]() -> Error {
5470 // Accumulate scan-outputs values.
5471 for (int k = 0; k < numK; ++k) {
5472 int ki = (body.input_size() - 1) + k;
5473 auto scan_output = body.output(ki);
5474 NodeValue outK;
5475 ASSIGN_VALUE_OR_RETURN_ERR(outK, getNodeValueByName(scan_output.name()));
5476 Node *unsqueezedOutK = G_->createExpandDims("unsqueezed_K", outK, {0});
5477 scanOutputKs[k].push_back(unsqueezedOutK);
5478 }
5479 return Error::success();
5480 };
5481
5482 auto loadSubgraph = [&](ONNX_NAMESPACE::GraphProto &graphDef) -> Error {
5483 RETURN_IF_ERR(loadInitializers(graphDef));
5484 RETURN_IF_ERR(loadNetwork(graphDef, false));
5485 return Error::success();
5486 };
5487
5488 auto canUnrollNextIter = [&](int64_t iterNum) -> Expected<bool> {
5489 if (!ignoreMaxTripCount && iterNum >= maxTripCount) {
5490 return false;
5491 }
5492 Constant *condOut = getConstantByNameOrNull(body.output(0).name());
5493 RETURN_ERR_IF_NOT(condOut,
5494 "Loop exit condition is unpredictable to be unrolled.");
5495 bool cond = (condOut->getPayload().getHandle<bool>()).raw(0) || ignoreCond;
5496 if (!cond) {
5497 return false;
5498 }
5499 RETURN_ERR_IF_NOT(
5500 iterNum < loopUnrollLimit,
5501 strFormat("Exceed the unroll limit (%u) while unrolling Loop operator.",
5502 loopUnrollLimit.getValue()));
5503 return cond;
5504 };
5505
5506 // Unroll the first iteration by connecting Loop's inputs to subgraph's
5507 // inputs.
5508 int64_t iterNum = 0;
5509 NodeValue condNode;
5510 ASSIGN_VALUE_OR_RETURN_ERR(condNode, op.input(1).empty()
5511 ? getDummyCond()
5512 : getNodeValueByName(op.input(1)));
5513 RETURN_IF_ERR(prepareNextIteration(iterNum, condNode, namesOfLoopNs));
5514 RETURN_IF_ERR(loadSubgraph(body));
5515 RETURN_IF_ERR(accumulateScanOutputs());
5516 ++iterNum;
5517 auto condCheck = canUnrollNextIter(iterNum);
5518
5519 RETURN_ERR_IF_NOT(condCheck, ERR_TO_STRING(condCheck.takeError()));
5520
5521 // Unroll remaining iterations by connecting outputs of previous iteration
5522 // with inputs of current iteration.
5523 while (condCheck && condCheck.get()) {
5524 NodeValue condOutNode;
5525 ASSIGN_VALUE_OR_RETURN_ERR(condOutNode,
5526 getNodeValueByName(body.output(0).name()));
5527 RETURN_IF_ERR(
5528 prepareNextIteration(iterNum, condOutNode, namesOfBodyOutputNs));
5529 RETURN_IF_ERR(loadSubgraph(body));
5530 RETURN_IF_ERR(accumulateScanOutputs());
5531 ++iterNum;
5532 condCheck = canUnrollNextIter(iterNum);
5533 RETURN_ERR_IF_NOT(condCheck, ERR_TO_STRING(condCheck.takeError()));
5534 }
5535
5536 // Hook final subgraph outputs to loop outputs.
5537 llvm::SmallVector<NodeValue, 4> outputs;
5538 outputs.reserve(numLoopOutputs);
5539 // Set outputs for N loop carried dependency values.
5540 for (int i = 0; i < numN; ++i) {
5541 NodeValue bodyout;
5542 ASSIGN_VALUE_OR_RETURN_ERR(bodyout,
5543 getNodeValueByName(body.output(i + 1).name()));
5544 outputs.push_back(bodyout);
5545 }
5546 // Set outputs for K scan_outputs.
5547 for (int k = 0; k < numK; ++k) {
5548 outputs.push_back(G_->createConcat("concat", scanOutputKs[k], 0));
5549 }
5550 RETURN_IF_ERR(assignNodeOutputs(op, outputs));
5551 return Error::success();
5552}
5553
5554Expected<TypeRef>
5555ONNXModelLoader::loadTypeFromAttributes(unsigned resNo,
5556 ArgumentDictionaryTy &dict) {
5557 // Load ElemKind.
5558 std::string elemKindStr;
5559 ASSIGN_VALUE_OR_RETURN_ERR(
5560 elemKindStr, loadStr(dict[getTypeAttrID(resNo, elemKindSignifier)]));
5561 const ElemKind k = Type::getElementKindFromName(elemKindStr);
5562
5563 // Load Shape. Note that we allow for empty shapes here because 0 dimensional
5564 // shapes are allowed (representing scalars).
5565 std::vector<dim_t> shape;
5566 ASSIGN_VALUE_OR_RETURN_ERR(
5567 shape, getShape<dim_t>(dict[getTypeAttrID(resNo, shapeSignifier)],
5568 /* allowEmptyShape */ true));
5569
5570 // Load strides. Note that we allow for empty strides here because 0
5571 // dimensional shapes are allowed (representing scalars).
5572 std::vector<dim_t> strides;
5573 auto stridesKey = getTypeAttrID(resNo, stridesSignifier);
5574 if (dict.count(stridesKey)) {
5575 ASSIGN_VALUE_OR_RETURN_ERR(strides,
5576 getShape<dim_t>(dict[stridesKey],
5577 /* allowEmptyShape */ true));
5578 }
5579
5580 // Create and return uniqued non-quantized Type.
5581 if (!isQuantizedElemKind(k)) {
5582 return getTypeWithCustomStrides(mod_, mod_.uniqueType(k, shape), strides);
5583 }
5584
5585 // Must be quantized kind, so get scale/offset and create and return uniqued
5586 // quantized Type.
5587 float scale;
5588 ASSIGN_VALUE_OR_RETURN_ERR(
5589 scale, loadFloat(dict[getTypeAttrID(resNo, qScaleSignifier)]));
5590 int32_t offset;
5591 ASSIGN_VALUE_OR_RETURN_ERR(
5592 offset, loadInt(dict[getTypeAttrID(resNo, qOffsetSignifier)]));
5593
5594 // If we have a scale of dummyScale, then this must be a dummy pair of
5595 // scale/offset. Look up the actual scale/offset to use as previously loaded,
5596 // using the offset as the key to updatedTQPs_. Skip fused kinds because
5597 // scales are already dummies.
5598 if (replaceDummyTQPs_ && scale == dummyScale &&
5599 !isFusedQuantizedElemKind(k)) {
5600 TensorQuantizationParams TQP;
5601 ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(offset));
5602 scale = TQP.scale;
5603 offset = TQP.offset;
5604 }
5605
5606 return getTypeWithCustomStrides(
5607 mod_, mod_.uniqueType(k, shape, scale, offset), strides);
5608}
5609
5610Expected<Node *>
5611ONNXModelLoader::tryLoadGlowCustomOp(llvm::StringRef typeName,
5612 const ONNX_NAMESPACE::NodeProto &op,
5613 ArgumentDictionaryTy &dict) {
5614 const std::string &opName = loadOperatorName(op);
5615
5616// Try all automatically generated import cases.
5617#include "glow/AutoGenNodesImport.h"
5618
5619 // If we get here then no case handled the op, so return nullptr.
5620 return nullptr;
5621}
5622
5623/// Load Node options for \p loadedNode from \p dict and set in \p nodeInfo.
5624/// These are specified in the format "NodeOpt_BACKENDNAME_OPTIONNAME".
5625static Error loadPerNodeOptions(const Node *loadedNode,
5626 BackendSpecificNodeInfo &nodeInfo,
5627 ArgumentDictionaryTy &dict) {
5628 // Look through all attributes in the dict for ones that have NodeOpt_ prefix.
5629 for (const auto &attrPair : dict) {
5630 // Split across the first '_' and check if it has the "NodeOpt" prefix.
5631 auto splitPair = llvm::StringRef(attrPair.first).split('_');
5632 if (splitPair.first == attrPair.first && splitPair.first == "") {
5633 // No '_' found, so continue.
5634 continue;
5635 }
5636 if (splitPair.first != nodeOptSignifier) {
5637 // Prefix is not "NodeOpt_", so continue.
5638 continue;
5639 }
5640
5641 // Must have a NodeOpt, so check it has strings and load them into nodeInfo.
5642 const ONNX_NAMESPACE::AttributeProto *attr = attrPair.second;
5643 RETURN_ERR_IF_NOT(attr->strings_size() > 0,
5644 strFormat("%s in %s has no strings",
5645 attrPair.first.c_str(),
5646 loadedNode->getName().data()));
5647 std::vector<std::string> &attrVals =
5648 nodeInfo[loadedNode->getParent()][loadedNode][splitPair.second];
5649 for (const std::string &s : attr->strings()) {
5650 attrVals.push_back(s);
5651 }
5652 }
5653 return Error::success();
5654}
5655
5656Error ONNXModelLoader::loadIf(const ONNX_NAMESPACE::NodeProto &op,
5657 const ArgumentDictionaryTy &dict) {
5658 Constant *condition = getConstantByNameOrNull(op.input(0));
5659 RETURN_ERR_IF_NOT(condition, "Only constant condition is supported!");
5660 RETURN_ERR_IF_NOT(condition->getElementType() == ElemKind::BoolTy,
5661 "Condition must be boolean!");
5662
5663 RETURN_ERR_IF_NOT(dict.count("then_branch"), "Then branch not found!");
5664 RETURN_ERR_IF_NOT(dict.count("else_branch"), "Else branch not found!");
5665 RETURN_ERR_IF_NOT(dict.at("then_branch")->type() ==
5666 ONNX_NAMESPACE::AttributeProto::GRAPH,
5667 "Only Subgraph branches are supported.");
5668 RETURN_ERR_IF_NOT(dict.at("else_branch")->type() ==
5669 ONNX_NAMESPACE::AttributeProto::GRAPH,
5670 "Only Subgraph branches are supported.");
5671
5672 auto loadSubgraph = [&](ONNX_NAMESPACE::GraphProto &graphDef) -> Error {
5673 RETURN_IF_ERR(loadInitializers(graphDef));
5674 RETURN_IF_ERR(loadNetwork(graphDef, false));
5675 return Error::success();
5676 };
5677
5678 if (condition->getPayload().getHandle<bool>().isZero()) {
5679 auto ifFalse = dict.at("else_branch")->g();
5680 RETURN_ERR_IF_NOT(ifFalse.output_size() == 1,
5681 "Only single output 'else' subgraph is supported.");
5682 RETURN_IF_ERR(loadSubgraph(ifFalse));
5683 NodeValue ifFalseVal;
5684 ASSIGN_VALUE_OR_RETURN_ERR(ifFalseVal,
5685 getNodeValueByName(ifFalse.output(0).name()));
5686 RETURN_IF_ERR(addNodeAsOutput(op, ifFalseVal));
5687 } else {
5688 auto ifTrue = dict.at("then_branch")->g();
5689 RETURN_ERR_IF_NOT(ifTrue.output_size() == 1,
5690 "Only single output 'then' subgraph is supported.");
5691 RETURN_IF_ERR(loadSubgraph(ifTrue));
5692 NodeValue ifTrueVal;
5693 ASSIGN_VALUE_OR_RETURN_ERR(ifTrueVal,
5694 getNodeValueByName(ifTrue.output(0).name()));
5695 RETURN_IF_ERR(addNodeAsOutput(op, ifTrueVal));
5696 }
5697 return Error::success();
5698}
5699
5700Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
5701 ArgumentDictionaryTy dict = loadArgumentMap(op);
5702 const std::string &typeName = op.op_type();
5703 mod_.registerOriginalName(op.name());
5704
5705 if (useGlowCustomOps_) {
5706 Node *loadedNode;
5707 ASSIGN_VALUE_OR_RETURN_ERR(loadedNode,
5708 tryLoadGlowCustomOp(typeName, op, dict));
5709 if (loadedNode) {
5710 if (!perNodeOpts_) {
5711 return Error::success();
5712 }
5713 return loadPerNodeOptions(loadedNode, *perNodeOpts_, dict);
5714 }
5715
5716 // These are handled earlier when loading initializers and inputs and so can
5717 // be safely ignored here.
5718 if (typeName == constFoldSubgraphNodeName ||
5719 typeName == staticPHDummyNodeName) {
5720 return Error::success();
5721 }
5722
5723 // Identity is the only official ONNX op used with useGlowCustomOps. Let it
5724 // fall through to logic to handle below, otherwise return error.
5725 if (typeName != "Identity") {
5726 return MAKE_ERR("Failed to load operator " + typeName + " .",
5727 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
5728 }
5729 }
5730
5731 // Check if operator is supported in parent class, CommonOperatorLoader.
5732 bool tryLoadCommonOperatorResult;
5733 ASSIGN_VALUE_OR_RETURN_ERR(tryLoadCommonOperatorResult,
5734 tryLoadCommonOperator(typeName, op, dict));
5735 if (tryLoadCommonOperatorResult) {
5736 return Error::success();
5737 }
5738 if (typeName == "Loop") {
5739 return loadLoop(op, dict);
5740 }
5741
5742 if (typeName == "Constant") {
5743 return loadConstant(op, dict);
5744 }
5745 if (typeName == "Range") {
5746 return loadRange(op, dict);
5747 }
5748 if (typeName == "PRelu") {
5749 return loadPRelu(op, dict);
5750 }
5751 if (typeName == "Slice") {
5752 return loadSlice(op, dict);
5753 }
5754 if (typeName == "Sin" || typeName == "Cos") {
5755 return loadTrigonometricOps(typeName, op, dict);
5756 }
5757 if (typeName == "Erf") {
5758 return loadErf(op, dict);
5759 }
5760 if (typeName == "Conv") {
5761 // If the Conv operator has quantized inputs and
5762 // dict contains the scale and offset params, use
5763 // loadTensorwiseQuantizedConvolution.
5764 NodeValue in;
5765 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5766 return in.getType()->isQuantizedType() && dict.count("out_scale") &&
5767 dict.count("out_offset")
5768 ? loadTensorwiseQuantizedConvolution(op, dict)
5769 : loadConv(op, dict);
5770 }
5771 if (typeName == "ChannelwiseQuantizedConvolution") {
5772 return loadChannelwiseQuantizedConvolution(op, dict);
5773 }
5774 if (typeName == "MaxPool" || typeName == "AveragePool") {
5775 // If the pool operator has quantized inputs, use
5776 // loadTensorwiseQuantizedPool.
5777 NodeValue in;
5778 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
5779 return in.getType()->isQuantizedType() && dict.count("out_scale") &&
5780 dict.count("out_offset")
5781 ? loadTensorwiseQuantizedPool(op, dict, typeName)
5782 : loadPool(op, dict, typeName);
5783 }
5784 if (typeName == "GlobalAveragePool") {
5785 return loadGlobalAveragePool(op, dict);
5786 }
5787 if (typeName == "Squeeze") {
5788 return loadSqueeze(op, dict);
5789 }
5790 if (typeName == "Unsqueeze") {
5791 return loadUnsqueeze(op, dict);
5792 }
5793 if (typeName == "BatchNormalization") {
5794 return loadBatchNormalization(op, dict);
5795 }
5796 if (typeName == "InstanceNormalization") {
5797 return loadInstanceNormalization(op, dict);
5798 }
5799 if (typeName == "Concat") {
5800 return loadConcat(op, dict);
5801 }
5802 if (typeName == "FCTransposed") {
5803 return loadFCTransposed(op, dict);
5804 }
5805 if (typeName == "Gemm") {
5806 return loadGemm(op, dict);
5807 }
5808 if (typeName == "Transpose") {
5809 return loadTranspose(op, dict, "perm");
5810 }
5811 if (typeName == "ReduceSumSquare") {
5812 return loadReduceOp(typeName, op, dict);
5813 }
5814 if (typeName == "MatMul") {
5815 return loadMatMul(op, dict);
5816 }
5817 if (typeName == "Pad") {
5818 return loadPad(op, dict);
5819 }
5820 if (typeName == "Cast") {
5821 return loadCast(op, dict);
5822 }
5823 if (typeName == "HardSigmoid") {
5824 return loadHardSigmoid(op, dict);
5825 }
5826 if (typeName == "LeakyRelu") {
5827 return loadLeakyRelu(op, dict);
5828 }
5829 if (typeName == "SpaceToDepth") {
5830 return loadSpaceToDepth(op, dict);
5831 }
5832 if (typeName == "DepthToSpace") {
5833 return loadDepthToSpace(op, dict);
5834 }
5835 if (typeName == "ReduceL2") {
5836 return loadReduceL2(op, dict);
5837 }
5838 if (typeName == "ConstantOfShape") {
5839 return loadConstantOfShape(op, dict, false /* isSplat */);
5840 }
5841 if (typeName == "Tile") {
5842 return loadTile(op, dict);
5843 }
5844 if (typeName == "Expand") {
5845 return loadExpand(op, dict);
5846 }
5847 if (typeName == "Where") {
5848 return loadWhere(op, dict);
5849 }
5850 if (typeName == "RNN") {
5851 return loadRNN(op, dict);
5852 }
5853 if (typeName == "GRU") {
5854 return loadGRU(op, dict);
5855 }
5856 if (typeName == "LSTM") {
5857 return loadLSTM(op, dict);
5858 }
5859 if (typeName == "Clip") {
5860 return loadClip(op, dict);
5861 }
5862 if (typeName == "Equal") {
5863 return loadCmpEQ(op, dict);
5864 }
5865 if (typeName == "CmpLTE") {
5866 return loadCmpLTE(op, dict);
5867 }
5868 if (typeName == "Mean") {
5869 return loadMean(op, dict);
5870 }
5871 if (typeName == "Select") {
5872 return loadSelect(op, dict);
5873 }
5874 if (typeName == "Quantize") {
5875 return loadQuantize(op, dict);
5876 }
5877 if (typeName == "QuantizeLinear") {
5878 return loadQuantizeLinear(op, dict);
5879 }
5880 if (typeName == "ConvertTo") {
5881 return loadConvertTo(op, dict);
5882 }
5883 if ((typeName == "Dequantize") || (typeName == "DequantizeLinear")) {
5884 return loadDequantize(op, dict);
5885 }
5886 if (typeName == "Regression") {
5887 return loadRegression(op, dict);
5888 }
5889 if (typeName == "BatchedAdd") {
5890 return loadBatchedAdd(op, dict);
5891 }
5892 if (typeName == "CumSum") {
5893 return loadCumSum(op, dict);
5894 }
5895 if ((typeName == "ScatterAssign") || (typeName == "ScatterND")) {
5896 return loadScatterAssign(op, dict);
5897 }
5898 if (typeName == "IntLookupTable") {
5899 return loadIntLookupTable(op, dict);
5900 }
5901 if (typeName == "LengthsRangeFill") {
5902 return loadLengthsRangeFill(op, dict);
5903 }
5904 if (typeName == "RescaleQuantized") {
5905 return loadRescaleQuantized(op, dict);
5906 }
5907 if (typeName == "RowwiseQuantizedSparseLengthsWeightedSum") {
5908 return loadRowwiseQuantizedSparseLengthsWeightedSum(op, dict);
5909 }
5910 if (typeName == "FusedRowwiseQuantizedSparseLengthsWeightedSum") {
5911 return loadFusedRowwiseQuantizedSparseLengthsWeightedSum(op, dict);
5912 }
5913 if (typeName == "FusedRowwiseQuantizedSparseLengthsSum") {
5914 return loadFusedRowwiseQuantizedSparseLengthsSum(op, dict);
5915 }
5916 if (typeName == "FullyConnected") {
5917 return loadFullyConnected(op, dict);
5918 }
5919 if (typeName == "RowwiseQuantizedFullyConnected") {
5920 return loadRowwiseQuantizedFullyConnected(op, dict);
5921 }
5922 if (typeName == "Splat") {
5923 return loadSplat(op, dict);
5924 }
5925 if (typeName == "InsertTensor") {
5926 return loadInsertTensor(op, dict);
5927 }
5928 if (typeName == "ArgMin") {
5929 return loadArgMinMax(op, dict, true);
5930 }
5931 if (typeName == "ArgMax") {
5932 return loadArgMinMax(op, dict, false);
5933 }
5934 if (typeName == "NonMaxSuppressionV4") {
5935 return loadNonMaxSuppression(op, dict, true);
5936 }
5937 if (typeName == "NonMaxSuppression") {
5938 return loadNonMaxSuppression(op, dict, false);
5939 }
5940 if (typeName == "ConvTranspose") {
5941 return loadConvTranspose(op, dict);
5942 }
5943 if (typeName == "If") {
5944 return loadIf(op, dict);
5945 }
5946 if (typeName == "AdaptiveAvgPool") {
5947 return loadAdaptiveAvgPool(op, dict);
5948 }
5949 if (typeName == "Flip") {
5950 return loadFlip(op, dict);
5951 }
5952 if (typeName == "AudioSpectrogram") {
5953 return loadAudioSpectrogram(op, dict);
5954 }
5955 if (typeName == "RoiAlign") {
5956 return loadROIAlign(op, dict);
5957 }
5958 if (typeName == "MFCC") {
5959 return loadMFCC(op, dict);
5960 }
5961 if (typeName == "Identity") {
5962 return loadIdentity(op, dict);
5963 }
5964 if (typeName == "Upsample") {
5965 return loadUpsample(op, dict);
5966 }
5967 if (typeName == "Resize") {
5968 return loadResize(op, dict);
5969 }
5970 if (typeName == "NonZero") {
5971 return loadNonZero(op, dict);
5972 }
5973 if (typeName == "Acos") {
5974 return loadAcos(op, dict);
5975 }
5976 if (typeName == "Asin") {
5977 return loadAsin(op, dict);
5978 }
5979 if (typeName == "Atan") {
5980 return loadAtan(op, dict);
5981 }
5982 if (typeName == "Sign") {
5983 return loadSign(op, dict);
5984 }
5985 if (typeName == "Softmax") {
5986 return loadSoftmax(op, dict);
5987 }
5988 if (typeName == "LogSoftmax") {
5989 return loadLogSoftmax(op, dict);
5990 }
5991 if (typeName == "ScatterData") {
5992 return loadScatterData(op, dict);
5993 }
5994 if (typeName == "TopK") {
5995 return loadTopK(op, dict);
5996 }
5997
5998 return MAKE_ERR("Failed to load operator " + typeName + " .",
5999 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
6000}
6001
6002void ONNXModelLoader::deleteConstFoldFunctions() {
6003 for (Function *constFoldF : constFoldFuns_) {
6004 mod_.eraseFunction(constFoldF);
6005 }
6006}
6007
6008Expected<Constant *>
6009ONNXModelLoader::runDeserializedConstFold(llvm::StringRef initializerName,
6010 llvm::StringRef outputName) {
6011 NodeValue NV;
6012 ASSIGN_VALUE_OR_RETURN_ERR(NV, getNodeValueByName(outputName));
6013
6014 // Force folding single splats, because we're folding a constant folding
6015 // subgraph, and so we know the exported model already decided to fold it
6016 // (normally the backend decides whether to fold it or not).
6017 std::vector<Constant *> constResults =
6018 constantFold(NV.getNode(), /* foldSingleSplats */ true);
6019 RETURN_ERR_IF_NOT(constResults.size() > 0,
6020 strFormat("Constant folding did not occur for %s",
6021 NV.getNode()->getName().data()));
6022 RETURN_ERR_IF_NOT(NV.getResNo() < constResults.size(),
6023 strFormat("Needed result %u from const folding results, "
6024 "but only got %lu results",
6025 NV.getResNo(), constResults.size()));
6026 Constant *foldedC = constResults[NV.getResNo()];
6027
6028 // Now we have the final Constant we want and it exists in the module. Set its
6029 // name to the actual initializer it came with if not already named that.
6030 if (foldedC->getName() != initializerName) {
6031 RETURN_ERR_IF_NOT(
6032 mod_.getConstantByName(initializerName) == nullptr,
6033 strFormat("Already had a Constant by name %s", initializerName.data()));
6034 foldedC->setName(initializerName);
6035 }
6036 RETURN_ERR_IF_NOT(
6037 nodeValueByName_.count(initializerName) == 0,
6038 strFormat("Should not have been a Constant by name %s registered yet",
6039 initializerName.data()));
6040 nodeValueByName_[initializerName] = foldedC->getOutput();
6041
6042 return foldedC;
6043}
6044
6045Expected<Constant *> ONNXModelLoader::replaySerializedConstFold(
6046 const ONNX_NAMESPACE::TensorProto &in, ONNX_NAMESPACE::GraphProto &net) {
6047 // Check if ins has a constant folding node associated with it.
6048 const char *constFoldNodeName = nullptr;
6049 int resNo = -1;
6050 for (const auto &keyVal : in.external_data()) {
6051 if (keyVal.key() == "ConstFoldNodeName") {
6052 constFoldNodeName = keyVal.value().data();
6053 continue;
6054 }
6055 if (keyVal.key() == "ConstFoldResNo") {
6056 ASSIGN_VALUE_OR_RETURN_ERR(resNo, getIntFromStr(keyVal.value()));
6057 continue;
6058 }
6059 }
6060 if (!constFoldNodeName) {
6061 return nullptr;
6062 }
6063 RETURN_ERR_IF_NOT(resNo >= 0,
6064 "Require ConstFoldResNo for Glow__ConstFoldSubgraph");
6065
6066 // Look through the ops in the graph to find the Node we need by name.
6067 ONNX_NAMESPACE::NodeProto *op = nullptr;
6068 for (int i = 0; i < net.node_size(); i++) {
6069 auto *curOp = net.mutable_node(i);
6070 if (loadOperatorName(*curOp) == constFoldNodeName) {
6071 op = curOp;
6072 break;
6073 }
6074 }
6075 RETURN_ERR_IF_NOT(
6076 op, strFormat("Did not find Node by name %s", constFoldNodeName));
6077 RETURN_ERR_IF_NOT(
6078 op->op_type() == constFoldSubgraphNodeName,
6079 strFormat("Node %s has type %s but expected Glow__ConstFoldSubgraph",
6080 constFoldNodeName, op->op_type().data()));
6081
6082 // Now look through the Node's attributes to find the subgraph.
6083 ONNX_NAMESPACE::GraphProto *subgraph = nullptr;
6084 for (auto &arg : *op->mutable_attribute()) {
6085 if (arg.name() == "ConstFoldSubgraph") {
6086 subgraph = arg.mutable_g();
6087 break;
6088 }
6089 }
6090
6091 RETURN_ERR_IF_NOT(subgraph, strFormat("Expected associated subgraph for %s",
6092 constFoldNodeName));
6093
6094 // We have the constant folding subgraph proto and need to load it to run it.
6095 const bool functionAlreadyLoaded = mod_.hasFunction(constFoldNodeName);
6096 Function *constFoldF = functionAlreadyLoaded
6097 ? mod_.getFunction(constFoldNodeName)
6098 : mod_.createFunction(constFoldNodeName);
6099 const auto insert = constFoldFuns_.insert(constFoldF);
6100 RETURN_ERR_IF_NOT(!(functionAlreadyLoaded && insert.second),
6101 strFormat("Function %s should only be processed once",
6102 constFoldNodeName));
6103
6104 // Temporarily swap in state for the constant folding Function.
6105 Function *origF = G_;
6106 G_ = constFoldF;
6107 llvm::StringMap<Function *> partNameToFunBackup;
6108 std::swap(partNameToFun_, partNameToFunBackup);
6109 // Make sure to restore original state of the loader when exiting this scope.
6110 ScopeGuard restoreOrigStateGuard([&]() {
6111 G_ = origF;
6112 std::swap(partNameToFun_, partNameToFunBackup);
6113 });
6114
6115 // Deserialize the Function if not already done.
6116 if (!functionAlreadyLoaded) {
6117 RETURN_IF_ERR(loadNetwork(*subgraph, /* loadingConstFoldSubgraph */ true));
6118 }
6119
6120 // Now that we have the Function deserialized, actually run and return the
6121 // resulting Constant that is foldled.
6122 RETURN_ERR_IF_NOT(subgraph->output_size() > resNo,
6123 strFormat("ConstFoldResNo %d invalid output idx.", resNo));
6124 return runDeserializedConstFold(in.name(), subgraph->output(resNo).name());
6125}
6126
6127Error ONNXModelLoader::loadInitializers(ONNX_NAMESPACE::GraphProto &net) {
6128 // Load the network initializers:
6129 for (const auto &in : net.initializer()) {
6130 // Replay any constant folding that occurred from previous optimization if
6131 // necessary. foldedC will be left as nullptr if no constant folding occurs.
6132 Constant *foldedC;
6133 ASSIGN_VALUE_OR_RETURN_ERR(foldedC, replaySerializedConstFold(in, net));
6134
6135 std::string layout = ANY_LAYOUT;
6136 if (useGlowCustomOps_) {
6137 ASSIGN_VALUE_OR_RETURN_ERR(
6138 layout, getAttrFromDocString(layoutSignifier, in.doc_string()));
6139 }
6140
6141 // If we already an existing module then expect to find Constants already
6142 // existing for each initializer.
6143 if (foldedC || loadIntoExistingModule_) {
6144 Constant *C = foldedC ? foldedC : mod_.getConstantByName(in.name());
6145 Type ty;
6146 ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(in));
6147
6148 // If the expected type is fused, and we are processing an initializer
6149 // with payload that already exists in the Module, then set the type to
6150 // fused here. This is because Caffe2 and ONNX (non-Glow-custom) protos do
6151 // not support fused ElemKinds, so we should explicitly set them as we do
6152 // during Caffe2ModelLoading.
6153 if (!foldedC && loadIntoExistingModule_ && ty.isFusedQuantizedType()) {
6154 RETURN_IF_ERR(setFusedTy(C, mod_.uniqueType(ty)));
6155 }
6156
6157 RETURN_IF_ERR(verifyPreexistingStorage(C, in.name(), ty, layout));
6158 nodeValueByName_[in.name()] = C->getOutput();
6159 continue;
6160 }
6161
6162 // If we are loading into an existing module then we would expect this
6163 // initializer doesn't have any data associated with it.
6164 Tensor T;
6165 RETURN_IF_ERR(loadTensor(in, &T, useGlowCustomOps_));
6166 RETURN_IF_ERR(createAndRegisterConstant(in.name(), std::move(T), layout));
6167 }
6168
6169 return Error::success();
6170}
6171
6172Error ONNXModelLoader::setOutputNodes(ONNX_NAMESPACE::GraphProto &net) {
6173 if (net.output_size() == 0) {
6174 return MAKE_ERR("Net output size must be greater than 0");
6175 }
6176
6177 for (int i = 0; i < net.output_size(); i++) {
6178 const auto &outputName = net.output(i).name();
6179 NodeValue r;
6180 ASSIGN_VALUE_OR_RETURN_ERR(r, getNodeValueByName(outputName));
6181
6182 const std::string &docString = net.output(i).doc_string();
6183
6184 Expected<std::string> saveName =
6185 getAttrFromDocString(saveNameSignifier, docString);
6186
6187 const bool hasSpecifiedSaveName =
6188 !ERR_TO_BOOL(saveName.takeError(), /* log */ false);
6189 const std::string &saveNodeName =
6190 hasSpecifiedSaveName ? saveName.get() : outputName;
6191
6192 std::pair<bool, std::string> trainableLayoutPair;
6193 ASSIGN_VALUE_OR_RETURN_ERR(
6194 trainableLayoutPair,
6195 getTrainableLayoutPairFromDocString(docString, useGlowCustomOps_));
6196
6197 // If loadIntoExistingModule_ then it's reasonable for there to be a savePH
6198 // already. If not then there shouldn't be one.
6199 Placeholder *savePH = mod_.getPlaceholderByNameSlow(outputName);
6200 if (!savePH) {
6201 savePH = mod_.createPlaceholder(r.getType(), outputName,
6202 trainableLayoutPair.first,
6203 trainableLayoutPair.second);
6204 } else {
6205 RETURN_ERR_IF_NOT(loadIntoExistingModule_,
6206 "Found pre-existing PH by name " + outputName);
6207 RETURN_IF_ERR(verifyPreexistingStorage(savePH, outputName, *r.getType(),
6208 trainableLayoutPair.second,
6209 trainableLayoutPair.first));
6210 }
6211 G_->createSave(saveNodeName, r, savePH, hasSpecifiedSaveName);
6212
6213 auto loaderNameOrErr = getAttrFromDocString(loaderNameSignifier, docString);
6214 const std::string &loaderName =
6215 !ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)
6216 ? loaderNameOrErr.get()
6217 : outputName;
6218 RETURN_ERR_IF_NOT(outputVarsByName_.try_emplace(loaderName, savePH).second,
6219 "Already had output placeholder by name " + loaderName);
6220 }
6221
6222 return Error::success();
6223}
6224
6225Error ONNXModelLoader::loadNetwork(ONNX_NAMESPACE::GraphProto &net,
6226 bool loadingConstFoldSubgraph) {
6227 /// Load the network operators:
6228 for (int i = 0; i < net.node_size(); i++) {
6229 auto &op = net.node(i);
6230
6231 // Always ignore these since they're dummy nodes used to just carry meta
6232 // info that is processed via setupOrigStaticTypeMap().
6233 if (op.op_type() == staticPHDummyNodeName) {
6234 continue;
6235 }
6236
6237 // Set up current partition to load into if relevant.
6238 if (partNameToFun_.size() && !loadingConstFoldSubgraph &&
6239 op.op_type() != constFoldSubgraphNodeName) {
6240 const ONNX_NAMESPACE::AttributeProto *pNameAttr = nullptr;
6241 for (auto &arg : op.attribute()) {
6242 if (arg.name() == "partitionName") {
6243 pNameAttr = &arg;
6244 break;
6245 }
6246 }
6247 RETURN_ERR_IF_NOT(pNameAttr, "partitionName not found for " + op.name());
6248 std::string pName;
6249 ASSIGN_VALUE_OR_RETURN_ERR(pName, loadStr(pNameAttr));
6250 auto it = partNameToFun_.find(pName);
6251 RETURN_ERR_IF_NOT(it != partNameToFun_.end(),
6252 "Did not find partition with name " + pName);
6253 G_ = it->second;
6254 }
6255 RETURN_ERR_IF_NOT(G_, "Internal Glow error; Graph was not valid.");
6256
6257 if (constFoldInLoader_) {
6258 auto tryFold = foldOperator(op);
6259 if (!tryFold) {
6260 // Error during constant folding; load the op normally below.
6261 const std::string errStr =
6262 ERR_TO_STRING(tryFold.takeError(), /* warning */ true);
6263 VLOG(1) << "Issue while trying to ConstantFold " << loadOperatorName(op)
6264 << ": " << errStr;
6265 } else if (tryFold.get()) {
6266 // Folded successfully, so skip loading the op below.
6267 continue;
6268 }
6269 }
6270 RETURN_IF_ERR(loadOperator(op));
6271 }
6272
6273 return Error::success();
6274}
6275
6276ONNXModelLoader::ONNXModelLoader(Function &F, Error *errPtr)
6277 : CommonOperatorLoader({}, {}, &F, errPtr) {
6278 deleteUnusedConstants();
6279}
6280
6281static Error checkStaticPH(const ONNX_NAMESPACE::ValueInfoProto &valueInfo,
6282 std::unordered_set<std::string> &staticInputs,
6283 bool useGlowCustomOps) {
6284 const std::string &inputName = valueInfo.name();
6285 if (useGlowCustomOps) {
6286 std::string isStatic;
6287 ASSIGN_VALUE_OR_RETURN_ERR(
6288 isStatic,
6289 getAttrFromDocString(staticSignifier, valueInfo.doc_string()));
6290 if (isStatic == "1") {
6291 staticInputs.emplace(inputName);
6292 }
6293 } else if (valueInfo.has_doc_string() &&
6294 valueInfo.doc_string() == staticSignifier) {
6295 staticInputs.emplace(inputName);
6296 }
6297 return Error::success();
6298}
6299
6300Error ONNXModelLoader::collectStaticInputs(ONNX_NAMESPACE::GraphProto &net) {
6301 for (int i = 0; i < net.input_size(); i++) {
6302 RETURN_IF_ERR(
6303 checkStaticPH(net.input(i), staticInputs_, useGlowCustomOps_));
6304 }
6305 return Error::success();
6306}
6307
6308Error ONNXModelLoader::checkInputs(ONNX_NAMESPACE::GraphProto &net,
6309 llvm::ArrayRef<const char *> tensorNames,
6310 llvm::ArrayRef<TypeRef> types) {
6311 for (size_t i = 0; i < tensorNames.size(); i++) {
6312 // Look if a corresponding input exists.
6313 for (int j = 0; j < net.input_size(); j++) {
6314 const ONNX_NAMESPACE::ValueInfoProto &valueInfo = net.input(j);
6315 const std::string &inputName = valueInfo.name();
6316
6317 if (inputName != tensorNames[i]) {
6318 continue;
6319 }
6320
6321 // Get tensor shape.
6322 llvm::ArrayRef<dim_t> dims = types[i]->dims();
6323
6324 // Get proto shape.
6325 std::vector<dim_t> dimsProto;
6326 ASSIGN_VALUE_OR_RETURN_ERR(
6327 dimsProto, getProtoShape(valueInfo.type().tensor_type().shape()));
6328
6329 // Check if the number of dimensions is consistent.
6330 RETURN_ERR_IF_NOT(dims.size() == dimsProto.size(),
6331 "Mismatch between input image and ONNX input shape");
6332
6333 // Allow batch dimensions to be different.
6334 for (size_t k = 1; k < dims.size(); k++) {
6335 RETURN_ERR_IF_NOT(dims[k] == dimsProto[k],
6336 "Mismatch between input image and ONNX input shape");
6337 }
6338
6339 RETURN_IF_ERR(checkStaticPH(valueInfo, staticInputs_, useGlowCustomOps_));
6340 }
6341 }
6342 return Error::success();
6343}
6344
6345Error ONNXModelLoader::setupOrigStaticTypeMap(ONNX_NAMESPACE::GraphProto &net) {
6346 if (!staticPlaceholderTypes_) {
6347 return Error::success();
6348 }
6349
6350 for (int i = 0; i < net.node_size(); i++) {
6351 auto &op = net.node(i);
6352 ArgumentDictionaryTy dict = loadArgumentMap(op);
6353 if (op.op_type() != staticPHDummyNodeName) {
6354 continue;
6355 }
6356 RETURN_ERR_IF_NOT(staticInputs_.count(op.name()),
6357 "Expected static input for " + op.name());
6358 TypeRef OT;
6359 ASSIGN_VALUE_OR_RETURN_ERR(
6360 OT, loadTypeFromAttributes(Storage::OutputIdx, dict));
6361 staticPlaceholderTypes_->emplace(op.name(), *OT);
6362 }
6363 RETURN_ERR_IF_NOT(
6364 staticPlaceholderTypes_->size() == staticInputs_.size(),
6365 strFormat(
6366 "Expected to find types for all static Placeholders. %lu vs. %lu",
6367 staticPlaceholderTypes_->size(), staticInputs_.size()));
6368 return Error::success();
6369}
6370
6371Error ONNXModelLoader::assignGraphInputs(const ONNX_NAMESPACE::GraphProto &net,
6372 llvm::ArrayRef<NodeValue> NVs) {
6373 RETURN_ERR_IF_NOT((dim_t)NVs.size() == (dim_t)net.input_size(),
6374 "Input size mismatch.");
6375 for (size_t i = 0; i < NVs.size(); i++) {
6376 nodeValueByName_[net.input(i).name()] = NVs[i];
6377 }
6378 return Error::success();
6379}
6380
6381Error ONNXModelLoader::loadModel(ONNX_NAMESPACE::ModelProto &modelDef,
6382 llvm::ArrayRef<const char *> tensorNames,
6383 llvm::ArrayRef<TypeRef> types,
6384 const Backend *B,
6385 bool loadInputsAsPlaceholdersForOnnx) {
6386 useGlowCustomOps_ = modelDef.producer_name() == "GlowONNXModelWriter";
6387
6388 RETURN_IF_ERR(setVersion(modelDef));
6389
6390 ONNX_NAMESPACE::GraphProto graphDef = modelDef.graph();
6391 RETURN_IF_ERR(checkInputs(graphDef, tensorNames, types));
6392 RETURN_IF_ERR(collectStaticInputs(graphDef));
6393 RETURN_IF_ERR(setupOrigStaticTypeMap(graphDef));
6394
6395 RETURN_IF_ERR(loadInitializers(graphDef));
6396
6397 if (tensorNames.empty() && types.empty()) {
6398 // Detect inputs without initializers and create placeholders.
6399 RETURN_IF_ERR(loadInputs(graphDef, loadInputsAsPlaceholdersForOnnx));
6400 }
6401
6402 RETURN_IF_ERR(loadNetwork(graphDef, /* loadingConstFoldSubgraph */ false));
6403
6404 RETURN_IF_ERR(setOutputNodes(graphDef));
6405
6406 RETURN_ERR_IF_NOT(G_->verify(B), "Function verification failed.");
6407
6408 deleteUnusedConstants();
6409
6410 deleteConstFoldFunctions();
6411
6412 RETURN_IF_ERR(verifyDummyQParams());
6413
6414 return Error::success();
6415}
6416
6417ONNXModelLoader::ONNXModelLoader(const std::string &modelDescFilename,
6418 llvm::ArrayRef<const char *> tensorNames,
6419 llvm::ArrayRef<TypeRef> types, Function &F,
6420 Error *errPtr, bool zipMode,
6421 BackendSpecificNodeInfo *perNodeOpts,
6422 bool disableConstFoldInLoader,
6423 bool loadIntoExistingModule, const Backend *B,
6424 const std::string *inputStringPtr)
6425 : CommonOperatorLoader(tensorNames, types, &F, errPtr,
6426 loadIntoExistingModule),
6427 perNodeOpts_(perNodeOpts), staticPlaceholderTypes_(nullptr) {
6428 // if errPtr already contains an error then don't continue with constructor
6429 if (errPtr && *errPtr) {
6430 return;
6431 }
6432
6433 if (disableConstFoldInLoader) {
6434 constFoldInLoader_ = false;
6435 }
6436
6437 auto setup = [&]() -> Error {
6438 ONNX_NAMESPACE::ModelProto modelDef;
6439 ASSIGN_VALUE_OR_RETURN_ERR(
6440 modelDef, loadProto(modelDescFilename, zipMode, inputStringPtr));
6441 RETURN_IF_ERR(loadModel(modelDef, tensorNames, types, B,
6442 /* loadInputsAsPlaceholdersForOnnx */ true));
6443 return Error::success();
6444 };
6445
6446 if (errPtr) {
6447 *errPtr = setup();
6448 } else {
6449 EXIT_ON_ERR(setup());
6450 }
6451}
6452
6453/// \returns a metadata prop found at \p key in \p modelDef.
6454static const char *getMetadataProp(const ONNX_NAMESPACE::ModelProto &modelDef,
6455 llvm::StringRef key) {
6456 for (const auto &keyVal : modelDef.metadata_props()) {
6457 if (keyVal.key() == key) {
6458 return keyVal.value().data();
6459 }
6460 }
6461 return nullptr;
6462}
6463
6464static Expected<int32_t>
6465getIntMetadataProp(const ONNX_NAMESPACE::ModelProto &modelDef,
6466 llvm::StringRef key) {
6467 const char *intStr = getMetadataProp(modelDef, key);
6468 RETURN_ERR_IF_NOT(intStr, "Did not find value for " + std::string(key));
6469 int32_t intVal;
6470 ASSIGN_VALUE_OR_RETURN_ERR(intVal, getIntFromStr(intStr));
6471 return intVal;
6472}
6473
6474Error ONNXModelLoader::setupPartitions(ONNX_NAMESPACE::ModelProto &modelDef,
6475 PrePartitionedConfig &PPC,
6476 llvm::StringRef rootName,
6477 int numPartitions) {
6478 PPC.funcName = rootName.str();
6479 PPC.resizeAndReserve(numPartitions);
6480
6481 for (int i = 0; i < numPartitions; i++) {
6482 const std::string partIdPrefix = getPartitionIdPrefix(i);
6483 const char *pName = getMetadataProp(modelDef, partIdPrefix + nameSignifier);
6484 RETURN_ERR_IF_NOT(pName, "Didn't find expected partition name");
6485
6486 // Load the partition name and create a Function with the same name.
6487 Function *PF = nullptr;
6488 if (loadIntoExistingModule_ && mod_.hasFunction(pName)) {
6489 PF = mod_.getFunction(pName);
6490 RETURN_ERR_IF_NOT(PF->getNodes().size() == 0,
6491 "Function must be empty to load into.");
6492 } else {
6493 PF = mod_.createFunction(pName);
6494 }
6495 partNameToFun_[pName] = PF;
6496 PPC.funcs.push_back(PF);
6497
6498 // Load all logical devices for the partition.
6499 int32_t numLogicalDevices;
6500 ASSIGN_VALUE_OR_RETURN_ERR(
6501 numLogicalDevices,
6502 getIntMetadataProp(modelDef,
6503 partIdPrefix + numLogicalDevicesSignifier));
6504 for (int j = 0; j < numLogicalDevices; j++) {
6505 DeviceIDTy ID;
6506 ASSIGN_VALUE_OR_RETURN_ERR(
6507 ID, getIntMetadataProp(modelDef,
6508 partIdPrefix + getLogicalDeviceSignfier(j)));
6509 PPC.logicalIDs[i].push_back(ID);
6510 }
6511
6512 // Get backend name.
6513 const char *backendName =
6514 getMetadataProp(modelDef, partIdPrefix + backendNameSignifier);
6515 RETURN_ERR_IF_NOT(backendName, "Didn't find Backend name");
6516 PPC.backendNames.emplace_back(backendName);
6517
6518 // Get backendHints.executionUnits. Note that we don't support serializing
6519 // SRAMPrioritization, so it's left empty.
6520 unsigned execUnits;
6521 ASSIGN_VALUE_OR_RETURN_ERR(
6522 execUnits,
6523 getIntMetadataProp(modelDef, partIdPrefix + executionUnitsSignifier));
6524 PPC.backendHints.push_back({execUnits, /* SRAMPrioritization */ {}});
6525
6526 // Load all backend-specific options.
6527 int32_t numBackendSpecificOpts;
6528 ASSIGN_VALUE_OR_RETURN_ERR(
6529 numBackendSpecificOpts,
6530 getIntMetadataProp(modelDef,
6531 partIdPrefix + numBackendSpecificOptsSignifier));
6532 for (int j = 0; j < numBackendSpecificOpts; j++) {
6533 const char *optKey = getMetadataProp(
6534 modelDef, partIdPrefix + getBackendSpecificOptKeySignifier(j));
6535 RETURN_ERR_IF_NOT(optKey,
6536 "Didn't find expected backend-specific option key");
6537 const char *optVal = getMetadataProp(
6538 modelDef, partIdPrefix + getBackendSpecificOptValSignifier(j));
6539 RETURN_ERR_IF_NOT(optVal,
6540 "Didn't find expected backend-specific option val");
6541 PPC.backendSpecificOpts[i][optKey] = optVal;
6542 }
6543
6544 // Get replicationCount.
6545 int32_t replicationCount;
6546 ASSIGN_VALUE_OR_RETURN_ERR(
6547 replicationCount,
6548 getIntMetadataProp(modelDef, partIdPrefix + replicationCountSignifier));
6549 PPC.replicationCounts.push_back(replicationCount);
6550 }
6551
6552 return Error::success();
6553}
6554
6555void ONNXModelLoader::setupPositionalIO(
6556 const ONNX_NAMESPACE::GraphProto &graph) {
6557 for (const auto &in : graph.input()) {
6558 if (staticInputs_.count(in.name())) {
6559 continue;
6560 }
6561 auto loaderNameOrErr =
6562 getAttrFromDocString(loaderNameSignifier, in.doc_string());
6563 if (ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)) {
6564 positionalInputNames_.clear();
6565 break;
6566 }
6567 positionalInputNames_.emplace_back(loaderNameOrErr.get());
6568 }
6569
6570 for (const auto &out : graph.output()) {
6571 auto loaderNameOrErr =
6572 getAttrFromDocString(loaderNameSignifier, out.doc_string());
6573 if (ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)) {
6574 positionalOutputNames_.clear();
6575 break;
6576 }
6577 positionalOutputNames_.emplace_back(loaderNameOrErr.get());
6578 }
6579}
6580
6581Error ONNXModelLoader::setupUpdatedTQPMap(
6582 ONNX_NAMESPACE::ModelProto &modelDef, uint32_t weightsCount,
6583 const onnxTensorDescriptorV1 *weightDescriptors) {
6584 // Check if we have the two strings in metadata props we need to do TQP
6585 // updating, and if not print a warning/return early.
6586 const char *originNameToUniqueOffsetMappingStr =
6587 getMetadataProp(modelDef, originNameToUniqueOffsetMappingSignifier);
6588 if (!originNameToUniqueOffsetMappingStr) {
6589 LOG(WARNING) << "Did not find \""
6590 << originNameToUniqueOffsetMappingSignifier
6591 << "\" in ONNX model, skipping setting updated TQP map.";
6592 return Error::success();
6593 }
6594
6595 const char *qParamC2ProtoStr = getMetadataProp(modelDef, "C2_with_q_params");
6596 if (!qParamC2ProtoStr) {
6597 LOG(WARNING) << "Did not find \"C2_with_q_params\" in ONNX model, skipping "
6598 "setting updated TQP map.";
6599 return Error::success();
6600 }
6601
6602 // Now load the qParamC2ProtoStr into a temporary dummy Caffe2ModelLoader,
6603 // which fills in the originNameToTQPMap based on that model and the weights.
6604 OriginNameToTQPMap originNameToTQPMap;
6605 Error err(Error::success());
6606 Module dummyMod;
6607 Caffe2ModelLoader tmpLoader(qParamC2ProtoStr, weightsCount, weightDescriptors,
6608 dummyMod, &err, &originNameToTQPMap,
6609 clipQuantRangeToFP16_);
6610 RETURN_IF_ERR(err);
6611
6612 // Now parse the originNameToUniqueOffsetMappingStr to find the original C2
6613 // name : unique offset pairs. These are formatted like
6614 // "name_op_a@0@@name_op_b@1@@", with @@ separating each pair, and @
6615 // separating name from unique offset.
6616 llvm::SmallVector<llvm::StringRef, 128> nameOffsetSplits;
6617 llvm::StringRef strRef = llvm::StringRef(originNameToUniqueOffsetMappingStr);
6618 strRef.split(nameOffsetSplits, offsetEndSig, /* MaxSplit */ -1,
6619 /* KeepEmpty */ false);
6620
6621 // Store the mapping into updatedTQPs_, where each unique offset is used as
6622 // the index into updatedTQPs_ to the actual TQP to use. I.e. we essentially
6623 // already have c2_name -> offset, and c2_name -> TQP, so we change this to
6624 // offset -> TQP.
6625 updatedTQPs_.resize(nameOffsetSplits.size());
6626 for (auto &nameOffsetSplit : nameOffsetSplits) {
6627 auto nameOffsetPair = nameOffsetSplit.split(offsetSepSig);
6628 int32_t idx;
6629 ASSIGN_VALUE_OR_RETURN_ERR(idx, getIntFromStr(nameOffsetPair.second));
6630 RETURN_ERR_IF_NOT(idx < int32_t(updatedTQPs_.size()),
6631 strFormat("Provided offset index %d not inside size "
6632 "of updatedTQPs_ %lu",
6633 idx, updatedTQPs_.size()));
6634
6635 auto it = originNameToTQPMap.find(nameOffsetPair.first.str());
6636 RETURN_ERR_IF_NOT(it != originNameToTQPMap.end(),
6637 strFormat("Did not find matching TQP for %s",
6638 nameOffsetPair.first.str().data()));
6639 updatedTQPs_[idx] = it->second;
6640 }
6641 return Error::success();
6642}
6643
6644ONNXModelLoader::ONNXModelLoader(
6645 const std::string &modelDescFilename,
6646 llvm::ArrayRef<const char *> tensorNames, llvm::ArrayRef<TypeRef> types,
6647 Module &mod, llvm::StringRef funName, PrePartitionedConfig *PPC,
6648 Error *errPtr, bool zipMode, BackendSpecificNodeInfo *perNodeOpts,
6649 bool loadIntoExistingModule, bool disableConstFoldInLoader,
6650 const Backend *B, const std::string *inputStringPtr)
6651 : CommonOperatorLoader(tensorNames, types, mod, errPtr,
6652 loadIntoExistingModule),
6653 perNodeOpts_(perNodeOpts), staticPlaceholderTypes_(nullptr) {
6654 // if errPtr already contains an error then don't continue with constructor
6655 if (errPtr && *errPtr) {
6656 return;
6657 }
6658
6659 if (disableConstFoldInLoader) {
6660 constFoldInLoader_ = false;
6661 }
6662
6663 auto setup = [&]() -> Error {
6664 ONNX_NAMESPACE::ModelProto modelDef;
6665 ASSIGN_VALUE_OR_RETURN_ERR(
6666 modelDef, loadProto(modelDescFilename, zipMode, inputStringPtr));
6667
6668 auto numPartitionsOrErr = getIntMetadataProp(modelDef, "numPartitions");
6669 if (!numPartitionsOrErr) {
6670 ERR_TO_VOID(numPartitionsOrErr.takeError(), /*log*/ false);
6671 G_ = mod_.createFunction(funName);
6672 } else {
6673 RETURN_ERR_IF_NOT(PPC, "No PrePartitionConfig to load partitions into");
6674 RETURN_IF_ERR(
6675 setupPartitions(modelDef, *PPC, funName, *numPartitionsOrErr));
6676 }
6677
6678 RETURN_IF_ERR(loadModel(modelDef, tensorNames, types, B,
6679 /* loadInputsAsPlaceholdersForOnnx */ true));
6680
6681 return Error::success();
6682 };
6683
6684 if (errPtr) {
6685 *errPtr = setup();
6686 } else {
6687 EXIT_ON_ERR(setup());
6688 }
6689}
6690
6691ONNXModelLoader::ONNXModelLoader(
6692 const void *model, uint32_t modelSize, uint32_t weightsCount,
6693 const onnxTensorDescriptorV1 *weightDescriptors, Function &F,
6694 bool loadInputsAsPlaceholdersForOnnx, Error *errPtr, bool constFoldInLoader,
6695 BackendSpecificNodeInfo *perNodeOpts)
6696 : CommonOperatorLoader({}, {}, &F, errPtr, true), perNodeOpts_(perNodeOpts),
6697 staticPlaceholderTypes_(nullptr) {
6698 // if errPtr already contains an error then don't continue with constructor
6699 if (errPtr && *errPtr) {
6700 return;
6701 }
6702
6703 // Always override the default for folding in this constructor.
6704 constFoldInLoader_ = constFoldInLoader;
6705
6706 // Lambda to setup the ONNXModelLoader and return any Errors that were
6707 // raised.
6708 auto setup = [&]() -> Error {
6709 ONNX_NAMESPACE::ModelProto modelDef;
6710 ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(model, modelSize));
6711
6712 RETURN_IF_ERR(loadWeights(weightsCount, weightDescriptors));
6713
6714 RETURN_IF_ERR(loadModel(modelDef, {}, {}, /* B */ nullptr,
6715 loadInputsAsPlaceholdersForOnnx));
6716
6717 return Error::success();
6718 };
6719
6720 if (errPtr) {
6721 *errPtr = setup();
6722 } else {
6723 EXIT_ON_ERR(setup());
6724 }
6725}
6726
6727ONNXModelLoader::ONNXModelLoader(
6728 const void *model, uint32_t modelSize, uint32_t weightsCount,
6729 const onnxTensorDescriptorV1 *weightDescriptors, Module &mod,
6730 llvm::StringRef funName, PrePartitionedConfig *PPC,
6731 bool loadInputsAsPlaceholdersForOnnx, Error *errPtr, bool constFoldInLoader,
6732 BackendSpecificNodeInfo *perNodeOpts,
6733 std::map<std::string, Type> *staticPlaceholderTypes, bool replaceDummyTQPs,
6734 bool clipQuantRangeToFP16)
6735 : CommonOperatorLoader({}, {}, mod, errPtr,
6736 /* loadIntoExistingModule */ true,
6737 /* originNameToTQPMap */ nullptr,
6738 /* loadUniquedDummyQParams */ false,
6739 replaceDummyTQPs, /* zeroScaleFP16Clip */ false,
6740 clipQuantRangeToFP16),
6741 perNodeOpts_(perNodeOpts),
6742 staticPlaceholderTypes_(staticPlaceholderTypes) {
6743 // if errPtr already contains an error then don't continue with constructor
6744 if (errPtr && *errPtr) {
6745 return;
6746 }
6747
6748 // Always override the default for folding in this constructor.
6749 constFoldInLoader_ = constFoldInLoader;
6750
6751 // Lambda to setup the ONNXModelLoader and return any Errors that were
6752 // raised.
6753 auto setup = [&]() -> Error {
6754 ONNX_NAMESPACE::ModelProto modelDef;
6755 ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(model, modelSize));
6756
6757 // If we're going to be replacing dummy TQPs then setup the updated TQP map,
6758 // which is used later on when loading each op.
6759 if (replaceDummyTQPs_) {
6760 // Check if the model has specified that clipQuantRangeToFP16 should be
6761 // overridden via the clipQuantRangeToFP16Key metadata prop. Do this
6762 // before updating TQP map, because this will affect the qparams loaded.
6763 for (const auto &keyVal : modelDef.metadata_props()) {
6764 if (keyVal.key() == clipQuantRangeToFP16Key && keyVal.value() == "1") {
6765 LOG(INFO) << "ONNXModelLoader found enabled "
6766 << clipQuantRangeToFP16Key;
6767 clipQuantRangeToFP16_ = true;
6768 break;
6769 }
6770 }
6771
6772 RETURN_IF_ERR(
6773 setupUpdatedTQPMap(modelDef, weightsCount, weightDescriptors));
6774 }
6775
6776 RETURN_IF_ERR(loadWeights(weightsCount, weightDescriptors));
6777
6778 auto numPartitionsOrErr = getIntMetadataProp(modelDef, "numPartitions");
6779 if (!numPartitionsOrErr) {
6780 ERR_TO_VOID(numPartitionsOrErr.takeError(), /*log*/ false);
6781 G_ = mod_.createFunction(funName);
6782 } else {
6783 RETURN_ERR_IF_NOT(PPC, "No PrePartitionConfig to load partitions into");
6784 RETURN_IF_ERR(
6785 setupPartitions(modelDef, *PPC, funName, *numPartitionsOrErr));
6786 }
6787
6788 RETURN_IF_ERR(loadModel(modelDef, {}, {}, /* B */ nullptr,
6789 loadInputsAsPlaceholdersForOnnx));
6790
6791 if (loadInputsAsPlaceholdersForOnnx) {
6792 setupPositionalIO(modelDef.graph());
6793 }
6794
6795 return Error::success();
6796 };
6797
6798 if (errPtr) {
6799 *errPtr = setup();
6800 } else {
6801 EXIT_ON_ERR(setup());
6802 }
6803}
6804