1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stddef.h>
17
18#include <cstring>
19#include <memory>
20#include <vector>
21
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/core/subgraph.h"
25#include "tensorflow/lite/kernels/internal/compatibility.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27
28namespace tflite {
29namespace ops {
30namespace builtin {
31namespace if_kernel {
32
33struct OpData {
34 int then_subgraph_index;
35 int else_subgraph_index;
36};
37
38void* Init(TfLiteContext* context, const char* buffer, size_t length) {
39 auto* op_data = new OpData;
40 const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer);
41 op_data->then_subgraph_index = params->then_subgraph_index;
42 op_data->else_subgraph_index = params->else_subgraph_index;
43 return op_data;
44}
45
46void Free(TfLiteContext* context, void* buffer) {
47 delete reinterpret_cast<OpData*>(buffer);
48}
49
50TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
52
53 TF_LITE_ENSURE(context, node->inputs->size > 0);
54
55 // The first input is the condition.
56 const TfLiteTensor* cond;
57 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
58 // Currently only bool is supported.
59 // TODO(ycling): Support other types since TensorFlow also support
60 // non-bool types as condition.
61 TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
62 TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
63
64 // The first input of the node is the condition. The rest of inputs are
65 // passed to the branch subgraphs. Therefore, the number of subgraph inputs
66 // will be the number of node inputs - 1.
67 int num_inputs = node->inputs->size - 1;
68 int num_outputs = node->outputs->size;
69
70 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
71 auto* subgraphs = this_subgraph->GetSubgraphs();
72 TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size());
73 TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size());
74
75 Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
76 Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
77
78 for (auto* subgraph : {then_subgraph, else_subgraph}) {
79 TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size());
80 TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size());
81 }
82
83 bool has_dynamic_output_tensors = false;
84 for (auto* subgraph : {then_subgraph, else_subgraph}) {
85 for (int i = 0; i < num_inputs; ++i) {
86 // The first input of the node is the condition. The indices of the inputs
87 // passed to the subgraphs are offset by 1.
88 const TfLiteTensor* input;
89 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
90 std::vector<int> dims(input->dims->data,
91 input->dims->data + input->dims->size);
92 subgraph->ResizeInputTensor(i, dims);
93 TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
94 if (IsDynamicTensor(input)) {
95 SetTensorToDynamic(subgraph_input);
96 }
97 TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
98 }
99 // Note: The `Prepare` function is responsible to run `AllocateTensors` on
100 // both subgraphs. It's intentionally not to break out of the loop when
101 // finding a dynamic output tensor.
102 TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors());
103 has_dynamic_output_tensors |= subgraph->HasDynamicTensors();
104 }
105
106 if (!has_dynamic_output_tensors) {
107 for (int i = 0; i < num_outputs; ++i) {
108 TfLiteTensor* then_output =
109 then_subgraph->tensor(then_subgraph->outputs()[i]);
110 TfLiteTensor* else_output =
111 else_subgraph->tensor(else_subgraph->outputs()[i]);
112 // If the 2 subgraphs have static but different output shapes, the output
113 // tensors of the IF op have dynamic sizes.
114 if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) {
115 has_dynamic_output_tensors = true;
116 break;
117 }
118 }
119 }
120
121 for (int i = 0; i < num_outputs; ++i) {
122 TfLiteTensor* output;
123 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
124 if (has_dynamic_output_tensors) {
125 SetTensorToDynamic(output);
126 } else {
127 // When there's no dynamic output tensors, the 2 subgraph has exactly
128 // the same static sized outputs.
129 TfLiteTensor* then_output =
130 then_subgraph->tensor(then_subgraph->outputs()[i]);
131 TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims);
132 TF_LITE_ENSURE_OK(context,
133 context->ResizeTensor(context, output, output_size));
134 }
135 }
136
137 return kTfLiteOk;
138}
139
140TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
141 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
142
143 const TfLiteTensor* cond;
144 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
145 bool cond_value = cond->data.b[0];
146
147 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
148 auto* subgraphs = this_subgraph->GetSubgraphs();
149
150 // Currently we copy the input / output between the subgraphs. This isn't
151 // optimized yet.
152 // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
153 int active_branch_subgraph_index =
154 cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
155 Subgraph& active_branch_subgraph =
156 *(*subgraphs)[active_branch_subgraph_index];
157
158 // We release memory of the subgraph at the end of evaluation to save memory.
159 // So it's required to call AllocateTensors() for the second run.
160 TF_LITE_ENSURE_OK(context, active_branch_subgraph.AllocateTensors());
161
162 for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
163 const TfLiteTensor* input;
164 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
165 TfLiteTensor* subgraph_input =
166 active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
167
168 if (IsDynamicTensor(subgraph_input)) {
169 TfLiteTensorRealloc(input->bytes, subgraph_input);
170 }
171
172 TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
173 TfLiteTensorCopy(input, subgraph_input);
174 }
175
176 TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke());
177
178 for (int tensor_index : active_branch_subgraph.outputs()) {
179 active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index);
180 }
181
182 bool has_dynamic_output_tensors = false;
183 for (int i = 0; i < node->outputs->size; ++i) {
184 TfLiteTensor* output;
185 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
186 if (IsDynamicTensor(output)) {
187 has_dynamic_output_tensors = true;
188 break;
189 }
190 }
191
192 if (has_dynamic_output_tensors) {
193 for (int i = 0; i < node->outputs->size; ++i) {
194 TfLiteTensor* output;
195 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
196 TfLiteTensor* subgraph_output =
197 active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
198 TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
199 TF_LITE_ENSURE_OK(context,
200 context->ResizeTensor(context, output, output_size));
201 }
202 }
203
204 for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
205 const TfLiteTensor* subgraph_output =
206 active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
207 TfLiteTensor* output;
208 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
209
210 if (IsDynamicTensor(output)) {
211 TfLiteTensorRealloc(subgraph_output->bytes, output);
212 }
213
214 TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
215 TfLiteTensorCopy(subgraph_output, output);
216 }
217
218 // Release memory of subgraphs to save the memory. Though it impacts latency,
219 // actual impacts looks very little, so no additional option is introduced for
220 // the feature until we find a different case.
221 Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
222 Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
223 TF_LITE_ENSURE_OK(context, then_subgraph->ReleaseMemory());
224 TF_LITE_ENSURE_OK(context, else_subgraph->ReleaseMemory());
225
226 return kTfLiteOk;
227}
228
229} // namespace if_kernel
230
231TfLiteRegistration* Register_IF() {
232 static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free,
233 if_kernel::Prepare, if_kernel::Eval};
234 return &r;
235}
236
237} // namespace builtin
238} // namespace ops
239} // namespace tflite
240