1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | #include <string.h> |
18 | |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/lite/c/builtin_op_data.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/core/subgraph.h" |
25 | #include "tensorflow/lite/experimental/resource/resource_variable.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor.h" |
27 | #include "tensorflow/lite/kernels/kernel_util.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace var_handle { |
33 | // Util struct with params that identifies the resource. |
34 | struct VarParams { |
35 | int resource_id; |
36 | }; |
37 | |
38 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
39 | const auto* var_params = |
40 | reinterpret_cast<const TfLiteVarHandleParams*>(buffer); |
41 | VarParams* params = new VarParams; |
42 | auto* subgraph = reinterpret_cast<Subgraph*>(context->impl_); |
43 | // Create a new entry if doesn't exist, return the existing one otherwise. |
44 | auto it = subgraph->resource_ids().insert(std::make_pair( |
45 | std::make_pair( |
46 | std::string(var_params->container ? var_params->container : "" ), |
47 | std::string(var_params->shared_name ? var_params->shared_name : "" )), |
48 | static_cast<int>(subgraph->resource_ids().size()))); |
49 | params->resource_id = it.first->second; |
50 | return params; |
51 | } |
52 | |
53 | void Free(TfLiteContext* context, void* buffer) { |
54 | delete reinterpret_cast<VarParams*>(buffer); |
55 | } |
56 | |
57 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
58 | TfLiteTensor* output; |
59 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
60 | |
61 | const int kBytesRequired = sizeof(int32_t); |
62 | TfLiteTensorRealloc(kBytesRequired, output); |
63 | output->bytes = kBytesRequired; |
64 | |
65 | return kTfLiteOk; |
66 | } |
67 | |
68 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
69 | auto* op_data = static_cast<VarParams*>(node->user_data); |
70 | TF_LITE_ENSURE(context, op_data != nullptr); |
71 | |
72 | TfLiteTensor* output; |
73 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
74 | |
75 | memcpy(output->data.raw, reinterpret_cast<char*>(&op_data->resource_id), |
76 | sizeof(op_data->resource_id)); |
77 | return kTfLiteOk; |
78 | } |
79 | |
80 | } // namespace var_handle |
81 | |
82 | TfLiteRegistration* Register_VAR_HANDLE() { |
83 | static TfLiteRegistration r = {var_handle::Init, var_handle::Free, |
84 | var_handle::Prepare, var_handle::Eval}; |
85 | return &r; |
86 | } |
87 | |
88 | } // namespace builtin |
89 | } // namespace ops |
90 | } // namespace tflite |
91 | |