1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Importer/TFLiteModelLoader.h"
18#include "glow/Base/Tensor.h"
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/Nodes.h"
21#include "glow/Importer/CommonOperatorLoader.h"
22#include "glow/Support/Support.h"
23
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/CommandLine.h"
26
27#include <fstream>
28#include <sstream>
29#include <string>
30#include <vector>
31
32using namespace glow;
33using llvm::cast;
34
35namespace {
36
37llvm::cl::OptionCategory
38 tfliteModelLoaderCat("TensorFlowLite Model Loader Options");
39
40llvm::cl::opt<bool> tfliteUint8ToInt8Opt(
41 "tflite-uint8-to-int8",
42 llvm::cl::desc("TensorFlowLite loader option to convert the model from "
43 "UINT8 data type to INT8 data type."),
44 llvm::cl::init(true), llvm::cl::Optional,
45 llvm::cl::cat(tfliteModelLoaderCat));
46
47llvm::cl::opt<bool> tfliteFloatInputsOpt(
48 "tflite-float-inputs",
49 llvm::cl::desc("TensorFlowLite loader option to replace the quantized "
50 "inputs with floating point inputs."),
51 llvm::cl::init(false), llvm::cl::Optional,
52 llvm::cl::cat(tfliteModelLoaderCat));
53
54llvm::cl::opt<bool> tfliteFloatOutputsOpt(
55 "tflite-float-outputs",
56 llvm::cl::desc("TensorFlowLite loader option to replace the quantized "
57 "outputs with floating point outputs."),
58 llvm::cl::init(false), llvm::cl::Optional,
59 llvm::cl::cat(tfliteModelLoaderCat));
60
61llvm::cl::opt<bool> tfliteFloatSoftmaxOpt(
62 "tflite-float-softmax",
63 llvm::cl::desc("TensorFlowLite loader option to replace a quantized Softmax"
64 "with a floating point Softmax."),
65 llvm::cl::init(false), llvm::cl::Optional,
66 llvm::cl::cat(tfliteModelLoaderCat));
67
68llvm::cl::opt<float> tfliteBiasScaleCheckMaxErrorOpt(
69 "tflite-bias-scale-check-max-error",
70 llvm::cl::desc(
71 "TensorFlowLite mandates that for quantized operators like Conv2D the "
72 "bias quantization parameter biasScale = inputScale * weightsScale but "
73 "some pre-quantized models do not EXACTLY satisfy this relation but "
74 "with very small relative errors (around 1e-8). Hence we allow a "
75 "tolerance of 1e-6 which, if satisfied, then we adjust the bias to "
76 "conform to the restriction."),
77 llvm::cl::init(1e-6), llvm::cl::Optional,
78 llvm::cl::cat(tfliteModelLoaderCat));
79
80llvm::cl::opt<bool> tfliteBiasScaleCheckThrowErrorOpt(
81 "tflite-bias-scale-check-throw-error",
82 llvm::cl::desc(
83 "TensorFlowLite mandates that for quantized operators like Conv2D the "
84 "bias quantization parameter biasScale = inputScale * weightsScale. If "
85 "this contraint is not met within the given tolerance then an error "
86 "will be thrown if this option is enabled."),
87 llvm::cl::init(true), llvm::cl::Optional,
88 llvm::cl::cat(tfliteModelLoaderCat));
89
90llvm::cl::opt<float> tfliteMfccSampleRateOpt(
91 "tflite-mfcc-sample-rate",
92 llvm::cl::desc(
93 "When the TensorFlowLite model has a MFCC node (Mel Frequency Cepstral "
94 "Coefficient) this option is used to set the sample rate (in Hz) used "
95 "by the node when no such attribute is specified."),
96 llvm::cl::init(16000.0), llvm::cl::Optional,
97 llvm::cl::cat(tfliteModelLoaderCat));
98
99/// Function to read a TensorFlowLite model from the file \p modelFilename into
100/// the data buffer \p modelData provided by the caller. The \p modelData buffer
101/// is allocated and initialized by this function but the caller must ensure its
102/// existence through the graph loading process. \returns the TensorFlowLite
103/// model object or Error in case something went wrong.
104Expected<const tflite::Model *> readModel(std::vector<char> &modelData,
105 const std::string &modelFilename) {
106 // Open file.
107 std::ifstream modelFile;
108 modelFile.open(modelFilename, std::ios::binary);
109 RETURN_ERR_IF_NOT(modelFile.is_open(),
110 strFormat("TensorFlowLite: Error opening model file '%s'!",
111 modelFilename.c_str()));
112 // Get model size.
113 modelFile.seekg(0, std::ios::end);
114 std::streamsize modelSize = modelFile.tellg();
115 modelFile.seekg(0, std::ios::beg);
116 // Read model data.
117 modelData = std::vector<char>(modelSize);
118 RETURN_ERR_IF_NOT(modelFile.read(modelData.data(), modelSize),
119 strFormat("TensorFlowLite: Error reading model file '%s'!",
120 modelFilename.c_str()));
121 modelFile.close();
122 // Return model object.
123 return tflite::GetModel(modelData.data());
124}
125
126/// Function to convert the UINT8 data from the buffer \p inpPtr to INT8 format
127/// into the buffer \p outPtr. The buffer size is given by \p numElem. This
128/// function is used to transform the UINT8 weights of a TensorFlowLite model to
129/// INT8 format which is the format preferred and supported by Glow.
130void convertUint8ToInt8(const uint8_t *inpPtr, int8_t *outPtr, size_t numElem) {
131 for (size_t idx = 0, idxEnd = numElem; idx < idxEnd; ++idx) {
132 int32_t val = inpPtr[idx];
133 val -= UINT8_TO_INT8_SHIFT;
134 outPtr[idx] = static_cast<int8_t>(val);
135 }
136}
137
138/// Function to compute the padding along a single dimension for the given input
139/// size \p inputSize and output size \p outputSize and for the given filter
140/// (kernel) size \p kernel \p stride, \p dilation and padding type \p padding.
141/// \returns a pair with the explicit padding values to be used for the input
142/// before and after the actual input data.
143std::pair<unsigned_t, unsigned_t>
144getConvPads(dim_t inputSize, dim_t outputSize, unsigned_t kernel,
145 unsigned_t stride, unsigned_t dilation, tflite::Padding padding) {
146 if (padding == tflite::Padding::Padding_VALID) {
147 // For VALID padding we do not use padding.
148 return std::pair<unsigned_t, unsigned_t>(0, 0);
149 } else if (padding == tflite::Padding::Padding_SAME) {
150 // Effective dilated filter (kernel) size.
151 unsigned_t effKernel = (kernel - 1) * dilation + 1;
152 // Compute the total padding size while saturating above 0.
153 unsigned_t padTotal = (outputSize - 1) * stride + effKernel;
154 padTotal =
155 std::max(padTotal, static_cast<unsigned_t>(inputSize)) - inputSize;
156 // We split the total padding evenly before/after. If the padding is odd
157 // then the "after" part gets the extra unit.
158 unsigned_t padBefore = padTotal / 2;
159 unsigned_t padAfter = padTotal - padBefore;
160 return std::pair<unsigned_t, unsigned_t>(padBefore, padAfter);
161 }
162 llvm_unreachable("Padding parameter invalid!");
163}
164
165/// Function used to compute the output size for convolution and pooling kernels
166/// for the given \p inputSize, \p kernel, \p stride, \p dilation, \p padding.
167/// This function is used to infer the output shape when not defined.
168dim_t getConvOutputSize(dim_t inputSize, unsigned_t kernel, unsigned_t stride,
169 unsigned_t dilation, tflite::Padding padding) {
170 if (padding == tflite::Padding::Padding_VALID) {
171 // Effective dilated filter (kernel) size.
172 unsigned_t effKernel = (kernel - 1) * dilation + 1;
173 // We compute the output size as CEIL((inputSize - effKernel) / stride) + 1.
174 return (inputSize - effKernel + stride - 1) / stride + 1;
175 } else if (padding == tflite::Padding::Padding_SAME) {
176 // For SAME padding the output size is computed as CEIL(inputSize / stride).
177 return (inputSize + stride - 1) / stride;
178 }
179 llvm_unreachable("Padding parameter invalid!");
180}
181
182/// Retrieves data from a constant Tensor and stores it in a vector.
183template <typename T, typename datatype = ssize_t>
184static void helperSetter(Constant *constT, std::vector<datatype> &vec) {
185 auto constH = constT->getPayload().getHandle<T>();
186 for (dim_t i = 0; i < constH.size(); ++i) {
187 vec.push_back(constH.at({i}));
188 }
189}
190
191/// Function to verify the quantization parameters of the bias operand. The
192/// TensorFlowLite format mandates that the bias scale must be equal to the
193/// product inputScale * weightsScale and the bias offset must be 0. This
194/// function is provided with the module \p mod and the node values \p input,
195/// \p weights, \p bias and \returns Error::success() if the bias parameters
196/// are valid and Error otherwise.
197Error checkBiasQuantizationParams(Module &mod, NodeValue input,
198 NodeValue weights, NodeValue bias) {
199 auto inputTy = input.getType();
200 auto weightsTy = weights.getType();
201 auto biasTy = bias.getType();
202 if (inputTy->isQuantizedType() && weightsTy->isQuantizedType() &&
203 biasTy->isQuantizedType()) {
204 float inputScale = inputTy->getScale();
205 float weightsScale = weightsTy->getScale();
206 float matMulScale = inputScale * weightsScale;
207 float biasScale = biasTy->getScale();
208 // Check bias scale relative error to inputScale * weightsScale.
209 if (biasScale != matMulScale) {
210 float relErr = std::abs(matMulScale - biasScale) / matMulScale;
211 llvm::errs() << strFormat(
212 "TensorFlowLite: WARNING: Per tensor BIAS scale value was expected "
213 "to be exactly %E (inputScale * weightsScale) but found "
214 "%E instead! Relative absolute error is %E!\n",
215 matMulScale, biasScale, relErr);
216 if (relErr < tfliteBiasScaleCheckMaxErrorOpt) {
217 // Set new bias type.
218 TypeRef newBiasTy =
219 mod.uniqueType(biasTy->getElementType(), biasTy->dims(),
220 matMulScale, biasTy->getOffset());
221 bias.setType(newBiasTy);
222 // If bias is constant we must also change the payload type.
223 if (auto *biasC = llvm::dyn_cast<Constant>(bias.getNode())) {
224 biasC->setPayloadType(newBiasTy);
225 }
226 } else if (tfliteBiasScaleCheckThrowErrorOpt) {
227 return MAKE_ERR(strFormat(
228 "TensorFlowLite: ERROR: Per tensor BIAS scale value was "
229 "expected to be exactly %E (inputScale * weightsScale) but "
230 "found %E instead! Relative absolute error is %E!\n",
231 matMulScale, biasScale, relErr));
232 }
233 }
234 int32_t biasOffset = biasTy->getOffset();
235 if (biasOffset != 0) {
236 return MAKE_ERR(
237 strFormat("TensorFlowLite: Bias offset value was expected to "
238 "be 0 but found %d instead!",
239 biasOffset));
240 }
241 }
242 return Error::success();
243}
244
245} // namespace
246
247///===---------------------------------------------------------------------===//
248/// Tensor Utilities
249///===---------------------------------------------------------------------===//
250Expected<const tflite::Tensor *>
251TFLiteModelLoader::getTensorByIndex(size_t index) {
252 auto *tensors = graph_->tensors();
253 RETURN_ERR_IF_NOT(
254 index < tensors->size(),
255 strFormat("TensorFlowLite: Tensor index %zu out of range!", index));
256 return (*tensors)[index];
257}
258
259std::string TFLiteModelLoader::getTensorName(const tflite::Tensor *tensor) {
260 return tensor->name()->str();
261}
262
263Expected<std::vector<dim_t>>
264TFLiteModelLoader::getTensorShape(const tflite::Tensor *tensor) {
265 // If tensor shape is NULL we use a 1D shape with size 1.
266 if (!tensor->shape()) {
267 return std::vector<dim_t>({1});
268 }
269 std::vector<dim_t> shape;
270 for (auto dim : *(tensor->shape())) {
271 RETURN_ERR_IF_NOT(dim > 0,
272 strFormat("TensorFlowLite: Tensor '%s' has invalid shape "
273 "element '%d'!",
274 getTensorName(tensor).c_str(), dim));
275 shape.push_back(static_cast<dim_t>(dim));
276 }
277 // If tensor shape is empty (scalar) we use a 1D shape with size 1.
278 if (shape.empty()) {
279 shape = {1};
280 }
281 return shape;
282}
283
284Expected<bool>
285TFLiteModelLoader::isTensorShapeUndefined(const tflite::Tensor *tensor) {
286 // If tensor shape is NULL.
287 if (!tensor->shape()) {
288 return true;
289 }
290 // If tensor shape is empty (scalar).
291 if (tensor->shape()->size() == 0) {
292 return true;
293 }
294 return false;
295}
296
297Expected<ElemKind>
298TFLiteModelLoader::getTensorElemKind(const tflite::Tensor *tensor) {
299 bool isQuantized = isTensorQuantized(tensor);
300 switch (tensor->type()) {
301 case tflite::TensorType_FLOAT32: {
302 RETURN_ERR_IF_NOT(
303 !isQuantized,
304 "TensorFlowLite: FLOAT32 type should have no quantization parameters!");
305 return ElemKind::FloatTy;
306 }
307 case tflite::TensorType_FLOAT16: {
308 RETURN_ERR_IF_NOT(
309 !isQuantized,
310 "TensorFlowLite: FLOAT16 type should have no quantization parameters!");
311 return ElemKind::Float16Ty;
312 }
313 case tflite::TensorType_INT8: {
314 if (isQuantized) {
315 return ElemKind::Int8QTy;
316 } else {
317 return MAKE_ERR("TensorFlowLite: Non-quantized INT8 type not supported!");
318 }
319 }
320 case tflite::TensorType_UINT8: {
321 if (isQuantized) {
322 // Convert UINT8 element type to INT8 element type.
323 if (tfliteUint8ToInt8Opt) {
324 return ElemKind::Int8QTy;
325 } else {
326 return ElemKind::UInt8QTy;
327 }
328 } else {
329 return MAKE_ERR(
330 "TensorFlowLite: Non-quantized UINT8 type not supported!");
331 }
332 }
333 case tflite::TensorType_INT16: {
334 if (isQuantized) {
335 return ElemKind::Int16QTy;
336 } else {
337 return MAKE_ERR(
338 "TensorFlowLite: Non-quantized INT16 type not supported!");
339 }
340 }
341 case tflite::TensorType_INT32: {
342 if (isQuantized) {
343 return ElemKind::Int32QTy;
344 } else {
345 return ElemKind::Int32ITy;
346 }
347 }
348 case tflite::TensorType_INT64: {
349 if (isQuantized) {
350 return MAKE_ERR("TensorFlowLite: Quantized INT64 type not supported!");
351 } else {
352 return ElemKind::Int64ITy;
353 }
354 }
355 case tflite::TensorType_BOOL: {
356 RETURN_ERR_IF_NOT(
357 !isQuantized,
358 "TensorFlowLite: BOOL type should have no quantization parameters!");
359 return ElemKind::BoolTy;
360 }
361 default:
362 return MAKE_ERR(
363 strFormat("TensorFlowLite: Tensor '%s' type '%s' not supported!",
364 getTensorName(tensor).c_str(),
365 tflite::EnumNameTensorType(tensor->type())));
366 }
367}
368
369bool TFLiteModelLoader::isTensorQuantized(const tflite::Tensor *tensor) {
370 auto *tensorQParams = tensor->quantization();
371 if (!tensorQParams) {
372 return false;
373 }
374 auto *scales = tensorQParams->scale();
375 auto *offsets = tensorQParams->zero_point();
376 if (!(scales && offsets)) {
377 return false;
378 }
379 if (!(scales->size() && offsets->size())) {
380 return false;
381 }
382 return true;
383}
384
385bool TFLiteModelLoader::isTensorPerAxisQuantized(const tflite::Tensor *tensor) {
386 if (!isTensorQuantized(tensor)) {
387 return false;
388 }
389 auto *tensorQParams = tensor->quantization();
390 auto *scales = tensorQParams->scale();
391 auto *offsets = tensorQParams->zero_point();
392 return (scales->size() > 1) && (offsets->size() > 1);
393}
394
395Expected<float>
396TFLiteModelLoader::getTensorScale(const tflite::Tensor *tensor) {
397 auto *tensorQParams = tensor->quantization();
398 RETURN_ERR_IF_NOT(
399 isTensorQuantized(tensor),
400 strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!",
401 getTensorName(tensor).c_str()));
402 RETURN_ERR_IF_NOT(
403 tensorQParams->details_type() == tflite::QuantizationDetails_NONE,
404 strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is "
405 "not supported!",
406 getTensorName(tensor).c_str()));
407 auto *scales = tensorQParams->scale();
408 RETURN_ERR_IF_NOT(scales->size() == 1,
409 strFormat("TensorFlowLite: Tensor '%s' has %d quantization "
410 "parameters but only one was expected!",
411 getTensorName(tensor).c_str(), scales->size()));
412 float scale = (*scales)[0];
413 return scale;
414}
415
416Expected<int32_t>
417TFLiteModelLoader::getTensorOffset(const tflite::Tensor *tensor) {
418 auto *tensorQParams = tensor->quantization();
419 RETURN_ERR_IF_NOT(
420 isTensorQuantized(tensor),
421 strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!",
422 getTensorName(tensor).c_str()));
423 RETURN_ERR_IF_NOT(
424 tensorQParams->details_type() == tflite::QuantizationDetails_NONE,
425 strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is "
426 "not supported!",
427 getTensorName(tensor).c_str()));
428 auto *offsets = tensorQParams->zero_point();
429 RETURN_ERR_IF_NOT(offsets->size() == 1,
430 strFormat("TensorFlowLite: Tensor '%s' has %d quantization "
431 "parameters but only one was expected!",
432 getTensorName(tensor).c_str(), offsets->size()));
433 // TensorFlowLite defines the offset as int64 since it also supports int64
434 // quantized type. Since Glow defines the offset as int32 we perform a cast
435 // here and also validate that the offset is within the int32 range.
436 int64_t offsetInt64 = (*offsets)[0];
437 RETURN_ERR_IF_NOT(
438 (std::numeric_limits<int32_t>::min() <= offsetInt64) &&
439 (offsetInt64 <= std::numeric_limits<int32_t>::max()),
440 strFormat(
441 "TensorFlowLite: Tensor '%s' has an offset out of the int32 range!",
442 getTensorName(tensor).c_str()));
443 int32_t offset = static_cast<int32_t>(offsetInt64);
444 // Convert UINT8 offset to INT8 offset.
445 if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) {
446 offset -= UINT8_TO_INT8_SHIFT;
447 }
448 return offset;
449}
450
451Expected<std::vector<float>>
452TFLiteModelLoader::getTensorScales(const tflite::Tensor *tensor) {
453 auto *tensorQParams = tensor->quantization();
454 RETURN_ERR_IF_NOT(
455 isTensorQuantized(tensor),
456 strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!",
457 getTensorName(tensor).c_str()));
458 RETURN_ERR_IF_NOT(
459 tensorQParams->details_type() == tflite::QuantizationDetails_NONE,
460 strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is "
461 "not supported!",
462 getTensorName(tensor).c_str()));
463 auto *scales = tensorQParams->scale();
464 RETURN_ERR_IF_NOT(scales->size() > 1,
465 strFormat("TensorFlowLite: Tensor '%s' has %d quantization "
466 "parameters but at least one was expected!",
467 getTensorName(tensor).c_str(), scales->size()));
468 std::vector<float> scalesVec =
469 std::vector<float>(scales->begin(), scales->end());
470 return scalesVec;
471}
472
473Expected<std::vector<int32_t>>
474TFLiteModelLoader::getTensorOffsets(const tflite::Tensor *tensor) {
475 auto *tensorQParams = tensor->quantization();
476 RETURN_ERR_IF_NOT(
477 isTensorQuantized(tensor),
478 strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!",
479 getTensorName(tensor).c_str()));
480 RETURN_ERR_IF_NOT(
481 tensorQParams->details_type() == tflite::QuantizationDetails_NONE,
482 strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is "
483 "not supported!",
484 getTensorName(tensor).c_str()));
485 auto *offsets = tensorQParams->zero_point();
486 RETURN_ERR_IF_NOT(offsets->size() > 1,
487 strFormat("TensorFlowLite: Tensor '%s' has %d quantization "
488 "parameters but at least one was expected!",
489 getTensorName(tensor).c_str(), offsets->size()));
490 // TensorFlowLite defines the offset as int64 since it also supports int64
491 // quantized type. Since Glow defines the offset as int32 we perform a cast
492 // here and also validate that the offset is within the int32 range.
493 std::vector<int32_t> offsetsVec;
494 for (auto offsetInt64 : *offsets) {
495 RETURN_ERR_IF_NOT(
496 (std::numeric_limits<int32_t>::min() <= offsetInt64) &&
497 (offsetInt64 <= std::numeric_limits<int32_t>::max()),
498 strFormat(
499 "TensorFlowLite: Tensor '%s' has an offset out of the int32 range!",
500 getTensorName(tensor).c_str()));
501 int32_t offset = static_cast<int32_t>(offsetInt64);
502 // Convert UINT8 offset to INT8 offset.
503 if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) {
504 offset -= UINT8_TO_INT8_SHIFT;
505 }
506 offsetsVec.push_back(offset);
507 }
508 return offsetsVec;
509}
510
511Expected<Type> TFLiteModelLoader::getTensorType(const tflite::Tensor *tensor) {
512 ElemKind elemKind;
513 ASSIGN_VALUE_OR_RETURN_ERR(elemKind, getTensorElemKind(tensor));
514 std::vector<dim_t> shape;
515 ASSIGN_VALUE_OR_RETURN_ERR(shape, getTensorShape(tensor));
516 if (isQuantizedElemKind(elemKind)) {
517 // If tensor is quantized per-axis we use a dummy scale 1.0 and offset 0.
518 float scale = 1.0;
519 int32_t offset = 0;
520 if (!isTensorPerAxisQuantized(tensor)) {
521 ASSIGN_VALUE_OR_RETURN_ERR(scale, getTensorScale(tensor));
522 ASSIGN_VALUE_OR_RETURN_ERR(offset, getTensorOffset(tensor));
523 }
524 return Type(elemKind, shape, scale, offset);
525 } else {
526 return Type(elemKind, shape);
527 }
528}
529
530Expected<std::pair<const char *, size_t>>
531TFLiteModelLoader::getTensorDataAndSize(const tflite::Tensor *tensor) {
532 uint32_t tensorBufferIdx = tensor->buffer();
533 auto *modelBuffers = model_->buffers();
534 RETURN_ERR_IF_NOT(tensorBufferIdx < modelBuffers->size(),
535 strFormat("TensorFlowLite: Tensor '%s' has a buffer index "
536 "out of range!",
537 getTensorName(tensor).c_str()));
538 const char *tensorData = nullptr;
539 size_t tensorSize = 0;
540 if (auto *buffer = (*modelBuffers)[tensorBufferIdx]) {
541 if (auto *array = buffer->data()) {
542 if (array->size()) {
543 tensorData =
544 const_cast<char *>(reinterpret_cast<const char *>(array->data()));
545 tensorSize = array->size();
546 }
547 }
548 }
549 return std::pair<const char *, size_t>(tensorData, tensorSize);
550}
551
552///===---------------------------------------------------------------------===//
553/// Operator Utilities
554///===---------------------------------------------------------------------===//
555Expected<tflite::BuiltinOperator>
556TFLiteModelLoader::getOperatorCode(const tflite::Operator *op) {
557 const auto *modelOpCodes = model_->operator_codes();
558 auto opCodeIdx = op->opcode_index();
559 RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(),
560 strFormat("TensorFlowLite: Missing registration for "
561 "opcode_index %d!",
562 opCodeIdx));
563 auto *opCode = (*modelOpCodes)[opCodeIdx];
564 auto builtinCode = opCode->builtin_code();
565 RETURN_ERR_IF_NOT(
566 (tflite::BuiltinOperator_MIN <= builtinCode) &&
567 (builtinCode <= tflite::BuiltinOperator_MAX),
568 strFormat(
569 "TensorFlowLite: Operator builtin_code %d out of the supported "
570 "range! You might be using a newer model than currently supported!",
571 builtinCode));
572 return builtinCode;
573}
574
575Expected<std::string>
576TFLiteModelLoader::getOperatorCustomCode(const tflite::Operator *op) {
577 const auto *modelOpCodes = model_->operator_codes();
578 auto opCodeIdx = op->opcode_index();
579 RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(),
580 strFormat("TensorFlowLite: Missing registration for "
581 "opcode_index %d!",
582 opCodeIdx));
583 auto *opCode = (*modelOpCodes)[opCodeIdx];
584 auto customCode = opCode->custom_code();
585 RETURN_ERR_IF_NOT(customCode,
586 strFormat("TensorFlowLite: Missing custom code for "
587 "opcode_index %d!",
588 opCodeIdx));
589 return customCode->str();
590}
591
592Expected<int32_t>
593TFLiteModelLoader::getOperatorVersion(const tflite::Operator *op) {
594 const auto *modelOpCodes = model_->operator_codes();
595 auto opCodeIdx = op->opcode_index();
596 RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(),
597 strFormat("TensorFlowLite: Missing registration for "
598 "opcode_index %d!",
599 opCodeIdx));
600 auto *opCode = (*modelOpCodes)[opCodeIdx];
601 return opCode->version();
602}
603
604Expected<std::string>
605TFLiteModelLoader::getOperatorType(const tflite::Operator *op) {
606 tflite::BuiltinOperator opCode;
607 ASSIGN_VALUE_OR_RETURN_ERR(opCode, getOperatorCode(op));
608 return std::string(tflite::EnumNameBuiltinOperator(opCode));
609}
610
611Expected<std::string>
612TFLiteModelLoader::getOperatorName(const tflite::Operator *op) {
613 std::string opType;
614 ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op));
615 const auto *opOutputs = op->outputs();
616 // If operator has no outputs then we return the operator type name.
617 if (opOutputs->size() == 0) {
618 return opType;
619 }
620 // If the first output tensor corresponds to an output placeholder then we use
621 // the operator type name in order to preserve the output placeholder name.
622 size_t opOutIdx = static_cast<size_t>((*opOutputs)[0]);
623 const auto *graphOutputs = graph_->outputs();
624 for (auto graphOutIdx : (*graphOutputs)) {
625 if (int(opOutIdx) == graphOutIdx) {
626 return opType;
627 }
628 }
629 // Return the name of the first output tensor.
630 const tflite::Tensor *tensor;
631 ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(opOutIdx));
632 return getTensorName(tensor);
633}
634
635Expected<flexbuffers::Map>
636TFLiteModelLoader::getOperatorCustomOpts(const tflite::Operator *op) {
637 size_t optsSize = op->custom_options()->size();
638 auto *customOpts = op->custom_options();
639 RETURN_ERR_IF_NOT(customOpts,
640 strFormat("TensorFlowLite: Missing custom options for "
641 "opcode_index %d!",
642 op->opcode_index()));
643 const uint8_t *optsAddr =
644 reinterpret_cast<const uint8_t *>(customOpts->data());
645 RETURN_ERR_IF_NOT(optsAddr,
646 strFormat("TensorFlowLite: Missing custom options for "
647 "opcode_index %d!",
648 op->opcode_index()));
649 return flexbuffers::GetRoot(optsAddr, optsSize).AsMap();
650}
651
652Expected<int32_t>
653TFLiteModelLoader::getOperatorInputTensorIdx(const tflite::Operator *op,
654 size_t inputIdx) {
655 std::string opType;
656 ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op));
657 const auto *opInputs = op->inputs();
658 RETURN_ERR_IF_NOT(opInputs,
659 strFormat("TensorFlowLite: Operator '%s' has no inputs!",
660 opType.c_str()));
661 RETURN_ERR_IF_NOT(inputIdx < opInputs->size(),
662 strFormat("TensorFlowLite: Operator '%s' input index %zu "
663 "is out of range! Operator has %d inputs!",
664 opType.c_str(), inputIdx, opInputs->size()));
665 return (*opInputs)[inputIdx];
666}
667
668Expected<size_t>
669TFLiteModelLoader::getOperatorOutputTensorIdx(const tflite::Operator *op,
670 size_t outputIdx) {
671 std::string opType;
672 ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op));
673 const auto *opOutputs = op->outputs();
674 RETURN_ERR_IF_NOT(opOutputs,
675 strFormat("TensorFlowLite: Operator '%s' has no outputs!",
676 opType.c_str()));
677 RETURN_ERR_IF_NOT(outputIdx < opOutputs->size(),
678 strFormat("TensorFlowLite: Operator '%s' output index %zu "
679 "is out of range! Operator has %d outputs!",
680 opType.c_str(), outputIdx, opOutputs->size()));
681 return static_cast<size_t>((*opOutputs)[outputIdx]);
682}
683
684Expected<bool>
685TFLiteModelLoader::isOperatorOutputFinalTensor(const tflite::Operator *op,
686 size_t outputIdx) {
687 size_t tensorIdx;
688 ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx,
689 getOperatorOutputTensorIdx(op, outputIdx));
690 const auto *graphOutputs = graph_->outputs();
691 for (auto graphOutIdx : (*graphOutputs)) {
692 if (int(tensorIdx) == graphOutIdx) {
693 return true;
694 }
695 }
696 return false;
697}
698
699Expected<NodeValue> TFLiteModelLoader::getNodeValueByIndex(size_t index) {
700 RETURN_ERR_IF_NOT(!nodeValueByIndex_.empty(),
701 "TensorFlowLite: Node value array not initialized!");
702 RETURN_ERR_IF_NOT(
703 index < nodeValueByIndex_.size(),
704 strFormat("TensorFlowLite: Node value index %zu is out of range!",
705 index));
706 NodeValue nodeValue = nodeValueByIndex_[index];
707 RETURN_ERR_IF_NOT(nodeValue.getNode(),
708 strFormat("TensorFlowLite: Node value with index %zu is "
709 "null (not initialized)!",
710 index));
711 return nodeValue;
712}
713
714Error TFLiteModelLoader::setNodeValueByIndex(size_t index,
715 NodeValue nodeValue) {
716 RETURN_ERR_IF_NOT(!nodeValueByIndex_.empty(),
717 "TensorFlowLite: Node value array not initialized!");
718 RETURN_ERR_IF_NOT(
719 index < nodeValueByIndex_.size(),
720 strFormat("TensorFlowLite: Node value index %zu is out of range!",
721 index));
722 nodeValueByIndex_[index] = nodeValue;
723 return Error::success();
724}
725
726Expected<NodeValue>
727TFLiteModelLoader::getInputNodeValue(const tflite::Operator *op,
728 size_t inputIdx) {
729 int32_t tensorIdx;
730 ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx,
731 getOperatorInputTensorIdx(op, inputIdx));
732 // If the tensor index is negative this means the operand does not exist.
733 if (tensorIdx < 0) {
734 return NodeValue(nullptr);
735 }
736 return getNodeValueByIndex(tensorIdx);
737}
738
739Error TFLiteModelLoader::setOutputNodeValue(const tflite::Operator *op,
740 NodeValue nodeValue,
741 bool checkType) {
742 std::vector<NodeValue> nodeValues = {nodeValue};
743 return setOutputNodeValues(op, nodeValues, checkType);
744}
745
746Error TFLiteModelLoader::setOutputNodeValues(
747 const tflite::Operator *op, llvm::ArrayRef<NodeValue> nodeValues,
748 bool checkType) {
749 std::string opType;
750 ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op));
751 const auto *opOutputs = op->outputs();
752 RETURN_ERR_IF_NOT(
753 opOutputs->size() == nodeValues.size(),
754 strFormat("TensorFlowLite: Operator '%s' has %d outputs but %zu are set!",
755 opType.c_str(), opOutputs->size(), nodeValues.size()));
756 for (size_t idx = 0, idxEnd = nodeValues.size(); idx < idxEnd; ++idx) {
757 NodeValue outNodeValue = nodeValues[idx];
758 // Verify the output type of the node value matches the type registered in
759 // the model with the exception of the final tensors which are allowed to
760 // be modified (for example for the Softmax output when it is a final node).
761 // Tensors with undefined shapes are also not checked.
762 if (checkType) {
763 TypeRef outTy;
764 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, idx));
765 bool isUndefined;
766 ASSIGN_VALUE_OR_RETURN_ERR(isUndefined, isOutputShapeUndefined(op, idx));
767 bool isFinal;
768 ASSIGN_VALUE_OR_RETURN_ERR(isFinal, isOperatorOutputFinalTensor(op, idx));
769 RETURN_ERR_IF_NOT(isUndefined || isFinal ||
770 outTy->isEqual(outNodeValue.getType()),
771 strFormat("TensorFlowLite: Operator '%s' modifies the "
772 "output type registered in the model!",
773 opType.c_str()));
774 }
775 // Register the output node value.
776 size_t tensorIdx = static_cast<size_t>((*opOutputs)[idx]);
777 RETURN_IF_ERR(setNodeValueByIndex(tensorIdx, outNodeValue));
778 }
779 return Error::success();
780}
781
782Expected<TypeRef> TFLiteModelLoader::getOutputType(const tflite::Operator *op,
783 size_t outputIndex) {
784 size_t tensorIdx;
785 ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx,
786 getOperatorOutputTensorIdx(op, outputIndex));
787 const tflite::Tensor *tensor;
788 ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(tensorIdx));
789 Type type;
790 ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor));
791 return mod_.uniqueType(type);
792}
793
794Expected<bool>
795TFLiteModelLoader::isOutputShapeUndefined(const tflite::Operator *op,
796 size_t outputIndex) {
797 size_t tensorIdx;
798 ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx,
799 getOperatorOutputTensorIdx(op, outputIndex));
800 const tflite::Tensor *tensor;
801 ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(tensorIdx));
802 bool undefined;
803 ASSIGN_VALUE_OR_RETURN_ERR(undefined, isTensorShapeUndefined(tensor));
804 return undefined;
805}
806
807void TFLiteModelLoader::initializeNodeValues() {
808 auto numTensors = graph_->tensors()->size();
809 nodeValueByIndex_ = std::vector<NodeValue>(numTensors, nullptr);
810}
811
812Error TFLiteModelLoader::loadInputPlaceholders() {
813 for (auto inpIdx : *(graph_->inputs())) {
814 // Get input placeholder name and type.
815 const tflite::Tensor *tensor;
816 ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(inpIdx));
817 std::string name = getTensorName(tensor);
818 Type type;
819 ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor));
820 // Create input placeholder. If the input type is quantized and a float
821 // input is requested then we create a float placeholder and a Quantize
822 // node, otherwise we create directly the quantized placeholder.
823 Placeholder *inpPH;
824 NodeValue inpNV;
825 if (tfliteFloatInputsOpt && type.isQuantizedType()) {
826 TypeRef floatType = mod_.uniqueType(ElemKind::FloatTy, type.dims());
827 inpPH = mod_.createPlaceholder(floatType, name, /*isTrainable*/ false,
828 ANY_LAYOUT);
829 inpNV = F_->createQuantize(name + ".Quantize", inpPH, &type);
830 } else {
831 inpPH = mod_.createPlaceholder(&type, name, /*isTrainable*/ false,
832 ANY_LAYOUT);
833 inpNV = inpPH;
834 }
835 // Register placeholder by model input name.
836 inputPlaceholderByName_.try_emplace(name, inpPH);
837 // Set input node value.
838 RETURN_IF_ERR(setNodeValueByIndex(inpIdx, inpNV));
839 }
840 return Error::success();
841}
842
843Error TFLiteModelLoader::loadConstants() {
844 const auto *tensors = graph_->tensors();
845 for (size_t idx = 0, idxEnd = tensors->size(); idx < idxEnd; ++idx) {
846 // Get tensor data and size. A TensorFlowLite model tensor is a constant
847 // if it has data stored in the model.
848 const tflite::Tensor *tensor = (*tensors)[idx];
849 std::pair<const char *, size_t> dataAndSize;
850 ASSIGN_VALUE_OR_RETURN_ERR(dataAndSize, getTensorDataAndSize(tensor));
851 if (dataAndSize.first == nullptr) {
852 continue;
853 }
854 // Create tensor and initialize data.
855 std::string name = getTensorName(tensor);
856 Type type;
857 ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor));
858 RETURN_ERR_IF_NOT(
859 type.getSizeInBytes() == dataAndSize.second,
860 strFormat("TensorFlowLite: Tensor '%s' mismatch between shape based "
861 "size (%zu bytes) and actual data size (%lu bytes)!",
862 name.c_str(), type.getSizeInBytes(), dataAndSize.second));
863 Tensor T = Tensor(type);
864 T.copyRawFrom(dataAndSize.first);
865 // Convert UINT8 data to INT8 data.
866 if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) {
867 convertUint8ToInt8(reinterpret_cast<uint8_t *>(T.getUnsafePtr()),
868 reinterpret_cast<int8_t *>(T.getUnsafePtr()),
869 dataAndSize.second);
870 }
871 // Create constant.
872 Constant *node = mod_.createConstant(name, std::move(T), ANY_LAYOUT);
873 // Register node value.
874 RETURN_IF_ERR(setNodeValueByIndex(idx, node->getOutput()));
875 }
876 return Error::success();
877}
878
879Error TFLiteModelLoader::loadOperators() {
880 OperatorInfo opInfo;
881 auto *graphOperators = graph_->operators();
882 for (size_t opIdx = 0, opIdxEnd = graphOperators->size(); opIdx < opIdxEnd;
883 ++opIdx) {
884 // Get operator meta data.
885 const tflite::Operator *op = (*graphOperators)[opIdx];
886 ASSIGN_VALUE_OR_RETURN_ERR(opInfo.name, getOperatorName(op));
887 ASSIGN_VALUE_OR_RETURN_ERR(opInfo.type, getOperatorType(op));
888 ASSIGN_VALUE_OR_RETURN_ERR(opInfo.code, getOperatorCode(op));
889 ASSIGN_VALUE_OR_RETURN_ERR(opInfo.version, getOperatorVersion(op));
890 opInfo.index = opIdx;
891 // Load operator.
892 mod_.registerOriginalName(opInfo.name);
893 RETURN_IF_ERR(loadOperator(op, opInfo));
894 }
895 return Error::success();
896}
897
898Error TFLiteModelLoader::saveOutputPlaceholders() {
899 for (auto outIdx : *(graph_->outputs())) {
900 // Get placeholder name.
901 const tflite::Tensor *tensor;
902 ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(outIdx));
903 std::string name = getTensorName(tensor);
904 // Save output placeholder. If the output type is quantized and a float
905 // output is requested then we create create a Dequantize node and save
906 // into a float placeholder, otherwise we save the quantized placeholder.
907 NodeValue outNodeValue;
908 ASSIGN_VALUE_OR_RETURN_ERR(outNodeValue, getNodeValueByIndex(outIdx));
909 if (tfliteFloatOutputsOpt && outNodeValue.getType()->isQuantizedType()) {
910 outNodeValue = F_->createDequantize(name + ".Dequantize", outNodeValue,
911 ElemKind::FloatTy);
912 }
913 auto *saveNode = F_->createSave(name, outNodeValue);
914 // Register placeholder by model output name.
915 outputPlaceholderByName_.try_emplace(name, saveNode->getPlaceholder());
916 }
917 return Error::success();
918}
919
920Error TFLiteModelLoader::addActivation(NodeValue &value,
921 tflite::ActivationFunctionType type) {
922 std::string nodeName = value.getNode()->getName().str();
923 std::string actType = EnumNameActivationFunctionType(type);
924 std::string actName = nodeName + "." + actType;
925 if (type == tflite::ActivationFunctionType_NONE) {
926 return Error::success();
927 }
928 if (type == tflite::ActivationFunctionType_RELU) {
929 value = F_->createRELU(actName, value);
930 return Error::success();
931 }
932 if (type == tflite::ActivationFunctionType_RELU_N1_TO_1) {
933 value = F_->createClip(actName, value, -1.0, 1.0);
934 return Error::success();
935 }
936 if (type == tflite::ActivationFunctionType_RELU6) {
937 value = F_->createClip(actName, value, 0.0, 6.0);
938 return Error::success();
939 }
940 if (type == tflite::ActivationFunctionType_TANH) {
941 value = F_->createTanh(actName, value);
942 return Error::success();
943 }
944 return MAKE_ERR(
945 strFormat("TensorFlowLite: Activation type '%s' is not supported!",
946 actType.c_str()));
947}
948
949const std::string TFLiteModelLoader::opErrMsg(const OperatorInfo &opInfo,
950 const std::string &errMsg) {
951 return strFormat("TensorFlowLite: Operator '%s' (Index %zu, Code %u): %s",
952 opInfo.type.c_str(), opInfo.index, opInfo.code,
953 errMsg.c_str());
954}
955
956template <typename T>
957Expected<T> TFLiteModelLoader::loadAxis(const OperatorInfo &opInfo,
958 NodeValue axis, NodeValue value) {
959 auto *axisC = llvm::dyn_cast<Constant>(axis.getNode());
960 RETURN_ERR_IF_NOT(axisC,
961 opErrMsg(opInfo, "Non constant axis not supported!"));
962 RETURN_ERR_IF_NOT(axisC->getType()->size() == 1,
963 opErrMsg(opInfo, "Axis should have 1 element!"));
964 T axisVal;
965 auto elemType = axisC->getType()->getElementType();
966 if (elemType == ElemKind::Int32ITy) {
967 auto axisH = axisC->getPayload().getHandle<int32_t>();
968 ASSIGN_VALUE_OR_RETURN_ERR(
969 axisVal, getPositiveAxis<T>(static_cast<int>(axisH.raw(0)), value));
970 } else if (elemType == ElemKind::Int64ITy) {
971 auto axisH = axisC->getPayload().getHandle<int64_t>();
972 ASSIGN_VALUE_OR_RETURN_ERR(
973 axisVal, getPositiveAxis<T>(static_cast<int>(axisH.raw(0)), value));
974 } else {
975 return MAKE_ERR(opErrMsg(opInfo, "Axis should have INT32 or INT64 type!"));
976 }
977 return axisVal;
978}
979
980template <typename T>
981Expected<std::vector<T>> TFLiteModelLoader::loadAxes(const OperatorInfo &opInfo,
982 NodeValue axes,
983 NodeValue value) {
984 Constant *axesC = llvm::dyn_cast<Constant>(axes.getNode());
985 RETURN_ERR_IF_NOT(axesC,
986 opErrMsg(opInfo, "Non constant axis not supported!"));
987 RETURN_ERR_IF_NOT(axesC->getType()->size() >= 1,
988 opErrMsg(opInfo, "Axis should have at least 1 element!"));
989 std::vector<T> axesVal = std::vector<T>(axesC->getType()->size());
990 auto elemType = axesC->getType()->getElementType();
991 for (size_t idx = 0; idx < axesC->getType()->size(); ++idx) {
992 if (elemType == ElemKind::Int32ITy) {
993 auto axesH = axesC->getPayload().getHandle<int32_t>();
994 ASSIGN_VALUE_OR_RETURN_ERR(
995 axesVal[idx],
996 getPositiveAxis<T>(static_cast<int>(axesH.raw(idx)), value));
997 } else if (elemType == ElemKind::Int64ITy) {
998 auto axesH = axesC->getPayload().getHandle<int64_t>();
999 ASSIGN_VALUE_OR_RETURN_ERR(
1000 axesVal[idx],
1001 getPositiveAxis<T>(static_cast<int>(axesH.raw(idx)), value));
1002 } else {
1003 return MAKE_ERR(
1004 opErrMsg(opInfo, "Axis should have INT32 or INT64 type!"));
1005 }
1006 }
1007 return axesVal;
1008}
1009
1010template <typename T>
1011Expected<std::vector<T>>
1012TFLiteModelLoader::loadArray(const OperatorInfo &opInfo, NodeValue value) {
1013 Constant *valueC = llvm::dyn_cast<Constant>(value.getNode());
1014 RETURN_ERR_IF_NOT(valueC,
1015 opErrMsg(opInfo, "Non constant array not supported!"));
1016 auto valueSize = valueC->getType()->size();
1017 RETURN_ERR_IF_NOT(valueSize >= 1,
1018 opErrMsg(opInfo, "Array should have at least 1 element!"));
1019 std::vector<T> valueV = std::vector<T>(valueSize);
1020 auto elemType = valueC->getType()->getElementType();
1021 for (size_t idx = 0; idx < valueSize; ++idx) {
1022 if (elemType == ElemKind::FloatTy) {
1023 auto valueH = valueC->getPayload().getHandle<float>();
1024 valueV[idx] = static_cast<T>(valueH.raw(idx));
1025 } else if (elemType == ElemKind::Int32ITy) {
1026 auto valueH = valueC->getPayload().getHandle<int32_t>();
1027 valueV[idx] = static_cast<T>(valueH.raw(idx));
1028 } else if (elemType == ElemKind::Int64ITy) {
1029 auto valueH = valueC->getPayload().getHandle<int64_t>();
1030 valueV[idx] = static_cast<T>(valueH.raw(idx));
1031 } else {
1032 return MAKE_ERR(opErrMsg(opInfo, "Array type not supported!"));
1033 }
1034 }
1035 return valueV;
1036}
1037
1038Expected<bool> TFLiteModelLoader::isConv2DPerAxisQuantized(
1039 const tflite::Operator *op, const OperatorInfo &opInfo,
1040 Constant *&filterScalesC, Constant *&filterOffsetsC, Constant *&biasScalesC,
1041 Constant *&biasOffsetsC) {
1042 // Get filter/bias tensors.
1043 NodeValue input;
1044 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1045 TypeRef outTy;
1046 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1047 int32_t filterTensorIdx;
1048 ASSIGN_VALUE_OR_RETURN_ERR(filterTensorIdx, getOperatorInputTensorIdx(op, 1));
1049 int32_t biasTensorIdx;
1050 ASSIGN_VALUE_OR_RETURN_ERR(biasTensorIdx, getOperatorInputTensorIdx(op, 2));
1051 const tflite::Tensor *filterTensor;
1052 ASSIGN_VALUE_OR_RETURN_ERR(filterTensor, getTensorByIndex(filterTensorIdx));
1053 const tflite::Tensor *biasTensor;
1054 ASSIGN_VALUE_OR_RETURN_ERR(biasTensor, getTensorByIndex(biasTensorIdx));
1055
1056 bool isPerAxisQuantized = isTensorPerAxisQuantized(filterTensor) &&
1057 isTensorPerAxisQuantized(biasTensor);
1058
1059 // If it is not per-axis quantized return directly.
1060 if (!isPerAxisQuantized) {
1061 filterScalesC = nullptr;
1062 filterOffsetsC = nullptr;
1063 biasScalesC = nullptr;
1064 biasOffsetsC = nullptr;
1065 return false;
1066 }
1067
1068 dim_t numChannels = outTy->dims().back();
1069
1070 // Get filter/bias quantization parameters.
1071 std::vector<float> filterScalesV;
1072 ASSIGN_VALUE_OR_RETURN_ERR(filterScalesV, getTensorScales(filterTensor));
1073 std::vector<int32_t> filterOffsetsV;
1074 ASSIGN_VALUE_OR_RETURN_ERR(filterOffsetsV, getTensorOffsets(filterTensor));
1075 std::vector<float> biasScalesV;
1076 ASSIGN_VALUE_OR_RETURN_ERR(biasScalesV, getTensorScales(biasTensor));
1077 std::vector<int32_t> biasOffsetsV;
1078 ASSIGN_VALUE_OR_RETURN_ERR(biasOffsetsV, getTensorOffsets(biasTensor));
1079
1080 // Create filter/bias quantization parameters graph constants.
1081 filterScalesC =
1082 mod_.createConstant(ElemKind::FloatTy, {numChannels}, "filterScales");
1083 filterOffsetsC =
1084 mod_.createConstant(ElemKind::Int32ITy, {numChannels}, "filterOffsets");
1085 biasScalesC =
1086 mod_.createConstant(ElemKind::FloatTy, {numChannels}, "biasScales");
1087 biasOffsetsC =
1088 mod_.createConstant(ElemKind::Int32ITy, {numChannels}, "biasOffsets");
1089
1090 RETURN_ERR_IF_NOT(
1091 filterScalesV.size() == numChannels,
1092 opErrMsg(opInfo,
1093 "Weights scales length should match the output channels!"));
1094 RETURN_ERR_IF_NOT(
1095 filterOffsetsV.size() == numChannels,
1096 opErrMsg(opInfo,
1097 "Weights offsets length should match the output channels!"));
1098 RETURN_ERR_IF_NOT(
1099 biasScalesV.size() == numChannels,
1100 opErrMsg(opInfo, "Bias scales length should match the output channels!"));
1101 RETURN_ERR_IF_NOT(
1102 biasOffsetsV.size() == numChannels,
1103 opErrMsg(opInfo,
1104 "Bias offsets length should match the output channels!"));
1105
1106 filterScalesC->getPayloadMutable().copyRawFrom(
1107 reinterpret_cast<const char *>(filterScalesV.data()));
1108 filterOffsetsC->getPayloadMutable().copyRawFrom(
1109 reinterpret_cast<const char *>(filterOffsetsV.data()));
1110 biasScalesC->getPayloadMutable().copyRawFrom(
1111 reinterpret_cast<const char *>(biasScalesV.data()));
1112 biasOffsetsC->getPayloadMutable().copyRawFrom(
1113 reinterpret_cast<const char *>(biasOffsetsV.data()));
1114
1115 // Validate filter/bias quantization parameters.
1116 float inputScale = input.getType()->getScale();
1117 auto filterScalesH = filterScalesC->getPayload().getHandle<float>();
1118 auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>();
1119 auto biasScalesH = biasScalesC->getPayload().getHandle<float>();
1120 auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>();
1121 for (size_t idx = 0; idx < numChannels; ++idx) {
1122 // TensorFlowLite mandates that filterOffset and biasOffset are 0.
1123 RETURN_ERR_IF_NOT(filterOffsetsH.raw(idx) == 0,
1124 opErrMsg(opInfo, "Filter offset was expected to be 0!"));
1125 RETURN_ERR_IF_NOT(biasOffsetsH.raw(idx) == 0,
1126 opErrMsg(opInfo, "Bias offset was expected to be 0!"));
1127
1128 float filterScale = filterScalesH.raw(idx);
1129 float matMulScale = inputScale * filterScale;
1130 float biasScale = biasScalesH.raw(idx);
1131
1132 // Check bias scale relative error to inputScale * filterScale.
1133 if (biasScale != matMulScale) {
1134 float relErr = std::abs(matMulScale - biasScale) / matMulScale;
1135 llvm::errs() << opErrMsg(
1136 opInfo,
1137 strFormat("WARNING: Per channel BIAS scale value was expected "
1138 "to be exactly %E (inputScale * weightsScale) but found "
1139 "%E instead! Relative absolute error is %E!\n",
1140 matMulScale, biasScale, relErr));
1141 if (relErr < tfliteBiasScaleCheckMaxErrorOpt) {
1142 // Modify bias scale.
1143 biasScalesH.raw(idx) = matMulScale;
1144 } else if (tfliteBiasScaleCheckThrowErrorOpt) {
1145 return MAKE_ERR(opErrMsg(
1146 opInfo,
1147 strFormat("ERROR: Per channel BIAS scale value was expected "
1148 "to be exactly %E (inputScale * weightsScale) but found "
1149 "%E instead! Relative absolute error is %E!\n",
1150 matMulScale, biasScale, relErr)));
1151 }
1152 }
1153 }
1154
1155 return true;
1156}
1157
1158Error TFLiteModelLoader::loadOperator(const tflite::Operator *op,
1159 const OperatorInfo &opInfo) {
1160 // Opcodes are treated in increasing order to allow easy tracking
1161 // for which operators are supported and which are not.
1162 auto opCode = opInfo.code;
1163 if (opCode == tflite::BuiltinOperator_ADD) {
1164 return loadBinaryArithmetic(op, opInfo);
1165 }
1166 if (opCode == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
1167 return loadPool2D(op, opInfo);
1168 }
1169 if (opCode == tflite::BuiltinOperator_CONCATENATION) {
1170 return loadConcat(op, opInfo);
1171 }
1172 if (opCode == tflite::BuiltinOperator_CONV_2D) {
1173 return loadConv2D(op, opInfo);
1174 }
1175 if (opCode == tflite::BuiltinOperator_DEPTHWISE_CONV_2D) {
1176 return loadDepthwiseConv2D(op, opInfo);
1177 }
1178 if (opCode == tflite::BuiltinOperator_DEQUANTIZE) {
1179 return loadUnaryArithmetic(op, opInfo);
1180 }
1181 if (opCode == tflite::BuiltinOperator_HARD_SWISH) {
1182 return loadUnaryArithmetic(op, opInfo);
1183 }
1184 if (opCode == tflite::BuiltinOperator_FLOOR) {
1185 return loadUnaryArithmetic(op, opInfo);
1186 }
1187 if (opCode == tflite::BuiltinOperator_FULLY_CONNECTED) {
1188 return loadFullyConnected(op, opInfo);
1189 }
1190 if (opCode == tflite::BuiltinOperator_LOGISTIC) {
1191 return loadUnaryArithmetic(op, opInfo);
1192 }
1193 if (opCode == tflite::BuiltinOperator_MAX_POOL_2D) {
1194 return loadPool2D(op, opInfo);
1195 }
1196 if (opCode == tflite::BuiltinOperator_MUL) {
1197 return loadBinaryArithmetic(op, opInfo);
1198 }
1199 if (opCode == tflite::BuiltinOperator_RELU) {
1200 return loadUnaryArithmetic(op, opInfo);
1201 }
1202 if (opCode == tflite::BuiltinOperator_RELU_N1_TO_1) {
1203 return loadUnaryArithmetic(op, opInfo);
1204 }
1205 if (opCode == tflite::BuiltinOperator_RELU6) {
1206 return loadUnaryArithmetic(op, opInfo);
1207 }
1208 if (opCode == tflite::BuiltinOperator_RESHAPE) {
1209 return loadReshape(op, opInfo);
1210 }
1211 if (opCode == tflite::BuiltinOperator_SOFTMAX) {
1212 return loadSoftmax(op, opInfo);
1213 }
1214 if (opCode == tflite::BuiltinOperator_LOG_SOFTMAX) {
1215 return loadLogSoftmax(op, opInfo);
1216 }
1217 if (opCode == tflite::BuiltinOperator_TANH) {
1218 return loadUnaryArithmetic(op, opInfo);
1219 }
1220 if (opCode == tflite::BuiltinOperator_PAD) {
1221 return loadPad(op, opInfo);
1222 }
1223 if (opCode == tflite::BuiltinOperator_TRANSPOSE) {
1224 return loadTranspose(op, opInfo);
1225 }
1226 if (opCode == tflite::BuiltinOperator_MEAN) {
1227 return loadReduce(op, opInfo);
1228 }
1229 if (opCode == tflite::BuiltinOperator_SUB) {
1230 return loadBinaryArithmetic(op, opInfo);
1231 }
1232 if (opCode == tflite::BuiltinOperator_DIV) {
1233 return loadBinaryArithmetic(op, opInfo);
1234 }
1235 if (opCode == tflite::BuiltinOperator_SQUEEZE) {
1236 return loadReshape(op, opInfo);
1237 }
1238 if (opCode == tflite::BuiltinOperator_STRIDED_SLICE) {
1239 return loadStridedSlice(op, opInfo);
1240 }
1241 if (opCode == tflite::BuiltinOperator_EXP) {
1242 return loadUnaryArithmetic(op, opInfo);
1243 }
1244 if (opCode == tflite::BuiltinOperator_SPLIT) {
1245 return loadSplit(op, opInfo);
1246 }
1247 if (opCode == tflite::BuiltinOperator_PRELU) {
1248 return loadBinaryArithmetic(op, opInfo);
1249 }
1250 if (opCode == tflite::BuiltinOperator_MAXIMUM) {
1251 return loadBinaryArithmetic(op, opInfo);
1252 }
1253 if (opCode == tflite::BuiltinOperator_ARG_MAX) {
1254 return loadArg(op, opInfo);
1255 }
1256 if (opCode == tflite::BuiltinOperator_MINIMUM) {
1257 return loadBinaryArithmetic(op, opInfo);
1258 }
1259 if (opCode == tflite::BuiltinOperator_LESS) {
1260 return loadBinaryArithmetic(op, opInfo);
1261 }
1262 if (opCode == tflite::BuiltinOperator_NEG) {
1263 return loadUnaryArithmetic(op, opInfo);
1264 }
1265 if (opCode == tflite::BuiltinOperator_GREATER) {
1266 return loadBinaryArithmetic(op, opInfo);
1267 }
1268 if (opCode == tflite::BuiltinOperator_GREATER_EQUAL) {
1269 return loadBinaryArithmetic(op, opInfo);
1270 }
1271 if (opCode == tflite::BuiltinOperator_LESS_EQUAL) {
1272 return loadBinaryArithmetic(op, opInfo);
1273 }
1274 if (opCode == tflite::BuiltinOperator_SLICE) {
1275 return loadSlice(op, opInfo);
1276 }
1277 if (opCode == tflite::BuiltinOperator_RESIZE_BILINEAR) {
1278 return loadResizeBilinear(op, opInfo);
1279 }
1280 if (opCode == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) {
1281 return loadResizeNearest(op, opInfo);
1282 }
1283 if (opCode == tflite::BuiltinOperator_SPACE_TO_DEPTH) {
1284 return loadSpaceToDepth(op, opInfo);
1285 }
1286 if (opCode == tflite::BuiltinOperator_DEPTH_TO_SPACE) {
1287 return loadDepthToSpace(op, opInfo);
1288 }
1289 if (opCode == tflite::BuiltinOperator_CAST) {
1290 return loadCast(op, opInfo);
1291 }
1292 if (opCode == tflite::BuiltinOperator_GATHER) {
1293 return loadGather(op, opInfo);
1294 }
1295 if (opCode == tflite::BuiltinOperator_GATHER_ND) {
1296 return loadGatherND(op, opInfo);
1297 }
1298 if (opCode == tflite::BuiltinOperator_SELECT) {
1299 return loadSelect(op, opInfo);
1300 }
1301 if (opCode == tflite::BuiltinOperator_SPACE_TO_BATCH_ND) {
1302 return loadSpaceToBatchNd(op, opInfo);
1303 }
1304 if (opCode == tflite::BuiltinOperator_BATCH_TO_SPACE_ND) {
1305 return loadBatchToSpaceNd(op, opInfo);
1306 }
1307 if (opCode == tflite::BuiltinOperator_SIN) {
1308 return loadUnaryArithmetic(op, opInfo);
1309 }
1310 if (opCode == tflite::BuiltinOperator_TILE) {
1311 return loadTile(op, opInfo);
1312 }
1313 if (opCode == tflite::BuiltinOperator_EXPAND_DIMS) {
1314 return loadReshape(op, opInfo);
1315 }
1316 if (opCode == tflite::BuiltinOperator_EQUAL) {
1317 return loadBinaryArithmetic(op, opInfo);
1318 }
1319 if (opCode == tflite::BuiltinOperator_NOT_EQUAL) {
1320 return loadBinaryArithmetic(op, opInfo);
1321 }
1322 if (opCode == tflite::BuiltinOperator_LOG) {
1323 return loadUnaryArithmetic(op, opInfo);
1324 }
1325 if (opCode == tflite::BuiltinOperator_SQRT) {
1326 return loadUnaryArithmetic(op, opInfo);
1327 }
1328 if (opCode == tflite::BuiltinOperator_RSQRT) {
1329 return loadUnaryArithmetic(op, opInfo);
1330 }
1331 if (opCode == tflite::BuiltinOperator_SHAPE) {
1332 return loadShape(op, opInfo);
1333 }
1334 if (opCode == tflite::BuiltinOperator_POW) {
1335 return loadBinaryArithmetic(op, opInfo);
1336 }
1337 if (opCode == tflite::BuiltinOperator_ARG_MIN) {
1338 return loadArg(op, opInfo);
1339 }
1340 if (opCode == tflite::BuiltinOperator_PACK) {
1341 return loadPack(op, opInfo);
1342 }
1343 if (opCode == tflite::BuiltinOperator_LOGICAL_OR) {
1344 return loadBinaryArithmetic(op, opInfo);
1345 }
1346 if (opCode == tflite::BuiltinOperator_LOGICAL_AND) {
1347 return loadBinaryArithmetic(op, opInfo);
1348 }
1349 if (opCode == tflite::BuiltinOperator_LOGICAL_NOT) {
1350 return loadUnaryArithmetic(op, opInfo);
1351 }
1352 if (opCode == tflite::BuiltinOperator_UNPACK) {
1353 return loadUnpack(op, opInfo);
1354 }
1355 if (opCode == tflite::BuiltinOperator_SQUARE) {
1356 return loadUnaryArithmetic(op, opInfo);
1357 }
1358 if (opCode == tflite::BuiltinOperator_LEAKY_RELU) {
1359 return loadUnaryArithmetic(op, opInfo);
1360 }
1361 if (opCode == tflite::BuiltinOperator_ABS) {
1362 return loadUnaryArithmetic(op, opInfo);
1363 }
1364 if (opCode == tflite::BuiltinOperator_CEIL) {
1365 return loadUnaryArithmetic(op, opInfo);
1366 }
1367 if (opCode == tflite::BuiltinOperator_COS) {
1368 return loadUnaryArithmetic(op, opInfo);
1369 }
1370 if (opCode == tflite::BuiltinOperator_QUANTIZE) {
1371 return loadUnaryArithmetic(op, opInfo);
1372 }
1373 if (opCode == tflite::BuiltinOperator_ROUND) {
1374 return loadUnaryArithmetic(op, opInfo);
1375 }
1376 // Load custom operators.
1377 if (opCode == tflite::BuiltinOperator_CUSTOM) {
1378 // Get custom operator code.
1379 std::string customOpCode;
1380 ASSIGN_VALUE_OR_RETURN_ERR(customOpCode, getOperatorCustomCode(op));
1381 // Get custom operator options.
1382 flexbuffers::Map opts = flexbuffers::Map::EmptyMap();
1383 ASSIGN_VALUE_OR_RETURN_ERR(opts, getOperatorCustomOpts(op));
1384 // Load custom operator.
1385 if (customOpCode == "TFLite_Detection_PostProcess") {
1386 return loadTFLiteDetectionPostProcess(op, opInfo, opts);
1387 }
1388 if (customOpCode == "AudioSpectrogram") {
1389 return loadTFLiteAudioSpectrogram(op, opInfo, opts);
1390 }
1391 if (customOpCode == "Mfcc") {
1392 return loadTFLiteMFCC(op, opInfo, opts);
1393 }
1394 }
1395 return MAKE_ERR(
1396 strFormat("TensorFlowLite: Operator type '%s' is not supported!",
1397 opInfo.type.c_str()));
1398}
1399
1400Error TFLiteModelLoader::loadUnaryArithmetic(const tflite::Operator *op,
1401 const OperatorInfo &opInfo) {
1402 NodeValue input;
1403 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1404 TypeRef outTy;
1405 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1406
1407 auto opCode = opInfo.code;
1408 NodeValue output;
1409 if (opCode == tflite::BuiltinOperator_LOGISTIC) {
1410 output = F_->createSigmoid(opInfo.name, outTy, input);
1411 } else if (opCode == tflite::BuiltinOperator_HARD_SWISH) {
1412 output = F_->createHardSwish(opInfo.name, outTy, input);
1413 } else if (opCode == tflite::BuiltinOperator_RELU) {
1414 output = F_->createRELU(opInfo.name, input, outTy);
1415 } else if (opCode == tflite::BuiltinOperator_RELU_N1_TO_1) {
1416 output = F_->createClip(opInfo.name, input, outTy, -1.0, 1.0);
1417 } else if (opCode == tflite::BuiltinOperator_RELU6) {
1418 output = F_->createClip(opInfo.name, input, outTy, 0.0, 6.0);
1419 } else if (opCode == tflite::BuiltinOperator_TANH) {
1420 output = F_->createTanh(opInfo.name, outTy, input);
1421 } else if (opCode == tflite::BuiltinOperator_EXP) {
1422 output = F_->createExp(opInfo.name, input);
1423 } else if (opCode == tflite::BuiltinOperator_LOG) {
1424 output = F_->createLog(opInfo.name, input, outTy);
1425 } else if (opCode == tflite::BuiltinOperator_LEAKY_RELU) {
1426 const auto *opts = op->builtin_options_as_LeakyReluOptions();
1427 float alpha = opts->alpha();
1428 output = F_->createLeakyRELU(opInfo.name, outTy, input, alpha);
1429 } else if (opCode == tflite::BuiltinOperator_SQUARE) {
1430 output = F_->createSquare(opInfo.name, outTy, input);
1431 } else if (opCode == tflite::BuiltinOperator_ABS) {
1432 output = F_->createAbs(opInfo.name, outTy, input);
1433 } else if (opCode == tflite::BuiltinOperator_NEG) {
1434 output = F_->createNeg(opInfo.name, outTy, input);
1435 } else if (opCode == tflite::BuiltinOperator_FLOOR) {
1436 output = F_->createFloor(opInfo.name, outTy, input);
1437 } else if (opCode == tflite::BuiltinOperator_CEIL) {
1438 output = F_->createCeil(opInfo.name, outTy, input);
1439 } else if (opCode == tflite::BuiltinOperator_ROUND) {
1440 output = F_->createRound(opInfo.name, outTy, input);
1441 } else if (opCode == tflite::BuiltinOperator_SQRT) {
1442 output = F_->createSqrt(opInfo.name, outTy, input);
1443 } else if (opCode == tflite::BuiltinOperator_RSQRT) {
1444 output = F_->createRsqrt(opInfo.name, outTy, input);
1445 } else if (opCode == tflite::BuiltinOperator_SIN) {
1446 output = F_->createSin(opInfo.name, outTy, input);
1447 } else if (opCode == tflite::BuiltinOperator_COS) {
1448 output = F_->createCos(opInfo.name, outTy, input);
1449 } else if (opCode == tflite::BuiltinOperator_LOGICAL_NOT) {
1450 output = F_->createNot(opInfo.name, input);
1451 } else if (opCode == tflite::BuiltinOperator_QUANTIZE) {
1452 if (input.getType()->isFPType()) {
1453 output = F_->createQuantize(opInfo.name, input, outTy);
1454 } else {
1455 output = F_->createRescaleQuantized(opInfo.name, input, outTy);
1456 }
1457 } else if (opCode == tflite::BuiltinOperator_DEQUANTIZE) {
1458 output = F_->createDequantize(opInfo.name, input, outTy);
1459 } else {
1460 return MAKE_ERR(opErrMsg(opInfo, "Unsupported unary arithmetic operator!"));
1461 }
1462 return setOutputNodeValue(op, output);
1463}
1464
1465Error TFLiteModelLoader::loadBinaryArithmetic(const tflite::Operator *op,
1466 const OperatorInfo &opInfo) {
1467 NodeValue LHS;
1468 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getInputNodeValue(op, 0));
1469 NodeValue RHS;
1470 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getInputNodeValue(op, 1));
1471 TypeRef outTy;
1472 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1473
1474 auto opCode = opInfo.code;
1475
1476 // skip operators with proper broadcast (no TFLite and Operator Tests for
1477 // others yet).
1478 if (opCode != tflite::BuiltinOperator_ADD &&
1479 opCode != tflite::BuiltinOperator_SUB &&
1480 opCode != tflite::BuiltinOperator_MUL &&
1481 opCode != tflite::BuiltinOperator_DIV &&
1482 opCode != tflite::BuiltinOperator_MIN &&
1483 opCode != tflite::BuiltinOperator_MAX) {
1484
1485 // LHS operand broadcasting.
1486 if (LHS.dims().size() < RHS.dims().size()) {
1487 unsigned_t axis = RHS.dims().size() - LHS.dims().size();
1488 LHS = F_->createBroadcast(opInfo.name + ".Broadcast", LHS, RHS.dims(),
1489 axis);
1490 }
1491
1492 // RHS operand broadcasting.
1493 if (RHS.dims().size() < LHS.dims().size()) {
1494 unsigned_t axis = LHS.dims().size() - RHS.dims().size();
1495 RHS = F_->createBroadcast(opInfo.name + ".Broadcast", RHS, LHS.dims(),
1496 axis);
1497 }
1498 }
1499
1500 NodeValue output;
1501 if (opCode == tflite::BuiltinOperator_ADD) {
1502 const auto *opts = op->builtin_options_as_AddOptions();
1503 output = F_->createNodeWithBroadcastOutTy<AddNode>(opInfo.name, -1, outTy,
1504 LHS, RHS);
1505 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1506 } else if (opCode == tflite::BuiltinOperator_MUL) {
1507 const auto *opts = op->builtin_options_as_MulOptions();
1508 output = F_->createNodeWithBroadcastOutTy<MulNode>(opInfo.name, -1, outTy,
1509 LHS, RHS);
1510 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1511 } else if (opCode == tflite::BuiltinOperator_SUB) {
1512 const auto *opts = op->builtin_options_as_SubOptions();
1513 output = F_->createNodeWithBroadcastOutTy<SubNode>(opInfo.name, -1, outTy,
1514 LHS, RHS);
1515 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1516 } else if (opCode == tflite::BuiltinOperator_DIV) {
1517 const auto *opts = op->builtin_options_as_DivOptions();
1518 output = F_->createNodeWithBroadcastOutTy<DivNode>(opInfo.name, -1, outTy,
1519 LHS, RHS);
1520 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1521 } else if (opCode == tflite::BuiltinOperator_POW) {
1522 output = F_->createPow(opInfo.name, outTy, LHS, RHS);
1523 } else if (opCode == tflite::BuiltinOperator_PRELU) {
1524 NodeValue slope =
1525 F_->createReshape(opInfo.name + ".reshape", RHS, outTy->dims());
1526 output = F_->createPRELU(opInfo.name, LHS, slope, outTy);
1527 } else if (opCode == tflite::BuiltinOperator_MAXIMUM) {
1528 output = F_->createNodeWithBroadcastOutTy<MaxNode>(opInfo.name, -1, outTy,
1529 LHS, RHS);
1530 } else if (opCode == tflite::BuiltinOperator_MINIMUM) {
1531 output = F_->createNodeWithBroadcastOutTy<MinNode>(opInfo.name, -1, outTy,
1532 LHS, RHS);
1533 } else if (opCode == tflite::BuiltinOperator_EQUAL) {
1534 output = F_->createCmpEQ(opInfo.name, LHS, RHS);
1535 } else if (opCode == tflite::BuiltinOperator_NOT_EQUAL) {
1536 output = F_->createCmpNEQ(opInfo.name, LHS, RHS);
1537 } else if (opCode == tflite::BuiltinOperator_LESS) {
1538 output = F_->createCmpLT(opInfo.name, LHS, RHS);
1539 } else if (opCode == tflite::BuiltinOperator_LESS_EQUAL) {
1540 output = F_->createCmpLTE(opInfo.name, LHS, RHS);
1541 } else if (opCode == tflite::BuiltinOperator_GREATER) {
1542 output = F_->createCmpGT(opInfo.name, LHS, RHS);
1543 } else if (opCode == tflite::BuiltinOperator_GREATER_EQUAL) {
1544 output = F_->createCmpGTE(opInfo.name, LHS, RHS);
1545 } else if (opCode == tflite::BuiltinOperator_LOGICAL_AND) {
1546 output = F_->createAnd(opInfo.name, LHS, RHS);
1547 } else if (opCode == tflite::BuiltinOperator_LOGICAL_OR) {
1548 output = F_->createOr(opInfo.name, LHS, RHS);
1549 } else {
1550 return MAKE_ERR(
1551 opErrMsg(opInfo, "Unsupported binary arithmetic operator!"));
1552 }
1553 return setOutputNodeValue(op, output);
1554}
1555
1556Error TFLiteModelLoader::loadPool2D(const tflite::Operator *op,
1557 const OperatorInfo &opInfo) {
1558 const auto *opts = op->builtin_options_as_Pool2DOptions();
1559 NodeValue input;
1560 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1561 TypeRef outTy;
1562 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1563
1564 // Output shape inference when not defined.
1565 bool outShapeUndefined;
1566 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
1567 if (outShapeUndefined) {
1568 dim_t outN = input.dims()[0];
1569 dim_t outH =
1570 getConvOutputSize(input.dims()[1], opts->filter_height(),
1571 opts->stride_h(), /* dilation */ 1, opts->padding());
1572 dim_t outW =
1573 getConvOutputSize(input.dims()[2], opts->filter_width(),
1574 opts->stride_w(), /* dilation */ 1, opts->padding());
1575 dim_t outC = input.dims()[3];
1576 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy,
1577 {outN, outH, outW, outC});
1578 }
1579
1580 ShapeNHWC inputShape = ShapeNHWC(input.dims());
1581 ShapeNHWC outputShape = ShapeNHWC(outTy->dims());
1582
1583 std::vector<unsigned_t> kernels = {
1584 static_cast<unsigned_t>(opts->filter_height()),
1585 static_cast<unsigned_t>(opts->filter_width())};
1586
1587 std::vector<unsigned_t> strides = {
1588 static_cast<unsigned_t>(opts->stride_h()),
1589 static_cast<unsigned_t>(opts->stride_w()),
1590 };
1591
1592 auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0],
1593 /* dilation */ 1, opts->padding());
1594 auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1],
1595 /* dilation */ 1, opts->padding());
1596 std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second,
1597 padsLR.second};
1598
1599 auto opCode = opInfo.code;
1600 NodeValue output;
1601 if (opCode == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
1602 // TFLite AvgPool does NOT include padded regions when normalizing.
1603 auto *node = F_->createAvgPool(opInfo.name, input, kernels, strides, pads,
1604 ConvolutionLayout::NHWC,
1605 /* countIncludePads */ false);
1606 output = node->getResult();
1607 } else if (opCode == tflite::BuiltinOperator_MAX_POOL_2D) {
1608 auto *node = F_->createMaxPool(opInfo.name, input, kernels, strides, pads);
1609 output = node->getResult();
1610 } else {
1611 return MAKE_ERR(opErrMsg(opInfo, "Unsupported Pool2D operator!"));
1612 }
1613
1614 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1615 return setOutputNodeValue(op, output);
1616}
1617
1618Error TFLiteModelLoader::loadConcat(const tflite::Operator *op,
1619 const OperatorInfo &opInfo) {
1620 const auto *opts = op->builtin_options_as_ConcatenationOptions();
1621 TypeRef outTy;
1622 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1623
1624 const size_t numInputs = op->inputs()->size();
1625 llvm::SmallVector<NodeValue, 4> inputs;
1626 inputs.reserve(numInputs);
1627 for (size_t idx = 0; idx < numInputs; ++idx) {
1628 NodeValue input;
1629 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, idx));
1630 inputs.push_back(input);
1631 }
1632
1633 // If this node is quantized and there is a mismatch between the input and
1634 // output quantization parameters then we pull Rescale nodes from the inputs
1635 // to match the output quantization parameters.
1636 if (outTy->isQuantizedType()) {
1637 for (size_t idx = 0; idx < numInputs; ++idx) {
1638 NodeValue input = inputs[idx];
1639 TypeRef inpTy = input.getType();
1640 RETURN_ERR_IF_NOT(
1641 inpTy->isQuantizedType(),
1642 opErrMsg(opInfo, "Mixed precision for input/output not supported!"));
1643 if ((inpTy->getScale() != outTy->getScale()) ||
1644 (inpTy->getOffset() != outTy->getOffset())) {
1645 TypeRef inpTyNew = mod_.uniqueTypeWithNewShape(outTy, inpTy->dims());
1646 auto *rescaleNode = F_->createRescaleQuantized(
1647 opInfo.name + ".Rescale" + std::to_string(idx), input, inpTyNew);
1648 inputs[idx] = rescaleNode->getResult();
1649 }
1650 }
1651 }
1652
1653 unsigned_t axis;
1654 ASSIGN_VALUE_OR_RETURN_ERR(
1655 axis, getPositiveAxis<unsigned_t>(opts->axis(), outTy->dims().size()));
1656
1657 NodeValue output = F_->createConcat(opInfo.name, inputs, axis, outTy);
1658 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1659 return setOutputNodeValue(op, output);
1660}
1661
1662Error TFLiteModelLoader::loadConv2D(const tflite::Operator *op,
1663 const OperatorInfo &opInfo) {
1664 const auto *opts = op->builtin_options_as_Conv2DOptions();
1665 NodeValue input;
1666 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1667 NodeValue filter;
1668 ASSIGN_VALUE_OR_RETURN_ERR(filter, getInputNodeValue(op, 1));
1669 NodeValue bias;
1670 ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2));
1671 TypeRef outTy;
1672 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1673
1674 // Output shape inference when not defined.
1675 bool outShapeUndefined;
1676 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
1677 if (outShapeUndefined) {
1678 dim_t outN = input.dims()[0];
1679 dim_t outH =
1680 getConvOutputSize(input.dims()[1], filter.dims()[1], opts->stride_h(),
1681 opts->dilation_h_factor(), opts->padding());
1682 dim_t outW =
1683 getConvOutputSize(input.dims()[2], filter.dims()[2], opts->stride_w(),
1684 opts->dilation_w_factor(), opts->padding());
1685 dim_t outC = filter.dims()[0];
1686 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy,
1687 {outN, outH, outW, outC});
1688 }
1689
1690 ShapeNHWC inputShape = ShapeNHWC(input.dims());
1691 ShapeNHWC filterShape = ShapeNHWC(filter.dims());
1692 ShapeNHWC outputShape = ShapeNHWC(outTy->dims());
1693
1694 std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(filterShape.h),
1695 static_cast<unsigned_t>(filterShape.w)};
1696
1697 std::vector<unsigned_t> strides = {
1698 static_cast<unsigned_t>(opts->stride_h()),
1699 static_cast<unsigned_t>(opts->stride_w()),
1700 };
1701
1702 std::vector<unsigned_t> dilations = {
1703 static_cast<unsigned_t>(opts->dilation_h_factor()),
1704 static_cast<unsigned_t>(opts->dilation_w_factor()),
1705 };
1706
1707 auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0],
1708 dilations[0], opts->padding());
1709 auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1],
1710 dilations[1], opts->padding());
1711 std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second,
1712 padsLR.second};
1713
1714 // There are TensorFlowLite models which have only the weights quantized
1715 // to INT8 (the rest of the operands being FLOAT32). Since Glow does not
1716 // support mixed precision operation we dequantize the weights.
1717 if (input.getType()->isFPType() && filter.getType()->isQuantizedType() &&
1718 bias.getType()->isFPType() && outTy->isFPType()) {
1719 filter = F_->createDequantize(opInfo.name + ".Dequantize", filter,
1720 outTy->getElementType());
1721 }
1722
1723 // Check whether this operator is quantized per axis.
1724 bool isPerAxisQuantized;
1725 Constant *filterScales = nullptr;
1726 Constant *filterOffsets = nullptr;
1727 Constant *biasScales = nullptr;
1728 Constant *biasOffsets = nullptr;
1729 ASSIGN_VALUE_OR_RETURN_ERR(isPerAxisQuantized,
1730 isConv2DPerAxisQuantized(op, opInfo, filterScales,
1731 filterOffsets, biasScales,
1732 biasOffsets));
1733
1734 // Create convolution node.
1735 NodeValue output;
1736 if (isPerAxisQuantized) {
1737 // Check that filter and bias are constants.
1738 RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(filter.getNode()),
1739 opErrMsg(opInfo, "Filter must be constant!"));
1740 RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(bias.getNode()),
1741 opErrMsg(opInfo, "Bias must be constant!"));
1742 // Create ChannelwiseQuantizedConvolution node.
1743 output = F_->createChannelwiseQuantizedConv(
1744 opInfo.name, input, filter, bias, filterScales, filterOffsets,
1745 biasScales, biasOffsets, outTy, kernels, strides, pads, /* group */ 1,
1746 dilations, /* quantizeFilter */ false, /* quantizeBias */ false);
1747 } else {
1748 // Check bias quantization parameters.
1749 RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, filter, bias));
1750 // Create Convolution node.
1751 output = F_->createConv(opInfo.name, input, filter, bias, outTy, kernels,
1752 strides, pads, /* group */ 1, dilations);
1753 }
1754
1755 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1756 return setOutputNodeValue(op, output);
1757}
1758
1759Error TFLiteModelLoader::loadDepthwiseConv2D(const tflite::Operator *op,
1760 const OperatorInfo &opInfo) {
1761 const auto *opts = op->builtin_options_as_DepthwiseConv2DOptions();
1762 NodeValue input;
1763 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1764 NodeValue filter;
1765 ASSIGN_VALUE_OR_RETURN_ERR(filter, getInputNodeValue(op, 1));
1766 NodeValue bias;
1767 ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2));
1768 TypeRef outTy;
1769 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1770
1771 // Output shape inference when not defined.
1772 bool outShapeUndefined;
1773 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
1774 if (outShapeUndefined) {
1775 dim_t outN = input.dims()[0];
1776 dim_t outH =
1777 getConvOutputSize(input.dims()[1], filter.dims()[1], opts->stride_h(),
1778 opts->dilation_h_factor(), opts->padding());
1779 dim_t outW =
1780 getConvOutputSize(input.dims()[2], filter.dims()[2], opts->stride_w(),
1781 opts->dilation_w_factor(), opts->padding());
1782 dim_t outC = input.dims()[3] * opts->depth_multiplier();
1783 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy,
1784 {outN, outH, outW, outC});
1785 }
1786
1787 ShapeNHWC inputShape = ShapeNHWC(input.dims());
1788 ShapeNHWC filterShape = ShapeNHWC(filter.dims());
1789 ShapeNHWC outputShape = ShapeNHWC(outTy->dims());
1790
1791 std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(filterShape.h),
1792 static_cast<unsigned_t>(filterShape.w)};
1793
1794 std::vector<unsigned_t> strides = {
1795 static_cast<unsigned_t>(opts->stride_h()),
1796 static_cast<unsigned_t>(opts->stride_w()),
1797 };
1798
1799 std::vector<unsigned_t> dilations = {1, 1};
1800 if (opInfo.version >= 2) {
1801 dilations = {static_cast<unsigned_t>(opts->dilation_h_factor()),
1802 static_cast<unsigned_t>(opts->dilation_w_factor())};
1803 }
1804
1805 auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0],
1806 dilations[0], opts->padding());
1807 auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1],
1808 dilations[1], opts->padding());
1809 std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second,
1810 padsLR.second};
1811
1812 // Convolution group is inputChannels / filterChannels = inputChannels.
1813 unsigned_t group = input.dims().back();
1814
1815 // There are TensorFlowLite models which have only the weights quantized
1816 // to INT8 (the rest of the operands being FLOAT32). Since Glow does not
1817 // support mixed precision operation we dequantize the weights.
1818 if (input.getType()->isFPType() && filter.getType()->isQuantizedType() &&
1819 bias.getType()->isFPType() && outTy->isFPType()) {
1820 filter = F_->createDequantize(opInfo.name + ".Dequantize", filter,
1821 outTy->getElementType());
1822 }
1823
1824 // Check whether this operator is quantized per axis.
1825 bool isPerAxisQuantized;
1826 Constant *filterScales = nullptr;
1827 Constant *filterOffsets = nullptr;
1828 Constant *biasScales = nullptr;
1829 Constant *biasOffsets = nullptr;
1830 ASSIGN_VALUE_OR_RETURN_ERR(isPerAxisQuantized,
1831 isConv2DPerAxisQuantized(op, opInfo, filterScales,
1832 filterOffsets, biasScales,
1833 biasOffsets));
1834
1835 // Transpose filter from CHWN to NHWC in-place without using a Reshape
1836 // node because further down the ChannelwiseQuantizedConvolution requires
1837 // the filter to be a Constant.
1838 RETURN_ERR_IF_NOT(filter.dims().size() == 4,
1839 opErrMsg(opInfo, "Filter should be 4D!"));
1840 if (isPerAxisQuantized) {
1841 Constant *filterC = llvm::dyn_cast<Constant>(filter.getNode());
1842 RETURN_ERR_IF_NOT(filterC, opErrMsg(opInfo, "Filter must be constant!"));
1843 TypeRef filterTy = filterC->getType();
1844 auto filterDims = filterTy->dims();
1845 TypeRef newFilterTy = mod_.uniqueTypeWithNewShape(
1846 filterTy, {filterDims[3], filterDims[1], filterDims[2], filterDims[0]});
1847 Tensor newFilterT = Tensor(newFilterTy);
1848 filterC->getPayload().transpose(&newFilterT, {3, 1, 2, 0});
1849 Constant *newFilterC = mod_.createConstant(
1850 filterC->getName().str() + ".Reshape", std::move(newFilterT), "NHWC");
1851 filter = newFilterC->getOutput();
1852 } else {
1853 filter = F_->createTranspose(opInfo.name + ".Transpose", filter,
1854 {3, 1, 2, 0}, "NHWC");
1855 }
1856 RETURN_ERR_IF_NOT(filter.dims().back() == 1,
1857 opErrMsg(opInfo, "Filter should have 1 channel!"));
1858
1859 // Create convolution node.
1860 NodeValue output;
1861 if (isPerAxisQuantized) {
1862 // Check that filter and bias are constants.
1863 RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(filter.getNode()),
1864 opErrMsg(opInfo, "Filter must be constant!"));
1865 RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(bias.getNode()),
1866 opErrMsg(opInfo, "Bias must be constant!"));
1867 // Create ChannelwiseQuantizedConvolution node.
1868 output = F_->createChannelwiseQuantizedConv(
1869 opInfo.name, input, filter, bias, filterScales, filterOffsets,
1870 biasScales, biasOffsets, outTy, kernels, strides, pads, group,
1871 dilations, /* quantizeFilter */ false, /* quantizeBias */ false);
1872 } else {
1873 // Check bias quantization parameters.
1874 RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, filter, bias));
1875 // Create Convolution node.
1876 output = F_->createConv(opInfo.name, input, filter, bias, outTy, kernels,
1877 strides, pads, group, dilations);
1878 }
1879
1880 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1881 return setOutputNodeValue(op, output);
1882}
1883
1884Error TFLiteModelLoader::loadFullyConnected(const tflite::Operator *op,
1885 const OperatorInfo &opInfo) {
1886 const auto *opts = op->builtin_options_as_FullyConnectedOptions();
1887 NodeValue input;
1888 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1889 NodeValue weights;
1890 ASSIGN_VALUE_OR_RETURN_ERR(weights, getInputNodeValue(op, 1));
1891 NodeValue bias;
1892 ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2));
1893 TypeRef outTy;
1894 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1895
1896 // If bias is not used we create one initialized with 0.
1897 if (!bias.getNode()) {
1898 if (input.getType()->isQuantizedType()) {
1899 float biasScale =
1900 input.getType()->getScale() * weights.getType()->getScale();
1901 int32_t biasOffset = 0;
1902 auto *biasC =
1903 mod_.createConstant(ElemKind::Int32QTy, {weights.dims()[0]},
1904 biasScale, biasOffset, opInfo.name + ".bias");
1905 biasC->getPayloadMutable().zero();
1906 bias = biasC;
1907 } else {
1908 auto *biasC = mod_.createConstant(ElemKind::FloatTy, {weights.dims()[0]},
1909 opInfo.name + ".bias");
1910 biasC->getPayloadMutable().zero();
1911 bias = biasC;
1912 }
1913 }
1914
1915 RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, weights, bias));
1916
1917 // Output shape inference when not defined.
1918 bool outShapeUndefined;
1919 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
1920 if (outShapeUndefined) {
1921 dim_t outN = input.dims()[0];
1922 dim_t outC = weights.dims()[0];
1923 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, {outN, outC});
1924 }
1925
1926 // There are TensorFlowLite models which have only the weights quantized
1927 // to INT8 (the rest of the operands being FLOAT32). Since Glow does not
1928 // support mixed precision operation we dequantize the weights.
1929 if (input.getType()->isFPType() && weights.getType()->isQuantizedType() &&
1930 bias.getType()->isFPType() && outTy->isFPType()) {
1931 weights = F_->createDequantize(opInfo.name + ".Dequantize", weights,
1932 outTy->getElementType());
1933 }
1934
1935 if (opInfo.version >= 2) {
1936 RETURN_ERR_IF_NOT(
1937 opts->weights_format() ==
1938 tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
1939 opErrMsg(opInfo, "Only default weights format is supported!"));
1940 }
1941
1942 bool keepDims = false;
1943 if (opInfo.version >= 5) {
1944 keepDims = opts->keep_num_dims();
1945 }
1946
1947 // Transpose weights.
1948 RETURN_ERR_IF_NOT(weights.dims().size() == 2,
1949 opErrMsg(opInfo, "Weights should be 2D!"));
1950 weights = F_->createTranspose(opInfo.name + ".Transpose", weights, {1, 0});
1951
1952 // For an input with shape [D(0), D(1), ... , D(N-1)] if:
1953 // keep_num_dims is FALSE then we flatten the input into:
1954 // [D(0), D(1) x D(2) x ... x D(N-1)] (axis = 1).
1955 // keep_num_dims options is TRUE then we flatten the input into:
1956 // [D(0) x D(1) x ... x D(N-2), D(N-1)] (axis = N-1).
1957 unsigned_t axis = keepDims ? (input.dims().size() - 1) : 1;
1958 NodeValue output =
1959 F_->createFullyConnected(opInfo.name, input, weights, bias, outTy, axis);
1960 RETURN_IF_ERR(addActivation(output, opts->fused_activation_function()));
1961
1962 // Expand output dims if necessary.
1963 if (keepDims) {
1964 std::vector<dim_t> outputDims = input.dims();
1965 outputDims.back() = output.dims().back();
1966 output = F_->createReshape(opInfo.name + ".Reshape", output, outputDims);
1967 }
1968 return setOutputNodeValue(op, output);
1969}
1970
1971Error TFLiteModelLoader::loadReshape(const tflite::Operator *op,
1972 const OperatorInfo &opInfo) {
1973 NodeValue input;
1974 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
1975 TypeRef outTy;
1976 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
1977
1978 // Output shape inference when not defined.
1979 bool outShapeUndefined;
1980 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
1981 if (outShapeUndefined) {
1982 if (const auto *opts = op->builtin_options_as_ReshapeOptions()) {
1983 auto *newShape = opts->new_shape();
1984 std::vector<dim_t> outDims(newShape->size());
1985 dim_t dimProd = 1;
1986 for (size_t idx = 0; idx < newShape->size(); ++idx) {
1987 auto newDim = (*newShape)[idx];
1988 if (newDim != -1) {
1989 outDims[idx] = newDim;
1990 dimProd *= newDim;
1991 }
1992 }
1993 for (size_t idx = 0; idx < newShape->size(); ++idx) {
1994 auto newDim = (*newShape)[idx];
1995 if (newDim == -1) {
1996 outDims[idx] = input.getType()->size() / dimProd;
1997 }
1998 }
1999 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, outDims);
2000 }
2001 }
2002
2003 // Note: The Reshape node has a second input operand which provides
2004 // the new shape but the documentation states that is should be ignored
2005 // and the 'new_shape' attribute should be used instead. Moreover, in
2006 // this case we are not using not even the 'new_shape' attribute because
2007 // we have the output type directly available. We are using this logic
2008 // also for loading other operators: Squeeze, ExpandDims.
2009 NodeValue output = F_->createReshape(opInfo.name, input, outTy->dims());
2010 return setOutputNodeValue(op, output);
2011}
2012
2013Error TFLiteModelLoader::loadSoftmax(const tflite::Operator *op,
2014 const OperatorInfo &opInfo) {
2015 const auto *opts = op->builtin_options_as_SoftmaxOptions();
2016 NodeValue input;
2017 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2018 TypeRef outTy;
2019 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2020
2021 // Output shape inference when not defined.
2022 bool outShapeUndefined;
2023 ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0));
2024 if (outShapeUndefined) {
2025 outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, input.dims());
2026 }
2027
2028 RETURN_ERR_IF_NOT(input.dims().size() >= 2,
2029 opErrMsg(opInfo, "Input rank must be >= 2!"));
2030 float beta = opts->beta();
2031
2032 // Create a constant to store labels to be used in SoftMaxGradNode.
2033 auto selected =
2034 mod_.createConstant(ElemKind::Int64ITy, {input.dims()[0], 1}, "selected");
2035
2036 NodeValue output;
2037 if (tfliteFloatSoftmaxOpt) {
2038 // We dequantize the input if it is quantized type.
2039 if (input.getType()->isQuantizedType()) {
2040 input = F_->createDequantize(opInfo.name + ".Dequantize", input,
2041 ElemKind::FloatTy);
2042 }
2043
2044 // Create float Softmax regardless of the type defined in the model.
2045 output = F_->createSoftMax(opInfo.name, input, selected, nullptr, beta);
2046
2047 // If target output type is quantized we quantize the float output of the
2048 // Softmax but only if it is not an output placeholder in which case we
2049 // allow the output placeholder to remain float even though it was defined
2050 // as quantized in the original model.
2051 bool isFinal;
2052 ASSIGN_VALUE_OR_RETURN_ERR(isFinal, isOperatorOutputFinalTensor(op, 0));
2053 if (outTy->isQuantizedType() && !isFinal) {
2054 output = F_->createQuantize(opInfo.name + ".Quantize", output, outTy);
2055 }
2056 } else {
2057 output = F_->createSoftMax(opInfo.name, input, selected, outTy, beta);
2058 }
2059 return setOutputNodeValue(op, output);
2060}
2061
2062Error TFLiteModelLoader::loadLogSoftmax(const tflite::Operator *op,
2063 const OperatorInfo &opInfo) {
2064 NodeValue input;
2065 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2066 TypeRef outTy;
2067 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2068
2069 // Create a constant to store labels to be used in SoftMaxGradNode.
2070 auto selected =
2071 mod_.createConstant(ElemKind::Int64ITy, {input.dims()[0], 1}, "selected");
2072
2073 NodeValue output = F_->createLogSoftMax(opInfo.name, input, selected, outTy);
2074 return setOutputNodeValue(op, output);
2075}
2076
2077Error TFLiteModelLoader::loadPad(const tflite::Operator *op,
2078 const OperatorInfo &opInfo) {
2079 NodeValue input;
2080 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2081 NodeValue pads;
2082 ASSIGN_VALUE_OR_RETURN_ERR(pads, getInputNodeValue(op, 1));
2083 TypeRef outTy;
2084 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2085
2086 // Validate paddings shape.
2087 auto numDims = input.dims().size();
2088 RETURN_ERR_IF_NOT(pads.dims().size() == 2,
2089 opErrMsg(opInfo, "Paddings should be 2D!"));
2090 RETURN_ERR_IF_NOT(pads.dims()[0] == numDims,
2091 opErrMsg(opInfo, "Paddings 1st dimension should match the "
2092 "input rank!"));
2093 RETURN_ERR_IF_NOT(pads.dims()[1] == 2,
2094 opErrMsg(opInfo, "Paddings 2nd dimensions should be 2!"));
2095
2096 // TFLite paddings are stored as start(D1),stop(D1),start(D2),stop(D2),etc.
2097 Constant *padsC = llvm::dyn_cast<Constant>(pads.getNode());
2098 RETURN_ERR_IF_NOT(padsC,
2099 opErrMsg(opInfo, "Non constant 'paddings' not supported!"));
2100 RETURN_ERR_IF_NOT(padsC->getType()->getElementType() == ElemKind::Int32ITy,
2101 opErrMsg(opInfo, "Paddings should have INT32 type!"));
2102 auto padsH = padsC->getPayload().getHandle<int32_t>();
2103 std::vector<int> padsVec(padsH.size());
2104 for (dim_t dim = 0; dim < numDims; ++dim) {
2105 auto padStart = padsH.at({dim, 0});
2106 auto padStop = padsH.at({dim, 1});
2107 RETURN_ERR_IF_NOT((padStart >= 0) && (padStop >= 0),
2108 opErrMsg(opInfo, "Invalid negative padding value!"));
2109 padsVec[0 * numDims + dim] = static_cast<int>(padStart);
2110 padsVec[1 * numDims + dim] = static_cast<int>(padStop);
2111 }
2112
2113 NodeValue output = F_->createPad(opInfo.name, input, outTy,
2114 PaddingMode::CONSTANT, padsVec, 0.f);
2115 return setOutputNodeValue(op, output);
2116}
2117
2118Error TFLiteModelLoader::loadTranspose(const tflite::Operator *op,
2119 const OperatorInfo &opInfo) {
2120 NodeValue input;
2121 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2122 NodeValue perm;
2123 ASSIGN_VALUE_OR_RETURN_ERR(perm, getInputNodeValue(op, 1));
2124
2125 Constant *permC = llvm::dyn_cast<Constant>(perm.getNode());
2126 RETURN_ERR_IF_NOT(
2127 permC, opErrMsg(opInfo, "Non constant permutation not supported!"));
2128 RETURN_ERR_IF_NOT(permC->getType()->getElementType() == ElemKind::Int32ITy,
2129 opErrMsg(opInfo, "Permutation should have INT32 type!"));
2130 auto permH = permC->getPayload().getHandle<int32_t>();
2131
2132 std::vector<unsigned_t> shuffle;
2133 for (size_t idx = 0; idx < permH.size(); ++idx) {
2134 int32_t dim = permH.raw(idx);
2135 RETURN_ERR_IF_NOT(dim >= 0, opErrMsg(opInfo, "Invalid permutation value!"));
2136 shuffle.push_back(static_cast<unsigned_t>(dim));
2137 }
2138
2139 NodeValue output = F_->createTranspose(opInfo.name, input, shuffle);
2140 return setOutputNodeValue(op, output);
2141}
2142
2143Error TFLiteModelLoader::loadReduce(const tflite::Operator *op,
2144 const OperatorInfo &opInfo) {
2145 const auto *opts = op->builtin_options_as_ReducerOptions();
2146 NodeValue input;
2147 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2148 NodeValue axes;
2149 ASSIGN_VALUE_OR_RETURN_ERR(axes, getInputNodeValue(op, 1));
2150 TypeRef outTy;
2151 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2152
2153 std::vector<unsigned_t> axesVal;
2154 ASSIGN_VALUE_OR_RETURN_ERR(axesVal,
2155 loadAxes<unsigned_t>(opInfo, axes, input));
2156
2157 bool keepDims = opts->keep_dims();
2158
2159 // Try to load ReduceMean as AveragePool if equivalent.
2160 // TODO: Move this into the GraphOptimizer once Glow supports reduce
2161 // operators with multiple axes.
2162 if (opInfo.code == tflite::BuiltinOperator_MEAN && axesVal.size() == 2 &&
2163 axesVal.at(0) == 1 && axesVal.at(1) == 2 && input.dims().size() == 4) {
2164 std::vector<unsigned_t> kernels = {
2165 static_cast<unsigned_t>(input.dims()[1]),
2166 static_cast<unsigned_t>(input.dims()[2])};
2167 std::vector<unsigned_t> strides = {1, 1};
2168 std::vector<unsigned_t> pads = {0, 0, 0, 0};
2169 NodeValue output = F_->createAvgPool(opInfo.name, input, kernels, strides,
2170 pads, ConvolutionLayout::NHWC,
2171 /* countIncludePads */ false);
2172 if (!keepDims) {
2173 output = F_->createSqueeze(opInfo.name + ".Squeeze", output, {1, 2});
2174 }
2175 return setOutputNodeValue(op, output);
2176 }
2177
2178 // Currently the Glow reduce operators do not support multiple axes so we
2179 // create chained reduce operators with single axis.
2180 // TODO: When Glow supports reduce operators with multiple axes remove this!
2181 auto opCode = opInfo.code;
2182 NodeValue output = input;
2183 for (size_t idx = 0, end = axesVal.size(); idx < end; ++idx) {
2184 // Current axis value.
2185 unsigned_t axisVal = axesVal[idx];
2186 if (!keepDims) {
2187 axisVal = axisVal - idx;
2188 }
2189 // Current output type.
2190 ShapeVector outDimsCurr(output.dims().begin(), output.dims().end());
2191 outDimsCurr.erase(outDimsCurr.begin() + axisVal);
2192 auto outTypeCurr = mod_.uniqueTypeWithNewShape(outTy, outDimsCurr);
2193 // Create reduce operator.
2194 if (opCode == tflite::BuiltinOperator_MEAN) {
2195 output = F_->createBatchedReduceMean(opInfo.name, outTypeCurr, output,
2196 {axisVal});
2197 // The BatchedReduceMean reduces the output dimension and hence we expand
2198 // the output dimensions if keepDims is true.
2199 if (keepDims) {
2200 output =
2201 F_->createExpandDims(opInfo.name + ".Expand", output, {axisVal});
2202 }
2203 } else {
2204 return MAKE_ERR(opErrMsg(opInfo, "Unsupported Reduce operator!"));
2205 }
2206 }
2207 return setOutputNodeValue(op, output);
2208}
2209
2210Error TFLiteModelLoader::loadSplit(const tflite::Operator *op,
2211 const OperatorInfo &opInfo) {
2212 const auto *opts = op->builtin_options_as_SplitOptions();
2213 NodeValue axis;
2214 ASSIGN_VALUE_OR_RETURN_ERR(axis, getInputNodeValue(op, 0));
2215 NodeValue input;
2216 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 1));
2217
2218 unsigned_t axisVal;
2219 ASSIGN_VALUE_OR_RETURN_ERR(axisVal,
2220 loadAxis<unsigned_t>(opInfo, axis, input));
2221
2222 unsigned_t numSplits = static_cast<unsigned_t>(opts->num_splits());
2223 RETURN_ERR_IF_NOT(
2224 input.dims()[axisVal] % numSplits == 0,
2225 opErrMsg(
2226 opInfo,
2227 "Input dimension should be divisible by 'num_splits' along axis!"));
2228
2229 std::vector<SliceNode *> outputNodes;
2230 F_->createSplit(opInfo.name, input, numSplits, axisVal, {}, outputNodes);
2231 std::vector<NodeValue> outputNodeValues(outputNodes.size(), nullptr);
2232 for (size_t idx = 0, end = outputNodes.size(); idx < end; ++idx) {
2233 outputNodeValues[idx] = outputNodes[idx]->getResult();
2234 }
2235 return setOutputNodeValues(op, outputNodeValues);
2236}
2237
2238Error TFLiteModelLoader::loadArg(const tflite::Operator *op,
2239 const OperatorInfo &opInfo) {
2240 NodeValue input;
2241 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2242 NodeValue axis;
2243 ASSIGN_VALUE_OR_RETURN_ERR(axis, getInputNodeValue(op, 1));
2244 TypeRef outTy;
2245 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2246
2247 unsigned_t axisVal;
2248 ASSIGN_VALUE_OR_RETURN_ERR(axisVal,
2249 loadAxis<unsigned_t>(opInfo, axis, input));
2250
2251 auto opCode = opInfo.code;
2252 NodeValue output = nullptr;
2253 if (opCode == tflite::BuiltinOperator_ARG_MAX) {
2254 output = F_->createArgMax(opInfo.name, input, axisVal, /* keepDims */ false,
2255 outTy->getElementType());
2256 } else if (opCode == tflite::BuiltinOperator_ARG_MIN) {
2257 output = F_->createArgMin(opInfo.name, input, axisVal, /* keepDims */ false,
2258 outTy->getElementType());
2259 } else {
2260 return MAKE_ERR(opErrMsg(opInfo, "Unsupported Arg operator!"));
2261 }
2262 return setOutputNodeValue(op, output);
2263}
2264
2265Error TFLiteModelLoader::loadShape(const tflite::Operator *op,
2266 const OperatorInfo &opInfo) {
2267 NodeValue input;
2268 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2269 TypeRef outTy;
2270 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2271
2272 Constant *shapeC = F_->getParent()->createConstant(outTy, opInfo.name);
2273 auto inputDims = input.getType()->dims();
2274 RETURN_ERR_IF_NOT(outTy->dims().size() == 1,
2275 opErrMsg(opInfo, "Output should be 1D!"));
2276 RETURN_ERR_IF_NOT(outTy->dims()[0] == inputDims.size(),
2277 opErrMsg(opInfo, "Output length should match input rank!"));
2278 if (outTy->getElementType() == ElemKind::Int32ITy) {
2279 auto shapeH = shapeC->getPayloadMutable().getHandle<int32_t>();
2280 for (size_t idx = 0; idx < inputDims.size(); ++idx) {
2281 shapeH.raw(idx) = static_cast<int32_t>(inputDims[idx]);
2282 }
2283 } else if (outTy->getElementType() == ElemKind::Int64ITy) {
2284 auto shapeH = shapeC->getPayloadMutable().getHandle<int64_t>();
2285 for (size_t idx = 0; idx < inputDims.size(); ++idx) {
2286 shapeH.raw(idx) = static_cast<int64_t>(inputDims[idx]);
2287 }
2288 } else {
2289 return MAKE_ERR(opErrMsg(opInfo, "Output should be INT32 or INT64!"));
2290 }
2291
2292 return setOutputNodeValue(op, shapeC);
2293}
2294
2295Error TFLiteModelLoader::loadSlice(const tflite::Operator *op,
2296 const OperatorInfo &opInfo) {
2297 NodeValue input;
2298 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2299 NodeValue begin;
2300 ASSIGN_VALUE_OR_RETURN_ERR(begin, getInputNodeValue(op, 1));
2301 TypeRef outTy;
2302 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2303 // Note: Slice has a third input operand 'size' which provides the size of
2304 // the output slice. We derive here the output slice size based on outTy.
2305
2306 Constant *beginC = llvm::dyn_cast<Constant>(begin.getNode());
2307 RETURN_ERR_IF_NOT(beginC,
2308 opErrMsg(opInfo, "Non constant begin not supported!"));
2309 RETURN_ERR_IF_NOT(beginC->getType()->getElementType() == ElemKind::Int32ITy,
2310 opErrMsg(opInfo, "Begin should have INT32 type!"));
2311 auto beginH = beginC->getPayload().getHandle<int32_t>();
2312
2313 std::vector<dim_t> start;
2314 for (size_t idx = 0; idx < beginH.size(); ++idx) {
2315 int32_t dimStart = beginH.raw(idx);
2316 RETURN_ERR_IF_NOT(dimStart >= 0, opErrMsg(opInfo, "Invalid begin value!"));
2317 start.push_back(static_cast<dim_t>(dimStart));
2318 }
2319
2320 NodeValue output = F_->createSlice(opInfo.name, input, start, outTy);
2321 return setOutputNodeValue(op, output);
2322}
2323
2324Error TFLiteModelLoader::loadStridedSlice(const tflite::Operator *op,
2325 const OperatorInfo &opInfo) {
2326 const auto *opts = op->builtin_options_as_StridedSliceOptions();
2327 NodeValue input;
2328 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2329 NodeValue begin;
2330 ASSIGN_VALUE_OR_RETURN_ERR(begin, getInputNodeValue(op, 1));
2331 NodeValue end;
2332 ASSIGN_VALUE_OR_RETURN_ERR(end, getInputNodeValue(op, 2));
2333 NodeValue strides;
2334 ASSIGN_VALUE_OR_RETURN_ERR(strides, getInputNodeValue(op, 3));
2335
2336 // You can find more information about this operator here:
2337 // https://www.tensorflow.org/api_docs/python/tf/strided_slice
2338 // https://www.tensorflow.org/mlir/tfl_ops#tflstrided_slice_tflstridedsliceop
2339
2340 // We only support strided slice if begin/end/strides are constants. This
2341 // is because Glow is statically typed.
2342 auto inpDims = input.dims();
2343 std::vector<dim_t> beginV;
2344 ASSIGN_VALUE_OR_RETURN_ERR(beginV, loadArray<dim_t>(opInfo, begin));
2345 std::vector<dim_t> endV;
2346 ASSIGN_VALUE_OR_RETURN_ERR(endV, loadArray<dim_t>(opInfo, end));
2347 std::vector<int32_t> stridesV;
2348 ASSIGN_VALUE_OR_RETURN_ERR(stridesV, loadArray<int32_t>(opInfo, strides));
2349
2350 // The strides must be non-zero.
2351 for (size_t idx = 0; idx < stridesV.size(); ++idx) {
2352 RETURN_ERR_IF_NOT(stridesV[idx] != 0,
2353 opErrMsg(opInfo, "Strides must be non-zero!"));
2354 }
2355
2356 // The begin/end/strides must have same size.
2357 RETURN_ERR_IF_NOT(
2358 beginV.size() == endV.size(),
2359 opErrMsg(opInfo, "Begin/end/strides should have same length!"));
2360 RETURN_ERR_IF_NOT(
2361 beginV.size() == stridesV.size(),
2362 opErrMsg(opInfo, "Begin/end/strides should have same length!"));
2363
2364 // The begin/end/strides length must be equal to the input rank.
2365 RETURN_ERR_IF_NOT(beginV.size() == inpDims.size(),
2366 opErrMsg(opInfo, "Begin/end/strides length invalid!"));
2367
2368 // Get attributes.
2369 int32_t begin_mask = opts->begin_mask();
2370 int32_t end_mask = opts->end_mask();
2371 int32_t ellipsis_mask = opts->ellipsis_mask();
2372 int32_t new_axis_mask = opts->new_axis_mask();
2373 int32_t shrink_axis_mask = opts->shrink_axis_mask();
2374
2375 // Utility to extract the Nth bit from a mask. Returns 0 or 1.
2376 auto getMaskBit = [](uint32_t mask, size_t n) -> int {
2377 assert((0 <= n) && (n <= 31) && "Bit number exceeded!");
2378 return (mask >> n) & 0x01;
2379 };
2380
2381 // If the ith bit of begin_mask is set, begin[i] is ignored and the fullest
2382 // possible range in that dimension is used instead.
2383 if (begin_mask) {
2384 for (size_t idx = 0; idx < beginV.size(); ++idx) {
2385 if (getMaskBit(begin_mask, idx)) {
2386 // If stride is positive we start from 0 otherwise from dimension size.
2387 beginV[idx] = (stridesV[idx] > 0) ? 0 : (inpDims[idx] - 1);
2388 }
2389 }
2390 }
2391
2392 // If the ith bit of end_mask is set, end[i] is ignored and the fullest
2393 // possible range in that dimension is used instead.
2394 if (end_mask) {
2395 for (size_t idx = 0; idx < endV.size(); ++idx) {
2396 if (getMaskBit(end_mask, idx)) {
2397 // If stride is positive we end at dimension size otherwise at 0.
2398 endV[idx] = (stridesV[idx] > 0) ? inpDims[idx] : 0;
2399 }
2400 }
2401 }
2402
2403 // If the ith bit of ellipsis_mask is set, as many unspecified dimensions as
2404 // needed will be inserted between other dimensions. Only one non-zero bit is
2405 // allowed in ellipsis_mask.
2406 if (ellipsis_mask) {
2407 size_t ellipsisIdx = 0;
2408 for (size_t idx = 0; idx < 32; ++idx) {
2409 if (getMaskBit(ellipsis_mask, idx)) {
2410 ellipsisIdx = idx;
2411 break;
2412 }
2413 }
2414 // Note: It is unclear from the TFLite specification how to derive the
2415 // number of dimensions associated to the ellipsis. We use the fact that
2416 // when ellipsis is used the associated dimensions from "begin" and "end"
2417 // arrays are commonly marked both with 0.
2418 size_t idx = ellipsisIdx;
2419 while ((idx < beginV.size()) && (beginV[idx] == 0) && (endV[idx] == 0)) {
2420 beginV[idx] = (stridesV[idx] > 0) ? 0 : (inpDims[idx] - 1);
2421 endV[idx] = (stridesV[idx] > 0) ? inpDims[idx] : 0;
2422 idx++;
2423 }
2424 }
2425
2426 // If the ith bit of new_axis_mask is set, then begin, end, and stride are
2427 // ignored and a new length 1 dimension is added at this point in the output
2428 // tensor.
2429 if (new_axis_mask) {
2430 size_t inpIdx = 0;
2431 size_t outIdx = 0;
2432 std::vector<dim_t> newOutDims;
2433 while (inpIdx < inpDims.size()) {
2434 if (getMaskBit(new_axis_mask, outIdx)) {
2435 newOutDims.push_back(1);
2436 } else {
2437 newOutDims.push_back(inpDims[inpIdx++]);
2438 }
2439 outIdx++;
2440 }
2441 NodeValue output = F_->createReshape(opInfo.name, input, newOutDims);
2442 return setOutputNodeValue(op, output);
2443 }
2444
2445 // If the ith bit of shrink_axis_mask is set, it implies that the ith
2446 // specification shrinks the dimensionality by 1, taking on the value at
2447 // index begin[i]. end[i] and strides[i] are ignored in this case.
2448 if (shrink_axis_mask) {
2449 for (size_t idx = 0; idx < beginV.size(); ++idx) {
2450 if (getMaskBit(shrink_axis_mask, idx)) {
2451 endV[idx] = beginV[idx] + 1;
2452 stridesV[idx] = 1;
2453 }
2454 }
2455 }
2456
2457 // Currently we only support strides of 1.
2458 // TODO: Add support for strides different than 1 (positive or negative) once
2459 // supported in Glow.
2460 for (size_t idx = 0; idx < stridesV.size(); ++idx) {
2461 RETURN_ERR_IF_NOT(
2462 stridesV[idx] == 1,
2463 opErrMsg(opInfo, "Only stride 1 is currently supported!"));
2464 }
2465
2466 // Create Slice node.
2467 NodeValue output = F_->createSlice(opInfo.name, input, beginV, endV);
2468
2469 // Reshape output if some dimensions were shrunk.
2470 if (shrink_axis_mask) {
2471 std::vector<dim_t> oldOutDims = output.dims();
2472 std::vector<dim_t> newOutDims;
2473 for (size_t idx = 0; idx < oldOutDims.size(); ++idx) {
2474 if (getMaskBit(shrink_axis_mask, idx)) {
2475 assert(oldOutDims[idx] == 1 && "Shrunk dimension should be 1!");
2476 } else {
2477 newOutDims.push_back(oldOutDims[idx]);
2478 }
2479 }
2480 if (newOutDims.empty()) {
2481 newOutDims.push_back(1);
2482 }
2483 output = F_->createReshape(opInfo.name + ".reshape", output, newOutDims);
2484 }
2485
2486 return setOutputNodeValue(op, output);
2487}
2488
2489Error TFLiteModelLoader::loadResizeBilinear(const tflite::Operator *op,
2490 const OperatorInfo &opInfo) {
2491 const auto *opts = op->builtin_options_as_ResizeBilinearOptions();
2492 NodeValue input;
2493 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2494 TypeRef outTy;
2495 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2496
2497 bool alignCorners = opts->align_corners();
2498 if (alignCorners) {
2499 LOG(WARNING) << opErrMsg(opInfo, "Option 'align_corners' is ignored!");
2500 }
2501
2502 bool halfPixelCenters = opts->half_pixel_centers();
2503 if (halfPixelCenters) {
2504 LOG(WARNING) << opErrMsg(opInfo, "Option 'half_pixel_centers' is ignored!");
2505 }
2506
2507 NodeValue output = F_->createResizeBilinear(opInfo.name, input, outTy);
2508 return setOutputNodeValue(op, output);
2509}
2510
2511Error TFLiteModelLoader::loadResizeNearest(const tflite::Operator *op,
2512 const OperatorInfo &opInfo) {
2513 const auto *opts = op->builtin_options_as_ResizeNearestNeighborOptions();
2514 NodeValue input;
2515 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2516 TypeRef outTy;
2517 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2518
2519 bool alignCorners = opts->align_corners();
2520 if (alignCorners) {
2521 LOG(WARNING) << opErrMsg(opInfo, "Option 'align_corners' is ignored!");
2522 }
2523
2524 bool halfPixelCenters = opts->half_pixel_centers();
2525 if (halfPixelCenters) {
2526 LOG(WARNING) << opErrMsg(opInfo, "Option 'half_pixel_centers' is ignored!");
2527 }
2528
2529 NodeValue output = F_->createResizeNearest(opInfo.name, input, outTy);
2530 return setOutputNodeValue(op, output);
2531}
2532
2533Error TFLiteModelLoader::loadSpaceToDepth(const tflite::Operator *op,
2534 const OperatorInfo &opInfo) {
2535 const auto *opts = op->builtin_options_as_SpaceToDepthOptions();
2536 NodeValue input;
2537 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2538 TypeRef outTy;
2539 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2540
2541 int32_t blockSize = opts->block_size();
2542
2543 NodeValue output = F_->createSpaceToDepth(opInfo.name, input, blockSize);
2544 RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy),
2545 opErrMsg(opInfo, "Expected output type incorrect!"));
2546 return setOutputNodeValue(op, output);
2547}
2548
2549Error TFLiteModelLoader::loadDepthToSpace(const tflite::Operator *op,
2550 const OperatorInfo &opInfo) {
2551 const auto *opts = op->builtin_options_as_DepthToSpaceOptions();
2552 NodeValue input;
2553 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2554 TypeRef outTy;
2555 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2556
2557 int32_t blockSize = opts->block_size();
2558
2559 NodeValue output = F_->createDepthToSpace(opInfo.name, input, blockSize);
2560 RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy),
2561 opErrMsg(opInfo, "Expected output type incorrect!"));
2562 return setOutputNodeValue(op, output);
2563}
2564
2565Error TFLiteModelLoader::loadCast(const tflite::Operator *op,
2566 const OperatorInfo &opInfo) {
2567 NodeValue input;
2568 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2569 TypeRef outTy;
2570 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2571 // The Cast operator has two attributes "in_data_type" and "out_data_type" but
2572 // are not used because the input and output types are already available.
2573 NodeValue output = F_->createConvertTo(opInfo.name, input, outTy);
2574 return setOutputNodeValue(op, output);
2575}
2576
2577Error TFLiteModelLoader::loadGather(const tflite::Operator *op,
2578 const OperatorInfo &opInfo) {
2579 const auto *opts = op->builtin_options_as_GatherOptions();
2580 NodeValue data;
2581 ASSIGN_VALUE_OR_RETURN_ERR(data, getInputNodeValue(op, 0));
2582 NodeValue indices;
2583 ASSIGN_VALUE_OR_RETURN_ERR(indices, getInputNodeValue(op, 1));
2584 TypeRef outTy;
2585 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2586
2587 unsigned_t axis;
2588 ASSIGN_VALUE_OR_RETURN_ERR(
2589 axis, getPositiveAxis<unsigned_t>(opts->axis(), data.dims().size()));
2590
2591 NodeValue output = F_->createGather(opInfo.name, data, indices, axis);
2592 RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy),
2593 opErrMsg(opInfo, "Expected output type incorrect!"));
2594 return setOutputNodeValue(op, output);
2595}
2596
2597Error TFLiteModelLoader::loadGatherND(const tflite::Operator *op,
2598 const OperatorInfo &opInfo) {
2599 NodeValue data;
2600 ASSIGN_VALUE_OR_RETURN_ERR(data, getInputNodeValue(op, 0));
2601 NodeValue indices;
2602 ASSIGN_VALUE_OR_RETURN_ERR(indices, getInputNodeValue(op, 1));
2603 TypeRef outTy;
2604 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2605
2606 NodeValue output = F_->createGatherND(opInfo.name, data, indices);
2607 RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy),
2608 opErrMsg(opInfo, "Expected output type incorrect!"));
2609 return setOutputNodeValue(op, output);
2610}
2611
2612Error TFLiteModelLoader::loadSelect(const tflite::Operator *op,
2613 const OperatorInfo &opInfo) {
2614 NodeValue cond;
2615 ASSIGN_VALUE_OR_RETURN_ERR(cond, getInputNodeValue(op, 0));
2616 NodeValue LHS;
2617 ASSIGN_VALUE_OR_RETURN_ERR(LHS, getInputNodeValue(op, 1));
2618 NodeValue RHS;
2619 ASSIGN_VALUE_OR_RETURN_ERR(RHS, getInputNodeValue(op, 2));
2620 TypeRef outTy;
2621 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2622
2623 NodeValue output = F_->createSelect(opInfo.name, outTy, cond, LHS, RHS);
2624 return setOutputNodeValue(op, output);
2625}
2626
2627Error TFLiteModelLoader::loadSpaceToBatchNd(const tflite::Operator *op,
2628 const OperatorInfo &opInfo) {
2629 NodeValue input;
2630 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2631 NodeValue block;
2632 ASSIGN_VALUE_OR_RETURN_ERR(block, getInputNodeValue(op, 1));
2633 NodeValue pads;
2634 ASSIGN_VALUE_OR_RETURN_ERR(pads, getInputNodeValue(op, 2));
2635 TypeRef outTy;
2636 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2637
2638 // Implemented as Pad -> Transpose N-D -> SpaceToDepth -> Transpose N-D
2639
2640 // Get Block Size dimensionality. Should be 2 but pretend it's arbitrary like
2641 // in tensorflow.
2642 int32_t blockSize = block.dims()[0];
2643
2644 // Validate block size input.
2645 RETURN_ERR_IF_NOT(block.dims().size() == 1,
2646 opErrMsg(opInfo, "Block Size should be 1D"));
2647 auto *blockC = llvm::dyn_cast<Constant>(block.getNode());
2648 RETURN_ERR_IF_NOT(
2649 blockC, opErrMsg(opInfo, "Non constant 'Block Size' not supported"));
2650 RETURN_ERR_IF_NOT(blockC->getType()->getElementType() == ElemKind::Int32ITy,
2651 opErrMsg(opInfo, "Block Size should have INT32 type"));
2652
2653 // Validate paddings.
2654 auto *padsC = llvm::dyn_cast<Constant>(pads.getNode());
2655 RETURN_ERR_IF_NOT(padsC,
2656 opErrMsg(opInfo, "Non constant 'paddings' not supported"));
2657 RETURN_ERR_IF_NOT(padsC->getType()->getElementType() == ElemKind::Int32ITy,
2658 opErrMsg(opInfo, "Paddings should have INT32 type"));
2659 RETURN_ERR_IF_NOT(pads.dims().size() == 2,
2660 opErrMsg(opInfo, "Paddings should be 2D"));
2661 for (dim_t i = 0; i < pads.dims().size(); i++) {
2662 RETURN_ERR_IF_NOT(
2663 pads.dims()[i] == (dim_t)blockSize,
2664 opErrMsg(opInfo,
2665 "Each Padding should have Block Size number of values"));
2666 }
2667
2668 // Get Block Size values.
2669 std::vector<int32_t> blockV;
2670 helperSetter<int32_t>(blockC, blockV);
2671 auto elemEqual = std::equal(blockV.begin() + 1, blockV.end(), blockV.begin());
2672 RETURN_ERR_IF_NOT(
2673 elemEqual,
2674 opErrMsg(opInfo, "Different Block Size values not supported yet."));
2675
2676 auto inDims = input.getType()->dims();
2677
2678 // Create Padding Node.
2679 auto padsH = padsC->getPayload().getHandle<int32_t>();
2680 std::vector<int> padsVec(2 * inDims.size());
2681 for (int32_t i = 0; i < blockSize; i++) {
2682 // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds
2683 padsVec[i + 1] = padsH.raw(2 * i + 0);
2684 padsVec[i + 1 + inDims.size()] = padsH.raw(2 * i + 1);
2685 }
2686 ShapeVector padDims(inDims.begin(), inDims.end());
2687 for (dim_t i = 0; i < padDims.size(); i++) {
2688 // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds
2689 padDims[i] += (padsVec[i] + padsVec[i + inDims.size()]);
2690 }
2691 auto OT = mod_.uniqueTypeWithNewShape(input.getType(), padDims);
2692 NodeValue pad = F_->createPad(opInfo.name, input, OT, PaddingMode::CONSTANT,
2693 padsVec, 0.f);
2694
2695 // Check expected dimensions.
2696 auto padInDims = pad.getType()->dims();
2697 auto outDims = outTy->dims();
2698 int32_t calcblockSize = 0;
2699 for (int32_t i = 0; i < blockSize; i++) {
2700 RETURN_ERR_IF_NOT(
2701 padInDims[i + 1] % outDims[i + 1] == 0,
2702 opErrMsg(opInfo, strFormat("Input value %d must be a multiple of "
2703 "output value %d for space dim %d",
2704 (int)padInDims[1], (int)outDims[1], i)));
2705 int32_t newBlockSize = padInDims[i + 1] / outDims[i + 1];
2706 if (i > 1) {
2707 RETURN_ERR_IF_NOT(
2708 calcblockSize == newBlockSize,
2709 opErrMsg(opInfo, "Block Size for H & W must be the same."));
2710 calcblockSize = newBlockSize;
2711 }
2712 }
2713
2714 // Transpose N and D and perform SpaceToDepth.
2715 auto *trIn =
2716 F_->createTranspose(opInfo.name + ".trIn", pad, {3, 1, 2, 0}, "NHWC");
2717 // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds
2718 auto *S2D = F_->createSpaceToDepth(opInfo.name + ".S2D", trIn, blockV[0]);
2719 auto *output =
2720 F_->createTranspose(opInfo.name + ".trOut", S2D, {3, 1, 2, 0}, "NHWC");
2721
2722 return setOutputNodeValue(op, output);
2723}
2724
2725Error TFLiteModelLoader::loadBatchToSpaceNd(const tflite::Operator *op,
2726 const OperatorInfo &opInfo) {
2727 NodeValue input;
2728 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2729 NodeValue block;
2730 ASSIGN_VALUE_OR_RETURN_ERR(block, getInputNodeValue(op, 1));
2731 NodeValue crop;
2732 ASSIGN_VALUE_OR_RETURN_ERR(crop, getInputNodeValue(op, 2));
2733 TypeRef outTy;
2734 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2735
2736 // Implemented as Transpose N-D -> DepthToSpace -> Transpose N-D -> Crop
2737
2738 // Get Block Size dimensionality. Should be 2 but pretend it's arbitrary like
2739 // in tensorflow.
2740 int32_t blockSize = block.dims()[0];
2741
2742 // Validate block input.
2743 RETURN_ERR_IF_NOT(block.dims().size() == 1,
2744 opErrMsg(opInfo, "Block Size should be 1D"));
2745 auto *blockC = llvm::dyn_cast<Constant>(block.getNode());
2746 RETURN_ERR_IF_NOT(
2747 blockC, opErrMsg(opInfo, "Non constant 'Block Size' not supported"));
2748 RETURN_ERR_IF_NOT(blockC->getType()->getElementType() == ElemKind::Int32ITy,
2749 opErrMsg(opInfo, "Block Size should have INT32 type"));
2750
2751 // Validate crop input.
2752 auto *cropC = llvm::dyn_cast<Constant>(crop.getNode());
2753 RETURN_ERR_IF_NOT(cropC,
2754 opErrMsg(opInfo, "Non constant 'crops' not supported"));
2755 RETURN_ERR_IF_NOT(cropC->getType()->getElementType() == ElemKind::Int32ITy,
2756 opErrMsg(opInfo, "Paddings should have INT32 type"));
2757 RETURN_ERR_IF_NOT(crop.dims().size() == 2,
2758 opErrMsg(opInfo, "Crops should be 2D"));
2759 for (dim_t i = 0; i < crop.dims().size(); i++) {
2760 RETURN_ERR_IF_NOT(
2761 crop.dims()[i] == (dim_t)blockSize,
2762 opErrMsg(opInfo, "Each crop should have Block Size number of values"));
2763 }
2764
2765 // Get Block Size values.
2766 std::vector<int32_t> blockV;
2767 helperSetter<int32_t>(blockC, blockV);
2768 auto elemEqual = std::equal(blockV.begin() + 1, blockV.end(), blockV.begin());
2769 RETURN_ERR_IF_NOT(
2770 elemEqual,
2771 opErrMsg(opInfo, "Different Block Size values not supported yet."));
2772
2773 // Transpose N and D and perform DepthToSpace.
2774 auto *tr1 =
2775 F_->createTranspose(opInfo.name + ".trPre", input, {3, 1, 2, 0}, "NHWC");
2776 // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds
2777 auto *D2S = F_->createDepthToSpace(opInfo.name + ".D2S", tr1, blockV[0]);
2778 auto *tr2 =
2779 F_->createTranspose(opInfo.name + ".trPost", D2S, {3, 1, 2, 0}, "NHWC");
2780
2781 // Create Crop Node.
2782 RETURN_ERR_IF_NOT(cropC->getType()->getElementType() == ElemKind::Int32ITy,
2783 opErrMsg(opInfo, "Paddings should have INT32 type"));
2784 auto cropH = cropC->getPayload().getHandle<int32_t>();
2785 std::vector<dim_t> cropVec(input.getType()->dims().size());
2786 for (int32_t i = 0; i < blockSize; i++) {
2787 // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds
2788 cropVec[i + 1] = cropH.raw(2 * i + 0);
2789 }
2790 NodeValue output = F_->createSlice(opInfo.name, tr2, cropVec, outTy);
2791 return setOutputNodeValue(op, output);
2792}
2793
2794Error TFLiteModelLoader::loadTile(const tflite::Operator *op,
2795 const OperatorInfo &opInfo) {
2796 NodeValue input;
2797 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2798 NodeValue multiples;
2799 ASSIGN_VALUE_OR_RETURN_ERR(multiples, getInputNodeValue(op, 1));
2800
2801 auto numDims = input.getType()->dims().size();
2802 std::vector<unsigned_t> numTiles;
2803 ASSIGN_VALUE_OR_RETURN_ERR(numTiles,
2804 loadArray<unsigned_t>(opInfo, multiples));
2805 RETURN_ERR_IF_NOT(numTiles.size() == numDims,
2806 opErrMsg(opInfo, "Input operand 'multiples' length should "
2807 "match the number of input dimensions!"));
2808
2809 NodeValue output = input;
2810 for (unsigned_t axis = 0; axis < numDims; ++axis) {
2811 unsigned_t tiles = numTiles[axis];
2812 if (tiles != 1) {
2813 output = F_->createTile(opInfo.name + std::to_string(axis), output, tiles,
2814 axis);
2815 }
2816 }
2817 return setOutputNodeValue(op, output);
2818}
2819
2820Error TFLiteModelLoader::loadPack(const tflite::Operator *op,
2821 const OperatorInfo &opInfo) {
2822 const auto *opts = op->builtin_options_as_PackOptions();
2823 TypeRef outTy;
2824 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2825
2826 const size_t numInputs = op->inputs()->size();
2827 RETURN_ERR_IF_NOT(int(numInputs) == opts->values_count(),
2828 opErrMsg(opInfo, "Attribute 'values_count' does not match "
2829 "the number of operator inputs!"));
2830 llvm::SmallVector<NodeValue, 4> inputs;
2831 inputs.reserve(numInputs);
2832 for (size_t idx = 0; idx < numInputs; ++idx) {
2833 NodeValue input;
2834 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, idx));
2835 inputs.push_back(input);
2836 }
2837
2838 unsigned_t axis;
2839 ASSIGN_VALUE_OR_RETURN_ERR(
2840 axis, getPositiveAxis<unsigned_t>(opts->axis(), outTy->dims().size()));
2841
2842 // Validate that all inputs have same shape.
2843 for (size_t idx = 0; idx < numInputs; ++idx) {
2844 RETURN_ERR_IF_NOT(
2845 inputs[idx].getType()->isEqual(inputs[0].getType()),
2846 opErrMsg(opInfo, "Operator inputs do not have same type/shape!"));
2847 }
2848
2849 // Reshape all inputs from [D0,D1,...,DN] to [D0,D1,...,1,...,DN] where the
2850 // the singular dimension 1 is on the axis position.
2851 std::vector<dim_t> inputDimsReshaped = inputs[0].dims();
2852 inputDimsReshaped.insert(inputDimsReshaped.begin() + axis, 1);
2853 for (size_t idx = 0; idx < numInputs; ++idx) {
2854 inputs[idx] =
2855 F_->createReshape(opInfo.name + ".Reshape" + std::to_string(idx),
2856 inputs[idx], inputDimsReshaped);
2857 }
2858
2859 // Concatenate all inputs along axis.
2860 NodeValue output =
2861 F_->createConcat(opInfo.name + ".Concat", inputs, axis, outTy);
2862 return setOutputNodeValue(op, output);
2863}
2864
2865Error TFLiteModelLoader::loadUnpack(const tflite::Operator *op,
2866 const OperatorInfo &opInfo) {
2867 const auto *opts = op->builtin_options_as_UnpackOptions();
2868 NodeValue input;
2869 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2870 TypeRef outTy;
2871 ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0));
2872
2873 unsigned_t axis;
2874 ASSIGN_VALUE_OR_RETURN_ERR(axis,
2875 getPositiveAxis<unsigned_t>(opts->axis(), input));
2876
2877 unsigned_t num = static_cast<unsigned_t>(opts->num());
2878 RETURN_ERR_IF_NOT(
2879 num == input.dims()[axis],
2880 opErrMsg(opInfo,
2881 "Attribute 'num' should be equal to input size along axis!"));
2882
2883 // Split input.
2884 std::vector<SliceNode *> outputNodes;
2885 F_->createSplit(opInfo.name, input, num, axis, {}, outputNodes);
2886
2887 // Reshape outputs.
2888 std::vector<NodeValue> outputNodeValues(outputNodes.size(), nullptr);
2889 for (size_t idx = 0, end = outputNodes.size(); idx < end; ++idx) {
2890 outputNodeValues[idx] = outputNodes[idx]->getResult();
2891 outputNodeValues[idx] =
2892 F_->createReshape(opInfo.name + ".Reshape" + std::to_string(idx),
2893 outputNodeValues[idx], outTy->dims());
2894 }
2895 return setOutputNodeValues(op, outputNodeValues);
2896}
2897
2898Error TFLiteModelLoader::loadTFLiteDetectionPostProcess(
2899 const tflite::Operator *op, const OperatorInfo &opInfo,
2900 const flexbuffers::Map &opts) {
2901 NodeValue boxes;
2902 ASSIGN_VALUE_OR_RETURN_ERR(boxes, getInputNodeValue(op, 0));
2903 NodeValue scores;
2904 ASSIGN_VALUE_OR_RETURN_ERR(scores, getInputNodeValue(op, 1));
2905 NodeValue anchors;
2906 ASSIGN_VALUE_OR_RETURN_ERR(anchors, getInputNodeValue(op, 2));
2907
2908 // Note: We cannot use the output types of the node because they are dynamic.
2909 // We create instead static types for this node with fixed sizes.
2910
2911 // Get operator attributes.
2912 int32_t numClasses = opts["num_classes"].AsInt32();
2913 int32_t maxDetections = opts["max_detections"].AsInt32();
2914 int32_t maxClassesPerDetection = opts["max_classes_per_detection"].AsInt32();
2915 constexpr int32_t defaultNumDetectionsPerClass = 100;
2916 int32_t maxDetectionsPerClass = (opts["detections_per_class"].IsNull())
2917 ? defaultNumDetectionsPerClass
2918 : opts["detections_per_class"].AsInt32();
2919 float iouThreshold = opts["nms_iou_threshold"].AsFloat();
2920 float scoreThreshold = opts["nms_score_threshold"].AsFloat();
2921 float xScale = opts["x_scale"].AsFloat();
2922 float yScale = opts["y_scale"].AsFloat();
2923 float hScale = opts["h_scale"].AsFloat();
2924 float wScale = opts["w_scale"].AsFloat();
2925 bool regularNMS = (opts["use_regular_nms"].IsNull())
2926 ? false
2927 : opts["use_regular_nms"].AsBool();
2928
2929 // Create node.
2930 auto *node = F_->createTFLiteDetectionPostProcess(
2931 opInfo.name, boxes, scores, anchors, numClasses, maxDetections,
2932 maxClassesPerDetection, maxDetectionsPerClass, iouThreshold,
2933 scoreThreshold, xScale, yScale, hScale, wScale, regularNMS);
2934 std::vector<NodeValue> outputNodeValues = {
2935 node->getDetectionBoxes(), node->getDetectionClasses(),
2936 node->getDetectionScores(), node->getNumDetections()};
2937 return setOutputNodeValues(op, outputNodeValues);
2938}
2939
2940Error TFLiteModelLoader::loadTFLiteAudioSpectrogram(
2941 const tflite::Operator *op, const OperatorInfo &opInfo,
2942 const flexbuffers::Map &opts) {
2943 NodeValue input;
2944 ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0));
2945
2946 // Get operator attributes.
2947 int32_t windowSize = opts["window_size"].AsInt32();
2948 int32_t windowStride = opts["stride"].AsInt32();
2949 bool magnitudeSquared = opts["magnitude_squared"].AsBool();
2950
2951 // Create node.
2952 NodeValue output = F_->createAudioSpectrogram(opInfo.name, input, windowSize,
2953 windowStride, magnitudeSquared);
2954 return setOutputNodeValue(op, output, /* checkType */ false);
2955}
2956
2957Error TFLiteModelLoader::loadTFLiteMFCC(const tflite::Operator *op,
2958 const OperatorInfo &opInfo,
2959 const flexbuffers::Map &opts) {
2960 NodeValue spectrogram;
2961 ASSIGN_VALUE_OR_RETURN_ERR(spectrogram, getInputNodeValue(op, 0));
2962
2963 // Get operator attributes.
2964 float sampleRate = (opts["sample_rate"].IsNull())
2965 ? tfliteMfccSampleRateOpt
2966 : opts["sample_rate"].AsFloat();
2967 float lowerFrequency = opts["lower_frequency_limit"].AsFloat();
2968 float upperFrequency = opts["upper_frequency_limit"].AsFloat();
2969 int32_t filterBankCount = opts["filterbank_channel_count"].AsInt32();
2970 int32_t numCoefficients = opts["dct_coefficient_count"].AsInt32();
2971
2972 // Create node.
2973 NodeValue output =
2974 F_->createMFCC(opInfo.name, spectrogram, sampleRate, lowerFrequency,
2975 upperFrequency, filterBankCount, numCoefficients);
2976 return setOutputNodeValue(op, output, /* checkType */ false);
2977}
2978
2979TFLiteModelLoader::TFLiteModelLoader(const std::string &modelFilename,
2980 Function *F)
2981 : F_(F), mod_(*F->getParent()) {
2982 auto setup = [&]() -> Error {
2983 // Read model.
2984 std::vector<char> modelData;
2985 ASSIGN_VALUE_OR_RETURN_ERR(model_, readModel(modelData, modelFilename));
2986
2987 // TODO: Verify model integrity using flatbuffers::Verifier class.
2988
2989 // Get model info.
2990 modelVersion_ = model_->version();
2991 modelDescription_ = model_->description()->str();
2992
2993 // Get model graph.
2994 const auto *modelGraphs = model_->subgraphs();
2995 RETURN_ERR_IF_NOT(
2996 modelGraphs->size() == 1,
2997 "TensorFlowLite: Only one model subgraph is currently supported!");
2998 graph_ = (*modelGraphs)[0];
2999
3000 // Initialize graph node values.
3001 initializeNodeValues();
3002
3003 // Load graph input placeholders.
3004 RETURN_IF_ERR(loadInputPlaceholders());
3005
3006 // Load graph constants.
3007 RETURN_IF_ERR(loadConstants());
3008
3009 // Load graph operators.
3010 RETURN_IF_ERR(loadOperators());
3011
3012 // Save graph output placeholders.
3013 RETURN_IF_ERR(saveOutputPlaceholders());
3014
3015 // Verify function.
3016 RETURN_ERR_IF_NOT(F_->verify(),
3017 "TensorFlowLite: Function verification failed!");
3018
3019 return Error::success();
3020 };
3021
3022 EXIT_ON_ERR(setup());
3023}
3024