1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stddef.h> |
17 | |
18 | #include <cstring> |
19 | #include <memory> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/lite/c/builtin_op_data.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/core/subgraph.h" |
25 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
26 | #include "tensorflow/lite/kernels/kernel_util.h" |
27 | |
28 | namespace tflite { |
29 | namespace ops { |
30 | namespace builtin { |
31 | namespace if_kernel { |
32 | |
33 | struct OpData { |
34 | int then_subgraph_index; |
35 | int else_subgraph_index; |
36 | }; |
37 | |
38 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
39 | auto* op_data = new OpData; |
40 | const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer); |
41 | op_data->then_subgraph_index = params->then_subgraph_index; |
42 | op_data->else_subgraph_index = params->else_subgraph_index; |
43 | return op_data; |
44 | } |
45 | |
46 | void Free(TfLiteContext* context, void* buffer) { |
47 | delete reinterpret_cast<OpData*>(buffer); |
48 | } |
49 | |
50 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
51 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
52 | |
53 | TF_LITE_ENSURE(context, node->inputs->size > 0); |
54 | |
55 | // The first input is the condition. |
56 | const TfLiteTensor* cond; |
57 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond)); |
58 | // Currently only bool is supported. |
59 | // TODO(ycling): Support other types since TensorFlow also support |
60 | // non-bool types as condition. |
61 | TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool); |
62 | TF_LITE_ENSURE_EQ(context, NumElements(cond), 1); |
63 | |
64 | // The first input of the node is the condition. The rest of inputs are |
65 | // passed to the branch subgraphs. Therefore, the number of subgraph inputs |
66 | // will be the number of node inputs - 1. |
67 | int num_inputs = node->inputs->size - 1; |
68 | int num_outputs = node->outputs->size; |
69 | |
70 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
71 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
72 | TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size()); |
73 | TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size()); |
74 | |
75 | Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get(); |
76 | Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get(); |
77 | |
78 | for (auto* subgraph : {then_subgraph, else_subgraph}) { |
79 | TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size()); |
80 | TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size()); |
81 | } |
82 | |
83 | bool has_dynamic_output_tensors = false; |
84 | for (auto* subgraph : {then_subgraph, else_subgraph}) { |
85 | for (int i = 0; i < num_inputs; ++i) { |
86 | // The first input of the node is the condition. The indices of the inputs |
87 | // passed to the subgraphs are offset by 1. |
88 | const TfLiteTensor* input; |
89 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input)); |
90 | std::vector<int> dims(input->dims->data, |
91 | input->dims->data + input->dims->size); |
92 | subgraph->ResizeInputTensor(i, dims); |
93 | TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]); |
94 | if (IsDynamicTensor(input)) { |
95 | SetTensorToDynamic(subgraph_input); |
96 | } |
97 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type); |
98 | } |
99 | // Note: The `Prepare` function is responsible to run `AllocateTensors` on |
100 | // both subgraphs. It's intentionally not to break out of the loop when |
101 | // finding a dynamic output tensor. |
102 | TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors()); |
103 | has_dynamic_output_tensors |= subgraph->HasDynamicTensors(); |
104 | } |
105 | |
106 | if (!has_dynamic_output_tensors) { |
107 | for (int i = 0; i < num_outputs; ++i) { |
108 | TfLiteTensor* then_output = |
109 | then_subgraph->tensor(then_subgraph->outputs()[i]); |
110 | TfLiteTensor* else_output = |
111 | else_subgraph->tensor(else_subgraph->outputs()[i]); |
112 | // If the 2 subgraphs have static but different output shapes, the output |
113 | // tensors of the IF op have dynamic sizes. |
114 | if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) { |
115 | has_dynamic_output_tensors = true; |
116 | break; |
117 | } |
118 | } |
119 | } |
120 | |
121 | for (int i = 0; i < num_outputs; ++i) { |
122 | TfLiteTensor* output; |
123 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
124 | if (has_dynamic_output_tensors) { |
125 | SetTensorToDynamic(output); |
126 | } else { |
127 | // When there's no dynamic output tensors, the 2 subgraph has exactly |
128 | // the same static sized outputs. |
129 | TfLiteTensor* then_output = |
130 | then_subgraph->tensor(then_subgraph->outputs()[i]); |
131 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims); |
132 | TF_LITE_ENSURE_OK(context, |
133 | context->ResizeTensor(context, output, output_size)); |
134 | } |
135 | } |
136 | |
137 | return kTfLiteOk; |
138 | } |
139 | |
140 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
141 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
142 | |
143 | const TfLiteTensor* cond; |
144 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond)); |
145 | bool cond_value = cond->data.b[0]; |
146 | |
147 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
148 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
149 | |
150 | // Currently we copy the input / output between the subgraphs. This isn't |
151 | // optimized yet. |
152 | // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. |
153 | int active_branch_subgraph_index = |
154 | cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index; |
155 | Subgraph& active_branch_subgraph = |
156 | *(*subgraphs)[active_branch_subgraph_index]; |
157 | |
158 | // We release memory of the subgraph at the end of evaluation to save memory. |
159 | // So it's required to call AllocateTensors() for the second run. |
160 | TF_LITE_ENSURE_OK(context, active_branch_subgraph.AllocateTensors()); |
161 | |
162 | for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) { |
163 | const TfLiteTensor* input; |
164 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input)); |
165 | TfLiteTensor* subgraph_input = |
166 | active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]); |
167 | |
168 | if (IsDynamicTensor(subgraph_input)) { |
169 | TfLiteTensorRealloc(input->bytes, subgraph_input); |
170 | } |
171 | |
172 | TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes); |
173 | TfLiteTensorCopy(input, subgraph_input); |
174 | } |
175 | |
176 | TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke()); |
177 | |
178 | for (int tensor_index : active_branch_subgraph.outputs()) { |
179 | active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index); |
180 | } |
181 | |
182 | bool has_dynamic_output_tensors = false; |
183 | for (int i = 0; i < node->outputs->size; ++i) { |
184 | TfLiteTensor* output; |
185 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
186 | if (IsDynamicTensor(output)) { |
187 | has_dynamic_output_tensors = true; |
188 | break; |
189 | } |
190 | } |
191 | |
192 | if (has_dynamic_output_tensors) { |
193 | for (int i = 0; i < node->outputs->size; ++i) { |
194 | TfLiteTensor* output; |
195 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
196 | TfLiteTensor* subgraph_output = |
197 | active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); |
198 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims); |
199 | TF_LITE_ENSURE_OK(context, |
200 | context->ResizeTensor(context, output, output_size)); |
201 | } |
202 | } |
203 | |
204 | for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) { |
205 | const TfLiteTensor* subgraph_output = |
206 | active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); |
207 | TfLiteTensor* output; |
208 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
209 | |
210 | if (IsDynamicTensor(output)) { |
211 | TfLiteTensorRealloc(subgraph_output->bytes, output); |
212 | } |
213 | |
214 | TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes); |
215 | TfLiteTensorCopy(subgraph_output, output); |
216 | } |
217 | |
218 | // Release memory of subgraphs to save the memory. Though it impacts latency, |
219 | // actual impacts looks very little, so no additional option is introduced for |
220 | // the feature until we find a different case. |
221 | Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get(); |
222 | Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get(); |
223 | TF_LITE_ENSURE_OK(context, then_subgraph->ReleaseMemory()); |
224 | TF_LITE_ENSURE_OK(context, else_subgraph->ReleaseMemory()); |
225 | |
226 | return kTfLiteOk; |
227 | } |
228 | |
229 | } // namespace if_kernel |
230 | |
231 | TfLiteRegistration* Register_IF() { |
232 | static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free, |
233 | if_kernel::Prepare, if_kernel::Eval}; |
234 | return &r; |
235 | } |
236 | |
237 | } // namespace builtin |
238 | } // namespace ops |
239 | } // namespace tflite |
240 | |