1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
16#define TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
17
18#include <cstddef>
19#include <cstdint>
20#include <string>
21#include <vector>
22
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/core/subgraph.h"
25#include "tensorflow/lite/internal/signature_def.h"
26
27namespace tflite {
28class Interpreter; // Class for friend declarations.
29class SignatureRunnerJNIHelper; // Class for friend declarations.
30class TensorHandle; // Class for friend declarations.
31class SignatureRunnerHelper; // Class for friend declarations.
32
33/// WARNING: Experimental interface, subject to change
34///
35/// SignatureRunner class for running TFLite models using SignatureDef.
36///
37/// Usage:
38///
39/// <pre><code>
40/// // Create model from file. Note that the model instance must outlive the
41/// // interpreter instance.
42/// auto model = tflite::FlatBufferModel::BuildFromFile(...);
43/// if (model == nullptr) {
44/// // Return error.
45/// }
46///
47/// // Create an Interpreter with an InterpreterBuilder.
48/// std::unique_ptr<tflite::Interpreter> interpreter;
49/// tflite::ops::builtin::BuiltinOpResolver resolver;
50/// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
51/// // Return failure.
52/// }
53///
54/// // Get the list of signatures and check it.
55/// auto signature_defs = interpreter->signature_def_names();
56/// if (signature_defs.empty()) {
57/// // Return error.
58/// }
59///
60/// // Get pointer to the SignatureRunner instance corresponding to a signature.
61/// // Note that the pointed SignatureRunner instance has lifetime same as the
62/// // Interpreter instance.
63/// tflite::SignatureRunner* runner =
64/// interpreter->GetSignatureRunner(signature_defs[0]->c_str());
65/// if (runner == nullptr) {
66/// // Return error.
67/// }
68/// if (runner->AllocateTensors() != kTfLiteOk) {
69/// // Return failure.
70/// }
71///
72/// // Set input data. In this example, the input tensor has float type.
73/// float* input = runner->input_tensor(0)->data.f;
74/// for (int i = 0; i < input_size; i++) {
75/// input[i] = ...;
76// }
77/// runner->Invoke();
78/// </code></pre>
79///
80/// WARNING: This class is *not* thread-safe. The client is responsible for
81/// ensuring serialized interaction to avoid data races and undefined behavior.
82///
83/// SignatureRunner and Interpreter share the same underlying data. Calling
84/// methods on an Interpreter object will affect the state in corresponding
85/// SignatureRunner objects. Therefore, it is recommended not to call other
86/// Interpreter methods after calling GetSignatureRunner to create
87/// SignatureRunner instances.
88class SignatureRunner {
89 public:
90 /// Returns the key for the corresponding signature.
91 const std::string& signature_key() { return signature_def_->signature_key; }
92
93 /// Returns the number of inputs.
94 size_t input_size() const { return subgraph_->inputs().size(); }
95
96 /// Returns the number of outputs.
97 size_t output_size() const { return subgraph_->outputs().size(); }
98
99 /// Read-only access to list of signature input names.
100 const std::vector<const char*>& input_names() { return input_names_; }
101
102 /// Read-only access to list of signature output names.
103 const std::vector<const char*>& output_names() { return output_names_; }
104
105 /// Returns the input tensor identified by 'input_name' in the
106 /// given signature. Returns nullptr if the given name is not valid.
107 TfLiteTensor* input_tensor(const char* input_name);
108
109 /// Returns the output tensor identified by 'output_name' in the
110 /// given signature. Returns nullptr if the given name is not valid.
111 const TfLiteTensor* output_tensor(const char* output_name) const;
112
113 /// Change a dimensionality of a given tensor. Note, this is only acceptable
114 /// for tensors that are inputs.
115 /// Returns status of failure or success. Note that this doesn't actually
116 /// resize any existing buffers. A call to AllocateTensors() is required to
117 /// change the tensor input buffer.
118 TfLiteStatus ResizeInputTensor(const char* input_name,
119 const std::vector<int>& new_size);
120
121 /// Change the dimensionality of a given tensor. This is only acceptable for
122 /// tensor indices that are inputs or variables. Only unknown dimensions can
123 /// be resized with this function. Unknown dimensions are indicated as `-1` in
124 /// the `dims_signature` attribute of a TfLiteTensor.
125 /// Returns status of failure or success. Note that this doesn't actually
126 /// resize any existing buffers. A call to AllocateTensors() is required to
127 /// change the tensor input buffer.
128 TfLiteStatus ResizeInputTensorStrict(const char* input_name,
129 const std::vector<int>& new_size);
130
131 /// Updates allocations for all tensors, related to the given signature.
132 TfLiteStatus AllocateTensors() { return subgraph_->AllocateTensors(); }
133
134 /// Invokes the signature runner (run the graph identified by the given
135 /// signature in dependency order).
136 TfLiteStatus Invoke();
137
138 private:
139 // The life cycle of SignatureRunner depends on the life cycle of Subgraph,
140 // which is owned by an Interpreter. Therefore, the Interpreter will takes the
141 // responsibility to create and manage SignatureRunner objects to make sure
142 // SignatureRunner objects don't outlive their corresponding Subgraph objects.
143 SignatureRunner(const internal::SignatureDef* signature_def,
144 Subgraph* subgraph);
145 friend class Interpreter;
146 friend class SignatureRunnerJNIHelper;
147 friend class TensorHandle;
148 friend class SignatureRunnerHelper;
149
150 // The SignatureDef object is owned by the interpreter.
151 const internal::SignatureDef* signature_def_;
152 // The Subgraph object is owned by the interpreter.
153 Subgraph* subgraph_;
154 // The list of input tensor names.
155 std::vector<const char*> input_names_;
156 // The list of output tensor names.
157 std::vector<const char*> output_names_;
158};
159
160} // namespace tflite
161
162#endif // TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
163