1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stddef.h>
17
18#include <cstring>
19#include <vector>
20
21#include "tensorflow/lite/c/builtin_op_data.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/context_util.h"
24#include "tensorflow/lite/core/subgraph.h"
25#include "tensorflow/lite/kernels/kernel_util.h"
26
27namespace tflite {
28namespace ops {
29namespace builtin {
30namespace while_kernel {
31
32struct OpData {
33 int cond_subgraph_index;
34 int body_subgraph_index;
35 bool cond_has_dynamic_output_tensors;
36 bool body_has_dynamic_output_tensors;
37 bool body_use_shallow_copy;
38 // set when Prepare_impl() is called.
39 bool subgraphs_prepared;
40};
41
42namespace {
43
44// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
45// to `dst_tensor_indices` in `dst_subgraph`.
46//
47// When `resize_subgraph_inputs` is true, the function calls subgraphs's
48// `ResizeInputTensor` function, and it may trigger the memory planner to
49// reallocate memory.
50// When `resize_subgraph_inputs` is false, it implies `context` belongs to
51// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
52// when resizing `While` op's outputs.
53template <typename SrcVector, typename DstVector>
54TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
55 Subgraph* src_subgraph,
56 const SrcVector& src_tensor_indices,
57 Subgraph* dst_subgraph,
58 const DstVector& dst_tensor_indices,
59 bool resize_subgraph_inputs) {
60 TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
61 dst_tensor_indices.size());
62 for (int i = 0; i < src_tensor_indices.size(); ++i) {
63 // Skip copying unused destination tensors.
64 if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
65
66 const TfLiteTensor* src_tensor =
67 src_subgraph->tensor(src_tensor_indices[i]);
68
69 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
70 if (resize_subgraph_inputs) {
71 std::vector<int> dims(src_tensor->dims->data,
72 src_tensor->dims->data + src_tensor->dims->size);
73 dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
74 } else {
75 TF_LITE_ENSURE_OK(
76 context, context->ResizeTensor(context, dst_tensor,
77 TfLiteIntArrayCopy(src_tensor->dims)));
78 }
79 dst_tensor->type = src_tensor->type;
80 }
81 return kTfLiteOk;
82}
83
84// Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph`
85// to `dst_tensor_indices` in `dst_subgraph`.
86template <typename SrcVector, typename DstVector>
87TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph,
88 const SrcVector& src_tensor_indices,
89 Subgraph* dst_subgraph,
90 const DstVector& dst_tensor_indices) {
91 TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
92 dst_tensor_indices.size());
93 for (int i = 0; i < src_tensor_indices.size(); ++i) {
94 // Skip copying unused destination tensors.
95 if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
96
97 const TfLiteTensor* src_tensor =
98 src_subgraph->tensor(src_tensor_indices[i]);
99 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
100 if (IsDynamicTensor(dst_tensor)) {
101 TfLiteTensorRealloc(src_tensor->bytes, dst_tensor);
102 }
103 TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
104 TfLiteTensorCopy(src_tensor, dst_tensor);
105 }
106 return kTfLiteOk;
107}
108
109// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
110// to `dst_tensor_indices` in `dst_subgraph` and copy data deeply.
111template <typename SrcVector, typename DstVector>
112TfLiteStatus DeepCopyTensorsShapeTypeData(TfLiteContext* context,
113 TfLiteNode* node,
114 Subgraph* src_subgraph,
115 const SrcVector& src_tensor_indices,
116 Subgraph* dst_subgraph,
117 const DstVector& dst_tensor_indices) {
118 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
119
120 if (op_data->body_has_dynamic_output_tensors) {
121 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
122 bool resize_subgraph_inputs = (dst_subgraph != this_subgraph);
123 TF_LITE_ENSURE_OK(
124 context, CopyTensorsShapeAndType(
125 context, src_subgraph, src_tensor_indices, dst_subgraph,
126 dst_tensor_indices, resize_subgraph_inputs));
127 if (resize_subgraph_inputs) {
128 TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors());
129 }
130 }
131 TF_LITE_ENSURE_OK(context,
132 CopyTensorsData(context, src_subgraph, src_tensor_indices,
133 dst_subgraph, dst_tensor_indices));
134 return kTfLiteOk;
135}
136
137// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
138// to `dst_tensor_indices` in `dst_subgraph` and copy data shallowly.
139template <typename SrcVector, typename DstVector>
140TfLiteStatus ShallowCopyTensorsShapeTypeData(
141 TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph,
142 const SrcVector& src_tensor_indices, Subgraph* dst_subgraph,
143 const DstVector& dst_tensor_indices) {
144 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
145 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
146 TF_LITE_ENSURE_EQ(context, op_data->body_has_dynamic_output_tensors, true);
147 // Only allow shallow copy from main node input.
148 TF_LITE_ENSURE_EQ(context, src_subgraph, this_subgraph);
149
150 TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
151 dst_tensor_indices.size());
152 bool reallocation_needed = false;
153 for (int i = 0; i < src_tensor_indices.size(); ++i) {
154 // Skip copying unused destination tensors.
155 if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
156
157 const TfLiteTensor* src_tensor =
158 src_subgraph->tensor(src_tensor_indices[i]);
159 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
160
161 if (!TfLiteIntArrayEqual(src_tensor->dims, dst_tensor->dims)) {
162 reallocation_needed = true;
163 TfLiteIntArrayFree(dst_tensor->dims);
164 dst_tensor->dims = TfLiteIntArrayCopy(src_tensor->dims);
165 }
166 dst_tensor->type = src_tensor->type;
167 dst_tensor->bytes = 0; // Don't allocate memory with AllocateTensors().
168 dst_tensor->data.raw = nullptr;
169 }
170
171 if (reallocation_needed && dst_subgraph != this_subgraph) {
172 TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors());
173 }
174
175 for (int i = 0; i < src_tensor_indices.size(); ++i) {
176 // Skip copying unused destination tensors.
177 if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
178
179 const TfLiteTensor* src_tensor =
180 src_subgraph->tensor(src_tensor_indices[i]);
181 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
182
183 dst_tensor->bytes = src_tensor->bytes;
184 dst_tensor->data.raw = src_tensor->data.raw;
185 }
186
187 return kTfLiteOk;
188}
189
190TfLiteStatus CheckCondOutput(TfLiteContext* context,
191 const TfLiteTensor* cond_output) {
192 // The condition output must be a single boolean value.
193 TF_LITE_ENSURE_TYPES_EQ(context, cond_output->type, kTfLiteBool);
194 if (cond_output->dims->size == 0) {
195 // It's okay if it's a 0D scalar.
196 return kTfLiteOk;
197 }
198 // Otherwise it must be 1D with shape [1].
199 TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
200 TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
201 return kTfLiteOk;
202}
203
204} // namespace
205
206void* Init(TfLiteContext* context, const char* buffer, size_t length) {
207 auto* op_data = new OpData;
208 const auto* params = reinterpret_cast<const TfLiteWhileParams*>(buffer);
209 op_data->cond_subgraph_index = params->cond_subgraph_index;
210 op_data->body_subgraph_index = params->body_subgraph_index;
211 op_data->cond_has_dynamic_output_tensors = false;
212 op_data->body_has_dynamic_output_tensors = false;
213 op_data->body_use_shallow_copy = false;
214 op_data->subgraphs_prepared = false;
215 return op_data;
216}
217
218void Free(TfLiteContext* context, void* buffer) {
219 delete reinterpret_cast<OpData*>(buffer);
220}
221
222TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) {
223 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
224 int num_inputs = node->inputs->size;
225 // The number of outputs should be the same as number of inputs.
226 TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
227
228 // Check subgraph indices and get subgraphs.
229 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
230 auto* subgraphs = this_subgraph->GetSubgraphs();
231 TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
232 TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
233 TF_LITE_ENSURE(context,
234 op_data->cond_subgraph_index != op_data->body_subgraph_index);
235
236 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
237 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
238
239 // Check input & output count of the condition subgraph.
240 TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
241 TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
242
243 // Check input & output count of the body subgraph.
244 TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
245 TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
246
247 // Remove unused inputs of the condition subgraph to skip copying unnecessary
248 // inputs.
249 cond_subgraph->RemoveUnusedInputs();
250
251 // Prepare and check the condition subgraph.
252 TF_LITE_ENSURE_OK(
253 context, CopyTensorsShapeAndType(
254 context, this_subgraph, TfLiteIntArrayView(node->inputs),
255 cond_subgraph, cond_subgraph->inputs(), true));
256 TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
257 TfLiteTensor* cond_output =
258 cond_subgraph->tensor(cond_subgraph->outputs()[0]);
259 // This should rarely happens. In most cases the output is static with shape
260 // [1]. However theoretically intermediate tensors in the cond subgraph
261 // can be dynamic.
262 if (IsDynamicTensor(cond_output)) {
263 op_data->cond_has_dynamic_output_tensors = true;
264 } else {
265 TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
266 }
267
268 // Prepare and check the body subgraph.
269 TF_LITE_ENSURE_OK(
270 context, CopyTensorsShapeAndType(
271 context, this_subgraph, TfLiteIntArrayView(node->inputs),
272 body_subgraph, body_subgraph->inputs(), true));
273
274 bool input_has_resource_or_variant_tensor = false;
275 for (int i = 0; i < num_inputs; ++i) {
276 if (IsResourceOrVariant(
277 body_subgraph->tensor(body_subgraph->inputs()[i]))) {
278 input_has_resource_or_variant_tensor = true;
279 break;
280 }
281 }
282 if (this_subgraph->ShouldOptimizeMemoryForLargeTensors() &&
283 !input_has_resource_or_variant_tensor) {
284 // The current shallow copy requires to use dynamic tensors which introduces
285 // additional overheads. Therefore, use the method only if dynamic
286 // allocation is enabled.
287 op_data->body_use_shallow_copy = true;
288 op_data->body_has_dynamic_output_tensors = true;
289 // Make body inputs dynamic to use shallow copy with Eval_dynamic().
290 for (int i = 0; i < num_inputs; ++i) {
291 TfLiteTensor* body_input =
292 body_subgraph->tensor(body_subgraph->inputs()[i]);
293 SetTensorToDynamic(body_input);
294 body_input->bytes = 0;
295 }
296 }
297
298 TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
299 if (body_subgraph->HasDynamicTensors()) {
300 op_data->body_has_dynamic_output_tensors = true;
301 } else {
302 for (int i = 0; i < num_inputs; ++i) {
303 TfLiteTensor* body_input =
304 body_subgraph->tensor(body_subgraph->inputs()[i]);
305 TfLiteTensor* body_output =
306 body_subgraph->tensor(body_subgraph->outputs()[i]);
307 TF_LITE_ENSURE_TYPES_EQ(context, body_input->type, body_output->type);
308
309 TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
310 if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) {
311 // If the output shape of the body subgraph is static w.r.t. a fixed
312 // input size, but it's different from input size, it's still considered
313 // dynamic. For example: If a subgraph keeps padding its input with a
314 // fixed padding, the output shape is static w.r.t the input shape and
315 // padding, but running it in a loop will keep bloating the tensor.
316 op_data->body_has_dynamic_output_tensors = true;
317 break;
318 }
319 }
320 }
321 for (int i = 0; i < num_inputs; ++i) {
322 TfLiteTensor* output;
323 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
324 if (op_data->body_has_dynamic_output_tensors) {
325 SetTensorToDynamic(output);
326 } else {
327 TfLiteTensor* body_output =
328 body_subgraph->tensor(body_subgraph->outputs()[i]);
329 TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
330 TF_LITE_ENSURE_OK(context,
331 context->ResizeTensor(context, output, output_size));
332 }
333 }
334 op_data->subgraphs_prepared = true;
335 return kTfLiteOk;
336}
337
338TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
339 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
340 if (this_subgraph->ShouldOptimizeMemoryForLargeTensors()) {
341 // Apply lazy initialization of WHILE kernel.
342 // Just make node output tensors dynamic.
343 int num_outputs = node->outputs->size;
344 for (int i = 0; i < num_outputs; ++i) {
345 TfLiteTensor* output;
346 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
347 SetTensorToDynamic(output);
348 }
349 return kTfLiteOk;
350 }
351 return Prepare_impl(context, node);
352}
353
354TfLiteStatus Prepare_lazy(TfLiteContext* context, TfLiteNode* node) {
355 return Prepare_impl(context, node);
356}
357
358// Evaluate cond subgraph and set the result.
359TfLiteStatus Eval_cond_subgraph(TfLiteContext* context, Subgraph* cond_subgraph,
360 bool cond_has_dynamic_output_tensors,
361 bool* cond_subgraph_output) {
362 TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
363 int cond_subgraph_output_index = cond_subgraph->outputs()[0];
364 cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index);
365 TfLiteTensor* cond_output = cond_subgraph->tensor(cond_subgraph_output_index);
366 if (cond_has_dynamic_output_tensors) {
367 TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
368 }
369
370 *cond_subgraph_output = (cond_output->data.b[0]);
371 return kTfLiteOk;
372}
373
374// Evaluate WHILE op when body subgraph has dynamic outputs.
375TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node) {
376 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
377 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
378 auto* subgraphs = this_subgraph->GetSubgraphs();
379 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
380 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
381
382 // The follow graph illustrates the current implementation.
383 //
384 // This Subgraph Cond Subgraph Body Subgraph
385 // +-----------+ (1) +------------+ +------------+
386 // | WHILE |-------->| SUBGRAPH | | SUBGRAPH |
387 // | INPUT | | INPUT | | INPUT |
388 // | | | ---------------->| |
389 // | | | / | <---- | |
390 // +-----------+ +--/---------+ \ +------------+
391 // | / | \ |
392 // | (2) (4) / | (3) (6) \ | (5)
393 // v / v \ v
394 // +-----------+ / +------------+ +------------+
395 // | WHILE |--/ | SUBGRAPH | | SUBGRAPH |
396 // | OUTPUT | (7) | OUTPUT | | OUTPUT |
397 // | |<-------------------------------| |
398 // +-----------+ +------------+ +------------+
399 //
400 // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
401 // (2) Copy the inputs of WHILE op to the outputs of WHILE op
402 // (3) Invoke condition subgraph.
403 // Exit the loop if the result is false.
404 // (4) Copy the outputs of WHILE op to the inputs of body subgraph.
405 // (5) Invoke body subgraph.
406 // (6) Copy the outputs of body subgraph to the inputs condition subgraph.
407 // (7) Copy the outputs of body subgraph to the outputs of WHILE op.
408 // Jump back to step 3!
409 //
410 // If the body subgraph has dynamic sized outputs, it's required to resize the
411 // tensor before copying in step 1, 2, 4, 6 and 7.
412 //
413 // Note the flow is carefully designed to handle the dynamic sized output
414 // case. The loop invariant is: The newest value is in the inputs of condition
415 // subgraph. This is always true before step 3.
416
417 // Step 1. node->inputs -> cond->inputs (fast)
418 TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData(
419 context, node, this_subgraph,
420 TfLiteIntArrayView(node->inputs),
421 cond_subgraph, cond_subgraph->inputs()));
422
423 // Step 2. node->inputs -> node->outputs
424 TF_LITE_ENSURE_OK(
425 context, DeepCopyTensorsShapeTypeData(context, node, this_subgraph,
426 TfLiteIntArrayView(node->inputs),
427 this_subgraph,
428 TfLiteIntArrayView(node->outputs)));
429
430 while (true) {
431 // Step 3. Eval cond subgraph
432 bool cond_subgraph_output;
433 TF_LITE_ENSURE_OK(
434 context, Eval_cond_subgraph(context, cond_subgraph,
435 op_data->cond_has_dynamic_output_tensors,
436 &cond_subgraph_output));
437 if (!cond_subgraph_output) {
438 break;
439 }
440
441 // Step 4. node->outputs -> body->inputs
442 if (op_data->body_use_shallow_copy) {
443 TF_LITE_ENSURE_OK(context, ShallowCopyTensorsShapeTypeData(
444 context, node, this_subgraph,
445 TfLiteIntArrayView(node->outputs),
446 body_subgraph, body_subgraph->inputs()));
447 } else {
448 TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData(
449 context, node, this_subgraph,
450 TfLiteIntArrayView(node->outputs),
451 body_subgraph, body_subgraph->inputs()));
452 }
453
454 // Step 5. Invoke body subgraph
455 TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
456 for (int tensor_index : body_subgraph->outputs()) {
457 body_subgraph->EnsureTensorDataIsReadable(tensor_index);
458 }
459
460 // Step 6. body->outputs -> cond->inputs (fast)
461 TF_LITE_ENSURE_OK(
462 context, DeepCopyTensorsShapeTypeData(
463 context, node, body_subgraph, body_subgraph->outputs(),
464 cond_subgraph, cond_subgraph->inputs()));
465
466 // Step 7. body->outputs -> node->outputs
467 TF_LITE_ENSURE_OK(
468 context, DeepCopyTensorsShapeTypeData(
469 context, node, body_subgraph, body_subgraph->outputs(),
470 this_subgraph, TfLiteIntArrayView(node->outputs)));
471 }
472
473 if (op_data->body_use_shallow_copy) {
474 // Clean up shallow copied pointer of body inputs.
475 for (int i = 0; i < body_subgraph->inputs().size(); ++i) {
476 TfLiteTensor* body_input =
477 body_subgraph->tensor(body_subgraph->inputs()[i]);
478 body_input->data.raw = nullptr;
479 }
480 }
481
482 return kTfLiteOk;
483}
484
485// Evaluate WHILE op when body subgraph has static outputs.
486TfLiteStatus Eval_static(TfLiteContext* context, TfLiteNode* node) {
487 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
488 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
489 auto* subgraphs = this_subgraph->GetSubgraphs();
490 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
491 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
492
493 // The follow graph illustrates the current implementation.
494 //
495 // This Subgraph Cond Subgraph Body Subgraph
496 // +-----------+ (1) +------------+ +------------+
497 // | WHILE |-------->| SUBGRAPH | | SUBGRAPH |
498 // | INPUT | (3-1) /| INPUT | | INPUT |
499 // | |------------------------------->| |
500 // | | | | <---- | |
501 // +-----------+ +------------+ \ +------------+
502 // | \ | ^
503 // | (2) (5) \ | (4) | (3-2)
504 // v \ v |
505 // +-----------+ +------------+ +------------+
506 // | WHILE | | SUBGRAPH | | SUBGRAPH |
507 // | OUTPUT | (6) | OUTPUT | | OUTPUT |
508 // | |<-------------------------------| |
509 // +-----------+ +------------+ +------------+
510 //
511 // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
512 // (2) Invoke condition subgraph.
513 // Jump to step 6 if the result is false.
514 // (3) If body is never invoked, run the step 3-1, else run the step 3-2.
515 // (3-1) Copy the inputs of WHILE op to the inputs of body subgraph.
516 // (3-2) Copy the outputs of body subgraph to the inputs of body subgraph.
517 // (4) Invoke body subgraph.
518 // (5) Copy the outputs of body subgraph to the inputs condition subgraph.
519 // Jump back to step 2!
520 // (6) Copy the outputs of body subgraph to the outputs of WHILE op.
521 //
522 // The body subgraph shouldn't have dynamic sized outputs.
523
524 // Step 1. node->inputs -> cond->inputs (fast)
525 TF_LITE_ENSURE_OK(
526 context,
527 CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
528 cond_subgraph, cond_subgraph->inputs()));
529
530 bool body_invoked = false;
531 while (true) {
532 // Step 2. Eval cond subgraph
533 bool cond_subgraph_output;
534 TF_LITE_ENSURE_OK(
535 context, Eval_cond_subgraph(context, cond_subgraph,
536 op_data->cond_has_dynamic_output_tensors,
537 &cond_subgraph_output));
538 if (!cond_subgraph_output) {
539 break;
540 }
541
542 if (body_invoked) {
543 // Step 3-2. body->output -> body->inputs
544 TF_LITE_ENSURE_OK(
545 context,
546 CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
547 body_subgraph, body_subgraph->inputs()));
548 } else {
549 // Step 3-1. node->inputs -> body->inputs
550 TF_LITE_ENSURE_OK(
551 context, CopyTensorsData(context, this_subgraph,
552 TfLiteIntArrayView(node->inputs),
553 body_subgraph, body_subgraph->inputs()));
554 }
555
556 // Step 4. Invoke body subgraph
557 TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
558 body_invoked = true;
559 for (int tensor_index : body_subgraph->outputs()) {
560 body_subgraph->EnsureTensorDataIsReadable(tensor_index);
561 }
562
563 // Step 5. body->output -> cond->inputs (fast)
564 TF_LITE_ENSURE_OK(
565 context,
566 CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
567 cond_subgraph, cond_subgraph->inputs()));
568 }
569
570 if (body_invoked) {
571 // Step 6. Copy body->output -> node->outputs
572 TF_LITE_ENSURE_OK(
573 context,
574 CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
575 this_subgraph, TfLiteIntArrayView(node->outputs)));
576 } else {
577 // Copy node->inputs if body is never invoked.
578 TF_LITE_ENSURE_OK(
579 context, CopyTensorsData(
580 context, this_subgraph, TfLiteIntArrayView(node->inputs),
581 this_subgraph, TfLiteIntArrayView(node->outputs)));
582 }
583
584 return kTfLiteOk;
585}
586
587TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
588 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
589 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
590 auto* subgraphs = this_subgraph->GetSubgraphs();
591 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
592 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
593
594 if (op_data->subgraphs_prepared == false) {
595 TF_LITE_ENSURE_OK(context, Prepare_lazy(context, node));
596 } else {
597 TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
598 TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
599 }
600
601 if (op_data->body_has_dynamic_output_tensors) {
602 TF_LITE_ENSURE_OK(context, Eval_dynamic(context, node));
603 } else {
604 TF_LITE_ENSURE_OK(context, Eval_static(context, node));
605 }
606
607 if (!this_subgraph->ShouldPreserveAllTensors()) {
608 TF_LITE_ENSURE_OK(context, cond_subgraph->ReleaseMemory());
609 TF_LITE_ENSURE_OK(context, body_subgraph->ReleaseMemory());
610 }
611
612 return kTfLiteOk;
613}
614
615} // namespace while_kernel
616
617TfLiteRegistration* Register_WHILE() {
618 static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
619 while_kernel::Prepare, while_kernel::Eval};
620 return &r;
621}
622
623} // namespace builtin
624} // namespace ops
625} // namespace tflite
626