1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stddef.h>
17#include <stdint.h>
18
19#include <map>
20#include <memory>
21#include <vector>
22
23#include "tensorflow/lite/c/builtin_op_data.h"
24#include "tensorflow/lite/c/common.h"
25#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26#include "tensorflow/lite/kernels/internal/tensor.h"
27#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace unique {
34
35void* Init(TfLiteContext* context, const char* buffer, size_t length) {
36 return nullptr;
37}
38
39void Free(TfLiteContext* context, void* buffer) {}
40
41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 static const int kOutputUniqueTensor = 0;
43 static const int kOutputIndexTensor = 1;
44
45 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
47 const TfLiteTensor* input;
48 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
49 TfLiteTensor* output_unique_tensor;
50 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
51 &output_unique_tensor));
52 TfLiteTensor* output_index_tensor;
53 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
54 &output_index_tensor));
55
56 // The op only supports 1D input.
57 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
58 TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
59 // The unique values are determined during evaluation, so we don't know yet
60 // the size of the output tensor.
61 SetTensorToDynamic(output_unique_tensor);
62 return context->ResizeTensor(context, output_index_tensor,
63 output_index_shape);
64}
65
66namespace {
67
68// Actual evaluation for the unique op.
69template <typename T, typename I>
70TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
71 TfLiteNode* node) {
72 // Map from value, to index in the unique elements vector.
73 // Note that we prefer to use map than unordered_map as it showed less
74 // increase in the binary size.
75 std::map<T, int> unique_values;
76 TfLiteTensor* output_indexes;
77 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
78 std::vector<T> output_values;
79 I* indexes = GetTensorData<I>(output_indexes);
80 const T* data = GetTensorData<T>(input);
81 const int num_elements = NumElements(input);
82
83 for (int i = 0; i < num_elements; ++i) {
84 const auto element_it = unique_values.find(data[i]);
85 if (element_it != unique_values.end()) {
86 indexes[i] = element_it->second;
87 } else {
88 const int unique_index = unique_values.size();
89 unique_values[data[i]] = unique_index;
90 indexes[i] = unique_index;
91 output_values.push_back(data[i]);
92 }
93 }
94 // Allocate output tensor.
95 TfLiteTensor* unique_output;
96 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
97 std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
98 TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
99 shape->data[0] = unique_values.size();
100 TF_LITE_ENSURE_STATUS(
101 context->ResizeTensor(context, unique_output, shape.release()));
102 // Set the values in the output tensor.
103 T* output_unique_values = GetTensorData<T>(unique_output);
104 for (int i = 0; i < output_values.size(); ++i) {
105 output_unique_values[i] = output_values[i];
106 }
107 return kTfLiteOk;
108}
109
110template <typename T>
111TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
112 TfLiteNode* node) {
113 auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
114 if (params == nullptr) {
115 TF_LITE_KERNEL_LOG(context, "Null params passed");
116 return kTfLiteError;
117 }
118 switch (params->index_out_type) {
119 case kTfLiteInt32:
120 return EvalImpl<T, int32_t>(context, input, node);
121 case kTfLiteInt64:
122 return EvalImpl<T, int64_t>(context, input, node);
123 default:
124 TF_LITE_KERNEL_LOG(
125 context,
126 "Unique index output array can only be Int32 or In64, requested: %s",
127 TfLiteTypeGetName(params->index_out_type));
128 }
129 return kTfLiteError;
130}
131
132} // namespace
133
134TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
135 const TfLiteTensor* input;
136 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
137 TfLiteTensor* output_index_tensor;
138 TF_LITE_ENSURE_OK(context,
139 GetOutputSafe(context, node, 1, &output_index_tensor));
140 TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
141 NumElements(input));
142
143 switch (input->type) {
144 case kTfLiteInt8:
145 TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
146 break;
147 case kTfLiteInt16:
148 TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
149 break;
150 case kTfLiteInt32:
151 TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
152 break;
153 case kTfLiteInt64:
154 TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
155 break;
156 case kTfLiteFloat32:
157 TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
158 break;
159 case kTfLiteUInt8:
160 TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
161 break;
162 default:
163 TF_LITE_KERNEL_LOG(context, "Currently Unique doesn't support type: %s",
164 TfLiteTypeGetName(input->type));
165 return kTfLiteError;
166 }
167 return kTfLiteOk;
168}
169
170} // namespace unique
171
172TfLiteRegistration* Register_UNIQUE() {
173 static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
174 unique::Eval};
175 return &r;
176}
177
178} // namespace builtin
179} // namespace ops
180} // namespace tflite
181