1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/signature_runner.h" |
17 | |
18 | #include "tensorflow/lite/c/c_api_types.h" |
19 | |
20 | namespace tflite { |
21 | |
22 | SignatureRunner::SignatureRunner(const internal::SignatureDef* signature_def, |
23 | Subgraph* subgraph) |
24 | : signature_def_(signature_def), subgraph_(subgraph) { |
25 | // Collects the list of input and output tensor names. |
26 | for (const auto& it : signature_def_->inputs) { |
27 | input_names_.push_back(it.first.c_str()); |
28 | } |
29 | for (const auto& it : signature_def_->outputs) { |
30 | output_names_.push_back(it.first.c_str()); |
31 | } |
32 | } |
33 | |
34 | TfLiteTensor* SignatureRunner::input_tensor(const char* input_name) { |
35 | const auto& it = signature_def_->inputs.find(input_name); |
36 | if (it == signature_def_->inputs.end()) { |
37 | subgraph_->ReportError("Input name %s was not found" , input_name); |
38 | return nullptr; |
39 | } |
40 | return subgraph_->tensor(it->second); |
41 | } |
42 | |
43 | const TfLiteTensor* SignatureRunner::output_tensor( |
44 | const char* output_name) const { |
45 | const auto& it = signature_def_->outputs.find(output_name); |
46 | if (it == signature_def_->outputs.end()) { |
47 | subgraph_->ReportError("Output name %s was not found" , output_name); |
48 | return nullptr; |
49 | } |
50 | return subgraph_->tensor(it->second); |
51 | } |
52 | |
53 | TfLiteStatus SignatureRunner::ResizeInputTensor( |
54 | const char* input_name, const std::vector<int>& new_size) { |
55 | const auto& it = signature_def_->inputs.find(input_name); |
56 | if (it == signature_def_->inputs.end()) { |
57 | subgraph_->ReportError("Input name %s was not found" , input_name); |
58 | return kTfLiteError; |
59 | } |
60 | return subgraph_->ResizeInputTensor(it->second, new_size); |
61 | } |
62 | |
63 | TfLiteStatus SignatureRunner::ResizeInputTensorStrict( |
64 | const char* input_name, const std::vector<int>& new_size) { |
65 | const auto& it = signature_def_->inputs.find(input_name); |
66 | if (it == signature_def_->inputs.end()) { |
67 | subgraph_->ReportError("Input name %s was not found" , input_name); |
68 | return kTfLiteError; |
69 | } |
70 | return subgraph_->ResizeInputTensorStrict(it->second, new_size); |
71 | } |
72 | |
73 | TfLiteStatus SignatureRunner::Invoke() { |
74 | TF_LITE_ENSURE_STATUS(subgraph_->Invoke()); |
75 | |
76 | // Makes sure output tensors are readable. |
77 | for (int tensor_index : subgraph_->outputs()) { |
78 | TF_LITE_ENSURE_STATUS(subgraph_->EnsureTensorDataIsReadable(tensor_index)); |
79 | } |
80 | return kTfLiteOk; |
81 | } |
82 | |
83 | } // namespace tflite |
84 | |