1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Importer/ONNXModelLoader.h" |
18 | |
19 | #include "glow/Base/Tensor.h" |
20 | #include "glow/Flags/Flags.h" |
21 | #include "glow/Graph/Graph.h" |
22 | #include "glow/Graph/Nodes.h" |
23 | #include "glow/Importer/Caffe2ModelLoader.h" |
24 | #include "glow/Support/Support.h" |
25 | #include "glow/Support/ZipUtils.h" |
26 | |
27 | #include "llvm/Support/Casting.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | |
30 | #include "google/protobuf/io/coded_stream.h" |
31 | #include "google/protobuf/io/tokenizer.h" |
32 | #include "google/protobuf/io/zero_copy_stream_impl.h" |
33 | #include "google/protobuf/text_format.h" |
34 | |
35 | #include <cstddef> |
36 | #include <cstdint> |
37 | #include <fstream> |
38 | #include <sstream> |
39 | #include <string> |
40 | #include <vector> |
41 | |
42 | using namespace glow; |
43 | using namespace glow::runtime; |
44 | using llvm::cast; |
45 | |
46 | namespace { |
47 | |
48 | llvm::cl::OptionCategory onnxModelLoaderCat("ONNX Model Loader Options" ); |
49 | |
50 | std::vector<std::string> onnxDefineSymbol; |
51 | llvm::cl::list<std::string, std::vector<std::string>> onnxDefineSymbolOpt( |
52 | "onnx-define-symbol" , llvm::cl::ZeroOrMore, |
53 | llvm::cl::location(onnxDefineSymbol), |
54 | llvm::cl::desc( |
55 | "Define (replace) the undefined symbols from the tensor descriptions\n" |
56 | "in the ONNX model with actual integer sizes. The undefined symbols \n" |
57 | "are marked in the proto description with the 'dim_param' field. For\n" |
58 | "example, if the model contains a tensor with the size described as \n" |
59 | "'None' x 3 x 224 x 224, the symbol 'None' can be replaced with an \n" |
60 | "actual integer size (for example 1) by using the following command \n" |
61 | "line option: \n" |
62 | " -onnx-define-symbol=None,1 \n" |
63 | "Multiple symbols can be defined using this option, for example: \n" |
64 | " -onnx-define-symbol=<symbol_name1>,<symbol_value1> \n" |
65 | " -onnx-define-symbol=<symbol_name2>,<symbol_value2> \n" |
66 | " ..................................................\n" ), |
67 | llvm::cl::value_desc("name,value" ), llvm::cl::cat(onnxModelLoaderCat)); |
68 | |
69 | llvm::cl::opt<bool> onnxExportRnnStatesOpt( |
70 | "onnx-export-rnn-states" , llvm::cl::init(false), llvm::cl::Optional, |
71 | llvm::cl::desc( |
72 | "Option to export the states of the ONNX RNN operators (for example \n" |
73 | "RNN, GRU, LSTM) as graph placeholders regardless of whether the \n" |
74 | "states are explicitly set or not in the graph. The placeholders are\n" |
75 | "also providing an automatic way for tracking the RNN states since \n" |
76 | "the states are updated automatically with the new RNN states after \n" |
77 | "each inference. Default is false." ), |
78 | llvm::cl::cat(onnxModelLoaderCat)); |
79 | |
80 | llvm::cl::opt<unsigned> loopUnrollLimit( |
81 | "loop-unroll-limit" , |
82 | llvm::cl::desc("Maximum unrollable iterations for the Loop operator" ), |
83 | llvm::cl::Optional, llvm::cl::init(20), llvm::cl::cat(onnxModelLoaderCat)); |
84 | |
85 | /// Parse the command line option and get the user defined map of symbols. |
86 | /// The command line option has the format <symbol_name>,<symbol_value>. |
87 | Expected<std::unordered_map<std::string, dim_t>> getSymbolMap() { |
88 | std::unordered_map<std::string, dim_t> symbolMap; |
89 | for (const auto &str : onnxDefineSymbol) { |
90 | auto strPair = llvm::StringRef(str).split(','); |
91 | llvm::StringRef name = strPair.first; |
92 | RETURN_ERR_IF_NOT(name.size() > 0, "ONNX defined symbol name is empty." ); |
93 | dim_t value; |
94 | RETURN_ERR_IF_NOT(!strPair.second.getAsInteger(0, value), |
95 | strFormat("ONNX defined symbol value '%s' is invalid." , |
96 | strPair.second.data())); |
97 | symbolMap[name.str()] = value; |
98 | } |
99 | return symbolMap; |
100 | } |
101 | |
102 | /// Get the shape of a TensorShapeProto given by \p shapeProto and return the |
103 | /// dimensions in the vector \p dim passed by reference. |
104 | Expected<std::vector<dim_t>> |
105 | getProtoShape(const ONNX_NAMESPACE::TensorShapeProto &shapeProto) { |
106 | std::vector<dim_t> dim; |
107 | for (auto d : shapeProto.dim()) { |
108 | if (d.has_dim_value()) { |
109 | // Proto shape has an explicit size given by the "dim_value" field. |
110 | dim.push_back(d.dim_value()); |
111 | } else if (d.has_dim_param()) { |
112 | // Proto shape has a symbolic size given by the "dim_param" field. Search |
113 | // the symbol in the user defined map of symbols. If the symbol is not |
114 | // found then raise an error. |
115 | auto symbolName = d.dim_param(); |
116 | std::unordered_map<std::string, dim_t> symbolMap; |
117 | ASSIGN_VALUE_OR_RETURN_ERR(symbolMap, getSymbolMap()); |
118 | if (symbolMap.count(symbolName) && symbolMap[symbolName] > 0) { |
119 | dim.push_back(symbolMap[symbolName]); |
120 | } else { |
121 | return MAKE_ERR(strFormat( |
122 | "ONNX model symbol '%s' is undefined. Define the symbol with the " |
123 | "following command line option: -onnx-define-symbol=%s,<value> and " |
124 | "each 'dim_value' of tensor shape proto must be greater than 0." , |
125 | symbolName.c_str(), symbolName.c_str())); |
126 | } |
127 | } else { |
128 | // Proto shape has no "dim_value" and no "dim_param" field. |
129 | return MAKE_ERR( |
130 | "Tensor shape proto has no 'dim_value' or 'dim_param' field!" ); |
131 | } |
132 | } |
133 | return dim; |
134 | } |
135 | |
136 | /// Given some \p onnxType, sets \p elemTy to a corresponding Glow |
137 | /// ElemKind. \returns whether an ElemKind was successfully selected. |
138 | Error onnxTensorDataTypeToElemKind(int32_t onnxType, ElemKind *elemTy) { |
139 | if (onnxType == ONNX_NAMESPACE::TensorProto::FLOAT) { |
140 | *elemTy = ElemKind::FloatTy; |
141 | return Error::success(); |
142 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::FLOAT16) { |
143 | *elemTy = ElemKind::Float16Ty; |
144 | return Error::success(); |
145 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::BFLOAT16) { |
146 | *elemTy = ElemKind::BFloat16Ty; |
147 | return Error::success(); |
148 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT64) { |
149 | *elemTy = ElemKind::Int64ITy; |
150 | return Error::success(); |
151 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT32) { |
152 | *elemTy = ElemKind::Int32ITy; |
153 | return Error::success(); |
154 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::UINT8) { |
155 | *elemTy = ElemKind::UInt8FusedQTy; |
156 | return Error::success(); |
157 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT8) { |
158 | *elemTy = ElemKind::Int8QTy; |
159 | return Error::success(); |
160 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::INT16) { |
161 | *elemTy = ElemKind::Int16QTy; |
162 | return Error::success(); |
163 | } else if (onnxType == ONNX_NAMESPACE::TensorProto::BOOL) { |
164 | *elemTy = ElemKind::BoolTy; |
165 | return Error::success(); |
166 | } else { |
167 | return MAKE_ERR(strFormat( |
168 | "Don't know how to convert ONNX tensor data type %d to ElemKind" , |
169 | onnxType)); |
170 | } |
171 | } |
172 | |
173 | /// Finds an attribute from the doc_string and \returns it. If it does not exist |
174 | /// then \returns Error. The expected structure here is that each attribute |
175 | /// starts with startChar and is separated from its value by a sepChar. |
176 | Expected<std::string> getAttrFromDocString(const std::string &attr, |
177 | const std::string &docStr) { |
178 | const std::string attrAndSep = attr + sepChar; |
179 | size_t begin = 0; |
180 | while (true) { |
181 | begin = docStr.find(startChar, begin); |
182 | if (begin == std::string::npos) { |
183 | return MAKE_ERR(strFormat("Didn't find PH attribute '%s'" , attr.c_str())); |
184 | } |
185 | |
186 | // Note: +1 here and following line to account for the leading startChar. |
187 | if (!docStr.compare(begin + 1, attrAndSep.size(), attrAndSep)) { |
188 | // If we found the attribute then set begin to just after attrAndSep. |
189 | begin += attrAndSep.size() + 1; |
190 | break; |
191 | } |
192 | // Move past the current non-matching attribute to try the next attribute. |
193 | begin = begin + attrAndSep.size(); |
194 | } |
195 | |
196 | return docStr.substr(begin, docStr.find(startChar, begin) - begin); |
197 | } |
198 | |
199 | Expected<std::pair<bool, std::string>> |
200 | getTrainableLayoutPairFromDocString(const std::string &docString, |
201 | bool useGlowCustomOps) { |
202 | std::string layout = ANY_LAYOUT; |
203 | std::string isTrainableStr = "0" ; |
204 | if (useGlowCustomOps) { |
205 | ASSIGN_VALUE_OR_RETURN_ERR( |
206 | isTrainableStr, getAttrFromDocString(trainableSignifier, docString)); |
207 | ASSIGN_VALUE_OR_RETURN_ERR( |
208 | layout, getAttrFromDocString(layoutSignifier, docString)); |
209 | } |
210 | return std::make_pair(isTrainableStr != "0" , layout); |
211 | } |
212 | |
213 | Expected<std::pair<float, int32_t>> |
214 | getQuantParamsFromDocString(const std::string &docStr) { |
215 | std::string scaleStr; |
216 | ASSIGN_VALUE_OR_RETURN_ERR(scaleStr, |
217 | getAttrFromDocString(qScaleSignifier, docStr)); |
218 | float scale = std::strtof(scaleStr.c_str(), NULL); |
219 | |
220 | std::string offsetStr; |
221 | ASSIGN_VALUE_OR_RETURN_ERR(offsetStr, |
222 | getAttrFromDocString(qOffsetSignifier, docStr)); |
223 | int32_t offset; |
224 | ASSIGN_VALUE_OR_RETURN_ERR(offset, getIntFromStr(offsetStr)); |
225 | return std::make_pair(scale, offset); |
226 | } |
227 | |
228 | ShapeVector getStridesFromDocString(const std::string &docStr) { |
229 | ShapeVector strides; |
230 | std::string stridesStr; |
231 | auto stridesOrError = getAttrFromDocString(stridesSignifier, docStr); |
232 | if (ERR_TO_BOOL(stridesOrError.takeError(), /* log */ false)) { |
233 | return strides; |
234 | } |
235 | stridesStr = std::move(stridesOrError.get()); |
236 | if (stridesStr.empty()) { |
237 | return strides; |
238 | } |
239 | // Parse comma-delimited stride values. |
240 | llvm::SmallVector<llvm::StringRef, max_tensor_dimensions> stridesStrSplit; |
241 | llvm::StringRef stridesStrRef = llvm::StringRef(stridesStr); |
242 | stridesStrRef.split(stridesStrSplit, ','); |
243 | for (const auto &stride : stridesStrSplit) { |
244 | strides.emplace_back(std::stoi(stride.str())); |
245 | } |
246 | return strides; |
247 | } |
248 | |
249 | /// Used for retrieving an attribute of type \p T from \p attr. Some |
250 | /// specializations used \p loader if necessary. |
251 | template <bool IsInteger, typename T> struct AttributeRetriever { |
252 | static Expected<T> get(const ONNX_NAMESPACE::AttributeProto *attr, |
253 | const ProtobufLoader &loader); |
254 | }; |
255 | |
256 | /// Specialization for std::vector<float>. |
257 | template <> struct AttributeRetriever<false, std::vector<float>> { |
258 | static Expected<std::vector<float>> |
259 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
260 | const ProtobufLoader & /* unused */) { |
261 | return getFloats(attr); |
262 | } |
263 | }; |
264 | |
265 | /// Specialization for std::vector<NodeValue>. |
266 | template <> struct AttributeRetriever<false, std::vector<NodeValue>> { |
267 | static Expected<std::vector<NodeValue>> |
268 | get(const ONNX_NAMESPACE::AttributeProto *attr, ProtobufLoader &loader) { |
269 | // Retrieve the names from the proto which map to NodeValues. |
270 | std::vector<std::string> strs; |
271 | ASSIGN_VALUE_OR_RETURN_ERR(strs, getStrings(attr)); |
272 | |
273 | // Get NodeValues corresponding to these names from the loader. |
274 | std::vector<NodeValue> NVs; |
275 | for (const auto &str : strs) { |
276 | NodeValue NV; |
277 | ASSIGN_VALUE_OR_RETURN_ERR(NV, loader.getNodeValueByName(str)); |
278 | NVs.push_back(NV); |
279 | } |
280 | return NVs; |
281 | } |
282 | }; |
283 | |
284 | /// Specialization for NodeValue. |
285 | template <> struct AttributeRetriever<false, NodeValue> { |
286 | static Expected<NodeValue> get(const ONNX_NAMESPACE::AttributeProto *attr, |
287 | ProtobufLoader &loader) { |
288 | // Retrieve the name from the proto, which is mapped to a NodeValue. |
289 | std::string str; |
290 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
291 | |
292 | // Get/return the corresponding NodeValue for this name from the loader. |
293 | NodeValue NV; |
294 | ASSIGN_VALUE_OR_RETURN_ERR(NV, loader.getNodeValueByName(str)); |
295 | return NV; |
296 | } |
297 | }; |
298 | |
299 | /// Specialization for std::vector<T>. Fall back for integer types. |
300 | template <typename T> struct AttributeRetriever<false, std::vector<T>> { |
301 | static Expected<std::vector<T>> |
302 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
303 | const ProtobufLoader & /* unused */) { |
304 | return getShape<T>(attr, /* allowEmptyShape */ true); |
305 | } |
306 | }; |
307 | |
308 | /// Specialization for integer types. |
309 | template <typename T> struct AttributeRetriever<true, T> { |
310 | static Expected<T> get(const ONNX_NAMESPACE::AttributeProto *attr, |
311 | const ProtobufLoader & /* unused */) { |
312 | return loadInt(attr); |
313 | } |
314 | }; |
315 | |
316 | /// Specialization for LengthsMode. |
317 | template <> struct AttributeRetriever<false, LengthsMode> { |
318 | static Expected<LengthsMode> get(const ONNX_NAMESPACE::AttributeProto *attr, |
319 | const ProtobufLoader & /* unused */) { |
320 | std::string str; |
321 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
322 | if (str == "AllOne" ) { |
323 | return LengthsMode::AllOne; |
324 | } else if (str == "Variable" ) { |
325 | return LengthsMode::Variable; |
326 | } else { |
327 | return MAKE_ERR("Invalid LengthsMode" ); |
328 | } |
329 | } |
330 | }; |
331 | |
332 | /// Specialization for FusedActivation. |
333 | template <> struct AttributeRetriever<false, FusedActivation> { |
334 | static Expected<FusedActivation> |
335 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
336 | const ProtobufLoader & /* unused */) { |
337 | std::string str; |
338 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
339 | if (str == "NONE" ) { |
340 | return FusedActivation::NONE; |
341 | } else if (str == "RELU" ) { |
342 | return FusedActivation::RELU; |
343 | } else if (str == "CLIP" ) { |
344 | return FusedActivation::CLIP; |
345 | } else if (str == "TANH" ) { |
346 | return FusedActivation::TANH; |
347 | } else if (str == "SIGMOID" ) { |
348 | return FusedActivation::SIGMOID; |
349 | } else if (str == "LEAKY_RELU" ) { |
350 | return FusedActivation::LEAKY_RELU; |
351 | } else { |
352 | return MAKE_ERR("Invalid FusedActivation" ); |
353 | } |
354 | } |
355 | }; |
356 | |
357 | /// Specialization for FusedActivation. |
358 | template <> struct AttributeRetriever<false, LUTOperator> { |
359 | static Expected<LUTOperator> get(const ONNX_NAMESPACE::AttributeProto *attr, |
360 | const ProtobufLoader & /* unused */) { |
361 | std::string str; |
362 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
363 | if (str == "NONE" ) { |
364 | return LUTOperator::NONE; |
365 | } else if (str == "RELU" ) { |
366 | return LUTOperator::RELU; |
367 | } else if (str == "CLIP" ) { |
368 | return LUTOperator::CLIP; |
369 | } else if (str == "TANH" ) { |
370 | return LUTOperator::TANH; |
371 | } else if (str == "SIGMOID" ) { |
372 | return LUTOperator::SIGMOID; |
373 | } else if (str == "LEAKY_RELU" ) { |
374 | return LUTOperator::LEAKY_RELU; |
375 | } else { |
376 | return MAKE_ERR("Invalid LUTOperator" ); |
377 | } |
378 | } |
379 | }; |
380 | |
381 | /// Specialization for ConvolutionLayout. |
382 | template <> struct AttributeRetriever<false, ConvolutionLayout> { |
383 | static Expected<ConvolutionLayout> |
384 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
385 | const ProtobufLoader & /* unused */) { |
386 | std::string str; |
387 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
388 | if (str == "NHWC" ) { |
389 | return ConvolutionLayout::NHWC; |
390 | } else if (str == "NCHW" ) { |
391 | return ConvolutionLayout::NCHW; |
392 | } else if (str == "NTHWC" ) { |
393 | return ConvolutionLayout::NTHWC; |
394 | } else if (str == "NCTHW" ) { |
395 | return ConvolutionLayout::NCTHW; |
396 | } else { |
397 | return MAKE_ERR("Invalid ConvolutionLayout" ); |
398 | } |
399 | } |
400 | }; |
401 | |
402 | /// Specialization for PaddingMode. |
403 | template <> struct AttributeRetriever<false, PaddingMode> { |
404 | static Expected<PaddingMode> get(const ONNX_NAMESPACE::AttributeProto *attr, |
405 | const ProtobufLoader & /* unused */) { |
406 | std::string str; |
407 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
408 | if (str == "CONSTANT" ) { |
409 | return PaddingMode::CONSTANT; |
410 | } else if (str == "REFLECT" ) { |
411 | return PaddingMode::REFLECT; |
412 | } else if (str == "EDGE" ) { |
413 | return PaddingMode::EDGE; |
414 | } else { |
415 | return MAKE_ERR("Invalid PaddingMode" ); |
416 | } |
417 | } |
418 | }; |
419 | |
420 | /// Specialization for SplitEmbeddingPoolingMode. |
421 | template <> struct AttributeRetriever<false, SplitEmbeddingPoolingMode> { |
422 | static Expected<SplitEmbeddingPoolingMode> |
423 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
424 | const ProtobufLoader & /* unused */) { |
425 | std::string poolingMode; |
426 | ASSIGN_VALUE_OR_RETURN_ERR(poolingMode, loadStr(attr)); |
427 | if (poolingMode == "0" ) { |
428 | return SplitEmbeddingPoolingMode::EP_SUM; |
429 | } else if (poolingMode == "1" ) { |
430 | return SplitEmbeddingPoolingMode::EP_MEAN; |
431 | } else if (poolingMode == "2" ) { |
432 | return SplitEmbeddingPoolingMode::EP_NONE; |
433 | } else { |
434 | return MAKE_ERR("Invalid SplitEmbeddingPoolingMode" ); |
435 | } |
436 | } |
437 | }; |
438 | |
439 | /// Specialization for SplitEmbeddingSparseType. |
440 | template <> struct AttributeRetriever<false, SplitEmbeddingSparseType> { |
441 | static Expected<SplitEmbeddingSparseType> |
442 | get(const ONNX_NAMESPACE::AttributeProto *attr, |
443 | const ProtobufLoader & /* unused */) { |
444 | std::string str; |
445 | ASSIGN_VALUE_OR_RETURN_ERR(str, loadStr(attr)); |
446 | if (str == "0" ) { |
447 | return SplitEmbeddingSparseType::EST_FLOAT; |
448 | } else if (str == "1" ) { |
449 | return SplitEmbeddingSparseType::EST_FLOAT16; |
450 | } else if (str == "2" ) { |
451 | return SplitEmbeddingSparseType::EST_INT8; |
452 | } else if (str == "3" ) { |
453 | return SplitEmbeddingSparseType::EST_INT4; |
454 | } else if (str == "4" ) { |
455 | return SplitEmbeddingSparseType::EST_INT2; |
456 | } else { |
457 | return MAKE_ERR("Invalid SplitEmbeddingSparseType" ); |
458 | } |
459 | } |
460 | }; |
461 | |
462 | /// Specialization for float. |
463 | template <> struct AttributeRetriever<false, float> { |
464 | static Expected<float> get(const ONNX_NAMESPACE::AttributeProto *attr, |
465 | const ProtobufLoader & /* unused */) { |
466 | return loadFloat(attr); |
467 | } |
468 | }; |
469 | |
470 | /// Specialization for std::string. |
471 | template <> struct AttributeRetriever<false, std::string> { |
472 | static Expected<std::string> get(const ONNX_NAMESPACE::AttributeProto *attr, |
473 | const ProtobufLoader & /* unused */) { |
474 | return loadStr(attr); |
475 | } |
476 | }; |
477 | |
478 | /// Forwards to the correct AttributeRetriever specialization. |
479 | template <typename T> |
480 | Expected<T> loadAttribute(const ONNX_NAMESPACE::AttributeProto *attr, |
481 | ProtobufLoader &loader) { |
482 | RETURN_ERR_IF_NOT(attr, "No such attribute" ); |
483 | return AttributeRetriever<std::numeric_limits<T>::is_integer, T>::get(attr, |
484 | loader); |
485 | } |
486 | |
487 | } // namespace |
488 | |
489 | using ArgumentDictionaryTy = |
490 | std::unordered_map<std::string, const ONNX_NAMESPACE::AttributeProto *>; |
491 | |
492 | /// \returns a type based on \p ty, but using the provided \p strides. |
493 | static Type getTypeWithCustomStrides(const Type &ty, |
494 | llvm::ArrayRef<dim_t> strides) { |
495 | if (strides.empty()) { |
496 | return ty; |
497 | } |
498 | return Type::newStrides(ty, strides); |
499 | } |
500 | |
501 | /// \returns a type in module \p mod, based on \p ty, but using the provided \p |
502 | /// strides. |
503 | static TypeRef getTypeWithCustomStrides(glow::Module &mod, const TypeRef ty, |
504 | llvm::ArrayRef<dim_t> strides) { |
505 | if (strides.empty()) { |
506 | return ty; |
507 | } |
508 | return mod.uniqueTypeWithNewStrides(ty, ty->dims(), strides); |
509 | } |
510 | |
511 | /// Given a docstring encoding \p str of a type and its dimension \p |
512 | /// dims, parses the string and \returns a Glow Type from it or Error if |
513 | /// parsing failed. Expected format of str is either elemKindSignifier or |
514 | /// "ElemKind:scale:offset". |
515 | Expected<Type> parseTypeFromDocString(const std::string &str, |
516 | llvm::ArrayRef<dim_t> dims, |
517 | bool useGlowCustomOps) { |
518 | float scale = 1.0; |
519 | int32_t offset = 0; |
520 | ElemKind elemKind = ElemKind::FloatTy; |
521 | ShapeVector strides; |
522 | |
523 | if (useGlowCustomOps) { |
524 | std::string elemKindStr; |
525 | ASSIGN_VALUE_OR_RETURN_ERR(elemKindStr, |
526 | getAttrFromDocString(elemKindSignifier, str)); |
527 | elemKind = Type::getElementKindFromName(elemKindStr); |
528 | |
529 | if (isQuantizedElemKind(elemKind)) { |
530 | std::pair<float, int32_t> scaleOffsetPair; |
531 | ASSIGN_VALUE_OR_RETURN_ERR(scaleOffsetPair, |
532 | getQuantParamsFromDocString(str)); |
533 | std::tie(scale, offset) = scaleOffsetPair; |
534 | } |
535 | strides = getStridesFromDocString(str); |
536 | } else { |
537 | size_t begin = 0; |
538 | |
539 | // Find Elemkind string |
540 | size_t end = str.find(':', begin); |
541 | |
542 | // If a ':' isn't found then assume the whole string is ElemKind (for |
543 | // backwards compatibility reasons) otherwise look for scale and offset |
544 | // strings. |
545 | std::string elemKindStr; |
546 | if (end == std::string::npos) { |
547 | elemKindStr = str.substr(0, str.size()); |
548 | } else { |
549 | elemKindStr = str.substr(begin, end - begin); |
550 | |
551 | // Get scale string. |
552 | begin = end + 1; |
553 | end = str.find(':', begin); |
554 | if (end == std::string::npos) { |
555 | return MAKE_ERR("scale not found" ); |
556 | } |
557 | std::string scaleStr = str.substr(begin, end - begin); |
558 | |
559 | // Get offset string. |
560 | begin = end + 1; |
561 | end = str.size(); |
562 | if (end - begin == 0) { |
563 | return MAKE_ERR("offset not found" ); |
564 | } |
565 | |
566 | std::string offsetStr = str.substr(begin, end - begin); |
567 | |
568 | scale = std::stof(scaleStr); |
569 | offset = std::stoi(offsetStr); |
570 | } |
571 | |
572 | elemKind = Type::getElementKindFromName(elemKindStr); |
573 | } |
574 | |
575 | Type ty; |
576 | if (isQuantizedElemKind(elemKind)) { |
577 | ty = Type(elemKind, dims, scale, offset); |
578 | } else { |
579 | ty = Type(elemKind, dims); |
580 | } |
581 | return getTypeWithCustomStrides(ty, strides); |
582 | } |
583 | |
584 | /// Translates the protocol buffer node \p op into a random access map. |
585 | static ArgumentDictionaryTy |
586 | loadArgumentMap(const ONNX_NAMESPACE::NodeProto &op) { |
587 | ArgumentDictionaryTy dict; |
588 | for (auto &arg : op.attribute()) { |
589 | dict[arg.name()] = &arg; |
590 | } |
591 | return dict; |
592 | } |
593 | |
594 | void glow::setOnnxDefineSymbol(const std::vector<std::string> &strs) { |
595 | onnxDefineSymbol = strs; |
596 | } |
597 | |
598 | ONNX_NAMESPACE::GraphProto glow::parseOnnxFile(const std::string &fileName) { |
599 | ::ONNX_NAMESPACE::GraphProto graphProto; |
600 | std::ifstream inputFileStream(fileName, std::ios::in | std::ios::binary); |
601 | CHECK(inputFileStream) << "Can't find the input file for " << fileName; |
602 | google::protobuf::io::IstreamInputStream protobufFileStream(&inputFileStream); |
603 | google::protobuf::io::CodedInputStream codedStream(&protobufFileStream); |
604 | #if GOOGLE_PROTOBUF_VERSION >= 3002000 |
605 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE); |
606 | #else |
607 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE); |
608 | #endif |
609 | bool parsedSuccessfully = graphProto.ParseFromCodedStream(&codedStream); |
610 | CHECK(parsedSuccessfully) << "Failed to parse GraphProto" ; |
611 | return graphProto; |
612 | } |
613 | |
614 | void glow::fillPlaceholders(const ONNX_NAMESPACE::GraphProto &inputGroup, |
615 | PlaceholderBindings *bindings, |
616 | std::vector<Tensor> *partialTensorPayloads, |
617 | bool usingGlowCustomOps) { |
618 | for (const auto &tensorProto : inputGroup.initializer()) { |
619 | const std::string glowLegalizedName = |
620 | glow::legalizeName(tensorProto.name()); |
621 | auto *tensor = |
622 | bindings->get(bindings->getPlaceholderByNameSlow(glowLegalizedName)); |
623 | CHECK(tensor) << "Missing " << tensorProto.name() |
624 | << ", Glow legalized name " << glowLegalizedName; |
625 | size_t fullSize = tensor->getSizeInBytes(); |
626 | const auto fullType = tensor->getType(); |
627 | auto error = loadTensor(tensorProto, tensor, usingGlowCustomOps); |
628 | bool hasError = ERR_TO_BOOL(std::move(error)); |
629 | CHECK(!hasError) << "Cannot load input tensor" ; |
630 | size_t loadedSize = tensor->getSizeInBytes(); |
631 | if (loadedSize != fullSize) { |
632 | if (partialTensorPayloads) { |
633 | VLOG(1) << "Loading " << tensorProto.name() << ", Glow legalized name " |
634 | << glowLegalizedName << " as a partial tensor: partial size=" |
635 | << tensor->getType().toString() |
636 | << " full size=" << fullType.toString(); |
637 | Tensor fullTensor(tensor->getUnsafePtr(), &fullType, |
638 | tensor->getSizeInBytes()); |
639 | // 'fullTensor' doesn't own the underlying data. 'tensor' does. So |
640 | // we want to keep the original tensor object around until inference |
641 | // is finished. |
642 | partialTensorPayloads->emplace_back(std::move(*tensor)); |
643 | *tensor = std::move(fullTensor); |
644 | } else { |
645 | // pad with 0 |
646 | VLOG(1) << "Loading and padding " << tensorProto.name() |
647 | << ", Glow legalized name " << glowLegalizedName |
648 | << " as a partial tensor: partial size=" |
649 | << tensor->getType().toString() |
650 | << " full size=" << fullType.toString(); |
651 | Tensor fullTensor(&fullType); |
652 | std::memcpy(fullTensor.getUnsafePtr(), tensor->getUnsafePtr(), |
653 | tensor->getSizeInBytes()); |
654 | std::memset(fullTensor.getUnsafePtr() + tensor->getSizeInBytes(), 0, |
655 | fullTensor.getSizeInBytes() - tensor->getSizeInBytes()); |
656 | *tensor = std::move(fullTensor); |
657 | } |
658 | } |
659 | } |
660 | } |
661 | |
662 | void glow::fillPlaceholders(const std::string &fileName, |
663 | PlaceholderBindings *bindings, |
664 | std::vector<Tensor> *partialTensorPayloads, |
665 | bool usingGlowCustomOps) { |
666 | const ONNX_NAMESPACE::GraphProto &inputGroup = parseOnnxFile(fileName); |
667 | fillPlaceholders(inputGroup, bindings, partialTensorPayloads, |
668 | usingGlowCustomOps); |
669 | } |
670 | |
671 | /// Loads tensor \p T from the input \p in. |
672 | Error glow::loadTensor(const ONNX_NAMESPACE::TensorProto &in, Tensor *T, |
673 | bool useGlowCustomOps, const std::string &data) { |
674 | std::vector<dim_t> dim; |
675 | for (auto d : in.dims()) { |
676 | dim.push_back(d); |
677 | } |
678 | |
679 | ShapeVector strides; |
680 | if (in.has_doc_string()) { |
681 | strides = getStridesFromDocString(in.doc_string()); |
682 | } |
683 | |
684 | if (in.data_type() == ONNX_NAMESPACE::TensorProto::FLOAT) { |
685 | Type ty(ElemKind::FloatTy, dim); |
686 | T->reset(getTypeWithCustomStrides(ty, strides)); |
687 | |
688 | if (in.float_data_size() > 0) { |
689 | auto TH = T->getHandle<>(); |
690 | size_t i = 0; |
691 | for (auto f : in.float_data()) { |
692 | TH.raw(i++) = f; |
693 | } |
694 | } else if (in.has_raw_data() || !data.empty()) { |
695 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
696 | std::stringstream::binary); |
697 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(float)); |
698 | } else { |
699 | return MAKE_ERR("Unsupported Tensor format for FLOAT, name: " + in.name(), |
700 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
701 | } |
702 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::FLOAT16) { |
703 | Type ty(ElemKind::Float16Ty, dim); |
704 | T->reset(getTypeWithCustomStrides(ty, strides)); |
705 | if (in.has_raw_data() || !data.empty()) { |
706 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
707 | std::stringstream::binary); |
708 | inStream.read(T->getUnsafePtr(), T->actualSize() * (sizeof(float) / 2)); |
709 | } else { |
710 | return MAKE_ERR("Unsupported Tensor format for FLOAT16, name: " + |
711 | in.name(), |
712 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
713 | } |
714 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::BFLOAT16) { |
715 | Type ty(ElemKind::BFloat16Ty, dim); |
716 | T->reset(getTypeWithCustomStrides(ty, strides)); |
717 | if (in.has_raw_data() || !data.empty()) { |
718 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
719 | std::stringstream::binary); |
720 | inStream.read(T->getUnsafePtr(), T->actualSize() * (sizeof(float) / 2)); |
721 | } else { |
722 | return MAKE_ERR("Unsupported Tensor format for BFLOAT16, name: " + |
723 | in.name(), |
724 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
725 | } |
726 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT64) { |
727 | Type ty(ElemKind::Int64ITy, dim); |
728 | T->reset(getTypeWithCustomStrides(ty, strides)); |
729 | |
730 | if (in.int64_data_size() > 0) { |
731 | auto TH = T->getHandle<int64_t>(); |
732 | size_t i = 0; |
733 | for (auto f : in.int64_data()) { |
734 | TH.raw(i++) = f; |
735 | } |
736 | } else if (in.has_raw_data() || !data.empty()) { |
737 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
738 | std::stringstream::binary); |
739 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int64_t)); |
740 | } else { |
741 | return MAKE_ERR("Unsupported Tensor format for INT64, name: " + in.name(), |
742 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
743 | } |
744 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT8) { |
745 | if (in.has_doc_string()) { |
746 | Type ty; |
747 | ASSIGN_VALUE_OR_RETURN_ERR( |
748 | ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps)); |
749 | T->reset(ty); |
750 | } else { |
751 | // Onnx uses Int8 data type through operators like QuantizeLinear. |
752 | // Also data is passed through raw_data since TensorProto |
753 | // does not have data field for int8_t or uint8_t. |
754 | // scale is set to 1 and offset is set to 0 since both scale and offset |
755 | // themselves are operators inputs. |
756 | Type ty(ElemKind::Int8QTy, dim, 1 /* scale*/, 0 /* offset*/); |
757 | T->reset(getTypeWithCustomStrides(ty, strides)); |
758 | } |
759 | if (in.has_raw_data() || !data.empty()) { |
760 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
761 | std::stringstream::binary); |
762 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int8_t)); |
763 | } else { |
764 | return MAKE_ERR("Unsupported Tensor format for INT8, name: " + in.name(), |
765 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
766 | } |
767 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT16) { |
768 | Type ty; |
769 | ASSIGN_VALUE_OR_RETURN_ERR( |
770 | ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps)); |
771 | T->reset(ty); |
772 | |
773 | if (in.has_raw_data() || !data.empty()) { |
774 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
775 | std::stringstream::binary); |
776 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int16_t)); |
777 | } else { |
778 | return MAKE_ERR("Unsupported Tensor format for INT16, name: " + in.name(), |
779 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
780 | } |
781 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT32) { |
782 | if (in.has_doc_string()) { |
783 | Type ty; |
784 | ASSIGN_VALUE_OR_RETURN_ERR( |
785 | ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps)); |
786 | T->reset(ty); |
787 | } else { |
788 | // There are few cases when we will have int32 tensors. For example, the |
789 | // second output of Concat from Caffe2 concat op is int32 |
790 | Type ty(ElemKind::Int32ITy, dim); |
791 | T->reset(getTypeWithCustomStrides(ty, strides)); |
792 | } |
793 | |
794 | if (in.int32_data_size() > 0) { |
795 | auto TH = T->getHandle<int32_t>(); |
796 | size_t i = 0; |
797 | for (auto f : in.int32_data()) { |
798 | TH.raw(i++) = f; |
799 | } |
800 | } else if (in.has_raw_data() || !data.empty()) { |
801 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
802 | std::stringstream::binary); |
803 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(int32_t)); |
804 | } else { |
805 | return MAKE_ERR("Unsupported Tensor format for INT32, name: " + in.name(), |
806 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
807 | } |
808 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::UINT8) { |
809 | if (in.has_doc_string()) { |
810 | Type ty; |
811 | ASSIGN_VALUE_OR_RETURN_ERR( |
812 | ty, parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps)); |
813 | T->reset(ty); |
814 | } else { |
815 | // Onnx uses Int8 data type through operators like QuantizeLinear. |
816 | // Also data is passed through raw_data since TensorProto |
817 | // does not have data field for int8_t or uint8_t. |
818 | // scale is set to 1 and offset is set to 0 since both scale and offset |
819 | // themselves are operators inputs. |
820 | Type ty(ElemKind::UInt8QTy, dim, 1 /* scale*/, 0 /* offset*/); |
821 | T->reset(getTypeWithCustomStrides(ty, strides)); |
822 | } |
823 | |
824 | if (in.has_raw_data() || !data.empty()) { |
825 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
826 | std::stringstream::binary); |
827 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(uint8_t)); |
828 | } else { |
829 | return MAKE_ERR("Unsupported Tensor format for UINT8, name: " + in.name(), |
830 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
831 | } |
832 | } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::BOOL) { |
833 | Type ty(ElemKind::BoolTy, dim); |
834 | T->reset(getTypeWithCustomStrides(ty, strides)); |
835 | if (in.has_raw_data() || !data.empty()) { |
836 | std::istringstream inStream(data.empty() ? in.raw_data() : data, |
837 | std::stringstream::binary); |
838 | inStream.read(T->getUnsafePtr(), T->actualSize() * sizeof(bool)); |
839 | } else if (in.int32_data_size() > 0) { |
840 | // Some ONNX models use int32_data to initialize bool type (e.g., when |
841 | // converted from Keras). |
842 | auto TH = T->getHandle<bool>(); |
843 | size_t i = 0; |
844 | for (auto f : in.int32_data()) { |
845 | TH.raw(i++) = (bool)f; |
846 | } |
847 | } else { |
848 | return MAKE_ERR("Unsupported Tensor format for BOOL, name: " + in.name(), |
849 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
850 | } |
851 | } else { |
852 | return MAKE_ERR(strFormat("Unsupported tensor data type: %u" , |
853 | static_cast<unsigned>(in.data_type())), |
854 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
855 | } |
856 | return Error::success(); |
857 | } |
858 | |
859 | Expected<Type> |
860 | ONNXModelLoader::getTensorType(const ONNX_NAMESPACE::TensorProto &in) { |
861 | std::vector<dim_t> dim; |
862 | for (auto d : in.dims()) { |
863 | dim.push_back(d); |
864 | } |
865 | |
866 | switch (in.data_type()) { |
867 | case ONNX_NAMESPACE::TensorProto::FLOAT: |
868 | return Type(ElemKind::FloatTy, dim); |
869 | |
870 | case ONNX_NAMESPACE::TensorProto::FLOAT16: |
871 | return Type(ElemKind::Float16Ty, dim); |
872 | |
873 | case ONNX_NAMESPACE::TensorProto::BFLOAT16: |
874 | return Type(ElemKind::BFloat16Ty, dim); |
875 | |
876 | case ONNX_NAMESPACE::TensorProto::INT64: |
877 | return Type(ElemKind::Int64ITy, dim); |
878 | |
879 | case ONNX_NAMESPACE::TensorProto::UINT8: |
880 | case ONNX_NAMESPACE::TensorProto::INT8: |
881 | case ONNX_NAMESPACE::TensorProto::INT16: |
882 | return parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps_); |
883 | |
884 | case ONNX_NAMESPACE::TensorProto::INT32: |
885 | if (in.has_doc_string()) { |
886 | return parseTypeFromDocString(in.doc_string(), dim, useGlowCustomOps_); |
887 | } |
888 | return Type(ElemKind::Int32ITy, dim); |
889 | |
890 | case ONNX_NAMESPACE::TensorProto::BOOL: |
891 | return Type(ElemKind::BoolTy, dim); |
892 | |
893 | default: |
894 | return MAKE_ERR(strFormat("Unsupported tensor data type: %u" , |
895 | static_cast<unsigned>(in.data_type())), |
896 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
897 | } |
898 | llvm_unreachable("Unsupported tensor data type" ); |
899 | } |
900 | |
901 | Expected<Type> |
902 | ONNXModelLoader::getTensorType(const ONNX_NAMESPACE::ValueInfoProto &in) { |
903 | auto type = in.type(); |
904 | |
905 | std::vector<dim_t> dim; |
906 | ASSIGN_VALUE_OR_RETURN_ERR(dim, getProtoShape(type.tensor_type().shape())); |
907 | |
908 | ElemKind kind = ElemKind::FloatTy; |
909 | float scale = 1.0; |
910 | int32_t offset = 0; |
911 | ShapeVector strides; |
912 | if (useGlowCustomOps_) { |
913 | std::string elemKindStr; |
914 | ASSIGN_VALUE_OR_RETURN_ERR( |
915 | elemKindStr, getAttrFromDocString(elemKindSignifier, in.doc_string())); |
916 | kind = Type::getElementKindFromName(elemKindStr); |
917 | if (isQuantizedElemKind(kind)) { |
918 | std::pair<float, int32_t> scaleOffsetPair; |
919 | ASSIGN_VALUE_OR_RETURN_ERR(scaleOffsetPair, |
920 | getQuantParamsFromDocString(in.doc_string())); |
921 | std::tie(scale, offset) = scaleOffsetPair; |
922 | } |
923 | strides = getStridesFromDocString(in.doc_string()); |
924 | } else { |
925 | // Retrieve the ElemKind from the ONNX type, including considerations for |
926 | // whether the datatype is quantized. |
927 | RETURN_IF_ERR( |
928 | onnxTensorDataTypeToElemKind(type.tensor_type().elem_type(), &kind)); |
929 | } |
930 | |
931 | // If quantized then retrieve the scale and offset if provided (may not be for |
932 | // fused quantized types since they're ignored anyway). |
933 | if (isQuantizedElemKind(kind)) { |
934 | assert(useGlowCustomOps_ && |
935 | "Quantized loading not fully supported without custom Glow ops." ); |
936 | return getTypeWithCustomStrides(Type(kind, dim, scale, offset), strides); |
937 | } |
938 | return getTypeWithCustomStrides(Type(kind, dim), strides); |
939 | } |
940 | |
941 | Error ONNXModelLoader::getInputsNamesAndTypes( |
942 | std::vector<std::string> &inTensorNames, std::vector<Type> &inTypes, |
943 | const std::string &filename) { |
944 | // Creating GraphProto from modelfile |
945 | ONNX_NAMESPACE::ModelProto modelDef; |
946 | |
947 | ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(filename, false, nullptr)); |
948 | |
949 | ONNX_NAMESPACE::GraphProto graphDef = modelDef.graph(); |
950 | |
951 | // GraphDef.input can have both inputs and intermediate tensors whereas |
952 | // initializers have info about intermediate tensors. Thus the difference |
953 | // betweem two is taken |
954 | std::vector<std::string> inputs; |
955 | for (auto &in : graphDef.input()) { |
956 | inputs.push_back(in.name()); |
957 | } |
958 | |
959 | for (const auto &in : graphDef.initializer()) { |
960 | auto position = std::find(inputs.begin(), inputs.end(), in.name()); |
961 | if (position != inputs.end()) { |
962 | inputs.erase(position); |
963 | } |
964 | } |
965 | |
966 | for (const auto &finalIn : inputs) { |
967 | for (const auto &in : graphDef.input()) { |
968 | if (finalIn.compare(in.name()) != 0) { |
969 | continue; |
970 | } |
971 | inTensorNames.push_back(in.name()); |
972 | const ONNX_NAMESPACE::ValueInfoProto &valueInfo = in; |
973 | auto type = valueInfo.type(); |
974 | |
975 | std::vector<dim_t> dim; |
976 | ASSIGN_VALUE_OR_RETURN_ERR(dim, |
977 | getProtoShape(type.tensor_type().shape())); |
978 | |
979 | ElemKind kind = ElemKind::FloatTy; |
980 | RETURN_IF_ERR( |
981 | onnxTensorDataTypeToElemKind(type.tensor_type().elem_type(), &kind)); |
982 | |
983 | inTypes.emplace_back(kind, dim); |
984 | } |
985 | } |
986 | |
987 | return Error::success(); |
988 | } |
989 | |
990 | Error ONNXModelLoader::verifyPreexistingStorage(const Storage *S, |
991 | const std::string &name, |
992 | const Type &ty, |
993 | const std::string &layout, |
994 | const bool trainable) { |
995 | RETURN_ERR_IF_NOT(S, "Storage did not exist in Module: " + name); |
996 | if (replaceDummyTQPs_ && ty.isQuantizedType() && |
997 | ty.getScale() == dummyScale) { |
998 | TensorQuantizationParams TQP; |
999 | ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(ty.getOffset())); |
1000 | // If we are replacing dummy TQPs with updated ones, then do verification |
1001 | // based on the updated type and not the base dummy type we found. |
1002 | Type updatedTy(ty.getElementType(), ty.dims(), TQP.scale, TQP.offset); |
1003 | RETURN_ERR_IF_NOT(S->getType()->isEqual(updatedTy), |
1004 | "Incorrect type for quant param updated existing " + |
1005 | S->getDebugDesc() + " " + "Expected type " + |
1006 | updatedTy.toString()); |
1007 | } else { |
1008 | RETURN_ERR_IF_NOT(S->getType()->isEqual(ty), |
1009 | "Incorrect type for existing " + S->getDebugDesc() + |
1010 | " " + "Expected type " + ty.toString()); |
1011 | } |
1012 | if (const Placeholder *PH = llvm::dyn_cast<Placeholder>(S)) { |
1013 | RETURN_ERR_IF_NOT(trainable == PH->isTraining(), |
1014 | "Incorrect trainability for existing Storage " + name); |
1015 | } |
1016 | RETURN_ERR_IF_NOT(layout == S->getLayout(), |
1017 | "Incorrect layout for existing Storage " + name); |
1018 | return Error::success(); |
1019 | } |
1020 | |
1021 | Error ONNXModelLoader::loadInputs(ONNX_NAMESPACE::GraphProto &net, |
1022 | bool loadInputsAsPlaceholdersForOnnx) { |
1023 | for (const auto &in : net.input()) { |
1024 | // Skip static weights. |
1025 | if (getConstantByNameOrNull(in.name())) { |
1026 | continue; |
1027 | } |
1028 | |
1029 | const std::string &docString = in.doc_string(); |
1030 | |
1031 | Type ty; |
1032 | ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(in)); |
1033 | |
1034 | if (replaceDummyTQPs_ && ty.isQuantizedType() && |
1035 | ty.getScale() == dummyScale) { |
1036 | TensorQuantizationParams TQP; |
1037 | ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(ty.getOffset())); |
1038 | ty = Type(ty.getElementType(), ty.dims(), TQP.scale, TQP.offset); |
1039 | } |
1040 | |
1041 | std::pair<bool, std::string> trainableLayoutPair; |
1042 | ASSIGN_VALUE_OR_RETURN_ERR( |
1043 | trainableLayoutPair, |
1044 | getTrainableLayoutPairFromDocString(docString, useGlowCustomOps_)); |
1045 | |
1046 | // If we already have the existing module then we may already have the input |
1047 | // Placeholder. If so, verify it has the correct type. |
1048 | if (loadIntoExistingModule_) { |
1049 | RETURN_ERR_IF_NOT( |
1050 | loadInputsAsPlaceholdersForOnnx, |
1051 | "Must load inputs as Placeholders when using existing Module." ); |
1052 | if (Placeholder *PH = mod_.getPlaceholderByNameSlow(in.name())) { |
1053 | // Set Fused types of Placeholders if they were expected to be |
1054 | // fused. Necessary because Caffe2/ONNX protos do not have fused types |
1055 | // explicitly, so will be loaded initially as int8. |
1056 | if (ty.isFusedQuantizedType()) { |
1057 | RETURN_IF_ERR(setFusedTy(PH, mod_.uniqueType(ty))); |
1058 | } |
1059 | RETURN_IF_ERR(verifyPreexistingStorage(PH, in.name(), ty, |
1060 | trainableLayoutPair.second, |
1061 | trainableLayoutPair.first)); |
1062 | nodeValueByName_[in.name()] = PH->getOutput(); |
1063 | continue; |
1064 | } |
1065 | } |
1066 | |
1067 | // We must not have the input created yet, so do so. |
1068 | if (loadInputsAsPlaceholdersForOnnx) { |
1069 | RETURN_ERR_IF_NOT(!clipQuantRangeToFP16_ || !ty.isQuantizedType() || |
1070 | ty.isFusedQuantizedType(), |
1071 | "Do not support clipQuantRangeToFP16 with unfused " |
1072 | "quantized input Placeholders: " + |
1073 | in.name()); |
1074 | Placeholder *inPH; |
1075 | ASSIGN_VALUE_OR_RETURN_ERR( |
1076 | inPH, createAndRegisterPlaceholder(in.name(), mod_.uniqueType(ty), |
1077 | staticInputs_.count(in.name()), |
1078 | trainableLayoutPair.first, |
1079 | trainableLayoutPair.second)); |
1080 | auto loaderNameOrErr = |
1081 | getAttrFromDocString(loaderNameSignifier, docString); |
1082 | const std::string &loaderName = |
1083 | !ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false) |
1084 | ? loaderNameOrErr.get() |
1085 | : in.name(); |
1086 | RETURN_ERR_IF_NOT(inputVarsByName_.try_emplace(loaderName, inPH).second, |
1087 | "Already had input placeholder by name " + loaderName); |
1088 | } else { |
1089 | Tensor T(ty); |
1090 | RETURN_IF_ERR(createAndRegisterConstant(in.name(), std::move(T))); |
1091 | } |
1092 | } |
1093 | return Error::success(); |
1094 | } |
1095 | |
1096 | Expected<bool> ONNXModelLoader::getBroadcast(ArgumentDictionaryTy &dict) { |
1097 | // Starting with opset 7, broadcasting is implicit and doesn't require any |
1098 | // attribute. |
1099 | if (opsetVersion_ > 6) { |
1100 | return true; |
1101 | } |
1102 | if (!dict.count("broadcast" )) { |
1103 | return false; |
1104 | } |
1105 | |
1106 | int broadcast; |
1107 | ASSIGN_VALUE_OR_RETURN_ERR(broadcast, loadInt(dict.at("broadcast" ))); |
1108 | return broadcast == 1; |
1109 | } |
1110 | |
1111 | bool ONNXModelLoader::hasMultidirectionalBroadcast( |
1112 | const llvm::StringRef typeName) { |
1113 | // Before opset 7, broadcasting was unidirectional. |
1114 | if (opsetVersion_ > 6) { |
1115 | // List of ops that support multidirectional broadcast can be found at |
1116 | // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md |
1117 | if ((typeName == "Add" ) || (typeName == "Sub" ) || (typeName == "Mul" ) || |
1118 | (typeName == "Div" ) || (typeName == "Equal" ) || |
1119 | (typeName == "Greater" ) || (typeName == "Less" ) || |
1120 | (typeName == "Max" ) || (typeName == "Mean" ) || (typeName == "Min" ) || |
1121 | (typeName == "Or" ) || (typeName == "Pow" ) || (typeName == "Sum" ) || |
1122 | (typeName == "Xor" )) { |
1123 | return true; |
1124 | } |
1125 | } |
1126 | return false; |
1127 | } |
1128 | |
1129 | Expected<ElemKind> ONNXModelLoader::convertTensorProtoDataType( |
1130 | ONNX_NAMESPACE::TensorProto_DataType t) { |
1131 | switch (t) { |
1132 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: |
1133 | return ElemKind::FloatTy; |
1134 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: |
1135 | return ElemKind::Float16Ty; |
1136 | case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: |
1137 | return ElemKind::BFloat16Ty; |
1138 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: |
1139 | return ElemKind::Float64Ty; |
1140 | case ONNX_NAMESPACE::TensorProto_DataType_INT32: |
1141 | return ElemKind::Int32ITy; |
1142 | case ONNX_NAMESPACE::TensorProto_DataType_INT64: |
1143 | return ElemKind::Int64ITy; |
1144 | case ONNX_NAMESPACE::TensorProto_DataType_BOOL: |
1145 | return ElemKind::BoolTy; |
1146 | default:; |
1147 | } |
1148 | return MAKE_ERR("Non supported ONNX type" ); |
1149 | } |
1150 | |
1151 | Error ONNXModelLoader::setVersion(ONNX_NAMESPACE::ModelProto MP) { |
1152 | irVersion_ = MP.ir_version(); |
1153 | opsetVersion_ = 0; |
1154 | RETURN_ERR_IF_NOT( |
1155 | irVersion_ >= 3, |
1156 | "This ONNX model with ir_version < 3 is too old to be supported." , |
1157 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION); |
1158 | for (const auto &imp : MP.opset_import()) { |
1159 | if (!imp.has_domain() || imp.domain() == "" ) { |
1160 | opsetVersion_ = imp.version(); |
1161 | break; |
1162 | } |
1163 | } |
1164 | RETURN_ERR_IF_NOT(opsetVersion_ > 0, |
1165 | "The opset of this ONNX model is not supported." ); |
1166 | return Error::success(); |
1167 | } |
1168 | |
1169 | Expected<ONNX_NAMESPACE::ModelProto> |
1170 | ONNXModelLoader::loadProto(google::protobuf::io::ZeroCopyInputStream &iStream) { |
1171 | // Construct and configure a Coded Input Stream |
1172 | google::protobuf::io::CodedInputStream codedStream(&iStream); |
1173 | |
1174 | // Don't warn about large file sizes. |
1175 | #if GOOGLE_PROTOBUF_VERSION >= 3002000 |
1176 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE); |
1177 | #else |
1178 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE); |
1179 | #endif |
1180 | ONNX_NAMESPACE::ModelProto MP; |
1181 | bool parseNet = MP.ParseFromCodedStream(&codedStream); |
1182 | RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto" , |
1183 | ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); |
1184 | |
1185 | return MP; |
1186 | } |
1187 | |
1188 | Expected<ONNX_NAMESPACE::ModelProto> |
1189 | ONNXModelLoader::loadProto(const void *onnxModel, size_t onnxModelSize) { |
1190 | google::protobuf::io::ArrayInputStream arrayStream(onnxModel, onnxModelSize); |
1191 | return loadProto(arrayStream); |
1192 | } |
1193 | |
1194 | Expected<ONNX_NAMESPACE::ModelProto> |
1195 | ONNXModelLoader::loadProto(const std::string &filename, bool zipMode, |
1196 | const std::string *inputStringPtr) { |
1197 | /// A helper class to report parsing errors when loading model protos. |
1198 | class OnnxTxtErrCollector : public google::protobuf::io::ErrorCollector { |
1199 | public: |
1200 | OnnxTxtErrCollector() = default; |
1201 | ~OnnxTxtErrCollector() override = default; |
1202 | void AddError(int line, int column, const std::string &message) override { |
1203 | llvm::errs() << strFormat("ONNX parsing error at [%d, %d]: " , line, |
1204 | column) |
1205 | << message << "\n" ; |
1206 | } |
1207 | void AddWarning(int line, int column, const std::string &message) override { |
1208 | llvm::errs() << strFormat("ONNX parsing warning at [%d, %d]: " , line, |
1209 | column) |
1210 | << message << "\n" ; |
1211 | } |
1212 | }; |
1213 | |
1214 | // Create a parser object and attach an error collector to it. |
1215 | OnnxTxtErrCollector errorCollector; |
1216 | google::protobuf::TextFormat::Parser parser; |
1217 | parser.RecordErrorsTo(&errorCollector); |
1218 | bool parseNet; |
1219 | ONNX_NAMESPACE::ModelProto MP; |
1220 | |
1221 | if (zipMode) { |
1222 | RETURN_ERR_IF_NOT( |
1223 | inputStringPtr == nullptr, |
1224 | "OnnxModelLoader load from string for zip mode not supported" ); |
1225 | ZipReader zip(filename); |
1226 | std::string buffer = zip.getRecord("model" ); |
1227 | // Try to parse as a protocol buffer first. |
1228 | parseNet = MP.ParseFromString(buffer); |
1229 | // Try to parse a textual representation first. |
1230 | if (!parseNet) { |
1231 | // If it is not a protocol buffer, try to parse as a text format. |
1232 | parseNet = parser.ParseFromString(buffer, &MP); |
1233 | } |
1234 | if (!parseNet) { |
1235 | RETURN_ERR_IF_NOT(false, "Failed to parse ModelProto" , |
1236 | ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); |
1237 | } |
1238 | size_t numWeights = 0; |
1239 | auto numWeightsStr = zip.getRecord("weights" ); |
1240 | numWeights = atoi(numWeightsStr.c_str()); |
1241 | for (size_t i = 0; i < numWeights; ++i) { |
1242 | std::stringstream ss; |
1243 | ss << "weight_" << i; |
1244 | buffer = zip.getRecord(ss.str()); |
1245 | auto *t = MP.mutable_graph()->add_initializer(); |
1246 | t->ParseFromString(buffer); |
1247 | } |
1248 | return MP; |
1249 | } |
1250 | |
1251 | std::ifstream ff(filename, std::ios::in | std::ios::binary); |
1252 | if (inputStringPtr == nullptr) { |
1253 | RETURN_ERR_IF_NOT(ff, |
1254 | strFormat("Can't find the model or network files for %s." , |
1255 | filename.c_str()), |
1256 | ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); |
1257 | } |
1258 | |
1259 | // TODO: intend to find a way to reuse the following function later |
1260 | // for the text format onnx model: |
1261 | // bool ONNXModelLoader::loadProto(ONNX_NAMESPACE::GraphProto &net, |
1262 | // google::protobuf::io::ZeroCopyInputStream &iStream) |
1263 | if (filename.find(".onnxtxt" ) != std::string::npos) { |
1264 | if (inputStringPtr == nullptr) { |
1265 | // Reserve a reasonably big buffer to make sure that reading of models |
1266 | // with serialized constants is fast enough. |
1267 | const std::streamsize bufferSize = 1024 * 1024 * 16; |
1268 | std::vector<char> buffer(bufferSize); |
1269 | ff.rdbuf()->pubsetbuf(buffer.data(), bufferSize); |
1270 | std::stringstream ss; |
1271 | ss << ff.rdbuf(); |
1272 | std::string str = ss.str(); |
1273 | parseNet = parser.ParseFromString(str, &MP); |
1274 | } else { |
1275 | parseNet = parser.ParseFromString(*inputStringPtr, &MP); |
1276 | } |
1277 | |
1278 | RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto" , |
1279 | ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); |
1280 | return MP; |
1281 | } |
1282 | |
1283 | if (inputStringPtr != nullptr) { |
1284 | std::istringstream iss(*inputStringPtr); |
1285 | google::protobuf::io::IstreamInputStream stringStream(&iss); |
1286 | return loadProto(stringStream); |
1287 | } |
1288 | google::protobuf::io::IstreamInputStream fileStream(&ff); |
1289 | return loadProto(fileStream); |
1290 | } |
1291 | |
1292 | /// Given an input \p val , ceil value is computed for a given datatype T |
1293 | template <typename T> T ceil(float val) { |
1294 | return (val - (T)val) > 0 ? (T)(val + 1) : (T)val; |
1295 | } |
1296 | |
1297 | namespace { |
1298 | /// Helper type for pads. |
1299 | using Pads = std::vector<unsigned_t>; |
1300 | } // namespace |
1301 | |
1302 | /// Get the Pads value based on setting for auto_pad. |
1303 | /// \p kdim : kernel sizes (HW) |
1304 | /// \p sdim: stride sizes (HW) |
1305 | /// \p idim: input sizes (HW) |
1306 | Expected<Pads> getPads(ArgumentDictionaryTy &dict, |
1307 | llvm::ArrayRef<unsigned_t> kdim, |
1308 | llvm::ArrayRef<unsigned_t> sdim, |
1309 | llvm::ArrayRef<unsigned_t> idim) { |
1310 | // TODO: ONNX spec disallows using "pads" and "auto_pad" together. However, |
1311 | // the implementation allows mixing them and onnxruntime gives pads priority. |
1312 | if (dict.count("pads" )) { |
1313 | if (dict.at("pads" )->ints_size() == 2) { // For maxPool1D |
1314 | return Pads({0, (unsigned_t)dict.at("pads" )->ints(0), 0, |
1315 | (unsigned_t)dict.at("pads" )->ints(1)}); |
1316 | } |
1317 | return getShape<unsigned_t>(dict["pads" ]); |
1318 | } |
1319 | |
1320 | // Set the default zero pads. Number of pads is dependent on input dims |
1321 | auto zeroPads = [idim]() { |
1322 | if (idim.size() == 3) { |
1323 | return Pads({0, 0, 0, 0, 0, 0}); |
1324 | } else { |
1325 | return Pads({0, 0, 0, 0}); |
1326 | } |
1327 | }; |
1328 | |
1329 | if (dict.count("auto_pad" )) { |
1330 | std::string padStr; |
1331 | ASSIGN_VALUE_OR_RETURN_ERR(padStr, loadStr(dict.at("auto_pad" ))); |
1332 | if (padStr == "VALID" ) { |
1333 | // Return default value 0 for pads. |
1334 | return zeroPads(); |
1335 | } else if (padStr == "SAME_UPPER" || padStr == "SAME_LOWER" ) { |
1336 | unsigned_t near, far, top, left, bottom, right; |
1337 | // From https://arxiv.org/pdf/1603.07285.pdf 2.4, |
1338 | // o = floor((i + 2*p - k)/s) + 1 |
1339 | // Also, from https://github.com/onnx/onnx/blob/master/docs/Operators.md |
1340 | // output_spatial_shape[i] = |
1341 | // ceil(input_spatial_shape[i] / strides_spatial_shape[i]) |
1342 | // pad_shape[i] = |
1343 | // (output_spatial_shape[i] - 1) * strides_spatial_shape[i] |
1344 | // + kernel_spatial_shape[i] - input_spatial_shape[i] |
1345 | // Use the smallest padding possible out of the possible options. |
1346 | unsigned_t odim; |
1347 | if (idim.size() == 2) { |
1348 | llvm::SmallVector<unsigned_t, 2> pdim(2); // Total Paddding, HW. |
1349 | for (size_t i = 0, e = pdim.size(); i < e; i++) { |
1350 | odim = ceil<unsigned_t>((float)idim[i] / (float)sdim[i]); |
1351 | pdim[i] = sdim[i] * (odim - 1) + kdim[i] - idim[i]; |
1352 | } |
1353 | if (padStr == "SAME_UPPER" ) { |
1354 | // SAME_UPPPER: if odd number for pdim[i], use extra padding at the |
1355 | // end. |
1356 | top = pdim[0] / 2; |
1357 | bottom = top + (pdim[0] & 0x1); |
1358 | left = pdim[1] / 2; |
1359 | right = left + (pdim[1] & 0x1); |
1360 | } else { |
1361 | // SAME_LOWER: if odd number for pdim[i], use extra padding at the |
1362 | // beginning. |
1363 | bottom = pdim[0] / 2; |
1364 | top = bottom + (pdim[0] & 0x1); |
1365 | right = pdim[1] / 2; |
1366 | left = right + (pdim[1] & 0x1); |
1367 | } |
1368 | return Pads({top, left, bottom, right}); |
1369 | } else if (idim.size() == 3) { |
1370 | llvm::SmallVector<unsigned_t, 3> pdim(3); // Total Paddding, HW. |
1371 | for (size_t i = 0, e = pdim.size(); i < e; i++) { |
1372 | odim = ceil<unsigned_t>((float)idim[i] / (float)sdim[i]); |
1373 | pdim[i] = sdim[i] * (odim - 1) + kdim[i] - idim[i]; |
1374 | } |
1375 | if (padStr == "SAME_UPPER" ) { |
1376 | // SAME_UPPPER: if odd number for pdim[i], use extra padding at the |
1377 | // end. |
1378 | near = pdim[0] / 2; |
1379 | far = near + (pdim[0] & 0x1); |
1380 | top = pdim[1] / 2; |
1381 | bottom = top + (pdim[1] & 0x1); |
1382 | left = pdim[2] / 2; |
1383 | right = left + (pdim[2] & 0x1); |
1384 | } else { |
1385 | // SAME_LOWER: if odd number for pdim[i], use extra padding at the |
1386 | // beginning. |
1387 | far = pdim[0] / 2; |
1388 | near = far + (pdim[0] & 0x1); |
1389 | bottom = pdim[1] / 2; |
1390 | top = bottom + (pdim[1] & 0x1); |
1391 | right = pdim[2] / 2; |
1392 | left = right + (pdim[2] & 0x1); |
1393 | } |
1394 | return Pads({near, top, left, far, bottom, right}); |
1395 | } else { |
1396 | return MAKE_ERR("getPads only works for 2D or 3D" ); |
1397 | } |
1398 | } else if (padStr == "NOTSET" ) { |
1399 | // We use explicit pads (if not given we assume its all zeros). |
1400 | if (dict.count("pads" )) { |
1401 | if (dict.at("pads" )->ints_size() == 2) { // For maxPool1D |
1402 | return Pads({0, (unsigned_t)dict.at("pads" )->ints(0), 0, |
1403 | (unsigned_t)dict.at("pads" )->ints(1)}); |
1404 | } |
1405 | return getShape<unsigned_t>(dict["pads" ]); |
1406 | } else { |
1407 | return zeroPads(); |
1408 | } |
1409 | } |
1410 | return MAKE_ERR("Only auto_pad==VALID, SAME_UPPER, SAME_LOWER and NOTSET " |
1411 | "are supported" ); |
1412 | } |
1413 | // Return default value 0 for pads. |
1414 | return zeroPads(); |
1415 | } |
1416 | |
1417 | /// Get the Pads value based on setting for auto_pad. |
1418 | /// \p kdim : kernel sizes (HW) |
1419 | /// \p sdim: stride sizes (HW) |
1420 | /// \p idim: input sizes (HW) |
1421 | static Expected<Pads> getConvTransposePadsfromOutput( |
1422 | ArgumentDictionaryTy &dict, llvm::ArrayRef<unsigned_t> kdim, |
1423 | llvm::ArrayRef<unsigned_t> sdim, llvm::ArrayRef<unsigned_t> dilation, |
1424 | llvm::ArrayRef<unsigned_t> idim, llvm::ArrayRef<unsigned_t> odim) { |
1425 | |
1426 | llvm::SmallVector<unsigned_t, 2> pdim(2); // Total Paddding, HW. |
1427 | for (size_t i = 0, e = pdim.size(); i < e; i++) { |
1428 | pdim[i] = sdim[i] * (idim[i] - 1) /* + output_padding[0]*/ + |
1429 | ((kdim[i] - 1) * dilation[i] + 1) - odim[i]; |
1430 | } |
1431 | |
1432 | unsigned_t top, left, bottom, right; |
1433 | |
1434 | if (dict.count("auto_pad" )) { |
1435 | std::string padStr; |
1436 | ASSIGN_VALUE_OR_RETURN_ERR(padStr, loadStr(dict.at("auto_pad" ))); |
1437 | if (padStr == "SAME_UPPER" ) { |
1438 | // SAME_UPPER ONNX formula: |
1439 | // if odd number for pdim[i], use extra padding at the end. |
1440 | // pads[start_i] = total_padding[i] - (total_padding[i]/2); |
1441 | // pads[end_i] = (total_padding[i]/2). |
1442 | top = pdim[0] - pdim[0] / 2; |
1443 | bottom = pdim[0] / 2; |
1444 | left = pdim[1] - pdim[1] / 2; |
1445 | right = pdim[1] / 2; |
1446 | return Pads({top, left, bottom, right}); |
1447 | } |
1448 | } |
1449 | // !SAME_UPPER ONNX formula: |
1450 | // pads[start_i] = total_padding[i]/2; |
1451 | // pads[end_i] = total_padding[i] - (total_padding[i]/2) |
1452 | top = pdim[0] / 2; |
1453 | bottom = pdim[0] - pdim[0] / 2; |
1454 | left = pdim[1] / 2; |
1455 | right = pdim[1] - pdim[1] / 2; |
1456 | return Pads({top, left, bottom, right}); |
1457 | } |
1458 | |
1459 | const std::string ONNXModelLoader::opErrMsg(const ONNX_NAMESPACE::NodeProto &op, |
1460 | const std::string &errMsg) { |
1461 | const std::string &opName = loadOperatorName(op); |
1462 | return strFormat(" [Operator-'%s', opset_version-%d, ir_version-%d] : %s " , |
1463 | opName.c_str(), int(opsetVersion_), int(irVersion_), |
1464 | errMsg.c_str()); |
1465 | } |
1466 | |
1467 | Error ONNXModelLoader::loadConstant(const ONNX_NAMESPACE::NodeProto &op, |
1468 | ArgumentDictionaryTy &dict) { |
1469 | /* |
1470 | output: "Parameter6" |
1471 | name: "Parameter6" |
1472 | op_type: "Constant" |
1473 | attribute { |
1474 | name: "value" |
1475 | t { |
1476 | dims: 8 |
1477 | data_type: FLOAT |
1478 | float_data: -0.161539719 |
1479 | float_data: -0.433835655 |
1480 | float_data: 0.091641359 |
1481 | float_data: -0.0168522168 |
1482 | float_data: -0.0650264397 |
1483 | float_data: -0.131737873 |
1484 | float_data: 0.0204175506 |
1485 | float_data: -0.121110231 |
1486 | } |
1487 | type: TENSOR |
1488 | } |
1489 | doc_string: "" |
1490 | domain: "" |
1491 | */ |
1492 | |
1493 | const auto &name = op.output(0); |
1494 | // If the tensor is pre-populated by the user of this class then we don't |
1495 | // need to allocate a new tensor. |
1496 | if (getConstantByNameOrNull(name)) { |
1497 | return Error::success(); |
1498 | } |
1499 | |
1500 | const auto &type = dict.at("value" )->type(); |
1501 | RETURN_ERR_IF_NOT((type == ONNX_NAMESPACE::AttributeProto::TENSOR || |
1502 | type == ONNX_NAMESPACE::AttributeProto::INTS), |
1503 | "Only Tensor type constants are supported." , |
1504 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); |
1505 | |
1506 | Tensor T; |
1507 | if (type == ONNX_NAMESPACE::AttributeProto::TENSOR) { |
1508 | RETURN_IF_ERR(loadTensor(dict.at("value" )->t(), &T, useGlowCustomOps_)); |
1509 | } else { |
1510 | std::vector<int64_t> ints; |
1511 | ASSIGN_VALUE_OR_RETURN_ERR(ints, getShape<int64_t>(dict["value" ])); |
1512 | T = Tensor(ElemKind::Int64ITy, {(dim_t)ints.size()}); |
1513 | auto TH = T.getHandle<int64_t>(); |
1514 | for (dim_t i = 0, e = ints.size(); i < e; ++i) { |
1515 | TH.at({i}) = ints[i]; |
1516 | } |
1517 | } |
1518 | RETURN_IF_ERR(createAndRegisterConstant(name, std::move(T))); |
1519 | |
1520 | return Error::success(); |
1521 | } |
1522 | |
1523 | /// Retrieves data from a constant Tensor and stores it in a vector. |
1524 | template <typename T, typename datatype = ssize_t> |
1525 | static void helperSetter(Constant *constT, std::vector<datatype> &vec) { |
1526 | auto constH = constT->getPayload().getHandle<T>(); |
1527 | for (dim_t i = 0; i < constH.size(); ++i) { |
1528 | vec.push_back(constH.at({i})); |
1529 | } |
1530 | } |
1531 | |
1532 | template <typename T> |
1533 | Error ONNXModelLoader::getRange(const ONNX_NAMESPACE::NodeProto &op, |
1534 | Constant *constT) { |
1535 | T start = constT->getPayload().getHandle<T>().raw(0); |
1536 | |
1537 | ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(1))); |
1538 | T limit = constT->getPayload().getHandle<T>().raw(0); |
1539 | |
1540 | ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(2))); |
1541 | T delta = constT->getPayload().getHandle<T>().raw(0); |
1542 | |
1543 | std::vector<T> rangeValues; |
1544 | if (limit > start) { |
1545 | RETURN_ERR_IF_NOT(delta > 0, "delta should be positive" ); |
1546 | auto i = start; |
1547 | while (i < limit) { |
1548 | rangeValues.push_back(i); |
1549 | i += delta; |
1550 | } |
1551 | } else if (limit < start) { |
1552 | RETURN_ERR_IF_NOT(delta < 0, "delta should be negative" ); |
1553 | auto i = start; |
1554 | while (i > limit) { |
1555 | rangeValues.push_back(i); |
1556 | i += delta; |
1557 | } |
1558 | } else { |
1559 | return MAKE_ERR("limit and start value should be different" ); |
1560 | } |
1561 | |
1562 | Tensor rangeTensor(constT->getElementType(), |
1563 | {static_cast<unsigned int>(rangeValues.size())}); |
1564 | rangeTensor.getHandle<T>() = rangeValues; |
1565 | RETURN_IF_ERR( |
1566 | createAndRegisterConstant(op.output(0), std::move(rangeTensor))); |
1567 | return Error::success(); |
1568 | } |
1569 | |
1570 | Error ONNXModelLoader::loadRange(const ONNX_NAMESPACE::NodeProto &op, |
1571 | ArgumentDictionaryTy &dict) { |
1572 | Constant *constT; |
1573 | ASSIGN_VALUE_OR_RETURN_ERR(constT, getConstantByName(op.input(0))); |
1574 | auto glowType = constT->getElementType(); |
1575 | if (glowType == ElemKind::Int64ITy) { |
1576 | return getRange<int64_t>(op, constT); |
1577 | } else if (glowType == ElemKind::Int32ITy) { |
1578 | return getRange<int32_t>(op, constT); |
1579 | } else if (glowType == ElemKind::FloatTy) { |
1580 | return getRange<float>(op, constT); |
1581 | } else { |
1582 | return MAKE_ERR("Data type not supported" ); |
1583 | } |
1584 | } |
1585 | |
1586 | Error ONNXModelLoader::loadPRelu(const ONNX_NAMESPACE::NodeProto &op, |
1587 | ArgumentDictionaryTy &dict) { |
1588 | const std::string &opName = loadOperatorName(op); |
1589 | |
1590 | NodeValue in; |
1591 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1592 | |
1593 | NodeValue slope; |
1594 | ASSIGN_VALUE_OR_RETURN_ERR(slope, getNodeValueByName(op.input(1))); |
1595 | |
1596 | // Do broadcasting. |
1597 | auto targetDim = in.dims(); |
1598 | // Sets the axis of each inputs so that the trailing-most dimensions of |
1599 | // input tensors and the target shape are aligned. |
1600 | int axis = targetDim.size() - slope.dims().size(); |
1601 | auto *finalSlope = G_->createBroadcast(opName, slope, targetDim, axis); |
1602 | auto *R = G_->createPRELU(opName, in, finalSlope); |
1603 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
1604 | return Error::success(); |
1605 | } |
1606 | |
1607 | Error ONNXModelLoader::loadSlice(const ONNX_NAMESPACE::NodeProto &op, |
1608 | ArgumentDictionaryTy &dict) { |
1609 | const std::string &opName = loadOperatorName(op); |
1610 | NodeValue data; |
1611 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1612 | auto dims = data.dims(); |
1613 | auto numDims = dims.size(); |
1614 | |
1615 | std::vector<ssize_t> starts; |
1616 | std::vector<ssize_t> ends; |
1617 | // Attribute 'axes' is optional. |
1618 | std::vector<ssize_t> axes; |
1619 | if (this->opsetVersion_ >= 10) { |
1620 | Constant *startsC = getConstantByNameOrNull(op.input(1)); |
1621 | Constant *endsC = getConstantByNameOrNull(op.input(2)); |
1622 | |
1623 | RETURN_ERR_IF_NOT(startsC, opErrMsg(op, "Starts Tensor is not Constant." )); |
1624 | RETURN_ERR_IF_NOT(endsC, opErrMsg(op, "Ends Tensor is not Constant." )); |
1625 | |
1626 | if (startsC->getElementType() == ElemKind::Int64ITy) { |
1627 | helperSetter<int64_t>(startsC, starts); |
1628 | } else if (startsC->getElementType() == ElemKind::Int32ITy) { |
1629 | helperSetter<int32_t>(startsC, starts); |
1630 | } else { |
1631 | RETURN_ERR_IF_NOT( |
1632 | false, |
1633 | opErrMsg( |
1634 | op, |
1635 | strFormat("Slice Starts Tensor has unsupported type '%s' " , |
1636 | startsC->getType()->getElementName().str().c_str()))); |
1637 | } |
1638 | |
1639 | if (endsC->getElementType() == ElemKind::Int64ITy) { |
1640 | helperSetter<int64_t>(endsC, ends); |
1641 | } else if (endsC->getElementType() == ElemKind::Int32ITy) { |
1642 | helperSetter<int32_t>(endsC, ends); |
1643 | } else { |
1644 | RETURN_ERR_IF_NOT( |
1645 | false, |
1646 | opErrMsg( |
1647 | op, strFormat("Slice Ends Tensor has unsupported type '%s' " , |
1648 | endsC->getType()->getElementName().str().c_str()))); |
1649 | } |
1650 | |
1651 | if (op.input_size() > 3) { |
1652 | Constant *axesC = getConstantByNameOrNull(op.input(3)); |
1653 | |
1654 | RETURN_ERR_IF_NOT(axesC, opErrMsg(op, "Axes Tensor is not Constant." )); |
1655 | |
1656 | if (axesC->getElementType() == ElemKind::Int64ITy) { |
1657 | helperSetter<int64_t>(axesC, axes); |
1658 | } else if (axesC->getElementType() == ElemKind::Int32ITy) { |
1659 | helperSetter<int32_t>(axesC, axes); |
1660 | } else { |
1661 | RETURN_ERR_IF_NOT( |
1662 | false, |
1663 | opErrMsg( |
1664 | op, |
1665 | strFormat("Slice Axes Tensor has unsupported type '%s' " , |
1666 | axesC->getType()->getElementName().str().c_str()))); |
1667 | } |
1668 | |
1669 | if (op.input_size() > 4) { |
1670 | std::vector<ssize_t> step; |
1671 | Constant *stepC = getConstantByNameOrNull(op.input(4)); |
1672 | |
1673 | RETURN_ERR_IF_NOT(stepC, opErrMsg(op, "Step tensor is not Constant." )); |
1674 | |
1675 | if (stepC->getElementType() == ElemKind::Int64ITy) { |
1676 | helperSetter<int64_t>(stepC, step); |
1677 | } else if (stepC->getElementType() == ElemKind::Int32ITy) { |
1678 | helperSetter<int32_t>(stepC, step); |
1679 | } else { |
1680 | RETURN_ERR_IF_NOT( |
1681 | false, |
1682 | opErrMsg( |
1683 | op, |
1684 | strFormat("Step Tensor has unsupported type '%s'" , |
1685 | stepC->getType()->getElementName().str().c_str()))); |
1686 | } |
1687 | |
1688 | // Step is interpreted 1 as default. |
1689 | for (size_t i = 0; i < step.size(); i++) { |
1690 | RETURN_ERR_IF_NOT(step[i] == 1, |
1691 | opErrMsg(op, "step!=1 is currently not supported" )); |
1692 | } |
1693 | } |
1694 | } |
1695 | } else { |
1696 | // Attributes 'starts' and 'ends' are mandatory and must be consistent. |
1697 | ASSIGN_VALUE_OR_RETURN_ERR(starts, getShape<ssize_t>(dict["starts" ])); |
1698 | ASSIGN_VALUE_OR_RETURN_ERR(ends, getShape<ssize_t>(dict["ends" ])); |
1699 | |
1700 | if (dict.count("axes" )) { |
1701 | // The ONNX spec is unclear so we consider that the 'axes' array may have |
1702 | // any size. The constraints are: |
1703 | // - the element value must be in range [0, numDims), |
1704 | // - 'starts' & 'ends' arrays must have the same size as the 'axes' array. |
1705 | // In case an axis is specified multiple times in 'axes', the later |
1706 | // parameters will simply overwrite the previous ones. |
1707 | ASSIGN_VALUE_OR_RETURN_ERR( |
1708 | axes, loadAxes<ssize_t>(dict["axes" ], data.dims().size())); |
1709 | } |
1710 | } |
1711 | RETURN_ERR_IF_NOT( |
1712 | (starts.size() == ends.size()), |
1713 | opErrMsg( |
1714 | op, |
1715 | strFormat("Slice: 'starts' and 'ends' arrays must have the same size." |
1716 | " but found starts %zu and ends %zu sizes " , |
1717 | starts.size(), ends.size()))); |
1718 | |
1719 | if (axes.empty()) { |
1720 | for (size_t i = 0; i < numDims; i++) { |
1721 | axes.push_back(ssize_t(i)); |
1722 | } |
1723 | } |
1724 | |
1725 | // The ONNX description is unclear and doesn't describe what to do when a |
1726 | // an axis index is not given in the axes array. An interpretation is that |
1727 | // for such an axis, the entire range is taken. Then, we initialize |
1728 | // newStarts and newEnds with the full range for all axes. |
1729 | std::vector<dim_t> newStarts(numDims); |
1730 | std::vector<dim_t> newEnds(numDims); |
1731 | for (size_t i = 0; i < numDims; i++) { |
1732 | newStarts[i] = 0; |
1733 | newEnds[i] = dims[i]; |
1734 | } |
1735 | |
1736 | // Determine the coordinates of the sub-tensor to extract. |
1737 | RETURN_ERR_IF_NOT(axes.size() == starts.size(), |
1738 | opErrMsg(op, strFormat("'axes' %zu and 'starts' %zu must" |
1739 | "be the same size." , |
1740 | axes.size(), starts.size()))); |
1741 | for (size_t i = 0; i < axes.size(); i++) { |
1742 | ssize_t newStart = starts[i]; |
1743 | ssize_t newEnd = ends[i]; |
1744 | ssize_t axisId = axes[i]; |
1745 | RETURN_ERR_IF_NOT( |
1746 | (axisId >= 0) && (axisId < ssize_t(numDims)), |
1747 | opErrMsg(op, "Axes indexes must be within the input tensor range." )); |
1748 | |
1749 | // ONNX: "If the value passed to start or end is larger than the n (the |
1750 | // number of elements in this dimension), it represents n". |
1751 | if (newStart > ssize_t(dims[axisId])) { |
1752 | newStart = ssize_t(dims[axisId]); |
1753 | } |
1754 | if (newEnd > ssize_t(dims[axisId])) { |
1755 | newEnd = ssize_t(dims[axisId]); |
1756 | } |
1757 | |
1758 | // The ONNX description is unclear and the numpy definition is more |
1759 | // accurate. |
1760 | // - ONNX: "Similar to numpy. [...]. If a negative value is passed for any |
1761 | // of the start or end indices, it represent number of elements before the |
1762 | // end of that dimension." |
1763 | // - Numpy: "Negative indices are interpreted as counting from the end of |
1764 | // the array (i.e., if n_i < 0, it means n_i + d_i)." |
1765 | if (newStart < 0) { |
1766 | newStart = ssize_t(dims[axisId]) + newStart; |
1767 | RETURN_ERR_IF_NOT( |
1768 | newStart >= 0, |
1769 | opErrMsg(op, strFormat("Slice: final start index %zu should " |
1770 | " never be negative." , |
1771 | newStart))); |
1772 | } |
1773 | if (newEnd < 0) { |
1774 | newEnd = ssize_t(dims[axisId]) + newEnd; |
1775 | RETURN_ERR_IF_NOT( |
1776 | newEnd >= 0, |
1777 | opErrMsg(op, strFormat("Slice: final end index %zu should " |
1778 | " never be negative." , |
1779 | newEnd))); |
1780 | } |
1781 | |
1782 | newStarts[axisId] = size_t(newStart); |
1783 | newEnds[axisId] = size_t(newEnd); |
1784 | } |
1785 | |
1786 | // Create the IR node. |
1787 | Node *SN = G_->createSlice(opName, data, newStarts, newEnds); |
1788 | RETURN_IF_ERR(addNodeAsOutput(op, SN)); |
1789 | |
1790 | return Error::success(); |
1791 | } |
1792 | |
1793 | Error ONNXModelLoader::loadTrigonometricOps(const std::string &typeName, |
1794 | const ONNX_NAMESPACE::NodeProto &op, |
1795 | ArgumentDictionaryTy &dict) { |
1796 | const std::string &opName = loadOperatorName(op); |
1797 | NodeValue in; |
1798 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1799 | Node *N; |
1800 | if (typeName == "Sin" ) { |
1801 | N = G_->createSin(opName, in); |
1802 | } else { |
1803 | N = G_->createCos(opName, in); |
1804 | } |
1805 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1806 | return Error::success(); |
1807 | } |
1808 | |
1809 | Error ONNXModelLoader::loadErf(const ONNX_NAMESPACE::NodeProto &op, |
1810 | const ArgumentDictionaryTy &dict) { |
1811 | const std::string &opName = loadOperatorName(op); |
1812 | NodeValue in; |
1813 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1814 | Node *N = G_->createErf(opName, in); |
1815 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1816 | return Error::success(); |
1817 | } |
1818 | |
1819 | Error ONNXModelLoader::loadConv(const ONNX_NAMESPACE::NodeProto &op, |
1820 | ArgumentDictionaryTy &dict) { |
1821 | // Load the inputs |
1822 | NodeValue in; |
1823 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1824 | |
1825 | if (in.dims().size() == 3) { |
1826 | return loadConv1D(op, dict); |
1827 | } else if (in.dims().size() == 4) { |
1828 | return loadConv2D(op, dict); |
1829 | } else if (in.dims().size() == 5) { |
1830 | return loadConv3D(op, dict); |
1831 | } else { |
1832 | return MAKE_ERR(strFormat( |
1833 | "Only 1D (3 dims), 2D (4 dims) and 3D (5 dims) convolution are " |
1834 | "supported by the ONNX loader but there are %zu input dims." , |
1835 | in.dims().size())); |
1836 | } |
1837 | } |
1838 | |
1839 | Error ONNXModelLoader::loadConv1D(const ONNX_NAMESPACE::NodeProto &op, |
1840 | ArgumentDictionaryTy &dict) { |
1841 | const std::string &opName = loadOperatorName(op); |
1842 | // Load the attributes |
1843 | std::vector<glow::unsigned_t> strides(2, 1); |
1844 | |
1845 | strides[1] = dict.count("strides" ) ? dict.at("strides" )->ints(0) : 1; |
1846 | strides[0] = 1; |
1847 | |
1848 | unsigned_t group = 1; |
1849 | if (dict.count("group" )) { |
1850 | ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group" ))); |
1851 | } |
1852 | |
1853 | std::vector<unsigned_t> dilations; |
1854 | ASSIGN_VALUE_OR_RETURN_ERR(dilations, |
1855 | getDilations(dict, std::vector<unsigned_t>{1, 1})); |
1856 | // Expand dilations to length 2 since Glow treat conv1D as conv2D |
1857 | if (dilations.size() == 1) { |
1858 | dilations.push_back(dilations[0]); |
1859 | } |
1860 | |
1861 | // Load the inputs |
1862 | NodeValue in; |
1863 | // input == NCW ---> NCHW |
1864 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1865 | in = G_->createExpandDims(opName, in, 2); |
1866 | // filtervalue == CKS ---> CKRS |
1867 | NodeValue filterValue; |
1868 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
1869 | filterValue = G_->createExpandDims(opName, filterValue, 2); |
1870 | // Transpose the filter to the right format. Glow expects to read the |
1871 | // weights in the format CRSK. ONNX stores the operators as CKRS. |
1872 | // C - output_depth, R - filter_height, S - filter_width, K - input_depth. |
1873 | // filtervalue == CKRS ---> CRSK |
1874 | TransposeNode *filterTransposeNode = |
1875 | G_->createTranspose(opName, filterValue, NCHW2NHWC); |
1876 | // The structure of the conv weights is: CRSK. We take the C, which is the |
1877 | // number of filters. We use this value to calculate the size of the bias |
1878 | // if it is not specified. |
1879 | const NodeValue filterTransposedValue = filterTransposeNode->getResult(); |
1880 | dim_t depth = filterTransposedValue.dims()[0]; |
1881 | |
1882 | // Construct the Bias field. |
1883 | NodeValue B; |
1884 | // Check if we have a serialized bias vector. |
1885 | if (op.input_size() > 2) { |
1886 | auto &biasTensorName = op.input(2); |
1887 | // Load the serialized bias vector as NodeValue. |
1888 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName)); |
1889 | } |
1890 | |
1891 | // If a serialized bias wasn't found then create a zero bias. |
1892 | if (op.input_size() == 2) { |
1893 | auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth}); |
1894 | Tensor biasTensor(biasTy); |
1895 | biasTensor.zero(); |
1896 | B = mod_.createConstant("conv.bias" , std::move(biasTensor)); |
1897 | } |
1898 | |
1899 | // ONNX passes the input as NCHW, and we expect the input to be NHWC. |
1900 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
1901 | // Calculate the size and allocate the output buffer. |
1902 | ShapeNHWC idim = ShapeNHWC(tr->getResult().dims()); |
1903 | llvm::SmallVector<unsigned_t, 2> idimHW(2); |
1904 | idimHW[0] = in.dims()[2]; |
1905 | idimHW[1] = in.dims()[3]; |
1906 | |
1907 | // Pads : {pad_top, pad_left, pad_bottom, pad_right} |
1908 | Pads pads; |
1909 | // Get the kernel shape. |
1910 | llvm::SmallVector<unsigned_t, 2> kernelShape(2); |
1911 | kernelShape[0] = filterTransposedValue.dims()[1]; |
1912 | kernelShape[1] = filterTransposedValue.dims()[2]; |
1913 | |
1914 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernelShape, strides, idimHW)); |
1915 | auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernelShape, strides, |
1916 | pads, dilations); |
1917 | std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}}; |
1918 | auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims); |
1919 | auto *node = G_->createConv(opName, tr, filterTransposeNode, B, outTy, |
1920 | kernelShape, strides, pads, group, dilations); |
1921 | |
1922 | auto *N = G_->createSqueeze(opName, node, 1 /*axes*/); |
1923 | // Transpose the output back |
1924 | auto *RR = G_->createTranspose(opName, N, {0, 2, 1}); |
1925 | RETURN_IF_ERR(addNodeAsOutput(op, RR)); |
1926 | return Error::success(); |
1927 | } |
1928 | |
1929 | Error ONNXModelLoader::loadConv2D(const ONNX_NAMESPACE::NodeProto &op, |
1930 | ArgumentDictionaryTy &dict) { |
1931 | const std::string &opName = loadOperatorName(op); |
1932 | |
1933 | // Load the inputs |
1934 | NodeValue in; |
1935 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1936 | |
1937 | NodeValue filterValue; |
1938 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
1939 | |
1940 | // Load the attributes |
1941 | std::vector<unsigned_t> strides(2, 1); |
1942 | if (dict.count("strides" )) { |
1943 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
1944 | } |
1945 | unsigned_t group = 1; |
1946 | if (dict.count("group" )) { |
1947 | ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group" ))); |
1948 | } |
1949 | |
1950 | std::vector<unsigned_t> dilations; |
1951 | ASSIGN_VALUE_OR_RETURN_ERR(dilations, |
1952 | getDilations(dict, std::vector<unsigned_t>{1, 1})); |
1953 | RETURN_ERR_IF_NOT( |
1954 | dilations.size() == 2, |
1955 | opErrMsg(op, strFormat("2D Conv dilations must be specified for 2 axes " |
1956 | " found axes %zu" , |
1957 | dilations.size()))); |
1958 | |
1959 | // Transpose the filter to the right format. Glow expects to read the |
1960 | // weights in the format CRSK. ONNX stores the operators as KCRS. |
1961 | // C - output_depth, R - filter_height, S - filter_width, K - input_depth. |
1962 | TransposeNode *filterTransposeNode = |
1963 | G_->createTranspose(opName, filterValue, NCHW2NHWC); |
1964 | |
1965 | // The structure of the conv weights is: CRSK. We take the C, which is the |
1966 | // number of filters. We use this value to calculate the size of the bias |
1967 | // if it is not specified. |
1968 | const NodeValue filterTransposedValue = filterTransposeNode->getResult(); |
1969 | dim_t depth = filterTransposedValue.dims()[0]; |
1970 | |
1971 | // Get the kernel shape from the input. |
1972 | llvm::SmallVector<unsigned_t, 2> kernelShape(2); |
1973 | kernelShape[0] = filterTransposedValue.dims()[1]; |
1974 | kernelShape[1] = filterTransposedValue.dims()[2]; |
1975 | |
1976 | // Extra check when the 'kernel_shape' attribute exists. |
1977 | // The 'kernel_shape' attribute is redundant not mandatory. |
1978 | if (dict.count("kernel_shape" )) { |
1979 | std::vector<unsigned_t> kernelShapeAttribute; |
1980 | ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute, |
1981 | getShape<unsigned_t>(dict["kernel_shape" ])); |
1982 | RETURN_ERR_IF_NOT((kernelShape[0] == kernelShapeAttribute[0] && |
1983 | kernelShape[1] == kernelShapeAttribute[1]), |
1984 | opErrMsg(op, "Conv The 'kernel_shape' attribute is not " |
1985 | "consistent with the actual " |
1986 | "convolution kernel shape." )); |
1987 | (void)kernelShapeAttribute; // Avoids compilation warning in release mode. |
1988 | } |
1989 | |
1990 | // Construct the Bias field. |
1991 | NodeValue B; |
1992 | // Check if we have a serialized bias vector. |
1993 | if (op.input_size() > 2) { |
1994 | auto &biasTensorName = op.input(2); |
1995 | // Load the serialized bias vector as NodeValue. |
1996 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName)); |
1997 | } |
1998 | |
1999 | // If a serialized bias wasn't found then create a zero bias. |
2000 | if (op.input_size() == 2) { |
2001 | auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth}); |
2002 | Tensor biasTensor(biasTy); |
2003 | biasTensor.zero(); |
2004 | B = mod_.createConstant("conv.bias" , std::move(biasTensor)); |
2005 | } |
2006 | |
2007 | // ONNX passes the input as NCHW, and we expect the input to be NHWC. |
2008 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
2009 | |
2010 | // Calculate the size and allocate the output buffer. |
2011 | ShapeNHWC idim = ShapeNHWC(tr->getResult().dims()); |
2012 | |
2013 | llvm::SmallVector<unsigned_t, 2> idimHW(2); |
2014 | idimHW[0] = in.dims()[2]; |
2015 | idimHW[1] = in.dims()[3]; |
2016 | |
2017 | // Pads : {pad_top, pad_left, pad_bottom, pad_right} |
2018 | Pads pads; |
2019 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernelShape, strides, idimHW)); |
2020 | |
2021 | auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernelShape, strides, |
2022 | pads, dilations); |
2023 | std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}}; |
2024 | auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims); |
2025 | |
2026 | auto *node = G_->createConv(opName, tr, filterTransposeNode, B, outTy, |
2027 | kernelShape, strides, pads, group, dilations); |
2028 | |
2029 | // Transpose the output back. |
2030 | auto *N = G_->createTranspose(opName, node, NHWC2NCHW); |
2031 | |
2032 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2033 | |
2034 | return Error::success(); |
2035 | } |
2036 | |
2037 | Error ONNXModelLoader::loadConv3D(const ONNX_NAMESPACE::NodeProto &op, |
2038 | ArgumentDictionaryTy &dict) { |
2039 | const std::string &opName = loadOperatorName(op); |
2040 | |
2041 | // Load the inputs |
2042 | NodeValue in; |
2043 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2044 | |
2045 | NodeValue filterValue; |
2046 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
2047 | |
2048 | // Load the attributes |
2049 | std::vector<unsigned_t> strides(3, 1); |
2050 | if (dict.count("strides" )) { |
2051 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2052 | } |
2053 | unsigned_t group = 1; |
2054 | if (dict.count("group" )) { |
2055 | ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group" ))); |
2056 | } |
2057 | |
2058 | std::vector<unsigned_t> dilations(3, 1); |
2059 | if (dict.count("dilations" )) { |
2060 | ASSIGN_VALUE_OR_RETURN_ERR( |
2061 | dilations, getDilations(dict, std::vector<unsigned_t>{1, 1, 1})); |
2062 | RETURN_ERR_IF_NOT( |
2063 | dilations.size() == 3, |
2064 | opErrMsg(op, strFormat("3D Conv dilations must be specified for 3 axes " |
2065 | "found %zu axes" , |
2066 | dilations.size()))); |
2067 | RETURN_ERR_IF_NOT( |
2068 | dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1, |
2069 | opErrMsg(op, strFormat("3D Conv dilations currently only support " |
2070 | "the default value of [1, 1, 1] but the model " |
2071 | "contains [%u, %u, %u]" , |
2072 | dilations[0], dilations[1], dilations[2]))); |
2073 | } |
2074 | |
2075 | // Transpose the filter to the right format. Glow expects to read the |
2076 | // weights in the format CRSK. ONNX stores the operators as KCRS. |
2077 | // C - output_depth, R - filter_height, S - filter_width, K - input_depth. |
2078 | TransposeNode *filterTransposeNode = |
2079 | G_->createTranspose(opName, filterValue, NCTHW2NTHWC); |
2080 | |
2081 | // The structure of the conv weights is: CRSK. We take the C, which is the |
2082 | // number of filters. We use this value to calculate the size of the bias |
2083 | // if it is not specified. |
2084 | const NodeValue filterTransposedValue = filterTransposeNode->getResult(); |
2085 | dim_t depth = filterTransposedValue.dims()[0]; |
2086 | |
2087 | // Get the kernel shape from the input. |
2088 | llvm::SmallVector<unsigned_t, 3> kernelShape(3); |
2089 | kernelShape[0] = filterTransposedValue.dims()[1]; |
2090 | kernelShape[1] = filterTransposedValue.dims()[2]; |
2091 | kernelShape[2] = filterTransposedValue.dims()[3]; |
2092 | |
2093 | // Extra check when the 'kernel_shape' attribute exists. |
2094 | // The 'kernel_shape' attribute is redundant not mandatory. |
2095 | if (dict.count("kernel_shape" )) { |
2096 | std::vector<unsigned_t> kernelShapeAttribute; |
2097 | ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute, |
2098 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2099 | RETURN_ERR_IF_NOT( |
2100 | (kernelShape[0] == kernelShapeAttribute[0] && |
2101 | kernelShape[1] == kernelShapeAttribute[1] && |
2102 | kernelShape[2] == kernelShapeAttribute[2]), |
2103 | opErrMsg( |
2104 | op, |
2105 | strFormat( |
2106 | "The 'kernel_shape' attribute [%d, %d, %d] is not consistent " |
2107 | "with the actual convolution kernel shape [%d, %d, %d]." , |
2108 | kernelShapeAttribute[0], kernelShapeAttribute[1], |
2109 | kernelShapeAttribute[2], kernelShape[0], kernelShape[1], |
2110 | kernelShape[2]))); |
2111 | (void)kernelShapeAttribute; // Avoids compilation warning in release mode. |
2112 | } |
2113 | |
2114 | // Construct the Bias field. |
2115 | NodeValue B; |
2116 | // Check if we have a serialized bias vector. |
2117 | if (op.input_size() > 2) { |
2118 | auto &biasTensorName = op.input(2); |
2119 | // Load the serialized bias vector as NodeValue. |
2120 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName)); |
2121 | } |
2122 | |
2123 | // If a serialized bias wasn't found then create a zero bias. |
2124 | if (op.input_size() == 2) { |
2125 | auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth}); |
2126 | Tensor biasTensor(biasTy); |
2127 | biasTensor.zero(); |
2128 | B = mod_.createConstant("conv.bias" , std::move(biasTensor)); |
2129 | } |
2130 | |
2131 | // ONNX passes the input as NCTHW, and we expect the input to be NTHWC. |
2132 | auto *tr = G_->createTranspose(opName, in, NCTHW2NTHWC); |
2133 | |
2134 | // Calculate the size and allocate the output buffer. |
2135 | ShapeNTHWC idim(tr->getResult().dims()); |
2136 | |
2137 | llvm::SmallVector<unsigned_t, 3> idimTHW(3); |
2138 | idimTHW = {static_cast<glow::unsigned_t>(idim.t), |
2139 | static_cast<glow::unsigned_t>(idim.h), |
2140 | static_cast<glow::unsigned_t>(idim.w)}; |
2141 | |
2142 | // Pads : {pad_near, pad_top, pad_left, pad_far, pad_bottom, pad_right} |
2143 | Pads tmpPads; |
2144 | ASSIGN_VALUE_OR_RETURN_ERR(tmpPads, |
2145 | getPads(dict, kernelShape, strides, idimTHW)); |
2146 | |
2147 | // Transpose padding from NTLFBR, which is the ONNX specified layout, to |
2148 | // NFTBLR, which is the glow internal layout per the interpreter |
2149 | // implementation. |
2150 | |
2151 | Pads pads = {tmpPads[0], tmpPads[3], tmpPads[1], |
2152 | tmpPads[4], tmpPads[2], tmpPads[5]}; |
2153 | |
2154 | auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, |
2155 | kernelShape, strides, pads); |
2156 | std::array<dim_t, 5> outDims = { |
2157 | {idim.n, outSz.temporal_frames, outSz.height, outSz.width, depth}}; |
2158 | auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims); |
2159 | |
2160 | auto *node = G_->createConv3D(opName, tr, filterTransposeNode, B, outTy, |
2161 | kernelShape, strides, pads, group); |
2162 | |
2163 | // Transpose the output back. |
2164 | auto *N = G_->createTranspose(opName, node, NTHWC2NCTHW); |
2165 | |
2166 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2167 | |
2168 | return Error::success(); |
2169 | } |
2170 | |
2171 | Error ONNXModelLoader::loadTensorwiseQuantizedConvolution( |
2172 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
2173 | const std::string &opName = loadOperatorName(op); |
2174 | |
2175 | NodeValue input; |
2176 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
2177 | NodeValue filterValue; |
2178 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
2179 | NodeValue biasValue; |
2180 | ASSIGN_VALUE_OR_RETURN_ERR(biasValue, getNodeValueByName(op.input(2))); |
2181 | |
2182 | std::vector<unsigned_t> kernels; |
2183 | ASSIGN_VALUE_OR_RETURN_ERR(kernels, |
2184 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2185 | std::vector<unsigned_t> strides; |
2186 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2187 | std::vector<unsigned_t> pads; |
2188 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<unsigned_t>(dict["pads" ])); |
2189 | |
2190 | unsigned_t groups; |
2191 | ASSIGN_VALUE_OR_RETURN_ERR(groups, loadInt(dict.at("group" ))); |
2192 | |
2193 | float outScale; |
2194 | ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale" ))); |
2195 | int32_t outOffset; |
2196 | ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset" ))); |
2197 | |
2198 | ShapeNHWC idim(input.dims()); |
2199 | auto outSz = |
2200 | calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads); |
2201 | std::array<dim_t, 4> outDims = { |
2202 | {idim.n, outSz.first, outSz.second, biasValue.dims()[0]}}; |
2203 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, outDims, outScale, outOffset); |
2204 | |
2205 | auto *node = G_->createConv(opName, input, filterValue, biasValue, outTy, |
2206 | kernels, strides, pads, groups); |
2207 | |
2208 | return addNodeAsOutput(op, node); |
2209 | } |
2210 | |
2211 | Error ONNXModelLoader::loadChannelwiseQuantizedConvolution( |
2212 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
2213 | const std::string &opName = loadOperatorName(op); |
2214 | |
2215 | NodeValue input; |
2216 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
2217 | NodeValue filterValue; |
2218 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
2219 | NodeValue biasValue; |
2220 | ASSIGN_VALUE_OR_RETURN_ERR(biasValue, getNodeValueByName(op.input(2))); |
2221 | NodeValue scalesValue; |
2222 | ASSIGN_VALUE_OR_RETURN_ERR(scalesValue, getNodeValueByName(op.input(3))); |
2223 | NodeValue offsetsValue; |
2224 | ASSIGN_VALUE_OR_RETURN_ERR(offsetsValue, getNodeValueByName(op.input(4))); |
2225 | |
2226 | std::vector<unsigned_t> kernels; |
2227 | ASSIGN_VALUE_OR_RETURN_ERR(kernels, |
2228 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2229 | std::vector<unsigned_t> strides; |
2230 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2231 | std::vector<unsigned_t> pads; |
2232 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<unsigned_t>(dict["pads" ])); |
2233 | |
2234 | unsigned_t group; |
2235 | ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group" ))); |
2236 | |
2237 | std::vector<unsigned_t> dilations; |
2238 | ASSIGN_VALUE_OR_RETURN_ERR(dilations, |
2239 | getDilations(dict, std::vector<unsigned_t>{1, 1})); |
2240 | |
2241 | float outScale; |
2242 | ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale" ))); |
2243 | int32_t outOffset; |
2244 | ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset" ))); |
2245 | |
2246 | ShapeNHWC idim(input.dims()); |
2247 | auto outSz = |
2248 | calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads); |
2249 | std::array<dim_t, 4> outDims = { |
2250 | {idim.n, outSz.first, outSz.second, biasValue.dims()[0]}}; |
2251 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, outDims, outScale, outOffset); |
2252 | |
2253 | // Quantize the filter automatically (only if it is float). The bias is NOT |
2254 | // quantized automatically and is left at the disposal of each Backend to |
2255 | // quantize it later using custom logic. |
2256 | auto *node = G_->createChannelwiseQuantizedConv( |
2257 | opName, input, filterValue, biasValue, scalesValue, offsetsValue, |
2258 | /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy, kernels, |
2259 | strides, pads, group, dilations, /* quantizeFilter */ true, |
2260 | /* quantizeBias */ false); |
2261 | |
2262 | return addNodeAsOutput(op, node); |
2263 | } |
2264 | |
2265 | Error ONNXModelLoader::loadConvTranspose(const ONNX_NAMESPACE::NodeProto &op, |
2266 | ArgumentDictionaryTy &dict) { |
2267 | const std::string &opName = loadOperatorName(op); |
2268 | // Load the attributes |
2269 | std::vector<unsigned_t> strides(2, 1); |
2270 | if (dict.count("strides" )) { |
2271 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2272 | } |
2273 | unsigned_t group = 1; |
2274 | if (dict.count("group" )) { |
2275 | ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict.at("group" ))); |
2276 | } |
2277 | |
2278 | std::vector<unsigned_t> dilations; |
2279 | ASSIGN_VALUE_OR_RETURN_ERR(dilations, |
2280 | getDilations(dict, std::vector<unsigned_t>{1, 1})); |
2281 | RETURN_ERR_IF_NOT(dilations.size() == 2, |
2282 | opErrMsg(op, strFormat("ConvTranspose dilations must be " |
2283 | "specified for 2 axes, found %zu " , |
2284 | dilations.size()))); |
2285 | |
2286 | // Load the inputs |
2287 | NodeValue in; |
2288 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2289 | NodeValue filterValue; |
2290 | ASSIGN_VALUE_OR_RETURN_ERR(filterValue, getNodeValueByName(op.input(1))); |
2291 | |
2292 | // Transpose the filter to the right format. Glow expects to read the |
2293 | // weights in the format CRSK. ONNX stores the operators as KCRS. |
2294 | // C - output_depth, R - filter_height, S - filter_width, K - input_depth. |
2295 | TransposeNode *filterTransposeNode = |
2296 | G_->createTranspose(opName, filterValue, CNHW2NHWC /* flip matrix */); |
2297 | |
2298 | // The structure of the conv weigts is: NHWC. We take the C, which is the |
2299 | // number of filters. We use this value to calculate the size of the bias |
2300 | // if it is not specified. |
2301 | const NodeValue filterTransposedValue = filterTransposeNode->getResult(); |
2302 | dim_t depth = filterTransposedValue.dims()[0] * group; |
2303 | |
2304 | // Get the kernel shape from the input. |
2305 | llvm::SmallVector<unsigned_t, 2> kernels(2); |
2306 | kernels[0] = filterTransposedValue.dims()[1]; |
2307 | kernels[1] = filterTransposedValue.dims()[2]; |
2308 | |
2309 | // Extra check when the 'kernel_shape' attribute exists. |
2310 | // The 'kernel_shape' attribute is redundant not mandatory. |
2311 | if (dict.count("kernel_shape" )) { |
2312 | std::vector<unsigned_t> kernelShapeAttribute; |
2313 | ASSIGN_VALUE_OR_RETURN_ERR(kernelShapeAttribute, |
2314 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2315 | RETURN_ERR_IF_NOT( |
2316 | (kernels[0] == kernelShapeAttribute[0] && |
2317 | kernels[1] == kernelShapeAttribute[1]), |
2318 | opErrMsg( |
2319 | op, |
2320 | "The 'kernel_shape' attribute is not consistent with the actual " |
2321 | "convolution kernel shape." )); |
2322 | (void)kernelShapeAttribute; // Avoids compilation warning in release mode. |
2323 | } |
2324 | |
2325 | // Construct the Bias field. |
2326 | NodeValue B; |
2327 | // Check if we have a serialized bias vector. |
2328 | if (op.input_size() > 2) { |
2329 | auto &biasTensorName = op.input(2); |
2330 | // Load the serialized bias vector as constant or NodeValue. |
2331 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(biasTensorName)); |
2332 | } |
2333 | |
2334 | // If a serialized bias wasn't found then create a zero bias. |
2335 | if (op.input_size() == 2) { |
2336 | auto biasTy = mod_.uniqueTypeWithNewShape(in.getType(), {depth}); |
2337 | Tensor biasTensor(biasTy); |
2338 | biasTensor.zero(); |
2339 | B = mod_.createConstant("conv.bias" , std::move(biasTensor)); |
2340 | } |
2341 | |
2342 | // ONNX passes the input as NCHW, and we expect the input to be NHWC. |
2343 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
2344 | |
2345 | // Calculate the size and allocate the output buffer. |
2346 | ShapeNHWC idim = ShapeNHWC(tr->getResult().dims()); |
2347 | |
2348 | llvm::SmallVector<unsigned_t, 2> idimHW(2); |
2349 | idimHW[0] = in.dims()[2]; |
2350 | idimHW[1] = in.dims()[3]; |
2351 | |
2352 | // Pads : {pad_top, pad_left, pad_bottom, pad_right} |
2353 | Pads pads; |
2354 | |
2355 | // Conv transpose output size (HxW) is either specified or calculated. |
2356 | std::pair<dim_t, dim_t> outSz; |
2357 | |
2358 | // Per spec, if output_shape is specified, pads are ignored. |
2359 | if (dict.count("output_shape" )) { |
2360 | std::vector<unsigned_t> outShape; |
2361 | ASSIGN_VALUE_OR_RETURN_ERR(outShape, |
2362 | getShape<unsigned_t>(dict["output_shape" ])); |
2363 | ASSIGN_VALUE_OR_RETURN_ERR( |
2364 | pads, getConvTransposePadsfromOutput(dict, kernels, strides, dilations, |
2365 | idimHW, outShape)); |
2366 | outSz = {outShape[0], outShape[1]}; |
2367 | |
2368 | std::pair<dim_t, dim_t> outSzTest = calculateConvTransposeOutputDims( |
2369 | idim.h, idim.w, kernels, strides, pads, dilations); |
2370 | RETURN_ERR_IF_NOT( |
2371 | (outShape[0] == outSzTest.first), |
2372 | opErrMsg(op, strFormat("ConvTranspose Expected %d /calculated %d " |
2373 | "pads don't match " , |
2374 | int(outShape[0]), int(outSzTest.first)))); |
2375 | RETURN_ERR_IF_NOT( |
2376 | (outShape[1] == outSzTest.second), |
2377 | opErrMsg(op, strFormat("ConvTranspose Expected %d /calculated %d " |
2378 | "pads don't match " , |
2379 | int(outShape[1]), int(outSzTest.second)))); |
2380 | } else { |
2381 | if (dict.count("output_padding" )) { |
2382 | std::vector<dim_t> outPad; |
2383 | ASSIGN_VALUE_OR_RETURN_ERR(outPad, |
2384 | getShape<dim_t>(dict["output_padding" ])); |
2385 | if (std::equal(outPad.begin() + 1, outPad.end(), outPad.begin()) && |
2386 | outPad[0] != 0) { |
2387 | LOG(FATAL) |
2388 | << "ConvTranspose argument 'output_padding' is not supported." ; |
2389 | } |
2390 | } |
2391 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW)); |
2392 | outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels, strides, |
2393 | pads, dilations); |
2394 | } |
2395 | std::array<dim_t, 4> outDims = {{idim.n, outSz.first, outSz.second, depth}}; |
2396 | auto outTy = mod_.uniqueTypeWithNewShape(in.getType(), outDims); |
2397 | |
2398 | auto *node = |
2399 | G_->createConvTranspose(opName, tr, filterTransposeNode, B, outTy, |
2400 | kernels, strides, pads, group, dilations); |
2401 | |
2402 | // Transpose the output back. |
2403 | auto *N = G_->createTranspose(opName, node, NHWC2NCHW); |
2404 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2405 | |
2406 | return Error::success(); |
2407 | } |
2408 | |
2409 | Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op, |
2410 | ArgumentDictionaryTy &dict, |
2411 | llvm::StringRef typeName) { |
2412 | const std::string &opName = loadOperatorName(op); |
2413 | |
2414 | // Load the inputs: |
2415 | NodeValue in; |
2416 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2417 | |
2418 | std::vector<unsigned_t> strides(2, 1); |
2419 | |
2420 | size_t inDim = in.dims().size(); |
2421 | |
2422 | std::vector<unsigned_t> kernelsShape; |
2423 | ASSIGN_VALUE_OR_RETURN_ERR(kernelsShape, |
2424 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2425 | |
2426 | size_t kerDim = kernelsShape.size(); |
2427 | |
2428 | std::vector<unsigned_t> kernels = {1, kernelsShape[kerDim - 1]}; |
2429 | |
2430 | bool countIncludePads; |
2431 | ASSIGN_VALUE_OR_RETURN_ERR( |
2432 | countIncludePads, getCountIncludePads(dict, /* defaultValue */ false)); |
2433 | |
2434 | // For maxPool1D inDim = 3 |
2435 | if (inDim == 3) { |
2436 | in = G_->createExpandDims(opName, in, 2); |
2437 | if (kerDim != 1) { |
2438 | return MAKE_ERR( |
2439 | opErrMsg(op, strFormat("Glow handles 1D pooling with kernel " |
2440 | "dimenstion size 1, but found %d " , |
2441 | int(kerDim))), |
2442 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); |
2443 | } else { |
2444 | if (dict.count("strides" )) { |
2445 | strides[1] = dict.at("strides" )->ints(0); |
2446 | strides[0] = 1; |
2447 | } |
2448 | } |
2449 | } |
2450 | |
2451 | if (kerDim == 2) { // For maxPool2D |
2452 | kernels[0] = kernelsShape[0]; |
2453 | if (dict.count("strides" )) { |
2454 | ASSIGN_VALUE_OR_RETURN_ERR(strides, |
2455 | getShape<unsigned_t>(dict["strides" ])); |
2456 | } |
2457 | } |
2458 | |
2459 | if (in.dims().size() != 4 || kernels.size() != 2) { |
2460 | // Glow only handles 2D pooling currently. |
2461 | return MAKE_ERR(opErrMsg(op, "Glow only handles 2D pooling currently." ), |
2462 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); |
2463 | } |
2464 | |
2465 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
2466 | |
2467 | // If 'global_pooling' is set then the operation will pool over the size of |
2468 | // the input by doing: kernel = height/width. |
2469 | if (dict.count("global_pooling" )) { |
2470 | auto Ty = in.getType(); |
2471 | kernels[0] = Ty->dims()[2]; |
2472 | kernels[1] = Ty->dims()[3]; |
2473 | } |
2474 | |
2475 | // NHWC |
2476 | llvm::SmallVector<unsigned_t, 2> idimHW(2); |
2477 | idimHW[0] = in.dims()[2]; // As per NCHW format |
2478 | idimHW[1] = in.dims()[3]; |
2479 | |
2480 | Pads pads; |
2481 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW)); |
2482 | |
2483 | Node *node = nullptr; |
2484 | if (op.output_size() > 1) { |
2485 | if (typeName != "MaxPool" ) { |
2486 | return MAKE_ERR( |
2487 | opErrMsg(op, "Pool Argmax output is only supported for MaxPool!" ), |
2488 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); |
2489 | } |
2490 | |
2491 | node = G_->createMaxPool(opName, tr, kernels, strides, pads); |
2492 | auto *res = G_->createTranspose(opName, NodeValue(node, 0), NHWC2NCHW); |
2493 | auto *argmax = G_->createTranspose(opName, NodeValue(node, 1), NHWC2NCHW); |
2494 | RETURN_IF_ERR(assignNodeOutputs(op, {res, argmax})); |
2495 | } else { |
2496 | size_t idx = 0; |
2497 | if (typeName == "MaxPool" ) { |
2498 | node = G_->createMaxPool(opName, tr, kernels, strides, pads); |
2499 | idx = MaxPoolNode::ResultIdx; |
2500 | } else { |
2501 | node = G_->createAvgPool(opName, tr, kernels, strides, pads, NHWC, |
2502 | countIncludePads); |
2503 | idx = AvgPoolNode::ResultIdx; |
2504 | } |
2505 | |
2506 | Node *N = nullptr; |
2507 | if (inDim == 3) { // For maxPool1D |
2508 | auto *R = G_->createSqueeze(opName, NodeValue(node, idx), 1); |
2509 | N = G_->createTranspose(opName, R, {0, 2, 1}); |
2510 | } else { |
2511 | N = G_->createTranspose(opName, NodeValue(node, idx), NHWC2NCHW); |
2512 | } |
2513 | |
2514 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2515 | } |
2516 | return Error::success(); |
2517 | } |
2518 | |
2519 | Error ONNXModelLoader::loadTensorwiseQuantizedPool( |
2520 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict, |
2521 | llvm::StringRef typeName) { |
2522 | const std::string &opName = loadOperatorName(op); |
2523 | |
2524 | // Load the inputs: |
2525 | NodeValue in; |
2526 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2527 | |
2528 | std::vector<unsigned_t> kernels; |
2529 | ASSIGN_VALUE_OR_RETURN_ERR(kernels, |
2530 | getShape<unsigned_t>(dict["kernel_shape" ])); |
2531 | std::vector<unsigned_t> strides; |
2532 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2533 | |
2534 | if (in.dims().size() != 4 || kernels.size() != 2) { |
2535 | // Glow only handles 2D pooling currently. |
2536 | return MAKE_ERR( |
2537 | opErrMsg(op, strFormat("TensorwiseQuantizedPool Glow only handles 2D " |
2538 | "pooling currently, but found kernel %zu " , |
2539 | kernels.size())), |
2540 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); |
2541 | } |
2542 | |
2543 | bool countIncludePads; |
2544 | ASSIGN_VALUE_OR_RETURN_ERR( |
2545 | countIncludePads, getCountIncludePads(dict, /* defaultValue */ false)); |
2546 | |
2547 | // NHWC |
2548 | llvm::SmallVector<unsigned_t, 2> idimHW(2); |
2549 | idimHW[0] = in.dims()[1]; |
2550 | idimHW[1] = in.dims()[2]; |
2551 | |
2552 | Pads pads; |
2553 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict, kernels, strides, idimHW)); |
2554 | |
2555 | if (op.output_size() > 1) { |
2556 | if (typeName != "MaxPool" ) { |
2557 | return MAKE_ERR(opErrMsg(op, |
2558 | "TensorwiseQuantizedPool Argmax output is only " |
2559 | "supported for MaxPool!" ), |
2560 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); |
2561 | } |
2562 | |
2563 | Node *maxpool = G_->createMaxPool(opName, in, kernels, strides, pads); |
2564 | auto res = maxpool->getNthResult(MaxPoolNode::ResultIdx); |
2565 | auto argmax = maxpool->getNthResult(MaxPoolNode::ArgmaxIdx); |
2566 | RETURN_IF_ERR(assignNodeOutputs(op, {res, argmax})); |
2567 | } else { |
2568 | Node *poolNode; |
2569 | if (typeName == "MaxPool" ) { |
2570 | poolNode = G_->createMaxPool(opName, in, kernels, strides, pads); |
2571 | } else { |
2572 | poolNode = G_->createAvgPool(opName, in, kernels, strides, pads, NHWC, |
2573 | countIncludePads); |
2574 | } |
2575 | RETURN_IF_ERR(addNodeAsOutput(op, poolNode)); |
2576 | } |
2577 | return Error::success(); |
2578 | } |
2579 | |
2580 | Error ONNXModelLoader::loadArgMinMax(const ONNX_NAMESPACE::NodeProto &op, |
2581 | ArgumentDictionaryTy &dict, bool isMin) { |
2582 | const std::string &opName = loadOperatorName(op); |
2583 | |
2584 | NodeValue in; |
2585 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2586 | size_t axis = 0; |
2587 | if (dict.count("axis" )) { |
2588 | ASSIGN_VALUE_OR_RETURN_ERR( |
2589 | axis, loadAxis<size_t>(dict.at("axis" ), in.dims().size())); |
2590 | } |
2591 | bool keepDims = true; |
2592 | if (dict.count("keepdims" )) { |
2593 | ASSIGN_VALUE_OR_RETURN_ERR(keepDims, loadInt(dict.at("keepdims" ))); |
2594 | } |
2595 | Node *node; |
2596 | if (isMin) { |
2597 | node = G_->createArgMin(opName, in, axis, keepDims); |
2598 | } else { |
2599 | node = G_->createArgMax(opName, in, axis, keepDims); |
2600 | } |
2601 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
2602 | return Error::success(); |
2603 | } |
2604 | |
2605 | Error ONNXModelLoader::loadUpsample(const ONNX_NAMESPACE::NodeProto &op, |
2606 | ArgumentDictionaryTy &dict) { |
2607 | |
2608 | RETURN_ERR_IF_NOT( |
2609 | (opsetVersion_ < 10) && (opsetVersion_ > 6), |
2610 | opErrMsg(op, "Version mismatch issue found, Upsample operator is " |
2611 | "supported for opset_version between 7 and 9" |
2612 | "use resize operator if opset_version > 9" )); |
2613 | |
2614 | const std::string &opName = loadOperatorName(op); |
2615 | NodeValue in; |
2616 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2617 | |
2618 | // Default mode of upsample operator is "nearest" |
2619 | std::string mode("nearest" ); |
2620 | if (dict.count("mode" )) { |
2621 | ASSIGN_VALUE_OR_RETURN_ERR(mode, loadStr(dict.at("mode" ))); |
2622 | } |
2623 | |
2624 | /// Only Nearest Mode is supported |
2625 | RETURN_ERR_IF_NOT(mode.compare("nearest" ) == 0, |
2626 | opErrMsg(op, strFormat("UpSample Operator has nearest mode " |
2627 | "support only, found mode '%s' " , |
2628 | mode.c_str()))); |
2629 | |
2630 | /// Scale is always float as per onnx documentation |
2631 | std::vector<float> scales; |
2632 | |
2633 | if (opsetVersion_ == 7) { |
2634 | if (dict.count("scales" )) { |
2635 | /// As per onnx documentation this is a required field |
2636 | /// and if not present then onnx.checker.check_model file check to fail |
2637 | ASSIGN_VALUE_OR_RETURN_ERR(scales, getFloats(dict["scales" ])); |
2638 | } else { |
2639 | return MAKE_ERR(opErrMsg(op, |
2640 | "UpSample Scales field is not present, expected " |
2641 | "for Upsample opset_version==7" )); |
2642 | } |
2643 | } |
2644 | |
2645 | if (opsetVersion_ > 7) { |
2646 | Constant *scale; |
2647 | ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1))); |
2648 | if (scale->getElementType() != ElemKind::FloatTy) { |
2649 | return MAKE_ERR(opErrMsg(op, |
2650 | "UpSample Scales Tensor should have float type " |
2651 | "for opset_version > 7" )); |
2652 | } |
2653 | auto constH = scale->getPayload().getHandle<float>(); |
2654 | for (dim_t i = 0; i < constH.size(); ++i) { |
2655 | scales.push_back(constH.at({i})); |
2656 | } |
2657 | } |
2658 | |
2659 | /// Scales tensor format is NHWC for supported modes other than nearest. |
2660 | /// For nearest mode scale can be 4D or 5D. |
2661 | RETURN_ERR_IF_NOT( |
2662 | (scales.size() >= 4 && scales.size() <= 5 && mode == "nearest" ) || |
2663 | scales.size() == 4, |
2664 | opErrMsg( |
2665 | op, strFormat( |
2666 | "UpSample Scales dimension invalid. Mode: %s Scale Size: %zu" , |
2667 | mode.c_str(), scales.size()))); |
2668 | |
2669 | for (auto &val : scales) { |
2670 | RETURN_ERR_IF_NOT( |
2671 | val >= 1, |
2672 | opErrMsg(op, strFormat("UpSample Scales value can only be " |
2673 | " greater than or equal to 1, but found %d" , |
2674 | int(val)))); |
2675 | } |
2676 | |
2677 | switch (scales.size()) { |
2678 | case 4: { |
2679 | vectorReorder(scales, {NHWC2NCHW}); |
2680 | auto *intr = G_->createTranspose(opName, in, NCHW2NHWC); |
2681 | auto *node = G_->createResizeNearest(opName, intr, scales); |
2682 | auto *N = G_->createTranspose(opName, node, NHWC2NCHW); |
2683 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2684 | return Error::success(); |
2685 | } |
2686 | case 5: { |
2687 | vectorReorder(scales, {NTHWC2NCTHW}); |
2688 | |
2689 | auto *intr = G_->createTranspose(opName, in, NCTHW2NTHWC); |
2690 | auto *node = G_->createResizeNearest(opName, intr, scales); |
2691 | auto *N = G_->createTranspose(opName, node, NTHWC2NCTHW); |
2692 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2693 | return Error::success(); |
2694 | } |
2695 | default: |
2696 | RETURN_ERR_IF_NOT( |
2697 | false, opErrMsg(op, strFormat("UpSample Scales dimension invalid" ))); |
2698 | } |
2699 | } |
2700 | |
2701 | Error ONNXModelLoader::loadResize(const ONNX_NAMESPACE::NodeProto &op, |
2702 | const ArgumentDictionaryTy &dict) { |
2703 | const std::string &opName = loadOperatorName(op); |
2704 | |
2705 | NodeValue in; |
2706 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2707 | |
2708 | std::string modeStr; |
2709 | ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode" ))); |
2710 | |
2711 | Constant *scalesC = nullptr; |
2712 | |
2713 | // Either scales or outDims will be populated (V11 can do either, V10 scales |
2714 | // only) |
2715 | std::vector<float> scales; |
2716 | std::vector<dim_t> outDims; |
2717 | |
2718 | int32_t scalesIdx = (this->opsetVersion_ >= 11) ? 2 : 1; |
2719 | scalesC = getConstantByNameOrNull(op.input(scalesIdx)); |
2720 | RETURN_ERR_IF_NOT( |
2721 | scalesC, |
2722 | opErrMsg(op, strFormat("Resize Scales Tensor '%s' is not Constant." , |
2723 | op.input(scalesIdx).c_str()))); |
2724 | if (scalesC->getElementType() != ElemKind::FloatTy) { |
2725 | return MAKE_ERR(opErrMsg( |
2726 | op, strFormat( |
2727 | "Resize Scales Tensor should have float type, but found '%s' " , |
2728 | scalesC->getType()->getElementName().str().c_str()))); |
2729 | } |
2730 | |
2731 | // For ONNX Resize v11, support attributes that are compatible with v10: |
2732 | // exclude_outside = 0 |
2733 | // extrapolation_value = 0.0 |
2734 | // nearest_mode = floor |
2735 | // coordinate_transformation_mode = asymmetric |
2736 | // mode = nearest, (bi)linear |
2737 | if (this->opsetVersion_ >= 11) { |
2738 | int32_t excludeOutside = 0; |
2739 | // attribute: exclude_outside. |
2740 | if (dict.count("exclude_outside" )) { |
2741 | ASSIGN_VALUE_OR_RETURN_ERR(excludeOutside, |
2742 | loadInt(dict.at("exclude_outside" ))); |
2743 | } |
2744 | RETURN_ERR_IF_NOT(excludeOutside == 0, |
2745 | opErrMsg(op, strFormat("ONNX Resize exclude outside " |
2746 | " not supported." ))); |
2747 | // attribute: extrapolation_value. |
2748 | float = 0.0; |
2749 | if (dict.count("extrapolation_value" )) { |
2750 | ASSIGN_VALUE_OR_RETURN_ERR(extrapolationValue, |
2751 | loadFloat(dict.at("extrapolation_value" ))); |
2752 | } |
2753 | RETURN_ERR_IF_NOT( |
2754 | extrapolationValue == 0.0, |
2755 | opErrMsg(op, strFormat("Resize extrapolation value 0 supported only, " |
2756 | "but found value %f" , |
2757 | extrapolationValue))); |
2758 | // attribute: nearest_mode. |
2759 | std::string nearestMode = "round_prefer_floor" ; |
2760 | if (dict.count("nearest_mode" )) { |
2761 | ASSIGN_VALUE_OR_RETURN_ERR(nearestMode, loadStr(dict.at("nearest_mode" ))); |
2762 | } |
2763 | if (modeStr == "nearest" && nearestMode != "floor" ) { |
2764 | return MAKE_ERR( |
2765 | opErrMsg(op, strFormat("Resize 'floor' and 'nearest' mode " |
2766 | "supported only, but found mode '%s' " , |
2767 | modeStr.c_str()))); |
2768 | } |
2769 | // attribute: coordinate_transformation_mode. |
2770 | std::string coordTransformMode = "half_pixel" ; |
2771 | if (dict.count("coordinate_transformation_mode" )) { |
2772 | ASSIGN_VALUE_OR_RETURN_ERR( |
2773 | coordTransformMode, |
2774 | loadStr(dict.at("coordinate_transformation_mode" ))); |
2775 | } |
2776 | RETURN_ERR_IF_NOT( |
2777 | coordTransformMode == "asymmetric" , |
2778 | opErrMsg(op, strFormat("Resize 'asymmetric' coordinate transformation " |
2779 | "mode supported only, but found %s" , |
2780 | coordTransformMode.c_str()))); |
2781 | |
2782 | // If no scales tensor, sizes tensor should be valid. |
2783 | if (scalesC->getPayload().getHandle().size() == 0) { |
2784 | Constant *sizesC; |
2785 | ASSIGN_VALUE_OR_RETURN_ERR(sizesC, getConstantByName(op.input(3))); |
2786 | RETURN_ERR_IF_NOT(sizesC, |
2787 | opErrMsg(op, strFormat("Resize Sizes Tensor '%s'" |
2788 | " is not Constant." , |
2789 | op.input(3).c_str()))); |
2790 | |
2791 | // Must be 1D tensor of int64_t. |
2792 | RETURN_ERR_IF_NOT( |
2793 | sizesC->dims().size() == 1, |
2794 | opErrMsg(op, strFormat("Resize Input must be a 1D vector." |
2795 | " but found vector size %zu " , |
2796 | sizesC->dims().size()))); |
2797 | RETURN_ERR_IF_NOT( |
2798 | sizesC->getType()->getElementType() == ElemKind::Int64ITy, |
2799 | opErrMsg(op, strFormat( |
2800 | "Resize Input element type must be Int64ITy, but " |
2801 | "found type '%s' " , |
2802 | sizesC->getType()->getElementName().str().c_str()))); |
2803 | |
2804 | auto sizesH = sizesC->getPayload().getHandle<int64_t>(); |
2805 | RETURN_ERR_IF_NOT( |
2806 | in.dims().size() == sizesH.size(), |
2807 | opErrMsg( |
2808 | op, |
2809 | strFormat("Data input %s and sizes input %s must match in size." , |
2810 | std::to_string(in.dims().size()).c_str(), |
2811 | std::to_string(sizesH.size()).c_str()))); |
2812 | // Now fill the output tensor |
2813 | for (dim_t i = 0; i < sizesH.size(); ++i) { |
2814 | outDims.push_back(sizesH.at({i})); |
2815 | } |
2816 | } else { |
2817 | RETURN_ERR_IF_NOT( |
2818 | op.input_size() == 3, |
2819 | opErrMsg(op, "Resize 'sizes' not valid with 'scales' input" )); |
2820 | } |
2821 | } // v11 processing. |
2822 | |
2823 | NodeValue outtr = nullptr; |
2824 | auto scalesH = scalesC->getPayload().getHandle(); |
2825 | |
2826 | // Check is scales is not empty - if yes, use it. |
2827 | if (scalesH.size()) { |
2828 | for (dim_t i = 0; i < scalesH.size(); ++i) { |
2829 | scales.push_back(scalesH.at({i})); |
2830 | } |
2831 | |
2832 | // Scales tensor format is NHWC for supported modes other than nearest. |
2833 | // For nearest mode scale can be 3D, 4D, 5D, or 6D. |
2834 | RETURN_ERR_IF_NOT( |
2835 | (scales.size() >= 3 && scales.size() <= 6 && modeStr == "nearest" ) || |
2836 | scales.size() == 4, |
2837 | opErrMsg( |
2838 | op, strFormat( |
2839 | "Resize Scales dimension invalid. Mode: %s Scale Size: %zu" , |
2840 | modeStr.c_str(), scales.size()))); |
2841 | |
2842 | for (auto &val : scales) { |
2843 | RETURN_ERR_IF_NOT( |
2844 | val > 0, |
2845 | opErrMsg( |
2846 | op, |
2847 | strFormat( |
2848 | "Resize Scale value must be greater than zero, but found %d" , |
2849 | int(val)))); |
2850 | } |
2851 | |
2852 | if (modeStr == "nearest" ) { |
2853 | outtr = G_->createResizeNearest(opName, in, scales); |
2854 | } else if (modeStr == "bilinear" || modeStr == "linear" ) { |
2855 | vectorReorder(scales, {NHWC2NCHW}); |
2856 | auto *intr = G_->createTranspose(opName, in, NCHW2NHWC); |
2857 | auto RN = G_->createResizeBilinear(opName, intr, scales); |
2858 | outtr = G_->createTranspose(opName, RN, NHWC2NCHW); |
2859 | } else { |
2860 | return MAKE_ERR( |
2861 | opErrMsg(op, strFormat("Resize Supports nearest or bilinear " |
2862 | "interpolation only, but found mode as '%s' " , |
2863 | modeStr.c_str())), |
2864 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
2865 | } |
2866 | } else if (outDims.size()) { |
2867 | if (modeStr == "nearest" ) { |
2868 | auto outTy = G_->getParent()->uniqueTypeWithNewShape( |
2869 | in.getType(), llvm::ArrayRef<dim_t>(outDims)); |
2870 | outtr = G_->createResizeNearest(opName, in, outTy); |
2871 | } else if (modeStr == "bilinear" || modeStr == "linear" ) { |
2872 | vectorReorder(outDims, {NHWC2NCHW}); |
2873 | auto outTy = G_->getParent()->uniqueTypeWithNewShape( |
2874 | in.getType(), llvm::ArrayRef<dim_t>(outDims)); |
2875 | auto *intr = G_->createTranspose(opName, in, NCHW2NHWC); |
2876 | auto RN = G_->createResizeBilinear(opName, intr, outTy); |
2877 | outtr = G_->createTranspose(opName, RN, NHWC2NCHW); |
2878 | } else { |
2879 | return MAKE_ERR( |
2880 | opErrMsg(op, strFormat( |
2881 | "Supporting nearest or (bi)linear interpolation only" |
2882 | " but found mode '%s' " , |
2883 | modeStr.c_str())), |
2884 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
2885 | } |
2886 | } else { |
2887 | return MAKE_ERR(opErrMsg(op, "Resize Neither scales or sizes are set." ), |
2888 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
2889 | } |
2890 | |
2891 | RETURN_IF_ERR(addNodeAsOutput(op, outtr)); |
2892 | return Error::success(); |
2893 | } |
2894 | |
2895 | Error ONNXModelLoader::loadGlobalAveragePool( |
2896 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
2897 | const std::string &opName = loadOperatorName(op); |
2898 | |
2899 | // Load the inputs: |
2900 | NodeValue in; |
2901 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2902 | std::vector<unsigned_t> strides(2, 1); |
2903 | if (dict.count("strides" )) { |
2904 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getShape<unsigned_t>(dict["strides" ])); |
2905 | } |
2906 | |
2907 | llvm::SmallVector<unsigned_t, 2> kernels(2); |
2908 | kernels[0] = in.dims()[2]; |
2909 | kernels[1] = in.dims()[3]; |
2910 | |
2911 | Pads pads; |
2912 | ASSIGN_VALUE_OR_RETURN_ERR( |
2913 | pads, getPads(dict, kernels, strides, kernels /* input sizes*/)); |
2914 | |
2915 | bool countIncludePads; |
2916 | ASSIGN_VALUE_OR_RETURN_ERR( |
2917 | countIncludePads, getCountIncludePads(dict, /* defaultValue */ false)); |
2918 | |
2919 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
2920 | Node *node = G_->createAvgPool(opName, tr, kernels, strides, pads, NHWC, |
2921 | countIncludePads); |
2922 | auto *N = G_->createTranspose(opName, node, NHWC2NCHW); |
2923 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
2924 | return Error::success(); |
2925 | } |
2926 | |
2927 | Error ONNXModelLoader::loadSqueeze(const ONNX_NAMESPACE::NodeProto &op, |
2928 | ArgumentDictionaryTy &dict) { |
2929 | const std::string &opName = loadOperatorName(op); |
2930 | |
2931 | NodeValue in; |
2932 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2933 | std::vector<dim_t> axes; |
2934 | ASSIGN_VALUE_OR_RETURN_ERR(axes, |
2935 | loadAxes<dim_t>(dict["axes" ], in.dims().size())); |
2936 | Node *node = G_->createSqueeze(opName, in, axes); |
2937 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
2938 | return Error::success(); |
2939 | } |
2940 | |
2941 | Error ONNXModelLoader::loadUnsqueeze(const ONNX_NAMESPACE::NodeProto &op, |
2942 | ArgumentDictionaryTy &dict) { |
2943 | const std::string &opName = loadOperatorName(op); |
2944 | |
2945 | NodeValue in; |
2946 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2947 | |
2948 | // Compute output rank. |
2949 | std::vector<int> axesTemp; |
2950 | ASSIGN_VALUE_OR_RETURN_ERR(axesTemp, getShape<int>(dict["axes" ])); |
2951 | int outputRank = in.dims().size() + axesTemp.size(); |
2952 | |
2953 | // Read again the axes and use the output rank to wrap negative axes. |
2954 | std::vector<dim_t> axes; |
2955 | ASSIGN_VALUE_OR_RETURN_ERR(axes, loadAxes<dim_t>(dict["axes" ], outputRank)); |
2956 | |
2957 | Node *node = G_->createExpandDims(opName, in, axes); |
2958 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
2959 | return Error::success(); |
2960 | } |
2961 | |
2962 | Error ONNXModelLoader::loadBatchNormalization( |
2963 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
2964 | const std::string &opName = loadOperatorName(op); |
2965 | |
2966 | NodeValue in; |
2967 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
2968 | NodeValue scale; |
2969 | ASSIGN_VALUE_OR_RETURN_ERR(scale, getNodeValueByName(op.input(1))); |
2970 | NodeValue bias; |
2971 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(2))); |
2972 | NodeValue mean; |
2973 | ASSIGN_VALUE_OR_RETURN_ERR(mean, getNodeValueByName(op.input(3))); |
2974 | NodeValue var; |
2975 | ASSIGN_VALUE_OR_RETURN_ERR(var, getNodeValueByName(op.input(4))); |
2976 | float epsilon = 1e-5f; // default |
2977 | auto epsilonIt = dict.find("epsilon" ); |
2978 | if (epsilonIt != dict.end()) { |
2979 | ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second)); |
2980 | } |
2981 | |
2982 | auto *node = G_->createBatchNormalization(opName, in.getType(), in, bias, |
2983 | scale, mean, var, 1, epsilon); |
2984 | |
2985 | // BatchNormalization has 4 optional outputs that are not supported by glow. |
2986 | // Then: 1/ In case the optional outputs are present and used by other |
2987 | // operations of the model, then the import should fail. 2/ In case the |
2988 | // optional outputs are declared but not used, the import should succeed. By |
2989 | // registering only the mandatory output, we make sure the import will fail if |
2990 | // the non supported features are actually requested by the ONNX model. |
2991 | RETURN_IF_ERR(addNodeAsOutput(op, node, 1)); |
2992 | |
2993 | return Error::success(); |
2994 | } |
2995 | |
2996 | Error ONNXModelLoader::loadInstanceNormalization( |
2997 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
2998 | const std::string &opName = loadOperatorName(op); |
2999 | |
3000 | NodeValue in; |
3001 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
3002 | NodeValue scale; |
3003 | ASSIGN_VALUE_OR_RETURN_ERR(scale, getNodeValueByName(op.input(1))); |
3004 | NodeValue bias; |
3005 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(2))); |
3006 | |
3007 | float epsilon = 1e-5f; // default |
3008 | auto epsilonIt = dict.find("epsilon" ); |
3009 | if (epsilonIt != dict.end()) { |
3010 | ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second)); |
3011 | } |
3012 | |
3013 | auto *node = |
3014 | G_->createInstanceNormalization(opName, in, bias, scale, 1, epsilon); |
3015 | RETURN_IF_ERR(addNodeAsOutput(op, node, 1)); |
3016 | |
3017 | return Error::success(); |
3018 | } |
3019 | |
3020 | Error ONNXModelLoader::loadConcat(const ONNX_NAMESPACE::NodeProto &op, |
3021 | ArgumentDictionaryTy &dict) { |
3022 | const std::string &opName = loadOperatorName(op); |
3023 | |
3024 | const unsigned numInputs = op.input_size(); |
3025 | llvm::SmallVector<NodeValue, 4> inputs; |
3026 | inputs.reserve(numInputs); |
3027 | for (unsigned i = 0; i < numInputs; i++) { |
3028 | NodeValue in; |
3029 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i))); |
3030 | inputs.push_back(in); |
3031 | } |
3032 | |
3033 | int axis; |
3034 | ASSIGN_VALUE_OR_RETURN_ERR( |
3035 | axis, loadAxis<int>(dict.at("axis" ), inputs.back().dims().size())); |
3036 | |
3037 | Node *node = G_->createConcat(opName, inputs, axis); |
3038 | |
3039 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
3040 | return Error::success(); |
3041 | } |
3042 | |
3043 | Error ONNXModelLoader::loadFCTransposed(const ONNX_NAMESPACE::NodeProto &op, |
3044 | ArgumentDictionaryTy &dict) { |
3045 | const std::string &opName = loadOperatorName(op); |
3046 | |
3047 | NodeValue in; |
3048 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
3049 | if (in.getType()->dims().size() > 2) { |
3050 | size_t axis = 1; |
3051 | if (dict.count("axis" )) { |
3052 | ASSIGN_VALUE_OR_RETURN_ERR( |
3053 | axis, loadAxis<size_t>(dict.at("axis" ), in.dims().size())); |
3054 | } |
3055 | in = G_->createFlatten(opName + ".fc.in" , in, axis); |
3056 | } |
3057 | |
3058 | unsigned_t axis_w = 1; |
3059 | if (dict.count("axis_w" )) { |
3060 | ASSIGN_VALUE_OR_RETURN_ERR(axis_w, loadInt(dict.at("axis_w" ))); |
3061 | } |
3062 | |
3063 | Constant *W; |
3064 | ASSIGN_VALUE_OR_RETURN_ERR(W, getConstantByName(op.input(1))); |
3065 | |
3066 | // w is stored already transposed. No need to additionally transpose it. |
3067 | if (W->dims().size() > 2) { |
3068 | Tensor tmp; |
3069 | auto wDims = flattenCdr(W->dims(), axis_w); |
3070 | tmp.reset(ElemKind::FloatTy, {wDims.first, wDims.second}); |
3071 | tmp.copyRawFrom(&W->getPayload()); |
3072 | W = mod_.createConstant(W->getName(), tmp); |
3073 | } |
3074 | |
3075 | Constant *B; |
3076 | ASSIGN_VALUE_OR_RETURN_ERR(B, getConstantByName(op.input(2))); |
3077 | |
3078 | auto *node = G_->createFullyConnected(opName, in, W, B); |
3079 | |
3080 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
3081 | return Error::success(); |
3082 | } |
3083 | |
3084 | Error ONNXModelLoader::loadGemm(const ONNX_NAMESPACE::NodeProto &op, |
3085 | ArgumentDictionaryTy &dict) { |
3086 | const std::string &opName = loadOperatorName(op); |
3087 | NodeValue A; |
3088 | ASSIGN_VALUE_OR_RETURN_ERR(A, getNodeValueByName(op.input(0))); |
3089 | NodeValue B; |
3090 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(1))); |
3091 | NodeValue C = nullptr; |
3092 | if (op.input_size() > 2 && !op.input(2).empty()) { |
3093 | ASSIGN_VALUE_OR_RETURN_ERR(C, getNodeValueByName(op.input(2))); |
3094 | } |
3095 | |
3096 | float alpha = 1.0; |
3097 | if (dict.count("alpha" )) { |
3098 | ASSIGN_VALUE_OR_RETURN_ERR(alpha, loadFloat(dict.at("alpha" ))); |
3099 | } |
3100 | |
3101 | float beta = 1.0; |
3102 | if (dict.count("beta" )) { |
3103 | ASSIGN_VALUE_OR_RETURN_ERR(beta, loadFloat(dict.at("beta" ))); |
3104 | } |
3105 | |
3106 | bool transA = false; |
3107 | if (dict.count("transA" )) { |
3108 | ASSIGN_VALUE_OR_RETURN_ERR(transA, loadInt(dict.at("transA" ))); |
3109 | } |
3110 | |
3111 | bool transB = false; |
3112 | if (dict.count("transB" )) { |
3113 | ASSIGN_VALUE_OR_RETURN_ERR(transB, loadInt(dict.at("transB" ))); |
3114 | } |
3115 | |
3116 | Node *node = G_->createGemm(opName, A, B, C, alpha, beta, transA, transB); |
3117 | |
3118 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
3119 | return Error::success(); |
3120 | } |
3121 | |
3122 | Error ONNXModelLoader::loadMatMul(const ONNX_NAMESPACE::NodeProto &op, |
3123 | ArgumentDictionaryTy &dict) { |
3124 | const std::string &opName = loadOperatorName(op); |
3125 | |
3126 | NodeValue LHS; |
3127 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0))); |
3128 | NodeValue RHS; |
3129 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1))); |
3130 | |
3131 | /// For dimension greater than 2 use batchedMatMul |
3132 | if (LHS.dims().size() > 2) { |
3133 | Node *node = G_->createBatchMatMul(opName, LHS, RHS); |
3134 | const size_t numDimsLHS = LHS.dims().size(); |
3135 | if (numDimsLHS > 3) { |
3136 | const size_t numDimsRHS = RHS.dims().size(); |
3137 | std::vector<dim_t> finalShape; |
3138 | for (auto d : LHS.dims()) { |
3139 | finalShape.push_back(d); |
3140 | } |
3141 | finalShape[numDimsLHS - 1] = RHS.dims()[numDimsRHS - 1]; |
3142 | node = G_->createReshape(opName, node, finalShape); |
3143 | } |
3144 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
3145 | } else { |
3146 | Node *node = G_->createMatMul(opName, LHS, RHS); |
3147 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
3148 | } |
3149 | return Error::success(); |
3150 | } |
3151 | |
3152 | Error ONNXModelLoader::loadHardSigmoid(const ONNX_NAMESPACE::NodeProto &op, |
3153 | ArgumentDictionaryTy &dict) { |
3154 | const std::string &opName = loadOperatorName(op); |
3155 | |
3156 | // Input Type. |
3157 | NodeValue input; |
3158 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3159 | |
3160 | float alphaVal = 0.2f; |
3161 | if (dict.count("alpha" )) { |
3162 | ASSIGN_VALUE_OR_RETURN_ERR(alphaVal, loadFloat(dict.at("alpha" ))); |
3163 | } |
3164 | float betaVal = 0.5f; |
3165 | if (dict.count("beta" )) { |
3166 | ASSIGN_VALUE_OR_RETURN_ERR(betaVal, loadFloat(dict.at("beta" ))); |
3167 | } |
3168 | |
3169 | // Create the node. |
3170 | Node *N = G_->createHardSigmoid(opName, input, alphaVal, betaVal); |
3171 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3172 | |
3173 | return Error::success(); |
3174 | } |
3175 | |
3176 | Error ONNXModelLoader::loadLeakyRelu(const ONNX_NAMESPACE::NodeProto &op, |
3177 | ArgumentDictionaryTy &dict) { |
3178 | const std::string &opName = loadOperatorName(op); |
3179 | |
3180 | // Input Type. |
3181 | NodeValue input; |
3182 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3183 | |
3184 | // ONNX spec says default is 0.01, but doesn't explicitly say it's optional. |
3185 | // like for others. The default example just omits alpha. |
3186 | float alphaVal = 0.01f; |
3187 | if (dict.count("alpha" )) { |
3188 | ASSIGN_VALUE_OR_RETURN_ERR(alphaVal, loadFloat(dict.at("alpha" ))); |
3189 | } |
3190 | |
3191 | // Create the node. |
3192 | Node *N = G_->createLeakyRELU(opName, input, alphaVal); |
3193 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3194 | |
3195 | return Error::success(); |
3196 | } |
3197 | |
3198 | Error ONNXModelLoader::loadPad(const ONNX_NAMESPACE::NodeProto &op, |
3199 | ArgumentDictionaryTy &dict) { |
3200 | const std::string &opName = loadOperatorName(op); |
3201 | |
3202 | // Input |
3203 | NodeValue input; |
3204 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3205 | auto inputDims = input.dims(); |
3206 | auto numDims = inputDims.size(); |
3207 | |
3208 | // Padding properties. |
3209 | unsigned_t mode = PaddingMode::CONSTANT; // default is constant. |
3210 | if (dict.count("mode" )) { |
3211 | std::string modeStr; |
3212 | ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode" ))); |
3213 | if (modeStr == "constant" ) { |
3214 | mode = PaddingMode::CONSTANT; |
3215 | } else if (modeStr == "reflect" ) { |
3216 | mode = PaddingMode::REFLECT; |
3217 | } else if (modeStr == "edge" ) { |
3218 | mode = PaddingMode::EDGE; |
3219 | } else { |
3220 | return MAKE_ERR( |
3221 | "Pad: Invalid mode" , |
3222 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
3223 | } |
3224 | } |
3225 | |
3226 | std::vector<int> pads; |
3227 | |
3228 | if (this->opsetVersion_ > 10) { |
3229 | // Get pads through input(1) from opset v11. |
3230 | RETURN_ERR_IF_NOT( |
3231 | op.input_size() > 1, |
3232 | "Pad: The 'pads' is mandatory as input(1) from opsetv11." ); |
3233 | Constant *padC; |
3234 | ASSIGN_VALUE_OR_RETURN_ERR(padC, getConstantByName(op.input(1))); |
3235 | RETURN_ERR_IF_NOT(padC, "Support only constant pad" ); |
3236 | helperSetter<int64_t, int>(padC, pads); |
3237 | } else { |
3238 | RETURN_ERR_IF_NOT(dict.count("pads" ), |
3239 | "Pad: The 'pads' property is mandatory" ); |
3240 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getShape<int>(dict["pads" ])); |
3241 | } |
3242 | |
3243 | RETURN_ERR_IF_NOT( |
3244 | (pads.size() == 2 * numDims), |
3245 | opErrMsg(op, " The 'pads' array must contain 2 values per dimensions" )); |
3246 | |
3247 | float value = 0.f; |
3248 | if (this->opsetVersion_ > 10) { |
3249 | if (op.input_size() > 2) { |
3250 | Constant *valueC; |
3251 | ASSIGN_VALUE_OR_RETURN_ERR(valueC, getConstantByName(op.input(2))); |
3252 | RETURN_ERR_IF_NOT(valueC, "Support only constant value in Pad" ); |
3253 | RETURN_ERR_IF_NOT(valueC->getElementType() == ElemKind::FloatTy, |
3254 | "Value in Pad should be float type." ); |
3255 | value = valueC->getPayload().getHandle().raw(0); |
3256 | } |
3257 | } else { |
3258 | if (dict.count("value" )) { |
3259 | ASSIGN_VALUE_OR_RETURN_ERR(value, loadFloat(dict.at("value" ))); |
3260 | } |
3261 | } |
3262 | |
3263 | // Compute the output type. |
3264 | std::vector<dim_t> outDims(numDims); |
3265 | for (unsigned_t i = 0; i < numDims; i++) { |
3266 | auto new_dim = inputDims[i] + pads[i] + pads[i + numDims]; |
3267 | RETURN_ERR_IF_NOT( |
3268 | new_dim > 0, |
3269 | opErrMsg(op, "The padding can't remove all elements of a dimension" )); |
3270 | outDims[i] = new_dim; |
3271 | } |
3272 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, outDims); |
3273 | |
3274 | // Create the IR node. |
3275 | Node *N = G_->createPad(opName, input, outTy, mode, pads, value); |
3276 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3277 | |
3278 | return Error::success(); |
3279 | } |
3280 | |
3281 | Error ONNXModelLoader::loadCast(const ONNX_NAMESPACE::NodeProto &op, |
3282 | ArgumentDictionaryTy &dict) { |
3283 | const std::string &opName = loadOperatorName(op); |
3284 | |
3285 | // Input type |
3286 | NodeValue input; |
3287 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3288 | ElemKind inputKind = input.getType()->getElementType(); |
3289 | |
3290 | // Target type. |
3291 | ElemKind targetKind; |
3292 | RETURN_ERR_IF_NOT(dict.count("to" ), |
3293 | opErrMsg(op, "Cast missing 'to' attribute" )); |
3294 | int toONNXTypeValue; |
3295 | ASSIGN_VALUE_OR_RETURN_ERR(toONNXTypeValue, loadInt(dict.at("to" ))); |
3296 | RETURN_ERR_IF_NOT( |
3297 | ONNX_NAMESPACE::TensorProto_DataType_IsValid(toONNXTypeValue), |
3298 | opErrMsg(op, "Cast invalid target type" ), |
3299 | ErrorValue::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); |
3300 | ASSIGN_VALUE_OR_RETURN_ERR( |
3301 | targetKind, convertTensorProtoDataType( |
3302 | ONNX_NAMESPACE::TensorProto_DataType(toONNXTypeValue))); |
3303 | |
3304 | // Only support non quantized types. |
3305 | RETURN_ERR_IF_NOT( |
3306 | (!isQuantizedElemKind(inputKind)) && (!isQuantizedElemKind(targetKind)), |
3307 | opErrMsg(op, |
3308 | "Cast Unsupported types (Supports only non quantized types)" )); |
3309 | |
3310 | // Create the IR node. |
3311 | Node *N = G_->createConvertTo(opName, input, targetKind); |
3312 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3313 | |
3314 | return Error::success(); |
3315 | } |
3316 | |
3317 | Error ONNXModelLoader::loadDepthToSpace(const ONNX_NAMESPACE::NodeProto &op, |
3318 | const ArgumentDictionaryTy &dict) { |
3319 | NodeValue input; |
3320 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3321 | |
3322 | dim_t blockSize = 0; |
3323 | if (dict.count("blocksize" )) { |
3324 | ASSIGN_VALUE_OR_RETURN_ERR(blockSize, loadInt(dict.at("blocksize" ))); |
3325 | } else { |
3326 | return MAKE_ERR("DepthToSpace: missing 'blocksize' attribute" ); |
3327 | } |
3328 | |
3329 | std::string mode = "DCR" ; |
3330 | if (dict.count("mode" )) { |
3331 | ASSIGN_VALUE_OR_RETURN_ERR(mode, loadStr(dict.at("mode" ))); |
3332 | } |
3333 | |
3334 | auto inputDim = input.dims(); |
3335 | RETURN_ERR_IF_NOT(inputDim.size() == 4, |
3336 | "DepthToSpace: dimension size of 4 is expected." ); |
3337 | RETURN_ERR_IF_NOT(inputDim[1] % blockSize == 0, |
3338 | "DepthToSpace: depth should be divisible by block size." ); |
3339 | |
3340 | std::string opName = loadOperatorName(op); |
3341 | auto *TR1 = G_->createTranspose(opName, input, NCHW2NHWC); |
3342 | auto *D2S = G_->createDepthToSpace(opName, TR1, blockSize, mode == "CRD" ); |
3343 | auto *TR2 = G_->createTranspose(opName, D2S, NHWC2NCHW); |
3344 | |
3345 | RETURN_IF_ERR(addNodeAsOutput(op, TR2)); |
3346 | return Error::success(); |
3347 | } |
3348 | |
3349 | Error ONNXModelLoader::loadSpaceToDepth(const ONNX_NAMESPACE::NodeProto &op, |
3350 | ArgumentDictionaryTy &dict) { |
3351 | |
3352 | // Input Type |
3353 | NodeValue input; |
3354 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
3355 | |
3356 | int blockSize = 0; |
3357 | if (dict.count("blocksize" )) { |
3358 | ASSIGN_VALUE_OR_RETURN_ERR(blockSize, loadInt(dict.at("blocksize" ))); |
3359 | } else { |
3360 | return MAKE_ERR( |
3361 | opErrMsg(op, "SpaceToDepth: missing 'blocksize' attribute" )); |
3362 | } |
3363 | |
3364 | // Create the node. |
3365 | std::string opName = loadOperatorName(op); |
3366 | auto *tr = G_->createTranspose(opName, input, NCHW2NHWC); |
3367 | Node *nd = G_->createSpaceToDepth(opName, tr, blockSize); |
3368 | auto *N = G_->createTranspose(opName, nd, NHWC2NCHW); |
3369 | |
3370 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3371 | |
3372 | return Error::success(); |
3373 | } |
3374 | |
3375 | Error ONNXModelLoader::loadReduceL2(const ONNX_NAMESPACE::NodeProto &op, |
3376 | const ArgumentDictionaryTy &dict) { |
3377 | const std::string &opName = loadOperatorName(op); |
3378 | NodeValue in; |
3379 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
3380 | in = G_->createMul(opName, in, in); |
3381 | |
3382 | // ReduceAdd. |
3383 | std::vector<unsigned_t> shapeAxes = {}; |
3384 | if (dict.count("axes" )) { |
3385 | ASSIGN_VALUE_OR_RETURN_ERR( |
3386 | shapeAxes, loadAxes<unsigned_t>(dict.at("axes" ), in.dims().size())); |
3387 | std::sort(shapeAxes.begin(), shapeAxes.end()); |
3388 | if (shapeAxes.size() > 1) { |
3389 | auto it = std::unique(shapeAxes.begin(), shapeAxes.end()); |
3390 | if (it != shapeAxes.end()) |
3391 | return MAKE_ERR(opErrMsg(op, "ReduceL2 Axes values are not unique." ), |
3392 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); |
3393 | } |
3394 | } else { |
3395 | shapeAxes.resize(in.dims().size()); |
3396 | std::iota(shapeAxes.begin(), shapeAxes.end(), 0); |
3397 | } |
3398 | |
3399 | bool keepDims = true; |
3400 | if (dict.count("keepdims" )) { |
3401 | int keepdims; |
3402 | ASSIGN_VALUE_OR_RETURN_ERR(keepdims, loadInt(dict.at("keepdims" ))); |
3403 | keepDims = (bool)keepdims; |
3404 | } |
3405 | |
3406 | // Reduceadd works only for single axis as of now. |
3407 | for (auto it = shapeAxes.rbegin(), e = shapeAxes.rend(); it != e; ++it) { |
3408 | in = G_->createBatchedReduceAdd(opName, in, llvm::makeArrayRef(*it)); |
3409 | if (keepDims) { |
3410 | in = G_->createExpandDims(opName, in, *it); |
3411 | } |
3412 | } |
3413 | |
3414 | in = G_->createPow(opName, in, 0.5f); |
3415 | RETURN_IF_ERR(addNodeAsOutput(op, in)); |
3416 | return Error::success(); |
3417 | } |
3418 | |
3419 | Error ONNXModelLoader::loadConstantOfShape(const ONNX_NAMESPACE::NodeProto &op, |
3420 | ArgumentDictionaryTy &dict, |
3421 | bool isSplat) { |
3422 | Tensor T(ElemKind::FloatTy, {1}); |
3423 | T.getHandle().raw(0) = 0.0; |
3424 | |
3425 | if (dict.count("value" )) { |
3426 | RETURN_IF_ERR(loadTensor(dict.at("value" )->t(), &T, useGlowCustomOps_)); |
3427 | if (!isSplat) { |
3428 | // Validate tensor only for ConstantOfShape operator. |
3429 | RETURN_ERR_IF_NOT( |
3430 | T.dims().size() == 1, |
3431 | opErrMsg(op, strFormat("ConstantOfShape Value must be " |
3432 | "a 1D vector, but found size %zu " , |
3433 | T.dims().size()))); |
3434 | RETURN_ERR_IF_NOT( |
3435 | T.getType().getElementType() == ElemKind::FloatTy || |
3436 | T.getType().getElementType() == ElemKind::Int64ITy || |
3437 | T.getType().getElementType() == ElemKind::Int32ITy, |
3438 | T.getType().getElementName().str() + " type Value is not supported." ); |
3439 | } |
3440 | } |
3441 | |
3442 | TypeRef ty; |
3443 | Node *SN = nullptr; |
3444 | if (op.input_size() > 0) { |
3445 | Constant *in; |
3446 | ASSIGN_VALUE_OR_RETURN_ERR(in, getConstantByName(op.input(0))); |
3447 | // Must be 1D tensor of int64_t. |
3448 | RETURN_ERR_IF_NOT( |
3449 | in->dims().size() == 1, |
3450 | opErrMsg( |
3451 | op, |
3452 | strFormat( |
3453 | "ConstantOfShape Input must be a 1D vector, but found size %d " , |
3454 | int(in->dims().size())))); |
3455 | RETURN_ERR_IF_NOT( |
3456 | in->getType()->getElementType() == ElemKind::Int64ITy, |
3457 | opErrMsg(op, "ConstantOfShape Input element type must be Int64ITy." )); |
3458 | // Convert 1D tensor of int64_t into llvm::ArrayRef<dim_t>. |
3459 | auto TH = in->getPayload().getHandle<int64_t>(); |
3460 | auto begin = &TH.raw(0); |
3461 | auto end = begin + TH.actualSize(); |
3462 | ShapeVector outputDims(begin, end); |
3463 | |
3464 | ty = mod_.uniqueType(T.getType().getElementType(), outputDims); |
3465 | switch (T.getType().getElementType()) { |
3466 | case ElemKind::Int64ITy: { |
3467 | int64_t v = T.getHandle<int64_t>().raw(0); |
3468 | RETURN_ERR_IF_NOT( |
3469 | v == static_cast<int64_t>(static_cast<float>(v)), |
3470 | opErrMsg( |
3471 | op, "ConstantOfShape implementation may cause losses for value " + |
3472 | std::to_string(v) + " ." )); |
3473 | SN = G_->createSplat(loadOperatorName(op), ty, v); |
3474 | break; |
3475 | } |
3476 | case ElemKind::Int32ITy: { |
3477 | int32_t v = T.getHandle<int32_t>().raw(0); |
3478 | RETURN_ERR_IF_NOT( |
3479 | v == static_cast<int32_t>(static_cast<float>(v)), |
3480 | opErrMsg( |
3481 | op, "ConstantOfShape implementation may cause losses for value " + |
3482 | std::to_string(v) + " ." )); |
3483 | SN = G_->createSplat(loadOperatorName(op), ty, v); |
3484 | break; |
3485 | } |
3486 | default: |
3487 | SN = G_->createSplat(loadOperatorName(op), ty, T.getHandle().raw(0)); |
3488 | } |
3489 | } else { |
3490 | ty = mod_.uniqueType(T.getType().getElementType(), T.dims()); |
3491 | SN = G_->createSplat(loadOperatorName(op), ty, T.getHandle().raw(0)); |
3492 | } |
3493 | RETURN_IF_ERR(addNodeAsOutput(op, SN)); |
3494 | return Error::success(); |
3495 | } |
3496 | |
3497 | Error ONNXModelLoader::loadTile(const ONNX_NAMESPACE::NodeProto &op, |
3498 | ArgumentDictionaryTy &dict) { |
3499 | const std::string &opName = loadOperatorName(op); |
3500 | NodeValue in, repeats; |
3501 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
3502 | ASSIGN_VALUE_OR_RETURN_ERR(repeats, getNodeValueByName(op.input(1))); |
3503 | if (!llvm::isa<Constant>(repeats)) { |
3504 | return MAKE_ERR(opErrMsg(op, "Tile Only constant Repeats is supported!" )); |
3505 | } |
3506 | |
3507 | if (repeats.dims().size() != 1) { |
3508 | return MAKE_ERR( |
3509 | opErrMsg(op, "Tile Repeats must be a single-dimensional tensor!" )); |
3510 | } |
3511 | |
3512 | if (repeats.dims()[0] != in.dims().size()) { |
3513 | return MAKE_ERR(opErrMsg( |
3514 | op, "Tile Repeats should have one value for each dimension of input!" )); |
3515 | } |
3516 | auto rh = llvm::cast<Constant>(repeats)->getPayload().getHandle<int64_t>(); |
3517 | Node *N = in; |
3518 | for (size_t i = 0; i < in.dims().size(); i++) { |
3519 | auto tiles = rh.raw(i); |
3520 | if (tiles != 1) { |
3521 | std::string name = opName + "." + std::to_string(i); |
3522 | N = G_->createTile(name, N, tiles, /*axis*/ i); |
3523 | } |
3524 | } |
3525 | |
3526 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3527 | return Error::success(); |
3528 | } |
3529 | |
3530 | Error ONNXModelLoader::loadExpand(const ONNX_NAMESPACE::NodeProto &op, |
3531 | const ArgumentDictionaryTy &dict) { |
3532 | const std::string &opName = loadOperatorName(op); |
3533 | NodeValue in; |
3534 | Constant *repeats; |
3535 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
3536 | ASSIGN_VALUE_OR_RETURN_ERR(repeats, getConstantByName(op.input(1))); |
3537 | |
3538 | std::vector<dim_t> tiles; |
3539 | helperSetter<int64_t, dim_t>(repeats, tiles); |
3540 | auto inputDimSize = (size_t)in.dims().size(); |
3541 | auto repeatSize = (size_t)tiles.size(); |
3542 | if (repeatSize > inputDimSize) { |
3543 | for (size_t i = 0, e = repeatSize - inputDimSize; i < e; i++) { |
3544 | in = G_->createExpandDims(opName + "_" + std::to_string(i), in, i); |
3545 | } |
3546 | } |
3547 | |
3548 | Node *N = in; |
3549 | for (size_t i = 0, e = tiles.size(); i < e; i++) { |
3550 | // Two corresponding dimension must have the same value, |
3551 | // or one of them is equal to 1. |
3552 | if (in.dims()[i] != 1 && tiles[i] != in.dims()[i] && tiles[i] != 1) { |
3553 | return MAKE_ERR(opErrMsg(op, "Expand Invalid repeat value" )); |
3554 | } |
3555 | if (tiles[i] != in.dims()[i] && tiles[i] != 1) { |
3556 | std::string name = opName + "_" + std::to_string(i); |
3557 | N = G_->createTile(name, N, tiles[i], /*axis*/ i); |
3558 | } |
3559 | } |
3560 | |
3561 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3562 | return Error::success(); |
3563 | } |
3564 | |
3565 | Expected<bool> |
3566 | ONNXModelLoader::foldOperator(const ONNX_NAMESPACE::NodeProto &op) { |
3567 | const unsigned numInputs = op.input_size(); |
3568 | const std::string &typeName = op.op_type(); |
3569 | llvm::SmallVector<NodeValue, 4> inputs; |
3570 | inputs.reserve(numInputs); |
3571 | for (unsigned i = 0; i < numInputs; i++) { |
3572 | // If the name of the input is empty then consider it to be unspecified, |
3573 | // which is valid for optional inputs, so simply skip. If it is necessary |
3574 | // for loading the op, then when we later try to load the proper error will |
3575 | // be propagated upward. |
3576 | if (op.input(i).empty()) { |
3577 | continue; |
3578 | } |
3579 | NodeValue in; |
3580 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i))); |
3581 | inputs.push_back(in); |
3582 | } |
3583 | |
3584 | if (!isConstantFoldable(inputs, typeName)) { |
3585 | return false; |
3586 | } |
3587 | |
3588 | // Create a temporary lightweight loader to construct function representing |
3589 | // current Op, and then constant fold the function using Interp backend. |
3590 | Function *tmpF = mod_.createFunction("eval_const_fold__" ); |
3591 | ONNXModelLoader tmpLoader(*tmpF); |
3592 | tmpLoader.opsetVersion_ = opsetVersion_; |
3593 | bool foldStatus = !ERR_TO_BOOL( |
3594 | constantFoldInLoader<ONNXModelLoader, ONNX_NAMESPACE::NodeProto>( |
3595 | tmpF, tmpLoader, this, op), |
3596 | /* log */ false); |
3597 | mod_.eraseFunction(tmpF); |
3598 | return foldStatus; |
3599 | } |
3600 | |
3601 | Error ONNXModelLoader::loadWhere(const ONNX_NAMESPACE::NodeProto &op, |
3602 | ArgumentDictionaryTy &dict) { |
3603 | NodeValue cNV; |
3604 | ASSIGN_VALUE_OR_RETURN_ERR(cNV, getNodeValueByName(op.input(0))); |
3605 | NodeValue xNV; |
3606 | ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(1))); |
3607 | NodeValue yNV; |
3608 | ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(2))); |
3609 | |
3610 | std::string opName = loadOperatorName(op); |
3611 | |
3612 | // Passing -1 for multi directional broadcast, axis will be computed |
3613 | // automatically. |
3614 | Node *N = G_->createNodeWithBroadcast<SelectNode>(opName, -1, cNV, xNV, yNV); |
3615 | |
3616 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
3617 | return Error::success(); |
3618 | } |
3619 | |
3620 | /// Utility function to get the RNN, GRU or LSTM direction from the proto |
3621 | /// description. If not provided, the default direction is 'forward'. |
3622 | static Expected<Function::RnnDirection> |
3623 | getRnnDirection(const ONNX_NAMESPACE::NodeProto &op, |
3624 | ArgumentDictionaryTy &dict) { |
3625 | Function::RnnDirection direction = Function::RnnDirection::Forward; |
3626 | if (dict.count("direction" )) { |
3627 | std::string directionStr; |
3628 | ASSIGN_VALUE_OR_RETURN_ERR(directionStr, loadStr(dict.at("direction" ))); |
3629 | if (directionStr == "forward" ) { |
3630 | direction = Function::RnnDirection::Forward; |
3631 | } else if (directionStr == "reverse" ) { |
3632 | direction = Function::RnnDirection::Reverse; |
3633 | } else if (directionStr == "bidirectional" ) { |
3634 | direction = Function::RnnDirection::Bidirectional; |
3635 | } else { |
3636 | return MAKE_ERR( |
3637 | "ONNX " + op.op_type() + " 'direction' attribute is invalid!" , |
3638 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
3639 | } |
3640 | } |
3641 | return direction; |
3642 | } |
3643 | |
3644 | /// Relu activation function definition. |
3645 | static Function::RnnActivation RnnActivationRelu(Function &F) { |
3646 | return [&F](llvm::StringRef name, Node *input) { |
3647 | return F.createRELU(name, input); |
3648 | }; |
3649 | } |
3650 | |
3651 | /// Tanh activation function definition. |
3652 | static Function::RnnActivation RnnActivationTanh(Function &F) { |
3653 | return [&F](llvm::StringRef name, Node *input) { |
3654 | return F.createTanh(name, input); |
3655 | }; |
3656 | } |
3657 | |
3658 | /// Sigmoid activation function definition. |
3659 | static Function::RnnActivation RnnActivationSigmoid(Function &F) { |
3660 | return [&F](llvm::StringRef name, Node *input) { |
3661 | return F.createSigmoid(name, input); |
3662 | }; |
3663 | } |
3664 | |
3665 | /// Utility function to get the RNN, GRU or LSTM activation functions from the |
3666 | /// proto description. The activation function array is assumed to be already |
3667 | /// initialized with the default values upon entering this function so that the |
3668 | /// purpose of this function is to overwrite the specific default values. |
3669 | /// Currenlty only Sigmoid, Tahn and ReLU activations are supported. |
3670 | static Error |
3671 | getRnnActivations(const ONNX_NAMESPACE::NodeProto &op, |
3672 | ArgumentDictionaryTy &dict, Function *F, |
3673 | std::vector<Function::RnnActivation> &activations) { |
3674 | |
3675 | // Activation alpha not supported (Optional)(Default:activation dependent). |
3676 | RETURN_ERR_IF_NOT(!dict.count("activation_alpha" ), |
3677 | "ONNX " + op.op_type() + |
3678 | " 'activation_alpha' attribute not supported!" ); |
3679 | |
3680 | // Activation beta not supported (Optional)(Default:activation dependent). |
3681 | RETURN_ERR_IF_NOT(!dict.count("activation_beta" ), |
3682 | "ONNX " + op.op_type() + |
3683 | " 'activation_beta' attribute not supported!" ); |
3684 | |
3685 | // Get activations. |
3686 | if (dict.count("activations" ) && dict.at("activations" )->strings_size()) { |
3687 | size_t actNum = dict.at("activations" )->strings_size(); |
3688 | size_t actNumExpected = activations.size(); |
3689 | RETURN_ERR_IF_NOT(actNum == actNumExpected, |
3690 | strFormat("ONNX %s 'activations' attribute has invalid " |
3691 | "number of functions! Expected number is %d!" , |
3692 | op.op_type().c_str(), (int)actNumExpected)); |
3693 | for (size_t actIdx = 0; actIdx < actNum; actIdx++) { |
3694 | std::string actStr = dict.at("activations" )->strings().Get(actIdx); |
3695 | if (actStr == "Relu" ) { |
3696 | activations[actIdx] = RnnActivationRelu(*F); |
3697 | } else if (actStr == "Tanh" ) { |
3698 | activations[actIdx] = RnnActivationTanh(*F); |
3699 | } else if (actStr == "Sigmoid" ) { |
3700 | activations[actIdx] = RnnActivationSigmoid(*F); |
3701 | } else { |
3702 | return MAKE_ERR( |
3703 | "ONNX " + op.op_type() + " activation '" + actStr + |
3704 | "' not supported!" , |
3705 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); |
3706 | } |
3707 | } |
3708 | } |
3709 | return Error::success(); |
3710 | } |
3711 | |
3712 | // Limitations: |
3713 | // - Activation clipping not supported. |
3714 | // - Variable sequence length not supported. |
3715 | Error ONNXModelLoader::loadRNN(const ONNX_NAMESPACE::NodeProto &op, |
3716 | ArgumentDictionaryTy &dict) { |
3717 | |
3718 | const std::string &opName = loadOperatorName(op); |
3719 | |
3720 | // ------------------------- Attributes ------------------------------------- |
3721 | // Get direction (Optional)(Default:forward). |
3722 | Function::RnnDirection direction; |
3723 | ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict)); |
3724 | dim_t numDirections = |
3725 | (direction == Function::RnnDirection::Bidirectional) ? 2 : 1; |
3726 | |
3727 | // Get activations as lambdas (Optional)(Default:f=Tanh). |
3728 | std::vector<Function::RnnActivation> activations; |
3729 | if (direction == Function::RnnDirection::Bidirectional) { |
3730 | activations = {RnnActivationTanh(*G_), RnnActivationTanh(*G_)}; |
3731 | } else { |
3732 | activations = {RnnActivationTanh(*G_)}; |
3733 | } |
3734 | RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations)); |
3735 | |
3736 | // Activation clipping not supported (Optional)(Default: 0 for no clipping). |
3737 | RETURN_ERR_IF_NOT(!dict.count("clip" ), |
3738 | opErrMsg(op, "ONNX RNN 'clip' attribute not supported!" )); |
3739 | |
3740 | // Get hidden size (Required). |
3741 | dim_t hiddenSize; |
3742 | RETURN_ERR_IF_NOT( |
3743 | dict.count("hidden_size" ), |
3744 | opErrMsg(op, "ONNX RNN 'hidden_size' attribute is required!" )); |
3745 | ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size" ))); |
3746 | |
3747 | // --------------------------- Inputs --------------------------------------- |
3748 | const int numInputs = op.input_size(); |
3749 | RETURN_ERR_IF_NOT((3 <= numInputs) && (numInputs <= 6), |
3750 | opErrMsg(op, strFormat("ONNX RNN should have minimum 3 and " |
3751 | "maximum 6 inputs, but found %d " , |
3752 | numInputs))); |
3753 | |
3754 | // Input0: X (Required). |
3755 | NodeValue X; |
3756 | ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0))); |
3757 | |
3758 | // Input1: W (Required). |
3759 | NodeValue W; |
3760 | ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1))); |
3761 | |
3762 | // Input2: R (Required). |
3763 | NodeValue R; |
3764 | ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2))); |
3765 | |
3766 | // Input3: B (Optional). |
3767 | NodeValue B = nullptr; |
3768 | if (numInputs > 3 && !op.input(3).empty()) { |
3769 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3))); |
3770 | } |
3771 | |
3772 | // Input4: sequence_lens (Optional). |
3773 | if (numInputs > 4 && !op.input(4).empty()) { |
3774 | LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of " |
3775 | "ONNX RNN input." ; |
3776 | } |
3777 | |
3778 | // Input5: initial_h (Optional). |
3779 | NodeValue initial_h = nullptr; |
3780 | if (numInputs > 5 && !op.input(5).empty()) { |
3781 | ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5))); |
3782 | } |
3783 | |
3784 | // -------------------------- Outputs --------------------------------------- |
3785 | // We allow creating placeholders for the RNN state variable Y_h for the |
3786 | // following reasons: |
3787 | // - expose the RNN state in the graph interface for accessibility (set |
3788 | // desired state, reset state, watch the state being updated automatically). |
3789 | // - since the RNN cells are unrolled (no graph loop primitive available |
3790 | // at this point), the optimal way to use the RNN within a model would be |
3791 | // to have it defined with only 1 time step and have the loop in the top |
3792 | // of the application while the RNN state will be automatically updated |
3793 | // from one iteration (time step) to the next through the placeholders. |
3794 | |
3795 | // Derived parameters. |
3796 | RETURN_ERR_IF_NOT( |
3797 | X.dims().size() == 3, |
3798 | opErrMsg(op, "ONNX RNN input 'X' should have 3 dimensions!" )); |
3799 | dim_t batchSize = X.dims()[1]; |
3800 | |
3801 | // Create Y_h (hidden state) output placeholder. |
3802 | Placeholder *Y_h_ph = nullptr; |
3803 | if (onnxExportRnnStatesOpt) { |
3804 | TypeRef Htype = mod_.uniqueTypeWithNewShape( |
3805 | X.getType(), {numDirections, batchSize, hiddenSize}); |
3806 | std::string Hname = opName + ".Y_h" ; |
3807 | ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph, |
3808 | createAndRegisterPlaceholder(Hname, Htype)); |
3809 | inputVarsByName_.try_emplace(Hname, Y_h_ph); |
3810 | } |
3811 | |
3812 | // Set RNN input state. |
3813 | NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h; |
3814 | |
3815 | // Create ONNX RNN. |
3816 | NodeValue Y, Y_h; |
3817 | G_->createOnnxRNN(opName, X, W, R, B, Y_h_init, Y, Y_h, hiddenSize, direction, |
3818 | activations); |
3819 | |
3820 | // Save RNN output state. |
3821 | if (onnxExportRnnStatesOpt) { |
3822 | G_->createSave(opName + ".Y_h.save" , Y_h, Y_h_ph); |
3823 | } |
3824 | |
3825 | // Add node. |
3826 | const int numOutputs = op.output_size(); |
3827 | if (numOutputs == 1) { |
3828 | RETURN_IF_ERR(addNodeAsOutput(op, Y)); |
3829 | } else if (numOutputs == 2) { |
3830 | RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h})); |
3831 | } else { |
3832 | return MAKE_ERR(opErrMsg(op, strFormat("ONNX RNN should have minimum 1 and " |
3833 | "maximum 2 outputs, but found %d " , |
3834 | numOutputs))); |
3835 | } |
3836 | return Error::success(); |
3837 | } |
3838 | |
3839 | // Limitations: |
3840 | // - Activation clipping not supported. |
3841 | // - Variable sequence length not supported. |
3842 | Error ONNXModelLoader::loadGRU(const ONNX_NAMESPACE::NodeProto &op, |
3843 | ArgumentDictionaryTy &dict) { |
3844 | |
3845 | const std::string &opName = loadOperatorName(op); |
3846 | |
3847 | // ------------------------- Attributes ------------------------------------- |
3848 | // Get direction (Optional)(Default:forward). |
3849 | Function::RnnDirection direction; |
3850 | ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict)); |
3851 | dim_t numDirections = |
3852 | (direction == Function::RnnDirection::Bidirectional) ? 2 : 1; |
3853 | |
3854 | // Get activations as lambdas (Optional)(Default:f=Sigmoid, g=Tanh). |
3855 | std::vector<Function::RnnActivation> activations; |
3856 | if (direction == Function::RnnDirection::Bidirectional) { |
3857 | activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_), |
3858 | RnnActivationSigmoid(*G_), RnnActivationTanh(*G_)}; |
3859 | } else { |
3860 | activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_)}; |
3861 | } |
3862 | RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations)); |
3863 | |
3864 | // Activation clipping not supported (Optional)(Default: 0 for no clipping). |
3865 | RETURN_ERR_IF_NOT(!dict.count("clip" ), |
3866 | opErrMsg(op, "ONNX GRU 'clip' attribute not supported!" )); |
3867 | |
3868 | // Get hidden size (Required). |
3869 | dim_t hiddenSize; |
3870 | RETURN_ERR_IF_NOT( |
3871 | dict.count("hidden_size" ), |
3872 | opErrMsg(op, "ONNX GRU 'hidden_size' attribute is required!" )); |
3873 | ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size" ))); |
3874 | ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size" ))); |
3875 | |
3876 | // Get linear_before_reset (Optional)(Default:0). |
3877 | int linearBeforeReset = 0; |
3878 | if (dict.count("linear_before_reset" ) && |
3879 | dict.at("linear_before_reset" )->has_i()) { |
3880 | linearBeforeReset = dict.at("linear_before_reset" )->i(); |
3881 | } |
3882 | |
3883 | // --------------------------- Inputs --------------------------------------- |
3884 | const int numInputs = op.input_size(); |
3885 | RETURN_ERR_IF_NOT((3 <= numInputs) && (numInputs <= 6), |
3886 | opErrMsg(op, strFormat("ONNX GRU should have minimum 3 and " |
3887 | "maximum 6 inputs, but found %d " , |
3888 | numInputs))); |
3889 | |
3890 | // Input0: X (Required). |
3891 | NodeValue X; |
3892 | ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0))); |
3893 | |
3894 | // Input1: W (Required). |
3895 | NodeValue W; |
3896 | ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1))); |
3897 | |
3898 | // Input2: R (Required). |
3899 | NodeValue R; |
3900 | ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2))); |
3901 | |
3902 | // Input3: B (Optional). |
3903 | NodeValue B = nullptr; |
3904 | if (numInputs > 3 && !op.input(3).empty()) { |
3905 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3))); |
3906 | } |
3907 | |
3908 | // Input4: sequence_lens (Optional). |
3909 | if (numInputs > 4 && !op.input(4).empty()) { |
3910 | LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of " |
3911 | "ONNX GRU input." ; |
3912 | } |
3913 | |
3914 | // Input5: initial_h (Optional). |
3915 | NodeValue initial_h = nullptr; |
3916 | if (numInputs > 5 && !op.input(5).empty()) { |
3917 | ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5))); |
3918 | } |
3919 | |
3920 | // -------------------------- Outputs --------------------------------------- |
3921 | // We allow creating placeholders for the GRU state variable Y_h for the |
3922 | // following reasons: |
3923 | // - expose the GRU state in the graph interface for accessibility (set |
3924 | // desired state, reset state, watch the state being updated automatically). |
3925 | // - since the GRU cells are unrolled (no graph loop primitive available |
3926 | // at this point), the optimal way to use the GRU within a model would be |
3927 | // to have it defined with only 1 time step and have the loop in the top |
3928 | // of the application while the GRU state will be automatically updated |
3929 | // from one iteration (time step) to the next through the placeholders. |
3930 | |
3931 | // Derived parameters. |
3932 | RETURN_ERR_IF_NOT( |
3933 | X.dims().size() == 3, |
3934 | opErrMsg(op, "ONNX GRU input 'X' should have 3 dimensions!" )); |
3935 | dim_t batchSize = X.dims()[1]; |
3936 | |
3937 | // Create Y_h (hidden state) output placeholder. |
3938 | Placeholder *Y_h_ph = nullptr; |
3939 | if (onnxExportRnnStatesOpt) { |
3940 | TypeRef Htype = mod_.uniqueTypeWithNewShape( |
3941 | X.getType(), {numDirections, batchSize, hiddenSize}); |
3942 | std::string Hname = opName + ".Y_h" ; |
3943 | ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph, |
3944 | createAndRegisterPlaceholder(Hname, Htype)); |
3945 | inputVarsByName_.try_emplace(Hname, Y_h_ph); |
3946 | } |
3947 | |
3948 | // Set GRU input state. |
3949 | NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h; |
3950 | |
3951 | // Create ONNX GRU. |
3952 | NodeValue Y, Y_h; |
3953 | G_->createOnnxGRU(opName, X, W, R, B, Y_h_init, Y, Y_h, hiddenSize, direction, |
3954 | activations, (bool)linearBeforeReset); |
3955 | |
3956 | // Save GRU output state. |
3957 | if (onnxExportRnnStatesOpt) { |
3958 | G_->createSave(opName + ".Y_h.save" , Y_h, Y_h_ph); |
3959 | } |
3960 | |
3961 | // Add node. |
3962 | const int numOutputs = op.output_size(); |
3963 | if (numOutputs == 1) { |
3964 | RETURN_IF_ERR(addNodeAsOutput(op, Y)); |
3965 | } else if (numOutputs == 2) { |
3966 | RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h})); |
3967 | } else { |
3968 | return MAKE_ERR(opErrMsg(op, strFormat("ONNX GRU should have minimum 1 and " |
3969 | "maximum 2 outputs, but found %d " , |
3970 | numOutputs))); |
3971 | } |
3972 | return Error::success(); |
3973 | } |
3974 | |
3975 | // Limitations: |
3976 | // - Activation clipping not supported. |
3977 | // - Variable sequence length not supported. |
3978 | Error ONNXModelLoader::loadLSTM(const ONNX_NAMESPACE::NodeProto &op, |
3979 | ArgumentDictionaryTy &dict) { |
3980 | |
3981 | const std::string &opName = loadOperatorName(op); |
3982 | |
3983 | // ------------------------- Attributes ------------------------------------- |
3984 | // Get direction (Optional)(Default:forward). |
3985 | Function::RnnDirection direction; |
3986 | ASSIGN_VALUE_OR_RETURN_ERR(direction, getRnnDirection(op, dict)); |
3987 | dim_t numDirections = |
3988 | (direction == Function::RnnDirection::Bidirectional) ? 2 : 1; |
3989 | |
3990 | // Get activations as lambdas (Optional)(Default:f=Sigmoid, g=Tanh, h=Tanh). |
3991 | std::vector<Function::RnnActivation> activations; |
3992 | if (direction == Function::RnnDirection::Bidirectional) { |
3993 | activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_), |
3994 | RnnActivationTanh(*G_), RnnActivationSigmoid(*G_), |
3995 | RnnActivationTanh(*G_), RnnActivationTanh(*G_)}; |
3996 | } else { |
3997 | activations = {RnnActivationSigmoid(*G_), RnnActivationTanh(*G_), |
3998 | RnnActivationTanh(*G_)}; |
3999 | } |
4000 | RETURN_IF_ERR(getRnnActivations(op, dict, G_, activations)); |
4001 | |
4002 | // Activation clipping not supported (Optional)(Default: 0 for no clipping). |
4003 | RETURN_ERR_IF_NOT(!dict.count("clip" ), |
4004 | opErrMsg(op, "ONNX LSTM 'clip' attribute not supported!" )); |
4005 | |
4006 | // Get hidden size (Required). |
4007 | dim_t hiddenSize; |
4008 | RETURN_ERR_IF_NOT( |
4009 | dict.count("hidden_size" ), |
4010 | opErrMsg(op, "ONNX LSTM 'hidden_size' attribute is required!" )); |
4011 | ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size" ))); |
4012 | |
4013 | // Get input forget (Optional)(Default:0). |
4014 | int inputForget = 0; |
4015 | if (dict.count("input_forget" ) && dict.at("input_forget" )->has_i()) { |
4016 | inputForget = dict.at("input_forget" )->i(); |
4017 | } |
4018 | |
4019 | // --------------------------- Inputs --------------------------------------- |
4020 | const int numInputs = op.input_size(); |
4021 | RETURN_ERR_IF_NOT( |
4022 | (3 <= numInputs) && (numInputs <= 8), |
4023 | opErrMsg(op, strFormat("ONNX LSTM should have minimum 3 and maximum 8 " |
4024 | "inputs, but found %d " , |
4025 | numInputs))); |
4026 | |
4027 | // Input0: X (Required). |
4028 | NodeValue X; |
4029 | ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0))); |
4030 | |
4031 | // Input1: W (Required). |
4032 | NodeValue W; |
4033 | ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1))); |
4034 | |
4035 | // Input2: R (Required). |
4036 | NodeValue R; |
4037 | ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2))); |
4038 | |
4039 | // Input3: B (Optional). |
4040 | NodeValue B = nullptr; |
4041 | if (numInputs > 3 && !op.input(3).empty()) { |
4042 | ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3))); |
4043 | } |
4044 | |
4045 | // Input4: sequence_lens (Optional). |
4046 | if (numInputs > 4 && !op.input(4).empty()) { |
4047 | LOG(WARNING) << "sequence_lens ignored, will be inferred from shape of " |
4048 | "ONNX LSTM input." ; |
4049 | } |
4050 | |
4051 | // Input5: initial_h (Optional). |
4052 | NodeValue initial_h = nullptr; |
4053 | if (numInputs > 5 && !op.input(5).empty()) { |
4054 | ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5))); |
4055 | } |
4056 | |
4057 | // Input6: initial_c (Optional). |
4058 | NodeValue initial_c = nullptr; |
4059 | if (numInputs > 6 && !op.input(6).empty()) { |
4060 | ASSIGN_VALUE_OR_RETURN_ERR(initial_c, getNodeValueByName(op.input(6))); |
4061 | } |
4062 | |
4063 | // Input7: P (Optional). |
4064 | NodeValue P = nullptr; |
4065 | if (numInputs > 7 && !op.input(7).empty()) { |
4066 | ASSIGN_VALUE_OR_RETURN_ERR(P, getNodeValueByName(op.input(7))); |
4067 | } |
4068 | |
4069 | // -------------------------- Outputs --------------------------------------- |
4070 | // We allow creating placeholders for the LSTM state variables (Y_h and Y_c) |
4071 | // for the following reasons: |
4072 | // - expose the LSTM state in the graph interface for accessibility (set |
4073 | // desired state, reset state, watch the state being updated automatically). |
4074 | // - since the LSTM cells are unrolled (no graph loop primitive available |
4075 | // at this point), the optimal way to use the LSTM within a model would be |
4076 | // to have it defined with only 1 time step and have the loop in the top |
4077 | // of the application while the LSTM state will be automatically updated |
4078 | // from one iteration (time step) to the next through the placeholders. |
4079 | |
4080 | // Derived parameters. |
4081 | RETURN_ERR_IF_NOT( |
4082 | X.dims().size() == 3, |
4083 | opErrMsg(op, "ONNX LSTM input 'X' should have 3 dimensions!" )); |
4084 | dim_t batchSize = X.dims()[1]; |
4085 | |
4086 | // Create Y_h (hidden state) output placeholder. |
4087 | Placeholder *Y_h_ph = nullptr; |
4088 | if (onnxExportRnnStatesOpt) { |
4089 | TypeRef Htype = mod_.uniqueTypeWithNewShape( |
4090 | X.getType(), {numDirections, batchSize, hiddenSize}); |
4091 | std::string Hname = opName + ".Y_h" ; |
4092 | ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph, |
4093 | createAndRegisterPlaceholder(Hname, Htype)); |
4094 | inputVarsByName_.try_emplace(Hname, Y_h_ph); |
4095 | } |
4096 | |
4097 | // Create Y_c (cell state) output placeholder. |
4098 | Placeholder *Y_c_ph = nullptr; |
4099 | if (onnxExportRnnStatesOpt) { |
4100 | TypeRef Ctype = mod_.uniqueTypeWithNewShape( |
4101 | X.getType(), {numDirections, batchSize, hiddenSize}); |
4102 | std::string Cname = opName + ".Y_c" ; |
4103 | ASSIGN_VALUE_OR_RETURN_ERR(Y_c_ph, |
4104 | createAndRegisterPlaceholder(Cname, Ctype)); |
4105 | inputVarsByName_.try_emplace(Cname, Y_c_ph); |
4106 | } |
4107 | |
4108 | // Set LSTM input states. |
4109 | NodeValue Y_h_init = onnxExportRnnStatesOpt ? Y_h_ph : initial_h; |
4110 | NodeValue Y_c_init = onnxExportRnnStatesOpt ? Y_c_ph : initial_c; |
4111 | |
4112 | // Create ONNX LSTM. |
4113 | NodeValue Y, Y_h, Y_c; |
4114 | G_->createOnnxLSTM(opName, X, W, R, B, Y_h_init, Y_c_init, P, Y, Y_h, Y_c, |
4115 | hiddenSize, direction, activations, (bool)inputForget); |
4116 | |
4117 | // Save LSTM output states. |
4118 | if (onnxExportRnnStatesOpt) { |
4119 | G_->createSave(opName + ".Y_h.save" , Y_h, Y_h_ph); |
4120 | G_->createSave(opName + ".Y_c.save" , Y_c, Y_c_ph); |
4121 | } |
4122 | |
4123 | // Add node. |
4124 | const int numOutputs = op.output_size(); |
4125 | if (numOutputs == 1) { |
4126 | RETURN_IF_ERR(addNodeAsOutput(op, Y)); |
4127 | } else if (numOutputs == 2) { |
4128 | RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h})); |
4129 | } else if (numOutputs == 3) { |
4130 | RETURN_IF_ERR(assignNodeOutputs(op, {Y, Y_h, Y_c})); |
4131 | } else { |
4132 | return MAKE_ERR( |
4133 | opErrMsg(op, strFormat("ONNX LSTM should have minimum 1 and " |
4134 | "maximum 3 outputs, but found %d " , |
4135 | numOutputs))); |
4136 | } |
4137 | return Error::success(); |
4138 | } |
4139 | |
4140 | Error ONNXModelLoader::loadClip(const ONNX_NAMESPACE::NodeProto &op, |
4141 | const ArgumentDictionaryTy &dict) { |
4142 | NodeValue in; |
4143 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4144 | |
4145 | float cmin = std::numeric_limits<float>::lowest(); |
4146 | if (opsetVersion_ > 10 && op.input_size() > 1 && !op.input(1).empty()) { |
4147 | // min value is optional and might not be supplied. |
4148 | Constant *minC = getConstantByNameOrNull(op.input(1)); |
4149 | RETURN_ERR_IF_NOT(minC, "Expect constant for min value in Clip operator." ); |
4150 | cmin = minC->getPayload().getHandle().raw(0); |
4151 | } else if (dict.count("min" )) { |
4152 | ASSIGN_VALUE_OR_RETURN_ERR(cmin, loadFloat(dict.find("min" )->second)); |
4153 | } |
4154 | |
4155 | // Windows headers define `max` macro, so have to wrap the function name in |
4156 | // parenthesis to avoid compilation error. |
4157 | float cmax = (std::numeric_limits<float>::max)(); |
4158 | if (opsetVersion_ > 10 && op.input_size() > 2 && !op.input(2).empty()) { |
4159 | // max value is optional and might not be supplied. |
4160 | Constant *maxC = getConstantByNameOrNull(op.input(2)); |
4161 | RETURN_ERR_IF_NOT(maxC, "Expect constant for max value in Clip operator." ); |
4162 | cmax = maxC->getPayload().getHandle().raw(0); |
4163 | } else if (dict.count("max" )) { |
4164 | ASSIGN_VALUE_OR_RETURN_ERR(cmax, loadFloat(dict.find("max" )->second)); |
4165 | } |
4166 | |
4167 | auto *node = G_->createClip(loadOperatorName(op), in, cmin, cmax); |
4168 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
4169 | return Error::success(); |
4170 | } |
4171 | |
4172 | Error ONNXModelLoader::loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op, |
4173 | ArgumentDictionaryTy &dict) { |
4174 | NodeValue LHS; |
4175 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0))); |
4176 | NodeValue RHS; |
4177 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1))); |
4178 | |
4179 | Node *N = G_->createNodeWithBroadcast<CmpEQNode>(loadOperatorName(op), |
4180 | /* axis */ -1, LHS, RHS); |
4181 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4182 | return Error::success(); |
4183 | } |
4184 | |
4185 | Error ONNXModelLoader::loadCmpLTE(const ONNX_NAMESPACE::NodeProto &op, |
4186 | ArgumentDictionaryTy &dict) { |
4187 | NodeValue LHS; |
4188 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0))); |
4189 | NodeValue RHS; |
4190 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1))); |
4191 | |
4192 | Node *N = G_->createNodeWithBroadcast<CmpLTENode>(loadOperatorName(op), |
4193 | /* axis */ -1, LHS, RHS); |
4194 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4195 | return Error::success(); |
4196 | } |
4197 | |
4198 | /// Takes a list of NodeValues \p inputs and broadcasts them to a common shape |
4199 | /// \p broadcastShape based on the maximum value along each dimension. |
4200 | static Error getShapeForBroadcast(llvm::ArrayRef<NodeValue> inputs, |
4201 | std::vector<dim_t> &broadcastShape) { |
4202 | std::vector<uint32_t> numDims; |
4203 | for (auto &N : inputs) { |
4204 | numDims.push_back(N.dims().size()); |
4205 | } |
4206 | const uint32_t outputNumDims = |
4207 | *std::max_element(numDims.begin(), numDims.end()); |
4208 | for (uint32_t i = 0; i < outputNumDims; i++) { |
4209 | std::vector<dim_t> dims; |
4210 | for (uint32_t j = 0; j < inputs.size(); j++) { |
4211 | auto vals = inputs[j].dims(); |
4212 | if (vals.size() > i) { |
4213 | dims.push_back(vals[vals.size() - 1 - i]); |
4214 | } |
4215 | } |
4216 | broadcastShape.insert(broadcastShape.begin(), |
4217 | *std::max_element(dims.begin(), dims.end())); |
4218 | } |
4219 | return Error::success(); |
4220 | } |
4221 | |
4222 | Error ONNXModelLoader::loadMean(const ONNX_NAMESPACE::NodeProto &op, |
4223 | ArgumentDictionaryTy &dict) { |
4224 | size_t numInputTensors = op.input_size(); |
4225 | if (numInputTensors == 1) { |
4226 | NodeValue in; |
4227 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4228 | RETURN_IF_ERR(addNodeAsOutput(op, in)); |
4229 | } else { |
4230 | const std::string &opName = loadOperatorName(op); |
4231 | llvm::SmallVector<NodeValue, 4> inputTensors; |
4232 | inputTensors.reserve(numInputTensors); |
4233 | for (unsigned i = 0; i < numInputTensors; i++) { |
4234 | NodeValue in; |
4235 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i))); |
4236 | inputTensors.push_back(in); |
4237 | } |
4238 | std::vector<dim_t> broadcastShape; |
4239 | RETURN_IF_ERR(getShapeForBroadcast(inputTensors, broadcastShape)); |
4240 | for (unsigned i = 0; i < numInputTensors; i++) { |
4241 | auto &in = inputTensors[i]; |
4242 | int axis = broadcastShape.size() - in.dims().size(); |
4243 | in = G_->createBroadcast(opName, in, broadcastShape, axis); |
4244 | in = G_->createExpandDims(opName, in, {0}); |
4245 | } |
4246 | ConcatNode *concat = G_->createConcat(opName, inputTensors, /* axis */ 0); |
4247 | Node *N = G_->createBatchedReduceMean(opName, concat, /* axis */ {0}); |
4248 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4249 | } |
4250 | return Error::success(); |
4251 | } |
4252 | |
4253 | Error ONNXModelLoader::loadSelect(const ONNX_NAMESPACE::NodeProto &op, |
4254 | ArgumentDictionaryTy &dict) { |
4255 | NodeValue Cond; |
4256 | ASSIGN_VALUE_OR_RETURN_ERR(Cond, getNodeValueByName(op.input(0))); |
4257 | NodeValue LHS; |
4258 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(1))); |
4259 | NodeValue RHS; |
4260 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(2))); |
4261 | |
4262 | std::vector<dim_t> shape; |
4263 | ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape" ])); |
4264 | |
4265 | auto outTy = mod_.uniqueType(LHS.getElementType(), shape); |
4266 | Node *N = G_->createSelect(loadOperatorName(op), outTy, Cond, LHS, RHS); |
4267 | |
4268 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4269 | return Error::success(); |
4270 | } |
4271 | |
4272 | Error ONNXModelLoader::loadNonZero(const ONNX_NAMESPACE::NodeProto &op, |
4273 | const ArgumentDictionaryTy &dict) { |
4274 | NodeValue input; |
4275 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4276 | |
4277 | Constant *C = getConstantByNameOrNull(op.input(0)); |
4278 | RETURN_ERR_IF_NOT(C, |
4279 | opErrMsg(op, "NonZero Only constant shape is supported!" )); |
4280 | |
4281 | // output tensor. |
4282 | Tensor outT; |
4283 | |
4284 | // Fold NonZero operator. |
4285 | auto foldNonZero = [&C, &outT](auto dummy) -> Error { |
4286 | auto inH = C->getPayload().getHandle<decltype(dummy)>(); |
4287 | auto dims = C->dims(); |
4288 | |
4289 | // First pass over the input is used to find the number of non-zero elements |
4290 | // so we can create output tensor that will be filled in the 2nd pass. |
4291 | dim_t nonZeroCnt = 0; |
4292 | for (dim_t idx = 0, e = inH.size(); idx < e; idx++) { |
4293 | nonZeroCnt += (inH.raw(idx) != 0) ? 1 : 0; |
4294 | } |
4295 | |
4296 | // No need to support zero Tensor (empty output); we support constant input |
4297 | // only and such input likely means it's an invalid model or the part of |
4298 | // graph can be removed. |
4299 | RETURN_ERR_IF_NOT(nonZeroCnt > 0, |
4300 | "Non-Zero input with all zeroes is not supported." ); |
4301 | |
4302 | // Create output tensor. First dimension is the rank of input tensor, second |
4303 | // dimension is the number of non-zero elements. |
4304 | outT.reset(ElemKind::Int64ITy, {(dim_t)dims.size(), nonZeroCnt}); |
4305 | |
4306 | // Strides for each dimensions, needed to calculate NonZero output. |
4307 | std::vector<dim_t> strides; |
4308 | strides.resize(dims.size()); |
4309 | strides[dims.size() - 1] = 1; |
4310 | if (dims.size() > 1) { |
4311 | for (int i = dims.size() - 2; i >= 0; i--) { |
4312 | strides[i] = dims[i + 1] * strides[i + 1]; |
4313 | } |
4314 | } |
4315 | |
4316 | // Second pass over the input is used to fill the output tensor. For each |
4317 | // non-zero element we fill all the dimensions, at position determined |
4318 | // by the non-zero element's index when zero elements are ignored. |
4319 | auto outH = outT.getHandle<int64_t>(); |
4320 | for (dim_t idx = 0, pos = 0, e = inH.size(); idx < e; idx++) { |
4321 | if (inH.raw(idx) != 0) { |
4322 | for (dim_t dim = 0; dim < dims.size(); dim++) { |
4323 | outH.at({dim, pos}) = (idx / strides[dim]) % dims[dim]; |
4324 | } |
4325 | pos++; |
4326 | } |
4327 | } |
4328 | return Error::success(); |
4329 | }; |
4330 | |
4331 | std::string err; |
4332 | if (C->getElementType() == ElemKind::FloatTy) { |
4333 | RETURN_IF_ERR(foldNonZero((float)0)); |
4334 | } else if (C->getElementType() == ElemKind::Int64ITy) { |
4335 | RETURN_IF_ERR(foldNonZero((int64_t)0)); |
4336 | } else if (C->getElementType() == ElemKind::Int32ITy) { |
4337 | RETURN_IF_ERR(foldNonZero((int32_t)0)); |
4338 | } else { |
4339 | return MAKE_ERR( |
4340 | opErrMsg(op, "NonZero: Unsupported input type for NonZero operator." |
4341 | "(Supports Float, Int32 and Int64)" )); |
4342 | } |
4343 | |
4344 | Constant *outC = G_->getParent()->createConstant("nonZero" , std::move(outT)); |
4345 | RETURN_IF_ERR(addNodeAsOutput(op, outC)); |
4346 | return Error::success(); |
4347 | } |
4348 | |
4349 | Error ONNXModelLoader::loadQuantize(const ONNX_NAMESPACE::NodeProto &op, |
4350 | ArgumentDictionaryTy &dict) { |
4351 | NodeValue in; |
4352 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4353 | |
4354 | float scale; |
4355 | ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict.at("scale" ))); |
4356 | unsigned_t offset; |
4357 | ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict.at("offset" ))); |
4358 | std::string elemKindStr; |
4359 | ASSIGN_VALUE_OR_RETURN_ERR(elemKindStr, loadStr(dict.at("elem_kind" ))); |
4360 | |
4361 | ElemKind elemKind = Type::getElementKindFromName(elemKindStr); |
4362 | |
4363 | auto outDims = in.getType()->dims(); |
4364 | auto outTy = mod_.uniqueType(elemKind, outDims, scale, offset); |
4365 | Node *N = G_->createQuantize(loadOperatorName(op), in, outTy); |
4366 | |
4367 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4368 | return Error::success(); |
4369 | } |
4370 | |
4371 | Error ONNXModelLoader::loadQuantizeLinear(const ONNX_NAMESPACE::NodeProto &op, |
4372 | ArgumentDictionaryTy &dict) { |
4373 | NodeValue in; |
4374 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4375 | |
4376 | // Onnx documentation expects input to be Float or Int32. |
4377 | if (!(in.getElementType() == ElemKind::FloatTy || |
4378 | in.getElementType() == ElemKind::Int32ITy)) { |
4379 | return MAKE_ERR( |
4380 | opErrMsg(op, "QuantizeLinear supports input to be Float or Int32." )); |
4381 | } |
4382 | |
4383 | // Only scale with constant is supported. |
4384 | Constant *scale; |
4385 | ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1))); |
4386 | |
4387 | // Glow supports only per layer scale. |
4388 | if (!((scale->getType()->dims().size() == 1 && |
4389 | scale->getType()->dims()[0] == 1) || |
4390 | scale->getType()->dims().size() == 0)) { |
4391 | return MAKE_ERR(opErrMsg(op, "QuantizeLinear: y_scale scalar value is only" |
4392 | " supported." )); |
4393 | } |
4394 | |
4395 | float scaleValue = scale->getPayload().getHandle<float>().raw(0); |
4396 | |
4397 | // Default values as per Onnx documentation. |
4398 | int32_t offsetValue = 0; |
4399 | auto type = ElemKind::UInt8QTy; |
4400 | |
4401 | // Check if we have a offset vector. |
4402 | if (op.input_size() > 2) { |
4403 | auto &offsetTensorName = op.input(2); |
4404 | // Only offset with constant is supported. |
4405 | Constant *offset = nullptr; |
4406 | // Load the serialized offset vector. |
4407 | ASSIGN_VALUE_OR_RETURN_ERR(offset, getConstantByName(offsetTensorName)); |
4408 | if (!((offset->getType()->dims().size() == 1 && |
4409 | offset->getType()->dims()[0] == 1) || |
4410 | offset->getType()->dims().size() == 0)) { |
4411 | return MAKE_ERR( |
4412 | opErrMsg(op, "QuantizeLinear: y_zero_point scalar value is only" |
4413 | " supported." )); |
4414 | } |
4415 | |
4416 | type = offset->getElementType(); |
4417 | // Only uint8 and int8 values are supported as per onnx. |
4418 | if (type == ElemKind::UInt8QTy) { |
4419 | offsetValue = static_cast<int32_t>( |
4420 | offset->getPayload().getHandle<uint8_t>().raw(0)); |
4421 | } else if (type == ElemKind::Int8QTy) { |
4422 | offsetValue = |
4423 | static_cast<int32_t>(offset->getPayload().getHandle<int8_t>().raw(0)); |
4424 | } else { |
4425 | // This condition is hit when there is onnx graph creation issue or |
4426 | // constant is not created correctly. |
4427 | return MAKE_ERR( |
4428 | opErrMsg(op, "QuantizeLinear: Supports only uint8 or int8 data in" |
4429 | " y_zero_point" )); |
4430 | } |
4431 | } |
4432 | |
4433 | auto outDims = in.getType()->dims(); |
4434 | auto outTy = mod_.uniqueType(type, outDims, scaleValue, offsetValue); |
4435 | Node *N = G_->createQuantize(loadOperatorName(op), in, outTy); |
4436 | |
4437 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4438 | return Error::success(); |
4439 | } |
4440 | |
4441 | Error ONNXModelLoader::loadConvertTo(const ONNX_NAMESPACE::NodeProto &op, |
4442 | ArgumentDictionaryTy &dict) { |
4443 | NodeValue in; |
4444 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4445 | |
4446 | const auto *attr = dict.at("shape" ); |
4447 | RETURN_ERR_IF_NOT( |
4448 | attr->has_t(), |
4449 | opErrMsg(op, "ConvertTo should have t() field as \"shape\"" )); |
4450 | const auto &t = attr->t(); |
4451 | std::vector<dim_t> shape; |
4452 | for (const auto d : t.dims()) { |
4453 | shape.push_back(d); |
4454 | } |
4455 | |
4456 | auto type = ElemKind::FloatTy; |
4457 | RETURN_IF_ERR(onnxTensorDataTypeToElemKind(t.data_type(), &type)); |
4458 | auto outTy = mod_.uniqueType(type, shape); |
4459 | Node *N = G_->createConvertTo(loadOperatorName(op), in, outTy); |
4460 | |
4461 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4462 | return Error::success(); |
4463 | } |
4464 | |
4465 | Error ONNXModelLoader::loadDequantize(const ONNX_NAMESPACE::NodeProto &op, |
4466 | ArgumentDictionaryTy &dict) { |
4467 | NodeValue in; |
4468 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4469 | |
4470 | Node *N = G_->createDequantize(loadOperatorName(op), in, ElemKind::FloatTy); |
4471 | |
4472 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4473 | return Error::success(); |
4474 | } |
4475 | |
4476 | Error ONNXModelLoader::loadRegression(const ONNX_NAMESPACE::NodeProto &op, |
4477 | ArgumentDictionaryTy &dict) { |
4478 | NodeValue in; |
4479 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4480 | NodeValue expected; |
4481 | ASSIGN_VALUE_OR_RETURN_ERR(expected, getNodeValueByName(op.input(1))); |
4482 | |
4483 | Node *N = G_->createRegression(loadOperatorName(op), in, expected); |
4484 | |
4485 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4486 | return Error::success(); |
4487 | } |
4488 | |
4489 | Error ONNXModelLoader::loadBatchedAdd(const ONNX_NAMESPACE::NodeProto &op, |
4490 | ArgumentDictionaryTy &dict) { |
4491 | NodeValue batch; |
4492 | ASSIGN_VALUE_OR_RETURN_ERR(batch, getNodeValueByName(op.input(0))); |
4493 | NodeValue sample; |
4494 | ASSIGN_VALUE_OR_RETURN_ERR(sample, getNodeValueByName(op.input(1))); |
4495 | |
4496 | Node *N = G_->createBatchedAdd(loadOperatorName(op), batch, sample); |
4497 | |
4498 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4499 | return Error::success(); |
4500 | } |
4501 | |
4502 | Error ONNXModelLoader::loadCumSum(const ONNX_NAMESPACE::NodeProto &op, |
4503 | ArgumentDictionaryTy &dict) { |
4504 | if (op.input_size() > 1) { |
4505 | Expected<NodeValue> axis = getNodeValueByName(op.input(1)); |
4506 | if (axis) { |
4507 | if (auto *AC = llvm::dyn_cast<Constant>(axis->getNode())) { |
4508 | RETURN_ERR_IF_NOT(AC->getPayload().dims().size() == 1, |
4509 | opErrMsg(op, "CumSum axis must be 0-D" )); |
4510 | RETURN_ERR_IF_NOT(AC->getPayload().dims()[0] == 1, |
4511 | opErrMsg(op, "CumSum axis must be 0-D" )); |
4512 | RETURN_ERR_IF_NOT(AC->getHandle<int32_t>().at(0) == 0, |
4513 | opErrMsg(op, "CumSum only supports axis == 0" )); |
4514 | } else { |
4515 | return MAKE_ERR(opErrMsg(op, "Axis must be Constant" )); |
4516 | } |
4517 | |
4518 | // Axis default is 0, which is fine. |
4519 | } |
4520 | } |
4521 | |
4522 | NodeValue input; |
4523 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4524 | bool exclusive = false; |
4525 | if (dict.count("exclusive" )) { |
4526 | ASSIGN_VALUE_OR_RETURN_ERR(exclusive, loadInt(dict.at("exclusive" ))); |
4527 | } |
4528 | |
4529 | bool reverse = false; |
4530 | if (dict.count("reverse" )) { |
4531 | ASSIGN_VALUE_OR_RETURN_ERR(reverse, loadInt(dict.at("reverse" ))); |
4532 | } |
4533 | |
4534 | // TODO: add axis/dim support |
4535 | Node *N = |
4536 | G_->createCumSum(loadOperatorName(op), input, 0, exclusive, reverse); |
4537 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4538 | return Error::success(); |
4539 | } |
4540 | |
4541 | Error ONNXModelLoader::loadScatterAssign(const ONNX_NAMESPACE::NodeProto &op, |
4542 | ArgumentDictionaryTy &dict) { |
4543 | NodeValue data; |
4544 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
4545 | NodeValue indices; |
4546 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
4547 | NodeValue slices; |
4548 | ASSIGN_VALUE_OR_RETURN_ERR(slices, getNodeValueByName(op.input(2))); |
4549 | |
4550 | Node *N = G_->createScatterData(loadOperatorName(op), data, indices, slices); |
4551 | |
4552 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4553 | return Error::success(); |
4554 | } |
4555 | |
4556 | Error ONNXModelLoader::loadIntLookupTable(const ONNX_NAMESPACE::NodeProto &op, |
4557 | ArgumentDictionaryTy &dict) { |
4558 | NodeValue in; |
4559 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4560 | |
4561 | if (in.getType()->getElementType() == ElemKind::Int8QTy) { |
4562 | std::vector<int8_t> values; |
4563 | ASSIGN_VALUE_OR_RETURN_ERR(values, getShape<int8_t>(dict["values" ])); |
4564 | std::vector<dim_t> shape; |
4565 | ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape" ])); |
4566 | |
4567 | auto outTy = mod_.uniqueType(in.getElementType(), shape); |
4568 | Node *N = G_->createIntLookupTable<int8_t>(loadOperatorName(op), in, values, |
4569 | outTy); |
4570 | |
4571 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4572 | return Error::success(); |
4573 | } else if (in.getType()->getElementType() == ElemKind::Int16QTy) { |
4574 | std::vector<int16_t> values; |
4575 | ASSIGN_VALUE_OR_RETURN_ERR(values, getShape<int16_t>(dict["values" ])); |
4576 | std::vector<dim_t> shape; |
4577 | ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["shape" ])); |
4578 | |
4579 | auto outTy = mod_.uniqueType(in.getElementType(), shape); |
4580 | Node *N = G_->createIntLookupTable<int16_t>(loadOperatorName(op), in, |
4581 | values, outTy); |
4582 | |
4583 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4584 | return Error::success(); |
4585 | } else { |
4586 | return MAKE_ERR(strFormat("Lookup table type '%s' not supported!" , |
4587 | in.getType()->getElementName().str().c_str())); |
4588 | } |
4589 | } |
4590 | |
4591 | Error ONNXModelLoader::loadLengthsRangeFill(const ONNX_NAMESPACE::NodeProto &op, |
4592 | ArgumentDictionaryTy &dict) { |
4593 | NodeValue lengths; |
4594 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0))); |
4595 | unsigned_t size; |
4596 | ASSIGN_VALUE_OR_RETURN_ERR(size, loadInt(dict.at("size" ))); |
4597 | |
4598 | Node *N = G_->createLengthsRangeFill(loadOperatorName(op), lengths, size); |
4599 | |
4600 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4601 | return Error::success(); |
4602 | } |
4603 | |
4604 | Error ONNXModelLoader::loadRescaleQuantized(const ONNX_NAMESPACE::NodeProto &op, |
4605 | ArgumentDictionaryTy &dict) { |
4606 | NodeValue in; |
4607 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4608 | float scale; |
4609 | ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict.at("scale" ))); |
4610 | unsigned_t offset; |
4611 | ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict.at("offset" ))); |
4612 | |
4613 | auto inTy = in.getType(); |
4614 | auto outTy = |
4615 | mod_.uniqueType(inTy->getElementType(), inTy->dims(), scale, offset); |
4616 | |
4617 | Node *N = G_->createRescaleQuantized(loadOperatorName(op), in, outTy); |
4618 | |
4619 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4620 | return Error::success(); |
4621 | } |
4622 | |
4623 | Error ONNXModelLoader::loadRowwiseQuantizedSparseLengthsWeightedSum( |
4624 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
4625 | Constant *data; |
4626 | ASSIGN_VALUE_OR_RETURN_ERR(data, getConstantByName(op.input(0))); |
4627 | Constant *scales; |
4628 | ASSIGN_VALUE_OR_RETURN_ERR(scales, getConstantByName(op.input(1))); |
4629 | Constant *offsets; |
4630 | ASSIGN_VALUE_OR_RETURN_ERR(offsets, getConstantByName(op.input(2))); |
4631 | NodeValue weights; |
4632 | ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(3))); |
4633 | NodeValue indices; |
4634 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(4))); |
4635 | NodeValue lengths; |
4636 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(5))); |
4637 | LengthsMode lengthsMode; |
4638 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
4639 | |
4640 | Node *N = G_->createRowwiseQuantizedSparseLengthsWeightedSum( |
4641 | loadOperatorName(op), data, scales, offsets, weights, indices, lengths, |
4642 | /* precision */ ElemKind::FloatTy, /* useFP16Accumulation */ false, |
4643 | lengthsMode); |
4644 | |
4645 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4646 | return Error::success(); |
4647 | } |
4648 | |
4649 | Error ONNXModelLoader::loadFusedRowwiseQuantizedSparseLengthsWeightedSum( |
4650 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
4651 | NodeValue data; |
4652 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
4653 | NodeValue weights; |
4654 | ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1))); |
4655 | NodeValue indices; |
4656 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(2))); |
4657 | NodeValue lengths; |
4658 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(3))); |
4659 | LengthsMode lengthsMode; |
4660 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
4661 | |
4662 | Node *N = G_->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
4663 | loadOperatorName(op), data, weights, indices, lengths, |
4664 | /* useFP16Accumulation */ false, lengthsMode); |
4665 | |
4666 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4667 | return Error::success(); |
4668 | } |
4669 | |
4670 | Error ONNXModelLoader::loadFusedRowwiseQuantizedSparseLengthsSum( |
4671 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
4672 | NodeValue data; |
4673 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
4674 | NodeValue indices; |
4675 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
4676 | NodeValue lengths; |
4677 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(2))); |
4678 | LengthsMode lengthsMode; |
4679 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
4680 | |
4681 | Storage *dataS = llvm::dyn_cast<Storage>(data); |
4682 | Node *N = G_->createFusedRowwiseQuantizedSparseLengthsSum( |
4683 | loadOperatorName(op), dataS, indices, lengths, |
4684 | /* useFP16Accumulation */ false, lengthsMode); |
4685 | |
4686 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4687 | return Error::success(); |
4688 | } |
4689 | |
4690 | Error ONNXModelLoader::loadFullyConnected(const ONNX_NAMESPACE::NodeProto &op, |
4691 | ArgumentDictionaryTy &dict) { |
4692 | NodeValue in; |
4693 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4694 | NodeValue w; |
4695 | ASSIGN_VALUE_OR_RETURN_ERR(w, getNodeValueByName(op.input(1))); |
4696 | NodeValue b; |
4697 | ASSIGN_VALUE_OR_RETURN_ERR(b, getNodeValueByName(op.input(2))); |
4698 | |
4699 | unsigned_t axis = 1; |
4700 | if (dict.count("axis" )) { |
4701 | ASSIGN_VALUE_OR_RETURN_ERR( |
4702 | axis, loadAxis<unsigned_t>(dict.at("axis" ), in.dims().size())); |
4703 | } |
4704 | |
4705 | Node *N = G_->createFullyConnected(loadOperatorName(op), in, w, b, axis); |
4706 | |
4707 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4708 | return Error::success(); |
4709 | } |
4710 | |
4711 | Error ONNXModelLoader::loadRowwiseQuantizedFullyConnected( |
4712 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict) { |
4713 | NodeValue input; |
4714 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4715 | |
4716 | NodeValue weights; |
4717 | ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1))); |
4718 | auto *weightsC = llvm::dyn_cast<Constant>(weights.getNode()); |
4719 | |
4720 | NodeValue scales; |
4721 | ASSIGN_VALUE_OR_RETURN_ERR(scales, getNodeValueByName(op.input(2))); |
4722 | auto *scalesC = llvm::dyn_cast<Constant>(scales.getNode()); |
4723 | |
4724 | NodeValue offsets; |
4725 | ASSIGN_VALUE_OR_RETURN_ERR(offsets, getNodeValueByName(op.input(3))); |
4726 | auto *offsetsC = llvm::dyn_cast<Constant>(offsets.getNode()); |
4727 | |
4728 | NodeValue bias; |
4729 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getNodeValueByName(op.input(4))); |
4730 | auto *biasC = llvm::dyn_cast<Constant>(bias.getNode()); |
4731 | |
4732 | float outScale; |
4733 | ASSIGN_VALUE_OR_RETURN_ERR(outScale, loadFloat(dict.at("out_scale" ))); |
4734 | int32_t outOffset; |
4735 | ASSIGN_VALUE_OR_RETURN_ERR(outOffset, loadInt(dict.at("out_offset" ))); |
4736 | |
4737 | auto outTy = |
4738 | mod_.uniqueType(ElemKind::Int8QTy, {input.dims()[0], weights.dims()[0]}, |
4739 | outScale, outOffset); |
4740 | |
4741 | Node *N = G_->createRowwiseQuantizedFullyConnected( |
4742 | loadOperatorName(op) + ".rowwise_quantized_fc" , input, weightsC, scalesC, |
4743 | offsetsC, biasC, outTy); |
4744 | |
4745 | return addNodeAsOutput(op, N); |
4746 | } |
4747 | |
4748 | Error ONNXModelLoader::loadNonMaxSuppression( |
4749 | const ONNX_NAMESPACE::NodeProto &op, ArgumentDictionaryTy &dict, |
4750 | bool isV4) { |
4751 | NodeValue boxesNV; |
4752 | ASSIGN_VALUE_OR_RETURN_ERR(boxesNV, getNodeValueByName(op.input(0))); |
4753 | NodeValue scoresNV; |
4754 | ASSIGN_VALUE_OR_RETURN_ERR(scoresNV, getNodeValueByName(op.input(1))); |
4755 | unsigned maxOutputBoxesPerClass = 0; |
4756 | Constant *maxOutputBoxesPerClassC = nullptr; |
4757 | if (op.input_size() > 2 && !op.input(2).empty()) { |
4758 | maxOutputBoxesPerClassC = getConstantByNameOrNull(op.input(2)); |
4759 | RETURN_ERR_IF_NOT(maxOutputBoxesPerClassC, |
4760 | "NMS: maxOutputBoxesPerClass is not a constant tensor." ); |
4761 | if (maxOutputBoxesPerClassC->getPayload().getElementType() == |
4762 | ElemKind::Int64ITy) { |
4763 | maxOutputBoxesPerClass = |
4764 | maxOutputBoxesPerClassC->getPayload().getHandle<int64_t>().raw(0); |
4765 | } else if (maxOutputBoxesPerClassC->getPayload().getElementType() == |
4766 | ElemKind::Int32ITy) { |
4767 | maxOutputBoxesPerClass = |
4768 | maxOutputBoxesPerClassC->getPayload().getHandle<int32_t>().raw(0); |
4769 | } else { |
4770 | return MAKE_ERR("NMS: Unsupported type for maxoutputboxesperclass." ); |
4771 | } |
4772 | } |
4773 | float iouThreshold = 0.0f; |
4774 | Constant *iouThresholdC = nullptr; |
4775 | if (op.input_size() > 3 && !op.input(3).empty()) { |
4776 | iouThresholdC = getConstantByNameOrNull(op.input(3)); |
4777 | RETURN_ERR_IF_NOT(iouThresholdC, |
4778 | "NMS: iouThreshold is not a constant tensor." ); |
4779 | iouThreshold = iouThresholdC->getPayload().getHandle<float>().raw(0); |
4780 | } |
4781 | float scoreThreshold = 0.0f; |
4782 | Constant *scoreThresholdC = nullptr; |
4783 | if (op.input_size() > 4 && !op.input(4).empty()) { |
4784 | scoreThresholdC = getConstantByNameOrNull(op.input(4)); |
4785 | RETURN_ERR_IF_NOT(scoreThresholdC, |
4786 | "NMS: scoreThreshold is not a constant tensor." ); |
4787 | scoreThreshold = scoreThresholdC->getPayload().getHandle<float>().raw(0); |
4788 | } |
4789 | |
4790 | // Defaults to 0 which is the same representation as TF. |
4791 | unsigned centerPointBox = 0; |
4792 | if (dict.count("center_point_box" )) { |
4793 | ASSIGN_VALUE_OR_RETURN_ERR(centerPointBox, |
4794 | loadInt(dict.at("center_point_box" ))); |
4795 | } |
4796 | |
4797 | int32_t padToMaxOutputSize = 0; |
4798 | if (isV4) { |
4799 | if (dict.count("pad_to_max_output_size" )) { |
4800 | ASSIGN_VALUE_OR_RETURN_ERR(padToMaxOutputSize, |
4801 | loadInt(dict.at("pad_to_max_output_size" ))); |
4802 | } |
4803 | |
4804 | // Does it make sense within GLOW context to have no padding? Since Size has |
4805 | // to be compile time constant. |
4806 | RETURN_ERR_IF_NOT( |
4807 | padToMaxOutputSize == 1, |
4808 | opErrMsg(op, "NonMaxSuppressionV4 does not support non-padding mode." )); |
4809 | } |
4810 | |
4811 | // Create Node. |
4812 | std::string opName = loadOperatorName(op); |
4813 | Node *N = nullptr; |
4814 | |
4815 | if (isV4) { |
4816 | N = G_->createNonMaxSuppressionV4(opName, boxesNV, scoresNV, centerPointBox, |
4817 | maxOutputBoxesPerClass, iouThreshold, |
4818 | scoreThreshold); |
4819 | } else { |
4820 | N = G_->createNonMaxSuppressionONNX(opName, boxesNV, scoresNV, |
4821 | centerPointBox, maxOutputBoxesPerClass, |
4822 | iouThreshold, scoreThreshold); |
4823 | } |
4824 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4825 | return Error::success(); |
4826 | } |
4827 | |
4828 | Error ONNXModelLoader::loadSplat(const ONNX_NAMESPACE::NodeProto &op, |
4829 | ArgumentDictionaryTy &dict) { |
4830 | return loadConstantOfShape(op, dict, true /* isSplat */); |
4831 | } |
4832 | |
4833 | Error ONNXModelLoader::loadInsertTensor(const ONNX_NAMESPACE::NodeProto &op, |
4834 | ArgumentDictionaryTy &dict) { |
4835 | NodeValue big; |
4836 | ASSIGN_VALUE_OR_RETURN_ERR(big, getNodeValueByName(op.input(0))); |
4837 | NodeValue small; |
4838 | ASSIGN_VALUE_OR_RETURN_ERR(small, getNodeValueByName(op.input(1))); |
4839 | |
4840 | std::vector<dim_t> start; |
4841 | ASSIGN_VALUE_OR_RETURN_ERR(start, getShape<dim_t>(dict["start" ])); |
4842 | |
4843 | unsigned_t count = 1; |
4844 | if (dict.count("count" )) { |
4845 | ASSIGN_VALUE_OR_RETURN_ERR(count, loadInt(dict.at("count" ))); |
4846 | } |
4847 | |
4848 | unsigned_t axis = 0; |
4849 | if (dict.count("axis" )) { |
4850 | ASSIGN_VALUE_OR_RETURN_ERR( |
4851 | axis, loadAxis<unsigned_t>(dict.at("axis" ), big.dims().size())); |
4852 | } |
4853 | |
4854 | Node *N = G_->createInsertTensor(loadOperatorName(op), big, small, start, |
4855 | count, axis); |
4856 | |
4857 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4858 | return Error::success(); |
4859 | } |
4860 | |
4861 | Error ONNXModelLoader::loadIdentity(const ONNX_NAMESPACE::NodeProto &op, |
4862 | ArgumentDictionaryTy &dict) { |
4863 | NodeValue in; |
4864 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
4865 | RETURN_IF_ERR(addNodeAsOutput(op, in)); |
4866 | return Error::success(); |
4867 | } |
4868 | |
4869 | Error ONNXModelLoader::loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op, |
4870 | ArgumentDictionaryTy &dict) { |
4871 | const std::string &opName = loadOperatorName(op); |
4872 | |
4873 | NodeValue input; |
4874 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4875 | |
4876 | std::vector<unsigned_t> outputShape; |
4877 | ASSIGN_VALUE_OR_RETURN_ERR(outputShape, |
4878 | getShape<unsigned_t>(dict["output_size" ])); |
4879 | |
4880 | ShapeNHWC idim(input.dims()); |
4881 | |
4882 | auto outTy = mod_.uniqueTypeWithNewShape( |
4883 | input.getType(), {idim.n, outputShape[0], outputShape[1], idim.c}); |
4884 | |
4885 | Node *N = G_->createAdaptiveAvgPool(opName, input, outTy); |
4886 | |
4887 | return addNodeAsOutput(op, N); |
4888 | } |
4889 | |
4890 | Error ONNXModelLoader::loadFlip(const ONNX_NAMESPACE::NodeProto &op, |
4891 | ArgumentDictionaryTy &dict) { |
4892 | NodeValue input; |
4893 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4894 | |
4895 | unsigned_t axis = 0; |
4896 | if (dict.count("axis" )) { |
4897 | ASSIGN_VALUE_OR_RETURN_ERR( |
4898 | axis, loadAxis<unsigned_t>(dict.at("axis" ), input.dims().size())); |
4899 | } |
4900 | |
4901 | Node *N = G_->createFlip(loadOperatorName(op) + ".flip" , input, axis); |
4902 | |
4903 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4904 | return Error::success(); |
4905 | } |
4906 | |
4907 | Error ONNXModelLoader::loadAudioSpectrogram(const ONNX_NAMESPACE::NodeProto &op, |
4908 | ArgumentDictionaryTy &dict) { |
4909 | NodeValue input; |
4910 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
4911 | |
4912 | // Get window size (Required). |
4913 | int64_t windowSize; |
4914 | RETURN_ERR_IF_NOT( |
4915 | dict.count("window_size" ), |
4916 | "ONNX AudioSpectrogram 'window_size' attribute is required!" ); |
4917 | ASSIGN_VALUE_OR_RETURN_ERR(windowSize, loadInt(dict.at("window_size" ))); |
4918 | |
4919 | // Get window stride (Required). |
4920 | int64_t windowStride; |
4921 | RETURN_ERR_IF_NOT(dict.count("stride" ), |
4922 | "ONNX AudioSpectrogram 'stride' attribute is required!" ); |
4923 | ASSIGN_VALUE_OR_RETURN_ERR(windowStride, loadInt(dict.at("stride" ))); |
4924 | |
4925 | // Get magnitude squared flag (Optional)(Default: 1). |
4926 | int magnitudeSquared = 1; |
4927 | if (dict.count("magnitude_squared" ) && |
4928 | dict.at("magnitude_squared" )->has_i()) { |
4929 | magnitudeSquared = dict.at("magnitude_squared" )->i(); |
4930 | } |
4931 | |
4932 | Node *N = G_->createAudioSpectrogram(loadOperatorName(op), input, windowSize, |
4933 | windowStride, (bool)magnitudeSquared); |
4934 | |
4935 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
4936 | return Error::success(); |
4937 | } |
4938 | |
4939 | Error ONNXModelLoader::loadROIAlign(const ONNX_NAMESPACE::NodeProto &op, |
4940 | ArgumentDictionaryTy &dict) { |
4941 | NodeValue featureMap; |
4942 | ASSIGN_VALUE_OR_RETURN_ERR(featureMap, getNodeValueByName(op.input(0))); |
4943 | NodeValue boxes; |
4944 | ASSIGN_VALUE_OR_RETURN_ERR(boxes, getNodeValueByName(op.input(1))); |
4945 | NodeValue batchIndices; |
4946 | ASSIGN_VALUE_OR_RETURN_ERR(batchIndices, getNodeValueByName(op.input(2))); |
4947 | |
4948 | PoolingMode mode = PoolingMode::AVG; |
4949 | if (dict.count("mode" )) { |
4950 | std::string modeStr; |
4951 | ASSIGN_VALUE_OR_RETURN_ERR(modeStr, loadStr(dict.at("mode" ))); |
4952 | if (modeStr == "avg" ) { |
4953 | mode = PoolingMode::AVG; |
4954 | } else if (modeStr == "max" ) { |
4955 | mode = PoolingMode::MAX; |
4956 | } else { |
4957 | return MAKE_ERR(strFormat("Invalid PoolingMode: %s" , modeStr.c_str())); |
4958 | } |
4959 | } |
4960 | |
4961 | bool rotated = false; |
4962 | if (dict.count("rotated" )) { |
4963 | ASSIGN_VALUE_OR_RETURN_ERR(rotated, loadInt(dict.at("rotated" ))); |
4964 | } |
4965 | |
4966 | bool aligned = false; |
4967 | if (dict.count("aligned" )) { |
4968 | ASSIGN_VALUE_OR_RETURN_ERR(aligned, loadInt(dict.at("aligned" ))); |
4969 | } |
4970 | |
4971 | uint32_t outputHeight = 1; |
4972 | if (dict.count("output_height" )) { |
4973 | ASSIGN_VALUE_OR_RETURN_ERR(outputHeight, loadInt(dict.at("output_height" ))); |
4974 | } |
4975 | |
4976 | uint32_t outputWidth = 1; |
4977 | if (dict.count("output_width" )) { |
4978 | ASSIGN_VALUE_OR_RETURN_ERR(outputWidth, loadInt(dict.at("output_width" ))); |
4979 | } |
4980 | |
4981 | uint32_t samplingRatio = 0; |
4982 | if (dict.count("sampling_ratio" )) { |
4983 | ASSIGN_VALUE_OR_RETURN_ERR(samplingRatio, |
4984 | loadInt(dict.at("sampling_ratio" ))); |
4985 | } |
4986 | |
4987 | float spatialScale = 1.0; |
4988 | if (dict.count("spatial_scale" )) { |
4989 | ASSIGN_VALUE_OR_RETURN_ERR(spatialScale, |
4990 | loadFloat(dict.at("spatial_scale" ))); |
4991 | } |
4992 | |
4993 | const std::string &opName = loadOperatorName(op); |
4994 | featureMap = G_->createTranspose(opName, featureMap, NCHW2NHWC); |
4995 | Node *N = G_->createROIAlign(opName, featureMap, boxes, batchIndices, |
4996 | outputHeight, outputWidth, samplingRatio, |
4997 | spatialScale, aligned, rotated, mode); |
4998 | N = G_->createTranspose(opName, N, NHWC2NCHW); |
4999 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
5000 | return Error::success(); |
5001 | } |
5002 | |
5003 | Error ONNXModelLoader::loadMFCC(const ONNX_NAMESPACE::NodeProto &op, |
5004 | ArgumentDictionaryTy &dict) { |
5005 | NodeValue spectrogram; |
5006 | ASSIGN_VALUE_OR_RETURN_ERR(spectrogram, getNodeValueByName(op.input(0))); |
5007 | |
5008 | // Get sample rate [Hz] (Required). |
5009 | float sampleRate; |
5010 | RETURN_ERR_IF_NOT(dict.count("sample_rate" ), |
5011 | "ONNX MFCC 'sample_rate' attribute is required!" ); |
5012 | ASSIGN_VALUE_OR_RETURN_ERR(sampleRate, loadFloat(dict.at("sample_rate" ))); |
5013 | |
5014 | // Get lower frequency [Hz] (Required). |
5015 | float lowerFrequency; |
5016 | RETURN_ERR_IF_NOT(dict.count("lower_frequency_limit" ), |
5017 | "ONNX MFCC 'lower_frequency_limit' attribute is required!" ); |
5018 | ASSIGN_VALUE_OR_RETURN_ERR(lowerFrequency, |
5019 | loadFloat(dict.at("lower_frequency_limit" ))); |
5020 | |
5021 | // Get upper frequency [Hz] (Required). |
5022 | float upperFrequency; |
5023 | RETURN_ERR_IF_NOT(dict.count("upper_frequency_limit" ), |
5024 | "ONNX MFCC 'upper_frequency_limit' attribute is required!" ); |
5025 | ASSIGN_VALUE_OR_RETURN_ERR(upperFrequency, |
5026 | loadFloat(dict.at("upper_frequency_limit" ))); |
5027 | |
5028 | // Get filter bank count (Required). |
5029 | int64_t filterBankCount; |
5030 | RETURN_ERR_IF_NOT( |
5031 | dict.count("filterbank_channel_count" ), |
5032 | "ONNX MFCC 'filterbank_channel_count' attribute is required!" ); |
5033 | ASSIGN_VALUE_OR_RETURN_ERR(filterBankCount, |
5034 | loadInt(dict.at("filterbank_channel_count" ))); |
5035 | |
5036 | // Get number of coefficients (Required). |
5037 | int64_t numCoefficients; |
5038 | RETURN_ERR_IF_NOT(dict.count("dct_coefficient_count" ), |
5039 | "ONNX MFCC 'dct_coefficient_count' attribute is required!" ); |
5040 | ASSIGN_VALUE_OR_RETURN_ERR(numCoefficients, |
5041 | loadInt(dict.at("dct_coefficient_count" ))); |
5042 | |
5043 | Node *N = G_->createMFCC(loadOperatorName(op), spectrogram, sampleRate, |
5044 | lowerFrequency, upperFrequency, filterBankCount, |
5045 | numCoefficients); |
5046 | |
5047 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
5048 | return Error::success(); |
5049 | } |
5050 | |
5051 | Error ONNXModelLoader::loadAsin(const ONNX_NAMESPACE::NodeProto &op, |
5052 | const ArgumentDictionaryTy &dict) { |
5053 | const std::string &opName = loadOperatorName(op); |
5054 | NodeValue in; |
5055 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5056 | auto outTy = mod_.uniqueType(*(in.getType())); |
5057 | Node *node = G_->createAsin(opName, outTy, in); |
5058 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
5059 | return Error::success(); |
5060 | } |
5061 | |
5062 | Error ONNXModelLoader::loadAcos(const ONNX_NAMESPACE::NodeProto &op, |
5063 | const ArgumentDictionaryTy &dict) { |
5064 | const std::string &opName = loadOperatorName(op); |
5065 | NodeValue in; |
5066 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5067 | auto outTy = mod_.uniqueType(*(in.getType())); |
5068 | Node *node = G_->createAcos(opName, outTy, in); |
5069 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
5070 | return Error::success(); |
5071 | } |
5072 | |
5073 | Error ONNXModelLoader::loadAtan(const ONNX_NAMESPACE::NodeProto &op, |
5074 | const ArgumentDictionaryTy &dict) { |
5075 | const std::string &opName = loadOperatorName(op); |
5076 | NodeValue in; |
5077 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5078 | auto outTy = mod_.uniqueType(*(in.getType())); |
5079 | Node *node = G_->createAtan(opName, outTy, in); |
5080 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
5081 | return Error::success(); |
5082 | } |
5083 | |
5084 | Error ONNXModelLoader::loadSign(const ONNX_NAMESPACE::NodeProto &op, |
5085 | const ArgumentDictionaryTy &dict) { |
5086 | const std::string &opName = loadOperatorName(op); |
5087 | NodeValue in; |
5088 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5089 | Node *node = G_->createSign(opName, in); |
5090 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
5091 | return Error::success(); |
5092 | } |
5093 | |
5094 | Error ONNXModelLoader::loadSoftmax(const ONNX_NAMESPACE::NodeProto &op, |
5095 | const ArgumentDictionaryTy &dict) { |
5096 | |
5097 | const std::string &opName = loadOperatorName(op); |
5098 | NodeValue in; |
5099 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5100 | |
5101 | RETURN_ERR_IF_NOT(in.dims().size() >= 2, "SoftMax input dims must be >= 2" ); |
5102 | |
5103 | // Create a constant to store labels to be used in SoftMaxGradNode. |
5104 | auto selected = |
5105 | mod_.createConstant(ElemKind::Int64ITy, {in.dims()[0], 1}, "selected" ); |
5106 | |
5107 | if (opsetVersion_ == 13) { |
5108 | int axis = in.dims().size() - 1; |
5109 | if (dict.count("axis" )) { |
5110 | ASSIGN_VALUE_OR_RETURN_ERR( |
5111 | axis, loadAxis<int>(dict.at("axis" ), in.dims().size())); |
5112 | } |
5113 | RETURN_ERR_IF_NOT(in.dims().size() == 4, "SoftMax 13 input dims must be 4" ); |
5114 | // Compute the shuffle layout based on axis input. |
5115 | std::vector<unsigned_t> shuffle; |
5116 | std::vector<unsigned_t> shuffleBack; |
5117 | switch (axis) { |
5118 | case 0: |
5119 | shuffle = {1u, 2u, 3u, 0u}; |
5120 | shuffleBack = {3u, 0u, 1u, 2u}; |
5121 | break; |
5122 | |
5123 | case 1: |
5124 | shuffle = {0u, 2u, 3u, 1u}; |
5125 | shuffleBack = {0u, 3u, 1u, 2u}; |
5126 | break; |
5127 | |
5128 | case 2: |
5129 | shuffle = {0u, 1u, 3u, 2u}; |
5130 | shuffleBack = {0u, 1u, 3u, 2u}; |
5131 | break; |
5132 | |
5133 | case 3: |
5134 | shuffle = {0u, 1u, 2u, 3u}; |
5135 | shuffleBack = {0u, 1u, 2u, 3u}; |
5136 | break; |
5137 | |
5138 | default: |
5139 | return MAKE_ERR("SoftMax Axis must be <=3" ); |
5140 | break; |
5141 | } |
5142 | auto *NH = G_->createTranspose(opName, in, shuffle); |
5143 | auto *FN = G_->createFlattenV1("reshapeInput" , NH, axis); |
5144 | auto *SM = G_->createSoftMax(opName, FN, selected); |
5145 | |
5146 | // The output should have the same shape as the original input. |
5147 | auto origInDims = NH->getResult().dims(); |
5148 | auto *RN = G_->createReshape("reshapeOutput" , SM, origInDims); |
5149 | auto *NC = G_->createTranspose(opName, RN, shuffleBack); |
5150 | RETURN_IF_ERR(addNodeAsOutput(op, NC)); |
5151 | } else { |
5152 | // ONNX allows shapes like <N x 10 x 1 x 1 >. Flatten the inputs to the |
5153 | // softmax function. This is basimilar to a bitcast operation. |
5154 | int axis = 1; |
5155 | if (dict.count("axis" )) { |
5156 | ASSIGN_VALUE_OR_RETURN_ERR( |
5157 | axis, loadAxis<int>(dict.at("axis" ), in.dims().size())); |
5158 | } |
5159 | auto *FN = G_->createFlatten("reshapeInput" , in, axis); |
5160 | auto *SM = G_->createSoftMax(opName, FN, selected); |
5161 | |
5162 | // The output should have the same shape as the original input. |
5163 | auto origInDims = in.getType()->dims(); |
5164 | auto *RN = G_->createReshape("reshapeOutput" , SM, origInDims); |
5165 | RETURN_IF_ERR(addNodeAsOutput(op, RN)); |
5166 | } |
5167 | return Error::success(); |
5168 | } |
5169 | |
5170 | Error ONNXModelLoader::loadLogSoftmax(const ONNX_NAMESPACE::NodeProto &op, |
5171 | const ArgumentDictionaryTy &dict) { |
5172 | |
5173 | const std::string &opName = loadOperatorName(op); |
5174 | NodeValue in; |
5175 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5176 | |
5177 | RETURN_ERR_IF_NOT(in.dims().size() >= 2, |
5178 | "LogSoftMax input dims must be >= 2" ); |
5179 | |
5180 | // Create a constant to store labels to be used in SoftMaxGradNode. |
5181 | auto selected = |
5182 | mod_.createConstant(ElemKind::Int64ITy, {in.dims()[0], 1}, "selected" ); |
5183 | |
5184 | if (opsetVersion_ == 13) { |
5185 | int axis = in.dims().size() - 1; |
5186 | if (dict.count("axis" )) { |
5187 | ASSIGN_VALUE_OR_RETURN_ERR( |
5188 | axis, loadAxis<int>(dict.at("axis" ), in.dims().size())); |
5189 | } |
5190 | RETURN_ERR_IF_NOT(in.dims().size() == 4, |
5191 | "LogSoftMax 13 input dims must be 4" ); |
5192 | // Compute the shuffle layout based on axis input. |
5193 | std::vector<unsigned_t> shuffle; |
5194 | std::vector<unsigned_t> shuffleBack; |
5195 | switch (axis) { |
5196 | case 0: |
5197 | shuffle = {1u, 2u, 3u, 0u}; |
5198 | shuffleBack = {3u, 0u, 1u, 2u}; |
5199 | break; |
5200 | |
5201 | case 1: |
5202 | shuffle = {0u, 2u, 3u, 1u}; |
5203 | shuffleBack = {0u, 3u, 1u, 2u}; |
5204 | break; |
5205 | |
5206 | case 2: |
5207 | shuffle = {0u, 1u, 3u, 2u}; |
5208 | shuffleBack = {0u, 1u, 3u, 2u}; |
5209 | break; |
5210 | |
5211 | case 3: |
5212 | shuffle = {0u, 1u, 2u, 3u}; |
5213 | shuffleBack = {0u, 1u, 2u, 3u}; |
5214 | break; |
5215 | |
5216 | default: |
5217 | return MAKE_ERR("LogSoftMax Axis must be <=3" ); |
5218 | break; |
5219 | } |
5220 | auto *NH = G_->createTranspose(opName, in, shuffle); |
5221 | auto *FN = G_->createFlattenV1("reshapeInput" , NH, axis); |
5222 | auto *SM = G_->createLogSoftMax(opName, FN, selected); |
5223 | |
5224 | // The output should have the same shape as the original input. |
5225 | auto origInDims = NH->getResult().dims(); |
5226 | auto *RN = G_->createReshape("reshapeOutput" , SM, origInDims); |
5227 | auto *NC = G_->createTranspose(opName, RN, shuffleBack); |
5228 | RETURN_IF_ERR(addNodeAsOutput(op, NC)); |
5229 | } else { |
5230 | // ONNX allows shapes like <N x 10 x 1 x 1 >. Flatten the inputs to the |
5231 | // logsoftmax function. This is basimilar to a bitcast operation. |
5232 | int axis = 1; |
5233 | if (dict.count("axis" )) { |
5234 | ASSIGN_VALUE_OR_RETURN_ERR( |
5235 | axis, loadAxis<int>(dict.at("axis" ), in.dims().size())); |
5236 | } |
5237 | auto *FN = G_->createFlatten("reshapeInput" , in, axis); |
5238 | auto *SM = G_->createLogSoftMax(opName, FN, selected); |
5239 | |
5240 | // The output should have the same shape as the original input. |
5241 | auto origInDims = in.getType()->dims(); |
5242 | auto *RN = G_->createReshape("reshapeOutput" , SM, origInDims); |
5243 | RETURN_IF_ERR(addNodeAsOutput(op, RN)); |
5244 | } |
5245 | return Error::success(); |
5246 | } |
5247 | |
5248 | Error ONNXModelLoader::loadScatterData(const ONNX_NAMESPACE::NodeProto &op, |
5249 | const ArgumentDictionaryTy &dict) { |
5250 | |
5251 | const std::string &opName = loadOperatorName(op); |
5252 | NodeValue data; |
5253 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
5254 | NodeValue indices; |
5255 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
5256 | NodeValue values; |
5257 | ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2))); |
5258 | |
5259 | RETURN_ERR_IF_NOT(indices.dims().size() == 2, |
5260 | opErrMsg(op, "Indices must be a 2D tensor!" )); |
5261 | RETURN_ERR_IF_NOT(indices.dims()[0] == values.dims()[0], |
5262 | opErrMsg(op, "Indices and values must have same lengths!" )); |
5263 | |
5264 | bool cumulative = false; |
5265 | if (dict.count("cumulative" )) { |
5266 | ASSIGN_VALUE_OR_RETURN_ERR(cumulative, loadInt(dict.at("cumulative" ))); |
5267 | } |
5268 | |
5269 | Node *node = G_->createScatterData(opName, data, indices, values, cumulative); |
5270 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
5271 | return Error::success(); |
5272 | } |
5273 | |
5274 | Error ONNXModelLoader::loadTopK(const ONNX_NAMESPACE::NodeProto &op, |
5275 | ArgumentDictionaryTy &dict) { |
5276 | const std::string &opName = loadOperatorName(op); |
5277 | NodeValue in; |
5278 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5279 | RETURN_ERR_IF_NOT( |
5280 | op.input_size() <= 2, |
5281 | opErrMsg( |
5282 | op, |
5283 | strFormat( |
5284 | "TopK: Maximum number of inputs is 2, but found input size %d " , |
5285 | op.input_size()))); |
5286 | unsigned_t k = 0; |
5287 | if (op.input_size() > 1) { |
5288 | Constant *kConst = getConstantByNameOrNull(op.input(1)); |
5289 | RETURN_ERR_IF_NOT( |
5290 | kConst, opErrMsg(op, "TopK: Non-constant k is not supported by Glow." )); |
5291 | RETURN_ERR_IF_NOT( |
5292 | kConst->getElementType() == ElemKind::Int64ITy, |
5293 | opErrMsg(op, |
5294 | strFormat("TopK: k input must be of type Int64, but found " |
5295 | "input type '%s' " , |
5296 | kConst->getType()->getElementName().str().c_str()))); |
5297 | auto constH = kConst->getPayload().getHandle<int64_t>(); |
5298 | k = constH.at({0}); |
5299 | } else { |
5300 | ASSIGN_VALUE_OR_RETURN_ERR(k, loadInt(dict["k" ])); |
5301 | } |
5302 | |
5303 | int lastDim = in.dims().size() - 1; |
5304 | int axis = lastDim; |
5305 | if (dict.count("axis" )) { |
5306 | ASSIGN_VALUE_OR_RETURN_ERR(axis, |
5307 | loadAxis<int>(dict["axis" ], in.dims().size())); |
5308 | } |
5309 | |
5310 | RETURN_ERR_IF_NOT( |
5311 | axis == lastDim, |
5312 | opErrMsg( |
5313 | op, |
5314 | strFormat( |
5315 | "TopK: Currently only support axis %d being last dimension %d " , |
5316 | axis, lastDim))); |
5317 | |
5318 | auto *R = G_->createTopK(opName, in, k); |
5319 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
5320 | return Error::success(); |
5321 | } |
5322 | |
5323 | Error ONNXModelLoader::loadLoop(const ONNX_NAMESPACE::NodeProto &op, |
5324 | const ArgumentDictionaryTy &dict) { |
5325 | int64_t maxTripCount; |
5326 | bool ignoreMaxTripCount = op.input(0).empty(); |
5327 | if (!ignoreMaxTripCount) { |
5328 | Constant *M = getConstantByNameOrNull(op.input(0)); |
5329 | RETURN_ERR_IF_NOT(M, "Loop operator M input must be a constant." ); |
5330 | RETURN_ERR_IF_NOT(M->getElementType() == ElemKind::Int64ITy, |
5331 | "Loop operator M input must be int64." ); |
5332 | maxTripCount = (M->getPayload().getHandle<int64_t>()).raw(0); |
5333 | |
5334 | RETURN_ERR_IF_NOT( |
5335 | maxTripCount >= 0, |
5336 | strFormat("Loop operator trip count (%ld) must be positive" , |
5337 | maxTripCount)); |
5338 | } |
5339 | |
5340 | bool condOrig = false; |
5341 | bool ignoreCond = op.input(1).empty(); |
5342 | if (!ignoreCond) { |
5343 | Constant *cond = getConstantByNameOrNull(op.input(1)); |
5344 | RETURN_ERR_IF_NOT(cond, "Loop operator cond input must be a constant." ); |
5345 | RETURN_ERR_IF_NOT(cond->getElementType() == ElemKind::BoolTy, |
5346 | "Loop operator cond input must be bool." ); |
5347 | condOrig = (cond->getPayload().getHandle<bool>()).raw(0); |
5348 | } |
5349 | |
5350 | RETURN_ERR_IF_NOT(dict.count("body" ), "Loop body not found." ); |
5351 | auto body = dict.at("body" )->g(); |
5352 | |
5353 | // 2 + N (i.e., maximum trip-count, cond, and N) |
5354 | const int numLoopInputs = op.input_size(); |
5355 | // N + K (final N loop carried dependency values then K scan_outputs) |
5356 | const int numLoopOutputs = op.output_size(); |
5357 | // 2 + N (i.e., iteration_num, condition, and N loop carried dependencies) |
5358 | const int numBodyInputs = body.input_size(); |
5359 | // 1 + N + K (i.e., condition, N loop carried dependencies, and K |
5360 | // scan_outputs) |
5361 | const int numBodyOutputs = body.output_size(); |
5362 | |
5363 | RETURN_ERR_IF_NOT(numLoopInputs >= 2 && numLoopInputs == numBodyInputs && |
5364 | numLoopOutputs == numBodyOutputs - 1 && |
5365 | numLoopOutputs >= 1, |
5366 | "Mismatched inputs/outputs of Loop and subgraph." ); |
5367 | |
5368 | const int numK = numBodyOutputs - numBodyInputs + 1; |
5369 | const int numN = numLoopInputs - 2; |
5370 | |
5371 | CHECK_GE(numN, 0) << "Invalid number of v_initial in Loop operator : " |
5372 | << numN; |
5373 | CHECK_GE(numK, 0) << "Invalid number of scan_outputs in Loop operator : " |
5374 | << numK; |
5375 | |
5376 | // Handle a loop with no iterations. |
5377 | if ((!ignoreMaxTripCount && maxTripCount == 0) || |
5378 | (!ignoreCond && !condOrig)) { |
5379 | // No need to load the subgraph, just connect loop's N v_initial to |
5380 | // N v_final and empty Tensor for K scan_outputs. |
5381 | llvm::SmallVector<NodeValue, 4> outputs; |
5382 | outputs.reserve(numLoopOutputs); |
5383 | // Connect N v_initial to N v_final. |
5384 | for (int i = 0; i < numN; ++i) { |
5385 | NodeValue in; |
5386 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i + 2))); |
5387 | outputs.push_back(in); |
5388 | } |
5389 | // Connect empty Tensors for K scan_outputs. |
5390 | for (int k = 0; k < numK; ++k) { |
5391 | int ki = (body.input_size() - 1) + k; |
5392 | auto scan_output = body.output(ki); |
5393 | Type ty; |
5394 | ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(scan_output)); |
5395 | Tensor T(ty); |
5396 | T.zero(); |
5397 | Constant *c = G_->getParent()->createConstant("empty" , std::move(T)); |
5398 | outputs.push_back(G_->createExpandDims("unsqueeze_K" , c, {0})); |
5399 | } |
5400 | RETURN_IF_ERR(assignNodeOutputs(op, outputs)); |
5401 | return Error::success(); |
5402 | } |
5403 | |
5404 | // Now, there is at least one iteration. |
5405 | llvm::SmallVector<NodeValue, 4> inputs; |
5406 | inputs.reserve(numBodyInputs); |
5407 | |
5408 | // Collect names of N loop carried dependencies from Loop op. It will be used |
5409 | // to connect Loop's inputs with subgraph's inputs for the first iteration. |
5410 | llvm::SmallVector<llvm::StringRef, 4> namesOfLoopNs; |
5411 | namesOfLoopNs.reserve(numN); |
5412 | for (int i = 0; i < numN; ++i) { |
5413 | namesOfLoopNs.push_back(op.input(i + 2)); |
5414 | } |
5415 | |
5416 | // Collect output node names for the next iteration. It will be used when |
5417 | // connecting subgraph's outputs with subgraph's input between iteration. Add |
5418 | // names just for N loop carried dependencies. No need to add names for K |
5419 | // scan_outputs because they are not input of subgraph. |
5420 | llvm::SmallVector<llvm::StringRef, 4> namesOfBodyOutputNs; |
5421 | namesOfBodyOutputNs.reserve(numN); |
5422 | for (int i = 0; i < numN; ++i) { |
5423 | namesOfBodyOutputNs.push_back(body.output(i + 1).name()); |
5424 | } |
5425 | |
5426 | // Collect names of K scan_outputs. |
5427 | llvm::SmallVector<llvm::SmallVector<NodeValue, 2>, 2> scanOutputKs; |
5428 | scanOutputKs.reserve(numK); |
5429 | for (int k = 0; k < numK; ++k) { |
5430 | llvm::SmallVector<NodeValue, 2> scanOutputIter; |
5431 | scanOutputIter.reserve(loopUnrollLimit); |
5432 | scanOutputKs.push_back(scanOutputIter); |
5433 | } |
5434 | |
5435 | auto getDummyCond = [&]() -> Expected<NodeValue> { |
5436 | RETURN_ERR_IF_NOT(ignoreCond, "Unexpected empty name in cond in Loop" ); |
5437 | Tensor dummyCondT(ElemKind::BoolTy, {1}); |
5438 | dummyCondT.zero(); |
5439 | Constant *dummyCondNode = |
5440 | G_->getParent()->createConstant("dumpCond" , std::move(dummyCondT)); |
5441 | return dummyCondNode->getOutput(); |
5442 | }; |
5443 | |
5444 | auto getIterationNumConst = [&](int64_t val) -> Constant * { |
5445 | Tensor T(ElemKind::Int64ITy, {1}); |
5446 | T.getHandle<int64_t>() = {val}; |
5447 | Constant *C = G_->getParent()->createConstant("const" , std::move(T)); |
5448 | return C; |
5449 | }; |
5450 | |
5451 | auto prepareNextIteration = |
5452 | [&](int64_t iterationNum, NodeValue condNode, |
5453 | llvm::ArrayRef<llvm::StringRef> outputNames) -> Error { |
5454 | inputs.clear(); |
5455 | // Set iteration_num. |
5456 | inputs.push_back(getIterationNumConst(iterationNum)); |
5457 | // Set condition. |
5458 | inputs.push_back(condNode); |
5459 | // Set N loop carried dependencies. |
5460 | for (auto oName : outputNames) { |
5461 | NodeValue in; |
5462 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(oName)); |
5463 | inputs.push_back(in); |
5464 | } |
5465 | RETURN_IF_ERR(assignGraphInputs(body, inputs)); |
5466 | return Error::success(); |
5467 | }; |
5468 | |
5469 | auto accumulateScanOutputs = [&]() -> Error { |
5470 | // Accumulate scan-outputs values. |
5471 | for (int k = 0; k < numK; ++k) { |
5472 | int ki = (body.input_size() - 1) + k; |
5473 | auto scan_output = body.output(ki); |
5474 | NodeValue outK; |
5475 | ASSIGN_VALUE_OR_RETURN_ERR(outK, getNodeValueByName(scan_output.name())); |
5476 | Node *unsqueezedOutK = G_->createExpandDims("unsqueezed_K" , outK, {0}); |
5477 | scanOutputKs[k].push_back(unsqueezedOutK); |
5478 | } |
5479 | return Error::success(); |
5480 | }; |
5481 | |
5482 | auto loadSubgraph = [&](ONNX_NAMESPACE::GraphProto &graphDef) -> Error { |
5483 | RETURN_IF_ERR(loadInitializers(graphDef)); |
5484 | RETURN_IF_ERR(loadNetwork(graphDef, false)); |
5485 | return Error::success(); |
5486 | }; |
5487 | |
5488 | auto canUnrollNextIter = [&](int64_t iterNum) -> Expected<bool> { |
5489 | if (!ignoreMaxTripCount && iterNum >= maxTripCount) { |
5490 | return false; |
5491 | } |
5492 | Constant *condOut = getConstantByNameOrNull(body.output(0).name()); |
5493 | RETURN_ERR_IF_NOT(condOut, |
5494 | "Loop exit condition is unpredictable to be unrolled." ); |
5495 | bool cond = (condOut->getPayload().getHandle<bool>()).raw(0) || ignoreCond; |
5496 | if (!cond) { |
5497 | return false; |
5498 | } |
5499 | RETURN_ERR_IF_NOT( |
5500 | iterNum < loopUnrollLimit, |
5501 | strFormat("Exceed the unroll limit (%u) while unrolling Loop operator." , |
5502 | loopUnrollLimit.getValue())); |
5503 | return cond; |
5504 | }; |
5505 | |
5506 | // Unroll the first iteration by connecting Loop's inputs to subgraph's |
5507 | // inputs. |
5508 | int64_t iterNum = 0; |
5509 | NodeValue condNode; |
5510 | ASSIGN_VALUE_OR_RETURN_ERR(condNode, op.input(1).empty() |
5511 | ? getDummyCond() |
5512 | : getNodeValueByName(op.input(1))); |
5513 | RETURN_IF_ERR(prepareNextIteration(iterNum, condNode, namesOfLoopNs)); |
5514 | RETURN_IF_ERR(loadSubgraph(body)); |
5515 | RETURN_IF_ERR(accumulateScanOutputs()); |
5516 | ++iterNum; |
5517 | auto condCheck = canUnrollNextIter(iterNum); |
5518 | |
5519 | RETURN_ERR_IF_NOT(condCheck, ERR_TO_STRING(condCheck.takeError())); |
5520 | |
5521 | // Unroll remaining iterations by connecting outputs of previous iteration |
5522 | // with inputs of current iteration. |
5523 | while (condCheck && condCheck.get()) { |
5524 | NodeValue condOutNode; |
5525 | ASSIGN_VALUE_OR_RETURN_ERR(condOutNode, |
5526 | getNodeValueByName(body.output(0).name())); |
5527 | RETURN_IF_ERR( |
5528 | prepareNextIteration(iterNum, condOutNode, namesOfBodyOutputNs)); |
5529 | RETURN_IF_ERR(loadSubgraph(body)); |
5530 | RETURN_IF_ERR(accumulateScanOutputs()); |
5531 | ++iterNum; |
5532 | condCheck = canUnrollNextIter(iterNum); |
5533 | RETURN_ERR_IF_NOT(condCheck, ERR_TO_STRING(condCheck.takeError())); |
5534 | } |
5535 | |
5536 | // Hook final subgraph outputs to loop outputs. |
5537 | llvm::SmallVector<NodeValue, 4> outputs; |
5538 | outputs.reserve(numLoopOutputs); |
5539 | // Set outputs for N loop carried dependency values. |
5540 | for (int i = 0; i < numN; ++i) { |
5541 | NodeValue bodyout; |
5542 | ASSIGN_VALUE_OR_RETURN_ERR(bodyout, |
5543 | getNodeValueByName(body.output(i + 1).name())); |
5544 | outputs.push_back(bodyout); |
5545 | } |
5546 | // Set outputs for K scan_outputs. |
5547 | for (int k = 0; k < numK; ++k) { |
5548 | outputs.push_back(G_->createConcat("concat" , scanOutputKs[k], 0)); |
5549 | } |
5550 | RETURN_IF_ERR(assignNodeOutputs(op, outputs)); |
5551 | return Error::success(); |
5552 | } |
5553 | |
5554 | Expected<TypeRef> |
5555 | ONNXModelLoader::loadTypeFromAttributes(unsigned resNo, |
5556 | ArgumentDictionaryTy &dict) { |
5557 | // Load ElemKind. |
5558 | std::string elemKindStr; |
5559 | ASSIGN_VALUE_OR_RETURN_ERR( |
5560 | elemKindStr, loadStr(dict[getTypeAttrID(resNo, elemKindSignifier)])); |
5561 | const ElemKind k = Type::getElementKindFromName(elemKindStr); |
5562 | |
5563 | // Load Shape. Note that we allow for empty shapes here because 0 dimensional |
5564 | // shapes are allowed (representing scalars). |
5565 | std::vector<dim_t> shape; |
5566 | ASSIGN_VALUE_OR_RETURN_ERR( |
5567 | shape, getShape<dim_t>(dict[getTypeAttrID(resNo, shapeSignifier)], |
5568 | /* allowEmptyShape */ true)); |
5569 | |
5570 | // Load strides. Note that we allow for empty strides here because 0 |
5571 | // dimensional shapes are allowed (representing scalars). |
5572 | std::vector<dim_t> strides; |
5573 | auto stridesKey = getTypeAttrID(resNo, stridesSignifier); |
5574 | if (dict.count(stridesKey)) { |
5575 | ASSIGN_VALUE_OR_RETURN_ERR(strides, |
5576 | getShape<dim_t>(dict[stridesKey], |
5577 | /* allowEmptyShape */ true)); |
5578 | } |
5579 | |
5580 | // Create and return uniqued non-quantized Type. |
5581 | if (!isQuantizedElemKind(k)) { |
5582 | return getTypeWithCustomStrides(mod_, mod_.uniqueType(k, shape), strides); |
5583 | } |
5584 | |
5585 | // Must be quantized kind, so get scale/offset and create and return uniqued |
5586 | // quantized Type. |
5587 | float scale; |
5588 | ASSIGN_VALUE_OR_RETURN_ERR( |
5589 | scale, loadFloat(dict[getTypeAttrID(resNo, qScaleSignifier)])); |
5590 | int32_t offset; |
5591 | ASSIGN_VALUE_OR_RETURN_ERR( |
5592 | offset, loadInt(dict[getTypeAttrID(resNo, qOffsetSignifier)])); |
5593 | |
5594 | // If we have a scale of dummyScale, then this must be a dummy pair of |
5595 | // scale/offset. Look up the actual scale/offset to use as previously loaded, |
5596 | // using the offset as the key to updatedTQPs_. Skip fused kinds because |
5597 | // scales are already dummies. |
5598 | if (replaceDummyTQPs_ && scale == dummyScale && |
5599 | !isFusedQuantizedElemKind(k)) { |
5600 | TensorQuantizationParams TQP; |
5601 | ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(offset)); |
5602 | scale = TQP.scale; |
5603 | offset = TQP.offset; |
5604 | } |
5605 | |
5606 | return getTypeWithCustomStrides( |
5607 | mod_, mod_.uniqueType(k, shape, scale, offset), strides); |
5608 | } |
5609 | |
5610 | Expected<Node *> |
5611 | ONNXModelLoader::tryLoadGlowCustomOp(llvm::StringRef typeName, |
5612 | const ONNX_NAMESPACE::NodeProto &op, |
5613 | ArgumentDictionaryTy &dict) { |
5614 | const std::string &opName = loadOperatorName(op); |
5615 | |
5616 | // Try all automatically generated import cases. |
5617 | #include "glow/AutoGenNodesImport.h" |
5618 | |
5619 | // If we get here then no case handled the op, so return nullptr. |
5620 | return nullptr; |
5621 | } |
5622 | |
5623 | /// Load Node options for \p loadedNode from \p dict and set in \p nodeInfo. |
5624 | /// These are specified in the format "NodeOpt_BACKENDNAME_OPTIONNAME". |
5625 | static Error loadPerNodeOptions(const Node *loadedNode, |
5626 | BackendSpecificNodeInfo &nodeInfo, |
5627 | ArgumentDictionaryTy &dict) { |
5628 | // Look through all attributes in the dict for ones that have NodeOpt_ prefix. |
5629 | for (const auto &attrPair : dict) { |
5630 | // Split across the first '_' and check if it has the "NodeOpt" prefix. |
5631 | auto splitPair = llvm::StringRef(attrPair.first).split('_'); |
5632 | if (splitPair.first == attrPair.first && splitPair.first == "" ) { |
5633 | // No '_' found, so continue. |
5634 | continue; |
5635 | } |
5636 | if (splitPair.first != nodeOptSignifier) { |
5637 | // Prefix is not "NodeOpt_", so continue. |
5638 | continue; |
5639 | } |
5640 | |
5641 | // Must have a NodeOpt, so check it has strings and load them into nodeInfo. |
5642 | const ONNX_NAMESPACE::AttributeProto *attr = attrPair.second; |
5643 | RETURN_ERR_IF_NOT(attr->strings_size() > 0, |
5644 | strFormat("%s in %s has no strings" , |
5645 | attrPair.first.c_str(), |
5646 | loadedNode->getName().data())); |
5647 | std::vector<std::string> &attrVals = |
5648 | nodeInfo[loadedNode->getParent()][loadedNode][splitPair.second]; |
5649 | for (const std::string &s : attr->strings()) { |
5650 | attrVals.push_back(s); |
5651 | } |
5652 | } |
5653 | return Error::success(); |
5654 | } |
5655 | |
5656 | Error ONNXModelLoader::loadIf(const ONNX_NAMESPACE::NodeProto &op, |
5657 | const ArgumentDictionaryTy &dict) { |
5658 | Constant *condition = getConstantByNameOrNull(op.input(0)); |
5659 | RETURN_ERR_IF_NOT(condition, "Only constant condition is supported!" ); |
5660 | RETURN_ERR_IF_NOT(condition->getElementType() == ElemKind::BoolTy, |
5661 | "Condition must be boolean!" ); |
5662 | |
5663 | RETURN_ERR_IF_NOT(dict.count("then_branch" ), "Then branch not found!" ); |
5664 | RETURN_ERR_IF_NOT(dict.count("else_branch" ), "Else branch not found!" ); |
5665 | RETURN_ERR_IF_NOT(dict.at("then_branch" )->type() == |
5666 | ONNX_NAMESPACE::AttributeProto::GRAPH, |
5667 | "Only Subgraph branches are supported." ); |
5668 | RETURN_ERR_IF_NOT(dict.at("else_branch" )->type() == |
5669 | ONNX_NAMESPACE::AttributeProto::GRAPH, |
5670 | "Only Subgraph branches are supported." ); |
5671 | |
5672 | auto loadSubgraph = [&](ONNX_NAMESPACE::GraphProto &graphDef) -> Error { |
5673 | RETURN_IF_ERR(loadInitializers(graphDef)); |
5674 | RETURN_IF_ERR(loadNetwork(graphDef, false)); |
5675 | return Error::success(); |
5676 | }; |
5677 | |
5678 | if (condition->getPayload().getHandle<bool>().isZero()) { |
5679 | auto ifFalse = dict.at("else_branch" )->g(); |
5680 | RETURN_ERR_IF_NOT(ifFalse.output_size() == 1, |
5681 | "Only single output 'else' subgraph is supported." ); |
5682 | RETURN_IF_ERR(loadSubgraph(ifFalse)); |
5683 | NodeValue ifFalseVal; |
5684 | ASSIGN_VALUE_OR_RETURN_ERR(ifFalseVal, |
5685 | getNodeValueByName(ifFalse.output(0).name())); |
5686 | RETURN_IF_ERR(addNodeAsOutput(op, ifFalseVal)); |
5687 | } else { |
5688 | auto ifTrue = dict.at("then_branch" )->g(); |
5689 | RETURN_ERR_IF_NOT(ifTrue.output_size() == 1, |
5690 | "Only single output 'then' subgraph is supported." ); |
5691 | RETURN_IF_ERR(loadSubgraph(ifTrue)); |
5692 | NodeValue ifTrueVal; |
5693 | ASSIGN_VALUE_OR_RETURN_ERR(ifTrueVal, |
5694 | getNodeValueByName(ifTrue.output(0).name())); |
5695 | RETURN_IF_ERR(addNodeAsOutput(op, ifTrueVal)); |
5696 | } |
5697 | return Error::success(); |
5698 | } |
5699 | |
5700 | Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) { |
5701 | ArgumentDictionaryTy dict = loadArgumentMap(op); |
5702 | const std::string &typeName = op.op_type(); |
5703 | mod_.registerOriginalName(op.name()); |
5704 | |
5705 | if (useGlowCustomOps_) { |
5706 | Node *loadedNode; |
5707 | ASSIGN_VALUE_OR_RETURN_ERR(loadedNode, |
5708 | tryLoadGlowCustomOp(typeName, op, dict)); |
5709 | if (loadedNode) { |
5710 | if (!perNodeOpts_) { |
5711 | return Error::success(); |
5712 | } |
5713 | return loadPerNodeOptions(loadedNode, *perNodeOpts_, dict); |
5714 | } |
5715 | |
5716 | // These are handled earlier when loading initializers and inputs and so can |
5717 | // be safely ignored here. |
5718 | if (typeName == constFoldSubgraphNodeName || |
5719 | typeName == staticPHDummyNodeName) { |
5720 | return Error::success(); |
5721 | } |
5722 | |
5723 | // Identity is the only official ONNX op used with useGlowCustomOps. Let it |
5724 | // fall through to logic to handle below, otherwise return error. |
5725 | if (typeName != "Identity" ) { |
5726 | return MAKE_ERR("Failed to load operator " + typeName + " ." , |
5727 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); |
5728 | } |
5729 | } |
5730 | |
5731 | // Check if operator is supported in parent class, CommonOperatorLoader. |
5732 | bool tryLoadCommonOperatorResult; |
5733 | ASSIGN_VALUE_OR_RETURN_ERR(tryLoadCommonOperatorResult, |
5734 | tryLoadCommonOperator(typeName, op, dict)); |
5735 | if (tryLoadCommonOperatorResult) { |
5736 | return Error::success(); |
5737 | } |
5738 | if (typeName == "Loop" ) { |
5739 | return loadLoop(op, dict); |
5740 | } |
5741 | |
5742 | if (typeName == "Constant" ) { |
5743 | return loadConstant(op, dict); |
5744 | } |
5745 | if (typeName == "Range" ) { |
5746 | return loadRange(op, dict); |
5747 | } |
5748 | if (typeName == "PRelu" ) { |
5749 | return loadPRelu(op, dict); |
5750 | } |
5751 | if (typeName == "Slice" ) { |
5752 | return loadSlice(op, dict); |
5753 | } |
5754 | if (typeName == "Sin" || typeName == "Cos" ) { |
5755 | return loadTrigonometricOps(typeName, op, dict); |
5756 | } |
5757 | if (typeName == "Erf" ) { |
5758 | return loadErf(op, dict); |
5759 | } |
5760 | if (typeName == "Conv" ) { |
5761 | // If the Conv operator has quantized inputs and |
5762 | // dict contains the scale and offset params, use |
5763 | // loadTensorwiseQuantizedConvolution. |
5764 | NodeValue in; |
5765 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5766 | return in.getType()->isQuantizedType() && dict.count("out_scale" ) && |
5767 | dict.count("out_offset" ) |
5768 | ? loadTensorwiseQuantizedConvolution(op, dict) |
5769 | : loadConv(op, dict); |
5770 | } |
5771 | if (typeName == "ChannelwiseQuantizedConvolution" ) { |
5772 | return loadChannelwiseQuantizedConvolution(op, dict); |
5773 | } |
5774 | if (typeName == "MaxPool" || typeName == "AveragePool" ) { |
5775 | // If the pool operator has quantized inputs, use |
5776 | // loadTensorwiseQuantizedPool. |
5777 | NodeValue in; |
5778 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
5779 | return in.getType()->isQuantizedType() && dict.count("out_scale" ) && |
5780 | dict.count("out_offset" ) |
5781 | ? loadTensorwiseQuantizedPool(op, dict, typeName) |
5782 | : loadPool(op, dict, typeName); |
5783 | } |
5784 | if (typeName == "GlobalAveragePool" ) { |
5785 | return loadGlobalAveragePool(op, dict); |
5786 | } |
5787 | if (typeName == "Squeeze" ) { |
5788 | return loadSqueeze(op, dict); |
5789 | } |
5790 | if (typeName == "Unsqueeze" ) { |
5791 | return loadUnsqueeze(op, dict); |
5792 | } |
5793 | if (typeName == "BatchNormalization" ) { |
5794 | return loadBatchNormalization(op, dict); |
5795 | } |
5796 | if (typeName == "InstanceNormalization" ) { |
5797 | return loadInstanceNormalization(op, dict); |
5798 | } |
5799 | if (typeName == "Concat" ) { |
5800 | return loadConcat(op, dict); |
5801 | } |
5802 | if (typeName == "FCTransposed" ) { |
5803 | return loadFCTransposed(op, dict); |
5804 | } |
5805 | if (typeName == "Gemm" ) { |
5806 | return loadGemm(op, dict); |
5807 | } |
5808 | if (typeName == "Transpose" ) { |
5809 | return loadTranspose(op, dict, "perm" ); |
5810 | } |
5811 | if (typeName == "ReduceSumSquare" ) { |
5812 | return loadReduceOp(typeName, op, dict); |
5813 | } |
5814 | if (typeName == "MatMul" ) { |
5815 | return loadMatMul(op, dict); |
5816 | } |
5817 | if (typeName == "Pad" ) { |
5818 | return loadPad(op, dict); |
5819 | } |
5820 | if (typeName == "Cast" ) { |
5821 | return loadCast(op, dict); |
5822 | } |
5823 | if (typeName == "HardSigmoid" ) { |
5824 | return loadHardSigmoid(op, dict); |
5825 | } |
5826 | if (typeName == "LeakyRelu" ) { |
5827 | return loadLeakyRelu(op, dict); |
5828 | } |
5829 | if (typeName == "SpaceToDepth" ) { |
5830 | return loadSpaceToDepth(op, dict); |
5831 | } |
5832 | if (typeName == "DepthToSpace" ) { |
5833 | return loadDepthToSpace(op, dict); |
5834 | } |
5835 | if (typeName == "ReduceL2" ) { |
5836 | return loadReduceL2(op, dict); |
5837 | } |
5838 | if (typeName == "ConstantOfShape" ) { |
5839 | return loadConstantOfShape(op, dict, false /* isSplat */); |
5840 | } |
5841 | if (typeName == "Tile" ) { |
5842 | return loadTile(op, dict); |
5843 | } |
5844 | if (typeName == "Expand" ) { |
5845 | return loadExpand(op, dict); |
5846 | } |
5847 | if (typeName == "Where" ) { |
5848 | return loadWhere(op, dict); |
5849 | } |
5850 | if (typeName == "RNN" ) { |
5851 | return loadRNN(op, dict); |
5852 | } |
5853 | if (typeName == "GRU" ) { |
5854 | return loadGRU(op, dict); |
5855 | } |
5856 | if (typeName == "LSTM" ) { |
5857 | return loadLSTM(op, dict); |
5858 | } |
5859 | if (typeName == "Clip" ) { |
5860 | return loadClip(op, dict); |
5861 | } |
5862 | if (typeName == "Equal" ) { |
5863 | return loadCmpEQ(op, dict); |
5864 | } |
5865 | if (typeName == "CmpLTE" ) { |
5866 | return loadCmpLTE(op, dict); |
5867 | } |
5868 | if (typeName == "Mean" ) { |
5869 | return loadMean(op, dict); |
5870 | } |
5871 | if (typeName == "Select" ) { |
5872 | return loadSelect(op, dict); |
5873 | } |
5874 | if (typeName == "Quantize" ) { |
5875 | return loadQuantize(op, dict); |
5876 | } |
5877 | if (typeName == "QuantizeLinear" ) { |
5878 | return loadQuantizeLinear(op, dict); |
5879 | } |
5880 | if (typeName == "ConvertTo" ) { |
5881 | return loadConvertTo(op, dict); |
5882 | } |
5883 | if ((typeName == "Dequantize" ) || (typeName == "DequantizeLinear" )) { |
5884 | return loadDequantize(op, dict); |
5885 | } |
5886 | if (typeName == "Regression" ) { |
5887 | return loadRegression(op, dict); |
5888 | } |
5889 | if (typeName == "BatchedAdd" ) { |
5890 | return loadBatchedAdd(op, dict); |
5891 | } |
5892 | if (typeName == "CumSum" ) { |
5893 | return loadCumSum(op, dict); |
5894 | } |
5895 | if ((typeName == "ScatterAssign" ) || (typeName == "ScatterND" )) { |
5896 | return loadScatterAssign(op, dict); |
5897 | } |
5898 | if (typeName == "IntLookupTable" ) { |
5899 | return loadIntLookupTable(op, dict); |
5900 | } |
5901 | if (typeName == "LengthsRangeFill" ) { |
5902 | return loadLengthsRangeFill(op, dict); |
5903 | } |
5904 | if (typeName == "RescaleQuantized" ) { |
5905 | return loadRescaleQuantized(op, dict); |
5906 | } |
5907 | if (typeName == "RowwiseQuantizedSparseLengthsWeightedSum" ) { |
5908 | return loadRowwiseQuantizedSparseLengthsWeightedSum(op, dict); |
5909 | } |
5910 | if (typeName == "FusedRowwiseQuantizedSparseLengthsWeightedSum" ) { |
5911 | return loadFusedRowwiseQuantizedSparseLengthsWeightedSum(op, dict); |
5912 | } |
5913 | if (typeName == "FusedRowwiseQuantizedSparseLengthsSum" ) { |
5914 | return loadFusedRowwiseQuantizedSparseLengthsSum(op, dict); |
5915 | } |
5916 | if (typeName == "FullyConnected" ) { |
5917 | return loadFullyConnected(op, dict); |
5918 | } |
5919 | if (typeName == "RowwiseQuantizedFullyConnected" ) { |
5920 | return loadRowwiseQuantizedFullyConnected(op, dict); |
5921 | } |
5922 | if (typeName == "Splat" ) { |
5923 | return loadSplat(op, dict); |
5924 | } |
5925 | if (typeName == "InsertTensor" ) { |
5926 | return loadInsertTensor(op, dict); |
5927 | } |
5928 | if (typeName == "ArgMin" ) { |
5929 | return loadArgMinMax(op, dict, true); |
5930 | } |
5931 | if (typeName == "ArgMax" ) { |
5932 | return loadArgMinMax(op, dict, false); |
5933 | } |
5934 | if (typeName == "NonMaxSuppressionV4" ) { |
5935 | return loadNonMaxSuppression(op, dict, true); |
5936 | } |
5937 | if (typeName == "NonMaxSuppression" ) { |
5938 | return loadNonMaxSuppression(op, dict, false); |
5939 | } |
5940 | if (typeName == "ConvTranspose" ) { |
5941 | return loadConvTranspose(op, dict); |
5942 | } |
5943 | if (typeName == "If" ) { |
5944 | return loadIf(op, dict); |
5945 | } |
5946 | if (typeName == "AdaptiveAvgPool" ) { |
5947 | return loadAdaptiveAvgPool(op, dict); |
5948 | } |
5949 | if (typeName == "Flip" ) { |
5950 | return loadFlip(op, dict); |
5951 | } |
5952 | if (typeName == "AudioSpectrogram" ) { |
5953 | return loadAudioSpectrogram(op, dict); |
5954 | } |
5955 | if (typeName == "RoiAlign" ) { |
5956 | return loadROIAlign(op, dict); |
5957 | } |
5958 | if (typeName == "MFCC" ) { |
5959 | return loadMFCC(op, dict); |
5960 | } |
5961 | if (typeName == "Identity" ) { |
5962 | return loadIdentity(op, dict); |
5963 | } |
5964 | if (typeName == "Upsample" ) { |
5965 | return loadUpsample(op, dict); |
5966 | } |
5967 | if (typeName == "Resize" ) { |
5968 | return loadResize(op, dict); |
5969 | } |
5970 | if (typeName == "NonZero" ) { |
5971 | return loadNonZero(op, dict); |
5972 | } |
5973 | if (typeName == "Acos" ) { |
5974 | return loadAcos(op, dict); |
5975 | } |
5976 | if (typeName == "Asin" ) { |
5977 | return loadAsin(op, dict); |
5978 | } |
5979 | if (typeName == "Atan" ) { |
5980 | return loadAtan(op, dict); |
5981 | } |
5982 | if (typeName == "Sign" ) { |
5983 | return loadSign(op, dict); |
5984 | } |
5985 | if (typeName == "Softmax" ) { |
5986 | return loadSoftmax(op, dict); |
5987 | } |
5988 | if (typeName == "LogSoftmax" ) { |
5989 | return loadLogSoftmax(op, dict); |
5990 | } |
5991 | if (typeName == "ScatterData" ) { |
5992 | return loadScatterData(op, dict); |
5993 | } |
5994 | if (typeName == "TopK" ) { |
5995 | return loadTopK(op, dict); |
5996 | } |
5997 | |
5998 | return MAKE_ERR("Failed to load operator " + typeName + " ." , |
5999 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); |
6000 | } |
6001 | |
6002 | void ONNXModelLoader::deleteConstFoldFunctions() { |
6003 | for (Function *constFoldF : constFoldFuns_) { |
6004 | mod_.eraseFunction(constFoldF); |
6005 | } |
6006 | } |
6007 | |
6008 | Expected<Constant *> |
6009 | ONNXModelLoader::runDeserializedConstFold(llvm::StringRef initializerName, |
6010 | llvm::StringRef outputName) { |
6011 | NodeValue NV; |
6012 | ASSIGN_VALUE_OR_RETURN_ERR(NV, getNodeValueByName(outputName)); |
6013 | |
6014 | // Force folding single splats, because we're folding a constant folding |
6015 | // subgraph, and so we know the exported model already decided to fold it |
6016 | // (normally the backend decides whether to fold it or not). |
6017 | std::vector<Constant *> constResults = |
6018 | constantFold(NV.getNode(), /* foldSingleSplats */ true); |
6019 | RETURN_ERR_IF_NOT(constResults.size() > 0, |
6020 | strFormat("Constant folding did not occur for %s" , |
6021 | NV.getNode()->getName().data())); |
6022 | RETURN_ERR_IF_NOT(NV.getResNo() < constResults.size(), |
6023 | strFormat("Needed result %u from const folding results, " |
6024 | "but only got %lu results" , |
6025 | NV.getResNo(), constResults.size())); |
6026 | Constant *foldedC = constResults[NV.getResNo()]; |
6027 | |
6028 | // Now we have the final Constant we want and it exists in the module. Set its |
6029 | // name to the actual initializer it came with if not already named that. |
6030 | if (foldedC->getName() != initializerName) { |
6031 | RETURN_ERR_IF_NOT( |
6032 | mod_.getConstantByName(initializerName) == nullptr, |
6033 | strFormat("Already had a Constant by name %s" , initializerName.data())); |
6034 | foldedC->setName(initializerName); |
6035 | } |
6036 | RETURN_ERR_IF_NOT( |
6037 | nodeValueByName_.count(initializerName) == 0, |
6038 | strFormat("Should not have been a Constant by name %s registered yet" , |
6039 | initializerName.data())); |
6040 | nodeValueByName_[initializerName] = foldedC->getOutput(); |
6041 | |
6042 | return foldedC; |
6043 | } |
6044 | |
6045 | Expected<Constant *> ONNXModelLoader::replaySerializedConstFold( |
6046 | const ONNX_NAMESPACE::TensorProto &in, ONNX_NAMESPACE::GraphProto &net) { |
6047 | // Check if ins has a constant folding node associated with it. |
6048 | const char *constFoldNodeName = nullptr; |
6049 | int resNo = -1; |
6050 | for (const auto &keyVal : in.external_data()) { |
6051 | if (keyVal.key() == "ConstFoldNodeName" ) { |
6052 | constFoldNodeName = keyVal.value().data(); |
6053 | continue; |
6054 | } |
6055 | if (keyVal.key() == "ConstFoldResNo" ) { |
6056 | ASSIGN_VALUE_OR_RETURN_ERR(resNo, getIntFromStr(keyVal.value())); |
6057 | continue; |
6058 | } |
6059 | } |
6060 | if (!constFoldNodeName) { |
6061 | return nullptr; |
6062 | } |
6063 | RETURN_ERR_IF_NOT(resNo >= 0, |
6064 | "Require ConstFoldResNo for Glow__ConstFoldSubgraph" ); |
6065 | |
6066 | // Look through the ops in the graph to find the Node we need by name. |
6067 | ONNX_NAMESPACE::NodeProto *op = nullptr; |
6068 | for (int i = 0; i < net.node_size(); i++) { |
6069 | auto *curOp = net.mutable_node(i); |
6070 | if (loadOperatorName(*curOp) == constFoldNodeName) { |
6071 | op = curOp; |
6072 | break; |
6073 | } |
6074 | } |
6075 | RETURN_ERR_IF_NOT( |
6076 | op, strFormat("Did not find Node by name %s" , constFoldNodeName)); |
6077 | RETURN_ERR_IF_NOT( |
6078 | op->op_type() == constFoldSubgraphNodeName, |
6079 | strFormat("Node %s has type %s but expected Glow__ConstFoldSubgraph" , |
6080 | constFoldNodeName, op->op_type().data())); |
6081 | |
6082 | // Now look through the Node's attributes to find the subgraph. |
6083 | ONNX_NAMESPACE::GraphProto *subgraph = nullptr; |
6084 | for (auto &arg : *op->mutable_attribute()) { |
6085 | if (arg.name() == "ConstFoldSubgraph" ) { |
6086 | subgraph = arg.mutable_g(); |
6087 | break; |
6088 | } |
6089 | } |
6090 | |
6091 | RETURN_ERR_IF_NOT(subgraph, strFormat("Expected associated subgraph for %s" , |
6092 | constFoldNodeName)); |
6093 | |
6094 | // We have the constant folding subgraph proto and need to load it to run it. |
6095 | const bool functionAlreadyLoaded = mod_.hasFunction(constFoldNodeName); |
6096 | Function *constFoldF = functionAlreadyLoaded |
6097 | ? mod_.getFunction(constFoldNodeName) |
6098 | : mod_.createFunction(constFoldNodeName); |
6099 | const auto insert = constFoldFuns_.insert(constFoldF); |
6100 | RETURN_ERR_IF_NOT(!(functionAlreadyLoaded && insert.second), |
6101 | strFormat("Function %s should only be processed once" , |
6102 | constFoldNodeName)); |
6103 | |
6104 | // Temporarily swap in state for the constant folding Function. |
6105 | Function *origF = G_; |
6106 | G_ = constFoldF; |
6107 | llvm::StringMap<Function *> partNameToFunBackup; |
6108 | std::swap(partNameToFun_, partNameToFunBackup); |
6109 | // Make sure to restore original state of the loader when exiting this scope. |
6110 | ScopeGuard restoreOrigStateGuard([&]() { |
6111 | G_ = origF; |
6112 | std::swap(partNameToFun_, partNameToFunBackup); |
6113 | }); |
6114 | |
6115 | // Deserialize the Function if not already done. |
6116 | if (!functionAlreadyLoaded) { |
6117 | RETURN_IF_ERR(loadNetwork(*subgraph, /* loadingConstFoldSubgraph */ true)); |
6118 | } |
6119 | |
6120 | // Now that we have the Function deserialized, actually run and return the |
6121 | // resulting Constant that is foldled. |
6122 | RETURN_ERR_IF_NOT(subgraph->output_size() > resNo, |
6123 | strFormat("ConstFoldResNo %d invalid output idx." , resNo)); |
6124 | return runDeserializedConstFold(in.name(), subgraph->output(resNo).name()); |
6125 | } |
6126 | |
6127 | Error ONNXModelLoader::loadInitializers(ONNX_NAMESPACE::GraphProto &net) { |
6128 | // Load the network initializers: |
6129 | for (const auto &in : net.initializer()) { |
6130 | // Replay any constant folding that occurred from previous optimization if |
6131 | // necessary. foldedC will be left as nullptr if no constant folding occurs. |
6132 | Constant *foldedC; |
6133 | ASSIGN_VALUE_OR_RETURN_ERR(foldedC, replaySerializedConstFold(in, net)); |
6134 | |
6135 | std::string layout = ANY_LAYOUT; |
6136 | if (useGlowCustomOps_) { |
6137 | ASSIGN_VALUE_OR_RETURN_ERR( |
6138 | layout, getAttrFromDocString(layoutSignifier, in.doc_string())); |
6139 | } |
6140 | |
6141 | // If we already an existing module then expect to find Constants already |
6142 | // existing for each initializer. |
6143 | if (foldedC || loadIntoExistingModule_) { |
6144 | Constant *C = foldedC ? foldedC : mod_.getConstantByName(in.name()); |
6145 | Type ty; |
6146 | ASSIGN_VALUE_OR_RETURN_ERR(ty, getTensorType(in)); |
6147 | |
6148 | // If the expected type is fused, and we are processing an initializer |
6149 | // with payload that already exists in the Module, then set the type to |
6150 | // fused here. This is because Caffe2 and ONNX (non-Glow-custom) protos do |
6151 | // not support fused ElemKinds, so we should explicitly set them as we do |
6152 | // during Caffe2ModelLoading. |
6153 | if (!foldedC && loadIntoExistingModule_ && ty.isFusedQuantizedType()) { |
6154 | RETURN_IF_ERR(setFusedTy(C, mod_.uniqueType(ty))); |
6155 | } |
6156 | |
6157 | RETURN_IF_ERR(verifyPreexistingStorage(C, in.name(), ty, layout)); |
6158 | nodeValueByName_[in.name()] = C->getOutput(); |
6159 | continue; |
6160 | } |
6161 | |
6162 | // If we are loading into an existing module then we would expect this |
6163 | // initializer doesn't have any data associated with it. |
6164 | Tensor T; |
6165 | RETURN_IF_ERR(loadTensor(in, &T, useGlowCustomOps_)); |
6166 | RETURN_IF_ERR(createAndRegisterConstant(in.name(), std::move(T), layout)); |
6167 | } |
6168 | |
6169 | return Error::success(); |
6170 | } |
6171 | |
6172 | Error ONNXModelLoader::setOutputNodes(ONNX_NAMESPACE::GraphProto &net) { |
6173 | if (net.output_size() == 0) { |
6174 | return MAKE_ERR("Net output size must be greater than 0" ); |
6175 | } |
6176 | |
6177 | for (int i = 0; i < net.output_size(); i++) { |
6178 | const auto &outputName = net.output(i).name(); |
6179 | NodeValue r; |
6180 | ASSIGN_VALUE_OR_RETURN_ERR(r, getNodeValueByName(outputName)); |
6181 | |
6182 | const std::string &docString = net.output(i).doc_string(); |
6183 | |
6184 | Expected<std::string> saveName = |
6185 | getAttrFromDocString(saveNameSignifier, docString); |
6186 | |
6187 | const bool hasSpecifiedSaveName = |
6188 | !ERR_TO_BOOL(saveName.takeError(), /* log */ false); |
6189 | const std::string &saveNodeName = |
6190 | hasSpecifiedSaveName ? saveName.get() : outputName; |
6191 | |
6192 | std::pair<bool, std::string> trainableLayoutPair; |
6193 | ASSIGN_VALUE_OR_RETURN_ERR( |
6194 | trainableLayoutPair, |
6195 | getTrainableLayoutPairFromDocString(docString, useGlowCustomOps_)); |
6196 | |
6197 | // If loadIntoExistingModule_ then it's reasonable for there to be a savePH |
6198 | // already. If not then there shouldn't be one. |
6199 | Placeholder *savePH = mod_.getPlaceholderByNameSlow(outputName); |
6200 | if (!savePH) { |
6201 | savePH = mod_.createPlaceholder(r.getType(), outputName, |
6202 | trainableLayoutPair.first, |
6203 | trainableLayoutPair.second); |
6204 | } else { |
6205 | RETURN_ERR_IF_NOT(loadIntoExistingModule_, |
6206 | "Found pre-existing PH by name " + outputName); |
6207 | RETURN_IF_ERR(verifyPreexistingStorage(savePH, outputName, *r.getType(), |
6208 | trainableLayoutPair.second, |
6209 | trainableLayoutPair.first)); |
6210 | } |
6211 | G_->createSave(saveNodeName, r, savePH, hasSpecifiedSaveName); |
6212 | |
6213 | auto loaderNameOrErr = getAttrFromDocString(loaderNameSignifier, docString); |
6214 | const std::string &loaderName = |
6215 | !ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false) |
6216 | ? loaderNameOrErr.get() |
6217 | : outputName; |
6218 | RETURN_ERR_IF_NOT(outputVarsByName_.try_emplace(loaderName, savePH).second, |
6219 | "Already had output placeholder by name " + loaderName); |
6220 | } |
6221 | |
6222 | return Error::success(); |
6223 | } |
6224 | |
6225 | Error ONNXModelLoader::loadNetwork(ONNX_NAMESPACE::GraphProto &net, |
6226 | bool loadingConstFoldSubgraph) { |
6227 | /// Load the network operators: |
6228 | for (int i = 0; i < net.node_size(); i++) { |
6229 | auto &op = net.node(i); |
6230 | |
6231 | // Always ignore these since they're dummy nodes used to just carry meta |
6232 | // info that is processed via setupOrigStaticTypeMap(). |
6233 | if (op.op_type() == staticPHDummyNodeName) { |
6234 | continue; |
6235 | } |
6236 | |
6237 | // Set up current partition to load into if relevant. |
6238 | if (partNameToFun_.size() && !loadingConstFoldSubgraph && |
6239 | op.op_type() != constFoldSubgraphNodeName) { |
6240 | const ONNX_NAMESPACE::AttributeProto *pNameAttr = nullptr; |
6241 | for (auto &arg : op.attribute()) { |
6242 | if (arg.name() == "partitionName" ) { |
6243 | pNameAttr = &arg; |
6244 | break; |
6245 | } |
6246 | } |
6247 | RETURN_ERR_IF_NOT(pNameAttr, "partitionName not found for " + op.name()); |
6248 | std::string pName; |
6249 | ASSIGN_VALUE_OR_RETURN_ERR(pName, loadStr(pNameAttr)); |
6250 | auto it = partNameToFun_.find(pName); |
6251 | RETURN_ERR_IF_NOT(it != partNameToFun_.end(), |
6252 | "Did not find partition with name " + pName); |
6253 | G_ = it->second; |
6254 | } |
6255 | RETURN_ERR_IF_NOT(G_, "Internal Glow error; Graph was not valid." ); |
6256 | |
6257 | if (constFoldInLoader_) { |
6258 | auto tryFold = foldOperator(op); |
6259 | if (!tryFold) { |
6260 | // Error during constant folding; load the op normally below. |
6261 | const std::string errStr = |
6262 | ERR_TO_STRING(tryFold.takeError(), /* warning */ true); |
6263 | VLOG(1) << "Issue while trying to ConstantFold " << loadOperatorName(op) |
6264 | << ": " << errStr; |
6265 | } else if (tryFold.get()) { |
6266 | // Folded successfully, so skip loading the op below. |
6267 | continue; |
6268 | } |
6269 | } |
6270 | RETURN_IF_ERR(loadOperator(op)); |
6271 | } |
6272 | |
6273 | return Error::success(); |
6274 | } |
6275 | |
6276 | ONNXModelLoader::ONNXModelLoader(Function &F, Error *errPtr) |
6277 | : CommonOperatorLoader({}, {}, &F, errPtr) { |
6278 | deleteUnusedConstants(); |
6279 | } |
6280 | |
6281 | static Error checkStaticPH(const ONNX_NAMESPACE::ValueInfoProto &valueInfo, |
6282 | std::unordered_set<std::string> &staticInputs, |
6283 | bool useGlowCustomOps) { |
6284 | const std::string &inputName = valueInfo.name(); |
6285 | if (useGlowCustomOps) { |
6286 | std::string isStatic; |
6287 | ASSIGN_VALUE_OR_RETURN_ERR( |
6288 | isStatic, |
6289 | getAttrFromDocString(staticSignifier, valueInfo.doc_string())); |
6290 | if (isStatic == "1" ) { |
6291 | staticInputs.emplace(inputName); |
6292 | } |
6293 | } else if (valueInfo.has_doc_string() && |
6294 | valueInfo.doc_string() == staticSignifier) { |
6295 | staticInputs.emplace(inputName); |
6296 | } |
6297 | return Error::success(); |
6298 | } |
6299 | |
6300 | Error ONNXModelLoader::collectStaticInputs(ONNX_NAMESPACE::GraphProto &net) { |
6301 | for (int i = 0; i < net.input_size(); i++) { |
6302 | RETURN_IF_ERR( |
6303 | checkStaticPH(net.input(i), staticInputs_, useGlowCustomOps_)); |
6304 | } |
6305 | return Error::success(); |
6306 | } |
6307 | |
6308 | Error ONNXModelLoader::checkInputs(ONNX_NAMESPACE::GraphProto &net, |
6309 | llvm::ArrayRef<const char *> tensorNames, |
6310 | llvm::ArrayRef<TypeRef> types) { |
6311 | for (size_t i = 0; i < tensorNames.size(); i++) { |
6312 | // Look if a corresponding input exists. |
6313 | for (int j = 0; j < net.input_size(); j++) { |
6314 | const ONNX_NAMESPACE::ValueInfoProto &valueInfo = net.input(j); |
6315 | const std::string &inputName = valueInfo.name(); |
6316 | |
6317 | if (inputName != tensorNames[i]) { |
6318 | continue; |
6319 | } |
6320 | |
6321 | // Get tensor shape. |
6322 | llvm::ArrayRef<dim_t> dims = types[i]->dims(); |
6323 | |
6324 | // Get proto shape. |
6325 | std::vector<dim_t> dimsProto; |
6326 | ASSIGN_VALUE_OR_RETURN_ERR( |
6327 | dimsProto, getProtoShape(valueInfo.type().tensor_type().shape())); |
6328 | |
6329 | // Check if the number of dimensions is consistent. |
6330 | RETURN_ERR_IF_NOT(dims.size() == dimsProto.size(), |
6331 | "Mismatch between input image and ONNX input shape" ); |
6332 | |
6333 | // Allow batch dimensions to be different. |
6334 | for (size_t k = 1; k < dims.size(); k++) { |
6335 | RETURN_ERR_IF_NOT(dims[k] == dimsProto[k], |
6336 | "Mismatch between input image and ONNX input shape" ); |
6337 | } |
6338 | |
6339 | RETURN_IF_ERR(checkStaticPH(valueInfo, staticInputs_, useGlowCustomOps_)); |
6340 | } |
6341 | } |
6342 | return Error::success(); |
6343 | } |
6344 | |
6345 | Error ONNXModelLoader::setupOrigStaticTypeMap(ONNX_NAMESPACE::GraphProto &net) { |
6346 | if (!staticPlaceholderTypes_) { |
6347 | return Error::success(); |
6348 | } |
6349 | |
6350 | for (int i = 0; i < net.node_size(); i++) { |
6351 | auto &op = net.node(i); |
6352 | ArgumentDictionaryTy dict = loadArgumentMap(op); |
6353 | if (op.op_type() != staticPHDummyNodeName) { |
6354 | continue; |
6355 | } |
6356 | RETURN_ERR_IF_NOT(staticInputs_.count(op.name()), |
6357 | "Expected static input for " + op.name()); |
6358 | TypeRef OT; |
6359 | ASSIGN_VALUE_OR_RETURN_ERR( |
6360 | OT, loadTypeFromAttributes(Storage::OutputIdx, dict)); |
6361 | staticPlaceholderTypes_->emplace(op.name(), *OT); |
6362 | } |
6363 | RETURN_ERR_IF_NOT( |
6364 | staticPlaceholderTypes_->size() == staticInputs_.size(), |
6365 | strFormat( |
6366 | "Expected to find types for all static Placeholders. %lu vs. %lu" , |
6367 | staticPlaceholderTypes_->size(), staticInputs_.size())); |
6368 | return Error::success(); |
6369 | } |
6370 | |
6371 | Error ONNXModelLoader::assignGraphInputs(const ONNX_NAMESPACE::GraphProto &net, |
6372 | llvm::ArrayRef<NodeValue> NVs) { |
6373 | RETURN_ERR_IF_NOT((dim_t)NVs.size() == (dim_t)net.input_size(), |
6374 | "Input size mismatch." ); |
6375 | for (size_t i = 0; i < NVs.size(); i++) { |
6376 | nodeValueByName_[net.input(i).name()] = NVs[i]; |
6377 | } |
6378 | return Error::success(); |
6379 | } |
6380 | |
6381 | Error ONNXModelLoader::loadModel(ONNX_NAMESPACE::ModelProto &modelDef, |
6382 | llvm::ArrayRef<const char *> tensorNames, |
6383 | llvm::ArrayRef<TypeRef> types, |
6384 | const Backend *B, |
6385 | bool loadInputsAsPlaceholdersForOnnx) { |
6386 | useGlowCustomOps_ = modelDef.producer_name() == "GlowONNXModelWriter" ; |
6387 | |
6388 | RETURN_IF_ERR(setVersion(modelDef)); |
6389 | |
6390 | ONNX_NAMESPACE::GraphProto graphDef = modelDef.graph(); |
6391 | RETURN_IF_ERR(checkInputs(graphDef, tensorNames, types)); |
6392 | RETURN_IF_ERR(collectStaticInputs(graphDef)); |
6393 | RETURN_IF_ERR(setupOrigStaticTypeMap(graphDef)); |
6394 | |
6395 | RETURN_IF_ERR(loadInitializers(graphDef)); |
6396 | |
6397 | if (tensorNames.empty() && types.empty()) { |
6398 | // Detect inputs without initializers and create placeholders. |
6399 | RETURN_IF_ERR(loadInputs(graphDef, loadInputsAsPlaceholdersForOnnx)); |
6400 | } |
6401 | |
6402 | RETURN_IF_ERR(loadNetwork(graphDef, /* loadingConstFoldSubgraph */ false)); |
6403 | |
6404 | RETURN_IF_ERR(setOutputNodes(graphDef)); |
6405 | |
6406 | RETURN_ERR_IF_NOT(G_->verify(B), "Function verification failed." ); |
6407 | |
6408 | deleteUnusedConstants(); |
6409 | |
6410 | deleteConstFoldFunctions(); |
6411 | |
6412 | RETURN_IF_ERR(verifyDummyQParams()); |
6413 | |
6414 | return Error::success(); |
6415 | } |
6416 | |
6417 | ONNXModelLoader::ONNXModelLoader(const std::string &modelDescFilename, |
6418 | llvm::ArrayRef<const char *> tensorNames, |
6419 | llvm::ArrayRef<TypeRef> types, Function &F, |
6420 | Error *errPtr, bool zipMode, |
6421 | BackendSpecificNodeInfo *perNodeOpts, |
6422 | bool disableConstFoldInLoader, |
6423 | bool loadIntoExistingModule, const Backend *B, |
6424 | const std::string *inputStringPtr) |
6425 | : CommonOperatorLoader(tensorNames, types, &F, errPtr, |
6426 | loadIntoExistingModule), |
6427 | perNodeOpts_(perNodeOpts), staticPlaceholderTypes_(nullptr) { |
6428 | // if errPtr already contains an error then don't continue with constructor |
6429 | if (errPtr && *errPtr) { |
6430 | return; |
6431 | } |
6432 | |
6433 | if (disableConstFoldInLoader) { |
6434 | constFoldInLoader_ = false; |
6435 | } |
6436 | |
6437 | auto setup = [&]() -> Error { |
6438 | ONNX_NAMESPACE::ModelProto modelDef; |
6439 | ASSIGN_VALUE_OR_RETURN_ERR( |
6440 | modelDef, loadProto(modelDescFilename, zipMode, inputStringPtr)); |
6441 | RETURN_IF_ERR(loadModel(modelDef, tensorNames, types, B, |
6442 | /* loadInputsAsPlaceholdersForOnnx */ true)); |
6443 | return Error::success(); |
6444 | }; |
6445 | |
6446 | if (errPtr) { |
6447 | *errPtr = setup(); |
6448 | } else { |
6449 | EXIT_ON_ERR(setup()); |
6450 | } |
6451 | } |
6452 | |
6453 | /// \returns a metadata prop found at \p key in \p modelDef. |
6454 | static const char *getMetadataProp(const ONNX_NAMESPACE::ModelProto &modelDef, |
6455 | llvm::StringRef key) { |
6456 | for (const auto &keyVal : modelDef.metadata_props()) { |
6457 | if (keyVal.key() == key) { |
6458 | return keyVal.value().data(); |
6459 | } |
6460 | } |
6461 | return nullptr; |
6462 | } |
6463 | |
6464 | static Expected<int32_t> |
6465 | getIntMetadataProp(const ONNX_NAMESPACE::ModelProto &modelDef, |
6466 | llvm::StringRef key) { |
6467 | const char *intStr = getMetadataProp(modelDef, key); |
6468 | RETURN_ERR_IF_NOT(intStr, "Did not find value for " + std::string(key)); |
6469 | int32_t intVal; |
6470 | ASSIGN_VALUE_OR_RETURN_ERR(intVal, getIntFromStr(intStr)); |
6471 | return intVal; |
6472 | } |
6473 | |
6474 | Error ONNXModelLoader::setupPartitions(ONNX_NAMESPACE::ModelProto &modelDef, |
6475 | PrePartitionedConfig &PPC, |
6476 | llvm::StringRef rootName, |
6477 | int numPartitions) { |
6478 | PPC.funcName = rootName.str(); |
6479 | PPC.resizeAndReserve(numPartitions); |
6480 | |
6481 | for (int i = 0; i < numPartitions; i++) { |
6482 | const std::string partIdPrefix = getPartitionIdPrefix(i); |
6483 | const char *pName = getMetadataProp(modelDef, partIdPrefix + nameSignifier); |
6484 | RETURN_ERR_IF_NOT(pName, "Didn't find expected partition name" ); |
6485 | |
6486 | // Load the partition name and create a Function with the same name. |
6487 | Function *PF = nullptr; |
6488 | if (loadIntoExistingModule_ && mod_.hasFunction(pName)) { |
6489 | PF = mod_.getFunction(pName); |
6490 | RETURN_ERR_IF_NOT(PF->getNodes().size() == 0, |
6491 | "Function must be empty to load into." ); |
6492 | } else { |
6493 | PF = mod_.createFunction(pName); |
6494 | } |
6495 | partNameToFun_[pName] = PF; |
6496 | PPC.funcs.push_back(PF); |
6497 | |
6498 | // Load all logical devices for the partition. |
6499 | int32_t numLogicalDevices; |
6500 | ASSIGN_VALUE_OR_RETURN_ERR( |
6501 | numLogicalDevices, |
6502 | getIntMetadataProp(modelDef, |
6503 | partIdPrefix + numLogicalDevicesSignifier)); |
6504 | for (int j = 0; j < numLogicalDevices; j++) { |
6505 | DeviceIDTy ID; |
6506 | ASSIGN_VALUE_OR_RETURN_ERR( |
6507 | ID, getIntMetadataProp(modelDef, |
6508 | partIdPrefix + getLogicalDeviceSignfier(j))); |
6509 | PPC.logicalIDs[i].push_back(ID); |
6510 | } |
6511 | |
6512 | // Get backend name. |
6513 | const char *backendName = |
6514 | getMetadataProp(modelDef, partIdPrefix + backendNameSignifier); |
6515 | RETURN_ERR_IF_NOT(backendName, "Didn't find Backend name" ); |
6516 | PPC.backendNames.emplace_back(backendName); |
6517 | |
6518 | // Get backendHints.executionUnits. Note that we don't support serializing |
6519 | // SRAMPrioritization, so it's left empty. |
6520 | unsigned execUnits; |
6521 | ASSIGN_VALUE_OR_RETURN_ERR( |
6522 | execUnits, |
6523 | getIntMetadataProp(modelDef, partIdPrefix + executionUnitsSignifier)); |
6524 | PPC.backendHints.push_back({execUnits, /* SRAMPrioritization */ {}}); |
6525 | |
6526 | // Load all backend-specific options. |
6527 | int32_t numBackendSpecificOpts; |
6528 | ASSIGN_VALUE_OR_RETURN_ERR( |
6529 | numBackendSpecificOpts, |
6530 | getIntMetadataProp(modelDef, |
6531 | partIdPrefix + numBackendSpecificOptsSignifier)); |
6532 | for (int j = 0; j < numBackendSpecificOpts; j++) { |
6533 | const char *optKey = getMetadataProp( |
6534 | modelDef, partIdPrefix + getBackendSpecificOptKeySignifier(j)); |
6535 | RETURN_ERR_IF_NOT(optKey, |
6536 | "Didn't find expected backend-specific option key" ); |
6537 | const char *optVal = getMetadataProp( |
6538 | modelDef, partIdPrefix + getBackendSpecificOptValSignifier(j)); |
6539 | RETURN_ERR_IF_NOT(optVal, |
6540 | "Didn't find expected backend-specific option val" ); |
6541 | PPC.backendSpecificOpts[i][optKey] = optVal; |
6542 | } |
6543 | |
6544 | // Get replicationCount. |
6545 | int32_t replicationCount; |
6546 | ASSIGN_VALUE_OR_RETURN_ERR( |
6547 | replicationCount, |
6548 | getIntMetadataProp(modelDef, partIdPrefix + replicationCountSignifier)); |
6549 | PPC.replicationCounts.push_back(replicationCount); |
6550 | } |
6551 | |
6552 | return Error::success(); |
6553 | } |
6554 | |
6555 | void ONNXModelLoader::setupPositionalIO( |
6556 | const ONNX_NAMESPACE::GraphProto &graph) { |
6557 | for (const auto &in : graph.input()) { |
6558 | if (staticInputs_.count(in.name())) { |
6559 | continue; |
6560 | } |
6561 | auto loaderNameOrErr = |
6562 | getAttrFromDocString(loaderNameSignifier, in.doc_string()); |
6563 | if (ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)) { |
6564 | positionalInputNames_.clear(); |
6565 | break; |
6566 | } |
6567 | positionalInputNames_.emplace_back(loaderNameOrErr.get()); |
6568 | } |
6569 | |
6570 | for (const auto &out : graph.output()) { |
6571 | auto loaderNameOrErr = |
6572 | getAttrFromDocString(loaderNameSignifier, out.doc_string()); |
6573 | if (ERR_TO_BOOL(loaderNameOrErr.takeError(), /* log */ false)) { |
6574 | positionalOutputNames_.clear(); |
6575 | break; |
6576 | } |
6577 | positionalOutputNames_.emplace_back(loaderNameOrErr.get()); |
6578 | } |
6579 | } |
6580 | |
6581 | Error ONNXModelLoader::setupUpdatedTQPMap( |
6582 | ONNX_NAMESPACE::ModelProto &modelDef, uint32_t weightsCount, |
6583 | const onnxTensorDescriptorV1 *weightDescriptors) { |
6584 | // Check if we have the two strings in metadata props we need to do TQP |
6585 | // updating, and if not print a warning/return early. |
6586 | const char *originNameToUniqueOffsetMappingStr = |
6587 | getMetadataProp(modelDef, originNameToUniqueOffsetMappingSignifier); |
6588 | if (!originNameToUniqueOffsetMappingStr) { |
6589 | LOG(WARNING) << "Did not find \"" |
6590 | << originNameToUniqueOffsetMappingSignifier |
6591 | << "\" in ONNX model, skipping setting updated TQP map." ; |
6592 | return Error::success(); |
6593 | } |
6594 | |
6595 | const char *qParamC2ProtoStr = getMetadataProp(modelDef, "C2_with_q_params" ); |
6596 | if (!qParamC2ProtoStr) { |
6597 | LOG(WARNING) << "Did not find \"C2_with_q_params\" in ONNX model, skipping " |
6598 | "setting updated TQP map." ; |
6599 | return Error::success(); |
6600 | } |
6601 | |
6602 | // Now load the qParamC2ProtoStr into a temporary dummy Caffe2ModelLoader, |
6603 | // which fills in the originNameToTQPMap based on that model and the weights. |
6604 | OriginNameToTQPMap originNameToTQPMap; |
6605 | Error err(Error::success()); |
6606 | Module dummyMod; |
6607 | Caffe2ModelLoader tmpLoader(qParamC2ProtoStr, weightsCount, weightDescriptors, |
6608 | dummyMod, &err, &originNameToTQPMap, |
6609 | clipQuantRangeToFP16_); |
6610 | RETURN_IF_ERR(err); |
6611 | |
6612 | // Now parse the originNameToUniqueOffsetMappingStr to find the original C2 |
6613 | // name : unique offset pairs. These are formatted like |
6614 | // "name_op_a@0@@name_op_b@1@@", with @@ separating each pair, and @ |
6615 | // separating name from unique offset. |
6616 | llvm::SmallVector<llvm::StringRef, 128> nameOffsetSplits; |
6617 | llvm::StringRef strRef = llvm::StringRef(originNameToUniqueOffsetMappingStr); |
6618 | strRef.split(nameOffsetSplits, offsetEndSig, /* MaxSplit */ -1, |
6619 | /* KeepEmpty */ false); |
6620 | |
6621 | // Store the mapping into updatedTQPs_, where each unique offset is used as |
6622 | // the index into updatedTQPs_ to the actual TQP to use. I.e. we essentially |
6623 | // already have c2_name -> offset, and c2_name -> TQP, so we change this to |
6624 | // offset -> TQP. |
6625 | updatedTQPs_.resize(nameOffsetSplits.size()); |
6626 | for (auto &nameOffsetSplit : nameOffsetSplits) { |
6627 | auto nameOffsetPair = nameOffsetSplit.split(offsetSepSig); |
6628 | int32_t idx; |
6629 | ASSIGN_VALUE_OR_RETURN_ERR(idx, getIntFromStr(nameOffsetPair.second)); |
6630 | RETURN_ERR_IF_NOT(idx < int32_t(updatedTQPs_.size()), |
6631 | strFormat("Provided offset index %d not inside size " |
6632 | "of updatedTQPs_ %lu" , |
6633 | idx, updatedTQPs_.size())); |
6634 | |
6635 | auto it = originNameToTQPMap.find(nameOffsetPair.first.str()); |
6636 | RETURN_ERR_IF_NOT(it != originNameToTQPMap.end(), |
6637 | strFormat("Did not find matching TQP for %s" , |
6638 | nameOffsetPair.first.str().data())); |
6639 | updatedTQPs_[idx] = it->second; |
6640 | } |
6641 | return Error::success(); |
6642 | } |
6643 | |
6644 | ONNXModelLoader::ONNXModelLoader( |
6645 | const std::string &modelDescFilename, |
6646 | llvm::ArrayRef<const char *> tensorNames, llvm::ArrayRef<TypeRef> types, |
6647 | Module &mod, llvm::StringRef funName, PrePartitionedConfig *PPC, |
6648 | Error *errPtr, bool zipMode, BackendSpecificNodeInfo *perNodeOpts, |
6649 | bool loadIntoExistingModule, bool disableConstFoldInLoader, |
6650 | const Backend *B, const std::string *inputStringPtr) |
6651 | : CommonOperatorLoader(tensorNames, types, mod, errPtr, |
6652 | loadIntoExistingModule), |
6653 | perNodeOpts_(perNodeOpts), staticPlaceholderTypes_(nullptr) { |
6654 | // if errPtr already contains an error then don't continue with constructor |
6655 | if (errPtr && *errPtr) { |
6656 | return; |
6657 | } |
6658 | |
6659 | if (disableConstFoldInLoader) { |
6660 | constFoldInLoader_ = false; |
6661 | } |
6662 | |
6663 | auto setup = [&]() -> Error { |
6664 | ONNX_NAMESPACE::ModelProto modelDef; |
6665 | ASSIGN_VALUE_OR_RETURN_ERR( |
6666 | modelDef, loadProto(modelDescFilename, zipMode, inputStringPtr)); |
6667 | |
6668 | auto numPartitionsOrErr = getIntMetadataProp(modelDef, "numPartitions" ); |
6669 | if (!numPartitionsOrErr) { |
6670 | ERR_TO_VOID(numPartitionsOrErr.takeError(), /*log*/ false); |
6671 | G_ = mod_.createFunction(funName); |
6672 | } else { |
6673 | RETURN_ERR_IF_NOT(PPC, "No PrePartitionConfig to load partitions into" ); |
6674 | RETURN_IF_ERR( |
6675 | setupPartitions(modelDef, *PPC, funName, *numPartitionsOrErr)); |
6676 | } |
6677 | |
6678 | RETURN_IF_ERR(loadModel(modelDef, tensorNames, types, B, |
6679 | /* loadInputsAsPlaceholdersForOnnx */ true)); |
6680 | |
6681 | return Error::success(); |
6682 | }; |
6683 | |
6684 | if (errPtr) { |
6685 | *errPtr = setup(); |
6686 | } else { |
6687 | EXIT_ON_ERR(setup()); |
6688 | } |
6689 | } |
6690 | |
6691 | ONNXModelLoader::ONNXModelLoader( |
6692 | const void *model, uint32_t modelSize, uint32_t weightsCount, |
6693 | const onnxTensorDescriptorV1 *weightDescriptors, Function &F, |
6694 | bool loadInputsAsPlaceholdersForOnnx, Error *errPtr, bool constFoldInLoader, |
6695 | BackendSpecificNodeInfo *perNodeOpts) |
6696 | : CommonOperatorLoader({}, {}, &F, errPtr, true), perNodeOpts_(perNodeOpts), |
6697 | staticPlaceholderTypes_(nullptr) { |
6698 | // if errPtr already contains an error then don't continue with constructor |
6699 | if (errPtr && *errPtr) { |
6700 | return; |
6701 | } |
6702 | |
6703 | // Always override the default for folding in this constructor. |
6704 | constFoldInLoader_ = constFoldInLoader; |
6705 | |
6706 | // Lambda to setup the ONNXModelLoader and return any Errors that were |
6707 | // raised. |
6708 | auto setup = [&]() -> Error { |
6709 | ONNX_NAMESPACE::ModelProto modelDef; |
6710 | ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(model, modelSize)); |
6711 | |
6712 | RETURN_IF_ERR(loadWeights(weightsCount, weightDescriptors)); |
6713 | |
6714 | RETURN_IF_ERR(loadModel(modelDef, {}, {}, /* B */ nullptr, |
6715 | loadInputsAsPlaceholdersForOnnx)); |
6716 | |
6717 | return Error::success(); |
6718 | }; |
6719 | |
6720 | if (errPtr) { |
6721 | *errPtr = setup(); |
6722 | } else { |
6723 | EXIT_ON_ERR(setup()); |
6724 | } |
6725 | } |
6726 | |
6727 | ONNXModelLoader::ONNXModelLoader( |
6728 | const void *model, uint32_t modelSize, uint32_t weightsCount, |
6729 | const onnxTensorDescriptorV1 *weightDescriptors, Module &mod, |
6730 | llvm::StringRef funName, PrePartitionedConfig *PPC, |
6731 | bool loadInputsAsPlaceholdersForOnnx, Error *errPtr, bool constFoldInLoader, |
6732 | BackendSpecificNodeInfo *perNodeOpts, |
6733 | std::map<std::string, Type> *staticPlaceholderTypes, bool replaceDummyTQPs, |
6734 | bool clipQuantRangeToFP16) |
6735 | : CommonOperatorLoader({}, {}, mod, errPtr, |
6736 | /* loadIntoExistingModule */ true, |
6737 | /* originNameToTQPMap */ nullptr, |
6738 | /* loadUniquedDummyQParams */ false, |
6739 | replaceDummyTQPs, /* zeroScaleFP16Clip */ false, |
6740 | clipQuantRangeToFP16), |
6741 | perNodeOpts_(perNodeOpts), |
6742 | staticPlaceholderTypes_(staticPlaceholderTypes) { |
6743 | // if errPtr already contains an error then don't continue with constructor |
6744 | if (errPtr && *errPtr) { |
6745 | return; |
6746 | } |
6747 | |
6748 | // Always override the default for folding in this constructor. |
6749 | constFoldInLoader_ = constFoldInLoader; |
6750 | |
6751 | // Lambda to setup the ONNXModelLoader and return any Errors that were |
6752 | // raised. |
6753 | auto setup = [&]() -> Error { |
6754 | ONNX_NAMESPACE::ModelProto modelDef; |
6755 | ASSIGN_VALUE_OR_RETURN_ERR(modelDef, loadProto(model, modelSize)); |
6756 | |
6757 | // If we're going to be replacing dummy TQPs then setup the updated TQP map, |
6758 | // which is used later on when loading each op. |
6759 | if (replaceDummyTQPs_) { |
6760 | // Check if the model has specified that clipQuantRangeToFP16 should be |
6761 | // overridden via the clipQuantRangeToFP16Key metadata prop. Do this |
6762 | // before updating TQP map, because this will affect the qparams loaded. |
6763 | for (const auto &keyVal : modelDef.metadata_props()) { |
6764 | if (keyVal.key() == clipQuantRangeToFP16Key && keyVal.value() == "1" ) { |
6765 | LOG(INFO) << "ONNXModelLoader found enabled " |
6766 | << clipQuantRangeToFP16Key; |
6767 | clipQuantRangeToFP16_ = true; |
6768 | break; |
6769 | } |
6770 | } |
6771 | |
6772 | RETURN_IF_ERR( |
6773 | setupUpdatedTQPMap(modelDef, weightsCount, weightDescriptors)); |
6774 | } |
6775 | |
6776 | RETURN_IF_ERR(loadWeights(weightsCount, weightDescriptors)); |
6777 | |
6778 | auto numPartitionsOrErr = getIntMetadataProp(modelDef, "numPartitions" ); |
6779 | if (!numPartitionsOrErr) { |
6780 | ERR_TO_VOID(numPartitionsOrErr.takeError(), /*log*/ false); |
6781 | G_ = mod_.createFunction(funName); |
6782 | } else { |
6783 | RETURN_ERR_IF_NOT(PPC, "No PrePartitionConfig to load partitions into" ); |
6784 | RETURN_IF_ERR( |
6785 | setupPartitions(modelDef, *PPC, funName, *numPartitionsOrErr)); |
6786 | } |
6787 | |
6788 | RETURN_IF_ERR(loadModel(modelDef, {}, {}, /* B */ nullptr, |
6789 | loadInputsAsPlaceholdersForOnnx)); |
6790 | |
6791 | if (loadInputsAsPlaceholdersForOnnx) { |
6792 | setupPositionalIO(modelDef.graph()); |
6793 | } |
6794 | |
6795 | return Error::success(); |
6796 | }; |
6797 | |
6798 | if (errPtr) { |
6799 | *errPtr = setup(); |
6800 | } else { |
6801 | EXIT_ON_ERR(setup()); |
6802 | } |
6803 | } |
6804 | |