1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15/// \file
16/// Main abstraction controlling the tflite interpreter.
17/// Do NOT include this file directly,
18/// instead include third_party/tensorflow/lite/interpreter.h
19/// See third_party/tensorflow/lite/c/common.h for the API for defining
20/// operations (TfLiteRegistration).
21#ifndef TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_
22#define TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_
23
24// IWYU pragma: private, include "third_party/tensorflow/lite/interpreter.h"
25// IWYU pragma: friend third_party/tensorflow/lite/interpreter.h
26
27#include <stddef.h>
28#include <stdint.h>
29
30#include <complex>
31#include <cstdio>
32#include <cstdlib>
33#include <functional>
34#include <map>
35#include <memory>
36#include <string>
37#include <utility>
38#include <vector>
39
40#include "tensorflow/lite/allocation.h"
41#include "tensorflow/lite/c/common.h" // IWYU pragma: export
42#include "tensorflow/lite/core/api/error_reporter.h"
43#include "tensorflow/lite/core/api/profiler.h"
44#include "tensorflow/lite/core/subgraph.h"
45#include "tensorflow/lite/experimental/resource/initialization_status.h"
46#include "tensorflow/lite/experimental/resource/resource_base.h"
47#include "tensorflow/lite/external_cpu_backend_context.h"
48#include "tensorflow/lite/internal/signature_def.h"
49#include "tensorflow/lite/interpreter_options.h"
50#include "tensorflow/lite/portable_type_to_tflitetype.h"
51#include "tensorflow/lite/profiling/root_profiler.h"
52#include "tensorflow/lite/signature_runner.h"
53#include "tensorflow/lite/stderr_reporter.h"
54#include "tensorflow/lite/string_type.h"
55#include "tensorflow/lite/type_to_tflitetype.h"
56
57namespace tflite {
58
59class InterpreterTest; // Class for friend declarations.
60
61namespace delegates {
62class InterpreterUtils; // Class for friend declarations.
63
64namespace test_utils {
65class TestDelegation; // Class for friend declarations.
66} // namespace test_utils
67} // namespace delegates
68
69namespace interpreter_wrapper {
70class InterpreterWrapper; // Class for friend declarations.
71} // namespace interpreter_wrapper
72
73/// An interpreter for a graph of nodes that input and output from tensors.
74/// Each node of the graph processes a set of input tensors and produces a
75/// set of output Tensors. All inputs/output tensors are referenced by index.
76///
77/// Usage:
78///
79/// <pre><code>
80/// // Create model from file. Note that the model instance must outlive the
81/// // interpreter instance.
82/// auto model = tflite::FlatBufferModel::BuildFromFile(...);
83/// if (model == nullptr) {
84/// // Return error.
85/// }
86/// // Create an Interpreter with an InterpreterBuilder.
87/// std::unique_ptr<tflite::Interpreter> interpreter;
88/// tflite::ops::builtin::BuiltinOpResolver resolver;
89/// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
90/// // Return failure.
91/// }
92/// if (interpreter->AllocateTensors() != kTfLiteOk) {
93/// // Return failure.
94/// }
95///
96/// auto input = interpreter->typed_tensor<float>(0);
97/// for (int i = 0; i < input_size; i++) {
98/// input[i] = ...;
99// }
100/// interpreter->Invoke();
101/// </code></pre>
102///
103/// Note: For nearly all practical use cases, one should not directly construct
104/// an Interpreter object, but rather use the InterpreterBuilder.
105///
106/// WARNING: This class is *not* thread-safe. The client is responsible for
107/// ensuring serialized interaction to avoid data races and undefined behavior.
108class Interpreter {
109 public:
110 // Instantiate an interpreter. All errors associated with reading and
111 // processing this model will be forwarded to the error_reporter object.
112 //
113 // Note, if error_reporter is nullptr, then a default StderrReporter is
114 // used. Ownership of 'error_reporter' remains with the caller.
115 // WARNING: Use of this constructor outside of an InterpreterBuilder is not
116 // recommended.
117 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
118
119 ~Interpreter();
120
121 // Interpreters are not copyable as they have non-trivial memory semantics.
122 Interpreter(const Interpreter&) = delete;
123 Interpreter& operator=(const Interpreter&) = delete;
124
125 // Functions to build interpreter
126#ifndef DOXYGEN_SKIP
127 /// Provide a list of tensor indexes that are inputs to the model.
128 /// Each index is bound check and this modifies the consistent_ flag of the
129 /// interpreter.
130 TfLiteStatus SetInputs(std::vector<int> inputs);
131
132 /// Provide a list of tensor indexes that are outputs to the model
133 /// Each index is bound check and this modifies the consistent_ flag of the
134 /// interpreter.
135 TfLiteStatus SetOutputs(std::vector<int> outputs);
136
137 /// Provide a list of tensor indexes that are variable tensors.
138 /// Each index is bound check and this modifies the consistent_ flag of the
139 /// interpreter.
140 TfLiteStatus SetVariables(std::vector<int> variables);
141
142 /// Adds a node with the given parameters and returns the index of the new
143 /// node in `node_index` (optionally). Interpreter will take ownership of
144 /// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
145 /// remains with the caller.
146 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
147 const std::vector<int>& outputs,
148 const char* init_data,
149 size_t init_data_size, void* builtin_data,
150 const TfLiteRegistration* registration,
151 int* node_index = nullptr);
152
153 /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
154 /// The value pointed to by `first_new_tensor_index` will be set to the
155 /// index of the first new tensor if `first_new_tensor_index` is non-null.
156 TfLiteStatus AddTensors(int tensors_to_add,
157 int* first_new_tensor_index = nullptr);
158
159 /// Set description of inputs/outputs/data/fptrs for node `node_index`.
160 /// This variant assumes an external buffer has been allocated of size
161 /// bytes. The lifetime of buffer must be ensured to be greater or equal
162 /// to Interpreter.
163 TfLiteStatus SetTensorParametersReadOnly(
164 int tensor_index, TfLiteType type, const char* name,
165 const std::vector<int>& dims, TfLiteQuantization quantization,
166 const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
167
168 /// Legacy. Deprecated in favor of above.
169 inline TfLiteStatus SetTensorParametersReadOnly(
170 int tensor_index, TfLiteType type, const char* name,
171 const std::vector<int>& dims, TfLiteQuantizationParams quantization,
172 const char* buffer, size_t bytes,
173 const Allocation* allocation = nullptr) {
174 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
175 dims.data(), quantization, buffer, bytes,
176 allocation);
177 }
178
179 TfLiteStatus SetTensorParametersReadOnly(
180 int tensor_index, TfLiteType type, const char* name, const size_t rank,
181 const int* dims, TfLiteQuantizationParams quantization,
182 const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
183
184 /// Set description of inputs/outputs/data/fptrs for node `node_index`.
185 /// This variant assumes an external buffer has been allocated of size
186 /// bytes. The lifetime of buffer must be ensured to be greater or equal
187 /// to Interpreter.
188 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
189 const char* name,
190 const std::vector<int>& dims,
191 TfLiteQuantization quantization,
192 bool is_variable = false);
193
194 /// Legacy. Deprecated in favor of above.
195 inline TfLiteStatus SetTensorParametersReadWrite(
196 int tensor_index, TfLiteType type, const char* name,
197 const std::vector<int>& dims, TfLiteQuantizationParams quantization,
198 bool is_variable = false,
199 const std::vector<int>* dims_signature = nullptr) {
200 size_t rank_dims_signature = 0;
201 const int* dims_signature_pointer = nullptr;
202 if (dims_signature) {
203 rank_dims_signature = dims_signature->size();
204 dims_signature_pointer = dims_signature->data();
205 }
206 return SetTensorParametersReadWrite(
207 tensor_index, type, name, dims.size(), dims.data(), quantization,
208 is_variable, rank_dims_signature, dims_signature_pointer);
209 }
210 TfLiteStatus SetTensorParametersReadWrite(
211 int tensor_index, TfLiteType type, const char* name, const size_t rank,
212 const int* dims, TfLiteQuantizationParams quantization,
213 bool is_variable = false, const size_t rank_dims_signature = 0,
214 const int* dims_signature = nullptr);
215#endif // DOXYGEN_SKIP
216 // Functions to access tensor data
217
218 /// Read only access to list of inputs.
219 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
220
221 /// Return the name of a given input. The given index must be between 0 and
222 /// inputs().size().
223 const char* GetInputName(int index) const {
224 return context_->tensors[inputs()[index]].name;
225 }
226
227 /// Read only access to list of outputs.
228 const std::vector<int>& outputs() const {
229 return primary_subgraph().outputs();
230 }
231
232 /// Read only access to list of variable tensors.
233 const std::vector<int>& variables() const {
234 return primary_subgraph().variables();
235 }
236
237 /// Return the name of a given output. The given index must be between 0 and
238 /// outputs().size().
239 const char* GetOutputName(int index) const {
240 return context_->tensors[outputs()[index]].name;
241 }
242
243 /// Return the number of tensors in the model.
244 size_t tensors_size() const { return context_->tensors_size; }
245
246 /// Return the number of ops in the model.
247 size_t nodes_size() const { return primary_subgraph().nodes_size(); }
248
249 /// WARNING: Experimental interface, subject to change
250 const std::vector<int>& execution_plan() const {
251 return primary_subgraph().execution_plan();
252 }
253
254 /// Get a mutable tensor data structure.
255 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
256 // read/write access to structure
257 TfLiteTensor* tensor(int tensor_index) {
258 return primary_subgraph().tensor(tensor_index);
259 }
260
261 /// Get an immutable tensor data structure.
262 const TfLiteTensor* tensor(int tensor_index) const {
263 return primary_subgraph().tensor(tensor_index);
264 }
265
266 /// Returns a pointer to an operation and registration data structure if in
267 /// bounds from the primary subgraph(subgraph_[0]).
268 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
269 int node_index) const {
270 return primary_subgraph().node_and_registration(node_index);
271 }
272
273 /// Returns a pointer to an operation and registration data structure if in
274 /// bounds.
275 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
276 int subgraph_index, int node_index) const {
277 return subgraph(subgraph_index)->node_and_registration(node_index);
278 }
279
280 /// Perform a checked cast to the appropriate tensor type (mutable pointer
281 /// version).
282 template <class T>
283 T* typed_tensor(int tensor_index) {
284 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
285 if (tensor_ptr->type == typeToTfLiteType<T>()) {
286 return reinterpret_cast<T*>(tensor_ptr->data.raw);
287 }
288 }
289 return nullptr;
290 }
291
292 /// Perform a checked cast to the appropriate tensor type (immutable pointer
293 /// version).
294 template <class T>
295 const T* typed_tensor(int tensor_index) const {
296 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
297 if (tensor_ptr->type == typeToTfLiteType<T>()) {
298 return reinterpret_cast<const T*>(tensor_ptr->data.raw);
299 }
300 }
301 return nullptr;
302 }
303
304 /// WARNING: Experimental interface, subject to change
305 /// Returns list of all keys of different method signatures defined in the
306 /// model.
307 /// Note, pointers returned have lifetime same as the Interpreter object.
308 std::vector<const std::string*> signature_keys() const {
309 std::vector<const std::string*> signature_keys;
310 signature_keys.reserve(signature_defs_.size());
311 for (const auto& sig_def : signature_defs_) {
312 signature_keys.emplace_back(&sig_def.signature_key);
313 }
314 return signature_keys;
315 }
316
317 /// WARNING: Experimental interface, subject to change
318 /// Returns a pointer to the SignatureRunner instance to run the part of the
319 /// graph identified by a SignatureDef. The nullptr is returned if the given
320 /// signature key is not valid.
321 /// If you need to specify delegates, you have to do that before calling this
322 /// function. This function will additionally apply default delegates. Thus,
323 /// applying delegates after that might lead to undesirable behaviors.
324 /// Note, the pointed instance has lifetime same as the Interpreter object
325 /// and the SignatureRunner class is *not* thread-safe.
326 SignatureRunner* GetSignatureRunner(const char* signature_key);
327
328 /// WARNING: Experimental interface, subject to change
329 /// Return the subgraph index that corresponds to a SignatureDef, defined by
330 /// 'signature_key'.
331 /// If invalid name passed, -1 will be returned.
332 int GetSubgraphIndexFromSignature(const char* signature_key) const {
333 for (const auto& signature : signature_defs_) {
334 if (signature.signature_key == signature_key) {
335 return signature.subgraph_index;
336 }
337 }
338 return -1;
339 }
340
341 /// WARNING: Experimental interface, subject to change
342 /// Returns the mapping of inputs to tensor index in the signature
343 /// specified through 'signature_key'.
344 /// If invalid name passed, an empty list will be returned.
345 const std::map<std::string, uint32_t>& signature_inputs(
346 const char* signature_key) const {
347 for (const auto& sig_def : signature_defs_) {
348 if (sig_def.signature_key == signature_key) return sig_def.inputs;
349 }
350 static const std::map<std::string, uint32_t>* default_empty_list =
351 new std::map<std::string, uint32_t>();
352 return *default_empty_list;
353 }
354
355 /// WARNING: Experimental interface, subject to change
356 /// Returns the mapping of outputs to tensor index in the signature
357 /// specified through 'signature_key'.
358 /// If invalid name passed, an empty list will be returned.
359 const std::map<std::string, uint32_t>& signature_outputs(
360 const char* signature_key) const {
361 for (const auto& sig_def : signature_defs_) {
362 if (sig_def.signature_key == signature_key) return sig_def.outputs;
363 }
364 static const std::map<std::string, uint32_t>* default_empty_list =
365 new std::map<std::string, uint32_t>();
366 return *default_empty_list;
367 }
368
369 /// WARNING: Experimental interface, subject to change
370 /// Returns the input tensor identified by 'signature_input_name' in the
371 /// signature identified by 'signature_key'.
372 /// Returns nullptr if not found.
373 TfLiteTensor* input_tensor_by_signature(const char* signature_input_name,
374 const char* signature_key) {
375 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
376 if (subgraph_index == -1) return nullptr;
377 const int tensor_index = GetTensorIndexFromSignature(
378 signature_input_name, signature_key, /*is_input=*/true);
379 if (tensor_index == -1) return nullptr;
380 return subgraph(subgraph_index)->tensor(tensor_index);
381 }
382
383 /// WARNING: Experimental interface, subject to change
384 /// Returns the output tensor identified by 'signature_output_name' in the
385 /// signature identified by 'signature_key'.
386 /// Returns nullptr if not found.
387 const TfLiteTensor* output_tensor_by_signature(
388 const char* signature_output_name, const char* signature_key) const {
389 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
390 if (subgraph_index == -1) return nullptr;
391 const int tensor_index = GetTensorIndexFromSignature(
392 signature_output_name, signature_key, /*is_input=*/false);
393 if (tensor_index == -1) return nullptr;
394 return subgraph(subgraph_index)->tensor(tensor_index);
395 }
396
397 /// Return a mutable pointer to the given input tensor. The given index must
398 /// be between 0 and inputs().size().
399 TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
400
401 /// Return an immutable pointer to the given input tensor. The given index
402 /// must be between 0 and inputs().size().
403 const TfLiteTensor* input_tensor(size_t index) const {
404 return tensor(inputs()[index]);
405 }
406
407 /// Return a mutable pointer into the data of a given input tensor. The given
408 /// index must be between 0 and inputs().size().
409 template <class T>
410 T* typed_input_tensor(int index) {
411 return typed_tensor<T>(inputs()[index]);
412 }
413
414 /// Return an immutable pointer into the data of a given input tensor. The
415 /// given index must be between 0 and inputs().size().
416 template <class T>
417 const T* typed_input_tensor(int index) const {
418 return typed_tensor<T>(inputs()[index]);
419 }
420
421 /// Return a mutable pointer to the given output tensor. The given index must
422 /// be between 0 and outputs().size().
423 TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); }
424
425 /// Return an immutable pointer to the given output tensor. The given index
426 /// must be between 0 and outputs().size().
427 const TfLiteTensor* output_tensor(size_t index) const {
428 return tensor(outputs()[index]);
429 }
430
431 /// Return a mutable pointer into the data of a given output tensor. The given
432 /// index must be between 0 and outputs().size().
433 template <class T>
434 T* typed_output_tensor(int index) {
435 return typed_tensor<T>(outputs()[index]);
436 }
437
438 /// Return an immutable pointer into the data of a given output tensor. The
439 /// given index must be between 0 and outputs().size().
440 template <class T>
441 const T* typed_output_tensor(int index) const {
442 return typed_tensor<T>(outputs()[index]);
443 }
444
445 /// Change the dimensionality of a given tensor. Note, this is only acceptable
446 /// for tensor indices that are inputs or variables.
447 /// Returns status of failure or success. Note that this doesn't actually
448 /// resize any existing buffers. A call to AllocateTensors() is required to
449 /// change the tensor input buffer.
450 TfLiteStatus ResizeInputTensor(int tensor_index,
451 const std::vector<int>& dims);
452
453 /// Change the dimensionality of a given tensor. This is only acceptable for
454 /// tensor indices that are inputs or variables. Only unknown dimensions can
455 /// be resized with this function. Unknown dimensions are indicated as `-1` in
456 /// the `dims_signature` attribute of a `TfLiteTensor`. Returns status of
457 /// failure or success. Note that this doesn't actually resize any existing
458 /// buffers. A call to AllocateTensors() is required to change the tensor
459 /// input buffer.
460 TfLiteStatus ResizeInputTensorStrict(int tensor_index,
461 const std::vector<int>& dims);
462
463 /// This releases memory held by non-persistent tensors. It does NOT
464 /// re-perform memory planning. AllocateTensors needs to be called before next
465 /// invocation. WARNING: Experimental interface, subject to change
466 TfLiteStatus ReleaseNonPersistentMemory();
467
468 /// Update allocations for all tensors. This will redim dependent tensors
469 /// using the input tensor dimensionality as given. This is relatively
470 /// expensive. This *must be* called after the interpreter has been created
471 /// and before running inference (and accessing tensor buffers), and *must be*
472 /// called again if (and only if) an input tensor is resized. Returns status
473 /// of success or failure. Will fail if any of the ops in the model (other
474 /// than those which were rewritten by delegates, if any) are not supported by
475 /// the Interpreter's OpResolver.
476 TfLiteStatus AllocateTensors();
477
478 /// Invoke the interpreter (run the whole graph in dependency order).
479 ///
480 /// NOTE: It is possible that the interpreter is not in a ready state
481 /// to evaluate (i.e. if a ResizeTensor() has been performed without an
482 /// AllocateTensors().
483 /// Returns status of success or failure.
484 TfLiteStatus Invoke();
485
486 /// Set the number of threads available to the interpreter.
487 ///
488 /// NOTE: `num_threads` should be >= -1. Setting `num_threads` to 0 has the
489 /// effect to disable multithreading, which is equivalent to setting
490 /// `num_threads` to 1. If set to the value -1, the number of threads used
491 /// will be implementation-defined and platform-dependent.
492 ///
493 /// As TfLite interpreter could internally apply a TfLite delegate by default
494 /// (i.e. XNNPACK), the number of threads that are available to the default
495 /// delegate *should be* set via InterpreterBuilder APIs as follows:
496 ///
497 /// std::unique_ptr<tflite::Interpreter> interpreter;
498 /// tflite::InterpreterBuilder builder(tflite model, op resolver);
499 /// builder.SetNumThreads(...)
500 /// ASSERT_EQ(builder(&interpreter), kTfLiteOk);
501 ///
502 /// WARNING: This API is deprecated: prefer using
503 /// `InterpreterBuilder::SetNumThreads`, as documented above.
504 TfLiteStatus SetNumThreads(int num_threads);
505
506 /// Allow float16 precision for FP32 calculation when possible.
507 /// Default: not allow.
508 ///
509 /// WARNING: This API is deprecated: prefer controlling this via delegate
510 /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or
511 /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
512 /// This method will be removed in a future release.
513 void SetAllowFp16PrecisionForFp32(bool allow);
514
515 /// Get the half precision flag.
516 /// WARNING: This is an experimental API and subject to change.
517 bool GetAllowFp16PrecisionForFp32() const {
518 return context_->allow_fp32_relax_to_fp16;
519 }
520
521 /// Sets the cancellation function pointer in order to cancel a request in the
522 /// middle of a call to Invoke(). The interpreter queries this function during
523 /// inference, between op invocations; when it returns true, the interpreter
524 /// will abort execution and return `kTfLiteError`. The `data` parameter
525 /// contains any data used by the cancellation function, and if non-null,
526 /// remains owned by the caller.
527 /// WARNING: This is an experimental API and subject to change.
528 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
529
530 /// Allow a delegate to look at the graph and modify the graph to handle
531 /// parts of the graph themselves. After this is called, the graph may
532 /// contain new nodes that replace 1 more nodes.
533 /// 'delegate' must outlive the interpreter.
534 /// Returns one of the following status codes:
535 /// 1. kTfLiteOk: Success.
536 /// 2. kTfLiteDelegateError: Delegation failed due to an error in the
537 /// delegate, or the delegate parameter was null. The Interpreter has been
538 /// restored to its pre-delegation state.
539 /// NOTE: This undoes all delegates previously applied to the Interpreter.
540 /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
541 /// incompatibility with the TfLite runtime, e.g., the model graph is already
542 /// immutable when applying the delegate. However, the interpreter could still
543 /// be invoked.
544 /// 4. kTfLiteUnresolvedOps: Delegation failed because the model has an
545 /// operator that cannot be resolved. This can happen when the op is not
546 /// registered or built with the TF Lite framework.
547 /// 5. kTfLiteError: Unexpected/runtime failure.
548 /// WARNING: This is an experimental API and subject to change.
549 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
550
551 // Owning handle to a TfLiteDelegate instance.
552 using TfLiteDelegatePtr =
553 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
554
555 /// Same as ModifyGraphWithDelegate except this interpreter takes
556 /// ownership of the provided delegate.
557 /// WARNING: This is an experimental API and subject to change.
558 template <typename Delegate, typename Deleter>
559 inline TfLiteStatus ModifyGraphWithDelegate(
560 std::unique_ptr<Delegate, Deleter> delegate) {
561 Deleter deleter = std::move(delegate.get_deleter());
562
563 // Note that we retain ownership of the delegate even if graph modification
564 // fails, as delegate use will be in an indeterminate state at that point.
565 owned_delegates_.emplace_back(
566 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
567 deleter(
568 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
569 delegate_to_delete));
570 });
571 return ModifyGraphWithDelegate(owned_delegates_.back().get());
572 }
573
574 /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no
575 /// virtual destructor. The default deleter of the unique_ptr does not know
576 /// how to delete C++ objects deriving from TfLiteDelegate.
577 TfLiteStatus ModifyGraphWithDelegate(
578 std::unique_ptr<TfLiteDelegate> delegate) = delete;
579
580 /// Ensure the data in `tensor.data` is readable. In case delegate is used,
581 /// it might require to copy the data from delegate buffer to raw memory.
582 /// WARNING: This is an experimental API and subject to change.
583 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
584 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
585 }
586
587 /// Set the delegate buffer handle to a tensor. It can be called in the
588 /// following cases:
589 /// 1. Set the buffer handle to a tensor that's not being written by a
590 /// delegate. For example, feeding an OpenGL texture as the input of the
591 /// inference graph.
592 /// 2. Set the buffer handle to a tensor that uses the same delegate.
593 /// For example, set an OpenGL texture as the output of inference, while
594 /// the node which produces output is an OpenGL delegate node.
595 /// WARNING: This is an experimental API and subject to change.
596 TfLiteStatus SetBufferHandle(int tensor_index,
597 TfLiteBufferHandle buffer_handle,
598 TfLiteDelegate* delegate);
599
600 /// Get the delegate buffer handle, and the delegate which can process the
601 /// buffer handle.
602 /// WARNING: This is an experimental API and subject to change.
603 TfLiteStatus GetBufferHandle(int tensor_index,
604 TfLiteBufferHandle* buffer_handle,
605 TfLiteDelegate** delegate);
606
607 /// Sets the profiler to tracing execution. The caller retains ownership
608 /// of the profiler and must ensure its validity.
609 /// Previously registered profilers will be unregistered.
610 /// If `profiler` is nullptr, all previously installed profilers will be
611 /// removed.
612 /// WARNING: This is an experimental API and subject to change.
613 void SetProfiler(Profiler* profiler);
614
615 /// Same as SetProfiler except this interpreter takes ownership
616 /// of the provided profiler.
617 /// Previously registered profilers will be unregistered.
618 /// If `profiler` is nullptr, all previously installed profilers will be
619 /// removed.
620 /// WARNING: This is an experimental API and subject to change.
621 void SetProfiler(std::unique_ptr<Profiler> profiler);
622
623 /// Adds the profiler to tracing execution. The caller retains ownership
624 /// of the profiler and must ensure its validity.
625 /// nullptr `profiler` will be ignored.
626 /// WARNING: This is an experimental API and subject to change.
627 void AddProfiler(Profiler* profiler);
628
629 /// Gets the profiler used for op tracing.
630 /// WARNING: This is an experimental API and subject to change.
631 Profiler* GetProfiler();
632
633 // The default capacity of `tensors_` vector.
634 static constexpr int kTensorsReservedCapacity = 128;
635 /// The capacity headroom of `tensors_` vector before calling ops'
636 /// `prepare` and `invoke` function. In these functions, it's guaranteed
637 /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
638 /// pointers to existing tensors.
639 static constexpr int kTensorsCapacityHeadroom = 16;
640
641 /// Set if buffer handle output is allowed.
642 ///
643 /// When using hardware delegation, Interpreter will make the data of output
644 /// tensors available in `tensor->data` by default. If the application can
645 /// consume the buffer handle directly (e.g. reading output from OpenGL
646 /// texture), it can set this flag to false, so Interpreter won't copy the
647 /// data from buffer handle to CPU memory.
648 /// WARNING: This is an experimental API and subject to change.
649 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
650 allow_buffer_handle_output_ = allow_buffer_handle_output;
651 }
652
653 /// Reset all variable tensors to the default value.
654 /// If a variable tensor doesn't have a buffer, reset it to zero.
655 /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
656 /// to the value of the buffer.
657 /// WARNING: This is an experimental API and subject to change.
658 TfLiteStatus ResetVariableTensors();
659
660 /// Retrieve an operator's description of its work, for profiling purposes.
661 const char* OpProfilingString(const TfLiteRegistration& op_reg,
662 const TfLiteNode* node) const {
663 if (op_reg.profiling_string == nullptr) return nullptr;
664 return op_reg.profiling_string(context_, node);
665 }
666
667 // Set the value of an external context. TFLite interpreter doesn't take the
668 // memory ownership of this external context 'ctx', and the context should
669 // outlive the TFLite interpreter.
670 void SetExternalContext(TfLiteExternalContextType type,
671 TfLiteExternalContext* ctx);
672
673 /// Assigns (or reassigns) a custom memory allocation for the given tensor.
674 /// `flags` is a bitmask, see TfLiteCustomAllocationFlags.
675 /// The runtime does NOT take ownership of the underlying memory.
676 ///
677 /// NOTE: User needs to call AllocateTensors() after this.
678 /// Invalid/insufficient buffers will cause an error during AllocateTensors or
679 /// Invoke (in case of dynamic shapes in the graph).
680 ///
681 /// Parameters should satisfy the following conditions:
682 /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
683 /// In general, this is true for I/O tensors & variable tensors.
684 /// 2. allocation->data has the appropriate permissions for runtime access
685 /// (Read-only for inputs, Read-Write for others), and outlives
686 /// Interpreter.
687 /// 3. allocation->bytes >= tensor->bytes.
688 /// This condition is checked again if any tensors are resized.
689 /// 4. allocation->data should be aligned to kDefaultTensorAlignment
690 /// defined in lite/util.h. (Currently 64 bytes)
691 /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is
692 /// set through `flags`.
693 ///
694 /// WARNING: This is an experimental interface that is subject to change.
695 TfLiteStatus SetCustomAllocationForTensor(
696 int tensor_index, const TfLiteCustomAllocation& allocation,
697 int64_t flags = kTfLiteCustomAllocationFlagsNone);
698
699 /// Apply InterpreterOptions which tunes behavior of the interpreter.
700 /// WARNING: This is an experimental interface that is subject to change.
701 TfLiteStatus ApplyOptions(InterpreterOptions* options);
702
703#ifndef DOXYGEN_SKIP
704 /// Return the number of subgraphs in the model.
705 /// WARNING: This is an experimental API and subject to change.
706 size_t subgraphs_size() const { return subgraphs_.size(); }
707
708 /// Get a pointer to a subgraph if in bounds.
709 /// WARNING: This is an experimental API and subject to change.
710 const Subgraph* subgraph(int subgraph_index) const {
711 if (subgraph_index < 0 ||
712 static_cast<size_t>(subgraph_index) >= subgraphs_size()) {
713 return nullptr;
714 }
715 return subgraphs_[subgraph_index].get();
716 }
717
718 /// WARNING: This is an experimental API and subject to change.
719 Subgraph* subgraph(int subgraph_index) {
720 return const_cast<Subgraph*>(
721 static_cast<const Interpreter*>(this)->subgraph(subgraph_index));
722 }
723
724 /// WARNING: Experimental interface, subject to change
725 Subgraph& primary_subgraph() {
726 return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry.
727 }
728
729 /// WARNING: Experimental interface, subject to change
730 const Subgraph& primary_subgraph() const {
731 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
732 }
733#endif // DOXYGEN_SKIP
734
735 /// WARNING: Experimental interface, subject to change
736 /// Get the error reporter associated with this interpreter.
737 ErrorReporter* error_reporter() const { return error_reporter_; }
738
739 private:
740 friend class InterpreterBuilder;
741 friend class tflite::InterpreterTest;
742 friend class tflite::delegates::InterpreterUtils;
743 friend class tflite::delegates::test_utils::TestDelegation;
744 friend class tflite::interpreter_wrapper::InterpreterWrapper;
745
746 /// Set the value of an external context.
747 static void SetExternalContext(struct TfLiteContext* context,
748 TfLiteExternalContextType type,
749 TfLiteExternalContext* ctx);
750
751 // Helper method that return the tensor index that corresponds to
752 // a name in a SignatureDef. Defined by 'signature_key', and
753 // 'signature_tensor_name'.
754 // If 'is_input' is true then the tensor is checked in input tensors,
755 // otherwise it will be checked in output tensors.
756 // Returns -1 if the tensor is not found.
757 int GetTensorIndexFromSignature(const char* signature_tensor_name,
758 const char* signature_key,
759 bool is_input) const {
760 // Iterate directly and don't use other methods to avoid extra allocation.
761 for (const auto& signature : signature_defs_) {
762 if (signature.signature_key != signature_key) continue;
763 auto& signature_list = (is_input ? signature.inputs : signature.outputs);
764 auto tensor_iter = signature_list.find(signature_tensor_name);
765 if (tensor_iter == signature_list.end()) return -1;
766 return tensor_iter->second;
767 }
768 return -1;
769 }
770
771 // Applies TFLite default delegates.
772 TfLiteStatus ApplyLazyDelegateProviders();
773
774 // Private non-experimental implementation of ModifyGraphWithDelegate.
775 // Unlike ModifyGraphWithDelegate, ModifyGraphWithDelegateImpl is defined in
776 // interpreter.cc rather than in interpreter_experimental.cc, so it can be
777 // used to implement other non-experimental methods.
778 TfLiteStatus ModifyGraphWithDelegateImpl(TfLiteDelegate* delegate);
779
780 // Same as ModifyGraphWithDelegateImpl except that it takes ownership of the
781 // delegate.
782 template <typename Delegate, typename Deleter>
783 inline TfLiteStatus ModifyGraphWithDelegateImpl(
784 std::unique_ptr<Delegate, Deleter>&& delegate) {
785 Deleter deleter = std::move(delegate.get_deleter());
786
787 // Note that we retain ownership of the delegate even if graph modification
788 // fails, as delegate use will be in an indeterminate state at that point.
789 owned_delegates_.emplace_back(
790 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
791 deleter(
792 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
793 delegate_to_delete));
794 });
795 return ModifyGraphWithDelegateImpl(owned_delegates_.back().get());
796 }
797
798 // Overrides execution plan. ImplThis bounds checks indices sent in.
799 // Note: Only used during initialization.
800 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
801
802 // Sets the profiler to all subgraphs.
803 void SetSubgraphProfiler();
804
805 // Remove delegates (for fallback behaviour). The interpreter is invokable
806 // afterwards.
807 TfLiteStatus RemoveAllDelegates();
808
809 // Returns true if delegates have been applied.
810 bool HasDelegates();
811
812 // Returns true if the model has been fully delegated.
813 bool IsFullyDelegated() const;
814
815 // Returns true if cancellation function returns true.
816 bool IsCancelled();
817
818 // Sets the list of signature defs in the model.
819 void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) {
820 signature_defs_ = std::move(signature_defs);
821 }
822
823 // Sets model metadata as a mapping of name (key) and buffer (value) strings.
824 // Used by InterpreterBuilder, should be called after setting up subgraphs.
825 TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata);
826
827 /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
828 /// entries. The value pointed to by `first_new_subgraph_index` will be set to
829 /// the index of the first new subgraph if `first_new_subgraph_index` is
830 /// non-null.
831 void AddSubgraphs(int subgraphs_to_add,
832 int* first_new_subgraph_index = nullptr);
833
834 /// Implementation of SetProfiler.
835 /// Unlike SetProfiler, this is defined in interpreter.cc rather than in
836 /// interpreter_experimental.cc, so it can be used by interpreter_builder.cc.
837 void SetProfilerImpl(std::unique_ptr<Profiler> profiler);
838
839 TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options);
840
841 // A pure C data structure used to communicate with the pure C plugin
842 // interface. To avoid copying tensor metadata, this is also the definitive
843 // structure to store tensors.
844 // This is the primary subgraph context.
845 TfLiteContext* context_ = nullptr;
846
847 // The error reporter delegate that tflite will forward queries errors to.
848 ErrorReporter* error_reporter_ = nullptr;
849
850 // List of delegates that have been installed and are owned by this
851 // interpreter instance. Useful if client delegate ownership is burdensome.
852 // WARNING: This is an experimental API and subject to change.
853 // TODO(b/116667551): Use TfLiteExternalContext for storing state.
854 std::vector<
855 std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>>
856 owned_delegates_;
857
858 // A root profiler that holds a list of attached profiler implementations.
859 // will be nullptr if there's no child profiler registered.
860 std::unique_ptr<profiling::RootProfiler> root_profiler_;
861
862 bool allow_buffer_handle_output_ = false;
863
864 // List of active external contexts.
865 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
866
867 // The default external cpu backend context. After an TFLite interpreter is
868 // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point
869 // to this object. However, if this element value is overwritten via calling
870 // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to
871 // nullptr if necessary.
872 std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_;
873
874 // Subgraphs
875 std::vector<std::unique_ptr<Subgraph>> subgraphs_;
876
877 // A map of resources. Owned by interpreter and shared by multiple subgraphs.
878 resource::ResourceMap resources_;
879
880 // A map of resource Ids. Owned by interpreter and shared by multiple
881 // subgraphs.
882 resource::ResourceIDMap resource_ids_;
883
884 // A map of initialization statuses, that indicate whether the initialization
885 // subgraph invocation is done or not. Owned by interpreter and shared by
886 // multiple subgraphs.
887 resource::InitializationStatusMap initialization_status_map_;
888
889 // Indicating delegates that the TFLite interpreter will apply by default.
890 // An empty one means there's no delegate to be applied by default or
891 // delegates have been applied and doesn't need to be applied again.
892 using TfLiteDelegateCreator =
893 std::function<TfLiteDelegatePtr(int /*num_threads*/)>;
894 using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
895 TfLiteDelegateCreators lazy_delegate_providers_;
896
897 // List of SignatureDefs obtained from the model.
898 std::vector<internal::SignatureDef> signature_defs_;
899
900 // Map of signature key to its corresponding SignatureRunner object.
901 // A SignatureRunner is basically a wrapper of the Subgraph corresponding to
902 // its SignatureDef.
903 std::map<std::string, SignatureRunner> signature_runner_map_;
904
905 // Model metadata stored as mapping of name (key) to buffer (value).
906 // Data is mapped from the Metadata in TFLite flatbuffer model.
907 std::map<std::string, std::string> metadata_;
908
909 // InterpreterOptions object which is being used.
910 std::unique_ptr<InterpreterOptions> options_;
911};
912
913} // namespace tflite
914#endif // TENSORFLOW_LITE_CORE_IMPL_INTERPRETER_H_
915