1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | /// \file |
16 | /// Main abstraction controlling the tflite interpreter. |
17 | /// Do NOT include this file directly, |
18 | /// instead include third_party/tensorflow/lite/interpreter.h |
19 | /// See third_party/tensorflow/lite/c/common.h for the API for defining |
20 | /// operations (TfLiteRegistration). |
21 | #ifndef TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_ |
22 | #define TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_ |
23 | |
24 | // IWYU pragma: private, include "third_party/tensorflow/lite/interpreter.h" |
25 | // IWYU pragma: friend third_party/tensorflow/lite/interpreter.h |
26 | |
27 | #include <stddef.h> |
28 | #include <stdint.h> |
29 | |
30 | #include <complex> |
31 | #include <cstdio> |
32 | #include <cstdlib> |
33 | #include <functional> |
34 | #include <map> |
35 | #include <memory> |
36 | #include <string> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | #include "tensorflow/lite/allocation.h" |
41 | #include "tensorflow/lite/c/common.h" // IWYU pragma: export |
42 | #include "tensorflow/lite/core/api/error_reporter.h" |
43 | #include "tensorflow/lite/core/api/profiler.h" |
44 | #include "tensorflow/lite/core/subgraph.h" |
45 | #include "tensorflow/lite/experimental/resource/initialization_status.h" |
46 | #include "tensorflow/lite/experimental/resource/resource_base.h" |
47 | #include "tensorflow/lite/external_cpu_backend_context.h" |
48 | #include "tensorflow/lite/internal/signature_def.h" |
49 | #include "tensorflow/lite/interpreter_options.h" |
50 | #include "tensorflow/lite/portable_type_to_tflitetype.h" |
51 | #include "tensorflow/lite/profiling/root_profiler.h" |
52 | #include "tensorflow/lite/signature_runner.h" |
53 | #include "tensorflow/lite/stderr_reporter.h" |
54 | #include "tensorflow/lite/string_type.h" |
55 | #include "tensorflow/lite/type_to_tflitetype.h" |
56 | |
57 | namespace tflite { |
58 | |
59 | class InterpreterTest; // Class for friend declarations. |
60 | |
61 | namespace delegates { |
62 | class InterpreterUtils; // Class for friend declarations. |
63 | |
64 | namespace test_utils { |
65 | class TestDelegation; // Class for friend declarations. |
66 | } // namespace test_utils |
67 | } // namespace delegates |
68 | |
69 | namespace interpreter_wrapper { |
70 | class InterpreterWrapper; // Class for friend declarations. |
71 | } // namespace interpreter_wrapper |
72 | |
73 | /// An interpreter for a graph of nodes that input and output from tensors. |
74 | /// Each node of the graph processes a set of input tensors and produces a |
75 | /// set of output Tensors. All inputs/output tensors are referenced by index. |
76 | /// |
77 | /// Usage: |
78 | /// |
79 | /// <pre><code> |
80 | /// // Create model from file. Note that the model instance must outlive the |
81 | /// // interpreter instance. |
82 | /// auto model = tflite::FlatBufferModel::BuildFromFile(...); |
83 | /// if (model == nullptr) { |
84 | /// // Return error. |
85 | /// } |
86 | /// // Create an Interpreter with an InterpreterBuilder. |
87 | /// std::unique_ptr<tflite::Interpreter> interpreter; |
88 | /// tflite::ops::builtin::BuiltinOpResolver resolver; |
89 | /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { |
90 | /// // Return failure. |
91 | /// } |
92 | /// if (interpreter->AllocateTensors() != kTfLiteOk) { |
93 | /// // Return failure. |
94 | /// } |
95 | /// |
96 | /// auto input = interpreter->typed_tensor<float>(0); |
97 | /// for (int i = 0; i < input_size; i++) { |
98 | /// input[i] = ...; |
99 | // } |
100 | /// interpreter->Invoke(); |
101 | /// </code></pre> |
102 | /// |
103 | /// Note: For nearly all practical use cases, one should not directly construct |
104 | /// an Interpreter object, but rather use the InterpreterBuilder. |
105 | /// |
106 | /// WARNING: This class is *not* thread-safe. The client is responsible for |
107 | /// ensuring serialized interaction to avoid data races and undefined behavior. |
108 | class Interpreter { |
109 | public: |
110 | // Instantiate an interpreter. All errors associated with reading and |
111 | // processing this model will be forwarded to the error_reporter object. |
112 | // |
113 | // Note, if error_reporter is nullptr, then a default StderrReporter is |
114 | // used. Ownership of 'error_reporter' remains with the caller. |
115 | // WARNING: Use of this constructor outside of an InterpreterBuilder is not |
116 | // recommended. |
117 | explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); |
118 | |
119 | ~Interpreter(); |
120 | |
121 | // Interpreters are not copyable as they have non-trivial memory semantics. |
122 | Interpreter(const Interpreter&) = delete; |
123 | Interpreter& operator=(const Interpreter&) = delete; |
124 | |
125 | // Functions to build interpreter |
126 | #ifndef DOXYGEN_SKIP |
127 | /// Provide a list of tensor indexes that are inputs to the model. |
128 | /// Each index is bound check and this modifies the consistent_ flag of the |
129 | /// interpreter. |
130 | TfLiteStatus SetInputs(std::vector<int> inputs); |
131 | |
132 | /// Provide a list of tensor indexes that are outputs to the model |
133 | /// Each index is bound check and this modifies the consistent_ flag of the |
134 | /// interpreter. |
135 | TfLiteStatus SetOutputs(std::vector<int> outputs); |
136 | |
137 | /// Provide a list of tensor indexes that are variable tensors. |
138 | /// Each index is bound check and this modifies the consistent_ flag of the |
139 | /// interpreter. |
140 | TfLiteStatus SetVariables(std::vector<int> variables); |
141 | |
142 | /// Adds a node with the given parameters and returns the index of the new |
143 | /// node in `node_index` (optionally). Interpreter will take ownership of |
144 | /// `builtin_data` and destroy it with `free`. Ownership of 'init_data' |
145 | /// remains with the caller. |
146 | TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, |
147 | const std::vector<int>& outputs, |
148 | const char* init_data, |
149 | size_t init_data_size, void* builtin_data, |
150 | const TfLiteRegistration* registration, |
151 | int* node_index = nullptr); |
152 | |
153 | /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. |
154 | /// The value pointed to by `first_new_tensor_index` will be set to the |
155 | /// index of the first new tensor if `first_new_tensor_index` is non-null. |
156 | TfLiteStatus AddTensors(int tensors_to_add, |
157 | int* first_new_tensor_index = nullptr); |
158 | |
159 | /// Set description of inputs/outputs/data/fptrs for node `node_index`. |
160 | /// This variant assumes an external buffer has been allocated of size |
161 | /// bytes. The lifetime of buffer must be ensured to be greater or equal |
162 | /// to Interpreter. |
163 | TfLiteStatus SetTensorParametersReadOnly( |
164 | int tensor_index, TfLiteType type, const char* name, |
165 | const std::vector<int>& dims, TfLiteQuantization quantization, |
166 | const char* buffer, size_t bytes, const Allocation* allocation = nullptr); |
167 | |
168 | /// Legacy. Deprecated in favor of above. |
169 | inline TfLiteStatus SetTensorParametersReadOnly( |
170 | int tensor_index, TfLiteType type, const char* name, |
171 | const std::vector<int>& dims, TfLiteQuantizationParams quantization, |
172 | const char* buffer, size_t bytes, |
173 | const Allocation* allocation = nullptr) { |
174 | return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), |
175 | dims.data(), quantization, buffer, bytes, |
176 | allocation); |
177 | } |
178 | |
179 | TfLiteStatus SetTensorParametersReadOnly( |
180 | int tensor_index, TfLiteType type, const char* name, const size_t rank, |
181 | const int* dims, TfLiteQuantizationParams quantization, |
182 | const char* buffer, size_t bytes, const Allocation* allocation = nullptr); |
183 | |
184 | /// Set description of inputs/outputs/data/fptrs for node `node_index`. |
185 | /// This variant assumes an external buffer has been allocated of size |
186 | /// bytes. The lifetime of buffer must be ensured to be greater or equal |
187 | /// to Interpreter. |
188 | TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, |
189 | const char* name, |
190 | const std::vector<int>& dims, |
191 | TfLiteQuantization quantization, |
192 | bool is_variable = false); |
193 | |
194 | /// Legacy. Deprecated in favor of above. |
195 | inline TfLiteStatus SetTensorParametersReadWrite( |
196 | int tensor_index, TfLiteType type, const char* name, |
197 | const std::vector<int>& dims, TfLiteQuantizationParams quantization, |
198 | bool is_variable = false, |
199 | const std::vector<int>* dims_signature = nullptr) { |
200 | size_t rank_dims_signature = 0; |
201 | const int* dims_signature_pointer = nullptr; |
202 | if (dims_signature) { |
203 | rank_dims_signature = dims_signature->size(); |
204 | dims_signature_pointer = dims_signature->data(); |
205 | } |
206 | return SetTensorParametersReadWrite( |
207 | tensor_index, type, name, dims.size(), dims.data(), quantization, |
208 | is_variable, rank_dims_signature, dims_signature_pointer); |
209 | } |
210 | TfLiteStatus SetTensorParametersReadWrite( |
211 | int tensor_index, TfLiteType type, const char* name, const size_t rank, |
212 | const int* dims, TfLiteQuantizationParams quantization, |
213 | bool is_variable = false, const size_t rank_dims_signature = 0, |
214 | const int* dims_signature = nullptr); |
215 | #endif // DOXYGEN_SKIP |
216 | // Functions to access tensor data |
217 | |
218 | /// Read only access to list of inputs. |
219 | const std::vector<int>& inputs() const { return primary_subgraph().inputs(); } |
220 | |
221 | /// Return the name of a given input. The given index must be between 0 and |
222 | /// inputs().size(). |
223 | const char* GetInputName(int index) const { |
224 | return context_->tensors[inputs()[index]].name; |
225 | } |
226 | |
227 | /// Read only access to list of outputs. |
228 | const std::vector<int>& outputs() const { |
229 | return primary_subgraph().outputs(); |
230 | } |
231 | |
232 | /// Read only access to list of variable tensors. |
233 | const std::vector<int>& variables() const { |
234 | return primary_subgraph().variables(); |
235 | } |
236 | |
237 | /// Return the name of a given output. The given index must be between 0 and |
238 | /// outputs().size(). |
239 | const char* GetOutputName(int index) const { |
240 | return context_->tensors[outputs()[index]].name; |
241 | } |
242 | |
243 | /// Return the number of tensors in the model. |
244 | size_t tensors_size() const { return context_->tensors_size; } |
245 | |
246 | /// Return the number of ops in the model. |
247 | size_t nodes_size() const { return primary_subgraph().nodes_size(); } |
248 | |
249 | /// WARNING: Experimental interface, subject to change |
250 | const std::vector<int>& execution_plan() const { |
251 | return primary_subgraph().execution_plan(); |
252 | } |
253 | |
254 | /// Get a mutable tensor data structure. |
255 | // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this |
256 | // read/write access to structure |
257 | TfLiteTensor* tensor(int tensor_index) { |
258 | return primary_subgraph().tensor(tensor_index); |
259 | } |
260 | |
261 | /// Get an immutable tensor data structure. |
262 | const TfLiteTensor* tensor(int tensor_index) const { |
263 | return primary_subgraph().tensor(tensor_index); |
264 | } |
265 | |
266 | /// Returns a pointer to an operation and registration data structure if in |
267 | /// bounds from the primary subgraph(subgraph_[0]). |
268 | const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( |
269 | int node_index) const { |
270 | return primary_subgraph().node_and_registration(node_index); |
271 | } |
272 | |
273 | /// Returns a pointer to an operation and registration data structure if in |
274 | /// bounds. |
275 | const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( |
276 | int subgraph_index, int node_index) const { |
277 | return subgraph(subgraph_index)->node_and_registration(node_index); |
278 | } |
279 | |
280 | /// Perform a checked cast to the appropriate tensor type (mutable pointer |
281 | /// version). |
282 | template <class T> |
283 | T* typed_tensor(int tensor_index) { |
284 | if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { |
285 | if (tensor_ptr->type == typeToTfLiteType<T>()) { |
286 | return reinterpret_cast<T*>(tensor_ptr->data.raw); |
287 | } |
288 | } |
289 | return nullptr; |
290 | } |
291 | |
292 | /// Perform a checked cast to the appropriate tensor type (immutable pointer |
293 | /// version). |
294 | template <class T> |
295 | const T* typed_tensor(int tensor_index) const { |
296 | if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { |
297 | if (tensor_ptr->type == typeToTfLiteType<T>()) { |
298 | return reinterpret_cast<const T*>(tensor_ptr->data.raw); |
299 | } |
300 | } |
301 | return nullptr; |
302 | } |
303 | |
304 | /// WARNING: Experimental interface, subject to change |
305 | /// Returns list of all keys of different method signatures defined in the |
306 | /// model. |
307 | /// Note, pointers returned have lifetime same as the Interpreter object. |
308 | std::vector<const std::string*> signature_keys() const { |
309 | std::vector<const std::string*> signature_keys; |
310 | signature_keys.reserve(signature_defs_.size()); |
311 | for (const auto& sig_def : signature_defs_) { |
312 | signature_keys.emplace_back(&sig_def.signature_key); |
313 | } |
314 | return signature_keys; |
315 | } |
316 | |
317 | /// WARNING: Experimental interface, subject to change |
318 | /// Returns a pointer to the SignatureRunner instance to run the part of the |
319 | /// graph identified by a SignatureDef. The nullptr is returned if the given |
320 | /// signature key is not valid. |
321 | /// If you need to specify delegates, you have to do that before calling this |
322 | /// function. This function will additionally apply default delegates. Thus, |
323 | /// applying delegates after that might lead to undesirable behaviors. |
324 | /// Note, the pointed instance has lifetime same as the Interpreter object |
325 | /// and the SignatureRunner class is *not* thread-safe. |
326 | SignatureRunner* GetSignatureRunner(const char* signature_key); |
327 | |
328 | /// WARNING: Experimental interface, subject to change |
329 | /// Return the subgraph index that corresponds to a SignatureDef, defined by |
330 | /// 'signature_key'. |
331 | /// If invalid name passed, -1 will be returned. |
332 | int GetSubgraphIndexFromSignature(const char* signature_key) const { |
333 | for (const auto& signature : signature_defs_) { |
334 | if (signature.signature_key == signature_key) { |
335 | return signature.subgraph_index; |
336 | } |
337 | } |
338 | return -1; |
339 | } |
340 | |
341 | /// WARNING: Experimental interface, subject to change |
342 | /// Returns the mapping of inputs to tensor index in the signature |
343 | /// specified through 'signature_key'. |
344 | /// If invalid name passed, an empty list will be returned. |
345 | const std::map<std::string, uint32_t>& signature_inputs( |
346 | const char* signature_key) const { |
347 | for (const auto& sig_def : signature_defs_) { |
348 | if (sig_def.signature_key == signature_key) return sig_def.inputs; |
349 | } |
350 | static const std::map<std::string, uint32_t>* default_empty_list = |
351 | new std::map<std::string, uint32_t>(); |
352 | return *default_empty_list; |
353 | } |
354 | |
355 | /// WARNING: Experimental interface, subject to change |
356 | /// Returns the mapping of outputs to tensor index in the signature |
357 | /// specified through 'signature_key'. |
358 | /// If invalid name passed, an empty list will be returned. |
359 | const std::map<std::string, uint32_t>& signature_outputs( |
360 | const char* signature_key) const { |
361 | for (const auto& sig_def : signature_defs_) { |
362 | if (sig_def.signature_key == signature_key) return sig_def.outputs; |
363 | } |
364 | static const std::map<std::string, uint32_t>* default_empty_list = |
365 | new std::map<std::string, uint32_t>(); |
366 | return *default_empty_list; |
367 | } |
368 | |
369 | /// WARNING: Experimental interface, subject to change |
370 | /// Returns the input tensor identified by 'signature_input_name' in the |
371 | /// signature identified by 'signature_key'. |
372 | /// Returns nullptr if not found. |
373 | TfLiteTensor* input_tensor_by_signature(const char* signature_input_name, |
374 | const char* signature_key) { |
375 | const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); |
376 | if (subgraph_index == -1) return nullptr; |
377 | const int tensor_index = GetTensorIndexFromSignature( |
378 | signature_input_name, signature_key, /*is_input=*/true); |
379 | if (tensor_index == -1) return nullptr; |
380 | return subgraph(subgraph_index)->tensor(tensor_index); |
381 | } |
382 | |
383 | /// WARNING: Experimental interface, subject to change |
384 | /// Returns the output tensor identified by 'signature_output_name' in the |
385 | /// signature identified by 'signature_key'. |
386 | /// Returns nullptr if not found. |
387 | const TfLiteTensor* output_tensor_by_signature( |
388 | const char* signature_output_name, const char* signature_key) const { |
389 | const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); |
390 | if (subgraph_index == -1) return nullptr; |
391 | const int tensor_index = GetTensorIndexFromSignature( |
392 | signature_output_name, signature_key, /*is_input=*/false); |
393 | if (tensor_index == -1) return nullptr; |
394 | return subgraph(subgraph_index)->tensor(tensor_index); |
395 | } |
396 | |
397 | /// Return a mutable pointer to the given input tensor. The given index must |
398 | /// be between 0 and inputs().size(). |
399 | TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } |
400 | |
401 | /// Return an immutable pointer to the given input tensor. The given index |
402 | /// must be between 0 and inputs().size(). |
403 | const TfLiteTensor* input_tensor(size_t index) const { |
404 | return tensor(inputs()[index]); |
405 | } |
406 | |
407 | /// Return a mutable pointer into the data of a given input tensor. The given |
408 | /// index must be between 0 and inputs().size(). |
409 | template <class T> |
410 | T* typed_input_tensor(int index) { |
411 | return typed_tensor<T>(inputs()[index]); |
412 | } |
413 | |
414 | /// Return an immutable pointer into the data of a given input tensor. The |
415 | /// given index must be between 0 and inputs().size(). |
416 | template <class T> |
417 | const T* typed_input_tensor(int index) const { |
418 | return typed_tensor<T>(inputs()[index]); |
419 | } |
420 | |
421 | /// Return a mutable pointer to the given output tensor. The given index must |
422 | /// be between 0 and outputs().size(). |
423 | TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); } |
424 | |
425 | /// Return an immutable pointer to the given output tensor. The given index |
426 | /// must be between 0 and outputs().size(). |
427 | const TfLiteTensor* output_tensor(size_t index) const { |
428 | return tensor(outputs()[index]); |
429 | } |
430 | |
431 | /// Return a mutable pointer into the data of a given output tensor. The given |
432 | /// index must be between 0 and outputs().size(). |
433 | template <class T> |
434 | T* typed_output_tensor(int index) { |
435 | return typed_tensor<T>(outputs()[index]); |
436 | } |
437 | |
438 | /// Return an immutable pointer into the data of a given output tensor. The |
439 | /// given index must be between 0 and outputs().size(). |
440 | template <class T> |
441 | const T* typed_output_tensor(int index) const { |
442 | return typed_tensor<T>(outputs()[index]); |
443 | } |
444 | |
445 | /// Change the dimensionality of a given tensor. Note, this is only acceptable |
446 | /// for tensor indices that are inputs or variables. |
447 | /// Returns status of failure or success. Note that this doesn't actually |
448 | /// resize any existing buffers. A call to AllocateTensors() is required to |
449 | /// change the tensor input buffer. |
450 | TfLiteStatus ResizeInputTensor(int tensor_index, |
451 | const std::vector<int>& dims); |
452 | |
453 | /// Change the dimensionality of a given tensor. This is only acceptable for |
454 | /// tensor indices that are inputs or variables. Only unknown dimensions can |
455 | /// be resized with this function. Unknown dimensions are indicated as `-1` in |
456 | /// the `dims_signature` attribute of a `TfLiteTensor`. Returns status of |
457 | /// failure or success. Note that this doesn't actually resize any existing |
458 | /// buffers. A call to AllocateTensors() is required to change the tensor |
459 | /// input buffer. |
460 | TfLiteStatus ResizeInputTensorStrict(int tensor_index, |
461 | const std::vector<int>& dims); |
462 | |
463 | /// This releases memory held by non-persistent tensors. It does NOT |
464 | /// re-perform memory planning. AllocateTensors needs to be called before next |
465 | /// invocation. WARNING: Experimental interface, subject to change |
466 | TfLiteStatus ReleaseNonPersistentMemory(); |
467 | |
468 | /// Update allocations for all tensors. This will redim dependent tensors |
469 | /// using the input tensor dimensionality as given. This is relatively |
470 | /// expensive. This *must be* called after the interpreter has been created |
471 | /// and before running inference (and accessing tensor buffers), and *must be* |
472 | /// called again if (and only if) an input tensor is resized. Returns status |
473 | /// of success or failure. Will fail if any of the ops in the model (other |
474 | /// than those which were rewritten by delegates, if any) are not supported by |
475 | /// the Interpreter's OpResolver. |
476 | TfLiteStatus AllocateTensors(); |
477 | |
478 | /// Invoke the interpreter (run the whole graph in dependency order). |
479 | /// |
480 | /// NOTE: It is possible that the interpreter is not in a ready state |
481 | /// to evaluate (i.e. if a ResizeTensor() has been performed without an |
482 | /// AllocateTensors(). |
483 | /// Returns status of success or failure. |
484 | TfLiteStatus Invoke(); |
485 | |
486 | /// Set the number of threads available to the interpreter. |
487 | /// |
488 | /// NOTE: `num_threads` should be >= -1. Setting `num_threads` to 0 has the |
489 | /// effect to disable multithreading, which is equivalent to setting |
490 | /// `num_threads` to 1. If set to the value -1, the number of threads used |
491 | /// will be implementation-defined and platform-dependent. |
492 | /// |
493 | /// As TfLite interpreter could internally apply a TfLite delegate by default |
494 | /// (i.e. XNNPACK), the number of threads that are available to the default |
495 | /// delegate *should be* set via InterpreterBuilder APIs as follows: |
496 | /// |
497 | /// std::unique_ptr<tflite::Interpreter> interpreter; |
498 | /// tflite::InterpreterBuilder builder(tflite model, op resolver); |
499 | /// builder.SetNumThreads(...) |
500 | /// ASSERT_EQ(builder(&interpreter), kTfLiteOk); |
501 | /// |
502 | /// WARNING: This API is deprecated: prefer using |
503 | /// `InterpreterBuilder::SetNumThreads`, as documented above. |
504 | TfLiteStatus SetNumThreads(int num_threads); |
505 | |
506 | /// Allow float16 precision for FP32 calculation when possible. |
507 | /// Default: not allow. |
508 | /// |
509 | /// WARNING: This API is deprecated: prefer controlling this via delegate |
510 | /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or |
511 | /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. |
512 | /// This method will be removed in a future release. |
513 | void SetAllowFp16PrecisionForFp32(bool allow); |
514 | |
515 | /// Get the half precision flag. |
516 | /// WARNING: This is an experimental API and subject to change. |
517 | bool GetAllowFp16PrecisionForFp32() const { |
518 | return context_->allow_fp32_relax_to_fp16; |
519 | } |
520 | |
521 | /// Sets the cancellation function pointer in order to cancel a request in the |
522 | /// middle of a call to Invoke(). The interpreter queries this function during |
523 | /// inference, between op invocations; when it returns true, the interpreter |
524 | /// will abort execution and return `kTfLiteError`. The `data` parameter |
525 | /// contains any data used by the cancellation function, and if non-null, |
526 | /// remains owned by the caller. |
527 | /// WARNING: This is an experimental API and subject to change. |
528 | void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*)); |
529 | |
530 | /// Allow a delegate to look at the graph and modify the graph to handle |
531 | /// parts of the graph themselves. After this is called, the graph may |
532 | /// contain new nodes that replace 1 more nodes. |
533 | /// 'delegate' must outlive the interpreter. |
534 | /// Returns one of the following status codes: |
535 | /// 1. kTfLiteOk: Success. |
536 | /// 2. kTfLiteDelegateError: Delegation failed due to an error in the |
537 | /// delegate, or the delegate parameter was null. The Interpreter has been |
538 | /// restored to its pre-delegation state. |
539 | /// NOTE: This undoes all delegates previously applied to the Interpreter. |
540 | /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the |
541 | /// incompatibility with the TfLite runtime, e.g., the model graph is already |
542 | /// immutable when applying the delegate. However, the interpreter could still |
543 | /// be invoked. |
544 | /// 4. kTfLiteUnresolvedOps: Delegation failed because the model has an |
545 | /// operator that cannot be resolved. This can happen when the op is not |
546 | /// registered or built with the TF Lite framework. |
547 | /// 5. kTfLiteError: Unexpected/runtime failure. |
548 | /// WARNING: This is an experimental API and subject to change. |
549 | TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); |
550 | |
551 | // Owning handle to a TfLiteDelegate instance. |
552 | using TfLiteDelegatePtr = |
553 | std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; |
554 | |
555 | /// Same as ModifyGraphWithDelegate except this interpreter takes |
556 | /// ownership of the provided delegate. |
557 | /// WARNING: This is an experimental API and subject to change. |
558 | template <typename Delegate, typename Deleter> |
559 | inline TfLiteStatus ModifyGraphWithDelegate( |
560 | std::unique_ptr<Delegate, Deleter> delegate) { |
561 | Deleter deleter = std::move(delegate.get_deleter()); |
562 | |
563 | // Note that we retain ownership of the delegate even if graph modification |
564 | // fails, as delegate use will be in an indeterminate state at that point. |
565 | owned_delegates_.emplace_back( |
566 | delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { |
567 | deleter( |
568 | static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( |
569 | delegate_to_delete)); |
570 | }); |
571 | return ModifyGraphWithDelegate(owned_delegates_.back().get()); |
572 | } |
573 | |
574 | /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no |
575 | /// virtual destructor. The default deleter of the unique_ptr does not know |
576 | /// how to delete C++ objects deriving from TfLiteDelegate. |
577 | TfLiteStatus ModifyGraphWithDelegate( |
578 | std::unique_ptr<TfLiteDelegate> delegate) = delete; |
579 | |
580 | /// Ensure the data in `tensor.data` is readable. In case delegate is used, |
581 | /// it might require to copy the data from delegate buffer to raw memory. |
582 | /// WARNING: This is an experimental API and subject to change. |
583 | TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { |
584 | return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); |
585 | } |
586 | |
587 | /// Set the delegate buffer handle to a tensor. It can be called in the |
588 | /// following cases: |
589 | /// 1. Set the buffer handle to a tensor that's not being written by a |
590 | /// delegate. For example, feeding an OpenGL texture as the input of the |
591 | /// inference graph. |
592 | /// 2. Set the buffer handle to a tensor that uses the same delegate. |
593 | /// For example, set an OpenGL texture as the output of inference, while |
594 | /// the node which produces output is an OpenGL delegate node. |
595 | /// WARNING: This is an experimental API and subject to change. |
596 | TfLiteStatus SetBufferHandle(int tensor_index, |
597 | TfLiteBufferHandle buffer_handle, |
598 | TfLiteDelegate* delegate); |
599 | |
600 | /// Get the delegate buffer handle, and the delegate which can process the |
601 | /// buffer handle. |
602 | /// WARNING: This is an experimental API and subject to change. |
603 | TfLiteStatus GetBufferHandle(int tensor_index, |
604 | TfLiteBufferHandle* buffer_handle, |
605 | TfLiteDelegate** delegate); |
606 | |
607 | /// Sets the profiler to tracing execution. The caller retains ownership |
608 | /// of the profiler and must ensure its validity. |
609 | /// Previously registered profilers will be unregistered. |
610 | /// If `profiler` is nullptr, all previously installed profilers will be |
611 | /// removed. |
612 | /// WARNING: This is an experimental API and subject to change. |
613 | void SetProfiler(Profiler* profiler); |
614 | |
615 | /// Same as SetProfiler except this interpreter takes ownership |
616 | /// of the provided profiler. |
617 | /// Previously registered profilers will be unregistered. |
618 | /// If `profiler` is nullptr, all previously installed profilers will be |
619 | /// removed. |
620 | /// WARNING: This is an experimental API and subject to change. |
621 | void SetProfiler(std::unique_ptr<Profiler> profiler); |
622 | |
623 | /// Adds the profiler to tracing execution. The caller retains ownership |
624 | /// of the profiler and must ensure its validity. |
625 | /// nullptr `profiler` will be ignored. |
626 | /// WARNING: This is an experimental API and subject to change. |
627 | void AddProfiler(Profiler* profiler); |
628 | |
629 | /// Gets the profiler used for op tracing. |
630 | /// WARNING: This is an experimental API and subject to change. |
631 | Profiler* GetProfiler(); |
632 | |
633 | // The default capacity of `tensors_` vector. |
634 | static constexpr int kTensorsReservedCapacity = 128; |
635 | /// The capacity headroom of `tensors_` vector before calling ops' |
636 | /// `prepare` and `invoke` function. In these functions, it's guaranteed |
637 | /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate |
638 | /// pointers to existing tensors. |
639 | static constexpr int kTensorsCapacityHeadroom = 16; |
640 | |
641 | /// Set if buffer handle output is allowed. |
642 | /// |
643 | /// When using hardware delegation, Interpreter will make the data of output |
644 | /// tensors available in `tensor->data` by default. If the application can |
645 | /// consume the buffer handle directly (e.g. reading output from OpenGL |
646 | /// texture), it can set this flag to false, so Interpreter won't copy the |
647 | /// data from buffer handle to CPU memory. |
648 | /// WARNING: This is an experimental API and subject to change. |
649 | void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { |
650 | allow_buffer_handle_output_ = allow_buffer_handle_output; |
651 | } |
652 | |
653 | /// Reset all variable tensors to the default value. |
654 | /// If a variable tensor doesn't have a buffer, reset it to zero. |
655 | /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it |
656 | /// to the value of the buffer. |
657 | /// WARNING: This is an experimental API and subject to change. |
658 | TfLiteStatus ResetVariableTensors(); |
659 | |
660 | /// Retrieve an operator's description of its work, for profiling purposes. |
661 | const char* OpProfilingString(const TfLiteRegistration& op_reg, |
662 | const TfLiteNode* node) const { |
663 | if (op_reg.profiling_string == nullptr) return nullptr; |
664 | return op_reg.profiling_string(context_, node); |
665 | } |
666 | |
667 | // Set the value of an external context. TFLite interpreter doesn't take the |
668 | // memory ownership of this external context 'ctx', and the context should |
669 | // outlive the TFLite interpreter. |
670 | void SetExternalContext(TfLiteExternalContextType type, |
671 | TfLiteExternalContext* ctx); |
672 | |
673 | /// Assigns (or reassigns) a custom memory allocation for the given tensor. |
674 | /// `flags` is a bitmask, see TfLiteCustomAllocationFlags. |
675 | /// The runtime does NOT take ownership of the underlying memory. |
676 | /// |
677 | /// NOTE: User needs to call AllocateTensors() after this. |
678 | /// Invalid/insufficient buffers will cause an error during AllocateTensors or |
679 | /// Invoke (in case of dynamic shapes in the graph). |
680 | /// |
681 | /// Parameters should satisfy the following conditions: |
682 | /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent |
683 | /// In general, this is true for I/O tensors & variable tensors. |
684 | /// 2. allocation->data has the appropriate permissions for runtime access |
685 | /// (Read-only for inputs, Read-Write for others), and outlives |
686 | /// Interpreter. |
687 | /// 3. allocation->bytes >= tensor->bytes. |
688 | /// This condition is checked again if any tensors are resized. |
689 | /// 4. allocation->data should be aligned to kDefaultTensorAlignment |
690 | /// defined in lite/util.h. (Currently 64 bytes) |
691 | /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is |
692 | /// set through `flags`. |
693 | /// |
694 | /// WARNING: This is an experimental interface that is subject to change. |
695 | TfLiteStatus SetCustomAllocationForTensor( |
696 | int tensor_index, const TfLiteCustomAllocation& allocation, |
697 | int64_t flags = kTfLiteCustomAllocationFlagsNone); |
698 | |
699 | /// Apply InterpreterOptions which tunes behavior of the interpreter. |
700 | /// WARNING: This is an experimental interface that is subject to change. |
701 | TfLiteStatus ApplyOptions(InterpreterOptions* options); |
702 | |
703 | #ifndef DOXYGEN_SKIP |
704 | /// Return the number of subgraphs in the model. |
705 | /// WARNING: This is an experimental API and subject to change. |
706 | size_t subgraphs_size() const { return subgraphs_.size(); } |
707 | |
708 | /// Get a pointer to a subgraph if in bounds. |
709 | /// WARNING: This is an experimental API and subject to change. |
710 | const Subgraph* subgraph(int subgraph_index) const { |
711 | if (subgraph_index < 0 || |
712 | static_cast<size_t>(subgraph_index) >= subgraphs_size()) { |
713 | return nullptr; |
714 | } |
715 | return subgraphs_[subgraph_index].get(); |
716 | } |
717 | |
718 | /// WARNING: This is an experimental API and subject to change. |
719 | Subgraph* subgraph(int subgraph_index) { |
720 | return const_cast<Subgraph*>( |
721 | static_cast<const Interpreter*>(this)->subgraph(subgraph_index)); |
722 | } |
723 | |
724 | /// WARNING: Experimental interface, subject to change |
725 | Subgraph& primary_subgraph() { |
726 | return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry. |
727 | } |
728 | |
729 | /// WARNING: Experimental interface, subject to change |
730 | const Subgraph& primary_subgraph() const { |
731 | return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. |
732 | } |
733 | #endif // DOXYGEN_SKIP |
734 | |
735 | /// WARNING: Experimental interface, subject to change |
736 | /// Get the error reporter associated with this interpreter. |
737 | ErrorReporter* error_reporter() const { return error_reporter_; } |
738 | |
739 | private: |
740 | friend class InterpreterBuilder; |
741 | friend class tflite::InterpreterTest; |
742 | friend class tflite::delegates::InterpreterUtils; |
743 | friend class tflite::delegates::test_utils::TestDelegation; |
744 | friend class tflite::interpreter_wrapper::InterpreterWrapper; |
745 | |
746 | /// Set the value of an external context. |
747 | static void SetExternalContext(struct TfLiteContext* context, |
748 | TfLiteExternalContextType type, |
749 | TfLiteExternalContext* ctx); |
750 | |
751 | // Helper method that return the tensor index that corresponds to |
752 | // a name in a SignatureDef. Defined by 'signature_key', and |
753 | // 'signature_tensor_name'. |
754 | // If 'is_input' is true then the tensor is checked in input tensors, |
755 | // otherwise it will be checked in output tensors. |
756 | // Returns -1 if the tensor is not found. |
757 | int GetTensorIndexFromSignature(const char* signature_tensor_name, |
758 | const char* signature_key, |
759 | bool is_input) const { |
760 | // Iterate directly and don't use other methods to avoid extra allocation. |
761 | for (const auto& signature : signature_defs_) { |
762 | if (signature.signature_key != signature_key) continue; |
763 | auto& signature_list = (is_input ? signature.inputs : signature.outputs); |
764 | auto tensor_iter = signature_list.find(signature_tensor_name); |
765 | if (tensor_iter == signature_list.end()) return -1; |
766 | return tensor_iter->second; |
767 | } |
768 | return -1; |
769 | } |
770 | |
771 | // Applies TFLite default delegates. |
772 | TfLiteStatus ApplyLazyDelegateProviders(); |
773 | |
774 | // Private non-experimental implementation of ModifyGraphWithDelegate. |
775 | // Unlike ModifyGraphWithDelegate, ModifyGraphWithDelegateImpl is defined in |
776 | // interpreter.cc rather than in interpreter_experimental.cc, so it can be |
777 | // used to implement other non-experimental methods. |
778 | TfLiteStatus ModifyGraphWithDelegateImpl(TfLiteDelegate* delegate); |
779 | |
780 | // Same as ModifyGraphWithDelegateImpl except that it takes ownership of the |
781 | // delegate. |
782 | template <typename Delegate, typename Deleter> |
783 | inline TfLiteStatus ModifyGraphWithDelegateImpl( |
784 | std::unique_ptr<Delegate, Deleter>&& delegate) { |
785 | Deleter deleter = std::move(delegate.get_deleter()); |
786 | |
787 | // Note that we retain ownership of the delegate even if graph modification |
788 | // fails, as delegate use will be in an indeterminate state at that point. |
789 | owned_delegates_.emplace_back( |
790 | delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { |
791 | deleter( |
792 | static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( |
793 | delegate_to_delete)); |
794 | }); |
795 | return ModifyGraphWithDelegateImpl(owned_delegates_.back().get()); |
796 | } |
797 | |
798 | // Overrides execution plan. ImplThis bounds checks indices sent in. |
799 | // Note: Only used during initialization. |
800 | TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); |
801 | |
802 | // Sets the profiler to all subgraphs. |
803 | void SetSubgraphProfiler(); |
804 | |
805 | // Remove delegates (for fallback behaviour). The interpreter is invokable |
806 | // afterwards. |
807 | TfLiteStatus RemoveAllDelegates(); |
808 | |
809 | // Returns true if delegates have been applied. |
810 | bool HasDelegates(); |
811 | |
812 | // Returns true if the model has been fully delegated. |
813 | bool IsFullyDelegated() const; |
814 | |
815 | // Returns true if cancellation function returns true. |
816 | bool IsCancelled(); |
817 | |
818 | // Sets the list of signature defs in the model. |
819 | void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) { |
820 | signature_defs_ = std::move(signature_defs); |
821 | } |
822 | |
823 | // Sets model metadata as a mapping of name (key) and buffer (value) strings. |
824 | // Used by InterpreterBuilder, should be called after setting up subgraphs. |
825 | TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata); |
826 | |
827 | /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph |
828 | /// entries. The value pointed to by `first_new_subgraph_index` will be set to |
829 | /// the index of the first new subgraph if `first_new_subgraph_index` is |
830 | /// non-null. |
831 | void AddSubgraphs(int subgraphs_to_add, |
832 | int* first_new_subgraph_index = nullptr); |
833 | |
834 | /// Implementation of SetProfiler. |
835 | /// Unlike SetProfiler, this is defined in interpreter.cc rather than in |
836 | /// interpreter_experimental.cc, so it can be used by interpreter_builder.cc. |
837 | void SetProfilerImpl(std::unique_ptr<Profiler> profiler); |
838 | |
839 | TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options); |
840 | |
841 | // A pure C data structure used to communicate with the pure C plugin |
842 | // interface. To avoid copying tensor metadata, this is also the definitive |
843 | // structure to store tensors. |
844 | // This is the primary subgraph context. |
845 | TfLiteContext* context_ = nullptr; |
846 | |
847 | // The error reporter delegate that tflite will forward queries errors to. |
848 | ErrorReporter* error_reporter_ = nullptr; |
849 | |
850 | // List of delegates that have been installed and are owned by this |
851 | // interpreter instance. Useful if client delegate ownership is burdensome. |
852 | // WARNING: This is an experimental API and subject to change. |
853 | // TODO(b/116667551): Use TfLiteExternalContext for storing state. |
854 | std::vector< |
855 | std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>> |
856 | owned_delegates_; |
857 | |
858 | // A root profiler that holds a list of attached profiler implementations. |
859 | // will be nullptr if there's no child profiler registered. |
860 | std::unique_ptr<profiling::RootProfiler> root_profiler_; |
861 | |
862 | bool allow_buffer_handle_output_ = false; |
863 | |
864 | // List of active external contexts. |
865 | TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; |
866 | |
867 | // The default external cpu backend context. After an TFLite interpreter is |
868 | // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point |
869 | // to this object. However, if this element value is overwritten via calling |
870 | // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to |
871 | // nullptr if necessary. |
872 | std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_; |
873 | |
874 | // Subgraphs |
875 | std::vector<std::unique_ptr<Subgraph>> subgraphs_; |
876 | |
877 | // A map of resources. Owned by interpreter and shared by multiple subgraphs. |
878 | resource::ResourceMap resources_; |
879 | |
880 | // A map of resource Ids. Owned by interpreter and shared by multiple |
881 | // subgraphs. |
882 | resource::ResourceIDMap resource_ids_; |
883 | |
884 | // A map of initialization statuses, that indicate whether the initialization |
885 | // subgraph invocation is done or not. Owned by interpreter and shared by |
886 | // multiple subgraphs. |
887 | resource::InitializationStatusMap initialization_status_map_; |
888 | |
889 | // Indicating delegates that the TFLite interpreter will apply by default. |
890 | // An empty one means there's no delegate to be applied by default or |
891 | // delegates have been applied and doesn't need to be applied again. |
892 | using TfLiteDelegateCreator = |
893 | std::function<TfLiteDelegatePtr(int /*num_threads*/)>; |
894 | using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>; |
895 | TfLiteDelegateCreators lazy_delegate_providers_; |
896 | |
897 | // List of SignatureDefs obtained from the model. |
898 | std::vector<internal::SignatureDef> signature_defs_; |
899 | |
900 | // Map of signature key to its corresponding SignatureRunner object. |
901 | // A SignatureRunner is basically a wrapper of the Subgraph corresponding to |
902 | // its SignatureDef. |
903 | std::map<std::string, SignatureRunner> signature_runner_map_; |
904 | |
905 | // Model metadata stored as mapping of name (key) to buffer (value). |
906 | // Data is mapped from the Metadata in TFLite flatbuffer model. |
907 | std::map<std::string, std::string> metadata_; |
908 | |
909 | // InterpreterOptions object which is being used. |
910 | std::unique_ptr<InterpreterOptions> options_; |
911 | }; |
912 | |
913 | } // namespace tflite |
914 | #endif // TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_ |
915 | |