1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_IMPORTER_TFLITEMODELLOADER_H |
18 | #define GLOW_IMPORTER_TFLITEMODELLOADER_H |
19 | |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | #define FLATBUFFERS_LOCALE_INDEPENDENT 0 |
23 | #include "flatbuffers/flexbuffers.h" |
24 | #include "schema_generated.h" |
25 | |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | #include "llvm/ADT/StringRef.h" |
28 | |
29 | namespace glow { |
30 | |
31 | /// Loads TensorFlowLite models. |
32 | class TFLiteModelLoader { |
33 | |
34 | /// TensorFlowLite model object. |
35 | const tflite::Model *model_{nullptr}; |
36 | |
37 | /// TensorFlowLite current graph object. |
38 | const tflite::SubGraph *graph_{nullptr}; |
39 | |
40 | /// TensorFlowLite model version. |
41 | size_t modelVersion_; |
42 | |
43 | /// TensorFlowLite model description. |
44 | std::string modelDescription_; |
45 | |
46 | /// The Glow function which is currently constructed from \ref graph_. |
47 | Function *F_{nullptr}; |
48 | |
49 | /// The Glow module containing the function(s) we are constructing. |
50 | Module &mod_; |
51 | |
52 | /// A map from names of the model inputs to placeholders. |
53 | llvm::StringMap<Placeholder *> inputPlaceholderByName_; |
54 | |
55 | /// A map from names of the model outputs to placeholders. |
56 | llvm::StringMap<Placeholder *> outputPlaceholderByName_; |
57 | |
58 | /// Vector with node values ordered by their corresponding tensor positions |
59 | /// in the original model. This vector contains only the node values (tensors) |
60 | /// registered in the original model since only those are needed for chaining |
61 | /// the model graph operators. Other node values created during graph loading |
62 | /// will not be registered in this vector. |
63 | std::vector<NodeValue> nodeValueByIndex_; |
64 | |
65 | /// \returns a tensor from the model using the index \p index. |
66 | Expected<const tflite::Tensor *> getTensorByIndex(size_t index); |
67 | |
68 | /// \returns the name of the tensor \p tensor. |
69 | std::string getTensorName(const tflite::Tensor *tensor); |
70 | |
71 | /// \returns the shape of the tensor \p tensor. |
72 | Expected<std::vector<dim_t>> getTensorShape(const tflite::Tensor *tensor); |
73 | |
74 | /// \returns whether the shape of the tensor \p tensor is undefined. |
75 | Expected<bool> isTensorShapeUndefined(const tflite::Tensor *tensor); |
76 | |
77 | /// \returns the element type of the tensor \p tensor. |
78 | Expected<ElemKind> getTensorElemKind(const tflite::Tensor *tensor); |
79 | |
80 | /// \returns whether the tensor \p tensor is quantized or not. |
81 | bool isTensorQuantized(const tflite::Tensor *tensor); |
82 | |
83 | /// \returns whether the tensor \p tensor is quantized per axis or not. |
84 | bool isTensorPerAxisQuantized(const tflite::Tensor *tensor); |
85 | |
86 | /// \returns the scale quantization parameter of the tensor \p tensor. |
87 | Expected<float> getTensorScale(const tflite::Tensor *tensor); |
88 | |
89 | /// \returns the offset quantization parameter of the tensor \p tensor. |
90 | Expected<int32_t> getTensorOffset(const tflite::Tensor *tensor); |
91 | |
92 | /// \returns the scales quantization parameters of the tensor \p tensor. |
93 | Expected<std::vector<float>> getTensorScales(const tflite::Tensor *tensor); |
94 | |
95 | /// \returns the offsets quantization parameters of the tensor \p tensor. |
96 | Expected<std::vector<int32_t>> getTensorOffsets(const tflite::Tensor *tensor); |
97 | |
98 | /// \returns the type of the tensor \p tensor. |
99 | Expected<Type> getTensorType(const tflite::Tensor *tensor); |
100 | |
101 | /// \returns the data pointer and the size of tensor \p tensor as a pair. |
102 | Expected<std::pair<const char *, size_t>> |
103 | getTensorDataAndSize(const tflite::Tensor *tensor); |
104 | |
105 | /// \returns the operator code of the operator \p op. |
106 | Expected<tflite::BuiltinOperator> getOperatorCode(const tflite::Operator *op); |
107 | |
108 | /// \returns the operator custom code of the operator \p op. |
109 | Expected<std::string> getOperatorCustomCode(const tflite::Operator *op); |
110 | |
111 | /// \returns the operator version of the operator \p op. |
112 | Expected<int32_t> getOperatorVersion(const tflite::Operator *op); |
113 | |
114 | /// \returns the operator type of the operator \p op. |
115 | Expected<std::string> getOperatorType(const tflite::Operator *op); |
116 | |
117 | /// \returns the operator name of the operator \p op. |
118 | Expected<std::string> getOperatorName(const tflite::Operator *op); |
119 | |
120 | /// \returns the operator custom options as a map. |
121 | Expected<flexbuffers::Map> getOperatorCustomOpts(const tflite::Operator *op); |
122 | |
123 | /// \returns the tensor index of the input operand with index \p inputIdx |
124 | /// of the operator \p op. This function returns a negative number if the |
125 | /// tensor does not exist (is not used). |
126 | Expected<int32_t> getOperatorInputTensorIdx(const tflite::Operator *op, |
127 | size_t inputIdx); |
128 | |
129 | /// \returns the tensor index of the output operand with index \p outputIdx |
130 | /// of the operator \p op. |
131 | Expected<size_t> getOperatorOutputTensorIdx(const tflite::Operator *op, |
132 | size_t outputIdx); |
133 | |
134 | /// \returns whether the output operand with index \p outputIdx of the |
135 | /// operator \p op is a final tensor (graph output placeholder). |
136 | Expected<bool> isOperatorOutputFinalTensor(const tflite::Operator *op, |
137 | size_t outputIdx); |
138 | |
139 | /// \returns Expected<NodeValue> if a node value is registered in the array |
140 | /// \ref nodeValueByIndex_ with \p index (tensor index) and Error otherwise. |
141 | Expected<NodeValue> getNodeValueByIndex(size_t index); |
142 | |
143 | /// Set a node value \p nodeValue using \p index (tensor index) in the array |
144 | /// \ref nodeValueByIndex_. \returns Error if \p index is invalid. |
145 | Error setNodeValueByIndex(size_t index, NodeValue nodeValue); |
146 | |
147 | /// \returns Expected<NodeValue> if an input node value with the given index |
148 | /// \p inputIdx (operator level index that is 0 for 1st input node value, etc) |
149 | /// is found for the operator \p op and Error otherwise. |
150 | Expected<NodeValue> getInputNodeValue(const tflite::Operator *op, |
151 | size_t inputIdx); |
152 | |
153 | /// Register the single output node value \p nodeValue for the operator \p op. |
154 | /// The flag \p checkType specifies whether the type of the given node value |
155 | /// is checked to be the same with the type registered in the model. |
156 | Error setOutputNodeValue(const tflite::Operator *op, NodeValue nodeValue, |
157 | bool checkType = true); |
158 | |
159 | /// Register multiple output node values \p nodeValues for the operator \p op. |
160 | /// The flag \p checkType specifies whether the type of the given node value |
161 | /// is checked to be the same with the type registered in the model. |
162 | Error setOutputNodeValues(const tflite::Operator *op, |
163 | llvm::ArrayRef<NodeValue> nodeValues, |
164 | bool checkType = true); |
165 | |
166 | /// \returns the output type for operator \p op with index \p outputIndex. |
167 | Expected<TypeRef> getOutputType(const tflite::Operator *op, |
168 | size_t outputIndex); |
169 | |
170 | /// \returns whether the output shape for operator \p op with index |
171 | /// \p outputIndex is undefined. |
172 | Expected<bool> isOutputShapeUndefined(const tflite::Operator *op, |
173 | size_t outputIndex); |
174 | |
175 | /// Initialize the node value array \ref nodeValueByIndex_. |
176 | void initializeNodeValues(); |
177 | |
178 | /// Load the input placeholders of the current graph. |
179 | Error loadInputPlaceholders(); |
180 | |
181 | /// Load the constant weights of the current graph. |
182 | Error loadConstants(); |
183 | |
184 | /// Load the operators of the current graph. |
185 | Error loadOperators(); |
186 | |
187 | /// Save the output placeholders of the current graph. |
188 | Error saveOutputPlaceholders(); |
189 | |
190 | /// Add an activation function to the node value \p value using the activation |
191 | /// type \p type. The node value is modified in-place. |
192 | Error addActivation(NodeValue &value, tflite::ActivationFunctionType type); |
193 | |
194 | /// Local definition of a POD structure with operator meta information. |
195 | struct OperatorInfo { |
196 | std::string name; |
197 | std::string type; |
198 | size_t index; |
199 | tflite::BuiltinOperator code; |
200 | int32_t version; |
201 | }; |
202 | |
203 | /// Utility function to extend the error message \p errMsg with the operator |
204 | /// context provided by \p opInfo. \returns the extended error message. |
205 | const std::string opErrMsg(const OperatorInfo &opInfo, |
206 | const std::string &errMsg); |
207 | |
208 | /// \returns the value of axis given the operator info \p opInfo, the node |
209 | /// value \p axis which stores the axis value and the node value \p value |
210 | /// which the axis refers to which is used to wrap the axis value if negative. |
211 | template <typename T> |
212 | Expected<T> loadAxis(const OperatorInfo &opInfo, NodeValue axis, |
213 | NodeValue value); |
214 | |
215 | /// \returns the value of axes given the operator info \p opInfo, the node |
216 | /// value \p axes which stores the axes values and the node value \p value |
217 | /// which the axes refer to which is used to wrap the axes values if negative. |
218 | template <typename T> |
219 | Expected<std::vector<T>> loadAxes(const OperatorInfo &opInfo, NodeValue axes, |
220 | NodeValue value); |
221 | |
222 | /// \returns the values stored in the node value \p value as a 1D array given |
223 | /// the operator info \p opInfo. The node value \p value must be a Constant. |
224 | template <typename T> |
225 | Expected<std::vector<T>> loadArray(const OperatorInfo &opInfo, |
226 | NodeValue value); |
227 | |
228 | /// Helper tool to verify whether the Conv2D or DepthwiseConv2D operator \p op |
229 | /// with the operator info \p opInfo is quantized per axis. \returns true if |
230 | /// the operator is quantized per axis and creates new graph constants by |
231 | /// setting the pointers \p filterScalesC, \p filterOffsetsC, \p biasScalesC |
232 | /// and \b biasOffsetsC and returns \p false otherwise. |
233 | Expected<bool> isConv2DPerAxisQuantized(const tflite::Operator *op, |
234 | const OperatorInfo &opInfo, |
235 | Constant *&filterScalesC, |
236 | Constant *&filterOffsetsC, |
237 | Constant *&biasScalesC, |
238 | Constant *&biasOffsetsC); |
239 | |
240 | /// Load the operator \p op into the current graph. \p opInfo provides meta |
241 | /// information about \p op. \returns Error if operator cannot be loaded. |
242 | Error loadOperator(const tflite::Operator *op, const OperatorInfo &opInfo); |
243 | |
244 | /// Load unary arithmetic operator. |
245 | Error loadUnaryArithmetic(const tflite::Operator *op, |
246 | const OperatorInfo &opInfo); |
247 | |
248 | /// Load binary arithmetic operator. |
249 | Error loadBinaryArithmetic(const tflite::Operator *op, |
250 | const OperatorInfo &opInfo); |
251 | |
252 | /// Load Pool2D operator (MaxPool2D or AvgPool2D). |
253 | Error loadPool2D(const tflite::Operator *op, const OperatorInfo &opInfo); |
254 | |
255 | /// Load Concatenation operator. |
256 | Error loadConcat(const tflite::Operator *op, const OperatorInfo &opInfo); |
257 | |
258 | /// Load Conv2D operator. |
259 | Error loadConv2D(const tflite::Operator *op, const OperatorInfo &opInfo); |
260 | |
261 | /// Load DepthwiseConv2D operator. |
262 | Error loadDepthwiseConv2D(const tflite::Operator *op, |
263 | const OperatorInfo &opInfo); |
264 | |
265 | /// Load FullyConnected operator. |
266 | Error loadFullyConnected(const tflite::Operator *op, |
267 | const OperatorInfo &opInfo); |
268 | |
269 | /// Load Reshape operator. |
270 | Error loadReshape(const tflite::Operator *op, const OperatorInfo &opInfo); |
271 | |
272 | /// Load Softmax operator. |
273 | Error loadSoftmax(const tflite::Operator *op, const OperatorInfo &opInfo); |
274 | |
275 | /// Load LogSoftmax operator. |
276 | Error loadLogSoftmax(const tflite::Operator *op, const OperatorInfo &opInfo); |
277 | |
278 | /// Load Pad operator. |
279 | Error loadPad(const tflite::Operator *op, const OperatorInfo &opInfo); |
280 | |
281 | /// Load Transpose operator. |
282 | Error loadTranspose(const tflite::Operator *op, const OperatorInfo &opInfo); |
283 | |
284 | /// Load Reduce operator. |
285 | Error loadReduce(const tflite::Operator *op, const OperatorInfo &opInfo); |
286 | |
287 | /// Load Split operator. |
288 | Error loadSplit(const tflite::Operator *op, const OperatorInfo &opInfo); |
289 | |
290 | /// Load Arg operator (ArgMax or ArgMin). |
291 | Error loadArg(const tflite::Operator *op, const OperatorInfo &opInfo); |
292 | |
293 | /// Load Shape operator. |
294 | Error loadShape(const tflite::Operator *op, const OperatorInfo &opInfo); |
295 | |
296 | /// Load Slice operator. |
297 | Error loadSlice(const tflite::Operator *op, const OperatorInfo &opInfo); |
298 | |
299 | /// Load StridedSlice operator. |
300 | Error loadStridedSlice(const tflite::Operator *op, |
301 | const OperatorInfo &opInfo); |
302 | |
303 | /// Load Resize Bilinear operator. |
304 | Error loadResizeBilinear(const tflite::Operator *op, |
305 | const OperatorInfo &opInfo); |
306 | |
307 | /// Load Resize Nearest operator. |
308 | Error loadResizeNearest(const tflite::Operator *op, |
309 | const OperatorInfo &opInfo); |
310 | |
311 | /// Load SpaceToDepth operator. |
312 | Error loadSpaceToDepth(const tflite::Operator *op, |
313 | const OperatorInfo &opInfo); |
314 | |
315 | /// Load DepthToSpace operator. |
316 | Error loadDepthToSpace(const tflite::Operator *op, |
317 | const OperatorInfo &opInfo); |
318 | |
319 | /// Load Cast operator. |
320 | Error loadCast(const tflite::Operator *op, const OperatorInfo &opInfo); |
321 | |
322 | /// Load Gather operator. |
323 | Error loadGather(const tflite::Operator *op, const OperatorInfo &opInfo); |
324 | |
325 | /// Load Gather ND operator. |
326 | Error loadGatherND(const tflite::Operator *op, const OperatorInfo &opInfo); |
327 | |
328 | /// Load Select operator. |
329 | Error loadSelect(const tflite::Operator *op, const OperatorInfo &opInfo); |
330 | |
331 | /// Load Space To Batch ND operator. |
332 | Error loadSpaceToBatchNd(const tflite::Operator *op, |
333 | const OperatorInfo &opInfo); |
334 | |
335 | /// Load Batch To Space ND operator. |
336 | Error loadBatchToSpaceNd(const tflite::Operator *op, |
337 | const OperatorInfo &opInfo); |
338 | |
339 | /// Load Tile operator. |
340 | Error loadTile(const tflite::Operator *op, const OperatorInfo &opInfo); |
341 | |
342 | /// Load Pack operator. |
343 | Error loadPack(const tflite::Operator *op, const OperatorInfo &opInfo); |
344 | |
345 | /// Load Unpack operator. |
346 | Error loadUnpack(const tflite::Operator *op, const OperatorInfo &opInfo); |
347 | |
348 | /// Load TFLite Detection PostProcess custom operator. |
349 | Error loadTFLiteDetectionPostProcess(const tflite::Operator *op, |
350 | const OperatorInfo &opInfo, |
351 | const flexbuffers::Map &opts); |
352 | |
353 | /// Load TFLite Audio Spectrogram custom operator. |
354 | Error loadTFLiteAudioSpectrogram(const tflite::Operator *op, |
355 | const OperatorInfo &opInfo, |
356 | const flexbuffers::Map &opts); |
357 | |
358 | /// Load TFLite MFCC custom operator. |
359 | Error loadTFLiteMFCC(const tflite::Operator *op, const OperatorInfo &opInfo, |
360 | const flexbuffers::Map &opts); |
361 | |
362 | public: |
363 | /// \returns the TensorFlowLite model version. |
364 | size_t getModelVersion() const { return modelVersion_; }; |
365 | |
366 | /// \returns the TensorFlowLite model description. |
367 | std::string getModelDescription() const { return modelDescription_; }; |
368 | |
369 | /// \returns a map between the model input names and the input placeholders. |
370 | const llvm::StringMap<Placeholder *> &getInputPlaceholderMap() const { |
371 | return inputPlaceholderByName_; |
372 | } |
373 | |
374 | /// \returns a map between the model output names and the output placeholders. |
375 | const llvm::StringMap<Placeholder *> &getOutputPlaceholderMap() const { |
376 | return outputPlaceholderByName_; |
377 | } |
378 | |
379 | /// Loads the TensorFlowLite model from the file \p modelFilename into the |
380 | /// function \p F. |
381 | TFLiteModelLoader(const std::string &modelFilename, Function *F); |
382 | }; |
383 | |
384 | } // namespace glow |
385 | |
386 | #endif // GLOW_IMPORTER_TFLITEMODELLOADER_H |
387 | |