1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stddef.h> |
17 | |
18 | #include <cstring> |
19 | #include <memory> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/lite/c/builtin_op_data.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/core/subgraph.h" |
25 | #include "tensorflow/lite/experimental/resource/initialization_status.h" |
26 | #include "tensorflow/lite/kernels/kernel_util.h" |
27 | |
28 | namespace tflite { |
29 | namespace ops { |
30 | namespace builtin { |
31 | namespace call_once_kernel { |
32 | |
33 | // CallOnce operator is a control flow op to invoke other subgraph in the graph |
34 | // in order to conduct the given graph's initialization tasks, for example, hash |
35 | // table initialization and variable initialization. |
36 | // |
37 | // This operator will invoke the subgraph for initialization in the first run |
38 | // and become no-op after the first run in an interpreter's life cycle. |
39 | |
40 | struct OpData { |
41 | // Subgraph index to be invoked once in a life cycle by this CallOnce op. |
42 | int init_subgraph_index; |
43 | }; |
44 | |
45 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
46 | auto* op_data = new OpData; |
47 | const auto* params = reinterpret_cast<const TfLiteCallOnceParams*>(buffer); |
48 | op_data->init_subgraph_index = params->init_subgraph_index; |
49 | return op_data; |
50 | } |
51 | |
52 | void Free(TfLiteContext* context, void* buffer) { |
53 | delete reinterpret_cast<OpData*>(buffer); |
54 | } |
55 | |
56 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
57 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
58 | |
59 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
60 | |
61 | // Return early if the initialization graph is already invoked. |
62 | resource::InitializationStatusMap* map = |
63 | &this_subgraph->initialization_status_map(); |
64 | resource::InitializationStatus* status = |
65 | resource::GetInitializationStatus(map, op_data->init_subgraph_index); |
66 | if (status->IsInitialized()) return kTfLiteOk; |
67 | |
68 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
69 | |
70 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 0); |
71 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 0); |
72 | |
73 | TF_LITE_ENSURE(context, op_data->init_subgraph_index < subgraphs->size()); |
74 | |
75 | // Ensures that there are no input and output tensors in the subgraph. |
76 | Subgraph* init_subgraph = (*subgraphs)[op_data->init_subgraph_index].get(); |
77 | TF_LITE_ENSURE_EQ(context, init_subgraph->inputs().size(), 0); |
78 | TF_LITE_ENSURE_EQ(context, init_subgraph->outputs().size(), 0); |
79 | return kTfLiteOk; |
80 | } |
81 | |
82 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
83 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
84 | |
85 | Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
86 | |
87 | // The initialization graph should be invoked once in a life cycle. |
88 | resource::InitializationStatusMap* map = |
89 | &this_subgraph->initialization_status_map(); |
90 | resource::InitializationStatus* status = |
91 | resource::GetInitializationStatus(map, op_data->init_subgraph_index); |
92 | if (status->IsInitialized()) return kTfLiteOk; |
93 | |
94 | auto* subgraphs = this_subgraph->GetSubgraphs(); |
95 | Subgraph& init_subgraph = *(*subgraphs)[op_data->init_subgraph_index]; |
96 | |
97 | TF_LITE_ENSURE_OK(context, init_subgraph.AllocateTensors()); |
98 | TF_LITE_ENSURE_OK(context, init_subgraph.Invoke()); |
99 | TF_LITE_ENSURE_OK(context, init_subgraph.ReleaseNonPersistentMemory()); |
100 | |
101 | // Mark the invocation completed. |
102 | status->MarkInitializationIsDone(); |
103 | return kTfLiteOk; |
104 | } |
105 | |
106 | } // namespace call_once_kernel |
107 | |
108 | TfLiteRegistration* Register_CALL_ONCE() { |
109 | static TfLiteRegistration r = {call_once_kernel::Init, call_once_kernel::Free, |
110 | call_once_kernel::Prepare, |
111 | call_once_kernel::Eval}; |
112 | return &r; |
113 | } |
114 | |
115 | } // namespace builtin |
116 | } // namespace ops |
117 | } // namespace tflite |
118 | |