1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef GLOW_IMPORTER_TFLITEMODELLOADER_H
18#define GLOW_IMPORTER_TFLITEMODELLOADER_H
19
20#include "glow/Graph/Graph.h"
21
22#define FLATBUFFERS_LOCALE_INDEPENDENT 0
23#include "flatbuffers/flexbuffers.h"
24#include "schema_generated.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/StringRef.h"
28
29namespace glow {
30
31/// Loads TensorFlowLite models.
32class TFLiteModelLoader {
33
34 /// TensorFlowLite model object.
35 const tflite::Model *model_{nullptr};
36
37 /// TensorFlowLite current graph object.
38 const tflite::SubGraph *graph_{nullptr};
39
40 /// TensorFlowLite model version.
41 size_t modelVersion_;
42
43 /// TensorFlowLite model description.
44 std::string modelDescription_;
45
46 /// The Glow function which is currently constructed from \ref graph_.
47 Function *F_{nullptr};
48
49 /// The Glow module containing the function(s) we are constructing.
50 Module &mod_;
51
52 /// A map from names of the model inputs to placeholders.
53 llvm::StringMap<Placeholder *> inputPlaceholderByName_;
54
55 /// A map from names of the model outputs to placeholders.
56 llvm::StringMap<Placeholder *> outputPlaceholderByName_;
57
58 /// Vector with node values ordered by their corresponding tensor positions
59 /// in the original model. This vector contains only the node values (tensors)
60 /// registered in the original model since only those are needed for chaining
61 /// the model graph operators. Other node values created during graph loading
62 /// will not be registered in this vector.
63 std::vector<NodeValue> nodeValueByIndex_;
64
65 /// \returns a tensor from the model using the index \p index.
66 Expected<const tflite::Tensor *> getTensorByIndex(size_t index);
67
68 /// \returns the name of the tensor \p tensor.
69 std::string getTensorName(const tflite::Tensor *tensor);
70
71 /// \returns the shape of the tensor \p tensor.
72 Expected<std::vector<dim_t>> getTensorShape(const tflite::Tensor *tensor);
73
74 /// \returns whether the shape of the tensor \p tensor is undefined.
75 Expected<bool> isTensorShapeUndefined(const tflite::Tensor *tensor);
76
77 /// \returns the element type of the tensor \p tensor.
78 Expected<ElemKind> getTensorElemKind(const tflite::Tensor *tensor);
79
80 /// \returns whether the tensor \p tensor is quantized or not.
81 bool isTensorQuantized(const tflite::Tensor *tensor);
82
83 /// \returns whether the tensor \p tensor is quantized per axis or not.
84 bool isTensorPerAxisQuantized(const tflite::Tensor *tensor);
85
86 /// \returns the scale quantization parameter of the tensor \p tensor.
87 Expected<float> getTensorScale(const tflite::Tensor *tensor);
88
89 /// \returns the offset quantization parameter of the tensor \p tensor.
90 Expected<int32_t> getTensorOffset(const tflite::Tensor *tensor);
91
92 /// \returns the scales quantization parameters of the tensor \p tensor.
93 Expected<std::vector<float>> getTensorScales(const tflite::Tensor *tensor);
94
95 /// \returns the offsets quantization parameters of the tensor \p tensor.
96 Expected<std::vector<int32_t>> getTensorOffsets(const tflite::Tensor *tensor);
97
98 /// \returns the type of the tensor \p tensor.
99 Expected<Type> getTensorType(const tflite::Tensor *tensor);
100
101 /// \returns the data pointer and the size of tensor \p tensor as a pair.
102 Expected<std::pair<const char *, size_t>>
103 getTensorDataAndSize(const tflite::Tensor *tensor);
104
105 /// \returns the operator code of the operator \p op.
106 Expected<tflite::BuiltinOperator> getOperatorCode(const tflite::Operator *op);
107
108 /// \returns the operator custom code of the operator \p op.
109 Expected<std::string> getOperatorCustomCode(const tflite::Operator *op);
110
111 /// \returns the operator version of the operator \p op.
112 Expected<int32_t> getOperatorVersion(const tflite::Operator *op);
113
114 /// \returns the operator type of the operator \p op.
115 Expected<std::string> getOperatorType(const tflite::Operator *op);
116
117 /// \returns the operator name of the operator \p op.
118 Expected<std::string> getOperatorName(const tflite::Operator *op);
119
120 /// \returns the operator custom options as a map.
121 Expected<flexbuffers::Map> getOperatorCustomOpts(const tflite::Operator *op);
122
123 /// \returns the tensor index of the input operand with index \p inputIdx
124 /// of the operator \p op. This function returns a negative number if the
125 /// tensor does not exist (is not used).
126 Expected<int32_t> getOperatorInputTensorIdx(const tflite::Operator *op,
127 size_t inputIdx);
128
129 /// \returns the tensor index of the output operand with index \p outputIdx
130 /// of the operator \p op.
131 Expected<size_t> getOperatorOutputTensorIdx(const tflite::Operator *op,
132 size_t outputIdx);
133
134 /// \returns whether the output operand with index \p outputIdx of the
135 /// operator \p op is a final tensor (graph output placeholder).
136 Expected<bool> isOperatorOutputFinalTensor(const tflite::Operator *op,
137 size_t outputIdx);
138
139 /// \returns Expected<NodeValue> if a node value is registered in the array
140 /// \ref nodeValueByIndex_ with \p index (tensor index) and Error otherwise.
141 Expected<NodeValue> getNodeValueByIndex(size_t index);
142
143 /// Set a node value \p nodeValue using \p index (tensor index) in the array
144 /// \ref nodeValueByIndex_. \returns Error if \p index is invalid.
145 Error setNodeValueByIndex(size_t index, NodeValue nodeValue);
146
147 /// \returns Expected<NodeValue> if an input node value with the given index
148 /// \p inputIdx (operator level index that is 0 for 1st input node value, etc)
149 /// is found for the operator \p op and Error otherwise.
150 Expected<NodeValue> getInputNodeValue(const tflite::Operator *op,
151 size_t inputIdx);
152
153 /// Register the single output node value \p nodeValue for the operator \p op.
154 /// The flag \p checkType specifies whether the type of the given node value
155 /// is checked to be the same with the type registered in the model.
156 Error setOutputNodeValue(const tflite::Operator *op, NodeValue nodeValue,
157 bool checkType = true);
158
159 /// Register multiple output node values \p nodeValues for the operator \p op.
160 /// The flag \p checkType specifies whether the type of the given node value
161 /// is checked to be the same with the type registered in the model.
162 Error setOutputNodeValues(const tflite::Operator *op,
163 llvm::ArrayRef<NodeValue> nodeValues,
164 bool checkType = true);
165
166 /// \returns the output type for operator \p op with index \p outputIndex.
167 Expected<TypeRef> getOutputType(const tflite::Operator *op,
168 size_t outputIndex);
169
170 /// \returns whether the output shape for operator \p op with index
171 /// \p outputIndex is undefined.
172 Expected<bool> isOutputShapeUndefined(const tflite::Operator *op,
173 size_t outputIndex);
174
175 /// Initialize the node value array \ref nodeValueByIndex_.
176 void initializeNodeValues();
177
178 /// Load the input placeholders of the current graph.
179 Error loadInputPlaceholders();
180
181 /// Load the constant weights of the current graph.
182 Error loadConstants();
183
184 /// Load the operators of the current graph.
185 Error loadOperators();
186
187 /// Save the output placeholders of the current graph.
188 Error saveOutputPlaceholders();
189
190 /// Add an activation function to the node value \p value using the activation
191 /// type \p type. The node value is modified in-place.
192 Error addActivation(NodeValue &value, tflite::ActivationFunctionType type);
193
194 /// Local definition of a POD structure with operator meta information.
195 struct OperatorInfo {
196 std::string name;
197 std::string type;
198 size_t index;
199 tflite::BuiltinOperator code;
200 int32_t version;
201 };
202
203 /// Utility function to extend the error message \p errMsg with the operator
204 /// context provided by \p opInfo. \returns the extended error message.
205 const std::string opErrMsg(const OperatorInfo &opInfo,
206 const std::string &errMsg);
207
208 /// \returns the value of axis given the operator info \p opInfo, the node
209 /// value \p axis which stores the axis value and the node value \p value
210 /// which the axis refers to which is used to wrap the axis value if negative.
211 template <typename T>
212 Expected<T> loadAxis(const OperatorInfo &opInfo, NodeValue axis,
213 NodeValue value);
214
215 /// \returns the value of axes given the operator info \p opInfo, the node
216 /// value \p axes which stores the axes values and the node value \p value
217 /// which the axes refer to which is used to wrap the axes values if negative.
218 template <typename T>
219 Expected<std::vector<T>> loadAxes(const OperatorInfo &opInfo, NodeValue axes,
220 NodeValue value);
221
222 /// \returns the values stored in the node value \p value as a 1D array given
223 /// the operator info \p opInfo. The node value \p value must be a Constant.
224 template <typename T>
225 Expected<std::vector<T>> loadArray(const OperatorInfo &opInfo,
226 NodeValue value);
227
228 /// Helper tool to verify whether the Conv2D or DepthwiseConv2D operator \p op
229 /// with the operator info \p opInfo is quantized per axis. \returns true if
230 /// the operator is quantized per axis and creates new graph constants by
231 /// setting the pointers \p filterScalesC, \p filterOffsetsC, \p biasScalesC
232 /// and \b biasOffsetsC and returns \p false otherwise.
233 Expected<bool> isConv2DPerAxisQuantized(const tflite::Operator *op,
234 const OperatorInfo &opInfo,
235 Constant *&filterScalesC,
236 Constant *&filterOffsetsC,
237 Constant *&biasScalesC,
238 Constant *&biasOffsetsC);
239
240 /// Load the operator \p op into the current graph. \p opInfo provides meta
241 /// information about \p op. \returns Error if operator cannot be loaded.
242 Error loadOperator(const tflite::Operator *op, const OperatorInfo &opInfo);
243
244 /// Load unary arithmetic operator.
245 Error loadUnaryArithmetic(const tflite::Operator *op,
246 const OperatorInfo &opInfo);
247
248 /// Load binary arithmetic operator.
249 Error loadBinaryArithmetic(const tflite::Operator *op,
250 const OperatorInfo &opInfo);
251
252 /// Load Pool2D operator (MaxPool2D or AvgPool2D).
253 Error loadPool2D(const tflite::Operator *op, const OperatorInfo &opInfo);
254
255 /// Load Concatenation operator.
256 Error loadConcat(const tflite::Operator *op, const OperatorInfo &opInfo);
257
258 /// Load Conv2D operator.
259 Error loadConv2D(const tflite::Operator *op, const OperatorInfo &opInfo);
260
261 /// Load DepthwiseConv2D operator.
262 Error loadDepthwiseConv2D(const tflite::Operator *op,
263 const OperatorInfo &opInfo);
264
265 /// Load FullyConnected operator.
266 Error loadFullyConnected(const tflite::Operator *op,
267 const OperatorInfo &opInfo);
268
269 /// Load Reshape operator.
270 Error loadReshape(const tflite::Operator *op, const OperatorInfo &opInfo);
271
272 /// Load Softmax operator.
273 Error loadSoftmax(const tflite::Operator *op, const OperatorInfo &opInfo);
274
275 /// Load LogSoftmax operator.
276 Error loadLogSoftmax(const tflite::Operator *op, const OperatorInfo &opInfo);
277
278 /// Load Pad operator.
279 Error loadPad(const tflite::Operator *op, const OperatorInfo &opInfo);
280
281 /// Load Transpose operator.
282 Error loadTranspose(const tflite::Operator *op, const OperatorInfo &opInfo);
283
284 /// Load Reduce operator.
285 Error loadReduce(const tflite::Operator *op, const OperatorInfo &opInfo);
286
287 /// Load Split operator.
288 Error loadSplit(const tflite::Operator *op, const OperatorInfo &opInfo);
289
290 /// Load Arg operator (ArgMax or ArgMin).
291 Error loadArg(const tflite::Operator *op, const OperatorInfo &opInfo);
292
293 /// Load Shape operator.
294 Error loadShape(const tflite::Operator *op, const OperatorInfo &opInfo);
295
296 /// Load Slice operator.
297 Error loadSlice(const tflite::Operator *op, const OperatorInfo &opInfo);
298
299 /// Load StridedSlice operator.
300 Error loadStridedSlice(const tflite::Operator *op,
301 const OperatorInfo &opInfo);
302
303 /// Load Resize Bilinear operator.
304 Error loadResizeBilinear(const tflite::Operator *op,
305 const OperatorInfo &opInfo);
306
307 /// Load Resize Nearest operator.
308 Error loadResizeNearest(const tflite::Operator *op,
309 const OperatorInfo &opInfo);
310
311 /// Load SpaceToDepth operator.
312 Error loadSpaceToDepth(const tflite::Operator *op,
313 const OperatorInfo &opInfo);
314
315 /// Load DepthToSpace operator.
316 Error loadDepthToSpace(const tflite::Operator *op,
317 const OperatorInfo &opInfo);
318
319 /// Load Cast operator.
320 Error loadCast(const tflite::Operator *op, const OperatorInfo &opInfo);
321
322 /// Load Gather operator.
323 Error loadGather(const tflite::Operator *op, const OperatorInfo &opInfo);
324
325 /// Load Gather ND operator.
326 Error loadGatherND(const tflite::Operator *op, const OperatorInfo &opInfo);
327
328 /// Load Select operator.
329 Error loadSelect(const tflite::Operator *op, const OperatorInfo &opInfo);
330
331 /// Load Space To Batch ND operator.
332 Error loadSpaceToBatchNd(const tflite::Operator *op,
333 const OperatorInfo &opInfo);
334
335 /// Load Batch To Space ND operator.
336 Error loadBatchToSpaceNd(const tflite::Operator *op,
337 const OperatorInfo &opInfo);
338
339 /// Load Tile operator.
340 Error loadTile(const tflite::Operator *op, const OperatorInfo &opInfo);
341
342 /// Load Pack operator.
343 Error loadPack(const tflite::Operator *op, const OperatorInfo &opInfo);
344
345 /// Load Unpack operator.
346 Error loadUnpack(const tflite::Operator *op, const OperatorInfo &opInfo);
347
348 /// Load TFLite Detection PostProcess custom operator.
349 Error loadTFLiteDetectionPostProcess(const tflite::Operator *op,
350 const OperatorInfo &opInfo,
351 const flexbuffers::Map &opts);
352
353 /// Load TFLite Audio Spectrogram custom operator.
354 Error loadTFLiteAudioSpectrogram(const tflite::Operator *op,
355 const OperatorInfo &opInfo,
356 const flexbuffers::Map &opts);
357
358 /// Load TFLite MFCC custom operator.
359 Error loadTFLiteMFCC(const tflite::Operator *op, const OperatorInfo &opInfo,
360 const flexbuffers::Map &opts);
361
362public:
363 /// \returns the TensorFlowLite model version.
364 size_t getModelVersion() const { return modelVersion_; };
365
366 /// \returns the TensorFlowLite model description.
367 std::string getModelDescription() const { return modelDescription_; };
368
369 /// \returns a map between the model input names and the input placeholders.
370 const llvm::StringMap<Placeholder *> &getInputPlaceholderMap() const {
371 return inputPlaceholderByName_;
372 }
373
374 /// \returns a map between the model output names and the output placeholders.
375 const llvm::StringMap<Placeholder *> &getOutputPlaceholderMap() const {
376 return outputPlaceholderByName_;
377 }
378
379 /// Loads the TensorFlowLite model from the file \p modelFilename into the
380 /// function \p F.
381 TFLiteModelLoader(const std::string &modelFilename, Function *F);
382};
383
384} // namespace glow
385
386#endif // GLOW_IMPORTER_TFLITEMODELLOADER_H
387