1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_IMPORTER_COMMONOPERATORLOADER_H |
18 | #define GLOW_IMPORTER_COMMONOPERATORLOADER_H |
19 | |
20 | #include "foxi/onnxifi.h" |
21 | |
22 | #include "glow/Importer/ProtobufLoader.h" |
23 | |
24 | #include "glow/Base/Tensor.h" |
25 | #include "glow/Graph/Graph.h" |
26 | |
27 | #include "llvm/ADT/ArrayRef.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | |
30 | #include <functional> |
31 | #include <numeric> |
32 | #include <string> |
33 | #include <unordered_map> |
34 | #include <vector> |
35 | |
36 | namespace glow { |
37 | |
38 | /// Result of loading a weight, potentially with additional offsets and |
39 | /// scales tensors containing quantization parameters only if the loaded weight |
40 | /// was found to have multiple quantization parameters. |
41 | struct LoadWeightResult { |
42 | /// Main Glow tensor, this is always non-null. |
43 | std::unique_ptr<Tensor> t; |
44 | /// Glow tensor containing quantization offsets. This should only be non-null |
45 | /// if there is more than 1 quantization parameter found. |
46 | std::unique_ptr<Tensor> offsets; |
47 | /// Glow tensor containing quantization scales. This should only be non-null |
48 | /// if there is more than 1 quantization parameter found. |
49 | std::unique_ptr<Tensor> scales; |
50 | /// Type info of the weight, this is used for offline weights. |
51 | Type type; |
52 | }; |
53 | |
54 | #define dispatchQuantizedImpl(functionName, elemTy, ...) \ |
55 | switch (elemTy) { \ |
56 | case ElemKind::Int8QTy: \ |
57 | functionName<int8_t>(__VA_ARGS__); \ |
58 | break; \ |
59 | case ElemKind::Int16QTy: \ |
60 | functionName<int16_t>(__VA_ARGS__); \ |
61 | break; \ |
62 | case ElemKind::Int32QTy: \ |
63 | functionName<int32_t>(__VA_ARGS__); \ |
64 | break; \ |
65 | default: \ |
66 | llvm_unreachable("Type is not supported"); \ |
67 | } |
68 | |
69 | template <typename eTy> |
70 | void rescaleQTensor(const Tensor &oldT, Tensor &rescaledT, float newMin, |
71 | float newMax) { |
72 | const Type &oldTy = oldT.getType(); |
73 | const TensorQuantizationParams oldQParams = {oldTy.getScale(), |
74 | oldTy.getOffset()}; |
75 | const TensorQuantizationParams newQParams = chooseQuantizationParams( |
76 | {newMin, newMax}, quantization::Asymmetric, oldTy.getElementType()); |
77 | |
78 | // Setup Tensor to copy rescaled Tensor into. |
79 | Type rescaledTy(oldTy.getElementType(), oldTy.dims(), newQParams.scale, |
80 | newQParams.offset); |
81 | rescaledT.reset(rescaledTy); |
82 | |
83 | auto srcH = oldT.getHandle<eTy>(); |
84 | auto destH = rescaledT.getHandle<eTy>(); |
85 | for (size_t i = 0, e = destH.size(); i < e; ++i) { |
86 | float val = quantization::dequantize(srcH.raw(i), oldQParams); |
87 | destH.raw(i) = quantization::quantize(val, newQParams); |
88 | } |
89 | } |
90 | |
91 | /// Given \p result, rescale it given \p newMin and \p newMax. |
92 | template <typename eTy> |
93 | void rescaleQTensorResult(LoadWeightResult &result, float newMin, |
94 | float newMax) { |
95 | // Get new type based on newMin/newMax and old elem kind. |
96 | auto rescaledT = glow::make_unique<Tensor>(); |
97 | rescaleQTensor<eTy>(*result.t, *rescaledT, newMin, newMax); |
98 | result.t = std::move(rescaledT); |
99 | result.type = result.t->getType(); |
100 | } |
101 | |
102 | /// Contains loaders for operators, which are common to ONNX and Caffe2 |
103 | /// formats. Every loader method adds necessary nodes to property G_, which |
104 | /// is inherited from ProtobufLoader class, therefore modifying the class |
105 | /// instance itself. |
106 | template <typename OpType, typename AttrType> |
107 | class CommonOperatorLoader : public ProtobufLoader { |
108 | /// Loads the onnxTensorDescriptorV1 \p in and \returns a LoadWeightResult |
109 | /// where result.t is the main contents of the the onnxTensorDescriptorV1 and |
110 | /// result.offsets and result.scales are the quantization scales and offsets |
111 | /// of the onnxTensorDescriptorV1 if there were more than 1. If there is |
112 | /// exactly 1 scale and offset then result.t will be a quantized glow tensor. |
113 | inline Expected<LoadWeightResult> |
114 | loadWeight(const onnxTensorDescriptorV1 &in) { |
115 | // Only support CPU memory tensors. |
116 | if (in.memoryType != ONNXIFI_MEMORY_TYPE_CPU) { |
117 | return MAKE_ERR("Only support CPU memory tensors." ); |
118 | } |
119 | |
120 | // Number of qparams in the onnxTensorDescriptor. |
121 | const dim_t qparams = static_cast<dim_t>(in.quantizationParams); |
122 | |
123 | // Only support quantizationAxis=1 for now. |
124 | if (qparams > 0 && in.quantizationAxis != 1) { |
125 | return MAKE_ERR(strFormat( |
126 | "Glow can only import quantized tensors with quantizationAxis=1 but " |
127 | "the tensor %s has quantizationAxis=%u" , |
128 | in.name, in.quantizationAxis)); |
129 | } |
130 | |
131 | LoadWeightResult result; |
132 | result.t = glow::make_unique<Tensor>(); |
133 | |
134 | std::vector<dim_t> dims; |
135 | for (unsigned i = 0; i < in.dimensions; ++i) { |
136 | dims.push_back(in.shape[i]); |
137 | } |
138 | |
139 | // Load unquantized tensor. |
140 | if (in.quantizationParams == 0) { |
141 | if (in.dataType == ONNXIFI_DATATYPE_FLOAT32) { |
142 | result.type = Type(ElemKind::FloatTy, dims); |
143 | } else if (in.dataType == ONNXIFI_DATATYPE_FLOAT16) { |
144 | result.type = Type(ElemKind::Float16Ty, dims); |
145 | } else if (in.dataType == ONNXIFI_DATATYPE_BFLOAT16) { |
146 | result.type = Type(ElemKind::BFloat16Ty, dims); |
147 | } else if (in.dataType == ONNXIFI_DATATYPE_INT32) { |
148 | result.type = Type(ElemKind::Int32ITy, dims); |
149 | } else if (in.dataType == ONNXIFI_DATATYPE_INT64) { |
150 | result.type = Type(ElemKind::Int64ITy, dims); |
151 | } else if (in.dataType == ONNXIFI_DATATYPE_UINT8) { |
152 | // UInt8 type is used for variety of rowwise quantized SLSs. |
153 | // Make dummy scale and offset for these cases. |
154 | result.type = Type(ElemKind::UInt8QTy, dims, 1.0, 0); |
155 | } else if (in.dataType == ONNXIFI_DATATYPE_UINT64) { |
156 | result.type = Type(ElemKind::Int64ITy, dims); |
157 | for (size_t i = 0; i < result.t->size(); ++i) { |
158 | RETURN_ERR_IF_NOT( |
159 | ((int64_t *)in.buffer)[i] >= 0, |
160 | "Disallow overflow of loaded UINT64 data into Int64ITy." ); |
161 | } |
162 | } else { |
163 | return MAKE_ERR(strFormat( |
164 | "Only float, index, and uint8 unquantized tensors are supported, " |
165 | "got input with ONNXIFI_DATATYPE: %zu" , |
166 | static_cast<size_t>(in.dataType))); |
167 | } |
168 | if (!in.isOffline) { |
169 | *result.t = Tensor((void *)in.buffer, &result.type); |
170 | } |
171 | return Expected<LoadWeightResult>(std::move(result)); |
172 | } |
173 | |
174 | // Load quantized tensor with either a single or multiple qparams. |
175 | float scale = 1.0; |
176 | int32_t offset = 0; |
177 | |
178 | // If multiple qparams are present then load them as tensors and use the |
179 | // the default qparams for the result.t otherwise use the first (only) |
180 | // qparams. |
181 | if (in.quantizationParams == 1) { |
182 | scale = in.scales[0]; |
183 | offset = in.biases[0]; |
184 | } else { |
185 | RETURN_ERR_IF_NOT(!loadUniquedDummyQParams_, |
186 | strFormat("Unsupported loading of uniqued qparams for " |
187 | "vector of scales/biases for %s" , |
188 | in.name)); |
189 | Type scalesTy(ElemKind::FloatTy, llvm::makeArrayRef({qparams})); |
190 | Type offsetsTy(ElemKind::Int32ITy, llvm::makeArrayRef({qparams})); |
191 | result.scales = glow::make_unique<Tensor>((void *)in.scales, &scalesTy); |
192 | result.offsets = glow::make_unique<Tensor>((void *)in.biases, &offsetsTy); |
193 | } |
194 | |
195 | // If we have a scale of dummyScale, then this must be a dummy pair of |
196 | // scale/offset. Look up the actual scale/offset to use as previously |
197 | // loaded, using the offset as the key to updatedTQPs_. |
198 | if (replaceDummyTQPs_ && scale == dummyScale) { |
199 | TensorQuantizationParams TQP; |
200 | ASSIGN_VALUE_OR_RETURN_ERR(TQP, getUpdatedTQP(offset)); |
201 | scale = TQP.scale; |
202 | offset = TQP.offset; |
203 | } |
204 | |
205 | if (in.dataType == ONNXIFI_DATATYPE_UINT8) { |
206 | TypeRef outTy; |
207 | ASSIGN_VALUE_OR_RETURN_ERR( |
208 | outTy, ProtobufLoader::loadQuantTy( |
209 | in.name, ElemKind::Int8QTy, dims, scale, offset, |
210 | /* shiftUInt8ToInt8 */ true, |
211 | /* skipClipQuantRangeToFP16 */ true)); |
212 | // Must copy the weights here because we will need to modify them by |
213 | // adjusting for UINT8_TO_INT8_SHIFT. |
214 | result.type = *outTy; |
215 | if (!in.isOffline) { |
216 | result.t->reset(result.type); |
217 | |
218 | auto TH = result.t->getHandle<int8_t>(); |
219 | uint8_t *data = (uint8_t *)in.buffer; |
220 | for (size_t i = 0; i < TH.size(); ++i) { |
221 | TH.raw(i) = (int8_t)(data[i] - UINT8_TO_INT8_SHIFT); |
222 | } |
223 | } |
224 | } else if (in.dataType == ONNXIFI_DATATYPE_INT32) { |
225 | TypeRef outTy; |
226 | ASSIGN_VALUE_OR_RETURN_ERR( |
227 | outTy, ProtobufLoader::loadQuantTy( |
228 | in.name, ElemKind::Int32QTy, dims, scale, offset, |
229 | /* shiftUInt8ToInt8 */ true, |
230 | /* skipClipQuantRangeToFP16 */ true)); |
231 | result.type = *outTy; |
232 | if (!in.isOffline) { |
233 | *result.t = Tensor((void *)in.buffer, &result.type); |
234 | } |
235 | } else if (in.dataType == ONNXIFI_DATATYPE_INT8) { |
236 | TypeRef outTy; |
237 | ASSIGN_VALUE_OR_RETURN_ERR( |
238 | outTy, ProtobufLoader::loadQuantTy( |
239 | in.name, ElemKind::Int8QTy, dims, scale, offset, |
240 | /* shiftUInt8ToInt8 */ false, |
241 | /* skipClipQuantRangeToFP16 */ true)); |
242 | result.type = *outTy; |
243 | if (!in.isOffline) { |
244 | *result.t = Tensor((void *)in.buffer, &result.type); |
245 | } |
246 | } else { |
247 | return MAKE_ERR( |
248 | strFormat("Only uint8, int32, and int8, quantized tensors are " |
249 | "supported, got input with ONNXIFI_DATATYPE: %zu" , |
250 | static_cast<size_t>(in.dataType))); |
251 | } |
252 | |
253 | // If we're clipping quantized ranges tp FP16, then we need to rescale the |
254 | // Tensor and update its type, plus the type in result. |
255 | if (clipQuantRangeToFP16_) { |
256 | const ElemKind k = result.type.getElementType(); |
257 | const auto qMinMax = getQuantizedValueRange(scale, offset, k); |
258 | const float newMin = std::max(qMinMax.first, kMinFP16); |
259 | const float newMax = std::min(qMinMax.second, kMaxFP16); |
260 | |
261 | // If min or max are clipped then create a new Tensor with the adjusted |
262 | // type, and rescale its payload. |
263 | if (newMin != qMinMax.first || newMax != qMinMax.second) { |
264 | RETURN_ERR_IF_NOT( |
265 | !in.isOffline, |
266 | strFormat("For clipQuantRangeToFP16, currently do " |
267 | "not support offline quantizated weights: %s" , |
268 | in.name)); |
269 | RETURN_ERR_IF_NOT(!result.offsets && !result.scales, |
270 | strFormat("For clipQuantRangeToFP16, currently do " |
271 | "not support multiple qparams: %s" , |
272 | in.name)); |
273 | |
274 | dispatchQuantizedImpl(rescaleQTensorResult, k, result, newMin, newMax); |
275 | } |
276 | } |
277 | |
278 | return Expected<LoadWeightResult>(std::move(result)); |
279 | } |
280 | |
281 | /// Merge shape \p shape into \p mergeShape, following multidirectional |
282 | /// broadcasting rules. |
283 | Error mergeMultidirectionalBroadcast(std::vector<dim_t> &mergeShape, |
284 | llvm::ArrayRef<dim_t> shape) { |
285 | size_t shift = mergeShape.size() - shape.size(); |
286 | for (size_t i = 0; i < shape.size(); i++) { |
287 | if (shape[i] != 1) { |
288 | RETURN_ERR_IF_NOT((shape[i] == mergeShape[shift + i]) || |
289 | (mergeShape[shift + i] == 1), |
290 | "Incompatible dimension for the broadcast" ); |
291 | mergeShape[shift + i] = shape[i]; |
292 | } |
293 | // Otherwise, just leave mergeShape[i] as it is. |
294 | } |
295 | return Error::success(); |
296 | } |
297 | |
298 | protected: |
299 | CommonOperatorLoader(llvm::ArrayRef<const char *> names, |
300 | llvm::ArrayRef<TypeRef> types, Function *F, |
301 | Error *errPtr = nullptr, |
302 | bool loadIntoExistingModule = false, |
303 | OriginNameToTQPMap *originNameToTQPMap = nullptr, |
304 | bool loadUniquedDummyQParams = false, |
305 | bool zeroScaleFP16Clip = false, |
306 | bool clipQuantRangeToFP16 = false) |
307 | : ProtobufLoader(names, types, F, errPtr, loadIntoExistingModule, |
308 | originNameToTQPMap, loadUniquedDummyQParams, |
309 | /* replaceDummyTQPs */ false, zeroScaleFP16Clip, |
310 | clipQuantRangeToFP16) {} |
311 | |
312 | CommonOperatorLoader( |
313 | llvm::ArrayRef<const char *> names, llvm::ArrayRef<TypeRef> types, |
314 | Module &mod, Error *errPtr = nullptr, bool loadIntoExistingModule = false, |
315 | OriginNameToTQPMap *originNameToTQPMap = nullptr, |
316 | bool loadUniquedDummyQParams = false, bool replaceDummyTQPs = false, |
317 | bool zeroScaleFP16Clip = false, bool clipQuantRangeToFP16 = false) |
318 | : ProtobufLoader(names, types, mod, errPtr, loadIntoExistingModule, |
319 | originNameToTQPMap, loadUniquedDummyQParams, |
320 | replaceDummyTQPs, zeroScaleFP16Clip, |
321 | clipQuantRangeToFP16) {} |
322 | |
323 | using ArgumentDictionaryTy = |
324 | std::unordered_map<std::string, const AttrType *>; |
325 | |
326 | /// If we were replacing or loading dummy TQPs, \returns success if there |
327 | /// aren't any dummies left, or there are only dummies left. |
328 | Error verifyDummyQParams() { |
329 | RETURN_ERR_IF_NOT(!(replaceDummyTQPs_ && loadUniquedDummyQParams_), |
330 | "Cannot replace dummy TQPs when loading uniqued TQPs." ); |
331 | if (replaceDummyTQPs_ || loadUniquedDummyQParams_) { |
332 | RETURN_IF_ERR(mod_.verifyDummyQParams(loadUniquedDummyQParams_)); |
333 | } |
334 | return Error::success(); |
335 | } |
336 | |
337 | /// Helper to load quantization parameters from \p dict for op named \p name. |
338 | /// \returns a new TypeRef given \p k and \p dims. |
339 | Expected<TypeRef> loadQuantTy(const std::string &name, ElemKind k, |
340 | llvm::ArrayRef<dim_t> dims, |
341 | ArgumentDictionaryTy &dict, |
342 | bool skipClipQuantRangeToFP16 = false) { |
343 | RETURN_ERR_IF_NOT(dict.count("Y_scale" ), |
344 | "missing Y_scale for quantized output type for " + name); |
345 | RETURN_ERR_IF_NOT(dict.count("Y_zero_point" ), |
346 | "missing zero point for quantized output type for " + |
347 | name); |
348 | |
349 | float scale; |
350 | ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict["Y_scale" ])); |
351 | int32_t offset; |
352 | ASSIGN_VALUE_OR_RETURN_ERR(offset, loadInt(dict["Y_zero_point" ])); |
353 | |
354 | return ProtobufLoader::loadQuantTy(name, k, dims, scale, offset, |
355 | /* shiftUInt8ToInt8 */ true, |
356 | skipClipQuantRangeToFP16); |
357 | } |
358 | |
359 | /// \returns True if the operator has broadcasting activated. |
360 | virtual Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) = 0; |
361 | |
362 | /// \returns True if the operator with the name \p typeName has support |
363 | /// for multidirectional broadcasting. |
364 | virtual bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) = 0; |
365 | |
366 | inline Expected<LengthsMode> getLengthsMode(ArgumentDictionaryTy &dict) { |
367 | bool length1 = false; |
368 | if (dict.count("length1" )) { |
369 | ASSIGN_VALUE_OR_RETURN_ERR(length1, loadInt(dict["length1" ])); |
370 | } |
371 | if (length1) { |
372 | return LengthsMode::AllOne; |
373 | } |
374 | return LengthsMode::Variable; |
375 | } |
376 | |
377 | inline Expected<float> getAvgLength(ArgumentDictionaryTy &dict) { |
378 | float avgLength = NAN; |
379 | if (dict.count("average_lookup_length" )) { |
380 | ASSIGN_VALUE_OR_RETURN_ERR(avgLength, |
381 | loadFloat(dict["average_lookup_length" ])); |
382 | } |
383 | return avgLength; |
384 | } |
385 | |
386 | const std::string opErrMsg(const OpType &op, const std::string &errMsg) { |
387 | const std::string &opName = loadOperatorName(op); |
388 | return strFormat(" [Operator-'%s'] : %s " , opName.c_str(), errMsg.c_str()); |
389 | } |
390 | |
391 | /// Associate the name of operation outputs to a NodeValues corresponding to |
392 | /// node \p node. If \p numOutputs is lower than 0, then all outputs are |
393 | /// associated. Otherwise, the first \p numOutputs outputs are associated. |
394 | Error addNodeAsOutput(const OpType &op, Node *node, int numOutputs = -1) { |
395 | RETURN_ERR_IF_NOT(numOutputs <= op.output_size(), |
396 | "Can't register more than outputs in the operation." ); |
397 | numOutputs = (numOutputs < 0) ? op.output_size() : numOutputs; |
398 | for (int i = 0; i < numOutputs; i++) { |
399 | nodeValueByName_[op.output(i)] = NodeValue(node, i); |
400 | } |
401 | return Error::success(); |
402 | } |
403 | |
404 | /// Loads RELU operator, given its protobuf representation and parsed args. |
405 | Error loadRelu(const OpType &op, ArgumentDictionaryTy &dict) { |
406 | const std::string &opName = loadOperatorName(op); |
407 | NodeValue in; |
408 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
409 | auto *R = G_->createRELU(opName, in); |
410 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
411 | return Error::success(); |
412 | } |
413 | |
414 | #define LOAD_UNARY_OP(OPNAME) \ |
415 | Error load##OPNAME(const OpType &op, ArgumentDictionaryTy &dict) { \ |
416 | const std::string &opName = loadOperatorName(op); \ |
417 | NodeValue in; \ |
418 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); \ |
419 | auto *T = G_->create##OPNAME(opName, in); \ |
420 | RETURN_IF_ERR(addNodeAsOutput(op, T)); \ |
421 | return Error::success(); \ |
422 | } |
423 | |
424 | LOAD_UNARY_OP(Sigmoid) |
425 | LOAD_UNARY_OP(Tanh) |
426 | LOAD_UNARY_OP(Exp) |
427 | LOAD_UNARY_OP(Neg) |
428 | LOAD_UNARY_OP(Floor) |
429 | LOAD_UNARY_OP(Ceil) |
430 | LOAD_UNARY_OP(Truncate) |
431 | LOAD_UNARY_OP(Log) |
432 | |
433 | Error loadShape(const OpType &op, ArgumentDictionaryTy &dict) { |
434 | NodeValue in; |
435 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
436 | |
437 | // This is statically known data, and so we create a Tensor for it. |
438 | Tensor T(ElemKind::Int64ITy, {(dim_t)in.dims().size()}); |
439 | T.getHandle<int64_t>() = |
440 | std::vector<int64_t>(in.dims().begin(), in.dims().end()); |
441 | |
442 | RETURN_IF_ERR(createAndRegisterConstant(op.output(0), std::move(T))); |
443 | |
444 | return Error::success(); |
445 | } |
446 | |
447 | /// Loads Pow operator, given its protobuf representation and parsed args. |
448 | Error loadPow(const OpType &op, ArgumentDictionaryTy &dict) { |
449 | const std::string &opName = loadOperatorName(op); |
450 | NodeValue base; |
451 | ASSIGN_VALUE_OR_RETURN_ERR(base, getNodeValueByName(op.input(0))); |
452 | NodeValue exp; |
453 | ASSIGN_VALUE_OR_RETURN_ERR(exp, getNodeValueByName(op.input(1))); |
454 | auto R = G_->createNodeWithBroadcast<PowNode>(opName, -1, base, exp); |
455 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
456 | return Error::success(); |
457 | } |
458 | |
459 | /// Loads Sqrt operator, given its protobuf representation and parsed args. |
460 | Error loadSqrt(const OpType &op, ArgumentDictionaryTy &dict) { |
461 | const std::string &opName = loadOperatorName(op); |
462 | NodeValue in; |
463 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
464 | auto *R = G_->createPow(opName, in, 0.5f); |
465 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
466 | return Error::success(); |
467 | } |
468 | |
469 | /// Loads Sqr operator, given its protobuf representation and parsed args. |
470 | Error loadSqr(const OpType &op, ArgumentDictionaryTy &dict) { |
471 | const std::string &opName = loadOperatorName(op); |
472 | NodeValue in; |
473 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
474 | auto *R = G_->createPow(opName, in, 2.0f); |
475 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
476 | return Error::success(); |
477 | } |
478 | |
479 | /// Loads Reciprocal operator, given its protobuf representation and parsed |
480 | /// args. |
481 | Error loadReciprocal(const OpType &op, ArgumentDictionaryTy &dict) { |
482 | const std::string &opName = loadOperatorName(op); |
483 | NodeValue in; |
484 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
485 | auto *R = G_->createPow(opName, in, -1.0f); |
486 | RETURN_IF_ERR(addNodeAsOutput(op, R)); |
487 | return Error::success(); |
488 | } |
489 | |
490 | Error loadSum(const OpType &op, ArgumentDictionaryTy &dict) { |
491 | if (op.input_size() == 1) { |
492 | NodeValue in; |
493 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
494 | RETURN_IF_ERR(addNodeAsOutput(op, in)); |
495 | } else if (op.input_size() == 2) { |
496 | const std::string &opName = loadOperatorName(op); |
497 | NodeValue in0; |
498 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
499 | NodeValue in1; |
500 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
501 | auto *node = G_->createAdd(opName, in0, in1); |
502 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
503 | } else { |
504 | const std::string &opName = loadOperatorName(op); |
505 | const unsigned numInputs = op.input_size(); |
506 | llvm::SmallVector<NodeValue, 4> inputs; |
507 | inputs.reserve(numInputs); |
508 | for (unsigned i = 0; i < numInputs; i++) { |
509 | NodeValue in; |
510 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i))); |
511 | inputs.push_back(G_->createExpandDims(opName, in, {0})); |
512 | } |
513 | ConcatNode *concat = G_->createConcat(opName, inputs, /* axis */ 0); |
514 | Node *node = G_->createBatchedReduceAdd(opName, concat, /* axis */ {0}); |
515 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
516 | } |
517 | return Error::success(); |
518 | } |
519 | |
520 | Error loadLRN(const OpType &op, ArgumentDictionaryTy &dict) { |
521 | const std::string &opName = loadOperatorName(op); |
522 | NodeValue in; |
523 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
524 | |
525 | size_t size; |
526 | ASSIGN_VALUE_OR_RETURN_ERR(size, loadInt(dict["size" ])); |
527 | float alpha; |
528 | ASSIGN_VALUE_OR_RETURN_ERR(alpha, loadFloat(dict["alpha" ])); |
529 | float beta; |
530 | ASSIGN_VALUE_OR_RETURN_ERR(beta, loadFloat(dict["beta" ])); |
531 | float k; |
532 | ASSIGN_VALUE_OR_RETURN_ERR(k, loadFloat(dict["bias" ])); |
533 | |
534 | auto *tr = G_->createTranspose(opName, in, NCHW2NHWC); |
535 | |
536 | auto *node = G_->createLocalResponseNormalization(opName, tr, size / 2, |
537 | alpha, beta, k); |
538 | |
539 | auto *N = G_->createTranspose(opName, node, NHWC2NCHW); |
540 | |
541 | // LRN in Caffe2 has a scale_ output, but I believe it's unused for |
542 | // inference. So explicitly only set output 0. |
543 | nodeValueByName_[op.output(0)] = N->getResult(); |
544 | return Error::success(); |
545 | } |
546 | |
547 | Error loadMinMax(llvm::StringRef typeName, const OpType &op, |
548 | ArgumentDictionaryTy &dict) { |
549 | const std::string &opName = loadOperatorName(op); |
550 | NodeValue in0; |
551 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
552 | NodeValue in1; |
553 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
554 | |
555 | Node *node = nullptr; |
556 | if (typeName == "Min" ) { |
557 | node = G_->createNodeWithBroadcast<MinNode>(opName, -1, in0, in1); |
558 | } else if (typeName == "Max" ) { |
559 | node = G_->createNodeWithBroadcast<MaxNode>(opName, -1, in0, in1); |
560 | } else { |
561 | return MAKE_ERR(opErrMsg(op, "Invalid min or max operator" )); |
562 | } |
563 | |
564 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
565 | return Error::success(); |
566 | } |
567 | |
568 | static Expected<NodeValue> handleMatMulTranspose(Function *F, |
569 | ArgumentDictionaryTy &dict, |
570 | llvm::StringRef key, |
571 | NodeValue input) { |
572 | if (!dict.count(key.str())) { |
573 | return input; |
574 | } |
575 | |
576 | int isTransposed; |
577 | ASSIGN_VALUE_OR_RETURN_ERR(isTransposed, loadInt(dict[key.str()])); |
578 | if (isTransposed == 1) { |
579 | auto dimsSize = input.dims().size(); |
580 | RETURN_ERR_IF_NOT(dimsSize >= 2, |
581 | "C2 specs say rank of inputs must be >= 2" ); |
582 | |
583 | std::vector<unsigned_t> shuffle; |
584 | unsigned_t i; |
585 | for (i = 0; i < dimsSize - 2; ++i) { |
586 | shuffle.push_back(i); |
587 | } |
588 | shuffle.push_back(i + 1); |
589 | shuffle.push_back(i); |
590 | |
591 | return F->createTranspose(input.getNode()->getName().str() + ".transpose" , |
592 | input, shuffle); |
593 | } |
594 | |
595 | return input; |
596 | } |
597 | |
598 | Error loadMatMul(const OpType &op, ArgumentDictionaryTy &dict) { |
599 | const std::string &opName = loadOperatorName(op); |
600 | NodeValue LHS; |
601 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0))); |
602 | NodeValue RHS; |
603 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1))); |
604 | |
605 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, |
606 | handleMatMulTranspose(G_, dict, "trans_a" , LHS)); |
607 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, |
608 | handleMatMulTranspose(G_, dict, "trans_b" , RHS)); |
609 | |
610 | Node *node = G_->createMatMul(opName, LHS, RHS); |
611 | |
612 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
613 | return Error::success(); |
614 | } |
615 | |
616 | Error loadBatchMatMul(const OpType &op, ArgumentDictionaryTy &dict) { |
617 | const std::string &opName = loadOperatorName(op); |
618 | NodeValue LHS; |
619 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getNodeValueByName(op.input(0))); |
620 | NodeValue RHS; |
621 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getNodeValueByName(op.input(1))); |
622 | |
623 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, |
624 | handleMatMulTranspose(G_, dict, "trans_a" , LHS)); |
625 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, |
626 | handleMatMulTranspose(G_, dict, "trans_b" , RHS)); |
627 | |
628 | const size_t numDimsLHS = LHS.dims().size(); |
629 | const size_t numDimsRHS = RHS.dims().size(); |
630 | RETURN_ERR_IF_NOT( |
631 | numDimsLHS >= 2, |
632 | opErrMsg(op, "BatchMatMul 1D operands are not yet supported." )); |
633 | RETURN_ERR_IF_NOT( |
634 | numDimsRHS >= 2, |
635 | opErrMsg(op, "BatchMatMul 1D operands are not yet supported." )); |
636 | |
637 | // This is a very simple case when we don't need any broadcasting |
638 | if (numDimsLHS == 2 && numDimsRHS == 2) { |
639 | Node *node = G_->createMatMul(opName, LHS, RHS); |
640 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
641 | return Error::success(); |
642 | } |
643 | |
644 | // In the rest of the function body we: |
645 | // 1. normalize operands using broadcasting rules, |
646 | // 2. convert normalized operands to 3D matrices, so they look like these: |
647 | // LHS = {numBatches, N, M} |
648 | // RHS = {numBatches, M, P} |
649 | // Result = {numBatches, N, P}, |
650 | // 3. multiply 3D matrices using createBatchMatMul(), result will be 3D, |
651 | // 4. convert the result to the normalized broadcast shape. |
652 | |
653 | const dim_t N = LHS.dims()[numDimsLHS - 2]; |
654 | const dim_t M = LHS.dims()[numDimsLHS - 1]; |
655 | const dim_t P = RHS.dims()[numDimsRHS - 1]; |
656 | |
657 | RETURN_ERR_IF_NOT( |
658 | RHS.dims()[numDimsRHS - 2] == M, |
659 | opErrMsg(op, "BatchMatMul operands dimensions are invalid." )); |
660 | |
661 | // Calculate broadcast shape and convert both operands to that shape |
662 | const std::vector<dim_t> originalDimsLHS{LHS.dims().begin(), |
663 | LHS.dims().end()}; |
664 | const std::vector<dim_t> originalDimsRHS{RHS.dims().begin(), |
665 | RHS.dims().end()}; |
666 | std::vector<dim_t> resultShape{P, N}; |
667 | resultShape.reserve(std::max(numDimsLHS, numDimsRHS)); |
668 | dim_t numBatches = 1; |
669 | int indLHS = numDimsLHS - 3; // skip last two dims |
670 | int indRHS = numDimsRHS - 3; // skip last two dims |
671 | for (; indLHS >= 0 && indRHS >= 0; --indLHS, --indRHS) { |
672 | const dim_t dimLHS = originalDimsLHS[indLHS]; |
673 | const dim_t dimRHS = originalDimsRHS[indRHS]; |
674 | |
675 | RETURN_ERR_IF_NOT( |
676 | (dimLHS == dimRHS || (dimLHS == 1) || dimRHS == 1), |
677 | opErrMsg(op, "BatchMatMul dimensions cannot be broadcast." )); |
678 | dim_t dim = 1; |
679 | if (dimLHS == dimRHS) { |
680 | dim = dimLHS; |
681 | } else if (dimLHS == 1) { |
682 | dim = dimRHS; |
683 | LHS = G_->createTile(opName + ".tileDim" , LHS, dim, indLHS); |
684 | } else { |
685 | dim = dimLHS; |
686 | RHS = G_->createTile(opName + ".tileDim" , RHS, dim, indRHS); |
687 | } |
688 | resultShape.push_back(dim); |
689 | numBatches *= dim; |
690 | } |
691 | for (; indLHS >= 0; --indLHS) { |
692 | const dim_t dim = originalDimsLHS[indLHS]; |
693 | resultShape.push_back(dim); |
694 | numBatches *= dim; |
695 | RHS = G_->createExpandDims(opName + ".addDim" , RHS, {0}); |
696 | RHS = G_->createTile(opName + ".tileDim" , RHS, dim, 0); |
697 | } |
698 | for (; indRHS >= 0; --indRHS) { |
699 | const dim_t dim = originalDimsRHS[indRHS]; |
700 | resultShape.push_back(dim); |
701 | numBatches *= dim; |
702 | LHS = G_->createExpandDims(opName + ".addDim" , LHS, {0}); |
703 | LHS = G_->createTile(opName + ".tileDim" , LHS, dim, 0); |
704 | } |
705 | std::reverse(resultShape.begin(), resultShape.end()); |
706 | |
707 | // Broadcast shape might have more than 3 dims, |
708 | // therefore, optionally, reshape the operands |
709 | if (resultShape.size() > 3) { |
710 | LHS = |
711 | G_->createReshape(opName + ".reshapeLHS3D" , LHS, {numBatches, N, M}); |
712 | RHS = |
713 | G_->createReshape(opName + ".reshapeRHS3D" , RHS, {numBatches, M, P}); |
714 | } |
715 | Node *node = G_->createBatchMatMul(opName, LHS, RHS); |
716 | |
717 | // Optionally, reshape result to broadcast shape |
718 | if (resultShape.size() != 3) { |
719 | node = G_->createReshape(opName + ".reshapeResult" , node, resultShape); |
720 | } |
721 | |
722 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
723 | return Error::success(); |
724 | } |
725 | |
726 | Error loadArithmetic(llvm::StringRef typeName, const OpType &op, |
727 | ArgumentDictionaryTy &dict) { |
728 | const std::string &opName = loadOperatorName(op); |
729 | NodeValue in0; |
730 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
731 | NodeValue in1; |
732 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
733 | |
734 | bool broadcast; |
735 | ASSIGN_VALUE_OR_RETURN_ERR(broadcast, getBroadcast(dict)); |
736 | // Check implicit broadcast |
737 | if (!broadcast && in0.dims().size() != in1.dims().size()) { |
738 | bool validBroadcast = true; |
739 | auto dimsA = in0.dims(); |
740 | auto dimsB = in1.dims(); |
741 | for (int i = dimsA.size() - 1, j = dimsB.size() - 1; i >= 0 && j >= 0;) { |
742 | auto a = dimsA[i]; |
743 | auto b = dimsB[j]; |
744 | if (!(a == b || a == 1 || b == 1)) { |
745 | validBroadcast = false; |
746 | break; |
747 | } |
748 | --i; |
749 | --j; |
750 | } |
751 | if (!validBroadcast) { |
752 | LOG(WARNING) << "Invalid broadcast rule for inputs of " << opName; |
753 | } |
754 | broadcast = validBroadcast; |
755 | } |
756 | |
757 | int axis = -1; |
758 | |
759 | // Broadcasting can be: |
760 | // - multidirectional (ONNX opset 7+), or |
761 | // - unidirectional (ONNX opset 1->6, Caffe2). |
762 | |
763 | // Unidirectional broadcasting consists of broadcasting the right operand |
764 | // (in1) so that it matches the shape of the left operand (in0). |
765 | if (broadcast && !hasMultidirectionalBroadcast(typeName)) { |
766 | // With unidirectional broadcasting, the 'axis' attribute specifies |
767 | // from how much the right operand shape must be 'shifted' right. |
768 | // - In Caffe2, the 'axis' attribute is optional. If not specified, axis |
769 | // must be automatically computed so that the trailing-most dimensions |
770 | // of in1 is aligned to the trailing-most dimension of in0. |
771 | // - In ONNX, the 'axis' attribute is mandatory. axis == -1 is |
772 | // equivalent to no axis specified in Caffe2. |
773 | |
774 | if (dict.count("axis" )) { |
775 | ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis" ])); |
776 | } |
777 | if (axis == -1) { |
778 | // Align trailing most dimensions. |
779 | axis = in0.dims().size() - in1.dims().size(); |
780 | } |
781 | } |
782 | |
783 | Node *node = nullptr; |
784 | if (broadcast) { |
785 | if (typeName == "Mul" ) { |
786 | node = G_->createNodeWithBroadcast<MulNode>(opName, axis, in0, in1); |
787 | } else if (typeName == "Add" ) { |
788 | node = G_->createNodeWithBroadcast<AddNode>(opName, axis, in0, in1); |
789 | } else if (typeName == "Sub" ) { |
790 | node = G_->createNodeWithBroadcast<SubNode>(opName, axis, in0, in1); |
791 | } else if (typeName == "Div" ) { |
792 | node = G_->createNodeWithBroadcast<DivNode>(opName, axis, in0, in1); |
793 | } else if (typeName == "Pow" ) { |
794 | node = G_->createNodeWithBroadcast<PowNode>(opName, axis, in0, in1); |
795 | } else { |
796 | return MAKE_ERR("Unsupported arithmetic typeName" ); |
797 | } |
798 | } else { |
799 | if (typeName == "Mul" ) { |
800 | node = G_->createMul(opName, in0, in1); |
801 | } else if (typeName == "Add" ) { |
802 | node = G_->createAdd(opName, in0, in1); |
803 | } else if (typeName == "Sub" ) { |
804 | node = G_->createSub(opName, in0, in1); |
805 | } else if (typeName == "Div" ) { |
806 | node = G_->createDiv(opName, in0, in1); |
807 | } else if (typeName == "Pow" ) { |
808 | node = G_->createPow(opName, in0, in1); |
809 | } else { |
810 | return MAKE_ERR("Unsupported arithmetic typeName" ); |
811 | } |
812 | } |
813 | |
814 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
815 | return Error::success(); |
816 | } |
817 | |
818 | Error loadSplit(const OpType &op, ArgumentDictionaryTy &dict) { |
819 | const std::string &opName = loadOperatorName(op); |
820 | NodeValue in; |
821 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
822 | size_t axis = 0; |
823 | if (dict.count("axis" )) { |
824 | ASSIGN_VALUE_OR_RETURN_ERR( |
825 | axis, loadAxis<size_t>(dict["axis" ], in.dims().size())); |
826 | } |
827 | |
828 | std::vector<dim_t> split; |
829 | if (dict.count("split" )) { |
830 | ASSIGN_VALUE_OR_RETURN_ERR(split, getShape<dim_t>(dict["split" ])); |
831 | } |
832 | |
833 | std::vector<SliceNode *> outputs; |
834 | G_->createSplit(opName, in, op.output_size(), axis, split, outputs); |
835 | |
836 | for (int i = 0, e = op.output_size(); i < e; i++) { |
837 | // Each output from Split is a SliceNode which only has a single output, |
838 | // so only use 0 here as the node value result. |
839 | nodeValueByName_[op.output(i)] = outputs[i]->getResult(); |
840 | } |
841 | return Error::success(); |
842 | } |
843 | |
844 | Error loadReshape(const OpType &op, ArgumentDictionaryTy &dict) { |
845 | const std::string &opName = loadOperatorName(op); |
846 | NodeValue in; |
847 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
848 | |
849 | // Get the requested shape from the model. |
850 | // First look at input tensors, then at the "shape" attribute. |
851 | std::vector<dim_t> requestedDims; |
852 | if (op.input_size() > 1) { |
853 | if (!getConstantByNameOrNull(op.input(1))) { |
854 | return MAKE_ERR(opErrMsg( |
855 | op, |
856 | "Reshape: Non-constant shape tensors are unsupported by Glow." )); |
857 | } |
858 | const Constant *constShapeConst; |
859 | ASSIGN_VALUE_OR_RETURN_ERR(constShapeConst, |
860 | getConstantByName(op.input(1))); |
861 | auto TH = constShapeConst->getPayload().getHandle<int64_t>(); |
862 | for (auto dim : TH) { |
863 | requestedDims.push_back(dim); |
864 | } |
865 | } else if (dict.count("shape" )) { |
866 | RETURN_ERR_IF_NOT( |
867 | op.input_size() == 1, |
868 | opErrMsg( |
869 | op, |
870 | "Reshape: Cannot specify new shape by both argument and input." )); |
871 | std::vector<int64_t> protoDims; |
872 | ASSIGN_VALUE_OR_RETURN_ERR(protoDims, getShape<int64_t>(dict["shape" ])); |
873 | |
874 | for (auto dim : protoDims) { |
875 | requestedDims.push_back(dim); |
876 | } |
877 | } else { |
878 | return MAKE_ERR(opErrMsg(op, |
879 | "Reshape: Missing output shape information for " |
880 | "the Reshape operator." )); |
881 | } |
882 | |
883 | // Compute the actual new shape |
884 | ssize_t negOneIndex = -1; |
885 | llvm::ArrayRef<dim_t> inputDims = in.dims(); |
886 | std::vector<dim_t> outputDims; |
887 | int64_t dimProduct = 1; |
888 | for (size_t i = 0, e = requestedDims.size(); i != e; i++) { |
889 | dim_t newDim = requestedDims[i]; |
890 | if (newDim == 0) { |
891 | // 0 means that corresponding input dimension should be propagated to |
892 | // the output. |
893 | newDim = inputDims[i]; |
894 | } |
895 | if (newDim != (dim_t)-1) { |
896 | dimProduct *= newDim; |
897 | outputDims.push_back(newDim); |
898 | } else { |
899 | // -1 means that the corresponding dimension should be inferred |
900 | // from all other dimensions, so that tensor size remains the same. |
901 | RETURN_ERR_IF_NOT( |
902 | negOneIndex < 0, |
903 | opErrMsg( |
904 | op, |
905 | "Reshape: At most one dimension of the new shape can be -1." )); |
906 | negOneIndex = (ssize_t)i; |
907 | // The -1 case value is handled later. |
908 | outputDims.push_back(0); |
909 | } |
910 | } |
911 | if (negOneIndex >= 0) { |
912 | outputDims[negOneIndex] = in.getType()->size() / dimProduct; |
913 | } |
914 | |
915 | auto *node = G_->createReshape(opName, in, outputDims); |
916 | |
917 | // Caffe2 sometimes outputs old_shape which goes unused. We do not currently |
918 | // support it, so explicitly only set the first output. |
919 | nodeValueByName_[op.output(0)] = node->getResult(); |
920 | return Error::success(); |
921 | } |
922 | |
923 | Error loadTranspose(const OpType &op, ArgumentDictionaryTy &dict, |
924 | llvm::StringRef permArgName) { |
925 | const std::string &opName = loadOperatorName(op); |
926 | NodeValue in; |
927 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
928 | |
929 | // There is a difference between ONNX and Caffe2 specs for Transpose: |
930 | // one contains permutation under name "perm", the other contains it under |
931 | // argument name "axes". That's why the name is passed as a parameter. |
932 | std::vector<unsigned_t> perm; |
933 | if (dict.count(permArgName.str())) |
934 | ASSIGN_VALUE_OR_RETURN_ERR(perm, |
935 | getShape<unsigned_t>(dict[permArgName.str()])); |
936 | |
937 | if (perm.empty()) { |
938 | // Empty permutation argument means reversing axes order. |
939 | size_t N = in.dims().size(); |
940 | for (int64_t i = N - 1; i >= 0; i--) |
941 | perm.push_back(i); |
942 | } |
943 | |
944 | auto *T = G_->createTranspose(opName, in, perm); |
945 | |
946 | RETURN_IF_ERR(addNodeAsOutput(op, T)); |
947 | return Error::success(); |
948 | } |
949 | |
950 | Error loadFlatten(const OpType &op, ArgumentDictionaryTy &dict) { |
951 | const std::string &opName = loadOperatorName(op); |
952 | NodeValue in; |
953 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
954 | int axis = 1; |
955 | if (dict.count("axis" )) { |
956 | ASSIGN_VALUE_OR_RETURN_ERR(axis, |
957 | loadAxis<int>(dict["axis" ], in.dims().size())); |
958 | } |
959 | auto *node = G_->createFlatten(opName, in, axis); |
960 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
961 | return Error::success(); |
962 | } |
963 | |
964 | Error loadIdentity(const OpType &op, ArgumentDictionaryTy &dict) { |
965 | NodeValue in; |
966 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
967 | |
968 | // If loading partitioned DAG then check if this identity is used for an |
969 | // intermediate, and if so create the Save+PH with the correct name. |
970 | if (partNameToFun_.size()) { |
971 | int intermediate = 0; |
972 | if (dict.count("isIntermediateOutputForDAG" )) { |
973 | ASSIGN_VALUE_OR_RETURN_ERR( |
974 | intermediate, loadInt(dict.at("isIntermediateOutputForDAG" ))); |
975 | } |
976 | |
977 | if (intermediate) { |
978 | const std::string &opName = loadOperatorName(op); |
979 | Placeholder *PH = mod_.getPlaceholderByNameSlow(op.output(0)); |
980 | if (!PH) { |
981 | PH = mod_.createPlaceholder(in.getType(), op.output(0), |
982 | /* isTrainable */ false); |
983 | } else { |
984 | RETURN_ERR_IF_NOT( |
985 | loadIntoExistingModule_, |
986 | opErrMsg(op, "Found pre-existing PH by name " + op.output(0))); |
987 | RETURN_ERR_IF_NOT( |
988 | PH->getType()->isEqual(in.getType()), |
989 | opErrMsg(op, "Mismatch on pre-existing intermediate PH type" )); |
990 | } |
991 | G_->createSave(opName, in, PH, /* skipSuffix */ true); |
992 | intermediatePHsByName_[op.output(0)] = PH; |
993 | in = PH->getOutput(); |
994 | } |
995 | } |
996 | |
997 | nodeValueByName_[op.output(0)] = in; |
998 | return Error::success(); |
999 | } |
1000 | |
1001 | Error loadReduceOp(llvm::StringRef typeName, const OpType &op, |
1002 | ArgumentDictionaryTy &dict) { |
1003 | const std::string &opName = loadOperatorName(op); |
1004 | NodeValue in; |
1005 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1006 | |
1007 | std::vector<unsigned_t> shapeAxes = {}; |
1008 | if (dict.count("axes" )) { |
1009 | ASSIGN_VALUE_OR_RETURN_ERR( |
1010 | shapeAxes, loadAxes<unsigned_t>(dict["axes" ], in.dims().size())); |
1011 | } else { |
1012 | shapeAxes.resize(in.dims().size()); |
1013 | std::iota(shapeAxes.begin(), shapeAxes.end(), 0); |
1014 | } |
1015 | |
1016 | std::sort(shapeAxes.begin(), shapeAxes.end()); |
1017 | |
1018 | llvm::ArrayRef<unsigned_t> axes(shapeAxes); |
1019 | |
1020 | // Check if axes elements are unique. |
1021 | if (axes.size() > 1) { |
1022 | auto it = std::unique(shapeAxes.begin(), shapeAxes.end()); |
1023 | if (it != shapeAxes.end()) { |
1024 | return MAKE_ERR(opErrMsg(op, "ReduceOp Axes values are not unique." ), |
1025 | ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); |
1026 | } |
1027 | } |
1028 | |
1029 | bool keepDims = true; |
1030 | if (dict.count("keepdims" )) { |
1031 | int keepdims; |
1032 | ASSIGN_VALUE_OR_RETURN_ERR(keepdims, loadInt(dict["keepdims" ])); |
1033 | keepDims = (bool)keepdims; |
1034 | } |
1035 | |
1036 | NodeValue node; |
1037 | if (typeName == "ReduceMean" ) { |
1038 | node = G_->createBatchedReduceMean(opName, in, axes); |
1039 | } else if (typeName == "ReduceSum" ) { |
1040 | node = G_->createBatchedReduceAdd(opName, in, axes); |
1041 | } else if (typeName == "ReduceMin" ) { |
1042 | node = G_->createBatchedReduceMin(opName, in, axes); |
1043 | } else if (typeName == "ReduceMax" ) { |
1044 | node = G_->createBatchedReduceMax(opName, in, axes); |
1045 | } else if (typeName == "ReduceProd" ) { |
1046 | node = G_->createBatchedReduceProd(opName, in, axes); |
1047 | } else if (typeName == "ReduceSumSquare" ) { |
1048 | node = G_->createBatchedReduceSumSquare(opName, in, axes); |
1049 | } else { |
1050 | return MAKE_ERR("Unsupported Reduce Op " + typeName.str()); |
1051 | } |
1052 | |
1053 | // Our batched reduce add/mean does not keep the dim; reshape if necessary. |
1054 | if (keepDims) { |
1055 | |
1056 | std::vector<dim_t> shape = node.dims(); |
1057 | |
1058 | // Add removed axes. Requires increasing order sort - done above. |
1059 | for (const auto &axis : shapeAxes) { |
1060 | shape.insert(shape.begin() + axis, 1); |
1061 | } |
1062 | node = G_->createReshape(opName, node, shape); |
1063 | } |
1064 | |
1065 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1066 | return Error::success(); |
1067 | } |
1068 | |
1069 | Error loadBatchOneHot(const OpType &op) { |
1070 | const std::string &opName = loadOperatorName(op); |
1071 | NodeValue data; |
1072 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1073 | NodeValue lengths; |
1074 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(1))); |
1075 | NodeValue values; |
1076 | ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2))); |
1077 | |
1078 | auto *node = G_->createBatchOneHot(opName, data, lengths, values); |
1079 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1080 | return Error::success(); |
1081 | } |
1082 | |
1083 | Error loadSparseLengthsSum(const OpType &op, ArgumentDictionaryTy &dict) { |
1084 | NodeValue in0; |
1085 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
1086 | NodeValue in1; |
1087 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
1088 | NodeValue in2; |
1089 | ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2))); |
1090 | LengthsMode lengthsMode; |
1091 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
1092 | float avgLength; |
1093 | ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict)); |
1094 | auto *node = G_->createSparseLengthsSum(loadOperatorName(op), in0, in1, in2, |
1095 | lengthsMode, avgLength); |
1096 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1097 | return Error::success(); |
1098 | } |
1099 | |
1100 | Error loadSparseLengthsWeightedSum(const OpType &op, |
1101 | ArgumentDictionaryTy &dict) { |
1102 | NodeValue in0; |
1103 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
1104 | NodeValue in1; |
1105 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
1106 | NodeValue in2; |
1107 | ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2))); |
1108 | NodeValue in3; |
1109 | ASSIGN_VALUE_OR_RETURN_ERR(in3, getNodeValueByName(op.input(3))); |
1110 | LengthsMode lengthsMode; |
1111 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
1112 | float avgLength; |
1113 | ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict)); |
1114 | auto *node = G_->createSparseLengthsWeightedSum( |
1115 | loadOperatorName(op), in0, in1, in2, in3, lengthsMode, avgLength); |
1116 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1117 | return Error::success(); |
1118 | } |
1119 | |
1120 | Error loadEmbedding(const OpType &op, ArgumentDictionaryTy &dict) { |
1121 | NodeValue weights; |
1122 | ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(0))); |
1123 | |
1124 | NodeValue indices; |
1125 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
1126 | |
1127 | int64_t padIdx = -1; |
1128 | if (dict.count("padIdx" )) { |
1129 | ASSIGN_VALUE_OR_RETURN_ERR(padIdx, loadInt(dict["padIdx" ])); |
1130 | } |
1131 | bool scale = false; |
1132 | if (dict.count("scale" )) { |
1133 | ASSIGN_VALUE_OR_RETURN_ERR(scale, loadInt(dict["scale" ])); |
1134 | scale = (bool)scale; |
1135 | RETURN_ERR_IF_NOT(scale == false, |
1136 | "Currently only support scale_grad_by_freq == 'false'" ); |
1137 | } |
1138 | bool sparse = false; |
1139 | if (dict.count("sparse" )) { |
1140 | ASSIGN_VALUE_OR_RETURN_ERR(sparse, loadInt(dict["sparse" ])); |
1141 | sparse = (bool)sparse; |
1142 | RETURN_ERR_IF_NOT(sparse == false, |
1143 | "Currently only support sparse == 'false'" ); |
1144 | } |
1145 | auto *node = G_->createEmbedding(loadOperatorName(op), weights, indices, |
1146 | padIdx, scale, sparse); |
1147 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1148 | return Error::success(); |
1149 | } |
1150 | |
1151 | Error loadEmbeddingBag(const OpType &op, ArgumentDictionaryTy &dict) { |
1152 | NodeValue in0; |
1153 | ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0))); |
1154 | NodeValue in1; |
1155 | ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1))); |
1156 | NodeValue in2; |
1157 | ASSIGN_VALUE_OR_RETURN_ERR(in2, getNodeValueByName(op.input(2))); |
1158 | NodeValue in3; |
1159 | ASSIGN_VALUE_OR_RETURN_ERR(in3, getNodeValueByName(op.input(3))); |
1160 | LengthsMode lengthsMode; |
1161 | ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict)); |
1162 | float avgLength; |
1163 | ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict)); |
1164 | auto *node = G_->createEmbeddingBag( |
1165 | loadOperatorName(op), in0, in1, in2, in3, |
1166 | /* hasEndOffset */ false, lengthsMode, avgLength); |
1167 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1168 | return Error::success(); |
1169 | } |
1170 | |
1171 | Error loadLengthsToRanges(const OpType &op) { |
1172 | NodeValue in; |
1173 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1174 | auto *node = G_->createLengthsToRanges(loadOperatorName(op), in); |
1175 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1176 | return Error::success(); |
1177 | } |
1178 | |
1179 | Error loadBatchBoxCox(const OpType &op) { |
1180 | NodeValue data; |
1181 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1182 | NodeValue lambda1; |
1183 | ASSIGN_VALUE_OR_RETURN_ERR(lambda1, getNodeValueByName(op.input(1))); |
1184 | NodeValue lambda2; |
1185 | ASSIGN_VALUE_OR_RETURN_ERR(lambda2, getNodeValueByName(op.input(2))); |
1186 | auto *node = |
1187 | G_->createBatchBoxCox(loadOperatorName(op), data, lambda1, lambda2); |
1188 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1189 | return Error::success(); |
1190 | } |
1191 | |
1192 | Error loadDotProduct(const OpType &op) { |
1193 | NodeValue X; |
1194 | ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0))); |
1195 | NodeValue Y; |
1196 | ASSIGN_VALUE_OR_RETURN_ERR(Y, getNodeValueByName(op.input(1))); |
1197 | RETURN_ERR_IF_NOT(X.dims() == Y.dims(), |
1198 | opErrMsg(op, "X and Y must have the same dimensions" )); |
1199 | auto *node = G_->createDotProduct(loadOperatorName(op), X, Y); |
1200 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1201 | return Error::success(); |
1202 | } |
1203 | |
1204 | Error loadReplaceNaN(const OpType &op, ArgumentDictionaryTy &dict) { |
1205 | // Load the input and NaN replacement value: |
1206 | NodeValue input; |
1207 | ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); |
1208 | auto valueIt = dict.find("value" ); |
1209 | float value = 0.0f; |
1210 | if (valueIt != dict.end()) { |
1211 | ASSIGN_VALUE_OR_RETURN_ERR(value, loadFloat(valueIt->second)); |
1212 | } |
1213 | auto *node = G_->createReplaceNaN(loadOperatorName(op), input, value); |
1214 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1215 | return Error::success(); |
1216 | } |
1217 | |
1218 | Error loadLengthsSum(const OpType &op) { |
1219 | const std::string &opName = loadOperatorName(op); |
1220 | NodeValue data; |
1221 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1222 | NodeValue lengths; |
1223 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(1))); |
1224 | |
1225 | RETURN_ERR_IF_NOT( |
1226 | lengths.dims().size() == 1, |
1227 | opErrMsg( |
1228 | op, |
1229 | strFormat("LengthsSum: Lengths must be a 1D vector, but found %zu " , |
1230 | lengths.dims().size()))); |
1231 | |
1232 | auto *node = G_->createLengthsSum(opName, data, lengths); |
1233 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1234 | return Error::success(); |
1235 | } |
1236 | |
1237 | Error loadExpandDims(const OpType &op, ArgumentDictionaryTy &dict) { |
1238 | NodeValue in; |
1239 | ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0))); |
1240 | std::vector<dim_t> shape; |
1241 | ASSIGN_VALUE_OR_RETURN_ERR(shape, getShape<dim_t>(dict["dims" ])); |
1242 | |
1243 | Node *node = G_->createExpandDims(loadOperatorName(op), in, shape); |
1244 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1245 | |
1246 | return Error::success(); |
1247 | } |
1248 | |
1249 | Error loadSparseToDense(const OpType &op, ArgumentDictionaryTy &dict) { |
1250 | if (op.input_size() != 3) { |
1251 | return MAKE_ERR(opErrMsg( |
1252 | op, |
1253 | strFormat( |
1254 | "SparseToDense operator must have three inputs, but found %d " , |
1255 | op.input_size()))); |
1256 | } |
1257 | |
1258 | NodeValue indices; |
1259 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(0))); |
1260 | NodeValue values; |
1261 | ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(1))); |
1262 | NodeValue dataToInferDim; |
1263 | ASSIGN_VALUE_OR_RETURN_ERR(dataToInferDim, getNodeValueByName(op.input(2))); |
1264 | |
1265 | RETURN_ERR_IF_NOT(indices.dims().size() == 1 || indices.dims().size() == 2, |
1266 | opErrMsg(op, "Indices must be 1D or 2D tensor." )); |
1267 | RETURN_ERR_IF_NOT(indices.getElementType() == ElemKind::Int32ITy || |
1268 | indices.getElementType() == ElemKind::Int64ITy, |
1269 | opErrMsg(op, "Indices must be of int32 or int64 type." )); |
1270 | |
1271 | const std::string &opName = loadOperatorName(op); |
1272 | |
1273 | if (indices.dims().size() == 1) { |
1274 | indices = G_->createReshape(opName + ".indices2D" , indices, |
1275 | {indices.dims()[0], 1}); |
1276 | } else { |
1277 | RETURN_ERR_IF_NOT( |
1278 | indices.dims()[1] == 1, |
1279 | opErrMsg(op, "Second dimension should be 1 in indices." )); |
1280 | } |
1281 | |
1282 | ShapeVector outDims{values.dims().begin(), values.dims().end()}; |
1283 | outDims[0] = dataToInferDim.dims()[0]; |
1284 | auto outTy = |
1285 | G_->getParent()->uniqueTypeWithNewShape(values.getType(), outDims); |
1286 | Node *data = G_->createSplat(opName + ".data" , outTy, 0.f); |
1287 | |
1288 | // SparseToDense has very similar behavior to ScatterND from ONNX |
1289 | // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ScatterND, |
1290 | // therefore we can use ScatterND to implement SparseToDense. |
1291 | Node *result = G_->createScatterData(opName + ".scatterData" , data, indices, |
1292 | values, true); |
1293 | |
1294 | RETURN_IF_ERR(addNodeAsOutput(op, result)); |
1295 | return Error::success(); |
1296 | } |
1297 | |
1298 | Error loadSparseToDenseMask(const OpType &op, ArgumentDictionaryTy &dict) { |
1299 | size_t inputSize = op.input_size(); |
1300 | if (inputSize != 3 && inputSize != 4) { |
1301 | return MAKE_ERR( |
1302 | opErrMsg(op, strFormat("SparseToDenseMask operator must have " |
1303 | "3 or 4 inputs, but found %zu " , |
1304 | inputSize))); |
1305 | } |
1306 | |
1307 | NodeValue indices; |
1308 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(0))); |
1309 | NodeValue values; |
1310 | ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(1))); |
1311 | NodeValue defaultValue; |
1312 | ASSIGN_VALUE_OR_RETURN_ERR(defaultValue, getNodeValueByName(op.input(2))); |
1313 | |
1314 | NodeValue lengths; |
1315 | if (inputSize == 4) { |
1316 | ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(3))); |
1317 | } else { |
1318 | // If Lengths input is not present, create scalar containing number of |
1319 | // index-value pairs. |
1320 | auto *lengthsConstant = |
1321 | mod_.createConstant(ElemKind::Int32ITy, {}, "lengthsConstant" ); |
1322 | lengthsConstant->getPayloadMutable().template getHandle<int32_t>().raw( |
1323 | 0) = indices.dims()[0]; |
1324 | lengths = lengthsConstant->getOutput(); |
1325 | } |
1326 | |
1327 | std::vector<dim_t> mask; |
1328 | ASSIGN_VALUE_OR_RETURN_ERR(mask, getShape<dim_t>(dict["mask" ])); |
1329 | |
1330 | auto *node = G_->createSparseToDenseMask( |
1331 | loadOperatorName(op), indices, values, defaultValue, lengths, mask); |
1332 | RETURN_IF_ERR(addNodeAsOutput(op, node)); |
1333 | return Error::success(); |
1334 | } |
1335 | |
1336 | Error loadGatherOps(const std::string &typeName, const OpType &op, |
1337 | const ArgumentDictionaryTy &dict) { |
1338 | |
1339 | NodeValue data; |
1340 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1341 | NodeValue indices; |
1342 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
1343 | size_t axis = typeName == "Gather" ? 0 : 1; |
1344 | |
1345 | if (dict.count("axis" )) { |
1346 | ASSIGN_VALUE_OR_RETURN_ERR( |
1347 | axis, |
1348 | loadAxis<size_t>(dict.find("axis" )->second, data.dims().size())); |
1349 | } |
1350 | |
1351 | if (indices.getElementType() != ElemKind::Int64ITy && |
1352 | indices.getElementType() != ElemKind::Int32ITy) { |
1353 | // If the index type is not Int32 or Int64 insert a conversion layer to |
1354 | // introduce robustness against model problems. Constant Float indices |
1355 | // will get converted to integer indices via constant folding pass. |
1356 | indices = G_->createConvertTo( |
1357 | loadOperatorName(op) + "_idx_convertToi32" , indices, |
1358 | G_->getParent()->uniqueType(ElemKind::Int32ITy, indices.dims())); |
1359 | } |
1360 | |
1361 | auto *GN = G_->createGather(loadOperatorName(op), data, indices, axis); |
1362 | RETURN_IF_ERR(addNodeAsOutput(op, GN)); |
1363 | return Error::success(); |
1364 | } |
1365 | |
1366 | Error loadGatherND(const std::string &typeName, const OpType &op, |
1367 | const ArgumentDictionaryTy &dict) { |
1368 | |
1369 | NodeValue data; |
1370 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1371 | NodeValue indices; |
1372 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1))); |
1373 | |
1374 | if (indices.getElementType() != ElemKind::Int64ITy && |
1375 | indices.getElementType() != ElemKind::Int32ITy) { |
1376 | // If the index type is not Int32 or Int64 insert a conversion layer to |
1377 | // introduce robustness against model problems. Constant Float indices |
1378 | // will get converted to integer indices via constant folding pass. |
1379 | indices = G_->createConvertTo( |
1380 | loadOperatorName(op) + "_idx_convertToi32" , indices, |
1381 | G_->getParent()->uniqueType(ElemKind::Int32ITy, indices.dims())); |
1382 | } |
1383 | |
1384 | auto *GN = G_->createGatherND(loadOperatorName(op), data, indices); |
1385 | RETURN_IF_ERR(addNodeAsOutput(op, GN)); |
1386 | return Error::success(); |
1387 | } |
1388 | |
1389 | Error loadGatherRanges(const std::string &typeName, const OpType &op, |
1390 | ArgumentDictionaryTy &dict) { |
1391 | NodeValue data; |
1392 | ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0))); |
1393 | RETURN_ERR_IF_NOT( |
1394 | data.dims().size() == 1, |
1395 | opErrMsg(op, strFormat("GatherRanges: Data must be a 1D vector, but " |
1396 | "found vector size %zu " , |
1397 | data.dims().size()))); |
1398 | |
1399 | NodeValue ranges; |
1400 | ASSIGN_VALUE_OR_RETURN_ERR(ranges, getNodeValueByName(op.input(1))); |
1401 | RETURN_ERR_IF_NOT( |
1402 | ranges.dims().size() == 3, |
1403 | opErrMsg(op, strFormat("GatherRanges: Ranges must be a 3D vector, but " |
1404 | "found vector size %zu " , |
1405 | ranges.dims().size()))); |
1406 | RETURN_ERR_IF_NOT( |
1407 | ranges.dims()[2] == 2, |
1408 | opErrMsg(op, strFormat("GatherRanges: Last dimension of " |
1409 | "ranges must be 2, but found %s" , |
1410 | std::to_string(ranges.dims()[2]).c_str()))); |
1411 | |
1412 | auto maxOutputSizeIt = dict.find("maxOutputSize" ); |
1413 | RETURN_ERR_IF_NOT(maxOutputSizeIt != dict.end(), |
1414 | opErrMsg(op, "GatherRanges: Require maxOutputSize when " |
1415 | "loading LengthsRangeFill." )); |
1416 | unsigned_t maxOutputSize; |
1417 | ASSIGN_VALUE_OR_RETURN_ERR(maxOutputSize, loadInt(maxOutputSizeIt->second)); |
1418 | |
1419 | Node *GR = G_->createGatherRanges(loadOperatorName(op), data, ranges, |
1420 | maxOutputSize); |
1421 | RETURN_IF_ERR(addNodeAsOutput(op, GR)); |
1422 | return Error::success(); |
1423 | } |
1424 | |
1425 | // Loads Less operator. Internally it's a cmpLT Node. |
1426 | Error loadLess(const OpType &op, ArgumentDictionaryTy &dict) { |
1427 | // Input Type. |
1428 | NodeValue xNV; |
1429 | ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0))); |
1430 | NodeValue yNV; |
1431 | ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(1))); |
1432 | |
1433 | std::string opName = loadOperatorName(op); |
1434 | |
1435 | auto *xNode = xNV.getNode(); |
1436 | auto *yNode = yNV.getNode(); |
1437 | |
1438 | Node *N = G_->createNodeWithBroadcast<CmpLTNode>(opName, /* axis */ -1, |
1439 | xNode, yNode); |
1440 | |
1441 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1442 | return Error::success(); |
1443 | } |
1444 | |
1445 | Error loadLogicalOps(llvm::StringRef typeName, const OpType &op) { |
1446 | std::string opName = loadOperatorName(op); |
1447 | NodeValue xNV; |
1448 | ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0))); |
1449 | NodeValue yNV; |
1450 | ASSIGN_VALUE_OR_RETURN_ERR(yNV, getNodeValueByName(op.input(1))); |
1451 | constexpr int axis = -1; |
1452 | Node *N = nullptr; |
1453 | if (typeName == "And" ) { |
1454 | N = G_->createNodeWithBroadcast<AndNode>(opName, axis, xNV, yNV); |
1455 | } else if (typeName == "Or" ) { |
1456 | N = G_->createNodeWithBroadcast<OrNode>(opName, axis, xNV, yNV); |
1457 | } else if (typeName == "Xor" ) { |
1458 | N = G_->createNodeWithBroadcast<XorNode>(opName, axis, xNV, yNV); |
1459 | } else { |
1460 | return MAKE_ERR("Unsupported Logical Operator" ); |
1461 | } |
1462 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1463 | return Error::success(); |
1464 | } |
1465 | |
1466 | Error loadNotOp(llvm::StringRef typeName, const OpType &op) { |
1467 | std::string opName = loadOperatorName(op); |
1468 | NodeValue xNV; |
1469 | ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0))); |
1470 | Node *N = G_->createNot(opName, xNV); |
1471 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1472 | return Error::success(); |
1473 | } |
1474 | |
1475 | // Loads Abs operator |
1476 | Error loadAbs(const OpType &op, ArgumentDictionaryTy &dict) { |
1477 | std::string opName = loadOperatorName(op); |
1478 | NodeValue xNV; |
1479 | ASSIGN_VALUE_OR_RETURN_ERR(xNV, getNodeValueByName(op.input(0))); |
1480 | auto *input = xNV.getNode(); |
1481 | |
1482 | auto *N = G_->createAbs(opName, input); |
1483 | RETURN_IF_ERR(addNodeAsOutput(op, N)); |
1484 | return Error::success(); |
1485 | } |
1486 | |
1487 | /// If operator type is supported, returns Expected<true> and creates new |
1488 | /// operator. Returns Operator<false> if operator type is not supported. |
1489 | /// Returns Error if an error occurred |
1490 | Expected<bool> tryLoadCommonOperator(llvm::StringRef typeName, |
1491 | const OpType &op, |
1492 | ArgumentDictionaryTy &dict) { |
1493 | if (typeName == "Relu" ) { |
1494 | RETURN_IF_ERR(loadRelu(op, dict)); |
1495 | return true; |
1496 | } |
1497 | if (typeName == "Sigmoid" ) { |
1498 | RETURN_IF_ERR(loadSigmoid(op, dict)); |
1499 | return true; |
1500 | } |
1501 | if (typeName == "Tanh" ) { |
1502 | RETURN_IF_ERR(loadTanh(op, dict)); |
1503 | return true; |
1504 | } |
1505 | if (typeName == "Exp" ) { |
1506 | RETURN_IF_ERR(loadExp(op, dict)); |
1507 | return true; |
1508 | } |
1509 | if (typeName == "Log" ) { |
1510 | RETURN_IF_ERR(loadLog(op, dict)); |
1511 | return true; |
1512 | } |
1513 | if (typeName == "Neg" ) { |
1514 | RETURN_IF_ERR(loadNeg(op, dict)); |
1515 | return true; |
1516 | } |
1517 | if (typeName == "Abs" ) { |
1518 | RETURN_IF_ERR(loadAbs(op, dict)); |
1519 | return true; |
1520 | } |
1521 | if (typeName == "Ceil" ) { |
1522 | RETURN_IF_ERR(loadCeil(op, dict)); |
1523 | return true; |
1524 | } |
1525 | if (typeName == "Floor" ) { |
1526 | RETURN_IF_ERR(loadFloor(op, dict)); |
1527 | return true; |
1528 | } |
1529 | if (typeName == "Shape" ) { |
1530 | RETURN_IF_ERR(loadShape(op, dict)); |
1531 | return true; |
1532 | } |
1533 | if (typeName == "Sqrt" ) { |
1534 | RETURN_IF_ERR(loadSqrt(op, dict)); |
1535 | return true; |
1536 | } |
1537 | if (typeName == "Sqr" ) { |
1538 | RETURN_IF_ERR(loadSqr(op, dict)); |
1539 | return true; |
1540 | } |
1541 | if (typeName == "Reciprocal" ) { |
1542 | RETURN_IF_ERR(loadReciprocal(op, dict)); |
1543 | return true; |
1544 | } |
1545 | if (typeName == "Sum" ) { |
1546 | RETURN_IF_ERR(loadSum(op, dict)); |
1547 | return true; |
1548 | } |
1549 | if (typeName == "LRN" ) { |
1550 | RETURN_IF_ERR(loadLRN(op, dict)); |
1551 | return true; |
1552 | } |
1553 | if (typeName == "Min" || typeName == "Max" ) { |
1554 | RETURN_IF_ERR(loadMinMax(typeName, op, dict)); |
1555 | return true; |
1556 | } |
1557 | if (typeName == "Mul" || typeName == "Add" || typeName == "Sub" || |
1558 | typeName == "Div" || typeName == "Pow" ) { |
1559 | RETURN_IF_ERR(loadArithmetic(typeName, op, dict)); |
1560 | return true; |
1561 | } |
1562 | if (typeName == "Split" ) { |
1563 | RETURN_IF_ERR(loadSplit(op, dict)); |
1564 | return true; |
1565 | } |
1566 | if (typeName == "Reshape" ) { |
1567 | RETURN_IF_ERR(loadReshape(op, dict)); |
1568 | return true; |
1569 | } |
1570 | if (typeName == "Flatten" ) { |
1571 | RETURN_IF_ERR(loadFlatten(op, dict)); |
1572 | return true; |
1573 | } |
1574 | if (typeName == "Dropout" ) { |
1575 | // Save the identity operation. Note the second output (mask) for Dropout |
1576 | // in Caffe2 and ONNX is unused during inference, and our Dropout does not |
1577 | // currently implement it, thus we do not call addNodeAsOutput() here. |
1578 | RETURN_IF_ERR(loadIdentity(op, dict)); |
1579 | return true; |
1580 | } |
1581 | if (typeName == "Identity" || typeName == "Alias" ) { |
1582 | RETURN_IF_ERR(loadIdentity(op, dict)); |
1583 | return true; |
1584 | } |
1585 | if (typeName == "ReduceMean" || typeName == "ReduceSum" || |
1586 | typeName == "ReduceMin" || typeName == "ReduceMax" || |
1587 | typeName == "ReduceProd" ) { |
1588 | RETURN_IF_ERR(loadReduceOp(typeName, op, dict)); |
1589 | return true; |
1590 | } |
1591 | if (typeName == "BatchMatMul" ) { |
1592 | RETURN_IF_ERR(loadBatchMatMul(op, dict)); |
1593 | return true; |
1594 | } |
1595 | if (typeName == "BatchOneHot" ) { |
1596 | RETURN_IF_ERR(loadBatchOneHot(op)); |
1597 | return true; |
1598 | } |
1599 | if (typeName == "SparseLengthsSum" ) { |
1600 | RETURN_IF_ERR(loadSparseLengthsSum(op, dict)); |
1601 | return true; |
1602 | } |
1603 | if (typeName == "SparseLengthsWeightedSum" ) { |
1604 | RETURN_IF_ERR(loadSparseLengthsWeightedSum(op, dict)); |
1605 | return true; |
1606 | } |
1607 | if (typeName == "EmbeddingBag" ) { |
1608 | RETURN_IF_ERR(loadEmbeddingBag(op, dict)); |
1609 | return true; |
1610 | } |
1611 | if (typeName == "Embedding" ) { |
1612 | RETURN_IF_ERR(loadEmbedding(op, dict)); |
1613 | return true; |
1614 | } |
1615 | if (typeName == "LengthsToRanges" ) { |
1616 | RETURN_IF_ERR(loadLengthsToRanges(op)); |
1617 | return true; |
1618 | } |
1619 | if (typeName == "BatchBoxCox" ) { |
1620 | RETURN_IF_ERR(loadBatchBoxCox(op)); |
1621 | return true; |
1622 | } |
1623 | if (typeName == "DotProduct" ) { |
1624 | RETURN_IF_ERR(loadDotProduct(op)); |
1625 | return true; |
1626 | } |
1627 | if (typeName == "ReplaceNaN" ) { |
1628 | RETURN_IF_ERR(loadReplaceNaN(op, dict)); |
1629 | return true; |
1630 | } |
1631 | if (typeName == "LengthsSum" ) { |
1632 | RETURN_IF_ERR(loadLengthsSum(op)); |
1633 | return true; |
1634 | } |
1635 | if (typeName == "ExpandDims" ) { |
1636 | RETURN_IF_ERR(loadExpandDims(op, dict)); |
1637 | return true; |
1638 | } |
1639 | if (typeName == "SparseToDense" ) { |
1640 | RETURN_IF_ERR(loadSparseToDense(op, dict)); |
1641 | return true; |
1642 | } |
1643 | if (typeName == "SparseToDenseMask" ) { |
1644 | RETURN_IF_ERR(loadSparseToDenseMask(op, dict)); |
1645 | return true; |
1646 | } |
1647 | if (typeName == "Gather" || typeName == "BatchGather" ) { |
1648 | RETURN_IF_ERR(loadGatherOps(typeName.str(), op, dict)); |
1649 | return true; |
1650 | } |
1651 | if (typeName == "GatherND" ) { |
1652 | RETURN_IF_ERR(loadGatherND(typeName.str(), op, dict)); |
1653 | return true; |
1654 | } |
1655 | if (typeName == "GatherRanges" ) { |
1656 | RETURN_IF_ERR(loadGatherRanges(typeName.str(), op, dict)); |
1657 | return true; |
1658 | } |
1659 | if (typeName == "Less" ) { |
1660 | RETURN_IF_ERR(loadLess(op, dict)); |
1661 | return true; |
1662 | } |
1663 | if (typeName == "And" || typeName == "Or" || typeName == "Xor" ) { |
1664 | RETURN_IF_ERR(loadLogicalOps(typeName, op)); |
1665 | return true; |
1666 | } |
1667 | if (typeName == "Not" ) { |
1668 | RETURN_IF_ERR(loadNotOp(typeName, op)); |
1669 | return true; |
1670 | } |
1671 | if (typeName == "Pow" ) { |
1672 | RETURN_IF_ERR(loadPow(op, dict)); |
1673 | return true; |
1674 | } |
1675 | |
1676 | return false; |
1677 | } |
1678 | |
1679 | /// Utility function which computes the resulting shape in case of |
1680 | /// multidirectional broadcasting. |
1681 | Expected<std::vector<dim_t>> |
1682 | computeMultidirectionalBroadcast(llvm::ArrayRef<dim_t> shape0, |
1683 | llvm::ArrayRef<dim_t> shape1) { |
1684 | size_t numDims0 = shape0.size(); |
1685 | size_t numDims1 = shape1.size(); |
1686 | size_t newNumDims = numDims0 > numDims1 ? numDims0 : numDims1; |
1687 | std::vector<dim_t> reshapeDims(newNumDims); |
1688 | |
1689 | for (size_t i = 0; i < newNumDims; i++) { |
1690 | reshapeDims[i] = 1; |
1691 | } |
1692 | RETURN_IF_ERR(mergeMultidirectionalBroadcast(reshapeDims, shape0)); |
1693 | RETURN_IF_ERR(mergeMultidirectionalBroadcast(reshapeDims, shape1)); |
1694 | |
1695 | return reshapeDims; |
1696 | } |
1697 | |
1698 | /// Associate all outputs of \p op with nodes in \p NVs. Number of outputs of |
1699 | /// \p op should match the number of elements of \p NVs. |
1700 | /// \returns error code in case of error. |
1701 | Error assignNodeOutputs(const OpType &op, llvm::ArrayRef<NodeValue> NVs) { |
1702 | RETURN_ERR_IF_NOT((dim_t)NVs.size() == (dim_t)op.output_size(), |
1703 | "Output size mismatch." ); |
1704 | for (size_t i = 0; i < NVs.size(); i++) { |
1705 | nodeValueByName_[op.output(i)] = NVs[i]; |
1706 | } |
1707 | return Error::success(); |
1708 | } |
1709 | |
1710 | /// Load pre-trained weights from \p weightDescriptors. |
1711 | Error loadWeights(uint32_t weightsCount, |
1712 | const onnxTensorDescriptorV1 *weightDescriptors) { |
1713 | for (uint32_t i = 0; i < weightsCount; ++i) { |
1714 | const char *name = weightDescriptors[i].name; |
1715 | |
1716 | LoadWeightResult loadResult; |
1717 | if (auto resOrErr = loadWeight(weightDescriptors[i])) { |
1718 | loadResult = std::move(*resOrErr); |
1719 | } else { |
1720 | RETURN_ERR(resOrErr.takeError()); |
1721 | } |
1722 | |
1723 | // If the weight is offline create a static placeholder, otherwise create |
1724 | // a constant. |
1725 | if (weightDescriptors[i].isOffline) { |
1726 | RETURN_ERR_IF_NOT( |
1727 | !clipQuantRangeToFP16_ || |
1728 | !loadResult.t->getType().isQuantizedType() || |
1729 | loadResult.t->getType().isFusedQuantizedType(), |
1730 | strFormat("Do not support clipQuantRangeToFP16 with unfused " |
1731 | "quantized input Placeholders: %s" , |
1732 | name)); |
1733 | Placeholder *pl; |
1734 | ASSIGN_VALUE_OR_RETURN_ERR( |
1735 | pl, createAndRegisterPlaceholder(name, &loadResult.type, |
1736 | /*isStatic*/ true)); |
1737 | (void)pl; |
1738 | } else { |
1739 | RETURN_IF_ERR( |
1740 | createAndRegisterConstant(name, std::move(*loadResult.t))); |
1741 | } |
1742 | |
1743 | if (loadResult.offsets) { |
1744 | auto offsetsName = strFormat("%s_loaded_offsets" , name); |
1745 | RETURN_IF_ERR(createAndRegisterConstant( |
1746 | offsetsName, std::move(*loadResult.offsets))); |
1747 | } |
1748 | |
1749 | if (loadResult.scales) { |
1750 | auto scalesName = strFormat("%s_loaded_scales" , name); |
1751 | RETURN_IF_ERR(createAndRegisterConstant(scalesName, |
1752 | std::move(*loadResult.scales))); |
1753 | } |
1754 | } |
1755 | |
1756 | return Error::success(); |
1757 | } |
1758 | |
1759 | /// Sets the type of \p S to have \p dstKind, using the same dims as S. |
1760 | Error setFusedTy(Storage *S, ElemKind dstKind) { |
1761 | // Use dummy 0.0/0 as scale/offset, since the actual scales/offsets |
1762 | // are fused inline with the data. |
1763 | TypeRef fusedTy = mod_.uniqueType(dstKind, S->dims(), 0.0, 0); |
1764 | return setFusedTy(S, fusedTy); |
1765 | } |
1766 | |
1767 | /// Sets the type of \p S to have \p fusedTy. If \p S already has type \p |
1768 | /// fusedTy, then this is a noop. Otherwise, expected that the original S is |
1769 | /// UInt8QTy. If \p S is a Constant, then also sets the payload of the |
1770 | /// Constant to have the same type. |
1771 | /// The motivation here is that there is no fused quantized type in |
1772 | /// Caffe2/ONNX, so we will always load them in UInt8QTy. We then change it |
1773 | /// from UInt8QTy to one of the fused kinds here. This may not be necessary if |
1774 | /// another user has already changed it, or the type may already have been |
1775 | /// modified in the case of loading into an existing module. |
1776 | Error setFusedTy(Storage *S, TypeRef fusedTy) { |
1777 | assert(fusedTy->isFusedQuantizedType() && "Expected fused quantized type." ); |
1778 | |
1779 | // If S already has the requested type then return early. |
1780 | if (S->getOutput().getType()->isEqual(*fusedTy)) { |
1781 | return Error::success(); |
1782 | } |
1783 | |
1784 | RETURN_ERR_IF_NOT(S->getElementType() == ElemKind::UInt8QTy, |
1785 | "Data must be UInt8QTy, but was " + |
1786 | Type::getElementName(S->getElementType()).str()); |
1787 | S->setType(Storage::OutputIdx, fusedTy); |
1788 | // If the node is a Constant set the payload type as well. |
1789 | if (auto *C = llvm::dyn_cast<Constant>(S)) { |
1790 | C->setPayloadType(fusedTy); |
1791 | } |
1792 | |
1793 | return Error::success(); |
1794 | } |
1795 | |
1796 | static Expected<bool> getCountIncludePads(ArgumentDictionaryTy &dict, |
1797 | bool defaultValue) { |
1798 | if (dict.count("count_include_pad" )) { |
1799 | int countIncludePads; |
1800 | ASSIGN_VALUE_OR_RETURN_ERR(countIncludePads, |
1801 | loadInt(dict.at("count_include_pad" ))); |
1802 | return (bool)countIncludePads; |
1803 | } |
1804 | // Return default value if can't find in the dict |
1805 | return defaultValue; |
1806 | } |
1807 | |
1808 | static Expected<std::vector<unsigned_t>> |
1809 | getDilations(ArgumentDictionaryTy &dict, |
1810 | const std::vector<unsigned_t> &defaultValue) { |
1811 | // For Caffe2 Model, `dilation` field can be either one integer or multiple |
1812 | // integers (one for each axis). When it's one integer the field in the dict |
1813 | // will be `dilation`. Otherwise, the field in the dict will be `dilations`. |
1814 | |
1815 | // For Onnx Model, it can only be `dilations` and it must be a list of |
1816 | // integers. |
1817 | if (dict.count("dilation" )) { |
1818 | unsigned_t dilation; |
1819 | ASSIGN_VALUE_OR_RETURN_ERR(dilation, loadInt(dict.at("dilation" ))); |
1820 | return std::vector<unsigned_t>{dilation, dilation}; |
1821 | } |
1822 | if (dict.count("dilations" )) { |
1823 | std::vector<unsigned_t> shape; |
1824 | ASSIGN_VALUE_OR_RETURN_ERR(shape, |
1825 | getShape<unsigned_t>(dict["dilations" ])); |
1826 | return shape; |
1827 | } |
1828 | |
1829 | return defaultValue; |
1830 | } |
1831 | }; |
1832 | |
1833 | } // namespace glow |
1834 | |
1835 | #endif // GLOW_IMPORTER_COMMONOPERATORLOADER_H |
1836 | |