1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stddef.h> |
17 | |
18 | #include <cstring> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/lite/c/builtin_op_data.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/context_util.h" |
24 | #include "tensorflow/lite/core/subgraph.h" |
25 | #include "tensorflow/lite/kernels/kernel_util.h" |
26 | |
27 | namespace tflite { |
28 | namespace ops { |
29 | namespace builtin { |
30 | namespace while_kernel { |
31 | |
32 | struct OpData { |
33 | int cond_subgraph_index; |
34 | int body_subgraph_index; |
35 | bool cond_has_dynamic_output_tensors; |
36 | bool body_has_dynamic_output_tensors; |
37 | bool body_use_shallow_copy; |
38 | // set when Prepare_impl() is called. |
39 | bool subgraphs_prepared; |
40 | }; |
41 | |
42 | namespace { |
43 | |
44 | // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` |
45 | // to `dst_tensor_indices` in `dst_subgraph`. |
46 | // |
47 | // When `resize_subgraph_inputs` is true, the function calls subgraphs's |
48 | // `ResizeInputTensor` function, and it may trigger the memory planner to |
49 | // reallocate memory. |
50 | // When `resize_subgraph_inputs` is false, it implies `context` belongs to |
51 | // `dst_subgraph`. The function calls `context->ResizeTensor`. This happens |
52 | // when resizing `While` op's outputs. |
53 | template <typename SrcVector, typename DstVector> |
54 | TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, |
55 | Subgraph* src_subgraph, |
56 | const SrcVector& src_tensor_indices, |
57 | Subgraph* dst_subgraph, |
58 | const DstVector& dst_tensor_indices, |
59 | bool resize_subgraph_inputs) { |
60 | TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), |
61 | dst_tensor_indices.size()); |
62 | for (int i = 0; i < src_tensor_indices.size(); ++i) { |
63 | // Skip copying unused destination tensors. |
64 | if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; |
65 | |
66 | const TfLiteTensor* src_tensor = |
67 | src_subgraph->tensor(src_tensor_indices[i]); |
68 | |
69 | TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); |
70 | if (resize_subgraph_inputs) { |
71 | std::vector<int> dims(src_tensor->dims->data, |
72 | src_tensor->dims->data + src_tensor->dims->size); |
73 | dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); |
74 | } else { |
75 | TF_LITE_ENSURE_OK( |
76 | context, context->ResizeTensor(context, dst_tensor, |
77 | TfLiteIntArrayCopy(src_tensor->dims))); |
78 | } |
79 | dst_tensor->type = src_tensor->type; |
80 | } |
81 | return kTfLiteOk; |
82 | } |
83 | |
84 | // Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph` |
85 | // to `dst_tensor_indices` in `dst_subgraph`. |
86 | template <typename SrcVector, typename DstVector> |
87 | TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph, |
88 | const SrcVector& src_tensor_indices, |
89 | Subgraph* dst_subgraph, |
90 | const DstVector& dst_tensor_indices) { |
91 | TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), |
92 | dst_tensor_indices.size()); |
93 | for (int i = 0; i < src_tensor_indices.size(); ++i) { |
94 | // Skip copying unused destination tensors. |
95 | if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; |
96 | |
97 | const TfLiteTensor* src_tensor = |
98 | src_subgraph->tensor(src_tensor_indices[i]); |
99 | TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); |
100 | if (IsDynamicTensor(dst_tensor)) { |
101 | TfLiteTensorRealloc(src_tensor->bytes, dst_tensor); |
102 | } |
103 | TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes); |
104 | TfLiteTensorCopy(src_tensor, dst_tensor); |
105 | } |
106 | return kTfLiteOk; |
107 | } |
108 | |
109 | // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` |
110 | // to `dst_tensor_indices` in `dst_subgraph` and copy data deeply. |
111 | template <typename SrcVector, typename DstVector> |
112 | TfLiteStatus (TfLiteContext* context, |
113 | TfLiteNode* node, |
114 | Subgraph* src_subgraph, |
115 | const SrcVector& src_tensor_indices, |
116 | Subgraph* dst_subgraph, |
117 | const DstVector& dst_tensor_indices) { |
118 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
119 | |
120 | if (op_data->body_has_dynamic_output_tensors) { |
121 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
122 | bool resize_subgraph_inputs = (dst_subgraph != this_subgraph); |
123 | TF_LITE_ENSURE_OK( |
124 | context, CopyTensorsShapeAndType( |
125 | context, src_subgraph, src_tensor_indices, dst_subgraph, |
126 | dst_tensor_indices, resize_subgraph_inputs)); |
127 | if (resize_subgraph_inputs) { |
128 | TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); |
129 | } |
130 | } |
131 | TF_LITE_ENSURE_OK(context, |
132 | CopyTensorsData(context, src_subgraph, src_tensor_indices, |
133 | dst_subgraph, dst_tensor_indices)); |
134 | return kTfLiteOk; |
135 | } |
136 | |
137 | // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` |
138 | // to `dst_tensor_indices` in `dst_subgraph` and copy data shallowly. |
139 | template <typename SrcVector, typename DstVector> |
140 | TfLiteStatus ( |
141 | TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph, |
142 | const SrcVector& src_tensor_indices, Subgraph* dst_subgraph, |
143 | const DstVector& dst_tensor_indices) { |
144 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
145 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
146 | TF_LITE_ENSURE_EQ(context, op_data->body_has_dynamic_output_tensors, true); |
147 | // Only allow shallow copy from main node input. |
148 | TF_LITE_ENSURE_EQ(context, src_subgraph, this_subgraph); |
149 | |
150 | TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), |
151 | dst_tensor_indices.size()); |
152 | bool reallocation_needed = false; |
153 | for (int i = 0; i < src_tensor_indices.size(); ++i) { |
154 | // Skip copying unused destination tensors. |
155 | if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; |
156 | |
157 | const TfLiteTensor* src_tensor = |
158 | src_subgraph->tensor(src_tensor_indices[i]); |
159 | TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); |
160 | |
161 | if (!TfLiteIntArrayEqual(src_tensor->dims, dst_tensor->dims)) { |
162 | reallocation_needed = true; |
163 | TfLiteIntArrayFree(dst_tensor->dims); |
164 | dst_tensor->dims = TfLiteIntArrayCopy(src_tensor->dims); |
165 | } |
166 | dst_tensor->type = src_tensor->type; |
167 | dst_tensor->bytes = 0; // Don't allocate memory with AllocateTensors(). |
168 | dst_tensor->data.raw = nullptr; |
169 | } |
170 | |
171 | if (reallocation_needed && dst_subgraph != this_subgraph) { |
172 | TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); |
173 | } |
174 | |
175 | for (int i = 0; i < src_tensor_indices.size(); ++i) { |
176 | // Skip copying unused destination tensors. |
177 | if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; |
178 | |
179 | const TfLiteTensor* src_tensor = |
180 | src_subgraph->tensor(src_tensor_indices[i]); |
181 | TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); |
182 | |
183 | dst_tensor->bytes = src_tensor->bytes; |
184 | dst_tensor->data.raw = src_tensor->data.raw; |
185 | } |
186 | |
187 | return kTfLiteOk; |
188 | } |
189 | |
190 | TfLiteStatus CheckCondOutput(TfLiteContext* context, |
191 | const TfLiteTensor* cond_output) { |
192 | // The condition output must be a single boolean value. |
193 | TF_LITE_ENSURE_TYPES_EQ(context, cond_output->type, kTfLiteBool); |
194 | if (cond_output->dims->size == 0) { |
195 | // It's okay if it's a 0D scalar. |
196 | return kTfLiteOk; |
197 | } |
198 | // Otherwise it must be 1D with shape [1]. |
199 | TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1); |
200 | TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1); |
201 | return kTfLiteOk; |
202 | } |
203 | |
204 | } // namespace |
205 | |
206 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
207 | auto* op_data = new OpData; |
208 | const auto* params = reinterpret_cast<const TfLiteWhileParams*>(buffer); |
209 | op_data->cond_subgraph_index = params->cond_subgraph_index; |
210 | op_data->body_subgraph_index = params->body_subgraph_index; |
211 | op_data->cond_has_dynamic_output_tensors = false; |
212 | op_data->body_has_dynamic_output_tensors = false; |
213 | op_data->body_use_shallow_copy = false; |
214 | op_data->subgraphs_prepared = false; |
215 | return op_data; |
216 | } |
217 | |
218 | void Free(TfLiteContext* context, void* buffer) { |
219 | delete reinterpret_cast<OpData*>(buffer); |
220 | } |
221 | |
222 | TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { |
223 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
224 | int num_inputs = node->inputs->size; |
225 | // The number of outputs should be the same as number of inputs. |
226 | TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs); |
227 | |
228 | // Check subgraph indices and get subgraphs. |
229 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
230 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
231 | TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size()); |
232 | TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size()); |
233 | TF_LITE_ENSURE(context, |
234 | op_data->cond_subgraph_index != op_data->body_subgraph_index); |
235 | |
236 | Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); |
237 | Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); |
238 | |
239 | // Check input & output count of the condition subgraph. |
240 | TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs); |
241 | TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1); |
242 | |
243 | // Check input & output count of the body subgraph. |
244 | TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs); |
245 | TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs); |
246 | |
247 | // Remove unused inputs of the condition subgraph to skip copying unnecessary |
248 | // inputs. |
249 | cond_subgraph->RemoveUnusedInputs(); |
250 | |
251 | // Prepare and check the condition subgraph. |
252 | TF_LITE_ENSURE_OK( |
253 | context, CopyTensorsShapeAndType( |
254 | context, this_subgraph, TfLiteIntArrayView(node->inputs), |
255 | cond_subgraph, cond_subgraph->inputs(), true)); |
256 | TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); |
257 | TfLiteTensor* cond_output = |
258 | cond_subgraph->tensor(cond_subgraph->outputs()[0]); |
259 | // This should rarely happens. In most cases the output is static with shape |
260 | // [1]. However theoretically intermediate tensors in the cond subgraph |
261 | // can be dynamic. |
262 | if (IsDynamicTensor(cond_output)) { |
263 | op_data->cond_has_dynamic_output_tensors = true; |
264 | } else { |
265 | TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); |
266 | } |
267 | |
268 | // Prepare and check the body subgraph. |
269 | TF_LITE_ENSURE_OK( |
270 | context, CopyTensorsShapeAndType( |
271 | context, this_subgraph, TfLiteIntArrayView(node->inputs), |
272 | body_subgraph, body_subgraph->inputs(), true)); |
273 | |
274 | bool input_has_resource_or_variant_tensor = false; |
275 | for (int i = 0; i < num_inputs; ++i) { |
276 | if (IsResourceOrVariant( |
277 | body_subgraph->tensor(body_subgraph->inputs()[i]))) { |
278 | input_has_resource_or_variant_tensor = true; |
279 | break; |
280 | } |
281 | } |
282 | if (this_subgraph->ShouldOptimizeMemoryForLargeTensors() && |
283 | !input_has_resource_or_variant_tensor) { |
284 | // The current shallow copy requires to use dynamic tensors which introduces |
285 | // additional overheads. Therefore, use the method only if dynamic |
286 | // allocation is enabled. |
287 | op_data->body_use_shallow_copy = true; |
288 | op_data->body_has_dynamic_output_tensors = true; |
289 | // Make body inputs dynamic to use shallow copy with Eval_dynamic(). |
290 | for (int i = 0; i < num_inputs; ++i) { |
291 | TfLiteTensor* body_input = |
292 | body_subgraph->tensor(body_subgraph->inputs()[i]); |
293 | SetTensorToDynamic(body_input); |
294 | body_input->bytes = 0; |
295 | } |
296 | } |
297 | |
298 | TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); |
299 | if (body_subgraph->HasDynamicTensors()) { |
300 | op_data->body_has_dynamic_output_tensors = true; |
301 | } else { |
302 | for (int i = 0; i < num_inputs; ++i) { |
303 | TfLiteTensor* body_input = |
304 | body_subgraph->tensor(body_subgraph->inputs()[i]); |
305 | TfLiteTensor* body_output = |
306 | body_subgraph->tensor(body_subgraph->outputs()[i]); |
307 | TF_LITE_ENSURE_TYPES_EQ(context, body_input->type, body_output->type); |
308 | |
309 | TF_LITE_ENSURE(context, !IsDynamicTensor(body_output)); |
310 | if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) { |
311 | // If the output shape of the body subgraph is static w.r.t. a fixed |
312 | // input size, but it's different from input size, it's still considered |
313 | // dynamic. For example: If a subgraph keeps padding its input with a |
314 | // fixed padding, the output shape is static w.r.t the input shape and |
315 | // padding, but running it in a loop will keep bloating the tensor. |
316 | op_data->body_has_dynamic_output_tensors = true; |
317 | break; |
318 | } |
319 | } |
320 | } |
321 | for (int i = 0; i < num_inputs; ++i) { |
322 | TfLiteTensor* output; |
323 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
324 | if (op_data->body_has_dynamic_output_tensors) { |
325 | SetTensorToDynamic(output); |
326 | } else { |
327 | TfLiteTensor* body_output = |
328 | body_subgraph->tensor(body_subgraph->outputs()[i]); |
329 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims); |
330 | TF_LITE_ENSURE_OK(context, |
331 | context->ResizeTensor(context, output, output_size)); |
332 | } |
333 | } |
334 | op_data->subgraphs_prepared = true; |
335 | return kTfLiteOk; |
336 | } |
337 | |
338 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
339 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
340 | if (this_subgraph->ShouldOptimizeMemoryForLargeTensors()) { |
341 | // Apply lazy initialization of WHILE kernel. |
342 | // Just make node output tensors dynamic. |
343 | int num_outputs = node->outputs->size; |
344 | for (int i = 0; i < num_outputs; ++i) { |
345 | TfLiteTensor* output; |
346 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
347 | SetTensorToDynamic(output); |
348 | } |
349 | return kTfLiteOk; |
350 | } |
351 | return Prepare_impl(context, node); |
352 | } |
353 | |
354 | TfLiteStatus Prepare_lazy(TfLiteContext* context, TfLiteNode* node) { |
355 | return Prepare_impl(context, node); |
356 | } |
357 | |
358 | // Evaluate cond subgraph and set the result. |
359 | TfLiteStatus Eval_cond_subgraph(TfLiteContext* context, Subgraph* cond_subgraph, |
360 | bool cond_has_dynamic_output_tensors, |
361 | bool* cond_subgraph_output) { |
362 | TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke()); |
363 | int cond_subgraph_output_index = cond_subgraph->outputs()[0]; |
364 | cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index); |
365 | TfLiteTensor* cond_output = cond_subgraph->tensor(cond_subgraph_output_index); |
366 | if (cond_has_dynamic_output_tensors) { |
367 | TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); |
368 | } |
369 | |
370 | *cond_subgraph_output = (cond_output->data.b[0]); |
371 | return kTfLiteOk; |
372 | } |
373 | |
374 | // Evaluate WHILE op when body subgraph has dynamic outputs. |
375 | TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node) { |
376 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
377 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
378 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
379 | Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); |
380 | Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); |
381 | |
382 | // The follow graph illustrates the current implementation. |
383 | // |
384 | // This Subgraph Cond Subgraph Body Subgraph |
385 | // +-----------+ (1) +------------+ +------------+ |
386 | // | WHILE |-------->| SUBGRAPH | | SUBGRAPH | |
387 | // | INPUT | | INPUT | | INPUT | |
388 | // | | | ---------------->| | |
389 | // | | | / | <---- | | |
390 | // +-----------+ +--/---------+ \ +------------+ |
391 | // | / | \ | |
392 | // | (2) (4) / | (3) (6) \ | (5) |
393 | // v / v \ v |
394 | // +-----------+ / +------------+ +------------+ |
395 | // | WHILE |--/ | SUBGRAPH | | SUBGRAPH | |
396 | // | OUTPUT | (7) | OUTPUT | | OUTPUT | |
397 | // | |<-------------------------------| | |
398 | // +-----------+ +------------+ +------------+ |
399 | // |
400 | // (1) Copy the inputs of WHILE op to the inputs of condition subgraph. |
401 | // (2) Copy the inputs of WHILE op to the outputs of WHILE op |
402 | // (3) Invoke condition subgraph. |
403 | // Exit the loop if the result is false. |
404 | // (4) Copy the outputs of WHILE op to the inputs of body subgraph. |
405 | // (5) Invoke body subgraph. |
406 | // (6) Copy the outputs of body subgraph to the inputs condition subgraph. |
407 | // (7) Copy the outputs of body subgraph to the outputs of WHILE op. |
408 | // Jump back to step 3! |
409 | // |
410 | // If the body subgraph has dynamic sized outputs, it's required to resize the |
411 | // tensor before copying in step 1, 2, 4, 6 and 7. |
412 | // |
413 | // Note the flow is carefully designed to handle the dynamic sized output |
414 | // case. The loop invariant is: The newest value is in the inputs of condition |
415 | // subgraph. This is always true before step 3. |
416 | |
417 | // Step 1. node->inputs -> cond->inputs (fast) |
418 | TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData( |
419 | context, node, this_subgraph, |
420 | TfLiteIntArrayView(node->inputs), |
421 | cond_subgraph, cond_subgraph->inputs())); |
422 | |
423 | // Step 2. node->inputs -> node->outputs |
424 | TF_LITE_ENSURE_OK( |
425 | context, DeepCopyTensorsShapeTypeData(context, node, this_subgraph, |
426 | TfLiteIntArrayView(node->inputs), |
427 | this_subgraph, |
428 | TfLiteIntArrayView(node->outputs))); |
429 | |
430 | while (true) { |
431 | // Step 3. Eval cond subgraph |
432 | bool cond_subgraph_output; |
433 | TF_LITE_ENSURE_OK( |
434 | context, Eval_cond_subgraph(context, cond_subgraph, |
435 | op_data->cond_has_dynamic_output_tensors, |
436 | &cond_subgraph_output)); |
437 | if (!cond_subgraph_output) { |
438 | break; |
439 | } |
440 | |
441 | // Step 4. node->outputs -> body->inputs |
442 | if (op_data->body_use_shallow_copy) { |
443 | TF_LITE_ENSURE_OK(context, ShallowCopyTensorsShapeTypeData( |
444 | context, node, this_subgraph, |
445 | TfLiteIntArrayView(node->outputs), |
446 | body_subgraph, body_subgraph->inputs())); |
447 | } else { |
448 | TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData( |
449 | context, node, this_subgraph, |
450 | TfLiteIntArrayView(node->outputs), |
451 | body_subgraph, body_subgraph->inputs())); |
452 | } |
453 | |
454 | // Step 5. Invoke body subgraph |
455 | TF_LITE_ENSURE_OK(context, body_subgraph->Invoke()); |
456 | for (int tensor_index : body_subgraph->outputs()) { |
457 | body_subgraph->EnsureTensorDataIsReadable(tensor_index); |
458 | } |
459 | |
460 | // Step 6. body->outputs -> cond->inputs (fast) |
461 | TF_LITE_ENSURE_OK( |
462 | context, DeepCopyTensorsShapeTypeData( |
463 | context, node, body_subgraph, body_subgraph->outputs(), |
464 | cond_subgraph, cond_subgraph->inputs())); |
465 | |
466 | // Step 7. body->outputs -> node->outputs |
467 | TF_LITE_ENSURE_OK( |
468 | context, DeepCopyTensorsShapeTypeData( |
469 | context, node, body_subgraph, body_subgraph->outputs(), |
470 | this_subgraph, TfLiteIntArrayView(node->outputs))); |
471 | } |
472 | |
473 | if (op_data->body_use_shallow_copy) { |
474 | // Clean up shallow copied pointer of body inputs. |
475 | for (int i = 0; i < body_subgraph->inputs().size(); ++i) { |
476 | TfLiteTensor* body_input = |
477 | body_subgraph->tensor(body_subgraph->inputs()[i]); |
478 | body_input->data.raw = nullptr; |
479 | } |
480 | } |
481 | |
482 | return kTfLiteOk; |
483 | } |
484 | |
485 | // Evaluate WHILE op when body subgraph has static outputs. |
486 | TfLiteStatus Eval_static(TfLiteContext* context, TfLiteNode* node) { |
487 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
488 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
489 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
490 | Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); |
491 | Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); |
492 | |
493 | // The follow graph illustrates the current implementation. |
494 | // |
495 | // This Subgraph Cond Subgraph Body Subgraph |
496 | // +-----------+ (1) +------------+ +------------+ |
497 | // | WHILE |-------->| SUBGRAPH | | SUBGRAPH | |
498 | // | INPUT | (3-1) /| INPUT | | INPUT | |
499 | // | |------------------------------->| | |
500 | // | | | | <---- | | |
501 | // +-----------+ +------------+ \ +------------+ |
502 | // | \ | ^ |
503 | // | (2) (5) \ | (4) | (3-2) |
504 | // v \ v | |
505 | // +-----------+ +------------+ +------------+ |
506 | // | WHILE | | SUBGRAPH | | SUBGRAPH | |
507 | // | OUTPUT | (6) | OUTPUT | | OUTPUT | |
508 | // | |<-------------------------------| | |
509 | // +-----------+ +------------+ +------------+ |
510 | // |
511 | // (1) Copy the inputs of WHILE op to the inputs of condition subgraph. |
512 | // (2) Invoke condition subgraph. |
513 | // Jump to step 6 if the result is false. |
514 | // (3) If body is never invoked, run the step 3-1, else run the step 3-2. |
515 | // (3-1) Copy the inputs of WHILE op to the inputs of body subgraph. |
516 | // (3-2) Copy the outputs of body subgraph to the inputs of body subgraph. |
517 | // (4) Invoke body subgraph. |
518 | // (5) Copy the outputs of body subgraph to the inputs condition subgraph. |
519 | // Jump back to step 2! |
520 | // (6) Copy the outputs of body subgraph to the outputs of WHILE op. |
521 | // |
522 | // The body subgraph shouldn't have dynamic sized outputs. |
523 | |
524 | // Step 1. node->inputs -> cond->inputs (fast) |
525 | TF_LITE_ENSURE_OK( |
526 | context, |
527 | CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs), |
528 | cond_subgraph, cond_subgraph->inputs())); |
529 | |
530 | bool body_invoked = false; |
531 | while (true) { |
532 | // Step 2. Eval cond subgraph |
533 | bool cond_subgraph_output; |
534 | TF_LITE_ENSURE_OK( |
535 | context, Eval_cond_subgraph(context, cond_subgraph, |
536 | op_data->cond_has_dynamic_output_tensors, |
537 | &cond_subgraph_output)); |
538 | if (!cond_subgraph_output) { |
539 | break; |
540 | } |
541 | |
542 | if (body_invoked) { |
543 | // Step 3-2. body->output -> body->inputs |
544 | TF_LITE_ENSURE_OK( |
545 | context, |
546 | CopyTensorsData(context, body_subgraph, body_subgraph->outputs(), |
547 | body_subgraph, body_subgraph->inputs())); |
548 | } else { |
549 | // Step 3-1. node->inputs -> body->inputs |
550 | TF_LITE_ENSURE_OK( |
551 | context, CopyTensorsData(context, this_subgraph, |
552 | TfLiteIntArrayView(node->inputs), |
553 | body_subgraph, body_subgraph->inputs())); |
554 | } |
555 | |
556 | // Step 4. Invoke body subgraph |
557 | TF_LITE_ENSURE_OK(context, body_subgraph->Invoke()); |
558 | body_invoked = true; |
559 | for (int tensor_index : body_subgraph->outputs()) { |
560 | body_subgraph->EnsureTensorDataIsReadable(tensor_index); |
561 | } |
562 | |
563 | // Step 5. body->output -> cond->inputs (fast) |
564 | TF_LITE_ENSURE_OK( |
565 | context, |
566 | CopyTensorsData(context, body_subgraph, body_subgraph->outputs(), |
567 | cond_subgraph, cond_subgraph->inputs())); |
568 | } |
569 | |
570 | if (body_invoked) { |
571 | // Step 6. Copy body->output -> node->outputs |
572 | TF_LITE_ENSURE_OK( |
573 | context, |
574 | CopyTensorsData(context, body_subgraph, body_subgraph->outputs(), |
575 | this_subgraph, TfLiteIntArrayView(node->outputs))); |
576 | } else { |
577 | // Copy node->inputs if body is never invoked. |
578 | TF_LITE_ENSURE_OK( |
579 | context, CopyTensorsData( |
580 | context, this_subgraph, TfLiteIntArrayView(node->inputs), |
581 | this_subgraph, TfLiteIntArrayView(node->outputs))); |
582 | } |
583 | |
584 | return kTfLiteOk; |
585 | } |
586 | |
587 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
588 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
589 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
590 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
591 | Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); |
592 | Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); |
593 | |
594 | if (op_data->subgraphs_prepared == false) { |
595 | TF_LITE_ENSURE_OK(context, Prepare_lazy(context, node)); |
596 | } else { |
597 | TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); |
598 | TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); |
599 | } |
600 | |
601 | if (op_data->body_has_dynamic_output_tensors) { |
602 | TF_LITE_ENSURE_OK(context, Eval_dynamic(context, node)); |
603 | } else { |
604 | TF_LITE_ENSURE_OK(context, Eval_static(context, node)); |
605 | } |
606 | |
607 | if (!this_subgraph->ShouldPreserveAllTensors()) { |
608 | TF_LITE_ENSURE_OK(context, cond_subgraph->ReleaseMemory()); |
609 | TF_LITE_ENSURE_OK(context, body_subgraph->ReleaseMemory()); |
610 | } |
611 | |
612 | return kTfLiteOk; |
613 | } |
614 | |
615 | } // namespace while_kernel |
616 | |
617 | TfLiteRegistration* Register_WHILE() { |
618 | static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free, |
619 | while_kernel::Prepare, while_kernel::Eval}; |
620 | return &r; |
621 | } |
622 | |
623 | } // namespace builtin |
624 | } // namespace ops |
625 | } // namespace tflite |
626 | |