1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Importer/TFLiteModelLoader.h" |
18 | #include "glow/Base/Tensor.h" |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Graph/Nodes.h" |
21 | #include "glow/Importer/CommonOperatorLoader.h" |
22 | #include "glow/Support/Support.h" |
23 | |
24 | #include "llvm/Support/Casting.h" |
25 | #include "llvm/Support/CommandLine.h" |
26 | |
27 | #include <fstream> |
28 | #include <sstream> |
29 | #include <string> |
30 | #include <vector> |
31 | |
32 | using namespace glow; |
33 | using llvm::cast; |
34 | |
35 | namespace { |
36 | |
37 | llvm::cl::OptionCategory |
38 | tfliteModelLoaderCat("TensorFlowLite Model Loader Options" ); |
39 | |
40 | llvm::cl::opt<bool> tfliteUint8ToInt8Opt( |
41 | "tflite-uint8-to-int8" , |
42 | llvm::cl::desc("TensorFlowLite loader option to convert the model from " |
43 | "UINT8 data type to INT8 data type." ), |
44 | llvm::cl::init(true), llvm::cl::Optional, |
45 | llvm::cl::cat(tfliteModelLoaderCat)); |
46 | |
47 | llvm::cl::opt<bool> tfliteFloatInputsOpt( |
48 | "tflite-float-inputs" , |
49 | llvm::cl::desc("TensorFlowLite loader option to replace the quantized " |
50 | "inputs with floating point inputs." ), |
51 | llvm::cl::init(false), llvm::cl::Optional, |
52 | llvm::cl::cat(tfliteModelLoaderCat)); |
53 | |
54 | llvm::cl::opt<bool> tfliteFloatOutputsOpt( |
55 | "tflite-float-outputs" , |
56 | llvm::cl::desc("TensorFlowLite loader option to replace the quantized " |
57 | "outputs with floating point outputs." ), |
58 | llvm::cl::init(false), llvm::cl::Optional, |
59 | llvm::cl::cat(tfliteModelLoaderCat)); |
60 | |
61 | llvm::cl::opt<bool> tfliteFloatSoftmaxOpt( |
62 | "tflite-float-softmax" , |
63 | llvm::cl::desc("TensorFlowLite loader option to replace a quantized Softmax" |
64 | "with a floating point Softmax." ), |
65 | llvm::cl::init(false), llvm::cl::Optional, |
66 | llvm::cl::cat(tfliteModelLoaderCat)); |
67 | |
68 | llvm::cl::opt<float> tfliteBiasScaleCheckMaxErrorOpt( |
69 | "tflite-bias-scale-check-max-error" , |
70 | llvm::cl::desc( |
71 | "TensorFlowLite mandates that for quantized operators like Conv2D the " |
72 | "bias quantization parameter biasScale = inputScale * weightsScale but " |
73 | "some pre-quantized models do not EXACTLY satisfy this relation but " |
74 | "with very small relative errors (around 1e-8). Hence we allow a " |
75 | "tolerance of 1e-6 which, if satisfied, then we adjust the bias to " |
76 | "conform to the restriction." ), |
77 | llvm::cl::init(1e-6), llvm::cl::Optional, |
78 | llvm::cl::cat(tfliteModelLoaderCat)); |
79 | |
80 | llvm::cl::opt<bool> tfliteBiasScaleCheckThrowErrorOpt( |
81 | "tflite-bias-scale-check-throw-error" , |
82 | llvm::cl::desc( |
83 | "TensorFlowLite mandates that for quantized operators like Conv2D the " |
84 | "bias quantization parameter biasScale = inputScale * weightsScale. If " |
85 | "this contraint is not met within the given tolerance then an error " |
86 | "will be thrown if this option is enabled." ), |
87 | llvm::cl::init(true), llvm::cl::Optional, |
88 | llvm::cl::cat(tfliteModelLoaderCat)); |
89 | |
90 | llvm::cl::opt<float> tfliteMfccSampleRateOpt( |
91 | "tflite-mfcc-sample-rate" , |
92 | llvm::cl::desc( |
93 | "When the TensorFlowLite model has a MFCC node (Mel Frequency Cepstral " |
94 | "Coefficient) this option is used to set the sample rate (in Hz) used " |
95 | "by the node when no such attribute is specified." ), |
96 | llvm::cl::init(16000.0), llvm::cl::Optional, |
97 | llvm::cl::cat(tfliteModelLoaderCat)); |
98 | |
99 | /// Function to read a TensorFlowLite model from the file \p modelFilename into |
100 | /// the data buffer \p modelData provided by the caller. The \p modelData buffer |
101 | /// is allocated and initialized by this function but the caller must ensure its |
102 | /// existence through the graph loading process. \returns the TensorFlowLite |
103 | /// model object or Error in case something went wrong. |
104 | Expected<const tflite::Model *> readModel(std::vector<char> &modelData, |
105 | const std::string &modelFilename) { |
106 | // Open file. |
107 | std::ifstream modelFile; |
108 | modelFile.open(modelFilename, std::ios::binary); |
109 | RETURN_ERR_IF_NOT(modelFile.is_open(), |
110 | strFormat("TensorFlowLite: Error opening model file '%s'!" , |
111 | modelFilename.c_str())); |
112 | // Get model size. |
113 | modelFile.seekg(0, std::ios::end); |
114 | std::streamsize modelSize = modelFile.tellg(); |
115 | modelFile.seekg(0, std::ios::beg); |
116 | // Read model data. |
117 | modelData = std::vector<char>(modelSize); |
118 | RETURN_ERR_IF_NOT(modelFile.read(modelData.data(), modelSize), |
119 | strFormat("TensorFlowLite: Error reading model file '%s'!" , |
120 | modelFilename.c_str())); |
121 | modelFile.close(); |
122 | // Return model object. |
123 | return tflite::GetModel(modelData.data()); |
124 | } |
125 | |
126 | /// Function to convert the UINT8 data from the buffer \p inpPtr to INT8 format |
127 | /// into the buffer \p outPtr. The buffer size is given by \p numElem. This |
128 | /// function is used to transform the UINT8 weights of a TensorFlowLite model to |
129 | /// INT8 format which is the format preferred and supported by Glow. |
130 | void convertUint8ToInt8(const uint8_t *inpPtr, int8_t *outPtr, size_t numElem) { |
131 | for (size_t idx = 0, idxEnd = numElem; idx < idxEnd; ++idx) { |
132 | int32_t val = inpPtr[idx]; |
133 | val -= UINT8_TO_INT8_SHIFT; |
134 | outPtr[idx] = static_cast<int8_t>(val); |
135 | } |
136 | } |
137 | |
138 | /// Function to compute the padding along a single dimension for the given input |
139 | /// size \p inputSize and output size \p outputSize and for the given filter |
140 | /// (kernel) size \p kernel \p stride, \p dilation and padding type \p padding. |
141 | /// \returns a pair with the explicit padding values to be used for the input |
142 | /// before and after the actual input data. |
143 | std::pair<unsigned_t, unsigned_t> |
144 | getConvPads(dim_t inputSize, dim_t outputSize, unsigned_t kernel, |
145 | unsigned_t stride, unsigned_t dilation, tflite::Padding padding) { |
146 | if (padding == tflite::Padding::Padding_VALID) { |
147 | // For VALID padding we do not use padding. |
148 | return std::pair<unsigned_t, unsigned_t>(0, 0); |
149 | } else if (padding == tflite::Padding::Padding_SAME) { |
150 | // Effective dilated filter (kernel) size. |
151 | unsigned_t effKernel = (kernel - 1) * dilation + 1; |
152 | // Compute the total padding size while saturating above 0. |
153 | unsigned_t padTotal = (outputSize - 1) * stride + effKernel; |
154 | padTotal = |
155 | std::max(padTotal, static_cast<unsigned_t>(inputSize)) - inputSize; |
156 | // We split the total padding evenly before/after. If the padding is odd |
157 | // then the "after" part gets the extra unit. |
158 | unsigned_t padBefore = padTotal / 2; |
159 | unsigned_t padAfter = padTotal - padBefore; |
160 | return std::pair<unsigned_t, unsigned_t>(padBefore, padAfter); |
161 | } |
162 | llvm_unreachable("Padding parameter invalid!" ); |
163 | } |
164 | |
165 | /// Function used to compute the output size for convolution and pooling kernels |
166 | /// for the given \p inputSize, \p kernel, \p stride, \p dilation, \p padding. |
167 | /// This function is used to infer the output shape when not defined. |
168 | dim_t getConvOutputSize(dim_t inputSize, unsigned_t kernel, unsigned_t stride, |
169 | unsigned_t dilation, tflite::Padding padding) { |
170 | if (padding == tflite::Padding::Padding_VALID) { |
171 | // Effective dilated filter (kernel) size. |
172 | unsigned_t effKernel = (kernel - 1) * dilation + 1; |
173 | // We compute the output size as CEIL((inputSize - effKernel) / stride) + 1. |
174 | return (inputSize - effKernel + stride - 1) / stride + 1; |
175 | } else if (padding == tflite::Padding::Padding_SAME) { |
176 | // For SAME padding the output size is computed as CEIL(inputSize / stride). |
177 | return (inputSize + stride - 1) / stride; |
178 | } |
179 | llvm_unreachable("Padding parameter invalid!" ); |
180 | } |
181 | |
182 | /// Retrieves data from a constant Tensor and stores it in a vector. |
183 | template <typename T, typename datatype = ssize_t> |
184 | static void helperSetter(Constant *constT, std::vector<datatype> &vec) { |
185 | auto constH = constT->getPayload().getHandle<T>(); |
186 | for (dim_t i = 0; i < constH.size(); ++i) { |
187 | vec.push_back(constH.at({i})); |
188 | } |
189 | } |
190 | |
191 | /// Function to verify the quantization parameters of the bias operand. The |
192 | /// TensorFlowLite format mandates that the bias scale must be equal to the |
193 | /// product inputScale * weightsScale and the bias offset must be 0. This |
194 | /// function is provided with the module \p mod and the node values \p input, |
195 | /// \p weights, \p bias and \returns Error::success() if the bias parameters |
196 | /// are valid and Error otherwise. |
197 | Error checkBiasQuantizationParams(Module &mod, NodeValue input, |
198 | NodeValue weights, NodeValue bias) { |
199 | auto inputTy = input.getType(); |
200 | auto weightsTy = weights.getType(); |
201 | auto biasTy = bias.getType(); |
202 | if (inputTy->isQuantizedType() && weightsTy->isQuantizedType() && |
203 | biasTy->isQuantizedType()) { |
204 | float inputScale = inputTy->getScale(); |
205 | float weightsScale = weightsTy->getScale(); |
206 | float matMulScale = inputScale * weightsScale; |
207 | float biasScale = biasTy->getScale(); |
208 | // Check bias scale relative error to inputScale * weightsScale. |
209 | if (biasScale != matMulScale) { |
210 | float relErr = std::abs(matMulScale - biasScale) / matMulScale; |
211 | llvm::errs() << strFormat( |
212 | "TensorFlowLite: WARNING: Per tensor BIAS scale value was expected " |
213 | "to be exactly %E (inputScale * weightsScale) but found " |
214 | "%E instead! Relative absolute error is %E!\n" , |
215 | matMulScale, biasScale, relErr); |
216 | if (relErr < tfliteBiasScaleCheckMaxErrorOpt) { |
217 | // Set new bias type. |
218 | TypeRef newBiasTy = |
219 | mod.uniqueType(biasTy->getElementType(), biasTy->dims(), |
220 | matMulScale, biasTy->getOffset()); |
221 | bias.setType(newBiasTy); |
222 | // If bias is constant we must also change the payload type. |
223 | if (auto *biasC = llvm::dyn_cast<Constant>(bias.getNode())) { |
224 | biasC->setPayloadType(newBiasTy); |
225 | } |
226 | } else if (tfliteBiasScaleCheckThrowErrorOpt) { |
227 | return MAKE_ERR(strFormat( |
228 | "TensorFlowLite: ERROR: Per tensor BIAS scale value was " |
229 | "expected to be exactly %E (inputScale * weightsScale) but " |
230 | "found %E instead! Relative absolute error is %E!\n" , |
231 | matMulScale, biasScale, relErr)); |
232 | } |
233 | } |
234 | int32_t biasOffset = biasTy->getOffset(); |
235 | if (biasOffset != 0) { |
236 | return MAKE_ERR( |
237 | strFormat("TensorFlowLite: Bias offset value was expected to " |
238 | "be 0 but found %d instead!" , |
239 | biasOffset)); |
240 | } |
241 | } |
242 | return Error::success(); |
243 | } |
244 | |
245 | } // namespace |
246 | |
247 | ///===---------------------------------------------------------------------===// |
248 | /// Tensor Utilities |
249 | ///===---------------------------------------------------------------------===// |
250 | Expected<const tflite::Tensor *> |
251 | TFLiteModelLoader::getTensorByIndex(size_t index) { |
252 | auto *tensors = graph_->tensors(); |
253 | RETURN_ERR_IF_NOT( |
254 | index < tensors->size(), |
255 | strFormat("TensorFlowLite: Tensor index %zu out of range!" , index)); |
256 | return (*tensors)[index]; |
257 | } |
258 | |
259 | std::string TFLiteModelLoader::getTensorName(const tflite::Tensor *tensor) { |
260 | return tensor->name()->str(); |
261 | } |
262 | |
263 | Expected<std::vector<dim_t>> |
264 | TFLiteModelLoader::getTensorShape(const tflite::Tensor *tensor) { |
265 | // If tensor shape is NULL we use a 1D shape with size 1. |
266 | if (!tensor->shape()) { |
267 | return std::vector<dim_t>({1}); |
268 | } |
269 | std::vector<dim_t> shape; |
270 | for (auto dim : *(tensor->shape())) { |
271 | RETURN_ERR_IF_NOT(dim > 0, |
272 | strFormat("TensorFlowLite: Tensor '%s' has invalid shape " |
273 | "element '%d'!" , |
274 | getTensorName(tensor).c_str(), dim)); |
275 | shape.push_back(static_cast<dim_t>(dim)); |
276 | } |
277 | // If tensor shape is empty (scalar) we use a 1D shape with size 1. |
278 | if (shape.empty()) { |
279 | shape = {1}; |
280 | } |
281 | return shape; |
282 | } |
283 | |
284 | Expected<bool> |
285 | TFLiteModelLoader::isTensorShapeUndefined(const tflite::Tensor *tensor) { |
286 | // If tensor shape is NULL. |
287 | if (!tensor->shape()) { |
288 | return true; |
289 | } |
290 | // If tensor shape is empty (scalar). |
291 | if (tensor->shape()->size() == 0) { |
292 | return true; |
293 | } |
294 | return false; |
295 | } |
296 | |
297 | Expected<ElemKind> |
298 | TFLiteModelLoader::getTensorElemKind(const tflite::Tensor *tensor) { |
299 | bool isQuantized = isTensorQuantized(tensor); |
300 | switch (tensor->type()) { |
301 | case tflite::TensorType_FLOAT32: { |
302 | RETURN_ERR_IF_NOT( |
303 | !isQuantized, |
304 | "TensorFlowLite: FLOAT32 type should have no quantization parameters!" ); |
305 | return ElemKind::FloatTy; |
306 | } |
307 | case tflite::TensorType_FLOAT16: { |
308 | RETURN_ERR_IF_NOT( |
309 | !isQuantized, |
310 | "TensorFlowLite: FLOAT16 type should have no quantization parameters!" ); |
311 | return ElemKind::Float16Ty; |
312 | } |
313 | case tflite::TensorType_INT8: { |
314 | if (isQuantized) { |
315 | return ElemKind::Int8QTy; |
316 | } else { |
317 | return MAKE_ERR("TensorFlowLite: Non-quantized INT8 type not supported!" ); |
318 | } |
319 | } |
320 | case tflite::TensorType_UINT8: { |
321 | if (isQuantized) { |
322 | // Convert UINT8 element type to INT8 element type. |
323 | if (tfliteUint8ToInt8Opt) { |
324 | return ElemKind::Int8QTy; |
325 | } else { |
326 | return ElemKind::UInt8QTy; |
327 | } |
328 | } else { |
329 | return MAKE_ERR( |
330 | "TensorFlowLite: Non-quantized UINT8 type not supported!" ); |
331 | } |
332 | } |
333 | case tflite::TensorType_INT16: { |
334 | if (isQuantized) { |
335 | return ElemKind::Int16QTy; |
336 | } else { |
337 | return MAKE_ERR( |
338 | "TensorFlowLite: Non-quantized INT16 type not supported!" ); |
339 | } |
340 | } |
341 | case tflite::TensorType_INT32: { |
342 | if (isQuantized) { |
343 | return ElemKind::Int32QTy; |
344 | } else { |
345 | return ElemKind::Int32ITy; |
346 | } |
347 | } |
348 | case tflite::TensorType_INT64: { |
349 | if (isQuantized) { |
350 | return MAKE_ERR("TensorFlowLite: Quantized INT64 type not supported!" ); |
351 | } else { |
352 | return ElemKind::Int64ITy; |
353 | } |
354 | } |
355 | case tflite::TensorType_BOOL: { |
356 | RETURN_ERR_IF_NOT( |
357 | !isQuantized, |
358 | "TensorFlowLite: BOOL type should have no quantization parameters!" ); |
359 | return ElemKind::BoolTy; |
360 | } |
361 | default: |
362 | return MAKE_ERR( |
363 | strFormat("TensorFlowLite: Tensor '%s' type '%s' not supported!" , |
364 | getTensorName(tensor).c_str(), |
365 | tflite::EnumNameTensorType(tensor->type()))); |
366 | } |
367 | } |
368 | |
369 | bool TFLiteModelLoader::isTensorQuantized(const tflite::Tensor *tensor) { |
370 | auto *tensorQParams = tensor->quantization(); |
371 | if (!tensorQParams) { |
372 | return false; |
373 | } |
374 | auto *scales = tensorQParams->scale(); |
375 | auto *offsets = tensorQParams->zero_point(); |
376 | if (!(scales && offsets)) { |
377 | return false; |
378 | } |
379 | if (!(scales->size() && offsets->size())) { |
380 | return false; |
381 | } |
382 | return true; |
383 | } |
384 | |
385 | bool TFLiteModelLoader::isTensorPerAxisQuantized(const tflite::Tensor *tensor) { |
386 | if (!isTensorQuantized(tensor)) { |
387 | return false; |
388 | } |
389 | auto *tensorQParams = tensor->quantization(); |
390 | auto *scales = tensorQParams->scale(); |
391 | auto *offsets = tensorQParams->zero_point(); |
392 | return (scales->size() > 1) && (offsets->size() > 1); |
393 | } |
394 | |
395 | Expected<float> |
396 | TFLiteModelLoader::getTensorScale(const tflite::Tensor *tensor) { |
397 | auto *tensorQParams = tensor->quantization(); |
398 | RETURN_ERR_IF_NOT( |
399 | isTensorQuantized(tensor), |
400 | strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!" , |
401 | getTensorName(tensor).c_str())); |
402 | RETURN_ERR_IF_NOT( |
403 | tensorQParams->details_type() == tflite::QuantizationDetails_NONE, |
404 | strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is " |
405 | "not supported!" , |
406 | getTensorName(tensor).c_str())); |
407 | auto *scales = tensorQParams->scale(); |
408 | RETURN_ERR_IF_NOT(scales->size() == 1, |
409 | strFormat("TensorFlowLite: Tensor '%s' has %d quantization " |
410 | "parameters but only one was expected!" , |
411 | getTensorName(tensor).c_str(), scales->size())); |
412 | float scale = (*scales)[0]; |
413 | return scale; |
414 | } |
415 | |
416 | Expected<int32_t> |
417 | TFLiteModelLoader::getTensorOffset(const tflite::Tensor *tensor) { |
418 | auto *tensorQParams = tensor->quantization(); |
419 | RETURN_ERR_IF_NOT( |
420 | isTensorQuantized(tensor), |
421 | strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!" , |
422 | getTensorName(tensor).c_str())); |
423 | RETURN_ERR_IF_NOT( |
424 | tensorQParams->details_type() == tflite::QuantizationDetails_NONE, |
425 | strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is " |
426 | "not supported!" , |
427 | getTensorName(tensor).c_str())); |
428 | auto *offsets = tensorQParams->zero_point(); |
429 | RETURN_ERR_IF_NOT(offsets->size() == 1, |
430 | strFormat("TensorFlowLite: Tensor '%s' has %d quantization " |
431 | "parameters but only one was expected!" , |
432 | getTensorName(tensor).c_str(), offsets->size())); |
433 | // TensorFlowLite defines the offset as int64 since it also supports int64 |
434 | // quantized type. Since Glow defines the offset as int32 we perform a cast |
435 | // here and also validate that the offset is within the int32 range. |
436 | int64_t offsetInt64 = (*offsets)[0]; |
437 | RETURN_ERR_IF_NOT( |
438 | (std::numeric_limits<int32_t>::min() <= offsetInt64) && |
439 | (offsetInt64 <= std::numeric_limits<int32_t>::max()), |
440 | strFormat( |
441 | "TensorFlowLite: Tensor '%s' has an offset out of the int32 range!" , |
442 | getTensorName(tensor).c_str())); |
443 | int32_t offset = static_cast<int32_t>(offsetInt64); |
444 | // Convert UINT8 offset to INT8 offset. |
445 | if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) { |
446 | offset -= UINT8_TO_INT8_SHIFT; |
447 | } |
448 | return offset; |
449 | } |
450 | |
451 | Expected<std::vector<float>> |
452 | TFLiteModelLoader::getTensorScales(const tflite::Tensor *tensor) { |
453 | auto *tensorQParams = tensor->quantization(); |
454 | RETURN_ERR_IF_NOT( |
455 | isTensorQuantized(tensor), |
456 | strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!" , |
457 | getTensorName(tensor).c_str())); |
458 | RETURN_ERR_IF_NOT( |
459 | tensorQParams->details_type() == tflite::QuantizationDetails_NONE, |
460 | strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is " |
461 | "not supported!" , |
462 | getTensorName(tensor).c_str())); |
463 | auto *scales = tensorQParams->scale(); |
464 | RETURN_ERR_IF_NOT(scales->size() > 1, |
465 | strFormat("TensorFlowLite: Tensor '%s' has %d quantization " |
466 | "parameters but at least one was expected!" , |
467 | getTensorName(tensor).c_str(), scales->size())); |
468 | std::vector<float> scalesVec = |
469 | std::vector<float>(scales->begin(), scales->end()); |
470 | return scalesVec; |
471 | } |
472 | |
473 | Expected<std::vector<int32_t>> |
474 | TFLiteModelLoader::getTensorOffsets(const tflite::Tensor *tensor) { |
475 | auto *tensorQParams = tensor->quantization(); |
476 | RETURN_ERR_IF_NOT( |
477 | isTensorQuantized(tensor), |
478 | strFormat("TensorFlowLite: Tensor '%s' has no quantization parameters!" , |
479 | getTensorName(tensor).c_str())); |
480 | RETURN_ERR_IF_NOT( |
481 | tensorQParams->details_type() == tflite::QuantizationDetails_NONE, |
482 | strFormat("TensorFlowLite: Tensor '%s' has custom quantization which is " |
483 | "not supported!" , |
484 | getTensorName(tensor).c_str())); |
485 | auto *offsets = tensorQParams->zero_point(); |
486 | RETURN_ERR_IF_NOT(offsets->size() > 1, |
487 | strFormat("TensorFlowLite: Tensor '%s' has %d quantization " |
488 | "parameters but at least one was expected!" , |
489 | getTensorName(tensor).c_str(), offsets->size())); |
490 | // TensorFlowLite defines the offset as int64 since it also supports int64 |
491 | // quantized type. Since Glow defines the offset as int32 we perform a cast |
492 | // here and also validate that the offset is within the int32 range. |
493 | std::vector<int32_t> offsetsVec; |
494 | for (auto offsetInt64 : *offsets) { |
495 | RETURN_ERR_IF_NOT( |
496 | (std::numeric_limits<int32_t>::min() <= offsetInt64) && |
497 | (offsetInt64 <= std::numeric_limits<int32_t>::max()), |
498 | strFormat( |
499 | "TensorFlowLite: Tensor '%s' has an offset out of the int32 range!" , |
500 | getTensorName(tensor).c_str())); |
501 | int32_t offset = static_cast<int32_t>(offsetInt64); |
502 | // Convert UINT8 offset to INT8 offset. |
503 | if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) { |
504 | offset -= UINT8_TO_INT8_SHIFT; |
505 | } |
506 | offsetsVec.push_back(offset); |
507 | } |
508 | return offsetsVec; |
509 | } |
510 | |
511 | Expected<Type> TFLiteModelLoader::getTensorType(const tflite::Tensor *tensor) { |
512 | ElemKind elemKind; |
513 | ASSIGN_VALUE_OR_RETURN_ERR(elemKind, getTensorElemKind(tensor)); |
514 | std::vector<dim_t> shape; |
515 | ASSIGN_VALUE_OR_RETURN_ERR(shape, getTensorShape(tensor)); |
516 | if (isQuantizedElemKind(elemKind)) { |
517 | // If tensor is quantized per-axis we use a dummy scale 1.0 and offset 0. |
518 | float scale = 1.0; |
519 | int32_t offset = 0; |
520 | if (!isTensorPerAxisQuantized(tensor)) { |
521 | ASSIGN_VALUE_OR_RETURN_ERR(scale, getTensorScale(tensor)); |
522 | ASSIGN_VALUE_OR_RETURN_ERR(offset, getTensorOffset(tensor)); |
523 | } |
524 | return Type(elemKind, shape, scale, offset); |
525 | } else { |
526 | return Type(elemKind, shape); |
527 | } |
528 | } |
529 | |
530 | Expected<std::pair<const char *, size_t>> |
531 | TFLiteModelLoader::getTensorDataAndSize(const tflite::Tensor *tensor) { |
532 | uint32_t tensorBufferIdx = tensor->buffer(); |
533 | auto *modelBuffers = model_->buffers(); |
534 | RETURN_ERR_IF_NOT(tensorBufferIdx < modelBuffers->size(), |
535 | strFormat("TensorFlowLite: Tensor '%s' has a buffer index " |
536 | "out of range!" , |
537 | getTensorName(tensor).c_str())); |
538 | const char *tensorData = nullptr; |
539 | size_t tensorSize = 0; |
540 | if (auto *buffer = (*modelBuffers)[tensorBufferIdx]) { |
541 | if (auto *array = buffer->data()) { |
542 | if (array->size()) { |
543 | tensorData = |
544 | const_cast<char *>(reinterpret_cast<const char *>(array->data())); |
545 | tensorSize = array->size(); |
546 | } |
547 | } |
548 | } |
549 | return std::pair<const char *, size_t>(tensorData, tensorSize); |
550 | } |
551 | |
552 | ///===---------------------------------------------------------------------===// |
553 | /// Operator Utilities |
554 | ///===---------------------------------------------------------------------===// |
555 | Expected<tflite::BuiltinOperator> |
556 | TFLiteModelLoader::getOperatorCode(const tflite::Operator *op) { |
557 | const auto *modelOpCodes = model_->operator_codes(); |
558 | auto opCodeIdx = op->opcode_index(); |
559 | RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(), |
560 | strFormat("TensorFlowLite: Missing registration for " |
561 | "opcode_index %d!" , |
562 | opCodeIdx)); |
563 | auto *opCode = (*modelOpCodes)[opCodeIdx]; |
564 | auto builtinCode = opCode->builtin_code(); |
565 | RETURN_ERR_IF_NOT( |
566 | (tflite::BuiltinOperator_MIN <= builtinCode) && |
567 | (builtinCode <= tflite::BuiltinOperator_MAX), |
568 | strFormat( |
569 | "TensorFlowLite: Operator builtin_code %d out of the supported " |
570 | "range! You might be using a newer model than currently supported!" , |
571 | builtinCode)); |
572 | return builtinCode; |
573 | } |
574 | |
575 | Expected<std::string> |
576 | TFLiteModelLoader::getOperatorCustomCode(const tflite::Operator *op) { |
577 | const auto *modelOpCodes = model_->operator_codes(); |
578 | auto opCodeIdx = op->opcode_index(); |
579 | RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(), |
580 | strFormat("TensorFlowLite: Missing registration for " |
581 | "opcode_index %d!" , |
582 | opCodeIdx)); |
583 | auto *opCode = (*modelOpCodes)[opCodeIdx]; |
584 | auto customCode = opCode->custom_code(); |
585 | RETURN_ERR_IF_NOT(customCode, |
586 | strFormat("TensorFlowLite: Missing custom code for " |
587 | "opcode_index %d!" , |
588 | opCodeIdx)); |
589 | return customCode->str(); |
590 | } |
591 | |
592 | Expected<int32_t> |
593 | TFLiteModelLoader::getOperatorVersion(const tflite::Operator *op) { |
594 | const auto *modelOpCodes = model_->operator_codes(); |
595 | auto opCodeIdx = op->opcode_index(); |
596 | RETURN_ERR_IF_NOT(opCodeIdx < modelOpCodes->size(), |
597 | strFormat("TensorFlowLite: Missing registration for " |
598 | "opcode_index %d!" , |
599 | opCodeIdx)); |
600 | auto *opCode = (*modelOpCodes)[opCodeIdx]; |
601 | return opCode->version(); |
602 | } |
603 | |
604 | Expected<std::string> |
605 | TFLiteModelLoader::getOperatorType(const tflite::Operator *op) { |
606 | tflite::BuiltinOperator opCode; |
607 | ASSIGN_VALUE_OR_RETURN_ERR(opCode, getOperatorCode(op)); |
608 | return std::string(tflite::EnumNameBuiltinOperator(opCode)); |
609 | } |
610 | |
611 | Expected<std::string> |
612 | TFLiteModelLoader::getOperatorName(const tflite::Operator *op) { |
613 | std::string opType; |
614 | ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op)); |
615 | const auto *opOutputs = op->outputs(); |
616 | // If operator has no outputs then we return the operator type name. |
617 | if (opOutputs->size() == 0) { |
618 | return opType; |
619 | } |
620 | // If the first output tensor corresponds to an output placeholder then we use |
621 | // the operator type name in order to preserve the output placeholder name. |
622 | size_t opOutIdx = static_cast<size_t>((*opOutputs)[0]); |
623 | const auto *graphOutputs = graph_->outputs(); |
624 | for (auto graphOutIdx : (*graphOutputs)) { |
625 | if (int(opOutIdx) == graphOutIdx) { |
626 | return opType; |
627 | } |
628 | } |
629 | // Return the name of the first output tensor. |
630 | const tflite::Tensor *tensor; |
631 | ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(opOutIdx)); |
632 | return getTensorName(tensor); |
633 | } |
634 | |
635 | Expected<flexbuffers::Map> |
636 | TFLiteModelLoader::getOperatorCustomOpts(const tflite::Operator *op) { |
637 | size_t optsSize = op->custom_options()->size(); |
638 | auto *customOpts = op->custom_options(); |
639 | RETURN_ERR_IF_NOT(customOpts, |
640 | strFormat("TensorFlowLite: Missing custom options for " |
641 | "opcode_index %d!" , |
642 | op->opcode_index())); |
643 | const uint8_t *optsAddr = |
644 | reinterpret_cast<const uint8_t *>(customOpts->data()); |
645 | RETURN_ERR_IF_NOT(optsAddr, |
646 | strFormat("TensorFlowLite: Missing custom options for " |
647 | "opcode_index %d!" , |
648 | op->opcode_index())); |
649 | return flexbuffers::GetRoot(optsAddr, optsSize).AsMap(); |
650 | } |
651 | |
652 | Expected<int32_t> |
653 | TFLiteModelLoader::getOperatorInputTensorIdx(const tflite::Operator *op, |
654 | size_t inputIdx) { |
655 | std::string opType; |
656 | ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op)); |
657 | const auto *opInputs = op->inputs(); |
658 | RETURN_ERR_IF_NOT(opInputs, |
659 | strFormat("TensorFlowLite: Operator '%s' has no inputs!" , |
660 | opType.c_str())); |
661 | RETURN_ERR_IF_NOT(inputIdx < opInputs->size(), |
662 | strFormat("TensorFlowLite: Operator '%s' input index %zu " |
663 | "is out of range! Operator has %d inputs!" , |
664 | opType.c_str(), inputIdx, opInputs->size())); |
665 | return (*opInputs)[inputIdx]; |
666 | } |
667 | |
668 | Expected<size_t> |
669 | TFLiteModelLoader::getOperatorOutputTensorIdx(const tflite::Operator *op, |
670 | size_t outputIdx) { |
671 | std::string opType; |
672 | ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op)); |
673 | const auto *opOutputs = op->outputs(); |
674 | RETURN_ERR_IF_NOT(opOutputs, |
675 | strFormat("TensorFlowLite: Operator '%s' has no outputs!" , |
676 | opType.c_str())); |
677 | RETURN_ERR_IF_NOT(outputIdx < opOutputs->size(), |
678 | strFormat("TensorFlowLite: Operator '%s' output index %zu " |
679 | "is out of range! Operator has %d outputs!" , |
680 | opType.c_str(), outputIdx, opOutputs->size())); |
681 | return static_cast<size_t>((*opOutputs)[outputIdx]); |
682 | } |
683 | |
684 | Expected<bool> |
685 | TFLiteModelLoader::isOperatorOutputFinalTensor(const tflite::Operator *op, |
686 | size_t outputIdx) { |
687 | size_t tensorIdx; |
688 | ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx, |
689 | getOperatorOutputTensorIdx(op, outputIdx)); |
690 | const auto *graphOutputs = graph_->outputs(); |
691 | for (auto graphOutIdx : (*graphOutputs)) { |
692 | if (int(tensorIdx) == graphOutIdx) { |
693 | return true; |
694 | } |
695 | } |
696 | return false; |
697 | } |
698 | |
699 | Expected<NodeValue> TFLiteModelLoader::getNodeValueByIndex(size_t index) { |
700 | RETURN_ERR_IF_NOT(!nodeValueByIndex_.empty(), |
701 | "TensorFlowLite: Node value array not initialized!" ); |
702 | RETURN_ERR_IF_NOT( |
703 | index < nodeValueByIndex_.size(), |
704 | strFormat("TensorFlowLite: Node value index %zu is out of range!" , |
705 | index)); |
706 | NodeValue nodeValue = nodeValueByIndex_[index]; |
707 | RETURN_ERR_IF_NOT(nodeValue.getNode(), |
708 | strFormat("TensorFlowLite: Node value with index %zu is " |
709 | "null (not initialized)!" , |
710 | index)); |
711 | return nodeValue; |
712 | } |
713 | |
714 | Error TFLiteModelLoader::setNodeValueByIndex(size_t index, |
715 | NodeValue nodeValue) { |
716 | RETURN_ERR_IF_NOT(!nodeValueByIndex_.empty(), |
717 | "TensorFlowLite: Node value array not initialized!" ); |
718 | RETURN_ERR_IF_NOT( |
719 | index < nodeValueByIndex_.size(), |
720 | strFormat("TensorFlowLite: Node value index %zu is out of range!" , |
721 | index)); |
722 | nodeValueByIndex_[index] = nodeValue; |
723 | return Error::success(); |
724 | } |
725 | |
726 | Expected<NodeValue> |
727 | TFLiteModelLoader::getInputNodeValue(const tflite::Operator *op, |
728 | size_t inputIdx) { |
729 | int32_t tensorIdx; |
730 | ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx, |
731 | getOperatorInputTensorIdx(op, inputIdx)); |
732 | // If the tensor index is negative this means the operand does not exist. |
733 | if (tensorIdx < 0) { |
734 | return NodeValue(nullptr); |
735 | } |
736 | return getNodeValueByIndex(tensorIdx); |
737 | } |
738 | |
739 | Error TFLiteModelLoader::setOutputNodeValue(const tflite::Operator *op, |
740 | NodeValue nodeValue, |
741 | bool checkType) { |
742 | std::vector<NodeValue> nodeValues = {nodeValue}; |
743 | return setOutputNodeValues(op, nodeValues, checkType); |
744 | } |
745 | |
746 | Error TFLiteModelLoader::setOutputNodeValues( |
747 | const tflite::Operator *op, llvm::ArrayRef<NodeValue> nodeValues, |
748 | bool checkType) { |
749 | std::string opType; |
750 | ASSIGN_VALUE_OR_RETURN_ERR(opType, getOperatorType(op)); |
751 | const auto *opOutputs = op->outputs(); |
752 | RETURN_ERR_IF_NOT( |
753 | opOutputs->size() == nodeValues.size(), |
754 | strFormat("TensorFlowLite: Operator '%s' has %d outputs but %zu are set!" , |
755 | opType.c_str(), opOutputs->size(), nodeValues.size())); |
756 | for (size_t idx = 0, idxEnd = nodeValues.size(); idx < idxEnd; ++idx) { |
757 | NodeValue outNodeValue = nodeValues[idx]; |
758 | // Verify the output type of the node value matches the type registered in |
759 | // the model with the exception of the final tensors which are allowed to |
760 | // be modified (for example for the Softmax output when it is a final node). |
761 | // Tensors with undefined shapes are also not checked. |
762 | if (checkType) { |
763 | TypeRef outTy; |
764 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, idx)); |
765 | bool isUndefined; |
766 | ASSIGN_VALUE_OR_RETURN_ERR(isUndefined, isOutputShapeUndefined(op, idx)); |
767 | bool isFinal; |
768 | ASSIGN_VALUE_OR_RETURN_ERR(isFinal, isOperatorOutputFinalTensor(op, idx)); |
769 | RETURN_ERR_IF_NOT(isUndefined || isFinal || |
770 | outTy->isEqual(outNodeValue.getType()), |
771 | strFormat("TensorFlowLite: Operator '%s' modifies the " |
772 | "output type registered in the model!" , |
773 | opType.c_str())); |
774 | } |
775 | // Register the output node value. |
776 | size_t tensorIdx = static_cast<size_t>((*opOutputs)[idx]); |
777 | RETURN_IF_ERR(setNodeValueByIndex(tensorIdx, outNodeValue)); |
778 | } |
779 | return Error::success(); |
780 | } |
781 | |
782 | Expected<TypeRef> TFLiteModelLoader::getOutputType(const tflite::Operator *op, |
783 | size_t outputIndex) { |
784 | size_t tensorIdx; |
785 | ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx, |
786 | getOperatorOutputTensorIdx(op, outputIndex)); |
787 | const tflite::Tensor *tensor; |
788 | ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(tensorIdx)); |
789 | Type type; |
790 | ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor)); |
791 | return mod_.uniqueType(type); |
792 | } |
793 | |
794 | Expected<bool> |
795 | TFLiteModelLoader::isOutputShapeUndefined(const tflite::Operator *op, |
796 | size_t outputIndex) { |
797 | size_t tensorIdx; |
798 | ASSIGN_VALUE_OR_RETURN_ERR(tensorIdx, |
799 | getOperatorOutputTensorIdx(op, outputIndex)); |
800 | const tflite::Tensor *tensor; |
801 | ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(tensorIdx)); |
802 | bool undefined; |
803 | ASSIGN_VALUE_OR_RETURN_ERR(undefined, isTensorShapeUndefined(tensor)); |
804 | return undefined; |
805 | } |
806 | |
807 | void TFLiteModelLoader::initializeNodeValues() { |
808 | auto numTensors = graph_->tensors()->size(); |
809 | nodeValueByIndex_ = std::vector<NodeValue>(numTensors, nullptr); |
810 | } |
811 | |
812 | Error TFLiteModelLoader::loadInputPlaceholders() { |
813 | for (auto inpIdx : *(graph_->inputs())) { |
814 | // Get input placeholder name and type. |
815 | const tflite::Tensor *tensor; |
816 | ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(inpIdx)); |
817 | std::string name = getTensorName(tensor); |
818 | Type type; |
819 | ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor)); |
820 | // Create input placeholder. If the input type is quantized and a float |
821 | // input is requested then we create a float placeholder and a Quantize |
822 | // node, otherwise we create directly the quantized placeholder. |
823 | Placeholder *inpPH; |
824 | NodeValue inpNV; |
825 | if (tfliteFloatInputsOpt && type.isQuantizedType()) { |
826 | TypeRef floatType = mod_.uniqueType(ElemKind::FloatTy, type.dims()); |
827 | inpPH = mod_.createPlaceholder(floatType, name, /*isTrainable*/ false, |
828 | ANY_LAYOUT); |
829 | inpNV = F_->createQuantize(name + ".Quantize" , inpPH, &type); |
830 | } else { |
831 | inpPH = mod_.createPlaceholder(&type, name, /*isTrainable*/ false, |
832 | ANY_LAYOUT); |
833 | inpNV = inpPH; |
834 | } |
835 | // Register placeholder by model input name. |
836 | inputPlaceholderByName_.try_emplace(name, inpPH); |
837 | // Set input node value. |
838 | RETURN_IF_ERR(setNodeValueByIndex(inpIdx, inpNV)); |
839 | } |
840 | return Error::success(); |
841 | } |
842 | |
843 | Error TFLiteModelLoader::loadConstants() { |
844 | const auto *tensors = graph_->tensors(); |
845 | for (size_t idx = 0, idxEnd = tensors->size(); idx < idxEnd; ++idx) { |
846 | // Get tensor data and size. A TensorFlowLite model tensor is a constant |
847 | // if it has data stored in the model. |
848 | const tflite::Tensor *tensor = (*tensors)[idx]; |
849 | std::pair<const char *, size_t> dataAndSize; |
850 | ASSIGN_VALUE_OR_RETURN_ERR(dataAndSize, getTensorDataAndSize(tensor)); |
851 | if (dataAndSize.first == nullptr) { |
852 | continue; |
853 | } |
854 | // Create tensor and initialize data. |
855 | std::string name = getTensorName(tensor); |
856 | Type type; |
857 | ASSIGN_VALUE_OR_RETURN_ERR(type, getTensorType(tensor)); |
858 | RETURN_ERR_IF_NOT( |
859 | type.getSizeInBytes() == dataAndSize.second, |
860 | strFormat("TensorFlowLite: Tensor '%s' mismatch between shape based " |
861 | "size (%zu bytes) and actual data size (%lu bytes)!" , |
862 | name.c_str(), type.getSizeInBytes(), dataAndSize.second)); |
863 | Tensor T = Tensor(type); |
864 | T.copyRawFrom(dataAndSize.first); |
865 | // Convert UINT8 data to INT8 data. |
866 | if (tfliteUint8ToInt8Opt && (tensor->type() == tflite::TensorType_UINT8)) { |
867 | convertUint8ToInt8(reinterpret_cast<uint8_t *>(T.getUnsafePtr()), |
868 | reinterpret_cast<int8_t *>(T.getUnsafePtr()), |
869 | dataAndSize.second); |
870 | } |
871 | // Create constant. |
872 | Constant *node = mod_.createConstant(name, std::move(T), ANY_LAYOUT); |
873 | // Register node value. |
874 | RETURN_IF_ERR(setNodeValueByIndex(idx, node->getOutput())); |
875 | } |
876 | return Error::success(); |
877 | } |
878 | |
879 | Error TFLiteModelLoader::loadOperators() { |
880 | OperatorInfo opInfo; |
881 | auto *graphOperators = graph_->operators(); |
882 | for (size_t opIdx = 0, opIdxEnd = graphOperators->size(); opIdx < opIdxEnd; |
883 | ++opIdx) { |
884 | // Get operator meta data. |
885 | const tflite::Operator *op = (*graphOperators)[opIdx]; |
886 | ASSIGN_VALUE_OR_RETURN_ERR(opInfo.name, getOperatorName(op)); |
887 | ASSIGN_VALUE_OR_RETURN_ERR(opInfo.type, getOperatorType(op)); |
888 | ASSIGN_VALUE_OR_RETURN_ERR(opInfo.code, getOperatorCode(op)); |
889 | ASSIGN_VALUE_OR_RETURN_ERR(opInfo.version, getOperatorVersion(op)); |
890 | opInfo.index = opIdx; |
891 | // Load operator. |
892 | mod_.registerOriginalName(opInfo.name); |
893 | RETURN_IF_ERR(loadOperator(op, opInfo)); |
894 | } |
895 | return Error::success(); |
896 | } |
897 | |
898 | Error TFLiteModelLoader::saveOutputPlaceholders() { |
899 | for (auto outIdx : *(graph_->outputs())) { |
900 | // Get placeholder name. |
901 | const tflite::Tensor *tensor; |
902 | ASSIGN_VALUE_OR_RETURN_ERR(tensor, getTensorByIndex(outIdx)); |
903 | std::string name = getTensorName(tensor); |
904 | // Save output placeholder. If the output type is quantized and a float |
905 | // output is requested then we create create a Dequantize node and save |
906 | // into a float placeholder, otherwise we save the quantized placeholder. |
907 | NodeValue outNodeValue; |
908 | ASSIGN_VALUE_OR_RETURN_ERR(outNodeValue, getNodeValueByIndex(outIdx)); |
909 | if (tfliteFloatOutputsOpt && outNodeValue.getType()->isQuantizedType()) { |
910 | outNodeValue = F_->createDequantize(name + ".Dequantize" , outNodeValue, |
911 | ElemKind::FloatTy); |
912 | } |
913 | auto *saveNode = F_->createSave(name, outNodeValue); |
914 | // Register placeholder by model output name. |
915 | outputPlaceholderByName_.try_emplace(name, saveNode->getPlaceholder()); |
916 | } |
917 | return Error::success(); |
918 | } |
919 | |
920 | Error TFLiteModelLoader::addActivation(NodeValue &value, |
921 | tflite::ActivationFunctionType type) { |
922 | std::string nodeName = value.getNode()->getName().str(); |
923 | std::string actType = EnumNameActivationFunctionType(type); |
924 | std::string actName = nodeName + "." + actType; |
925 | if (type == tflite::ActivationFunctionType_NONE) { |
926 | return Error::success(); |
927 | } |
928 | if (type == tflite::ActivationFunctionType_RELU) { |
929 | value = F_->createRELU(actName, value); |
930 | return Error::success(); |
931 | } |
932 | if (type == tflite::ActivationFunctionType_RELU_N1_TO_1) { |
933 | value = F_->createClip(actName, value, -1.0, 1.0); |
934 | return Error::success(); |
935 | } |
936 | if (type == tflite::ActivationFunctionType_RELU6) { |
937 | value = F_->createClip(actName, value, 0.0, 6.0); |
938 | return Error::success(); |
939 | } |
940 | if (type == tflite::ActivationFunctionType_TANH) { |
941 | value = F_->createTanh(actName, value); |
942 | return Error::success(); |
943 | } |
944 | return MAKE_ERR( |
945 | strFormat("TensorFlowLite: Activation type '%s' is not supported!" , |
946 | actType.c_str())); |
947 | } |
948 | |
949 | const std::string TFLiteModelLoader::opErrMsg(const OperatorInfo &opInfo, |
950 | const std::string &errMsg) { |
951 | return strFormat("TensorFlowLite: Operator '%s' (Index %zu, Code %u): %s" , |
952 | opInfo.type.c_str(), opInfo.index, opInfo.code, |
953 | errMsg.c_str()); |
954 | } |
955 | |
956 | template <typename T> |
957 | Expected<T> TFLiteModelLoader::loadAxis(const OperatorInfo &opInfo, |
958 | NodeValue axis, NodeValue value) { |
959 | auto *axisC = llvm::dyn_cast<Constant>(axis.getNode()); |
960 | RETURN_ERR_IF_NOT(axisC, |
961 | opErrMsg(opInfo, "Non constant axis not supported!" )); |
962 | RETURN_ERR_IF_NOT(axisC->getType()->size() == 1, |
963 | opErrMsg(opInfo, "Axis should have 1 element!" )); |
964 | T axisVal; |
965 | auto elemType = axisC->getType()->getElementType(); |
966 | if (elemType == ElemKind::Int32ITy) { |
967 | auto axisH = axisC->getPayload().getHandle<int32_t>(); |
968 | ASSIGN_VALUE_OR_RETURN_ERR( |
969 | axisVal, getPositiveAxis<T>(static_cast<int>(axisH.raw(0)), value)); |
970 | } else if (elemType == ElemKind::Int64ITy) { |
971 | auto axisH = axisC->getPayload().getHandle<int64_t>(); |
972 | ASSIGN_VALUE_OR_RETURN_ERR( |
973 | axisVal, getPositiveAxis<T>(static_cast<int>(axisH.raw(0)), value)); |
974 | } else { |
975 | return MAKE_ERR(opErrMsg(opInfo, "Axis should have INT32 or INT64 type!" )); |
976 | } |
977 | return axisVal; |
978 | } |
979 | |
980 | template <typename T> |
981 | Expected<std::vector<T>> TFLiteModelLoader::loadAxes(const OperatorInfo &opInfo, |
982 | NodeValue axes, |
983 | NodeValue value) { |
984 | Constant *axesC = llvm::dyn_cast<Constant>(axes.getNode()); |
985 | RETURN_ERR_IF_NOT(axesC, |
986 | opErrMsg(opInfo, "Non constant axis not supported!" )); |
987 | RETURN_ERR_IF_NOT(axesC->getType()->size() >= 1, |
988 | opErrMsg(opInfo, "Axis should have at least 1 element!" )); |
989 | std::vector<T> axesVal = std::vector<T>(axesC->getType()->size()); |
990 | auto elemType = axesC->getType()->getElementType(); |
991 | for (size_t idx = 0; idx < axesC->getType()->size(); ++idx) { |
992 | if (elemType == ElemKind::Int32ITy) { |
993 | auto axesH = axesC->getPayload().getHandle<int32_t>(); |
994 | ASSIGN_VALUE_OR_RETURN_ERR( |
995 | axesVal[idx], |
996 | getPositiveAxis<T>(static_cast<int>(axesH.raw(idx)), value)); |
997 | } else if (elemType == ElemKind::Int64ITy) { |
998 | auto axesH = axesC->getPayload().getHandle<int64_t>(); |
999 | ASSIGN_VALUE_OR_RETURN_ERR( |
1000 | axesVal[idx], |
1001 | getPositiveAxis<T>(static_cast<int>(axesH.raw(idx)), value)); |
1002 | } else { |
1003 | return MAKE_ERR( |
1004 | opErrMsg(opInfo, "Axis should have INT32 or INT64 type!" )); |
1005 | } |
1006 | } |
1007 | return axesVal; |
1008 | } |
1009 | |
1010 | template <typename T> |
1011 | Expected<std::vector<T>> |
1012 | TFLiteModelLoader::loadArray(const OperatorInfo &opInfo, NodeValue value) { |
1013 | Constant *valueC = llvm::dyn_cast<Constant>(value.getNode()); |
1014 | RETURN_ERR_IF_NOT(valueC, |
1015 | opErrMsg(opInfo, "Non constant array not supported!" )); |
1016 | auto valueSize = valueC->getType()->size(); |
1017 | RETURN_ERR_IF_NOT(valueSize >= 1, |
1018 | opErrMsg(opInfo, "Array should have at least 1 element!" )); |
1019 | std::vector<T> valueV = std::vector<T>(valueSize); |
1020 | auto elemType = valueC->getType()->getElementType(); |
1021 | for (size_t idx = 0; idx < valueSize; ++idx) { |
1022 | if (elemType == ElemKind::FloatTy) { |
1023 | auto valueH = valueC->getPayload().getHandle<float>(); |
1024 | valueV[idx] = static_cast<T>(valueH.raw(idx)); |
1025 | } else if (elemType == ElemKind::Int32ITy) { |
1026 | auto valueH = valueC->getPayload().getHandle<int32_t>(); |
1027 | valueV[idx] = static_cast<T>(valueH.raw(idx)); |
1028 | } else if (elemType == ElemKind::Int64ITy) { |
1029 | auto valueH = valueC->getPayload().getHandle<int64_t>(); |
1030 | valueV[idx] = static_cast<T>(valueH.raw(idx)); |
1031 | } else { |
1032 | return MAKE_ERR(opErrMsg(opInfo, "Array type not supported!" )); |
1033 | } |
1034 | } |
1035 | return valueV; |
1036 | } |
1037 | |
1038 | Expected<bool> TFLiteModelLoader::isConv2DPerAxisQuantized( |
1039 | const tflite::Operator *op, const OperatorInfo &opInfo, |
1040 | Constant *&filterScalesC, Constant *&filterOffsetsC, Constant *&biasScalesC, |
1041 | Constant *&biasOffsetsC) { |
1042 | // Get filter/bias tensors. |
1043 | NodeValue input; |
1044 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1045 | TypeRef outTy; |
1046 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1047 | int32_t filterTensorIdx; |
1048 | ASSIGN_VALUE_OR_RETURN_ERR(filterTensorIdx, getOperatorInputTensorIdx(op, 1)); |
1049 | int32_t biasTensorIdx; |
1050 | ASSIGN_VALUE_OR_RETURN_ERR(biasTensorIdx, getOperatorInputTensorIdx(op, 2)); |
1051 | const tflite::Tensor *filterTensor; |
1052 | ASSIGN_VALUE_OR_RETURN_ERR(filterTensor, getTensorByIndex(filterTensorIdx)); |
1053 | const tflite::Tensor *biasTensor; |
1054 | ASSIGN_VALUE_OR_RETURN_ERR(biasTensor, getTensorByIndex(biasTensorIdx)); |
1055 | |
1056 | bool isPerAxisQuantized = isTensorPerAxisQuantized(filterTensor) && |
1057 | isTensorPerAxisQuantized(biasTensor); |
1058 | |
1059 | // If it is not per-axis quantized return directly. |
1060 | if (!isPerAxisQuantized) { |
1061 | filterScalesC = nullptr; |
1062 | filterOffsetsC = nullptr; |
1063 | biasScalesC = nullptr; |
1064 | biasOffsetsC = nullptr; |
1065 | return false; |
1066 | } |
1067 | |
1068 | dim_t numChannels = outTy->dims().back(); |
1069 | |
1070 | // Get filter/bias quantization parameters. |
1071 | std::vector<float> filterScalesV; |
1072 | ASSIGN_VALUE_OR_RETURN_ERR(filterScalesV, getTensorScales(filterTensor)); |
1073 | std::vector<int32_t> filterOffsetsV; |
1074 | ASSIGN_VALUE_OR_RETURN_ERR(filterOffsetsV, getTensorOffsets(filterTensor)); |
1075 | std::vector<float> biasScalesV; |
1076 | ASSIGN_VALUE_OR_RETURN_ERR(biasScalesV, getTensorScales(biasTensor)); |
1077 | std::vector<int32_t> biasOffsetsV; |
1078 | ASSIGN_VALUE_OR_RETURN_ERR(biasOffsetsV, getTensorOffsets(biasTensor)); |
1079 | |
1080 | // Create filter/bias quantization parameters graph constants. |
1081 | filterScalesC = |
1082 | mod_.createConstant(ElemKind::FloatTy, {numChannels}, "filterScales" ); |
1083 | filterOffsetsC = |
1084 | mod_.createConstant(ElemKind::Int32ITy, {numChannels}, "filterOffsets" ); |
1085 | biasScalesC = |
1086 | mod_.createConstant(ElemKind::FloatTy, {numChannels}, "biasScales" ); |
1087 | biasOffsetsC = |
1088 | mod_.createConstant(ElemKind::Int32ITy, {numChannels}, "biasOffsets" ); |
1089 | |
1090 | RETURN_ERR_IF_NOT( |
1091 | filterScalesV.size() == numChannels, |
1092 | opErrMsg(opInfo, |
1093 | "Weights scales length should match the output channels!" )); |
1094 | RETURN_ERR_IF_NOT( |
1095 | filterOffsetsV.size() == numChannels, |
1096 | opErrMsg(opInfo, |
1097 | "Weights offsets length should match the output channels!" )); |
1098 | RETURN_ERR_IF_NOT( |
1099 | biasScalesV.size() == numChannels, |
1100 | opErrMsg(opInfo, "Bias scales length should match the output channels!" )); |
1101 | RETURN_ERR_IF_NOT( |
1102 | biasOffsetsV.size() == numChannels, |
1103 | opErrMsg(opInfo, |
1104 | "Bias offsets length should match the output channels!" )); |
1105 | |
1106 | filterScalesC->getPayloadMutable().copyRawFrom( |
1107 | reinterpret_cast<const char *>(filterScalesV.data())); |
1108 | filterOffsetsC->getPayloadMutable().copyRawFrom( |
1109 | reinterpret_cast<const char *>(filterOffsetsV.data())); |
1110 | biasScalesC->getPayloadMutable().copyRawFrom( |
1111 | reinterpret_cast<const char *>(biasScalesV.data())); |
1112 | biasOffsetsC->getPayloadMutable().copyRawFrom( |
1113 | reinterpret_cast<const char *>(biasOffsetsV.data())); |
1114 | |
1115 | // Validate filter/bias quantization parameters. |
1116 | float inputScale = input.getType()->getScale(); |
1117 | auto filterScalesH = filterScalesC->getPayload().getHandle<float>(); |
1118 | auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>(); |
1119 | auto biasScalesH = biasScalesC->getPayload().getHandle<float>(); |
1120 | auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>(); |
1121 | for (size_t idx = 0; idx < numChannels; ++idx) { |
1122 | // TensorFlowLite mandates that filterOffset and biasOffset are 0. |
1123 | RETURN_ERR_IF_NOT(filterOffsetsH.raw(idx) == 0, |
1124 | opErrMsg(opInfo, "Filter offset was expected to be 0!" )); |
1125 | RETURN_ERR_IF_NOT(biasOffsetsH.raw(idx) == 0, |
1126 | opErrMsg(opInfo, "Bias offset was expected to be 0!" )); |
1127 | |
1128 | float filterScale = filterScalesH.raw(idx); |
1129 | float matMulScale = inputScale * filterScale; |
1130 | float biasScale = biasScalesH.raw(idx); |
1131 | |
1132 | // Check bias scale relative error to inputScale * filterScale. |
1133 | if (biasScale != matMulScale) { |
1134 | float relErr = std::abs(matMulScale - biasScale) / matMulScale; |
1135 | llvm::errs() << opErrMsg( |
1136 | opInfo, |
1137 | strFormat("WARNING: Per channel BIAS scale value was expected " |
1138 | "to be exactly %E (inputScale * weightsScale) but found " |
1139 | "%E instead! Relative absolute error is %E!\n" , |
1140 | matMulScale, biasScale, relErr)); |
1141 | if (relErr < tfliteBiasScaleCheckMaxErrorOpt) { |
1142 | // Modify bias scale. |
1143 | biasScalesH.raw(idx) = matMulScale; |
1144 | } else if (tfliteBiasScaleCheckThrowErrorOpt) { |
1145 | return MAKE_ERR(opErrMsg( |
1146 | opInfo, |
1147 | strFormat("ERROR: Per channel BIAS scale value was expected " |
1148 | "to be exactly %E (inputScale * weightsScale) but found " |
1149 | "%E instead! Relative absolute error is %E!\n" , |
1150 | matMulScale, biasScale, relErr))); |
1151 | } |
1152 | } |
1153 | } |
1154 | |
1155 | return true; |
1156 | } |
1157 | |
1158 | Error TFLiteModelLoader::loadOperator(const tflite::Operator *op, |
1159 | const OperatorInfo &opInfo) { |
1160 | // Opcodes are treated in increasing order to allow easy tracking |
1161 | // for which operators are supported and which are not. |
1162 | auto opCode = opInfo.code; |
1163 | if (opCode == tflite::BuiltinOperator_ADD) { |
1164 | return loadBinaryArithmetic(op, opInfo); |
1165 | } |
1166 | if (opCode == tflite::BuiltinOperator_AVERAGE_POOL_2D) { |
1167 | return loadPool2D(op, opInfo); |
1168 | } |
1169 | if (opCode == tflite::BuiltinOperator_CONCATENATION) { |
1170 | return loadConcat(op, opInfo); |
1171 | } |
1172 | if (opCode == tflite::BuiltinOperator_CONV_2D) { |
1173 | return loadConv2D(op, opInfo); |
1174 | } |
1175 | if (opCode == tflite::BuiltinOperator_DEPTHWISE_CONV_2D) { |
1176 | return loadDepthwiseConv2D(op, opInfo); |
1177 | } |
1178 | if (opCode == tflite::BuiltinOperator_DEQUANTIZE) { |
1179 | return loadUnaryArithmetic(op, opInfo); |
1180 | } |
1181 | if (opCode == tflite::BuiltinOperator_HARD_SWISH) { |
1182 | return loadUnaryArithmetic(op, opInfo); |
1183 | } |
1184 | if (opCode == tflite::BuiltinOperator_FLOOR) { |
1185 | return loadUnaryArithmetic(op, opInfo); |
1186 | } |
1187 | if (opCode == tflite::BuiltinOperator_FULLY_CONNECTED) { |
1188 | return loadFullyConnected(op, opInfo); |
1189 | } |
1190 | if (opCode == tflite::BuiltinOperator_LOGISTIC) { |
1191 | return loadUnaryArithmetic(op, opInfo); |
1192 | } |
1193 | if (opCode == tflite::BuiltinOperator_MAX_POOL_2D) { |
1194 | return loadPool2D(op, opInfo); |
1195 | } |
1196 | if (opCode == tflite::BuiltinOperator_MUL) { |
1197 | return loadBinaryArithmetic(op, opInfo); |
1198 | } |
1199 | if (opCode == tflite::BuiltinOperator_RELU) { |
1200 | return loadUnaryArithmetic(op, opInfo); |
1201 | } |
1202 | if (opCode == tflite::BuiltinOperator_RELU_N1_TO_1) { |
1203 | return loadUnaryArithmetic(op, opInfo); |
1204 | } |
1205 | if (opCode == tflite::BuiltinOperator_RELU6) { |
1206 | return loadUnaryArithmetic(op, opInfo); |
1207 | } |
1208 | if (opCode == tflite::BuiltinOperator_RESHAPE) { |
1209 | return loadReshape(op, opInfo); |
1210 | } |
1211 | if (opCode == tflite::BuiltinOperator_SOFTMAX) { |
1212 | return loadSoftmax(op, opInfo); |
1213 | } |
1214 | if (opCode == tflite::BuiltinOperator_LOG_SOFTMAX) { |
1215 | return loadLogSoftmax(op, opInfo); |
1216 | } |
1217 | if (opCode == tflite::BuiltinOperator_TANH) { |
1218 | return loadUnaryArithmetic(op, opInfo); |
1219 | } |
1220 | if (opCode == tflite::BuiltinOperator_PAD) { |
1221 | return loadPad(op, opInfo); |
1222 | } |
1223 | if (opCode == tflite::BuiltinOperator_TRANSPOSE) { |
1224 | return loadTranspose(op, opInfo); |
1225 | } |
1226 | if (opCode == tflite::BuiltinOperator_MEAN) { |
1227 | return loadReduce(op, opInfo); |
1228 | } |
1229 | if (opCode == tflite::BuiltinOperator_SUB) { |
1230 | return loadBinaryArithmetic(op, opInfo); |
1231 | } |
1232 | if (opCode == tflite::BuiltinOperator_DIV) { |
1233 | return loadBinaryArithmetic(op, opInfo); |
1234 | } |
1235 | if (opCode == tflite::BuiltinOperator_SQUEEZE) { |
1236 | return loadReshape(op, opInfo); |
1237 | } |
1238 | if (opCode == tflite::BuiltinOperator_STRIDED_SLICE) { |
1239 | return loadStridedSlice(op, opInfo); |
1240 | } |
1241 | if (opCode == tflite::BuiltinOperator_EXP) { |
1242 | return loadUnaryArithmetic(op, opInfo); |
1243 | } |
1244 | if (opCode == tflite::BuiltinOperator_SPLIT) { |
1245 | return loadSplit(op, opInfo); |
1246 | } |
1247 | if (opCode == tflite::BuiltinOperator_PRELU) { |
1248 | return loadBinaryArithmetic(op, opInfo); |
1249 | } |
1250 | if (opCode == tflite::BuiltinOperator_MAXIMUM) { |
1251 | return loadBinaryArithmetic(op, opInfo); |
1252 | } |
1253 | if (opCode == tflite::BuiltinOperator_ARG_MAX) { |
1254 | return loadArg(op, opInfo); |
1255 | } |
1256 | if (opCode == tflite::BuiltinOperator_MINIMUM) { |
1257 | return loadBinaryArithmetic(op, opInfo); |
1258 | } |
1259 | if (opCode == tflite::BuiltinOperator_LESS) { |
1260 | return loadBinaryArithmetic(op, opInfo); |
1261 | } |
1262 | if (opCode == tflite::BuiltinOperator_NEG) { |
1263 | return loadUnaryArithmetic(op, opInfo); |
1264 | } |
1265 | if (opCode == tflite::BuiltinOperator_GREATER) { |
1266 | return loadBinaryArithmetic(op, opInfo); |
1267 | } |
1268 | if (opCode == tflite::BuiltinOperator_GREATER_EQUAL) { |
1269 | return loadBinaryArithmetic(op, opInfo); |
1270 | } |
1271 | if (opCode == tflite::BuiltinOperator_LESS_EQUAL) { |
1272 | return loadBinaryArithmetic(op, opInfo); |
1273 | } |
1274 | if (opCode == tflite::BuiltinOperator_SLICE) { |
1275 | return loadSlice(op, opInfo); |
1276 | } |
1277 | if (opCode == tflite::BuiltinOperator_RESIZE_BILINEAR) { |
1278 | return loadResizeBilinear(op, opInfo); |
1279 | } |
1280 | if (opCode == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) { |
1281 | return loadResizeNearest(op, opInfo); |
1282 | } |
1283 | if (opCode == tflite::BuiltinOperator_SPACE_TO_DEPTH) { |
1284 | return loadSpaceToDepth(op, opInfo); |
1285 | } |
1286 | if (opCode == tflite::BuiltinOperator_DEPTH_TO_SPACE) { |
1287 | return loadDepthToSpace(op, opInfo); |
1288 | } |
1289 | if (opCode == tflite::BuiltinOperator_CAST) { |
1290 | return loadCast(op, opInfo); |
1291 | } |
1292 | if (opCode == tflite::BuiltinOperator_GATHER) { |
1293 | return loadGather(op, opInfo); |
1294 | } |
1295 | if (opCode == tflite::BuiltinOperator_GATHER_ND) { |
1296 | return loadGatherND(op, opInfo); |
1297 | } |
1298 | if (opCode == tflite::BuiltinOperator_SELECT) { |
1299 | return loadSelect(op, opInfo); |
1300 | } |
1301 | if (opCode == tflite::BuiltinOperator_SPACE_TO_BATCH_ND) { |
1302 | return loadSpaceToBatchNd(op, opInfo); |
1303 | } |
1304 | if (opCode == tflite::BuiltinOperator_BATCH_TO_SPACE_ND) { |
1305 | return loadBatchToSpaceNd(op, opInfo); |
1306 | } |
1307 | if (opCode == tflite::BuiltinOperator_SIN) { |
1308 | return loadUnaryArithmetic(op, opInfo); |
1309 | } |
1310 | if (opCode == tflite::BuiltinOperator_TILE) { |
1311 | return loadTile(op, opInfo); |
1312 | } |
1313 | if (opCode == tflite::BuiltinOperator_EXPAND_DIMS) { |
1314 | return loadReshape(op, opInfo); |
1315 | } |
1316 | if (opCode == tflite::BuiltinOperator_EQUAL) { |
1317 | return loadBinaryArithmetic(op, opInfo); |
1318 | } |
1319 | if (opCode == tflite::BuiltinOperator_NOT_EQUAL) { |
1320 | return loadBinaryArithmetic(op, opInfo); |
1321 | } |
1322 | if (opCode == tflite::BuiltinOperator_LOG) { |
1323 | return loadUnaryArithmetic(op, opInfo); |
1324 | } |
1325 | if (opCode == tflite::BuiltinOperator_SQRT) { |
1326 | return loadUnaryArithmetic(op, opInfo); |
1327 | } |
1328 | if (opCode == tflite::BuiltinOperator_RSQRT) { |
1329 | return loadUnaryArithmetic(op, opInfo); |
1330 | } |
1331 | if (opCode == tflite::BuiltinOperator_SHAPE) { |
1332 | return loadShape(op, opInfo); |
1333 | } |
1334 | if (opCode == tflite::BuiltinOperator_POW) { |
1335 | return loadBinaryArithmetic(op, opInfo); |
1336 | } |
1337 | if (opCode == tflite::BuiltinOperator_ARG_MIN) { |
1338 | return loadArg(op, opInfo); |
1339 | } |
1340 | if (opCode == tflite::BuiltinOperator_PACK) { |
1341 | return loadPack(op, opInfo); |
1342 | } |
1343 | if (opCode == tflite::BuiltinOperator_LOGICAL_OR) { |
1344 | return loadBinaryArithmetic(op, opInfo); |
1345 | } |
1346 | if (opCode == tflite::BuiltinOperator_LOGICAL_AND) { |
1347 | return loadBinaryArithmetic(op, opInfo); |
1348 | } |
1349 | if (opCode == tflite::BuiltinOperator_LOGICAL_NOT) { |
1350 | return loadUnaryArithmetic(op, opInfo); |
1351 | } |
1352 | if (opCode == tflite::BuiltinOperator_UNPACK) { |
1353 | return loadUnpack(op, opInfo); |
1354 | } |
1355 | if (opCode == tflite::BuiltinOperator_SQUARE) { |
1356 | return loadUnaryArithmetic(op, opInfo); |
1357 | } |
1358 | if (opCode == tflite::BuiltinOperator_LEAKY_RELU) { |
1359 | return loadUnaryArithmetic(op, opInfo); |
1360 | } |
1361 | if (opCode == tflite::BuiltinOperator_ABS) { |
1362 | return loadUnaryArithmetic(op, opInfo); |
1363 | } |
1364 | if (opCode == tflite::BuiltinOperator_CEIL) { |
1365 | return loadUnaryArithmetic(op, opInfo); |
1366 | } |
1367 | if (opCode == tflite::BuiltinOperator_COS) { |
1368 | return loadUnaryArithmetic(op, opInfo); |
1369 | } |
1370 | if (opCode == tflite::BuiltinOperator_QUANTIZE) { |
1371 | return loadUnaryArithmetic(op, opInfo); |
1372 | } |
1373 | if (opCode == tflite::BuiltinOperator_ROUND) { |
1374 | return loadUnaryArithmetic(op, opInfo); |
1375 | } |
1376 | // Load custom operators. |
1377 | if (opCode == tflite::BuiltinOperator_CUSTOM) { |
1378 | // Get custom operator code. |
1379 | std::string customOpCode; |
1380 | ASSIGN_VALUE_OR_RETURN_ERR(customOpCode, getOperatorCustomCode(op)); |
1381 | // Get custom operator options. |
1382 | flexbuffers::Map opts = flexbuffers::Map::EmptyMap(); |
1383 | ASSIGN_VALUE_OR_RETURN_ERR(opts, getOperatorCustomOpts(op)); |
1384 | // Load custom operator. |
1385 | if (customOpCode == "TFLite_Detection_PostProcess" ) { |
1386 | return loadTFLiteDetectionPostProcess(op, opInfo, opts); |
1387 | } |
1388 | if (customOpCode == "AudioSpectrogram" ) { |
1389 | return loadTFLiteAudioSpectrogram(op, opInfo, opts); |
1390 | } |
1391 | if (customOpCode == "Mfcc" ) { |
1392 | return loadTFLiteMFCC(op, opInfo, opts); |
1393 | } |
1394 | } |
1395 | return MAKE_ERR( |
1396 | strFormat("TensorFlowLite: Operator type '%s' is not supported!" , |
1397 | opInfo.type.c_str())); |
1398 | } |
1399 | |
1400 | Error TFLiteModelLoader::loadUnaryArithmetic(const tflite::Operator *op, |
1401 | const OperatorInfo &opInfo) { |
1402 | NodeValue input; |
1403 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1404 | TypeRef outTy; |
1405 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1406 | |
1407 | auto opCode = opInfo.code; |
1408 | NodeValue output; |
1409 | if (opCode == tflite::BuiltinOperator_LOGISTIC) { |
1410 | output = F_->createSigmoid(opInfo.name, outTy, input); |
1411 | } else if (opCode == tflite::BuiltinOperator_HARD_SWISH) { |
1412 | output = F_->createHardSwish(opInfo.name, outTy, input); |
1413 | } else if (opCode == tflite::BuiltinOperator_RELU) { |
1414 | output = F_->createRELU(opInfo.name, input, outTy); |
1415 | } else if (opCode == tflite::BuiltinOperator_RELU_N1_TO_1) { |
1416 | output = F_->createClip(opInfo.name, input, outTy, -1.0, 1.0); |
1417 | } else if (opCode == tflite::BuiltinOperator_RELU6) { |
1418 | output = F_->createClip(opInfo.name, input, outTy, 0.0, 6.0); |
1419 | } else if (opCode == tflite::BuiltinOperator_TANH) { |
1420 | output = F_->createTanh(opInfo.name, outTy, input); |
1421 | } else if (opCode == tflite::BuiltinOperator_EXP) { |
1422 | output = F_->createExp(opInfo.name, input); |
1423 | } else if (opCode == tflite::BuiltinOperator_LOG) { |
1424 | output = F_->createLog(opInfo.name, input, outTy); |
1425 | } else if (opCode == tflite::BuiltinOperator_LEAKY_RELU) { |
1426 | const auto *opts = op->builtin_options_as_LeakyReluOptions(); |
1427 | float alpha = opts->alpha(); |
1428 | output = F_->createLeakyRELU(opInfo.name, outTy, input, alpha); |
1429 | } else if (opCode == tflite::BuiltinOperator_SQUARE) { |
1430 | output = F_->createSquare(opInfo.name, outTy, input); |
1431 | } else if (opCode == tflite::BuiltinOperator_ABS) { |
1432 | output = F_->createAbs(opInfo.name, outTy, input); |
1433 | } else if (opCode == tflite::BuiltinOperator_NEG) { |
1434 | output = F_->createNeg(opInfo.name, outTy, input); |
1435 | } else if (opCode == tflite::BuiltinOperator_FLOOR) { |
1436 | output = F_->createFloor(opInfo.name, outTy, input); |
1437 | } else if (opCode == tflite::BuiltinOperator_CEIL) { |
1438 | output = F_->createCeil(opInfo.name, outTy, input); |
1439 | } else if (opCode == tflite::BuiltinOperator_ROUND) { |
1440 | output = F_->createRound(opInfo.name, outTy, input); |
1441 | } else if (opCode == tflite::BuiltinOperator_SQRT) { |
1442 | output = F_->createSqrt(opInfo.name, outTy, input); |
1443 | } else if (opCode == tflite::BuiltinOperator_RSQRT) { |
1444 | output = F_->createRsqrt(opInfo.name, outTy, input); |
1445 | } else if (opCode == tflite::BuiltinOperator_SIN) { |
1446 | output = F_->createSin(opInfo.name, outTy, input); |
1447 | } else if (opCode == tflite::BuiltinOperator_COS) { |
1448 | output = F_->createCos(opInfo.name, outTy, input); |
1449 | } else if (opCode == tflite::BuiltinOperator_LOGICAL_NOT) { |
1450 | output = F_->createNot(opInfo.name, input); |
1451 | } else if (opCode == tflite::BuiltinOperator_QUANTIZE) { |
1452 | if (input.getType()->isFPType()) { |
1453 | output = F_->createQuantize(opInfo.name, input, outTy); |
1454 | } else { |
1455 | output = F_->createRescaleQuantized(opInfo.name, input, outTy); |
1456 | } |
1457 | } else if (opCode == tflite::BuiltinOperator_DEQUANTIZE) { |
1458 | output = F_->createDequantize(opInfo.name, input, outTy); |
1459 | } else { |
1460 | return MAKE_ERR(opErrMsg(opInfo, "Unsupported unary arithmetic operator!" )); |
1461 | } |
1462 | return setOutputNodeValue(op, output); |
1463 | } |
1464 | |
1465 | Error TFLiteModelLoader::loadBinaryArithmetic(const tflite::Operator *op, |
1466 | const OperatorInfo &opInfo) { |
1467 | NodeValue LHS; |
1468 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getInputNodeValue(op, 0)); |
1469 | NodeValue RHS; |
1470 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getInputNodeValue(op, 1)); |
1471 | TypeRef outTy; |
1472 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1473 | |
1474 | auto opCode = opInfo.code; |
1475 | |
1476 | // skip operators with proper broadcast (no TFLite and Operator Tests for |
1477 | // others yet). |
1478 | if (opCode != tflite::BuiltinOperator_ADD && |
1479 | opCode != tflite::BuiltinOperator_SUB && |
1480 | opCode != tflite::BuiltinOperator_MUL && |
1481 | opCode != tflite::BuiltinOperator_DIV && |
1482 | opCode != tflite::BuiltinOperator_MIN && |
1483 | opCode != tflite::BuiltinOperator_MAX) { |
1484 | |
1485 | // LHS operand broadcasting. |
1486 | if (LHS.dims().size() < RHS.dims().size()) { |
1487 | unsigned_t axis = RHS.dims().size() - LHS.dims().size(); |
1488 | LHS = F_->createBroadcast(opInfo.name + ".Broadcast" , LHS, RHS.dims(), |
1489 | axis); |
1490 | } |
1491 | |
1492 | // RHS operand broadcasting. |
1493 | if (RHS.dims().size() < LHS.dims().size()) { |
1494 | unsigned_t axis = LHS.dims().size() - RHS.dims().size(); |
1495 | RHS = F_->createBroadcast(opInfo.name + ".Broadcast" , RHS, LHS.dims(), |
1496 | axis); |
1497 | } |
1498 | } |
1499 | |
1500 | NodeValue output; |
1501 | if (opCode == tflite::BuiltinOperator_ADD) { |
1502 | const auto *opts = op->builtin_options_as_AddOptions(); |
1503 | output = F_->createNodeWithBroadcastOutTy<AddNode>(opInfo.name, -1, outTy, |
1504 | LHS, RHS); |
1505 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1506 | } else if (opCode == tflite::BuiltinOperator_MUL) { |
1507 | const auto *opts = op->builtin_options_as_MulOptions(); |
1508 | output = F_->createNodeWithBroadcastOutTy<MulNode>(opInfo.name, -1, outTy, |
1509 | LHS, RHS); |
1510 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1511 | } else if (opCode == tflite::BuiltinOperator_SUB) { |
1512 | const auto *opts = op->builtin_options_as_SubOptions(); |
1513 | output = F_->createNodeWithBroadcastOutTy<SubNode>(opInfo.name, -1, outTy, |
1514 | LHS, RHS); |
1515 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1516 | } else if (opCode == tflite::BuiltinOperator_DIV) { |
1517 | const auto *opts = op->builtin_options_as_DivOptions(); |
1518 | output = F_->createNodeWithBroadcastOutTy<DivNode>(opInfo.name, -1, outTy, |
1519 | LHS, RHS); |
1520 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1521 | } else if (opCode == tflite::BuiltinOperator_POW) { |
1522 | output = F_->createPow(opInfo.name, outTy, LHS, RHS); |
1523 | } else if (opCode == tflite::BuiltinOperator_PRELU) { |
1524 | NodeValue slope = |
1525 | F_->createReshape(opInfo.name + ".reshape" , RHS, outTy->dims()); |
1526 | output = F_->createPRELU(opInfo.name, LHS, slope, outTy); |
1527 | } else if (opCode == tflite::BuiltinOperator_MAXIMUM) { |
1528 | output = F_->createNodeWithBroadcastOutTy<MaxNode>(opInfo.name, -1, outTy, |
1529 | LHS, RHS); |
1530 | } else if (opCode == tflite::BuiltinOperator_MINIMUM) { |
1531 | output = F_->createNodeWithBroadcastOutTy<MinNode>(opInfo.name, -1, outTy, |
1532 | LHS, RHS); |
1533 | } else if (opCode == tflite::BuiltinOperator_EQUAL) { |
1534 | output = F_->createCmpEQ(opInfo.name, LHS, RHS); |
1535 | } else if (opCode == tflite::BuiltinOperator_NOT_EQUAL) { |
1536 | output = F_->createCmpNEQ(opInfo.name, LHS, RHS); |
1537 | } else if (opCode == tflite::BuiltinOperator_LESS) { |
1538 | output = F_->createCmpLT(opInfo.name, LHS, RHS); |
1539 | } else if (opCode == tflite::BuiltinOperator_LESS_EQUAL) { |
1540 | output = F_->createCmpLTE(opInfo.name, LHS, RHS); |
1541 | } else if (opCode == tflite::BuiltinOperator_GREATER) { |
1542 | output = F_->createCmpGT(opInfo.name, LHS, RHS); |
1543 | } else if (opCode == tflite::BuiltinOperator_GREATER_EQUAL) { |
1544 | output = F_->createCmpGTE(opInfo.name, LHS, RHS); |
1545 | } else if (opCode == tflite::BuiltinOperator_LOGICAL_AND) { |
1546 | output = F_->createAnd(opInfo.name, LHS, RHS); |
1547 | } else if (opCode == tflite::BuiltinOperator_LOGICAL_OR) { |
1548 | output = F_->createOr(opInfo.name, LHS, RHS); |
1549 | } else { |
1550 | return MAKE_ERR( |
1551 | opErrMsg(opInfo, "Unsupported binary arithmetic operator!" )); |
1552 | } |
1553 | return setOutputNodeValue(op, output); |
1554 | } |
1555 | |
1556 | Error TFLiteModelLoader::loadPool2D(const tflite::Operator *op, |
1557 | const OperatorInfo &opInfo) { |
1558 | const auto *opts = op->builtin_options_as_Pool2DOptions(); |
1559 | NodeValue input; |
1560 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1561 | TypeRef outTy; |
1562 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1563 | |
1564 | // Output shape inference when not defined. |
1565 | bool outShapeUndefined; |
1566 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
1567 | if (outShapeUndefined) { |
1568 | dim_t outN = input.dims()[0]; |
1569 | dim_t outH = |
1570 | getConvOutputSize(input.dims()[1], opts->filter_height(), |
1571 | opts->stride_h(), /* dilation */ 1, opts->padding()); |
1572 | dim_t outW = |
1573 | getConvOutputSize(input.dims()[2], opts->filter_width(), |
1574 | opts->stride_w(), /* dilation */ 1, opts->padding()); |
1575 | dim_t outC = input.dims()[3]; |
1576 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, |
1577 | {outN, outH, outW, outC}); |
1578 | } |
1579 | |
1580 | ShapeNHWC inputShape = ShapeNHWC(input.dims()); |
1581 | ShapeNHWC outputShape = ShapeNHWC(outTy->dims()); |
1582 | |
1583 | std::vector<unsigned_t> kernels = { |
1584 | static_cast<unsigned_t>(opts->filter_height()), |
1585 | static_cast<unsigned_t>(opts->filter_width())}; |
1586 | |
1587 | std::vector<unsigned_t> strides = { |
1588 | static_cast<unsigned_t>(opts->stride_h()), |
1589 | static_cast<unsigned_t>(opts->stride_w()), |
1590 | }; |
1591 | |
1592 | auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0], |
1593 | /* dilation */ 1, opts->padding()); |
1594 | auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1], |
1595 | /* dilation */ 1, opts->padding()); |
1596 | std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second, |
1597 | padsLR.second}; |
1598 | |
1599 | auto opCode = opInfo.code; |
1600 | NodeValue output; |
1601 | if (opCode == tflite::BuiltinOperator_AVERAGE_POOL_2D) { |
1602 | // TFLite AvgPool does NOT include padded regions when normalizing. |
1603 | auto *node = F_->createAvgPool(opInfo.name, input, kernels, strides, pads, |
1604 | ConvolutionLayout::NHWC, |
1605 | /* countIncludePads */ false); |
1606 | output = node->getResult(); |
1607 | } else if (opCode == tflite::BuiltinOperator_MAX_POOL_2D) { |
1608 | auto *node = F_->createMaxPool(opInfo.name, input, kernels, strides, pads); |
1609 | output = node->getResult(); |
1610 | } else { |
1611 | return MAKE_ERR(opErrMsg(opInfo, "Unsupported Pool2D operator!" )); |
1612 | } |
1613 | |
1614 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1615 | return setOutputNodeValue(op, output); |
1616 | } |
1617 | |
1618 | Error TFLiteModelLoader::loadConcat(const tflite::Operator *op, |
1619 | const OperatorInfo &opInfo) { |
1620 | const auto *opts = op->builtin_options_as_ConcatenationOptions(); |
1621 | TypeRef outTy; |
1622 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1623 | |
1624 | const size_t numInputs = op->inputs()->size(); |
1625 | llvm::SmallVector<NodeValue, 4> inputs; |
1626 | inputs.reserve(numInputs); |
1627 | for (size_t idx = 0; idx < numInputs; ++idx) { |
1628 | NodeValue input; |
1629 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, idx)); |
1630 | inputs.push_back(input); |
1631 | } |
1632 | |
1633 | // If this node is quantized and there is a mismatch between the input and |
1634 | // output quantization parameters then we pull Rescale nodes from the inputs |
1635 | // to match the output quantization parameters. |
1636 | if (outTy->isQuantizedType()) { |
1637 | for (size_t idx = 0; idx < numInputs; ++idx) { |
1638 | NodeValue input = inputs[idx]; |
1639 | TypeRef inpTy = input.getType(); |
1640 | RETURN_ERR_IF_NOT( |
1641 | inpTy->isQuantizedType(), |
1642 | opErrMsg(opInfo, "Mixed precision for input/output not supported!" )); |
1643 | if ((inpTy->getScale() != outTy->getScale()) || |
1644 | (inpTy->getOffset() != outTy->getOffset())) { |
1645 | TypeRef inpTyNew = mod_.uniqueTypeWithNewShape(outTy, inpTy->dims()); |
1646 | auto *rescaleNode = F_->createRescaleQuantized( |
1647 | opInfo.name + ".Rescale" + std::to_string(idx), input, inpTyNew); |
1648 | inputs[idx] = rescaleNode->getResult(); |
1649 | } |
1650 | } |
1651 | } |
1652 | |
1653 | unsigned_t axis; |
1654 | ASSIGN_VALUE_OR_RETURN_ERR( |
1655 | axis, getPositiveAxis<unsigned_t>(opts->axis(), outTy->dims().size())); |
1656 | |
1657 | NodeValue output = F_->createConcat(opInfo.name, inputs, axis, outTy); |
1658 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1659 | return setOutputNodeValue(op, output); |
1660 | } |
1661 | |
1662 | Error TFLiteModelLoader::loadConv2D(const tflite::Operator *op, |
1663 | const OperatorInfo &opInfo) { |
1664 | const auto *opts = op->builtin_options_as_Conv2DOptions(); |
1665 | NodeValue input; |
1666 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1667 | NodeValue filter; |
1668 | ASSIGN_VALUE_OR_RETURN_ERR(filter, getInputNodeValue(op, 1)); |
1669 | NodeValue bias; |
1670 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2)); |
1671 | TypeRef outTy; |
1672 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1673 | |
1674 | // Output shape inference when not defined. |
1675 | bool outShapeUndefined; |
1676 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
1677 | if (outShapeUndefined) { |
1678 | dim_t outN = input.dims()[0]; |
1679 | dim_t outH = |
1680 | getConvOutputSize(input.dims()[1], filter.dims()[1], opts->stride_h(), |
1681 | opts->dilation_h_factor(), opts->padding()); |
1682 | dim_t outW = |
1683 | getConvOutputSize(input.dims()[2], filter.dims()[2], opts->stride_w(), |
1684 | opts->dilation_w_factor(), opts->padding()); |
1685 | dim_t outC = filter.dims()[0]; |
1686 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, |
1687 | {outN, outH, outW, outC}); |
1688 | } |
1689 | |
1690 | ShapeNHWC inputShape = ShapeNHWC(input.dims()); |
1691 | ShapeNHWC filterShape = ShapeNHWC(filter.dims()); |
1692 | ShapeNHWC outputShape = ShapeNHWC(outTy->dims()); |
1693 | |
1694 | std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(filterShape.h), |
1695 | static_cast<unsigned_t>(filterShape.w)}; |
1696 | |
1697 | std::vector<unsigned_t> strides = { |
1698 | static_cast<unsigned_t>(opts->stride_h()), |
1699 | static_cast<unsigned_t>(opts->stride_w()), |
1700 | }; |
1701 | |
1702 | std::vector<unsigned_t> dilations = { |
1703 | static_cast<unsigned_t>(opts->dilation_h_factor()), |
1704 | static_cast<unsigned_t>(opts->dilation_w_factor()), |
1705 | }; |
1706 | |
1707 | auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0], |
1708 | dilations[0], opts->padding()); |
1709 | auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1], |
1710 | dilations[1], opts->padding()); |
1711 | std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second, |
1712 | padsLR.second}; |
1713 | |
1714 | // There are TensorFlowLite models which have only the weights quantized |
1715 | // to INT8 (the rest of the operands being FLOAT32). Since Glow does not |
1716 | // support mixed precision operation we dequantize the weights. |
1717 | if (input.getType()->isFPType() && filter.getType()->isQuantizedType() && |
1718 | bias.getType()->isFPType() && outTy->isFPType()) { |
1719 | filter = F_->createDequantize(opInfo.name + ".Dequantize" , filter, |
1720 | outTy->getElementType()); |
1721 | } |
1722 | |
1723 | // Check whether this operator is quantized per axis. |
1724 | bool isPerAxisQuantized; |
1725 | Constant *filterScales = nullptr; |
1726 | Constant *filterOffsets = nullptr; |
1727 | Constant *biasScales = nullptr; |
1728 | Constant *biasOffsets = nullptr; |
1729 | ASSIGN_VALUE_OR_RETURN_ERR(isPerAxisQuantized, |
1730 | isConv2DPerAxisQuantized(op, opInfo, filterScales, |
1731 | filterOffsets, biasScales, |
1732 | biasOffsets)); |
1733 | |
1734 | // Create convolution node. |
1735 | NodeValue output; |
1736 | if (isPerAxisQuantized) { |
1737 | // Check that filter and bias are constants. |
1738 | RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(filter.getNode()), |
1739 | opErrMsg(opInfo, "Filter must be constant!" )); |
1740 | RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(bias.getNode()), |
1741 | opErrMsg(opInfo, "Bias must be constant!" )); |
1742 | // Create ChannelwiseQuantizedConvolution node. |
1743 | output = F_->createChannelwiseQuantizedConv( |
1744 | opInfo.name, input, filter, bias, filterScales, filterOffsets, |
1745 | biasScales, biasOffsets, outTy, kernels, strides, pads, /* group */ 1, |
1746 | dilations, /* quantizeFilter */ false, /* quantizeBias */ false); |
1747 | } else { |
1748 | // Check bias quantization parameters. |
1749 | RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, filter, bias)); |
1750 | // Create Convolution node. |
1751 | output = F_->createConv(opInfo.name, input, filter, bias, outTy, kernels, |
1752 | strides, pads, /* group */ 1, dilations); |
1753 | } |
1754 | |
1755 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1756 | return setOutputNodeValue(op, output); |
1757 | } |
1758 | |
1759 | Error TFLiteModelLoader::loadDepthwiseConv2D(const tflite::Operator *op, |
1760 | const OperatorInfo &opInfo) { |
1761 | const auto *opts = op->builtin_options_as_DepthwiseConv2DOptions(); |
1762 | NodeValue input; |
1763 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1764 | NodeValue filter; |
1765 | ASSIGN_VALUE_OR_RETURN_ERR(filter, getInputNodeValue(op, 1)); |
1766 | NodeValue bias; |
1767 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2)); |
1768 | TypeRef outTy; |
1769 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1770 | |
1771 | // Output shape inference when not defined. |
1772 | bool outShapeUndefined; |
1773 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
1774 | if (outShapeUndefined) { |
1775 | dim_t outN = input.dims()[0]; |
1776 | dim_t outH = |
1777 | getConvOutputSize(input.dims()[1], filter.dims()[1], opts->stride_h(), |
1778 | opts->dilation_h_factor(), opts->padding()); |
1779 | dim_t outW = |
1780 | getConvOutputSize(input.dims()[2], filter.dims()[2], opts->stride_w(), |
1781 | opts->dilation_w_factor(), opts->padding()); |
1782 | dim_t outC = input.dims()[3] * opts->depth_multiplier(); |
1783 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, |
1784 | {outN, outH, outW, outC}); |
1785 | } |
1786 | |
1787 | ShapeNHWC inputShape = ShapeNHWC(input.dims()); |
1788 | ShapeNHWC filterShape = ShapeNHWC(filter.dims()); |
1789 | ShapeNHWC outputShape = ShapeNHWC(outTy->dims()); |
1790 | |
1791 | std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(filterShape.h), |
1792 | static_cast<unsigned_t>(filterShape.w)}; |
1793 | |
1794 | std::vector<unsigned_t> strides = { |
1795 | static_cast<unsigned_t>(opts->stride_h()), |
1796 | static_cast<unsigned_t>(opts->stride_w()), |
1797 | }; |
1798 | |
1799 | std::vector<unsigned_t> dilations = {1, 1}; |
1800 | if (opInfo.version >= 2) { |
1801 | dilations = {static_cast<unsigned_t>(opts->dilation_h_factor()), |
1802 | static_cast<unsigned_t>(opts->dilation_w_factor())}; |
1803 | } |
1804 | |
1805 | auto padsTB = getConvPads(inputShape.h, outputShape.h, kernels[0], strides[0], |
1806 | dilations[0], opts->padding()); |
1807 | auto padsLR = getConvPads(inputShape.w, outputShape.w, kernels[1], strides[1], |
1808 | dilations[1], opts->padding()); |
1809 | std::vector<unsigned_t> pads = {padsTB.first, padsLR.first, padsTB.second, |
1810 | padsLR.second}; |
1811 | |
1812 | // Convolution group is inputChannels / filterChannels = inputChannels. |
1813 | unsigned_t group = input.dims().back(); |
1814 | |
1815 | // There are TensorFlowLite models which have only the weights quantized |
1816 | // to INT8 (the rest of the operands being FLOAT32). Since Glow does not |
1817 | // support mixed precision operation we dequantize the weights. |
1818 | if (input.getType()->isFPType() && filter.getType()->isQuantizedType() && |
1819 | bias.getType()->isFPType() && outTy->isFPType()) { |
1820 | filter = F_->createDequantize(opInfo.name + ".Dequantize" , filter, |
1821 | outTy->getElementType()); |
1822 | } |
1823 | |
1824 | // Check whether this operator is quantized per axis. |
1825 | bool isPerAxisQuantized; |
1826 | Constant *filterScales = nullptr; |
1827 | Constant *filterOffsets = nullptr; |
1828 | Constant *biasScales = nullptr; |
1829 | Constant *biasOffsets = nullptr; |
1830 | ASSIGN_VALUE_OR_RETURN_ERR(isPerAxisQuantized, |
1831 | isConv2DPerAxisQuantized(op, opInfo, filterScales, |
1832 | filterOffsets, biasScales, |
1833 | biasOffsets)); |
1834 | |
1835 | // Transpose filter from CHWN to NHWC in-place without using a Reshape |
1836 | // node because further down the ChannelwiseQuantizedConvolution requires |
1837 | // the filter to be a Constant. |
1838 | RETURN_ERR_IF_NOT(filter.dims().size() == 4, |
1839 | opErrMsg(opInfo, "Filter should be 4D!" )); |
1840 | if (isPerAxisQuantized) { |
1841 | Constant *filterC = llvm::dyn_cast<Constant>(filter.getNode()); |
1842 | RETURN_ERR_IF_NOT(filterC, opErrMsg(opInfo, "Filter must be constant!" )); |
1843 | TypeRef filterTy = filterC->getType(); |
1844 | auto filterDims = filterTy->dims(); |
1845 | TypeRef newFilterTy = mod_.uniqueTypeWithNewShape( |
1846 | filterTy, {filterDims[3], filterDims[1], filterDims[2], filterDims[0]}); |
1847 | Tensor newFilterT = Tensor(newFilterTy); |
1848 | filterC->getPayload().transpose(&newFilterT, {3, 1, 2, 0}); |
1849 | Constant *newFilterC = mod_.createConstant( |
1850 | filterC->getName().str() + ".Reshape" , std::move(newFilterT), "NHWC" ); |
1851 | filter = newFilterC->getOutput(); |
1852 | } else { |
1853 | filter = F_->createTranspose(opInfo.name + ".Transpose" , filter, |
1854 | {3, 1, 2, 0}, "NHWC" ); |
1855 | } |
1856 | RETURN_ERR_IF_NOT(filter.dims().back() == 1, |
1857 | opErrMsg(opInfo, "Filter should have 1 channel!" )); |
1858 | |
1859 | // Create convolution node. |
1860 | NodeValue output; |
1861 | if (isPerAxisQuantized) { |
1862 | // Check that filter and bias are constants. |
1863 | RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(filter.getNode()), |
1864 | opErrMsg(opInfo, "Filter must be constant!" )); |
1865 | RETURN_ERR_IF_NOT(llvm::dyn_cast<Constant>(bias.getNode()), |
1866 | opErrMsg(opInfo, "Bias must be constant!" )); |
1867 | // Create ChannelwiseQuantizedConvolution node. |
1868 | output = F_->createChannelwiseQuantizedConv( |
1869 | opInfo.name, input, filter, bias, filterScales, filterOffsets, |
1870 | biasScales, biasOffsets, outTy, kernels, strides, pads, group, |
1871 | dilations, /* quantizeFilter */ false, /* quantizeBias */ false); |
1872 | } else { |
1873 | // Check bias quantization parameters. |
1874 | RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, filter, bias)); |
1875 | // Create Convolution node. |
1876 | output = F_->createConv(opInfo.name, input, filter, bias, outTy, kernels, |
1877 | strides, pads, group, dilations); |
1878 | } |
1879 | |
1880 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1881 | return setOutputNodeValue(op, output); |
1882 | } |
1883 | |
1884 | Error TFLiteModelLoader::loadFullyConnected(const tflite::Operator *op, |
1885 | const OperatorInfo &opInfo) { |
1886 | const auto *opts = op->builtin_options_as_FullyConnectedOptions(); |
1887 | NodeValue input; |
1888 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1889 | NodeValue weights; |
1890 | ASSIGN_VALUE_OR_RETURN_ERR(weights, getInputNodeValue(op, 1)); |
1891 | NodeValue bias; |
1892 | ASSIGN_VALUE_OR_RETURN_ERR(bias, getInputNodeValue(op, 2)); |
1893 | TypeRef outTy; |
1894 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1895 | |
1896 | // If bias is not used we create one initialized with 0. |
1897 | if (!bias.getNode()) { |
1898 | if (input.getType()->isQuantizedType()) { |
1899 | float biasScale = |
1900 | input.getType()->getScale() * weights.getType()->getScale(); |
1901 | int32_t biasOffset = 0; |
1902 | auto *biasC = |
1903 | mod_.createConstant(ElemKind::Int32QTy, {weights.dims()[0]}, |
1904 | biasScale, biasOffset, opInfo.name + ".bias" ); |
1905 | biasC->getPayloadMutable().zero(); |
1906 | bias = biasC; |
1907 | } else { |
1908 | auto *biasC = mod_.createConstant(ElemKind::FloatTy, {weights.dims()[0]}, |
1909 | opInfo.name + ".bias" ); |
1910 | biasC->getPayloadMutable().zero(); |
1911 | bias = biasC; |
1912 | } |
1913 | } |
1914 | |
1915 | RETURN_IF_ERR(checkBiasQuantizationParams(mod_, input, weights, bias)); |
1916 | |
1917 | // Output shape inference when not defined. |
1918 | bool outShapeUndefined; |
1919 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
1920 | if (outShapeUndefined) { |
1921 | dim_t outN = input.dims()[0]; |
1922 | dim_t outC = weights.dims()[0]; |
1923 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, {outN, outC}); |
1924 | } |
1925 | |
1926 | // There are TensorFlowLite models which have only the weights quantized |
1927 | // to INT8 (the rest of the operands being FLOAT32). Since Glow does not |
1928 | // support mixed precision operation we dequantize the weights. |
1929 | if (input.getType()->isFPType() && weights.getType()->isQuantizedType() && |
1930 | bias.getType()->isFPType() && outTy->isFPType()) { |
1931 | weights = F_->createDequantize(opInfo.name + ".Dequantize" , weights, |
1932 | outTy->getElementType()); |
1933 | } |
1934 | |
1935 | if (opInfo.version >= 2) { |
1936 | RETURN_ERR_IF_NOT( |
1937 | opts->weights_format() == |
1938 | tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, |
1939 | opErrMsg(opInfo, "Only default weights format is supported!" )); |
1940 | } |
1941 | |
1942 | bool keepDims = false; |
1943 | if (opInfo.version >= 5) { |
1944 | keepDims = opts->keep_num_dims(); |
1945 | } |
1946 | |
1947 | // Transpose weights. |
1948 | RETURN_ERR_IF_NOT(weights.dims().size() == 2, |
1949 | opErrMsg(opInfo, "Weights should be 2D!" )); |
1950 | weights = F_->createTranspose(opInfo.name + ".Transpose" , weights, {1, 0}); |
1951 | |
1952 | // For an input with shape [D(0), D(1), ... , D(N-1)] if: |
1953 | // keep_num_dims is FALSE then we flatten the input into: |
1954 | // [D(0), D(1) x D(2) x ... x D(N-1)] (axis = 1). |
1955 | // keep_num_dims options is TRUE then we flatten the input into: |
1956 | // [D(0) x D(1) x ... x D(N-2), D(N-1)] (axis = N-1). |
1957 | unsigned_t axis = keepDims ? (input.dims().size() - 1) : 1; |
1958 | NodeValue output = |
1959 | F_->createFullyConnected(opInfo.name, input, weights, bias, outTy, axis); |
1960 | RETURN_IF_ERR(addActivation(output, opts->fused_activation_function())); |
1961 | |
1962 | // Expand output dims if necessary. |
1963 | if (keepDims) { |
1964 | std::vector<dim_t> outputDims = input.dims(); |
1965 | outputDims.back() = output.dims().back(); |
1966 | output = F_->createReshape(opInfo.name + ".Reshape" , output, outputDims); |
1967 | } |
1968 | return setOutputNodeValue(op, output); |
1969 | } |
1970 | |
1971 | Error TFLiteModelLoader::loadReshape(const tflite::Operator *op, |
1972 | const OperatorInfo &opInfo) { |
1973 | NodeValue input; |
1974 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
1975 | TypeRef outTy; |
1976 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
1977 | |
1978 | // Output shape inference when not defined. |
1979 | bool outShapeUndefined; |
1980 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
1981 | if (outShapeUndefined) { |
1982 | if (const auto *opts = op->builtin_options_as_ReshapeOptions()) { |
1983 | auto *newShape = opts->new_shape(); |
1984 | std::vector<dim_t> outDims(newShape->size()); |
1985 | dim_t dimProd = 1; |
1986 | for (size_t idx = 0; idx < newShape->size(); ++idx) { |
1987 | auto newDim = (*newShape)[idx]; |
1988 | if (newDim != -1) { |
1989 | outDims[idx] = newDim; |
1990 | dimProd *= newDim; |
1991 | } |
1992 | } |
1993 | for (size_t idx = 0; idx < newShape->size(); ++idx) { |
1994 | auto newDim = (*newShape)[idx]; |
1995 | if (newDim == -1) { |
1996 | outDims[idx] = input.getType()->size() / dimProd; |
1997 | } |
1998 | } |
1999 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, outDims); |
2000 | } |
2001 | } |
2002 | |
2003 | // Note: The Reshape node has a second input operand which provides |
2004 | // the new shape but the documentation states that is should be ignored |
2005 | // and the 'new_shape' attribute should be used instead. Moreover, in |
2006 | // this case we are not using not even the 'new_shape' attribute because |
2007 | // we have the output type directly available. We are using this logic |
2008 | // also for loading other operators: Squeeze, ExpandDims. |
2009 | NodeValue output = F_->createReshape(opInfo.name, input, outTy->dims()); |
2010 | return setOutputNodeValue(op, output); |
2011 | } |
2012 | |
2013 | Error TFLiteModelLoader::loadSoftmax(const tflite::Operator *op, |
2014 | const OperatorInfo &opInfo) { |
2015 | const auto *opts = op->builtin_options_as_SoftmaxOptions(); |
2016 | NodeValue input; |
2017 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2018 | TypeRef outTy; |
2019 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2020 | |
2021 | // Output shape inference when not defined. |
2022 | bool outShapeUndefined; |
2023 | ASSIGN_VALUE_OR_RETURN_ERR(outShapeUndefined, isOutputShapeUndefined(op, 0)); |
2024 | if (outShapeUndefined) { |
2025 | outTy = F_->getParent()->uniqueTypeWithNewShape(outTy, input.dims()); |
2026 | } |
2027 | |
2028 | RETURN_ERR_IF_NOT(input.dims().size() >= 2, |
2029 | opErrMsg(opInfo, "Input rank must be >= 2!" )); |
2030 | float beta = opts->beta(); |
2031 | |
2032 | // Create a constant to store labels to be used in SoftMaxGradNode. |
2033 | auto selected = |
2034 | mod_.createConstant(ElemKind::Int64ITy, {input.dims()[0], 1}, "selected" ); |
2035 | |
2036 | NodeValue output; |
2037 | if (tfliteFloatSoftmaxOpt) { |
2038 | // We dequantize the input if it is quantized type. |
2039 | if (input.getType()->isQuantizedType()) { |
2040 | input = F_->createDequantize(opInfo.name + ".Dequantize" , input, |
2041 | ElemKind::FloatTy); |
2042 | } |
2043 | |
2044 | // Create float Softmax regardless of the type defined in the model. |
2045 | output = F_->createSoftMax(opInfo.name, input, selected, nullptr, beta); |
2046 | |
2047 | // If target output type is quantized we quantize the float output of the |
2048 | // Softmax but only if it is not an output placeholder in which case we |
2049 | // allow the output placeholder to remain float even though it was defined |
2050 | // as quantized in the original model. |
2051 | bool isFinal; |
2052 | ASSIGN_VALUE_OR_RETURN_ERR(isFinal, isOperatorOutputFinalTensor(op, 0)); |
2053 | if (outTy->isQuantizedType() && !isFinal) { |
2054 | output = F_->createQuantize(opInfo.name + ".Quantize" , output, outTy); |
2055 | } |
2056 | } else { |
2057 | output = F_->createSoftMax(opInfo.name, input, selected, outTy, beta); |
2058 | } |
2059 | return setOutputNodeValue(op, output); |
2060 | } |
2061 | |
2062 | Error TFLiteModelLoader::loadLogSoftmax(const tflite::Operator *op, |
2063 | const OperatorInfo &opInfo) { |
2064 | NodeValue input; |
2065 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2066 | TypeRef outTy; |
2067 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2068 | |
2069 | // Create a constant to store labels to be used in SoftMaxGradNode. |
2070 | auto selected = |
2071 | mod_.createConstant(ElemKind::Int64ITy, {input.dims()[0], 1}, "selected" ); |
2072 | |
2073 | NodeValue output = F_->createLogSoftMax(opInfo.name, input, selected, outTy); |
2074 | return setOutputNodeValue(op, output); |
2075 | } |
2076 | |
2077 | Error TFLiteModelLoader::loadPad(const tflite::Operator *op, |
2078 | const OperatorInfo &opInfo) { |
2079 | NodeValue input; |
2080 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2081 | NodeValue pads; |
2082 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getInputNodeValue(op, 1)); |
2083 | TypeRef outTy; |
2084 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2085 | |
2086 | // Validate paddings shape. |
2087 | auto numDims = input.dims().size(); |
2088 | RETURN_ERR_IF_NOT(pads.dims().size() == 2, |
2089 | opErrMsg(opInfo, "Paddings should be 2D!" )); |
2090 | RETURN_ERR_IF_NOT(pads.dims()[0] == numDims, |
2091 | opErrMsg(opInfo, "Paddings 1st dimension should match the " |
2092 | "input rank!" )); |
2093 | RETURN_ERR_IF_NOT(pads.dims()[1] == 2, |
2094 | opErrMsg(opInfo, "Paddings 2nd dimensions should be 2!" )); |
2095 | |
2096 | // TFLite paddings are stored as start(D1),stop(D1),start(D2),stop(D2),etc. |
2097 | Constant *padsC = llvm::dyn_cast<Constant>(pads.getNode()); |
2098 | RETURN_ERR_IF_NOT(padsC, |
2099 | opErrMsg(opInfo, "Non constant 'paddings' not supported!" )); |
2100 | RETURN_ERR_IF_NOT(padsC->getType()->getElementType() == ElemKind::Int32ITy, |
2101 | opErrMsg(opInfo, "Paddings should have INT32 type!" )); |
2102 | auto padsH = padsC->getPayload().getHandle<int32_t>(); |
2103 | std::vector<int> padsVec(padsH.size()); |
2104 | for (dim_t dim = 0; dim < numDims; ++dim) { |
2105 | auto padStart = padsH.at({dim, 0}); |
2106 | auto padStop = padsH.at({dim, 1}); |
2107 | RETURN_ERR_IF_NOT((padStart >= 0) && (padStop >= 0), |
2108 | opErrMsg(opInfo, "Invalid negative padding value!" )); |
2109 | padsVec[0 * numDims + dim] = static_cast<int>(padStart); |
2110 | padsVec[1 * numDims + dim] = static_cast<int>(padStop); |
2111 | } |
2112 | |
2113 | NodeValue output = F_->createPad(opInfo.name, input, outTy, |
2114 | PaddingMode::CONSTANT, padsVec, 0.f); |
2115 | return setOutputNodeValue(op, output); |
2116 | } |
2117 | |
2118 | Error TFLiteModelLoader::loadTranspose(const tflite::Operator *op, |
2119 | const OperatorInfo &opInfo) { |
2120 | NodeValue input; |
2121 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2122 | NodeValue perm; |
2123 | ASSIGN_VALUE_OR_RETURN_ERR(perm, getInputNodeValue(op, 1)); |
2124 | |
2125 | Constant *permC = llvm::dyn_cast<Constant>(perm.getNode()); |
2126 | RETURN_ERR_IF_NOT( |
2127 | permC, opErrMsg(opInfo, "Non constant permutation not supported!" )); |
2128 | RETURN_ERR_IF_NOT(permC->getType()->getElementType() == ElemKind::Int32ITy, |
2129 | opErrMsg(opInfo, "Permutation should have INT32 type!" )); |
2130 | auto permH = permC->getPayload().getHandle<int32_t>(); |
2131 | |
2132 | std::vector<unsigned_t> shuffle; |
2133 | for (size_t idx = 0; idx < permH.size(); ++idx) { |
2134 | int32_t dim = permH.raw(idx); |
2135 | RETURN_ERR_IF_NOT(dim >= 0, opErrMsg(opInfo, "Invalid permutation value!" )); |
2136 | shuffle.push_back(static_cast<unsigned_t>(dim)); |
2137 | } |
2138 | |
2139 | NodeValue output = F_->createTranspose(opInfo.name, input, shuffle); |
2140 | return setOutputNodeValue(op, output); |
2141 | } |
2142 | |
2143 | Error TFLiteModelLoader::loadReduce(const tflite::Operator *op, |
2144 | const OperatorInfo &opInfo) { |
2145 | const auto *opts = op->builtin_options_as_ReducerOptions(); |
2146 | NodeValue input; |
2147 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2148 | NodeValue axes; |
2149 | ASSIGN_VALUE_OR_RETURN_ERR(axes, getInputNodeValue(op, 1)); |
2150 | TypeRef outTy; |
2151 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2152 | |
2153 | std::vector<unsigned_t> axesVal; |
2154 | ASSIGN_VALUE_OR_RETURN_ERR(axesVal, |
2155 | loadAxes<unsigned_t>(opInfo, axes, input)); |
2156 | |
2157 | bool keepDims = opts->keep_dims(); |
2158 | |
2159 | // Try to load ReduceMean as AveragePool if equivalent. |
2160 | // TODO: Move this into the GraphOptimizer once Glow supports reduce |
2161 | // operators with multiple axes. |
2162 | if (opInfo.code == tflite::BuiltinOperator_MEAN && axesVal.size() == 2 && |
2163 | axesVal.at(0) == 1 && axesVal.at(1) == 2 && input.dims().size() == 4) { |
2164 | std::vector<unsigned_t> kernels = { |
2165 | static_cast<unsigned_t>(input.dims()[1]), |
2166 | static_cast<unsigned_t>(input.dims()[2])}; |
2167 | std::vector<unsigned_t> strides = {1, 1}; |
2168 | std::vector<unsigned_t> pads = {0, 0, 0, 0}; |
2169 | NodeValue output = F_->createAvgPool(opInfo.name, input, kernels, strides, |
2170 | pads, ConvolutionLayout::NHWC, |
2171 | /* countIncludePads */ false); |
2172 | if (!keepDims) { |
2173 | output = F_->createSqueeze(opInfo.name + ".Squeeze" , output, {1, 2}); |
2174 | } |
2175 | return setOutputNodeValue(op, output); |
2176 | } |
2177 | |
2178 | // Currently the Glow reduce operators do not support multiple axes so we |
2179 | // create chained reduce operators with single axis. |
2180 | // TODO: When Glow supports reduce operators with multiple axes remove this! |
2181 | auto opCode = opInfo.code; |
2182 | NodeValue output = input; |
2183 | for (size_t idx = 0, end = axesVal.size(); idx < end; ++idx) { |
2184 | // Current axis value. |
2185 | unsigned_t axisVal = axesVal[idx]; |
2186 | if (!keepDims) { |
2187 | axisVal = axisVal - idx; |
2188 | } |
2189 | // Current output type. |
2190 | ShapeVector outDimsCurr(output.dims().begin(), output.dims().end()); |
2191 | outDimsCurr.erase(outDimsCurr.begin() + axisVal); |
2192 | auto outTypeCurr = mod_.uniqueTypeWithNewShape(outTy, outDimsCurr); |
2193 | // Create reduce operator. |
2194 | if (opCode == tflite::BuiltinOperator_MEAN) { |
2195 | output = F_->createBatchedReduceMean(opInfo.name, outTypeCurr, output, |
2196 | {axisVal}); |
2197 | // The BatchedReduceMean reduces the output dimension and hence we expand |
2198 | // the output dimensions if keepDims is true. |
2199 | if (keepDims) { |
2200 | output = |
2201 | F_->createExpandDims(opInfo.name + ".Expand" , output, {axisVal}); |
2202 | } |
2203 | } else { |
2204 | return MAKE_ERR(opErrMsg(opInfo, "Unsupported Reduce operator!" )); |
2205 | } |
2206 | } |
2207 | return setOutputNodeValue(op, output); |
2208 | } |
2209 | |
2210 | Error TFLiteModelLoader::loadSplit(const tflite::Operator *op, |
2211 | const OperatorInfo &opInfo) { |
2212 | const auto *opts = op->builtin_options_as_SplitOptions(); |
2213 | NodeValue axis; |
2214 | ASSIGN_VALUE_OR_RETURN_ERR(axis, getInputNodeValue(op, 0)); |
2215 | NodeValue input; |
2216 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 1)); |
2217 | |
2218 | unsigned_t axisVal; |
2219 | ASSIGN_VALUE_OR_RETURN_ERR(axisVal, |
2220 | loadAxis<unsigned_t>(opInfo, axis, input)); |
2221 | |
2222 | unsigned_t numSplits = static_cast<unsigned_t>(opts->num_splits()); |
2223 | RETURN_ERR_IF_NOT( |
2224 | input.dims()[axisVal] % numSplits == 0, |
2225 | opErrMsg( |
2226 | opInfo, |
2227 | "Input dimension should be divisible by 'num_splits' along axis!" )); |
2228 | |
2229 | std::vector<SliceNode *> outputNodes; |
2230 | F_->createSplit(opInfo.name, input, numSplits, axisVal, {}, outputNodes); |
2231 | std::vector<NodeValue> outputNodeValues(outputNodes.size(), nullptr); |
2232 | for (size_t idx = 0, end = outputNodes.size(); idx < end; ++idx) { |
2233 | outputNodeValues[idx] = outputNodes[idx]->getResult(); |
2234 | } |
2235 | return setOutputNodeValues(op, outputNodeValues); |
2236 | } |
2237 | |
2238 | Error TFLiteModelLoader::loadArg(const tflite::Operator *op, |
2239 | const OperatorInfo &opInfo) { |
2240 | NodeValue input; |
2241 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2242 | NodeValue axis; |
2243 | ASSIGN_VALUE_OR_RETURN_ERR(axis, getInputNodeValue(op, 1)); |
2244 | TypeRef outTy; |
2245 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2246 | |
2247 | unsigned_t axisVal; |
2248 | ASSIGN_VALUE_OR_RETURN_ERR(axisVal, |
2249 | loadAxis<unsigned_t>(opInfo, axis, input)); |
2250 | |
2251 | auto opCode = opInfo.code; |
2252 | NodeValue output = nullptr; |
2253 | if (opCode == tflite::BuiltinOperator_ARG_MAX) { |
2254 | output = F_->createArgMax(opInfo.name, input, axisVal, /* keepDims */ false, |
2255 | outTy->getElementType()); |
2256 | } else if (opCode == tflite::BuiltinOperator_ARG_MIN) { |
2257 | output = F_->createArgMin(opInfo.name, input, axisVal, /* keepDims */ false, |
2258 | outTy->getElementType()); |
2259 | } else { |
2260 | return MAKE_ERR(opErrMsg(opInfo, "Unsupported Arg operator!" )); |
2261 | } |
2262 | return setOutputNodeValue(op, output); |
2263 | } |
2264 | |
2265 | Error TFLiteModelLoader::loadShape(const tflite::Operator *op, |
2266 | const OperatorInfo &opInfo) { |
2267 | NodeValue input; |
2268 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2269 | TypeRef outTy; |
2270 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2271 | |
2272 | Constant *shapeC = F_->getParent()->createConstant(outTy, opInfo.name); |
2273 | auto inputDims = input.getType()->dims(); |
2274 | RETURN_ERR_IF_NOT(outTy->dims().size() == 1, |
2275 | opErrMsg(opInfo, "Output should be 1D!" )); |
2276 | RETURN_ERR_IF_NOT(outTy->dims()[0] == inputDims.size(), |
2277 | opErrMsg(opInfo, "Output length should match input rank!" )); |
2278 | if (outTy->getElementType() == ElemKind::Int32ITy) { |
2279 | auto shapeH = shapeC->getPayloadMutable().getHandle<int32_t>(); |
2280 | for (size_t idx = 0; idx < inputDims.size(); ++idx) { |
2281 | shapeH.raw(idx) = static_cast<int32_t>(inputDims[idx]); |
2282 | } |
2283 | } else if (outTy->getElementType() == ElemKind::Int64ITy) { |
2284 | auto shapeH = shapeC->getPayloadMutable().getHandle<int64_t>(); |
2285 | for (size_t idx = 0; idx < inputDims.size(); ++idx) { |
2286 | shapeH.raw(idx) = static_cast<int64_t>(inputDims[idx]); |
2287 | } |
2288 | } else { |
2289 | return MAKE_ERR(opErrMsg(opInfo, "Output should be INT32 or INT64!" )); |
2290 | } |
2291 | |
2292 | return setOutputNodeValue(op, shapeC); |
2293 | } |
2294 | |
2295 | Error TFLiteModelLoader::loadSlice(const tflite::Operator *op, |
2296 | const OperatorInfo &opInfo) { |
2297 | NodeValue input; |
2298 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2299 | NodeValue begin; |
2300 | ASSIGN_VALUE_OR_RETURN_ERR(begin, getInputNodeValue(op, 1)); |
2301 | TypeRef outTy; |
2302 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2303 | // Note: Slice has a third input operand 'size' which provides the size of |
2304 | // the output slice. We derive here the output slice size based on outTy. |
2305 | |
2306 | Constant *beginC = llvm::dyn_cast<Constant>(begin.getNode()); |
2307 | RETURN_ERR_IF_NOT(beginC, |
2308 | opErrMsg(opInfo, "Non constant begin not supported!" )); |
2309 | RETURN_ERR_IF_NOT(beginC->getType()->getElementType() == ElemKind::Int32ITy, |
2310 | opErrMsg(opInfo, "Begin should have INT32 type!" )); |
2311 | auto beginH = beginC->getPayload().getHandle<int32_t>(); |
2312 | |
2313 | std::vector<dim_t> start; |
2314 | for (size_t idx = 0; idx < beginH.size(); ++idx) { |
2315 | int32_t dimStart = beginH.raw(idx); |
2316 | RETURN_ERR_IF_NOT(dimStart >= 0, opErrMsg(opInfo, "Invalid begin value!" )); |
2317 | start.push_back(static_cast<dim_t>(dimStart)); |
2318 | } |
2319 | |
2320 | NodeValue output = F_->createSlice(opInfo.name, input, start, outTy); |
2321 | return setOutputNodeValue(op, output); |
2322 | } |
2323 | |
2324 | Error TFLiteModelLoader::loadStridedSlice(const tflite::Operator *op, |
2325 | const OperatorInfo &opInfo) { |
2326 | const auto *opts = op->builtin_options_as_StridedSliceOptions(); |
2327 | NodeValue input; |
2328 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2329 | NodeValue begin; |
2330 | ASSIGN_VALUE_OR_RETURN_ERR(begin, getInputNodeValue(op, 1)); |
2331 | NodeValue end; |
2332 | ASSIGN_VALUE_OR_RETURN_ERR(end, getInputNodeValue(op, 2)); |
2333 | NodeValue strides; |
2334 | ASSIGN_VALUE_OR_RETURN_ERR(strides, getInputNodeValue(op, 3)); |
2335 | |
2336 | // You can find more information about this operator here: |
2337 | // https://www.tensorflow.org/api_docs/python/tf/strided_slice |
2338 | // https://www.tensorflow.org/mlir/tfl_ops#tflstrided_slice_tflstridedsliceop |
2339 | |
2340 | // We only support strided slice if begin/end/strides are constants. This |
2341 | // is because Glow is statically typed. |
2342 | auto inpDims = input.dims(); |
2343 | std::vector<dim_t> beginV; |
2344 | ASSIGN_VALUE_OR_RETURN_ERR(beginV, loadArray<dim_t>(opInfo, begin)); |
2345 | std::vector<dim_t> endV; |
2346 | ASSIGN_VALUE_OR_RETURN_ERR(endV, loadArray<dim_t>(opInfo, end)); |
2347 | std::vector<int32_t> stridesV; |
2348 | ASSIGN_VALUE_OR_RETURN_ERR(stridesV, loadArray<int32_t>(opInfo, strides)); |
2349 | |
2350 | // The strides must be non-zero. |
2351 | for (size_t idx = 0; idx < stridesV.size(); ++idx) { |
2352 | RETURN_ERR_IF_NOT(stridesV[idx] != 0, |
2353 | opErrMsg(opInfo, "Strides must be non-zero!" )); |
2354 | } |
2355 | |
2356 | // The begin/end/strides must have same size. |
2357 | RETURN_ERR_IF_NOT( |
2358 | beginV.size() == endV.size(), |
2359 | opErrMsg(opInfo, "Begin/end/strides should have same length!" )); |
2360 | RETURN_ERR_IF_NOT( |
2361 | beginV.size() == stridesV.size(), |
2362 | opErrMsg(opInfo, "Begin/end/strides should have same length!" )); |
2363 | |
2364 | // The begin/end/strides length must be equal to the input rank. |
2365 | RETURN_ERR_IF_NOT(beginV.size() == inpDims.size(), |
2366 | opErrMsg(opInfo, "Begin/end/strides length invalid!" )); |
2367 | |
2368 | // Get attributes. |
2369 | int32_t begin_mask = opts->begin_mask(); |
2370 | int32_t end_mask = opts->end_mask(); |
2371 | int32_t ellipsis_mask = opts->ellipsis_mask(); |
2372 | int32_t new_axis_mask = opts->new_axis_mask(); |
2373 | int32_t shrink_axis_mask = opts->shrink_axis_mask(); |
2374 | |
2375 | // Utility to extract the Nth bit from a mask. Returns 0 or 1. |
2376 | auto getMaskBit = [](uint32_t mask, size_t n) -> int { |
2377 | assert((0 <= n) && (n <= 31) && "Bit number exceeded!" ); |
2378 | return (mask >> n) & 0x01; |
2379 | }; |
2380 | |
2381 | // If the ith bit of begin_mask is set, begin[i] is ignored and the fullest |
2382 | // possible range in that dimension is used instead. |
2383 | if (begin_mask) { |
2384 | for (size_t idx = 0; idx < beginV.size(); ++idx) { |
2385 | if (getMaskBit(begin_mask, idx)) { |
2386 | // If stride is positive we start from 0 otherwise from dimension size. |
2387 | beginV[idx] = (stridesV[idx] > 0) ? 0 : (inpDims[idx] - 1); |
2388 | } |
2389 | } |
2390 | } |
2391 | |
2392 | // If the ith bit of end_mask is set, end[i] is ignored and the fullest |
2393 | // possible range in that dimension is used instead. |
2394 | if (end_mask) { |
2395 | for (size_t idx = 0; idx < endV.size(); ++idx) { |
2396 | if (getMaskBit(end_mask, idx)) { |
2397 | // If stride is positive we end at dimension size otherwise at 0. |
2398 | endV[idx] = (stridesV[idx] > 0) ? inpDims[idx] : 0; |
2399 | } |
2400 | } |
2401 | } |
2402 | |
2403 | // If the ith bit of ellipsis_mask is set, as many unspecified dimensions as |
2404 | // needed will be inserted between other dimensions. Only one non-zero bit is |
2405 | // allowed in ellipsis_mask. |
2406 | if (ellipsis_mask) { |
2407 | size_t ellipsisIdx = 0; |
2408 | for (size_t idx = 0; idx < 32; ++idx) { |
2409 | if (getMaskBit(ellipsis_mask, idx)) { |
2410 | ellipsisIdx = idx; |
2411 | break; |
2412 | } |
2413 | } |
2414 | // Note: It is unclear from the TFLite specification how to derive the |
2415 | // number of dimensions associated to the ellipsis. We use the fact that |
2416 | // when ellipsis is used the associated dimensions from "begin" and "end" |
2417 | // arrays are commonly marked both with 0. |
2418 | size_t idx = ellipsisIdx; |
2419 | while ((idx < beginV.size()) && (beginV[idx] == 0) && (endV[idx] == 0)) { |
2420 | beginV[idx] = (stridesV[idx] > 0) ? 0 : (inpDims[idx] - 1); |
2421 | endV[idx] = (stridesV[idx] > 0) ? inpDims[idx] : 0; |
2422 | idx++; |
2423 | } |
2424 | } |
2425 | |
2426 | // If the ith bit of new_axis_mask is set, then begin, end, and stride are |
2427 | // ignored and a new length 1 dimension is added at this point in the output |
2428 | // tensor. |
2429 | if (new_axis_mask) { |
2430 | size_t inpIdx = 0; |
2431 | size_t outIdx = 0; |
2432 | std::vector<dim_t> newOutDims; |
2433 | while (inpIdx < inpDims.size()) { |
2434 | if (getMaskBit(new_axis_mask, outIdx)) { |
2435 | newOutDims.push_back(1); |
2436 | } else { |
2437 | newOutDims.push_back(inpDims[inpIdx++]); |
2438 | } |
2439 | outIdx++; |
2440 | } |
2441 | NodeValue output = F_->createReshape(opInfo.name, input, newOutDims); |
2442 | return setOutputNodeValue(op, output); |
2443 | } |
2444 | |
2445 | // If the ith bit of shrink_axis_mask is set, it implies that the ith |
2446 | // specification shrinks the dimensionality by 1, taking on the value at |
2447 | // index begin[i]. end[i] and strides[i] are ignored in this case. |
2448 | if (shrink_axis_mask) { |
2449 | for (size_t idx = 0; idx < beginV.size(); ++idx) { |
2450 | if (getMaskBit(shrink_axis_mask, idx)) { |
2451 | endV[idx] = beginV[idx] + 1; |
2452 | stridesV[idx] = 1; |
2453 | } |
2454 | } |
2455 | } |
2456 | |
2457 | // Currently we only support strides of 1. |
2458 | // TODO: Add support for strides different than 1 (positive or negative) once |
2459 | // supported in Glow. |
2460 | for (size_t idx = 0; idx < stridesV.size(); ++idx) { |
2461 | RETURN_ERR_IF_NOT( |
2462 | stridesV[idx] == 1, |
2463 | opErrMsg(opInfo, "Only stride 1 is currently supported!" )); |
2464 | } |
2465 | |
2466 | // Create Slice node. |
2467 | NodeValue output = F_->createSlice(opInfo.name, input, beginV, endV); |
2468 | |
2469 | // Reshape output if some dimensions were shrunk. |
2470 | if (shrink_axis_mask) { |
2471 | std::vector<dim_t> oldOutDims = output.dims(); |
2472 | std::vector<dim_t> newOutDims; |
2473 | for (size_t idx = 0; idx < oldOutDims.size(); ++idx) { |
2474 | if (getMaskBit(shrink_axis_mask, idx)) { |
2475 | assert(oldOutDims[idx] == 1 && "Shrunk dimension should be 1!" ); |
2476 | } else { |
2477 | newOutDims.push_back(oldOutDims[idx]); |
2478 | } |
2479 | } |
2480 | if (newOutDims.empty()) { |
2481 | newOutDims.push_back(1); |
2482 | } |
2483 | output = F_->createReshape(opInfo.name + ".reshape" , output, newOutDims); |
2484 | } |
2485 | |
2486 | return setOutputNodeValue(op, output); |
2487 | } |
2488 | |
2489 | Error TFLiteModelLoader::loadResizeBilinear(const tflite::Operator *op, |
2490 | const OperatorInfo &opInfo) { |
2491 | const auto *opts = op->builtin_options_as_ResizeBilinearOptions(); |
2492 | NodeValue input; |
2493 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2494 | TypeRef outTy; |
2495 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2496 | |
2497 | bool alignCorners = opts->align_corners(); |
2498 | if (alignCorners) { |
2499 | LOG(WARNING) << opErrMsg(opInfo, "Option 'align_corners' is ignored!" ); |
2500 | } |
2501 | |
2502 | bool halfPixelCenters = opts->half_pixel_centers(); |
2503 | if (halfPixelCenters) { |
2504 | LOG(WARNING) << opErrMsg(opInfo, "Option 'half_pixel_centers' is ignored!" ); |
2505 | } |
2506 | |
2507 | NodeValue output = F_->createResizeBilinear(opInfo.name, input, outTy); |
2508 | return setOutputNodeValue(op, output); |
2509 | } |
2510 | |
2511 | Error TFLiteModelLoader::loadResizeNearest(const tflite::Operator *op, |
2512 | const OperatorInfo &opInfo) { |
2513 | const auto *opts = op->builtin_options_as_ResizeNearestNeighborOptions(); |
2514 | NodeValue input; |
2515 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2516 | TypeRef outTy; |
2517 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2518 | |
2519 | bool alignCorners = opts->align_corners(); |
2520 | if (alignCorners) { |
2521 | LOG(WARNING) << opErrMsg(opInfo, "Option 'align_corners' is ignored!" ); |
2522 | } |
2523 | |
2524 | bool halfPixelCenters = opts->half_pixel_centers(); |
2525 | if (halfPixelCenters) { |
2526 | LOG(WARNING) << opErrMsg(opInfo, "Option 'half_pixel_centers' is ignored!" ); |
2527 | } |
2528 | |
2529 | NodeValue output = F_->createResizeNearest(opInfo.name, input, outTy); |
2530 | return setOutputNodeValue(op, output); |
2531 | } |
2532 | |
2533 | Error TFLiteModelLoader::loadSpaceToDepth(const tflite::Operator *op, |
2534 | const OperatorInfo &opInfo) { |
2535 | const auto *opts = op->builtin_options_as_SpaceToDepthOptions(); |
2536 | NodeValue input; |
2537 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2538 | TypeRef outTy; |
2539 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2540 | |
2541 | int32_t blockSize = opts->block_size(); |
2542 | |
2543 | NodeValue output = F_->createSpaceToDepth(opInfo.name, input, blockSize); |
2544 | RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy), |
2545 | opErrMsg(opInfo, "Expected output type incorrect!" )); |
2546 | return setOutputNodeValue(op, output); |
2547 | } |
2548 | |
2549 | Error TFLiteModelLoader::loadDepthToSpace(const tflite::Operator *op, |
2550 | const OperatorInfo &opInfo) { |
2551 | const auto *opts = op->builtin_options_as_DepthToSpaceOptions(); |
2552 | NodeValue input; |
2553 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2554 | TypeRef outTy; |
2555 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2556 | |
2557 | int32_t blockSize = opts->block_size(); |
2558 | |
2559 | NodeValue output = F_->createDepthToSpace(opInfo.name, input, blockSize); |
2560 | RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy), |
2561 | opErrMsg(opInfo, "Expected output type incorrect!" )); |
2562 | return setOutputNodeValue(op, output); |
2563 | } |
2564 | |
2565 | Error TFLiteModelLoader::loadCast(const tflite::Operator *op, |
2566 | const OperatorInfo &opInfo) { |
2567 | NodeValue input; |
2568 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2569 | TypeRef outTy; |
2570 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2571 | // The Cast operator has two attributes "in_data_type" and "out_data_type" but |
2572 | // are not used because the input and output types are already available. |
2573 | NodeValue output = F_->createConvertTo(opInfo.name, input, outTy); |
2574 | return setOutputNodeValue(op, output); |
2575 | } |
2576 | |
2577 | Error TFLiteModelLoader::loadGather(const tflite::Operator *op, |
2578 | const OperatorInfo &opInfo) { |
2579 | const auto *opts = op->builtin_options_as_GatherOptions(); |
2580 | NodeValue data; |
2581 | ASSIGN_VALUE_OR_RETURN_ERR(data, getInputNodeValue(op, 0)); |
2582 | NodeValue indices; |
2583 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getInputNodeValue(op, 1)); |
2584 | TypeRef outTy; |
2585 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2586 | |
2587 | unsigned_t axis; |
2588 | ASSIGN_VALUE_OR_RETURN_ERR( |
2589 | axis, getPositiveAxis<unsigned_t>(opts->axis(), data.dims().size())); |
2590 | |
2591 | NodeValue output = F_->createGather(opInfo.name, data, indices, axis); |
2592 | RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy), |
2593 | opErrMsg(opInfo, "Expected output type incorrect!" )); |
2594 | return setOutputNodeValue(op, output); |
2595 | } |
2596 | |
2597 | Error TFLiteModelLoader::loadGatherND(const tflite::Operator *op, |
2598 | const OperatorInfo &opInfo) { |
2599 | NodeValue data; |
2600 | ASSIGN_VALUE_OR_RETURN_ERR(data, getInputNodeValue(op, 0)); |
2601 | NodeValue indices; |
2602 | ASSIGN_VALUE_OR_RETURN_ERR(indices, getInputNodeValue(op, 1)); |
2603 | TypeRef outTy; |
2604 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2605 | |
2606 | NodeValue output = F_->createGatherND(opInfo.name, data, indices); |
2607 | RETURN_ERR_IF_NOT(output.getType()->isEqual(outTy), |
2608 | opErrMsg(opInfo, "Expected output type incorrect!" )); |
2609 | return setOutputNodeValue(op, output); |
2610 | } |
2611 | |
2612 | Error TFLiteModelLoader::loadSelect(const tflite::Operator *op, |
2613 | const OperatorInfo &opInfo) { |
2614 | NodeValue cond; |
2615 | ASSIGN_VALUE_OR_RETURN_ERR(cond, getInputNodeValue(op, 0)); |
2616 | NodeValue LHS; |
2617 | ASSIGN_VALUE_OR_RETURN_ERR(LHS, getInputNodeValue(op, 1)); |
2618 | NodeValue RHS; |
2619 | ASSIGN_VALUE_OR_RETURN_ERR(RHS, getInputNodeValue(op, 2)); |
2620 | TypeRef outTy; |
2621 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2622 | |
2623 | NodeValue output = F_->createSelect(opInfo.name, outTy, cond, LHS, RHS); |
2624 | return setOutputNodeValue(op, output); |
2625 | } |
2626 | |
2627 | Error TFLiteModelLoader::loadSpaceToBatchNd(const tflite::Operator *op, |
2628 | const OperatorInfo &opInfo) { |
2629 | NodeValue input; |
2630 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2631 | NodeValue block; |
2632 | ASSIGN_VALUE_OR_RETURN_ERR(block, getInputNodeValue(op, 1)); |
2633 | NodeValue pads; |
2634 | ASSIGN_VALUE_OR_RETURN_ERR(pads, getInputNodeValue(op, 2)); |
2635 | TypeRef outTy; |
2636 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2637 | |
2638 | // Implemented as Pad -> Transpose N-D -> SpaceToDepth -> Transpose N-D |
2639 | |
2640 | // Get Block Size dimensionality. Should be 2 but pretend it's arbitrary like |
2641 | // in tensorflow. |
2642 | int32_t blockSize = block.dims()[0]; |
2643 | |
2644 | // Validate block size input. |
2645 | RETURN_ERR_IF_NOT(block.dims().size() == 1, |
2646 | opErrMsg(opInfo, "Block Size should be 1D" )); |
2647 | auto *blockC = llvm::dyn_cast<Constant>(block.getNode()); |
2648 | RETURN_ERR_IF_NOT( |
2649 | blockC, opErrMsg(opInfo, "Non constant 'Block Size' not supported" )); |
2650 | RETURN_ERR_IF_NOT(blockC->getType()->getElementType() == ElemKind::Int32ITy, |
2651 | opErrMsg(opInfo, "Block Size should have INT32 type" )); |
2652 | |
2653 | // Validate paddings. |
2654 | auto *padsC = llvm::dyn_cast<Constant>(pads.getNode()); |
2655 | RETURN_ERR_IF_NOT(padsC, |
2656 | opErrMsg(opInfo, "Non constant 'paddings' not supported" )); |
2657 | RETURN_ERR_IF_NOT(padsC->getType()->getElementType() == ElemKind::Int32ITy, |
2658 | opErrMsg(opInfo, "Paddings should have INT32 type" )); |
2659 | RETURN_ERR_IF_NOT(pads.dims().size() == 2, |
2660 | opErrMsg(opInfo, "Paddings should be 2D" )); |
2661 | for (dim_t i = 0; i < pads.dims().size(); i++) { |
2662 | RETURN_ERR_IF_NOT( |
2663 | pads.dims()[i] == (dim_t)blockSize, |
2664 | opErrMsg(opInfo, |
2665 | "Each Padding should have Block Size number of values" )); |
2666 | } |
2667 | |
2668 | // Get Block Size values. |
2669 | std::vector<int32_t> blockV; |
2670 | helperSetter<int32_t>(blockC, blockV); |
2671 | auto elemEqual = std::equal(blockV.begin() + 1, blockV.end(), blockV.begin()); |
2672 | RETURN_ERR_IF_NOT( |
2673 | elemEqual, |
2674 | opErrMsg(opInfo, "Different Block Size values not supported yet." )); |
2675 | |
2676 | auto inDims = input.getType()->dims(); |
2677 | |
2678 | // Create Padding Node. |
2679 | auto padsH = padsC->getPayload().getHandle<int32_t>(); |
2680 | std::vector<int> padsVec(2 * inDims.size()); |
2681 | for (int32_t i = 0; i < blockSize; i++) { |
2682 | // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds |
2683 | padsVec[i + 1] = padsH.raw(2 * i + 0); |
2684 | padsVec[i + 1 + inDims.size()] = padsH.raw(2 * i + 1); |
2685 | } |
2686 | ShapeVector padDims(inDims.begin(), inDims.end()); |
2687 | for (dim_t i = 0; i < padDims.size(); i++) { |
2688 | // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds |
2689 | padDims[i] += (padsVec[i] + padsVec[i + inDims.size()]); |
2690 | } |
2691 | auto OT = mod_.uniqueTypeWithNewShape(input.getType(), padDims); |
2692 | NodeValue pad = F_->createPad(opInfo.name, input, OT, PaddingMode::CONSTANT, |
2693 | padsVec, 0.f); |
2694 | |
2695 | // Check expected dimensions. |
2696 | auto padInDims = pad.getType()->dims(); |
2697 | auto outDims = outTy->dims(); |
2698 | int32_t calcblockSize = 0; |
2699 | for (int32_t i = 0; i < blockSize; i++) { |
2700 | RETURN_ERR_IF_NOT( |
2701 | padInDims[i + 1] % outDims[i + 1] == 0, |
2702 | opErrMsg(opInfo, strFormat("Input value %d must be a multiple of " |
2703 | "output value %d for space dim %d" , |
2704 | (int)padInDims[1], (int)outDims[1], i))); |
2705 | int32_t newBlockSize = padInDims[i + 1] / outDims[i + 1]; |
2706 | if (i > 1) { |
2707 | RETURN_ERR_IF_NOT( |
2708 | calcblockSize == newBlockSize, |
2709 | opErrMsg(opInfo, "Block Size for H & W must be the same." )); |
2710 | calcblockSize = newBlockSize; |
2711 | } |
2712 | } |
2713 | |
2714 | // Transpose N and D and perform SpaceToDepth. |
2715 | auto *trIn = |
2716 | F_->createTranspose(opInfo.name + ".trIn" , pad, {3, 1, 2, 0}, "NHWC" ); |
2717 | // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds |
2718 | auto *S2D = F_->createSpaceToDepth(opInfo.name + ".S2D" , trIn, blockV[0]); |
2719 | auto *output = |
2720 | F_->createTranspose(opInfo.name + ".trOut" , S2D, {3, 1, 2, 0}, "NHWC" ); |
2721 | |
2722 | return setOutputNodeValue(op, output); |
2723 | } |
2724 | |
2725 | Error TFLiteModelLoader::loadBatchToSpaceNd(const tflite::Operator *op, |
2726 | const OperatorInfo &opInfo) { |
2727 | NodeValue input; |
2728 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2729 | NodeValue block; |
2730 | ASSIGN_VALUE_OR_RETURN_ERR(block, getInputNodeValue(op, 1)); |
2731 | NodeValue crop; |
2732 | ASSIGN_VALUE_OR_RETURN_ERR(crop, getInputNodeValue(op, 2)); |
2733 | TypeRef outTy; |
2734 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2735 | |
2736 | // Implemented as Transpose N-D -> DepthToSpace -> Transpose N-D -> Crop |
2737 | |
2738 | // Get Block Size dimensionality. Should be 2 but pretend it's arbitrary like |
2739 | // in tensorflow. |
2740 | int32_t blockSize = block.dims()[0]; |
2741 | |
2742 | // Validate block input. |
2743 | RETURN_ERR_IF_NOT(block.dims().size() == 1, |
2744 | opErrMsg(opInfo, "Block Size should be 1D" )); |
2745 | auto *blockC = llvm::dyn_cast<Constant>(block.getNode()); |
2746 | RETURN_ERR_IF_NOT( |
2747 | blockC, opErrMsg(opInfo, "Non constant 'Block Size' not supported" )); |
2748 | RETURN_ERR_IF_NOT(blockC->getType()->getElementType() == ElemKind::Int32ITy, |
2749 | opErrMsg(opInfo, "Block Size should have INT32 type" )); |
2750 | |
2751 | // Validate crop input. |
2752 | auto *cropC = llvm::dyn_cast<Constant>(crop.getNode()); |
2753 | RETURN_ERR_IF_NOT(cropC, |
2754 | opErrMsg(opInfo, "Non constant 'crops' not supported" )); |
2755 | RETURN_ERR_IF_NOT(cropC->getType()->getElementType() == ElemKind::Int32ITy, |
2756 | opErrMsg(opInfo, "Paddings should have INT32 type" )); |
2757 | RETURN_ERR_IF_NOT(crop.dims().size() == 2, |
2758 | opErrMsg(opInfo, "Crops should be 2D" )); |
2759 | for (dim_t i = 0; i < crop.dims().size(); i++) { |
2760 | RETURN_ERR_IF_NOT( |
2761 | crop.dims()[i] == (dim_t)blockSize, |
2762 | opErrMsg(opInfo, "Each crop should have Block Size number of values" )); |
2763 | } |
2764 | |
2765 | // Get Block Size values. |
2766 | std::vector<int32_t> blockV; |
2767 | helperSetter<int32_t>(blockC, blockV); |
2768 | auto elemEqual = std::equal(blockV.begin() + 1, blockV.end(), blockV.begin()); |
2769 | RETURN_ERR_IF_NOT( |
2770 | elemEqual, |
2771 | opErrMsg(opInfo, "Different Block Size values not supported yet." )); |
2772 | |
2773 | // Transpose N and D and perform DepthToSpace. |
2774 | auto *tr1 = |
2775 | F_->createTranspose(opInfo.name + ".trPre" , input, {3, 1, 2, 0}, "NHWC" ); |
2776 | // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds |
2777 | auto *D2S = F_->createDepthToSpace(opInfo.name + ".D2S" , tr1, blockV[0]); |
2778 | auto *tr2 = |
2779 | F_->createTranspose(opInfo.name + ".trPost" , D2S, {3, 1, 2, 0}, "NHWC" ); |
2780 | |
2781 | // Create Crop Node. |
2782 | RETURN_ERR_IF_NOT(cropC->getType()->getElementType() == ElemKind::Int32ITy, |
2783 | opErrMsg(opInfo, "Paddings should have INT32 type" )); |
2784 | auto cropH = cropC->getPayload().getHandle<int32_t>(); |
2785 | std::vector<dim_t> cropVec(input.getType()->dims().size()); |
2786 | for (int32_t i = 0; i < blockSize; i++) { |
2787 | // @lint-ignore CLANGTIDY LocalUncheckedArrayBounds |
2788 | cropVec[i + 1] = cropH.raw(2 * i + 0); |
2789 | } |
2790 | NodeValue output = F_->createSlice(opInfo.name, tr2, cropVec, outTy); |
2791 | return setOutputNodeValue(op, output); |
2792 | } |
2793 | |
2794 | Error TFLiteModelLoader::loadTile(const tflite::Operator *op, |
2795 | const OperatorInfo &opInfo) { |
2796 | NodeValue input; |
2797 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2798 | NodeValue multiples; |
2799 | ASSIGN_VALUE_OR_RETURN_ERR(multiples, getInputNodeValue(op, 1)); |
2800 | |
2801 | auto numDims = input.getType()->dims().size(); |
2802 | std::vector<unsigned_t> numTiles; |
2803 | ASSIGN_VALUE_OR_RETURN_ERR(numTiles, |
2804 | loadArray<unsigned_t>(opInfo, multiples)); |
2805 | RETURN_ERR_IF_NOT(numTiles.size() == numDims, |
2806 | opErrMsg(opInfo, "Input operand 'multiples' length should " |
2807 | "match the number of input dimensions!" )); |
2808 | |
2809 | NodeValue output = input; |
2810 | for (unsigned_t axis = 0; axis < numDims; ++axis) { |
2811 | unsigned_t tiles = numTiles[axis]; |
2812 | if (tiles != 1) { |
2813 | output = F_->createTile(opInfo.name + std::to_string(axis), output, tiles, |
2814 | axis); |
2815 | } |
2816 | } |
2817 | return setOutputNodeValue(op, output); |
2818 | } |
2819 | |
2820 | Error TFLiteModelLoader::loadPack(const tflite::Operator *op, |
2821 | const OperatorInfo &opInfo) { |
2822 | const auto *opts = op->builtin_options_as_PackOptions(); |
2823 | TypeRef outTy; |
2824 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2825 | |
2826 | const size_t numInputs = op->inputs()->size(); |
2827 | RETURN_ERR_IF_NOT(int(numInputs) == opts->values_count(), |
2828 | opErrMsg(opInfo, "Attribute 'values_count' does not match " |
2829 | "the number of operator inputs!" )); |
2830 | llvm::SmallVector<NodeValue, 4> inputs; |
2831 | inputs.reserve(numInputs); |
2832 | for (size_t idx = 0; idx < numInputs; ++idx) { |
2833 | NodeValue input; |
2834 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, idx)); |
2835 | inputs.push_back(input); |
2836 | } |
2837 | |
2838 | unsigned_t axis; |
2839 | ASSIGN_VALUE_OR_RETURN_ERR( |
2840 | axis, getPositiveAxis<unsigned_t>(opts->axis(), outTy->dims().size())); |
2841 | |
2842 | // Validate that all inputs have same shape. |
2843 | for (size_t idx = 0; idx < numInputs; ++idx) { |
2844 | RETURN_ERR_IF_NOT( |
2845 | inputs[idx].getType()->isEqual(inputs[0].getType()), |
2846 | opErrMsg(opInfo, "Operator inputs do not have same type/shape!" )); |
2847 | } |
2848 | |
2849 | // Reshape all inputs from [D0,D1,...,DN] to [D0,D1,...,1,...,DN] where the |
2850 | // the singular dimension 1 is on the axis position. |
2851 | std::vector<dim_t> inputDimsReshaped = inputs[0].dims(); |
2852 | inputDimsReshaped.insert(inputDimsReshaped.begin() + axis, 1); |
2853 | for (size_t idx = 0; idx < numInputs; ++idx) { |
2854 | inputs[idx] = |
2855 | F_->createReshape(opInfo.name + ".Reshape" + std::to_string(idx), |
2856 | inputs[idx], inputDimsReshaped); |
2857 | } |
2858 | |
2859 | // Concatenate all inputs along axis. |
2860 | NodeValue output = |
2861 | F_->createConcat(opInfo.name + ".Concat" , inputs, axis, outTy); |
2862 | return setOutputNodeValue(op, output); |
2863 | } |
2864 | |
2865 | Error TFLiteModelLoader::loadUnpack(const tflite::Operator *op, |
2866 | const OperatorInfo &opInfo) { |
2867 | const auto *opts = op->builtin_options_as_UnpackOptions(); |
2868 | NodeValue input; |
2869 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2870 | TypeRef outTy; |
2871 | ASSIGN_VALUE_OR_RETURN_ERR(outTy, getOutputType(op, 0)); |
2872 | |
2873 | unsigned_t axis; |
2874 | ASSIGN_VALUE_OR_RETURN_ERR(axis, |
2875 | getPositiveAxis<unsigned_t>(opts->axis(), input)); |
2876 | |
2877 | unsigned_t num = static_cast<unsigned_t>(opts->num()); |
2878 | RETURN_ERR_IF_NOT( |
2879 | num == input.dims()[axis], |
2880 | opErrMsg(opInfo, |
2881 | "Attribute 'num' should be equal to input size along axis!" )); |
2882 | |
2883 | // Split input. |
2884 | std::vector<SliceNode *> outputNodes; |
2885 | F_->createSplit(opInfo.name, input, num, axis, {}, outputNodes); |
2886 | |
2887 | // Reshape outputs. |
2888 | std::vector<NodeValue> outputNodeValues(outputNodes.size(), nullptr); |
2889 | for (size_t idx = 0, end = outputNodes.size(); idx < end; ++idx) { |
2890 | outputNodeValues[idx] = outputNodes[idx]->getResult(); |
2891 | outputNodeValues[idx] = |
2892 | F_->createReshape(opInfo.name + ".Reshape" + std::to_string(idx), |
2893 | outputNodeValues[idx], outTy->dims()); |
2894 | } |
2895 | return setOutputNodeValues(op, outputNodeValues); |
2896 | } |
2897 | |
2898 | Error TFLiteModelLoader::loadTFLiteDetectionPostProcess( |
2899 | const tflite::Operator *op, const OperatorInfo &opInfo, |
2900 | const flexbuffers::Map &opts) { |
2901 | NodeValue boxes; |
2902 | ASSIGN_VALUE_OR_RETURN_ERR(boxes, getInputNodeValue(op, 0)); |
2903 | NodeValue scores; |
2904 | ASSIGN_VALUE_OR_RETURN_ERR(scores, getInputNodeValue(op, 1)); |
2905 | NodeValue anchors; |
2906 | ASSIGN_VALUE_OR_RETURN_ERR(anchors, getInputNodeValue(op, 2)); |
2907 | |
2908 | // Note: We cannot use the output types of the node because they are dynamic. |
2909 | // We create instead static types for this node with fixed sizes. |
2910 | |
2911 | // Get operator attributes. |
2912 | int32_t numClasses = opts["num_classes" ].AsInt32(); |
2913 | int32_t maxDetections = opts["max_detections" ].AsInt32(); |
2914 | int32_t maxClassesPerDetection = opts["max_classes_per_detection" ].AsInt32(); |
2915 | constexpr int32_t defaultNumDetectionsPerClass = 100; |
2916 | int32_t maxDetectionsPerClass = (opts["detections_per_class" ].IsNull()) |
2917 | ? defaultNumDetectionsPerClass |
2918 | : opts["detections_per_class" ].AsInt32(); |
2919 | float iouThreshold = opts["nms_iou_threshold" ].AsFloat(); |
2920 | float scoreThreshold = opts["nms_score_threshold" ].AsFloat(); |
2921 | float xScale = opts["x_scale" ].AsFloat(); |
2922 | float yScale = opts["y_scale" ].AsFloat(); |
2923 | float hScale = opts["h_scale" ].AsFloat(); |
2924 | float wScale = opts["w_scale" ].AsFloat(); |
2925 | bool regularNMS = (opts["use_regular_nms" ].IsNull()) |
2926 | ? false |
2927 | : opts["use_regular_nms" ].AsBool(); |
2928 | |
2929 | // Create node. |
2930 | auto *node = F_->createTFLiteDetectionPostProcess( |
2931 | opInfo.name, boxes, scores, anchors, numClasses, maxDetections, |
2932 | maxClassesPerDetection, maxDetectionsPerClass, iouThreshold, |
2933 | scoreThreshold, xScale, yScale, hScale, wScale, regularNMS); |
2934 | std::vector<NodeValue> outputNodeValues = { |
2935 | node->getDetectionBoxes(), node->getDetectionClasses(), |
2936 | node->getDetectionScores(), node->getNumDetections()}; |
2937 | return setOutputNodeValues(op, outputNodeValues); |
2938 | } |
2939 | |
2940 | Error TFLiteModelLoader::loadTFLiteAudioSpectrogram( |
2941 | const tflite::Operator *op, const OperatorInfo &opInfo, |
2942 | const flexbuffers::Map &opts) { |
2943 | NodeValue input; |
2944 | ASSIGN_VALUE_OR_RETURN_ERR(input, getInputNodeValue(op, 0)); |
2945 | |
2946 | // Get operator attributes. |
2947 | int32_t windowSize = opts["window_size" ].AsInt32(); |
2948 | int32_t windowStride = opts["stride" ].AsInt32(); |
2949 | bool magnitudeSquared = opts["magnitude_squared" ].AsBool(); |
2950 | |
2951 | // Create node. |
2952 | NodeValue output = F_->createAudioSpectrogram(opInfo.name, input, windowSize, |
2953 | windowStride, magnitudeSquared); |
2954 | return setOutputNodeValue(op, output, /* checkType */ false); |
2955 | } |
2956 | |
2957 | Error TFLiteModelLoader::loadTFLiteMFCC(const tflite::Operator *op, |
2958 | const OperatorInfo &opInfo, |
2959 | const flexbuffers::Map &opts) { |
2960 | NodeValue spectrogram; |
2961 | ASSIGN_VALUE_OR_RETURN_ERR(spectrogram, getInputNodeValue(op, 0)); |
2962 | |
2963 | // Get operator attributes. |
2964 | float sampleRate = (opts["sample_rate" ].IsNull()) |
2965 | ? tfliteMfccSampleRateOpt |
2966 | : opts["sample_rate" ].AsFloat(); |
2967 | float lowerFrequency = opts["lower_frequency_limit" ].AsFloat(); |
2968 | float upperFrequency = opts["upper_frequency_limit" ].AsFloat(); |
2969 | int32_t filterBankCount = opts["filterbank_channel_count" ].AsInt32(); |
2970 | int32_t numCoefficients = opts["dct_coefficient_count" ].AsInt32(); |
2971 | |
2972 | // Create node. |
2973 | NodeValue output = |
2974 | F_->createMFCC(opInfo.name, spectrogram, sampleRate, lowerFrequency, |
2975 | upperFrequency, filterBankCount, numCoefficients); |
2976 | return setOutputNodeValue(op, output, /* checkType */ false); |
2977 | } |
2978 | |
2979 | TFLiteModelLoader::TFLiteModelLoader(const std::string &modelFilename, |
2980 | Function *F) |
2981 | : F_(F), mod_(*F->getParent()) { |
2982 | auto setup = [&]() -> Error { |
2983 | // Read model. |
2984 | std::vector<char> modelData; |
2985 | ASSIGN_VALUE_OR_RETURN_ERR(model_, readModel(modelData, modelFilename)); |
2986 | |
2987 | // TODO: Verify model integrity using flatbuffers::Verifier class. |
2988 | |
2989 | // Get model info. |
2990 | modelVersion_ = model_->version(); |
2991 | modelDescription_ = model_->description()->str(); |
2992 | |
2993 | // Get model graph. |
2994 | const auto *modelGraphs = model_->subgraphs(); |
2995 | RETURN_ERR_IF_NOT( |
2996 | modelGraphs->size() == 1, |
2997 | "TensorFlowLite: Only one model subgraph is currently supported!" ); |
2998 | graph_ = (*modelGraphs)[0]; |
2999 | |
3000 | // Initialize graph node values. |
3001 | initializeNodeValues(); |
3002 | |
3003 | // Load graph input placeholders. |
3004 | RETURN_IF_ERR(loadInputPlaceholders()); |
3005 | |
3006 | // Load graph constants. |
3007 | RETURN_IF_ERR(loadConstants()); |
3008 | |
3009 | // Load graph operators. |
3010 | RETURN_IF_ERR(loadOperators()); |
3011 | |
3012 | // Save graph output placeholders. |
3013 | RETURN_IF_ERR(saveOutputPlaceholders()); |
3014 | |
3015 | // Verify function. |
3016 | RETURN_ERR_IF_NOT(F_->verify(), |
3017 | "TensorFlowLite: Function verification failed!" ); |
3018 | |
3019 | return Error::success(); |
3020 | }; |
3021 | |
3022 | EXIT_ON_ERR(setup()); |
3023 | } |
3024 | |