1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stddef.h>
17
18#include <cstring>
19#include <memory>
20#include <vector>
21
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/core/subgraph.h"
25#include "tensorflow/lite/experimental/resource/initialization_status.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27
28namespace tflite {
29namespace ops {
30namespace builtin {
31namespace call_once_kernel {
32
33// CallOnce operator is a control flow op to invoke other subgraph in the graph
34// in order to conduct the given graph's initialization tasks, for example, hash
35// table initialization and variable initialization.
36//
37// This operator will invoke the subgraph for initialization in the first run
38// and become no-op after the first run in an interpreter's life cycle.
39
40struct OpData {
41 // Subgraph index to be invoked once in a life cycle by this CallOnce op.
42 int init_subgraph_index;
43};
44
45void* Init(TfLiteContext* context, const char* buffer, size_t length) {
46 auto* op_data = new OpData;
47 const auto* params = reinterpret_cast<const TfLiteCallOnceParams*>(buffer);
48 op_data->init_subgraph_index = params->init_subgraph_index;
49 return op_data;
50}
51
52void Free(TfLiteContext* context, void* buffer) {
53 delete reinterpret_cast<OpData*>(buffer);
54}
55
56TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
57 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
58
59 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
60
61 // Return early if the initialization graph is already invoked.
62 resource::InitializationStatusMap* map =
63 &this_subgraph->initialization_status_map();
64 resource::InitializationStatus* status =
65 resource::GetInitializationStatus(map, op_data->init_subgraph_index);
66 if (status->IsInitialized()) return kTfLiteOk;
67
68 auto* subgraphs = this_subgraph->GetSubgraphs();
69
70 TF_LITE_ENSURE_EQ(context, node->inputs->size, 0);
71 TF_LITE_ENSURE_EQ(context, node->outputs->size, 0);
72
73 TF_LITE_ENSURE(context, op_data->init_subgraph_index < subgraphs->size());
74
75 // Ensures that there are no input and output tensors in the subgraph.
76 Subgraph* init_subgraph = (*subgraphs)[op_data->init_subgraph_index].get();
77 TF_LITE_ENSURE_EQ(context, init_subgraph->inputs().size(), 0);
78 TF_LITE_ENSURE_EQ(context, init_subgraph->outputs().size(), 0);
79 return kTfLiteOk;
80}
81
82TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
83 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
84
85 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
86
87 // The initialization graph should be invoked once in a life cycle.
88 resource::InitializationStatusMap* map =
89 &this_subgraph->initialization_status_map();
90 resource::InitializationStatus* status =
91 resource::GetInitializationStatus(map, op_data->init_subgraph_index);
92 if (status->IsInitialized()) return kTfLiteOk;
93
94 auto* subgraphs = this_subgraph->GetSubgraphs();
95 Subgraph& init_subgraph = *(*subgraphs)[op_data->init_subgraph_index];
96
97 TF_LITE_ENSURE_OK(context, init_subgraph.AllocateTensors());
98 TF_LITE_ENSURE_OK(context, init_subgraph.Invoke());
99 TF_LITE_ENSURE_OK(context, init_subgraph.ReleaseNonPersistentMemory());
100
101 // Mark the invocation completed.
102 status->MarkInitializationIsDone();
103 return kTfLiteOk;
104}
105
106} // namespace call_once_kernel
107
108TfLiteRegistration* Register_CALL_ONCE() {
109 static TfLiteRegistration r = {call_once_kernel::Init, call_once_kernel::Free,
110 call_once_kernel::Prepare,
111 call_once_kernel::Eval};
112 return &r;
113}
114
115} // namespace builtin
116} // namespace ops
117} // namespace tflite
118