1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ |
16 | #define TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ |
17 | |
18 | #include <cstddef> |
19 | #include <cstdint> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/core/subgraph.h" |
25 | #include "tensorflow/lite/internal/signature_def.h" |
26 | |
27 | namespace tflite { |
28 | class Interpreter; // Class for friend declarations. |
29 | class SignatureRunnerJNIHelper; // Class for friend declarations. |
30 | class TensorHandle; // Class for friend declarations. |
31 | class SignatureRunnerHelper; // Class for friend declarations. |
32 | |
33 | /// WARNING: Experimental interface, subject to change |
34 | /// |
35 | /// SignatureRunner class for running TFLite models using SignatureDef. |
36 | /// |
37 | /// Usage: |
38 | /// |
39 | /// <pre><code> |
40 | /// // Create model from file. Note that the model instance must outlive the |
41 | /// // interpreter instance. |
42 | /// auto model = tflite::FlatBufferModel::BuildFromFile(...); |
43 | /// if (model == nullptr) { |
44 | /// // Return error. |
45 | /// } |
46 | /// |
47 | /// // Create an Interpreter with an InterpreterBuilder. |
48 | /// std::unique_ptr<tflite::Interpreter> interpreter; |
49 | /// tflite::ops::builtin::BuiltinOpResolver resolver; |
50 | /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { |
51 | /// // Return failure. |
52 | /// } |
53 | /// |
54 | /// // Get the list of signatures and check it. |
55 | /// auto signature_defs = interpreter->signature_def_names(); |
56 | /// if (signature_defs.empty()) { |
57 | /// // Return error. |
58 | /// } |
59 | /// |
60 | /// // Get pointer to the SignatureRunner instance corresponding to a signature. |
61 | /// // Note that the pointed SignatureRunner instance has lifetime same as the |
62 | /// // Interpreter instance. |
63 | /// tflite::SignatureRunner* runner = |
64 | /// interpreter->GetSignatureRunner(signature_defs[0]->c_str()); |
65 | /// if (runner == nullptr) { |
66 | /// // Return error. |
67 | /// } |
68 | /// if (runner->AllocateTensors() != kTfLiteOk) { |
69 | /// // Return failure. |
70 | /// } |
71 | /// |
72 | /// // Set input data. In this example, the input tensor has float type. |
73 | /// float* input = runner->input_tensor(0)->data.f; |
74 | /// for (int i = 0; i < input_size; i++) { |
75 | /// input[i] = ...; |
76 | // } |
77 | /// runner->Invoke(); |
78 | /// </code></pre> |
79 | /// |
80 | /// WARNING: This class is *not* thread-safe. The client is responsible for |
81 | /// ensuring serialized interaction to avoid data races and undefined behavior. |
82 | /// |
83 | /// SignatureRunner and Interpreter share the same underlying data. Calling |
84 | /// methods on an Interpreter object will affect the state in corresponding |
85 | /// SignatureRunner objects. Therefore, it is recommended not to call other |
86 | /// Interpreter methods after calling GetSignatureRunner to create |
87 | /// SignatureRunner instances. |
88 | class SignatureRunner { |
89 | public: |
90 | /// Returns the key for the corresponding signature. |
91 | const std::string& signature_key() { return signature_def_->signature_key; } |
92 | |
93 | /// Returns the number of inputs. |
94 | size_t input_size() const { return subgraph_->inputs().size(); } |
95 | |
96 | /// Returns the number of outputs. |
97 | size_t output_size() const { return subgraph_->outputs().size(); } |
98 | |
99 | /// Read-only access to list of signature input names. |
100 | const std::vector<const char*>& input_names() { return input_names_; } |
101 | |
102 | /// Read-only access to list of signature output names. |
103 | const std::vector<const char*>& output_names() { return output_names_; } |
104 | |
105 | /// Returns the input tensor identified by 'input_name' in the |
106 | /// given signature. Returns nullptr if the given name is not valid. |
107 | TfLiteTensor* input_tensor(const char* input_name); |
108 | |
109 | /// Returns the output tensor identified by 'output_name' in the |
110 | /// given signature. Returns nullptr if the given name is not valid. |
111 | const TfLiteTensor* output_tensor(const char* output_name) const; |
112 | |
113 | /// Change a dimensionality of a given tensor. Note, this is only acceptable |
114 | /// for tensors that are inputs. |
115 | /// Returns status of failure or success. Note that this doesn't actually |
116 | /// resize any existing buffers. A call to AllocateTensors() is required to |
117 | /// change the tensor input buffer. |
118 | TfLiteStatus ResizeInputTensor(const char* input_name, |
119 | const std::vector<int>& new_size); |
120 | |
121 | /// Change the dimensionality of a given tensor. This is only acceptable for |
122 | /// tensor indices that are inputs or variables. Only unknown dimensions can |
123 | /// be resized with this function. Unknown dimensions are indicated as `-1` in |
124 | /// the `dims_signature` attribute of a TfLiteTensor. |
125 | /// Returns status of failure or success. Note that this doesn't actually |
126 | /// resize any existing buffers. A call to AllocateTensors() is required to |
127 | /// change the tensor input buffer. |
128 | TfLiteStatus ResizeInputTensorStrict(const char* input_name, |
129 | const std::vector<int>& new_size); |
130 | |
131 | /// Updates allocations for all tensors, related to the given signature. |
132 | TfLiteStatus AllocateTensors() { return subgraph_->AllocateTensors(); } |
133 | |
134 | /// Invokes the signature runner (run the graph identified by the given |
135 | /// signature in dependency order). |
136 | TfLiteStatus Invoke(); |
137 | |
138 | private: |
139 | // The life cycle of SignatureRunner depends on the life cycle of Subgraph, |
140 | // which is owned by an Interpreter. Therefore, the Interpreter will takes the |
141 | // responsibility to create and manage SignatureRunner objects to make sure |
142 | // SignatureRunner objects don't outlive their corresponding Subgraph objects. |
143 | SignatureRunner(const internal::SignatureDef* signature_def, |
144 | Subgraph* subgraph); |
145 | friend class Interpreter; |
146 | friend class SignatureRunnerJNIHelper; |
147 | friend class TensorHandle; |
148 | friend class SignatureRunnerHelper; |
149 | |
150 | // The SignatureDef object is owned by the interpreter. |
151 | const internal::SignatureDef* signature_def_; |
152 | // The Subgraph object is owned by the interpreter. |
153 | Subgraph* subgraph_; |
154 | // The list of input tensor names. |
155 | std::vector<const char*> input_names_; |
156 | // The list of output tensor names. |
157 | std::vector<const char*> output_names_; |
158 | }; |
159 | |
160 | } // namespace tflite |
161 | |
162 | #endif // TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ |
163 | |