1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stddef.h> |
17 | #include <stdint.h> |
18 | |
19 | #include <map> |
20 | #include <memory> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/lite/c/builtin_op_data.h" |
24 | #include "tensorflow/lite/c/common.h" |
25 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace unique { |
34 | |
35 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
36 | return nullptr; |
37 | } |
38 | |
39 | void Free(TfLiteContext* context, void* buffer) {} |
40 | |
41 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
42 | static const int kOutputUniqueTensor = 0; |
43 | static const int kOutputIndexTensor = 1; |
44 | |
45 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
46 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); |
47 | const TfLiteTensor* input; |
48 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
49 | TfLiteTensor* output_unique_tensor; |
50 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor, |
51 | &output_unique_tensor)); |
52 | TfLiteTensor* output_index_tensor; |
53 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor, |
54 | &output_index_tensor)); |
55 | |
56 | // The op only supports 1D input. |
57 | TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); |
58 | TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims); |
59 | // The unique values are determined during evaluation, so we don't know yet |
60 | // the size of the output tensor. |
61 | SetTensorToDynamic(output_unique_tensor); |
62 | return context->ResizeTensor(context, output_index_tensor, |
63 | output_index_shape); |
64 | } |
65 | |
66 | namespace { |
67 | |
68 | // Actual evaluation for the unique op. |
69 | template <typename T, typename I> |
70 | TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input, |
71 | TfLiteNode* node) { |
72 | // Map from value, to index in the unique elements vector. |
73 | // Note that we prefer to use map than unordered_map as it showed less |
74 | // increase in the binary size. |
75 | std::map<T, int> unique_values; |
76 | TfLiteTensor* output_indexes; |
77 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes)); |
78 | std::vector<T> output_values; |
79 | I* indexes = GetTensorData<I>(output_indexes); |
80 | const T* data = GetTensorData<T>(input); |
81 | const int num_elements = NumElements(input); |
82 | |
83 | for (int i = 0; i < num_elements; ++i) { |
84 | const auto element_it = unique_values.find(data[i]); |
85 | if (element_it != unique_values.end()) { |
86 | indexes[i] = element_it->second; |
87 | } else { |
88 | const int unique_index = unique_values.size(); |
89 | unique_values[data[i]] = unique_index; |
90 | indexes[i] = unique_index; |
91 | output_values.push_back(data[i]); |
92 | } |
93 | } |
94 | // Allocate output tensor. |
95 | TfLiteTensor* unique_output; |
96 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output)); |
97 | std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape( |
98 | TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree); |
99 | shape->data[0] = unique_values.size(); |
100 | TF_LITE_ENSURE_STATUS( |
101 | context->ResizeTensor(context, unique_output, shape.release())); |
102 | // Set the values in the output tensor. |
103 | T* output_unique_values = GetTensorData<T>(unique_output); |
104 | for (int i = 0; i < output_values.size(); ++i) { |
105 | output_unique_values[i] = output_values[i]; |
106 | } |
107 | return kTfLiteOk; |
108 | } |
109 | |
110 | template <typename T> |
111 | TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input, |
112 | TfLiteNode* node) { |
113 | auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data); |
114 | if (params == nullptr) { |
115 | TF_LITE_KERNEL_LOG(context, "Null params passed" ); |
116 | return kTfLiteError; |
117 | } |
118 | switch (params->index_out_type) { |
119 | case kTfLiteInt32: |
120 | return EvalImpl<T, int32_t>(context, input, node); |
121 | case kTfLiteInt64: |
122 | return EvalImpl<T, int64_t>(context, input, node); |
123 | default: |
124 | TF_LITE_KERNEL_LOG( |
125 | context, |
126 | "Unique index output array can only be Int32 or In64, requested: %s" , |
127 | TfLiteTypeGetName(params->index_out_type)); |
128 | } |
129 | return kTfLiteError; |
130 | } |
131 | |
132 | } // namespace |
133 | |
134 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
135 | const TfLiteTensor* input; |
136 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
137 | TfLiteTensor* output_index_tensor; |
138 | TF_LITE_ENSURE_OK(context, |
139 | GetOutputSafe(context, node, 1, &output_index_tensor)); |
140 | TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor), |
141 | NumElements(input)); |
142 | |
143 | switch (input->type) { |
144 | case kTfLiteInt8: |
145 | TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node)); |
146 | break; |
147 | case kTfLiteInt16: |
148 | TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node)); |
149 | break; |
150 | case kTfLiteInt32: |
151 | TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node)); |
152 | break; |
153 | case kTfLiteInt64: |
154 | TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node)); |
155 | break; |
156 | case kTfLiteFloat32: |
157 | TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node)); |
158 | break; |
159 | case kTfLiteUInt8: |
160 | TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node)); |
161 | break; |
162 | default: |
163 | TF_LITE_KERNEL_LOG(context, "Currently Unique doesn't support type: %s" , |
164 | TfLiteTypeGetName(input->type)); |
165 | return kTfLiteError; |
166 | } |
167 | return kTfLiteOk; |
168 | } |
169 | |
170 | } // namespace unique |
171 | |
172 | TfLiteRegistration* Register_UNIQUE() { |
173 | static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare, |
174 | unique::Eval}; |
175 | return &r; |
176 | } |
177 | |
178 | } // namespace builtin |
179 | } // namespace ops |
180 | } // namespace tflite |
181 | |