1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef GLOW_IMPORTER_COMMONOPERATORLOADER_H
18#define GLOW_IMPORTER_COMMONOPERATORLOADER_H
19
20#include "foxi/onnxifi.h"
21
22#include "glow/Importer/ProtobufLoader.h"
23
24#include "glow/Base/Tensor.h"
25#include "glow/Graph/Graph.h"
26
27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/StringRef.h"
29
30#include <functional>
31#include <numeric>
32#include <string>
33#include <unordered_map>
34#include <vector>
35
36namespace glow {
37
38/// Result of loading a weight, potentially with additional offsets and
39/// scales tensors containing quantization parameters only if the loaded weight
40/// was found to have multiple quantization parameters.
41struct LoadWeightResult {
42 /// Main Glow tensor, this is always non-null.
43 std::unique_ptr<Tensor> t;
44 /// Glow tensor containing quantization offsets. This should only be non-null
45 /// if there is more than 1 quantization parameter found.
46 std::unique_ptr<Tensor> offsets;
47 /// Glow tensor containing quantization scales. This should only be non-null
48 /// if there is more than 1 quantization parameter found.
49 std::unique_ptr<Tensor> scales;
50 /// Type info of the weight, this is used for offline weights.
51 Type type;
52};
53
54#define dispatchQuantizedImpl(functionName, elemTy, ...) \
55 switch (elemTy) { \
56 case ElemKind::Int8QTy: \
57 functionName<int8_t>(__VA_ARGS__); \
58 break; \
59 case ElemKind::Int16QTy: \
60 functionName<int16_t>(__VA_ARGS__); \
61 break; \
62 case ElemKind::Int32QTy: \
63 functionName<int32_t>(__VA_ARGS__); \
64 break; \
65 default: \
66 llvm_unreachable("Type is not supported"); \
67 }
68
69template <typename eTy>
70void rescaleQTensor(const Tensor &oldT, Tensor &rescaledT, float newMin,
71 float newMax) {
72 const Type &oldTy = oldT.getType();
73 const TensorQuantizationParams oldQParams = {oldTy.getScale(),
74 oldTy.getOffset()};
75 const TensorQuantizationParams newQParams = chooseQuantizationParams(
76 {newMin, newMax}, quantization::Asymmetric, oldTy.getElementType());
77
78 // Setup Tensor to copy rescaled Tensor into.
79 Type rescaledTy(oldTy.getElementType(), oldTy.dims(), newQParams.scale,
80 newQParams.offset);
81 rescaledT.reset(rescaledTy);
82
83 auto srcH = oldT.getHandle<eTy>();
84 auto destH = rescaledT.getHandle<eTy>();
85 for (size_t i = 0, e = destH.size(); i < e; ++i) {
86 float val = quantization::dequantize(srcH.raw(i), oldQParams);
87 destH.raw(i) = quantization::quantize(val, newQParams);
88 }
89}
90
91/// Given \p result, rescale it given \p newMin and \p newMax.
92template <typename eTy>
93void rescaleQTensorResult(LoadWeightResult &result, float newMin,
94 float newMax) {
95 // Get new type based on newMin/newMax and old elem kind.
96 auto rescaledT = glow::make_unique<Tensor>();
97 rescaleQTensor<eTy>(*result.t, *rescaledT, newMin, newMax);
98 result.t = std::move(rescaledT);
99 result.type = result.t->getType();
100}
101
102/// Contains loaders for operators, which are common to ONNX and Caffe2
103/// formats. Every loader method adds necessary nodes to property G_, which
104/// is inherited from ProtobufLoader class, therefore modifying the class
105/// instance itself.
106template <typename OpType, typename AttrType>
107class CommonOperatorLoader : public ProtobufLoader {
108 /// Loads the onnxTensorDescriptorV1 \p in and \returns a LoadWeightResult
109 /// where result.t is the main contents of the the onnxTensorDescriptorV1 and
110 /// result.offsets and result.scales are the quantization scales and offsets
111 /// of the onnxTensorDescriptorV1 if there were more than 1. If there is
112 /// exactly 1 scale and offset then result.t will be a quantized glow tensor.
113 inline Expected<LoadWeightResult>
114 loadWeight(const onnxTensorDescriptorV1 &in) {
115 // Only support CPU memory tensors.
116 if (in.memoryType != ONNXIFI_MEMORY_TYPE_CPU) {
117 return MAKE_ERR("Only support CPU memory tensors.");
118 }
119
120 // Number of qparams in the onnxTensorDescriptor.
121 const dim_t qparams = static_cast<dim_t>(in.quantizationParams);
122
123 // Only support quantizationAxis=1 for now.
124 if (qparams > 0 && in.quantizationAxis != 1) {
125 return MAKE_ERR(strFormat(
126 "Glow can only import quantized tensors with quantizationAxis=1 but "
127 "the tensor %s has quantizationAxis=%u",
128 in.name, in.quantizationAxis));
129 }
130
131 LoadWeightResult result;
132 result.t = glow::make_unique<Tensor>();
133
134 std::vector<dim_t> dims;
135 for (unsigned i = 0; i < in.dimensions; ++i) {
136 dims.push_back(in.shape[i]);
137 }
138
139 // Load unquantized tensor.
140 if (in.quantizationParams == 0) {
141 if (in.dataType == ONNXIFI_DATATYPE_FLOAT32) {
142 result.type = Type(ElemKind::FloatTy, dims);
143 } else if (in.dataType == ONNXIFI_DATATYPE_FLOAT16) {
144 result.type = Type(ElemKind::Float16Ty, dims);
145 } else if (in.dataType == ONNXIFI_DATATYPE_BFLOAT16) {
146 result.type = Type(ElemKind::BFloat16Ty, dims);
147 } else if (in.dataType == ONNXIFI_DATATYPE_INT32) {
148 result.type = Type(ElemKind::Int32ITy, dims);
149 } else if (in.dataType == ONNXIFI_DATATYPE_INT64) {
150 result.type = Type(ElemKind::Int64ITy, dims);
151 } else if (in.dataType == ONNXIFI_DATATYPE_UINT8) {
152 // UInt8 type is used for variety of rowwise quantized SLSs.
153 // Make dummy scale and offset for these cases.
154 result.type = Type(ElemKind::UInt8QTy, dims, 1.0, 0);
155 } else if (in.dataType == ONNXIFI_DATATYPE_UINT64) {
156 result.type = Type(ElemKind::Int64ITy, dims);
157 for (size_t i = 0; i < result.t->size(); ++i) {
158 RETURN_ERR_IF_NOT(
159 ((int64_t *)in.buffer)[i] >= 0,
160 "Disallow overflow of loaded UINT64 data into Int64ITy.");
161 }
162 } else {
163 return MAKE_ERR(strFormat(
164 "Only float, index, and uint8 unquantized tensors are supported, "
165 "got input with ONNXIFI_DATATYPE: %zu",
166 static_cast<size_t>(in.dataType)));
167 }
168 if (!in.isOffline) {
169 *result.t = Tensor((void *)in.buffer, &result.type);
170 }
171 return Expected<LoadWeightResult>(std::move(result));
172 }
173
174 // Load quantized tensor with either a single or multiple qparams.
175 float scale = 1.0;
176 int32_t offset = 0;
177
178 // If multiple qparams are present then load them as tensors and use the
179 // the default qparams for the result.t otherwise use the first (only)
180 // qparams.
181 if (in.quantizationParams == 1) {
182 scale = in.scales[0];
183 offset = in.biases[0];
184 } else {
185 RETURN_ERR_IF_NOT(!loadUniquedDummyQParams_,
186 strFormat("Unsupported loading of uniqued qparams for "
187 "vector of scales/biases for %s",
188 in.name));
189 Type scalesTy(ElemKind::FloatTy, llvm::makeArrayRef({qparams}));
190 Type offsetsTy(ElemKind::Int32ITy, llvm::makeArrayRef({qparams}));
191 result.scales = glow::make_unique<Tensor>((void *)in.scales, &scalesTy);
192 result.offsets = glow::make_unique<Tensor>((void *)in.biases, &offsetsTy);
193 }
194
195 // If we have a scale of dummyScale, then this must be a dummy pair of
196 // scale/offset. Look up the actual scale/offset to use as previously
197 // loaded, using the offset as the key to updatedTQPs_.
198 if (replaceDummyTQPs_ && scale == dummyScale) {
199 TensorQuantizationParams TQP;
200 ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(offset));
201 scale = TQP.scale;
202 offset = TQP.offset;
203 }
204
205 if (in.dataType == ONNXIFI_DATATYPE_UINT8) {
206 TypeRef outTy;
207 ASSIGN_VALUE_OR_RETURN_ERR(
208 outTy, ProtobufLoader::loadQuantTy(
209 in.name, ElemKind::Int8QTy, dims, scale, offset,
210 /* shiftUInt8ToInt8 */ true,
211 /* skipClipQuantRangeToFP16 */ true));
212 // Must copy the weights here because we will need to modify them by
213 // adjusting for UINT8_TO_INT8_SHIFT.
214 result.type = *outTy;
215 if (!in.isOffline) {
216 result.t->reset(result.type);
217
218 auto TH = result.t->getHandle<int8_t>();
219 uint8_t *data = (uint8_t *)in.buffer;
220 for (size_t i = 0; i < TH.size(); ++i) {
221 TH.raw(i) = (int8_t)(data[i] - UINT8_TO_INT8_SHIFT);
222 }
223 }
224 } else if (in.dataType == ONNXIFI_DATATYPE_INT32) {
225 TypeRef outTy;
226 ASSIGN_VALUE_OR_RETURN_ERR(
227 outTy, ProtobufLoader::loadQuantTy(
228 in.name, ElemKind::Int32QTy, dims, scale, offset,
229 /* shiftUInt8ToInt8 */ true,
230 /* skipClipQuantRangeToFP16 */ true));
231 result.type = *outTy;
232 if (!in.isOffline) {
233 *result.t = Tensor((void *)in.buffer, &result.type);
234 }
235 } else if (in.dataType == ONNXIFI_DATATYPE_INT8) {
236 TypeRef outTy;
237 ASSIGN_VALUE_OR_RETURN_ERR(
238 outTy, ProtobufLoader::loadQuantTy(
239 in.name, ElemKind::Int8QTy, dims, scale, offset,
240 /* shiftUInt8ToInt8 */ false,
241 /* skipClipQuantRangeToFP16 */ true));
242 result.type = *outTy;
243 if (!in.isOffline) {
244 *result.t = Tensor((void *)in.buffer, &result.type);
245 }
246 } else {
247 return MAKE_ERR(
248 strFormat("Only uint8, int32, and int8, quantized tensors are "
249 "supported, got input with ONNXIFI_DATATYPE: %zu",
250 static_cast<size_t>(in.dataType)));
251 }
252
253 // If we're clipping quantized ranges tp FP16, then we need to rescale the
254 // Tensor and update its type, plus the type in result.
255 if (clipQuantRangeToFP16_) {
256 const ElemKind k = result.type.getElementType();
257 const auto qMinMax = getQuantizedValueRange(scale, offset, k);
258 const float newMin = std::max(qMinMax.first, kMinFP16);
259 const float newMax = std::min(qMinMax.second, kMaxFP16);
260
261 // If min or max are clipped then create a new Tensor with the adjusted
262 // type, and rescale its payload.
263 if (newMin != qMinMax.first || newMax != qMinMax.second) {
264 RETURN_ERR_IF_NOT(
265 !in.isOffline,
266 strFormat("For clipQuantRangeToFP16, currently do "
267 "not support offline quantizated weights: %s",
268 in.name));
269 RETURN_ERR_IF_NOT(!result.offsets && !result.scales,
270 strFormat("For clipQuantRangeToFP16, currently do "
271 "not support multiple qparams: %s",
272 in.name));
273
274 dispatchQuantizedImpl(rescaleQTensorResult, k, result, newMin, newMax);
275 }
276 }
277
278 return Expected<LoadWeightResult>(std::move(result));
279 }
280
281 /// Merge shape \p shape into \p mergeShape, following multidirectional
282 /// broadcasting rules.
283 Error mergeMultidirectionalBroadcast(std::vector<dim_t> &mergeShape,
284 llvm::ArrayRef<dim_t> shape) {
285 size_t shift = mergeShape.size() - shape.size();
286 for (size_t i = 0; i < shape.size(); i++) {
287 if (shape[i] != 1) {
288 RETURN_ERR_IF_NOT((shape[i] == mergeShape[shift + i]) ||
289 (mergeShape[shift + i] == 1),
290 "Incompatible dimension for the broadcast");
291 mergeShape[shift + i] = shape[i];
292 }
293 // Otherwise, just leave mergeShape[i] as it is.
294 }
295 return Error::success();
296 }
297
298protected:
299 CommonOperatorLoader(llvm::ArrayRef<const char *> names,
300 llvm::ArrayRef<TypeRef> types, Function *F,
301 Error *errPtr = nullptr,
302 bool loadIntoExistingModule = false,
303 OriginNameToTQPMap *originNameToTQPMap = nullptr,
304 bool loadUniquedDummyQParams = false,
305 bool zeroScaleFP16Clip = false,
306 bool clipQuantRangeToFP16 = false)
307 : ProtobufLoader(names, types, F, errPtr, loadIntoExistingModule,
308 originNameToTQPMap, loadUniquedDummyQParams,
309 /* replaceDummyTQPs */ false, zeroScaleFP16Clip,
310 clipQuantRangeToFP16) {}
311
312 CommonOperatorLoader(
313 llvm::ArrayRef<const char *> names, llvm::ArrayRef<TypeRef> types,
314 Module &mod, Error *errPtr = nullptr, bool loadIntoExistingModule = false,
315 OriginNameToTQPMap *originNameToTQPMap = nullptr,
316 bool loadUniquedDummyQParams = false, bool replaceDummyTQPs = false,
317 bool zeroScaleFP16Clip = false, bool clipQuantRangeToFP16 = false)
318 : ProtobufLoader(names, types, mod, errPtr, loadIntoExistingModule,
319 originNameToTQPMap, loadUniquedDummyQParams,
320 replaceDummyTQPs, zeroScaleFP16Clip,
321 clipQuantRangeToFP16) {}
322
323 using ArgumentDictionaryTy =
324 std::unordered_map<std::string, const AttrType *>;
325
326 /// If we were replacing or loading dummy TQPs, \returns success if there
327 /// aren't any dummies left, or there are only dummies left.
328 Error verifyDummyQParams() {
329 RETURN_ERR_IF_NOT(!(replaceDummyTQPs_ && loadUniquedDummyQParams_),
330 "Cannot replace dummy TQPs when loading uniqued TQPs.");
331 if (replaceDummyTQPs_ || loadUniquedDummyQParams_) {
332 RETURN_IF_ERR(mod_.verifyDummyQParams(loadUniquedDummyQParams_));
333 }
334 return Error::success();
335 }
336
337 /// Helper to load quantization parameters from \p dict for op named \p name.
338 /// \returns a new TypeRef given \p k and \p dims.
339 Expected<TypeRef> loadQuantTy(const std::string &name, ElemKind k,
340 llvm::ArrayRef<dim_t> dims,
341 ArgumentDictionaryTy &dict,
342 bool skipClipQuantRangeToFP16 = false) {
343 RETURN_ERR_IF_NOT(dict.count("Y_scale"),
344 "missing Y_scale for quantized output type for " + name);
345 RETURN_ERR_IF_NOT(dict.count("Y_zero_point"),
346 "missing zero point for quantized output type for " +
347 name);
348
349 float scale;
350 ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict["Y_scale"]));
351 int32_t offset;
352 ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict["Y_zero_point"]));
353
354 return ProtobufLoader::loadQuantTy(name, k, dims, scale, offset,
355 /* shiftUInt8ToInt8 */ true,
356 skipClipQuantRangeToFP16);
357 }
358
359 /// \returns True if the operator has broadcasting activated.
360 virtual Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) = 0;
361
362 /// \returns True if the operator with the name \p typeName has support
363 /// for multidirectional broadcasting.
364 virtual bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) = 0;
365
366 inline Expected<LengthsMode> getLengthsMode(ArgumentDictionaryTy &dict) {
367 bool length1 = false;
368 if (dict.count("length1")) {
369 ASSIGN_VALUE_OR_RETURN_ERR(length1, loadInt(dict["length1"]));
370 }
371 if (length1) {
372 return LengthsMode::AllOne;
373 }
374 return LengthsMode::Variable;
375 }
376
377 inline Expected<float> getAvgLength(ArgumentDictionaryTy &dict) {
378 float avgLength = NAN;
379 if (dict.count("average_lookup_length")) {
380 ASSIGN_VALUE_OR_RETURN_ERR(avgLength,
381 loadFloat(dict["average_lookup_length"]));
382 }
383 return avgLength;
384 }
385
386 const std::string opErrMsg(const OpType &op, const std::string &errMsg) {
387 const std::string &opName = loadOperatorName(op);
388 return strFormat(" [Operator-'%s'] : %s ", opName.c_str(), errMsg.c_str());
389 }
390
391 /// Associate the name of operation outputs to a NodeValues corresponding to
392 /// node \p node. If \p numOutputs is lower than 0, then all outputs are
393 /// associated. Otherwise, the first \p numOutputs outputs are associated.
394 Error addNodeAsOutput(const OpType &op, Node *node, int numOutputs = -1) {
395 RETURN_ERR_IF_NOT(numOutputs <= op.output_size(),
396 "Can't register more than outputs in the operation.");
397 numOutputs = (numOutputs < 0) ? op.output_size() : numOutputs;
398 for (int i = 0; i < numOutputs; i++) {
399 nodeValueByName_[op.output(i)] = NodeValue(node, i);
400 }
401 return Error::success();
402 }
403
404 /// Loads RELU operator, given its protobuf representation and parsed args.
405 Error loadRelu(const OpType &op, ArgumentDictionaryTy &dict) {
406 const std::string &opName = loadOperatorName(op);
407 NodeValue in;
408 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
409 auto *R = G_->createRELU(opName, in);
410 RETURN_IF_ERR(addNodeAsOutput(op, R));
411 return Error::success();
412 }
413
414#define LOAD_UNARY_OP(OPNAME) \
415 Error load##OPNAME(const OpType &op, ArgumentDictionaryTy &dict) { \
416 const std::string &opName = loadOperatorName(op); \
417 NodeValue in; \
418 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); \
419 auto *T = G_->create##OPNAME(opName, in); \
420 RETURN_IF_ERR(addNodeAsOutput(op, T)); \
421 return Error::success(); \
422 }
423
424 LOAD_UNARY_OP(Sigmoid)
425 LOAD_UNARY_OP(Tanh)
426 LOAD_UNARY_OP(Exp)
427 LOAD_UNARY_OP(Neg)
428 LOAD_UNARY_OP(Floor)
429 LOAD_UNARY_OP(Ceil)
430 LOAD_UNARY_OP(Truncate)
431 LOAD_UNARY_OP(Log)
432
433 Error loadShape(const OpType &op, ArgumentDictionaryTy &dict) {
434 NodeValue in;
435 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
436
437 // This is statically known data, and so we create a Tensor for it.
438 Tensor T(ElemKind::Int64ITy, {(dim_t)in.dims().size()});
439 T.getHandle<int64_t>() =
440 std::vector<int64_t>(in.dims().begin(), in.dims().end());
441
442 RETURN_IF_ERR(createAndRegisterConstant(op.output(0), std::move(T)));
443
444 return Error::success();
445 }
446
447 /// Loads Pow operator, given its protobuf representation and parsed args.
448 Error loadPow(const OpType &op, ArgumentDictionaryTy &dict) {
449 const std::string &opName = loadOperatorName(op);
450 NodeValue base;
451 ASSIGN_VALUE_OR_RETURN_ERR(base, getNodeValueByName(op.input(0)));
452 NodeValue exp;
453 ASSIGN_VALUE_OR_RETURN_ERR(exp, getNodeValueByName(op.input(1)));
454 auto R = G_->createNodeWithBroadcast<PowNode>(opName, -1, base, exp);
455 RETURN_IF_ERR(addNodeAsOutput(op, R));
456 return Error::success();
457 }
458
459 /// Loads Sqrt operator, given its protobuf representation and parsed args.
460 Error loadSqrt(const OpType &op, ArgumentDictionaryTy &dict) {
461 const std::string &opName = loadOperatorName(op);
462 NodeValue in;
463 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
464 auto *R = G_->createPow(opName, in, 0.5f);
465 RETURN_IF_ERR(addNodeAsOutput(op, R));
466 return Error::success();
467 }
468
469 /// Loads Sqr operator, given its protobuf representation and parsed args.
470 Error loadSqr(const OpType &op, ArgumentDictionaryTy &dict) {
471 const std::string &opName = loadOperatorName(op);
472 NodeValue in;
473 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
474 auto *R = G_->createPow(opName, in, 2.0f);
475 RETURN_IF_ERR(addNodeAsOutput(op, R));
476 return Error::success();
477 }
478
479 /// Loads Reciprocal operator, given its protobuf representation and parsed
480 /// args.
481 Error loadReciprocal(const OpType &op, ArgumentDictionaryTy &dict) {
482 const std::string &opName = loadOperatorName(op);
483 NodeValue in;
484 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
485 auto *R = G_->createPow(opName, in, -1.0f);
486 RETURN_IF_ERR(addNodeAsOutput(op, R));
487 return Error::success();
488 }
489
490 Error loadSum(const OpType &op, ArgumentDictionaryTy &dict) {
491 if (op.input_size() == 1) {
492 NodeValue in;
493 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
494 RETURN_IF_ERR(addNodeAsOutput(op, in));
495 } else if (op.input_size() == 2) {
496 const std::string &opName = loadOperatorName(op);
497 NodeValue in0;
498 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
499 NodeValue in1;
500 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
501 auto *node = G_->createAdd(opName, in0, in1);
502 RETURN_IF_ERR(addNodeAsOutput(op, node));
503 } else {
504 const std::string &opName = loadOperatorName(op);
505 const unsigned numInputs = op.input_size();
506 llvm::SmallVector<NodeValue, 4> inputs;
507 inputs.reserve(numInputs);
508 for (unsigned i = 0; i < numInputs; i++) {
509 NodeValue in;
510 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
511 inputs.push_back(G_->createExpandDims(opName, in, {0}));
512 }
513 ConcatNode *concat = G_->createConcat(opName, inputs, /* axis */ 0);
514 Node *node = G_->createBatchedReduceAdd(opName, concat, /* axis */ {0});
515 RETURN_IF_ERR(addNodeAsOutput(op, node));
516 }
517 return Error::success();
518 }
519
520 Error loadLRN(const OpType &op, ArgumentDictionaryTy &dict) {
521 const std::string &opName = loadOperatorName(op);
522 NodeValue in;
523 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
524
525 size_t size;
526 ASSIGN_VALUE_OR_RETURN_ERR(size, loadInt(dict["size"]));
527 float alpha;
528 ASSIGN_VALUE_OR_RETURN_ERR(alpha, loadFloat(dict["alpha"]));
529 float beta;
530 ASSIGN_VALUE_OR_RETURN_ERR(beta, loadFloat(dict["beta"]));
531 float k;
532 ASSIGN_VALUE_OR_RETURN_ERR(k, loadFloat(dict["bias"]));
533
534 auto *tr = G_->createTranspose(opName, in, NCHW2NHWC);
535
536 auto *node = G_->createLocalResponseNormalization(opName, tr, size / 2,
537 alpha, beta, k);
538
539 auto *N = G_->createTranspose(opName, node, NHWC2NCHW);
540
541 // LRN in Caffe2 has a scale_ output, but I believe it's unused for
542 // inference. So explicitly only set output 0.
543 nodeValueByName_[op.output(0)] = N->getResult();
544 return Error::success();
545 }
546
547 Error loadMinMax(llvm::StringRef typeName, const OpType &op,
548 ArgumentDictionaryTy &dict) {
549 const std::string &opName = loadOperatorName(op);
550 NodeValue in0;
551 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
552 NodeValue in1;
553 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
554
555 Node *node = nullptr;
556 if (typeName == "Min") {
557 node = G_->createNodeWithBroadcast<MinNode>(opName, -1, in0, in1);
558 } else if (typeName == "Max") {
559 node = G_->createNodeWithBroadcast<MaxNode>(opName, -1, in0, in1);
560 } else {
561 return MAKE_ERR(opErrMsg(op, "Invalid min or max operator"));
562 }
563
564 RETURN_IF_ERR(addNodeAsOutput(op, node));
565 return Error::success();
566 }
567
568 static Expected<NodeValue> handleMatMulTranspose(Function *F,
569 ArgumentDictionaryTy &dict,
570 llvm::StringRef key,
571 NodeValue input) {
572 if (!dict.count(key.str())) {
573 return input;
574 }
575
576 int isTransposed;
577 ASSIGN_VALUE_OR_RETURN_ERR(isTransposed, loadInt(dict[key.str()]));
578 if (isTransposed == 1) {
579 auto dimsSize = input.dims().size();
580 RETURN_ERR_IF_NOT(dimsSize >= 2,
581 "C2 specs say rank of inputs must be >= 2");
582
583 std::vector<unsigned_t> shuffle;
584 unsigned_t i;
585 for (i = 0; i < dimsSize - 2; ++i) {
586 shuffle.push_back(i);
587 }
588 shuffle.push_back(i + 1);
589 shuffle.push_back(i);
590
591 return F->createTranspose(input.getNode()->getName().str() + ".transpose",
592 input, shuffle);
593 }
594
595 return input;
596 }
597
598 Error loadMatMul(const OpType &op, ArgumentDictionaryTy &dict) {
599 const std::string &opName = loadOperatorName(op);
600 NodeValue LHS;
601 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0)));
602 NodeValue RHS;
603 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1)));
604
605 ASSIGN_VALUE_OR_RETURN_ERR(LHS,
606 handleMatMulTranspose(G_, dict, "trans_a", LHS));
607 ASSIGN_VALUE_OR_RETURN_ERR(RHS,
608 handleMatMulTranspose(G_, dict, "trans_b", RHS));
609
610 Node *node = G_->createMatMul(opName, LHS, RHS);
611
612 RETURN_IF_ERR(addNodeAsOutput(op, node));
613 return Error::success();
614 }
615
616 Error loadBatchMatMul(const OpType &op, ArgumentDictionaryTy &dict) {
617 const std::string &opName = loadOperatorName(op);
618 NodeValue LHS;
619 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0)));
620 NodeValue RHS;
621 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1)));
622
623 ASSIGN_VALUE_OR_RETURN_ERR(LHS,
624 handleMatMulTranspose(G_, dict, "trans_a", LHS));
625 ASSIGN_VALUE_OR_RETURN_ERR(RHS,
626 handleMatMulTranspose(G_, dict, "trans_b", RHS));
627
628 const size_t numDimsLHS = LHS.dims().size();
629 const size_t numDimsRHS = RHS.dims().size();
630 RETURN_ERR_IF_NOT(
631 numDimsLHS >= 2,
632 opErrMsg(op, "BatchMatMul 1D operands are not yet supported."));
633 RETURN_ERR_IF_NOT(
634 numDimsRHS >= 2,
635 opErrMsg(op, "BatchMatMul 1D operands are not yet supported."));
636
637 // This is a very simple case when we don't need any broadcasting
638 if (numDimsLHS == 2 && numDimsRHS == 2) {
639 Node *node = G_->createMatMul(opName, LHS, RHS);
640 RETURN_IF_ERR(addNodeAsOutput(op, node));
641 return Error::success();
642 }
643
644 // In the rest of the function body we:
645 // 1. normalize operands using broadcasting rules,
646 // 2. convert normalized operands to 3D matrices, so they look like these:
647 // LHS = {numBatches, N, M}
648 // RHS = {numBatches, M, P}
649 // Result = {numBatches, N, P},
650 // 3. multiply 3D matrices using createBatchMatMul(), result will be 3D,
651 // 4. convert the result to the normalized broadcast shape.
652
653 const dim_t N = LHS.dims()[numDimsLHS - 2];
654 const dim_t M = LHS.dims()[numDimsLHS - 1];
655 const dim_t P = RHS.dims()[numDimsRHS - 1];
656
657 RETURN_ERR_IF_NOT(
658 RHS.dims()[numDimsRHS - 2] == M,
659 opErrMsg(op, "BatchMatMul operands dimensions are invalid."));
660
661 // Calculate broadcast shape and convert both operands to that shape
662 const std::vector<dim_t> originalDimsLHS{LHS.dims().begin(),
663 LHS.dims().end()};
664 const std::vector<dim_t> originalDimsRHS{RHS.dims().begin(),
665 RHS.dims().end()};
666 std::vector<dim_t> resultShape{P, N};
667 resultShape.reserve(std::max(numDimsLHS, numDimsRHS));
668 dim_t numBatches = 1;
669 int indLHS = numDimsLHS - 3; // skip last two dims
670 int indRHS = numDimsRHS - 3; // skip last two dims
671 for (; indLHS >= 0 && indRHS >= 0; --indLHS, --indRHS) {
672 const dim_t dimLHS = originalDimsLHS[indLHS];
673 const dim_t dimRHS = originalDimsRHS[indRHS];
674
675 RETURN_ERR_IF_NOT(
676 (dimLHS == dimRHS || (dimLHS == 1) || dimRHS == 1),
677 opErrMsg(op, "BatchMatMul dimensions cannot be broadcast."));
678 dim_t dim = 1;
679 if (dimLHS == dimRHS) {
680 dim = dimLHS;
681 } else if (dimLHS == 1) {
682 dim = dimRHS;
683 LHS = G_->createTile(opName + ".tileDim", LHS, dim, indLHS);
684 } else {
685 dim = dimLHS;
686 RHS = G_->createTile(opName + ".tileDim", RHS, dim, indRHS);
687 }
688 resultShape.push_back(dim);
689 numBatches *= dim;
690 }
691 for (; indLHS >= 0; --indLHS) {
692 const dim_t dim = originalDimsLHS[indLHS];
693 resultShape.push_back(dim);
694 numBatches *= dim;
695 RHS = G_->createExpandDims(opName + ".addDim", RHS, {0});
696 RHS = G_->createTile(opName + ".tileDim", RHS, dim, 0);
697 }
698 for (; indRHS >= 0; --indRHS) {
699 const dim_t dim = originalDimsRHS[indRHS];
700 resultShape.push_back(dim);
701 numBatches *= dim;
702 LHS = G_->createExpandDims(opName + ".addDim", LHS, {0});
703 LHS = G_->createTile(opName + ".tileDim", LHS, dim, 0);
704 }
705 std::reverse(resultShape.begin(), resultShape.end());
706
707 // Broadcast shape might have more than 3 dims,
708 // therefore, optionally, reshape the operands
709 if (resultShape.size() > 3) {
710 LHS =
711 G_->createReshape(opName + ".reshapeLHS3D", LHS, {numBatches, N, M});
712 RHS =
713 G_->createReshape(opName + ".reshapeRHS3D", RHS, {numBatches, M, P});
714 }
715 Node *node = G_->createBatchMatMul(opName, LHS, RHS);
716
717 // Optionally, reshape result to broadcast shape
718 if (resultShape.size() != 3) {
719 node = G_->createReshape(opName + ".reshapeResult", node, resultShape);
720 }
721
722 RETURN_IF_ERR(addNodeAsOutput(op, node));
723 return Error::success();
724 }
725
726 Error loadArithmetic(llvm::StringRef typeName, const OpType &op,
727 ArgumentDictionaryTy &dict) {
728 const std::string &opName = loadOperatorName(op);
729 NodeValue in0;
730 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
731 NodeValue in1;
732 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
733
734 bool broadcast;
735 ASSIGN_VALUE_OR_RETURN_ERR(broadcast, getBroadcast(dict));
736 // Check implicit broadcast
737 if (!broadcast && in0.dims().size() != in1.dims().size()) {
738 bool validBroadcast = true;
739 auto dimsA = in0.dims();
740 auto dimsB = in1.dims();
741 for (int i = dimsA.size() - 1, j = dimsB.size() - 1; i >= 0 && j >= 0;) {
742 auto a = dimsA[i];
743 auto b = dimsB[j];
744 if (!(a == b || a == 1 || b == 1)) {
745 validBroadcast = false;
746 break;
747 }
748 --i;
749 --j;
750 }
751 if (!validBroadcast) {
752 LOG(WARNING) << "Invalid broadcast rule for inputs of " << opName;
753 }
754 broadcast = validBroadcast;
755 }
756
757 int axis = -1;
758
759 // Broadcasting can be:
760 // - multidirectional (ONNX opset 7+), or
761 // - unidirectional (ONNX opset 1->6, Caffe2).
762
763 // Unidirectional broadcasting consists of broadcasting the right operand
764 // (in1) so that it matches the shape of the left operand (in0).
765 if (broadcast && !hasMultidirectionalBroadcast(typeName)) {
766 // With unidirectional broadcasting, the 'axis' attribute specifies
767 // from how much the right operand shape must be 'shifted' right.
768 // - In Caffe2, the 'axis' attribute is optional. If not specified, axis
769 // must be automatically computed so that the trailing-most dimensions
770 // of in1 is aligned to the trailing-most dimension of in0.
771 // - In ONNX, the 'axis' attribute is mandatory. axis == -1 is
772 // equivalent to no axis specified in Caffe2.
773
774 if (dict.count("axis")) {
775 ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
776 }
777 if (axis == -1) {
778 // Align trailing most dimensions.
779 axis = in0.dims().size() - in1.dims().size();
780 }
781 }
782
783 Node *node = nullptr;
784 if (broadcast) {
785 if (typeName == "Mul") {
786 node = G_->createNodeWithBroadcast<MulNode>(opName, axis, in0, in1);
787 } else if (typeName == "Add") {
788 node = G_->createNodeWithBroadcast<AddNode>(opName, axis, in0, in1);
789 } else if (typeName == "Sub") {
790 node = G_->createNodeWithBroadcast<SubNode>(opName, axis, in0, in1);
791 } else if (typeName == "Div") {
792 node = G_->createNodeWithBroadcast<DivNode>(opName, axis, in0, in1);
793 } else if (typeName == "Pow") {
794 node = G_->createNodeWithBroadcast<PowNode>(opName, axis, in0, in1);
795 } else {
796 return MAKE_ERR("Unsupported arithmetic typeName");
797 }
798 } else {
799 if (typeName == "Mul") {
800 node = G_->createMul(opName, in0, in1);
801 } else if (typeName == "Add") {
802 node = G_->createAdd(opName, in0, in1);
803 } else if (typeName == "Sub") {
804 node = G_->createSub(opName, in0, in1);
805 } else if (typeName == "Div") {
806 node = G_->createDiv(opName, in0, in1);
807 } else if (typeName == "Pow") {
808 node = G_->createPow(opName, in0, in1);
809 } else {
810 return MAKE_ERR("Unsupported arithmetic typeName");
811 }
812 }
813
814 RETURN_IF_ERR(addNodeAsOutput(op, node));
815 return Error::success();
816 }
817
818 Error loadSplit(const OpType &op, ArgumentDictionaryTy &dict) {
819 const std::string &opName = loadOperatorName(op);
820 NodeValue in;
821 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
822 size_t axis = 0;
823 if (dict.count("axis")) {
824 ASSIGN_VALUE_OR_RETURN_ERR(
825 axis, loadAxis<size_t>(dict["axis"], in.dims().size()));
826 }
827
828 std::vector<dim_t> split;
829 if (dict.count("split")) {
830 ASSIGN_VALUE_OR_RETURN_ERR(split, getShape<dim_t>(dict["split"]));
831 }
832
833 std::vector<SliceNode *> outputs;
834 G_->createSplit(opName, in, op.output_size(), axis, split, outputs);
835
836 for (int i = 0, e = op.output_size(); i < e; i++) {
837 // Each output from Split is a SliceNode which only has a single output,
838 // so only use 0 here as the node value result.
839 nodeValueByName_[op.output(i)] = outputs[i]->getResult();
840 }
841 return Error::success();
842 }
843
844 Error loadReshape(const OpType &op, ArgumentDictionaryTy &dict) {
845 const std::string &opName = loadOperatorName(op);
846 NodeValue in;
847 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
848
849 // Get the requested shape from the model.
850 // First look at input tensors, then at the "shape" attribute.
851 std::vector<dim_t> requestedDims;
852 if (op.input_size() > 1) {
853 if (!getConstantByNameOrNull(op.input(1))) {
854 return MAKE_ERR(opErrMsg(
855 op,
856 "Reshape: Non-constant shape tensors are unsupported by Glow."));
857 }
858 const Constant *constShapeConst;
859 ASSIGN_VALUE_OR_RETURN_ERR(constShapeConst,
860 getConstantByName(op.input(1)));
861 auto TH = constShapeConst->getPayload().getHandle<int64_t>();
862 for (auto dim : TH) {
863 requestedDims.push_back(dim);
864 }
865 } else if (dict.count("shape")) {
866 RETURN_ERR_IF_NOT(
867 op.input_size() == 1,
868 opErrMsg(
869 op,
870 "Reshape: Cannot specify new shape by both argument and input."));
871 std::vector<int64_t> protoDims;
872 ASSIGN_VALUE_OR_RETURN_ERR(protoDims, getShape<int64_t>(dict["shape"]));
873
874 for (auto dim : protoDims) {
875 requestedDims.push_back(dim);
876 }
877 } else {
878 return MAKE_ERR(opErrMsg(op,
879 "Reshape: Missing output shape information for "
880 "the Reshape operator."));
881 }
882
883 // Compute the actual new shape
884 ssize_t negOneIndex = -1;
885 llvm::ArrayRef<dim_t> inputDims = in.dims();
886 std::vector<dim_t> outputDims;
887 int64_t dimProduct = 1;
888 for (size_t i = 0, e = requestedDims.size(); i != e; i++) {
889 dim_t newDim = requestedDims[i];
890 if (newDim == 0) {
891 // 0 means that corresponding input dimension should be propagated to
892 // the output.
893 newDim = inputDims[i];
894 }
895 if (newDim != (dim_t)-1) {
896 dimProduct *= newDim;
897 outputDims.push_back(newDim);
898 } else {
899 // -1 means that the corresponding dimension should be inferred
900 // from all other dimensions, so that tensor size remains the same.
901 RETURN_ERR_IF_NOT(
902 negOneIndex < 0,
903 opErrMsg(
904 op,
905 "Reshape: At most one dimension of the new shape can be -1."));
906 negOneIndex = (ssize_t)i;
907 // The -1 case value is handled later.
908 outputDims.push_back(0);
909 }
910 }
911 if (negOneIndex >= 0) {
912 outputDims[negOneIndex] = in.getType()->size() / dimProduct;
913 }
914
915 auto *node = G_->createReshape(opName, in, outputDims);
916
917 // Caffe2 sometimes outputs old_shape which goes unused. We do not currently
918 // support it, so explicitly only set the first output.
919 nodeValueByName_[op.output(0)] = node->getResult();
920 return Error::success();
921 }
922
923 Error loadTranspose(const OpType &op, ArgumentDictionaryTy &dict,
924 llvm::StringRef permArgName) {
925 const std::string &opName = loadOperatorName(op);
926 NodeValue in;
927 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
928
929 // There is a difference between ONNX and Caffe2 specs for Transpose:
930 // one contains permutation under name "perm", the other contains it under
931 // argument name "axes". That's why the name is passed as a parameter.
932 std::vector<unsigned_t> perm;
933 if (dict.count(permArgName.str()))
934 ASSIGN_VALUE_OR_RETURN_ERR(perm,
935 getShape<unsigned_t>(dict[permArgName.str()]));
936
937 if (perm.empty()) {
938 // Empty permutation argument means reversing axes order.
939 size_t N = in.dims().size();
940 for (int64_t i = N - 1; i >= 0; i--)
941 perm.push_back(i);
942 }
943
944 auto *T = G_->createTranspose(opName, in, perm);
945
946 RETURN_IF_ERR(addNodeAsOutput(op, T));
947 return Error::success();
948 }
949
950 Error loadFlatten(const OpType &op, ArgumentDictionaryTy &dict) {
951 const std::string &opName = loadOperatorName(op);
952 NodeValue in;
953 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
954 int axis = 1;
955 if (dict.count("axis")) {
956 ASSIGN_VALUE_OR_RETURN_ERR(axis,
957 loadAxis<int>(dict["axis"], in.dims().size()));
958 }
959 auto *node = G_->createFlatten(opName, in, axis);
960 RETURN_IF_ERR(addNodeAsOutput(op, node));
961 return Error::success();
962 }
963
964 Error loadIdentity(const OpType &op, ArgumentDictionaryTy &dict) {
965 NodeValue in;
966 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
967
968 // If loading partitioned DAG then check if this identity is used for an
969 // intermediate, and if so create the Save+PH with the correct name.
970 if (partNameToFun_.size()) {
971 int intermediate = 0;
972 if (dict.count("isIntermediateOutputForDAG")) {
973 ASSIGN_VALUE_OR_RETURN_ERR(
974 intermediate, loadInt(dict.at("isIntermediateOutputForDAG")));
975 }
976
977 if (intermediate) {
978 const std::string &opName = loadOperatorName(op);
979 Placeholder *PH = mod_.getPlaceholderByNameSlow(op.output(0));
980 if (!PH) {
981 PH = mod_.createPlaceholder(in.getType(), op.output(0),
982 /* isTrainable */ false);
983 } else {
984 RETURN_ERR_IF_NOT(
985 loadIntoExistingModule_,
986 opErrMsg(op, "Found pre-existing PH by name " + op.output(0)));
987 RETURN_ERR_IF_NOT(
988 PH->getType()->isEqual(in.getType()),
989 opErrMsg(op, "Mismatch on pre-existing intermediate PH type"));
990 }
991 G_->createSave(opName, in, PH, /* skipSuffix */ true);
992 intermediatePHsByName_[op.output(0)] = PH;
993 in = PH->getOutput();
994 }
995 }
996
997 nodeValueByName_[op.output(0)] = in;
998 return Error::success();
999 }
1000
1001 Error loadReduceOp(llvm::StringRef typeName, const OpType &op,
1002 ArgumentDictionaryTy &dict) {
1003 const std::string &opName = loadOperatorName(op);
1004 NodeValue in;
1005 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1006
1007 std::vector<unsigned_t> shapeAxes = {};
1008 if (dict.count("axes")) {
1009 ASSIGN_VALUE_OR_RETURN_ERR(
1010 shapeAxes, loadAxes<unsigned_t>(dict["axes"], in.dims().size()));
1011 } else {
1012 shapeAxes.resize(in.dims().size());
1013 std::iota(shapeAxes.begin(), shapeAxes.end(), 0);
1014 }
1015
1016 std::sort(shapeAxes.begin(), shapeAxes.end());
1017
1018 llvm::ArrayRef<unsigned_t> axes(shapeAxes);
1019
1020 // Check if axes elements are unique.
1021 if (axes.size() > 1) {
1022 auto it = std::unique(shapeAxes.begin(), shapeAxes.end());
1023 if (it != shapeAxes.end()) {
1024 return MAKE_ERR(opErrMsg(op, "ReduceOp Axes values are not unique."),
1025 ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
1026 }
1027 }
1028
1029 bool keepDims = true;
1030 if (dict.count("keepdims")) {
1031 int keepdims;
1032 ASSIGN_VALUE_OR_RETURN_ERR(keepdims, loadInt(dict["keepdims"]));
1033 keepDims = (bool)keepdims;
1034 }
1035
1036 NodeValue node;
1037 if (typeName == "ReduceMean") {
1038 node = G_->createBatchedReduceMean(opName, in, axes);
1039 } else if (typeName == "ReduceSum") {
1040 node = G_->createBatchedReduceAdd(opName, in, axes);
1041 } else if (typeName == "ReduceMin") {
1042 node = G_->createBatchedReduceMin(opName, in, axes);
1043 } else if (typeName == "ReduceMax") {
1044 node = G_->createBatchedReduceMax(opName, in, axes);
1045 } else if (typeName == "ReduceProd") {
1046 node = G_->createBatchedReduceProd(opName, in, axes);
1047 } else if (typeName == "ReduceSumSquare") {
1048 node = G_->createBatchedReduceSumSquare(opName, in, axes);
1049 } else {
1050 return MAKE_ERR("Unsupported Reduce Op " + typeName.str());
1051 }
1052
1053 // Our batched reduce add/mean does not keep the dim; reshape if necessary.
1054 if (keepDims) {
1055
1056 std::vector<dim_t> shape = node.dims();
1057
1058 // Add removed axes. Requires increasing order sort - done above.
1059 for (const auto &axis : shapeAxes) {
1060 shape.insert(shape.begin() + axis, 1);
1061 }
1062 node = G_->createReshape(opName, node, shape);
1063 }
1064
1065 RETURN_IF_ERR(addNodeAsOutput(op, node));
1066 return Error::success();
1067 }
1068
1069 Error loadBatchOneHot(const OpType &op) {
1070 const std::string &opName = loadOperatorName(op);
1071 NodeValue data;
1072 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1073 NodeValue lengths;
1074 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(1)));
1075 NodeValue values;
1076 ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));
1077
1078 auto *node = G_->createBatchOneHot(opName, data, lengths, values);
1079 RETURN_IF_ERR(addNodeAsOutput(op, node));
1080 return Error::success();
1081 }
1082
1083 Error loadSparseLengthsSum(const OpType &op, ArgumentDictionaryTy &dict) {
1084 NodeValue in0;
1085 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
1086 NodeValue in1;
1087 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
1088 NodeValue in2;
1089 ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2)));
1090 LengthsMode lengthsMode;
1091 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
1092 float avgLength;
1093 ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict));
1094 auto *node = G_->createSparseLengthsSum(loadOperatorName(op), in0, in1, in2,
1095 lengthsMode, avgLength);
1096 RETURN_IF_ERR(addNodeAsOutput(op, node));
1097 return Error::success();
1098 }
1099
1100 Error loadSparseLengthsWeightedSum(const OpType &op,
1101 ArgumentDictionaryTy &dict) {
1102 NodeValue in0;
1103 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
1104 NodeValue in1;
1105 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
1106 NodeValue in2;
1107 ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2)));
1108 NodeValue in3;
1109 ASSIGN_VALUE_OR_RETURN_ERR(in3, getNodeValueByName(op.input(3)));
1110 LengthsMode lengthsMode;
1111 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
1112 float avgLength;
1113 ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict));
1114 auto *node = G_->createSparseLengthsWeightedSum(
1115 loadOperatorName(op), in0, in1, in2, in3, lengthsMode, avgLength);
1116 RETURN_IF_ERR(addNodeAsOutput(op, node));
1117 return Error::success();
1118 }
1119
1120 Error loadEmbedding(const OpType &op, ArgumentDictionaryTy &dict) {
1121 NodeValue weights;
1122 ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(0)));
1123
1124 NodeValue indices;
1125 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
1126
1127 int64_t padIdx = -1;
1128 if (dict.count("padIdx")) {
1129 ASSIGN_VALUE_OR_RETURN_ERR(padIdx, loadInt(dict["padIdx"]));
1130 }
1131 bool scale = false;
1132 if (dict.count("scale")) {
1133 ASSIGN_VALUE_OR_RETURN_ERR(scale, loadInt(dict["scale"]));
1134 scale = (bool)scale;
1135 RETURN_ERR_IF_NOT(scale == false,
1136 "Currently only support scale_grad_by_freq == 'false'");
1137 }
1138 bool sparse = false;
1139 if (dict.count("sparse")) {
1140 ASSIGN_VALUE_OR_RETURN_ERR(sparse, loadInt(dict["sparse"]));
1141 sparse = (bool)sparse;
1142 RETURN_ERR_IF_NOT(sparse == false,
1143 "Currently only support sparse == 'false'");
1144 }
1145 auto *node = G_->createEmbedding(loadOperatorName(op), weights, indices,
1146 padIdx, scale, sparse);
1147 RETURN_IF_ERR(addNodeAsOutput(op, node));
1148 return Error::success();
1149 }
1150
1151 Error loadEmbeddingBag(const OpType &op, ArgumentDictionaryTy &dict) {
1152 NodeValue in0;
1153 ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
1154 NodeValue in1;
1155 ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
1156 NodeValue in2;
1157 ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2)));
1158 NodeValue in3;
1159 ASSIGN_VALUE_OR_RETURN_ERR(in3, getNodeValueByName(op.input(3)));
1160 LengthsMode lengthsMode;
1161 ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
1162 float avgLength;
1163 ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict));
1164 auto *node = G_->createEmbeddingBag(
1165 loadOperatorName(op), in0, in1, in2, in3,
1166 /* hasEndOffset */ false, lengthsMode, avgLength);
1167 RETURN_IF_ERR(addNodeAsOutput(op, node));
1168 return Error::success();
1169 }
1170
1171 Error loadLengthsToRanges(const OpType &op) {
1172 NodeValue in;
1173 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1174 auto *node = G_->createLengthsToRanges(loadOperatorName(op), in);
1175 RETURN_IF_ERR(addNodeAsOutput(op, node));
1176 return Error::success();
1177 }
1178
1179 Error loadBatchBoxCox(const OpType &op) {
1180 NodeValue data;
1181 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1182 NodeValue lambda1;
1183 ASSIGN_VALUE_OR_RETURN_ERR(lambda1, getNodeValueByName(op.input(1)));
1184 NodeValue lambda2;
1185 ASSIGN_VALUE_OR_RETURN_ERR(lambda2, getNodeValueByName(op.input(2)));
1186 auto *node =
1187 G_->createBatchBoxCox(loadOperatorName(op), data, lambda1, lambda2);
1188 RETURN_IF_ERR(addNodeAsOutput(op, node));
1189 return Error::success();
1190 }
1191
1192 Error loadDotProduct(const OpType &op) {
1193 NodeValue X;
1194 ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
1195 NodeValue Y;
1196 ASSIGN_VALUE_OR_RETURN_ERR(Y, getNodeValueByName(op.input(1)));
1197 RETURN_ERR_IF_NOT(X.dims() == Y.dims(),
1198 opErrMsg(op, "X and Y must have the same dimensions"));
1199 auto *node = G_->createDotProduct(loadOperatorName(op), X, Y);
1200 RETURN_IF_ERR(addNodeAsOutput(op, node));
1201 return Error::success();
1202 }
1203
1204 Error loadReplaceNaN(const OpType &op, ArgumentDictionaryTy &dict) {
1205 // Load the input and NaN replacement value:
1206 NodeValue input;
1207 ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
1208 auto valueIt = dict.find("value");
1209 float value = 0.0f;
1210 if (valueIt != dict.end()) {
1211 ASSIGN_VALUE_OR_RETURN_ERR(value, loadFloat(valueIt->second));
1212 }
1213 auto *node = G_->createReplaceNaN(loadOperatorName(op), input, value);
1214 RETURN_IF_ERR(addNodeAsOutput(op, node));
1215 return Error::success();
1216 }
1217
1218 Error loadLengthsSum(const OpType &op) {
1219 const std::string &opName = loadOperatorName(op);
1220 NodeValue data;
1221 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1222 NodeValue lengths;
1223 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(1)));
1224
1225 RETURN_ERR_IF_NOT(
1226 lengths.dims().size() == 1,
1227 opErrMsg(
1228 op,
1229 strFormat("LengthsSum: Lengths must be a 1D vector, but found %zu ",
1230 lengths.dims().size())));
1231
1232 auto *node = G_->createLengthsSum(opName, data, lengths);
1233 RETURN_IF_ERR(addNodeAsOutput(op, node));
1234 return Error::success();
1235 }
1236
1237 Error loadExpandDims(const OpType &op, ArgumentDictionaryTy &dict) {
1238 NodeValue in;
1239 ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
1240 std::vector<dim_t> shape;
1241 ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["dims"]));
1242
1243 Node *node = G_->createExpandDims(loadOperatorName(op), in, shape);
1244 RETURN_IF_ERR(addNodeAsOutput(op, node));
1245
1246 return Error::success();
1247 }
1248
1249 Error loadSparseToDense(const OpType &op, ArgumentDictionaryTy &dict) {
1250 if (op.input_size() != 3) {
1251 return MAKE_ERR(opErrMsg(
1252 op,
1253 strFormat(
1254 "SparseToDense operator must have three inputs, but found %d ",
1255 op.input_size())));
1256 }
1257
1258 NodeValue indices;
1259 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(0)));
1260 NodeValue values;
1261 ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(1)));
1262 NodeValue dataToInferDim;
1263 ASSIGN_VALUE_OR_RETURN_ERR(dataToInferDim, getNodeValueByName(op.input(2)));
1264
1265 RETURN_ERR_IF_NOT(indices.dims().size() == 1 || indices.dims().size() == 2,
1266 opErrMsg(op, "Indices must be 1D or 2D tensor."));
1267 RETURN_ERR_IF_NOT(indices.getElementType() == ElemKind::Int32ITy ||
1268 indices.getElementType() == ElemKind::Int64ITy,
1269 opErrMsg(op, "Indices must be of int32 or int64 type."));
1270
1271 const std::string &opName = loadOperatorName(op);
1272
1273 if (indices.dims().size() == 1) {
1274 indices = G_->createReshape(opName + ".indices2D", indices,
1275 {indices.dims()[0], 1});
1276 } else {
1277 RETURN_ERR_IF_NOT(
1278 indices.dims()[1] == 1,
1279 opErrMsg(op, "Second dimension should be 1 in indices."));
1280 }
1281
1282 ShapeVector outDims{values.dims().begin(), values.dims().end()};
1283 outDims[0] = dataToInferDim.dims()[0];
1284 auto outTy =
1285 G_->getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
1286 Node *data = G_->createSplat(opName + ".data", outTy, 0.f);
1287
1288 // SparseToDense has very similar behavior to ScatterND from ONNX
1289 // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ScatterND,
1290 // therefore we can use ScatterND to implement SparseToDense.
1291 Node *result = G_->createScatterData(opName + ".scatterData", data, indices,
1292 values, true);
1293
1294 RETURN_IF_ERR(addNodeAsOutput(op, result));
1295 return Error::success();
1296 }
1297
1298 Error loadSparseToDenseMask(const OpType &op, ArgumentDictionaryTy &dict) {
1299 size_t inputSize = op.input_size();
1300 if (inputSize != 3 && inputSize != 4) {
1301 return MAKE_ERR(
1302 opErrMsg(op, strFormat("SparseToDenseMask operator must have "
1303 "3 or 4 inputs, but found %zu ",
1304 inputSize)));
1305 }
1306
1307 NodeValue indices;
1308 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(0)));
1309 NodeValue values;
1310 ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(1)));
1311 NodeValue defaultValue;
1312 ASSIGN_VALUE_OR_RETURN_ERR(defaultValue, getNodeValueByName(op.input(2)));
1313
1314 NodeValue lengths;
1315 if (inputSize == 4) {
1316 ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(3)));
1317 } else {
1318 // If Lengths input is not present, create scalar containing number of
1319 // index-value pairs.
1320 auto *lengthsConstant =
1321 mod_.createConstant(ElemKind::Int32ITy, {}, "lengthsConstant");
1322 lengthsConstant->getPayloadMutable().template getHandle<int32_t>().raw(
1323 0) = indices.dims()[0];
1324 lengths = lengthsConstant->getOutput();
1325 }
1326
1327 std::vector<dim_t> mask;
1328 ASSIGN_VALUE_OR_RETURN_ERR(mask, getShape<dim_t>(dict["mask"]));
1329
1330 auto *node = G_->createSparseToDenseMask(
1331 loadOperatorName(op), indices, values, defaultValue, lengths, mask);
1332 RETURN_IF_ERR(addNodeAsOutput(op, node));
1333 return Error::success();
1334 }
1335
1336 Error loadGatherOps(const std::string &typeName, const OpType &op,
1337 const ArgumentDictionaryTy &dict) {
1338
1339 NodeValue data;
1340 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1341 NodeValue indices;
1342 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
1343 size_t axis = typeName == "Gather" ? 0 : 1;
1344
1345 if (dict.count("axis")) {
1346 ASSIGN_VALUE_OR_RETURN_ERR(
1347 axis,
1348 loadAxis<size_t>(dict.find("axis")->second, data.dims().size()));
1349 }
1350
1351 if (indices.getElementType() != ElemKind::Int64ITy &&
1352 indices.getElementType() != ElemKind::Int32ITy) {
1353 // If the index type is not Int32 or Int64 insert a conversion layer to
1354 // introduce robustness against model problems. Constant Float indices
1355 // will get converted to integer indices via constant folding pass.
1356 indices = G_->createConvertTo(
1357 loadOperatorName(op) + "_idx_convertToi32", indices,
1358 G_->getParent()->uniqueType(ElemKind::Int32ITy, indices.dims()));
1359 }
1360
1361 auto *GN = G_->createGather(loadOperatorName(op), data, indices, axis);
1362 RETURN_IF_ERR(addNodeAsOutput(op, GN));
1363 return Error::success();
1364 }
1365
1366 Error loadGatherND(const std::string &typeName, const OpType &op,
1367 const ArgumentDictionaryTy &dict) {
1368
1369 NodeValue data;
1370 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1371 NodeValue indices;
1372 ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
1373
1374 if (indices.getElementType() != ElemKind::Int64ITy &&
1375 indices.getElementType() != ElemKind::Int32ITy) {
1376 // If the index type is not Int32 or Int64 insert a conversion layer to
1377 // introduce robustness against model problems. Constant Float indices
1378 // will get converted to integer indices via constant folding pass.
1379 indices = G_->createConvertTo(
1380 loadOperatorName(op) + "_idx_convertToi32", indices,
1381 G_->getParent()->uniqueType(ElemKind::Int32ITy, indices.dims()));
1382 }
1383
1384 auto *GN = G_->createGatherND(loadOperatorName(op), data, indices);
1385 RETURN_IF_ERR(addNodeAsOutput(op, GN));
1386 return Error::success();
1387 }
1388
1389 Error loadGatherRanges(const std::string &typeName, const OpType &op,
1390 ArgumentDictionaryTy &dict) {
1391 NodeValue data;
1392 ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
1393 RETURN_ERR_IF_NOT(
1394 data.dims().size() == 1,
1395 opErrMsg(op, strFormat("GatherRanges: Data must be a 1D vector, but "
1396 "found vector size %zu ",
1397 data.dims().size())));
1398
1399 NodeValue ranges;
1400 ASSIGN_VALUE_OR_RETURN_ERR(ranges, getNodeValueByName(op.input(1)));
1401 RETURN_ERR_IF_NOT(
1402 ranges.dims().size() == 3,
1403 opErrMsg(op, strFormat("GatherRanges: Ranges must be a 3D vector, but "
1404 "found vector size %zu ",
1405 ranges.dims().size())));
1406 RETURN_ERR_IF_NOT(
1407 ranges.dims()[2] == 2,
1408 opErrMsg(op, strFormat("GatherRanges: Last dimension of "
1409 "ranges must be 2, but found %s",
1410 std::to_string(ranges.dims()[2]).c_str())));
1411
1412 auto maxOutputSizeIt = dict.find("maxOutputSize");
1413 RETURN_ERR_IF_NOT(maxOutputSizeIt != dict.end(),
1414 opErrMsg(op, "GatherRanges: Require maxOutputSize when "
1415 "loading LengthsRangeFill."));
1416 unsigned_t maxOutputSize;
1417 ASSIGN_VALUE_OR_RETURN_ERR(maxOutputSize, loadInt(maxOutputSizeIt->second));
1418
1419 Node *GR = G_->createGatherRanges(loadOperatorName(op), data, ranges,
1420 maxOutputSize);
1421 RETURN_IF_ERR(addNodeAsOutput(op, GR));
1422 return Error::success();
1423 }
1424
1425 // Loads Less operator. Internally it's a cmpLT Node.
1426 Error loadLess(const OpType &op, ArgumentDictionaryTy &dict) {
1427 // Input Type.
1428 NodeValue xNV;
1429 ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0)));
1430 NodeValue yNV;
1431 ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(1)));
1432
1433 std::string opName = loadOperatorName(op);
1434
1435 auto *xNode = xNV.getNode();
1436 auto *yNode = yNV.getNode();
1437
1438 Node *N = G_->createNodeWithBroadcast<CmpLTNode>(opName, /* axis */ -1,
1439 xNode, yNode);
1440
1441 RETURN_IF_ERR(addNodeAsOutput(op, N));
1442 return Error::success();
1443 }
1444
1445 Error loadLogicalOps(llvm::StringRef typeName, const OpType &op) {
1446 std::string opName = loadOperatorName(op);
1447 NodeValue xNV;
1448 ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0)));
1449 NodeValue yNV;
1450 ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(1)));
1451 constexpr int axis = -1;
1452 Node *N = nullptr;
1453 if (typeName == "And") {
1454 N = G_->createNodeWithBroadcast<AndNode>(opName, axis, xNV, yNV);
1455 } else if (typeName == "Or") {
1456 N = G_->createNodeWithBroadcast<OrNode>(opName, axis, xNV, yNV);
1457 } else if (typeName == "Xor") {
1458 N = G_->createNodeWithBroadcast<XorNode>(opName, axis, xNV, yNV);
1459 } else {
1460 return MAKE_ERR("Unsupported Logical Operator");
1461 }
1462 RETURN_IF_ERR(addNodeAsOutput(op, N));
1463 return Error::success();
1464 }
1465
1466 Error loadNotOp(llvm::StringRef typeName, const OpType &op) {
1467 std::string opName = loadOperatorName(op);
1468 NodeValue xNV;
1469 ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0)));
1470 Node *N = G_->createNot(opName, xNV);
1471 RETURN_IF_ERR(addNodeAsOutput(op, N));
1472 return Error::success();
1473 }
1474
1475 // Loads Abs operator
1476 Error loadAbs(const OpType &op, ArgumentDictionaryTy &dict) {
1477 std::string opName = loadOperatorName(op);
1478 NodeValue xNV;
1479 ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0)));
1480 auto *input = xNV.getNode();
1481
1482 auto *N = G_->createAbs(opName, input);
1483 RETURN_IF_ERR(addNodeAsOutput(op, N));
1484 return Error::success();
1485 }
1486
1487 /// If operator type is supported, returns Expected<true> and creates new
1488 /// operator. Returns Operator<false> if operator type is not supported.
1489 /// Returns Error if an error occurred
1490 Expected<bool> tryLoadCommonOperator(llvm::StringRef typeName,
1491 const OpType &op,
1492 ArgumentDictionaryTy &dict) {
1493 if (typeName == "Relu") {
1494 RETURN_IF_ERR(loadRelu(op, dict));
1495 return true;
1496 }
1497 if (typeName == "Sigmoid") {
1498 RETURN_IF_ERR(loadSigmoid(op, dict));
1499 return true;
1500 }
1501 if (typeName == "Tanh") {
1502 RETURN_IF_ERR(loadTanh(op, dict));
1503 return true;
1504 }
1505 if (typeName == "Exp") {
1506 RETURN_IF_ERR(loadExp(op, dict));
1507 return true;
1508 }
1509 if (typeName == "Log") {
1510 RETURN_IF_ERR(loadLog(op, dict));
1511 return true;
1512 }
1513 if (typeName == "Neg") {
1514 RETURN_IF_ERR(loadNeg(op, dict));
1515 return true;
1516 }
1517 if (typeName == "Abs") {
1518 RETURN_IF_ERR(loadAbs(op, dict));
1519 return true;
1520 }
1521 if (typeName == "Ceil") {
1522 RETURN_IF_ERR(loadCeil(op, dict));
1523 return true;
1524 }
1525 if (typeName == "Floor") {
1526 RETURN_IF_ERR(loadFloor(op, dict));
1527 return true;
1528 }
1529 if (typeName == "Shape") {
1530 RETURN_IF_ERR(loadShape(op, dict));
1531 return true;
1532 }
1533 if (typeName == "Sqrt") {
1534 RETURN_IF_ERR(loadSqrt(op, dict));
1535 return true;
1536 }
1537 if (typeName == "Sqr") {
1538 RETURN_IF_ERR(loadSqr(op, dict));
1539 return true;
1540 }
1541 if (typeName == "Reciprocal") {
1542 RETURN_IF_ERR(loadReciprocal(op, dict));
1543 return true;
1544 }
1545 if (typeName == "Sum") {
1546 RETURN_IF_ERR(loadSum(op, dict));
1547 return true;
1548 }
1549 if (typeName == "LRN") {
1550 RETURN_IF_ERR(loadLRN(op, dict));
1551 return true;
1552 }
1553 if (typeName == "Min" || typeName == "Max") {
1554 RETURN_IF_ERR(loadMinMax(typeName, op, dict));
1555 return true;
1556 }
1557 if (typeName == "Mul" || typeName == "Add" || typeName == "Sub" ||
1558 typeName == "Div" || typeName == "Pow") {
1559 RETURN_IF_ERR(loadArithmetic(typeName, op, dict));
1560 return true;
1561 }
1562 if (typeName == "Split") {
1563 RETURN_IF_ERR(loadSplit(op, dict));
1564 return true;
1565 }
1566 if (typeName == "Reshape") {
1567 RETURN_IF_ERR(loadReshape(op, dict));
1568 return true;
1569 }
1570 if (typeName == "Flatten") {
1571 RETURN_IF_ERR(loadFlatten(op, dict));
1572 return true;
1573 }
1574 if (typeName == "Dropout") {
1575 // Save the identity operation. Note the second output (mask) for Dropout
1576 // in Caffe2 and ONNX is unused during inference, and our Dropout does not
1577 // currently implement it, thus we do not call addNodeAsOutput() here.
1578 RETURN_IF_ERR(loadIdentity(op, dict));
1579 return true;
1580 }
1581 if (typeName == "Identity" || typeName == "Alias") {
1582 RETURN_IF_ERR(loadIdentity(op, dict));
1583 return true;
1584 }
1585 if (typeName == "ReduceMean" || typeName == "ReduceSum" ||
1586 typeName == "ReduceMin" || typeName == "ReduceMax" ||
1587 typeName == "ReduceProd") {
1588 RETURN_IF_ERR(loadReduceOp(typeName, op, dict));
1589 return true;
1590 }
1591 if (typeName == "BatchMatMul") {
1592 RETURN_IF_ERR(loadBatchMatMul(op, dict));
1593 return true;
1594 }
1595 if (typeName == "BatchOneHot") {
1596 RETURN_IF_ERR(loadBatchOneHot(op));
1597 return true;
1598 }
1599 if (typeName == "SparseLengthsSum") {
1600 RETURN_IF_ERR(loadSparseLengthsSum(op, dict));
1601 return true;
1602 }
1603 if (typeName == "SparseLengthsWeightedSum") {
1604 RETURN_IF_ERR(loadSparseLengthsWeightedSum(op, dict));
1605 return true;
1606 }
1607 if (typeName == "EmbeddingBag") {
1608 RETURN_IF_ERR(loadEmbeddingBag(op, dict));
1609 return true;
1610 }
1611 if (typeName == "Embedding") {
1612 RETURN_IF_ERR(loadEmbedding(op, dict));
1613 return true;
1614 }
1615 if (typeName == "LengthsToRanges") {
1616 RETURN_IF_ERR(loadLengthsToRanges(op));
1617 return true;
1618 }
1619 if (typeName == "BatchBoxCox") {
1620 RETURN_IF_ERR(loadBatchBoxCox(op));
1621 return true;
1622 }
1623 if (typeName == "DotProduct") {
1624 RETURN_IF_ERR(loadDotProduct(op));
1625 return true;
1626 }
1627 if (typeName == "ReplaceNaN") {
1628 RETURN_IF_ERR(loadReplaceNaN(op, dict));
1629 return true;
1630 }
1631 if (typeName == "LengthsSum") {
1632 RETURN_IF_ERR(loadLengthsSum(op));
1633 return true;
1634 }
1635 if (typeName == "ExpandDims") {
1636 RETURN_IF_ERR(loadExpandDims(op, dict));
1637 return true;
1638 }
1639 if (typeName == "SparseToDense") {
1640 RETURN_IF_ERR(loadSparseToDense(op, dict));
1641 return true;
1642 }
1643 if (typeName == "SparseToDenseMask") {
1644 RETURN_IF_ERR(loadSparseToDenseMask(op, dict));
1645 return true;
1646 }
1647 if (typeName == "Gather" || typeName == "BatchGather") {
1648 RETURN_IF_ERR(loadGatherOps(typeName.str(), op, dict));
1649 return true;
1650 }
1651 if (typeName == "GatherND") {
1652 RETURN_IF_ERR(loadGatherND(typeName.str(), op, dict));
1653 return true;
1654 }
1655 if (typeName == "GatherRanges") {
1656 RETURN_IF_ERR(loadGatherRanges(typeName.str(), op, dict));
1657 return true;
1658 }
1659 if (typeName == "Less") {
1660 RETURN_IF_ERR(loadLess(op, dict));
1661 return true;
1662 }
1663 if (typeName == "And" || typeName == "Or" || typeName == "Xor") {
1664 RETURN_IF_ERR(loadLogicalOps(typeName, op));
1665 return true;
1666 }
1667 if (typeName == "Not") {
1668 RETURN_IF_ERR(loadNotOp(typeName, op));
1669 return true;
1670 }
1671 if (typeName == "Pow") {
1672 RETURN_IF_ERR(loadPow(op, dict));
1673 return true;
1674 }
1675
1676 return false;
1677 }
1678
1679 /// Utility function which computes the resulting shape in case of
1680 /// multidirectional broadcasting.
1681 Expected<std::vector<dim_t>>
1682 computeMultidirectionalBroadcast(llvm::ArrayRef<dim_t> shape0,
1683 llvm::ArrayRef<dim_t> shape1) {
1684 size_t numDims0 = shape0.size();
1685 size_t numDims1 = shape1.size();
1686 size_t newNumDims = numDims0 > numDims1 ? numDims0 : numDims1;
1687 std::vector<dim_t> reshapeDims(newNumDims);
1688
1689 for (size_t i = 0; i < newNumDims; i++) {
1690 reshapeDims[i] = 1;
1691 }
1692 RETURN_IF_ERR(mergeMultidirectionalBroadcast(reshapeDims, shape0));
1693 RETURN_IF_ERR(mergeMultidirectionalBroadcast(reshapeDims, shape1));
1694
1695 return reshapeDims;
1696 }
1697
1698 /// Associate all outputs of \p op with nodes in \p NVs. Number of outputs of
1699 /// \p op should match the number of elements of \p NVs.
1700 /// \returns error code in case of error.
1701 Error assignNodeOutputs(const OpType &op, llvm::ArrayRef<NodeValue> NVs) {
1702 RETURN_ERR_IF_NOT((dim_t)NVs.size() == (dim_t)op.output_size(),
1703 "Output size mismatch.");
1704 for (size_t i = 0; i < NVs.size(); i++) {
1705 nodeValueByName_[op.output(i)] = NVs[i];
1706 }
1707 return Error::success();
1708 }
1709
1710 /// Load pre-trained weights from \p weightDescriptors.
1711 Error loadWeights(uint32_t weightsCount,
1712 const onnxTensorDescriptorV1 *weightDescriptors) {
1713 for (uint32_t i = 0; i < weightsCount; ++i) {
1714 const char *name = weightDescriptors[i].name;
1715
1716 LoadWeightResult loadResult;
1717 if (auto resOrErr = loadWeight(weightDescriptors[i])) {
1718 loadResult = std::move(*resOrErr);
1719 } else {
1720 RETURN_ERR(resOrErr.takeError());
1721 }
1722
1723 // If the weight is offline create a static placeholder, otherwise create
1724 // a constant.
1725 if (weightDescriptors[i].isOffline) {
1726 RETURN_ERR_IF_NOT(
1727 !clipQuantRangeToFP16_ ||
1728 !loadResult.t->getType().isQuantizedType() ||
1729 loadResult.t->getType().isFusedQuantizedType(),
1730 strFormat("Do not support clipQuantRangeToFP16 with unfused "
1731 "quantized input Placeholders: %s",
1732 name));
1733 Placeholder *pl;
1734 ASSIGN_VALUE_OR_RETURN_ERR(
1735 pl, createAndRegisterPlaceholder(name, &loadResult.type,
1736 /*isStatic*/ true));
1737 (void)pl;
1738 } else {
1739 RETURN_IF_ERR(
1740 createAndRegisterConstant(name, std::move(*loadResult.t)));
1741 }
1742
1743 if (loadResult.offsets) {
1744 auto offsetsName = strFormat("%s_loaded_offsets", name);
1745 RETURN_IF_ERR(createAndRegisterConstant(
1746 offsetsName, std::move(*loadResult.offsets)));
1747 }
1748
1749 if (loadResult.scales) {
1750 auto scalesName = strFormat("%s_loaded_scales", name);
1751 RETURN_IF_ERR(createAndRegisterConstant(scalesName,
1752 std::move(*loadResult.scales)));
1753 }
1754 }
1755
1756 return Error::success();
1757 }
1758
1759 /// Sets the type of \p S to have \p dstKind, using the same dims as S.
1760 Error setFusedTy(Storage *S, ElemKind dstKind) {
1761 // Use dummy 0.0/0 as scale/offset, since the actual scales/offsets
1762 // are fused inline with the data.
1763 TypeRef fusedTy = mod_.uniqueType(dstKind, S->dims(), 0.0, 0);
1764 return setFusedTy(S, fusedTy);
1765 }
1766
1767 /// Sets the type of \p S to have \p fusedTy. If \p S already has type \p
1768 /// fusedTy, then this is a noop. Otherwise, expected that the original S is
1769 /// UInt8QTy. If \p S is a Constant, then also sets the payload of the
1770 /// Constant to have the same type.
1771 /// The motivation here is that there is no fused quantized type in
1772 /// Caffe2/ONNX, so we will always load them in UInt8QTy. We then change it
1773 /// from UInt8QTy to one of the fused kinds here. This may not be necessary if
1774 /// another user has already changed it, or the type may already have been
1775 /// modified in the case of loading into an existing module.
1776 Error setFusedTy(Storage *S, TypeRef fusedTy) {
1777 assert(fusedTy->isFusedQuantizedType() && "Expected fused quantized type.");
1778
1779 // If S already has the requested type then return early.
1780 if (S->getOutput().getType()->isEqual(*fusedTy)) {
1781 return Error::success();
1782 }
1783
1784 RETURN_ERR_IF_NOT(S->getElementType() == ElemKind::UInt8QTy,
1785 "Data must be UInt8QTy, but was " +
1786 Type::getElementName(S->getElementType()).str());
1787 S->setType(Storage::OutputIdx, fusedTy);
1788 // If the node is a Constant set the payload type as well.
1789 if (auto *C = llvm::dyn_cast<Constant>(S)) {
1790 C->setPayloadType(fusedTy);
1791 }
1792
1793 return Error::success();
1794 }
1795
1796 static Expected<bool> getCountIncludePads(ArgumentDictionaryTy &dict,
1797 bool defaultValue) {
1798 if (dict.count("count_include_pad")) {
1799 int countIncludePads;
1800 ASSIGN_VALUE_OR_RETURN_ERR(countIncludePads,
1801 loadInt(dict.at("count_include_pad")));
1802 return (bool)countIncludePads;
1803 }
1804 // Return default value if can't find in the dict
1805 return defaultValue;
1806 }
1807
1808 static Expected<std::vector<unsigned_t>>
1809 getDilations(ArgumentDictionaryTy &dict,
1810 const std::vector<unsigned_t> &defaultValue) {
1811 // For Caffe2 Model, `dilation` field can be either one integer or multiple
1812 // integers (one for each axis). When it's one integer the field in the dict
1813 // will be `dilation`. Otherwise, the field in the dict will be `dilations`.
1814
1815 // For Onnx Model, it can only be `dilations` and it must be a list of
1816 // integers.
1817 if (dict.count("dilation")) {
1818 unsigned_t dilation;
1819 ASSIGN_VALUE_OR_RETURN_ERR(dilation, loadInt(dict.at("dilation")));
1820 return std::vector<unsigned_t>{dilation, dilation};
1821 }
1822 if (dict.count("dilations")) {
1823 std::vector<unsigned_t> shape;
1824 ASSIGN_VALUE_OR_RETURN_ERR(shape,
1825 getShape<unsigned_t>(dict["dilations"]));
1826 return shape;
1827 }
1828
1829 return defaultValue;
1830 }
1831};
1832
1833} // namespace glow
1834
1835#endif // GLOW_IMPORTER_COMMONOPERATORLOADER_H
1836