1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include <algorithm>
18#include <iterator>
19#include <vector>
20
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/compatibility.h"
23#include "tensorflow/lite/kernels/internal/tensor.h"
24#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25#include "tensorflow/lite/kernels/kernel_util.h"
26
27namespace tflite {
28namespace ops {
29namespace builtin {
30namespace topk_v2 {
31constexpr int kInputTensor = 0;
32constexpr int kInputTopK = 1;
33constexpr int kOutputValues = 0;
34constexpr int kOutputIndexes = 1;
35
36namespace {
37TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
38 const TfLiteTensor* top_k;
39 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
40 // INT32 number of top results is supported.
41 TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
42 // Check that the tensor contains only one value.
43 TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
44 const int32 k = *GetTensorData<int32_t>(top_k);
45
46 const TfLiteTensor* input;
47 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
48 const int num_dimensions = NumDimensions(input);
49 // Check that input has one or more dimensions.
50 TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
51 "TopK k input must have 1 or more dimensions.");
52 // Check that k is less or equal the internal dimension.
53 TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
54 "TopK k is higher than the internal dimension.");
55
56 TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
57 TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
58 for (int i = 0; i < num_dimensions - 1; ++i) {
59 output_indexes_shape->data[i] = input->dims->data[i];
60 output_values_shape->data[i] = input->dims->data[i];
61 }
62 output_indexes_shape->data[num_dimensions - 1] = k;
63 output_values_shape->data[num_dimensions - 1] = k;
64 TfLiteTensor* output_indexes;
65 TF_LITE_ENSURE_OK(
66 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
67 TfLiteTensor* output_values;
68 TF_LITE_ENSURE_OK(
69 context, GetOutputSafe(context, node, kOutputValues, &output_values));
70 // Force output types.
71 output_indexes->type = kTfLiteInt32;
72 output_values->type = input->type;
73 auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
74 TfLiteIntArray* delete_on_error) {
75 TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
76 if (status != kTfLiteOk) {
77 if (delete_on_error != nullptr) {
78 TfLiteIntArrayFree(delete_on_error);
79 }
80 }
81 return status;
82 };
83 TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
84 output_values_shape));
85 TF_LITE_ENSURE_OK(context,
86 resize_tensor(output_values, output_values_shape, nullptr));
87 return kTfLiteOk;
88}
89
90// Class that collects indices of top k values. Based on template
91// tensorflow::gtl::TopN<> but, for optimization, it re-uses the same container.
92template <typename T>
93class TopContainer {
94 public:
95 TopContainer() = delete;
96 TopContainer(int32 k, int32 row_size) : k_(k) {
97 container_.reserve(std::min(k, row_size) + 1);
98 }
99
100 void start_collecting(const T* values) {
101 values_ = values;
102 container_.clear();
103 is_heap_ = false;
104 }
105
106 void push(int32 a) {
107 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
108 if (!is_heap_) {
109 container_.push_back(a);
110 if (container_.size() == k_ + 1) {
111 std::make_heap(container_.begin(), container_.end(), comparator);
112 std::pop_heap(container_.begin(), container_.end(), comparator);
113 container_.pop_back();
114 is_heap_ = true;
115 }
116 } else if (comparator(a, container_.front())) {
117 // Due to how we defined comparator / compare_fun, container_.front()
118 // contains the index of the smallest of the top-k elements seen so far.
119 //
120 // If control reaches this point, we know that the current index a
121 // corresponds to an element which is bigger than the smallest of the
122 // top-k elements seen so far. Hence, we have to update the indices of
123 // the top-k elements, by removing the index of the smallest top-k
124 // element, adding a, and making sure container_[0:k] is still a heap.
125 std::pop_heap(container_.begin(), container_.end(), comparator);
126 container_.back() = a;
127 std::push_heap(container_.begin(), container_.end(), comparator);
128 }
129 }
130
131 const std::vector<int32>& sorted_result() {
132 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
133 if (!is_heap_) {
134 // Note: due to the way we defined compare_fun (see comments for that
135 // function) std::sort puts the indices from container_ in decreasing
136 // order of the corresponding elements.
137 std::sort(container_.begin(), container_.end(), comparator);
138 } else {
139 std::sort_heap(container_.begin(), container_.end(), comparator);
140 }
141 return container_;
142 }
143
144 private:
145 const int32 k_;
146
147 // container_[0,k) holds the indices of the largest k elements from values_
148 // seen so far. If more than k elements are pushed, then elements are
149 // maintained in a min-heap order: container_.front() is
150 // the index of the smallest of the top-k elements see so far.
151 std::vector<int32> container_;
152
153 // Once more than k elements are pushed, the container becomes a min heap,
154 // and is_heap_ becomes true.
155 bool is_heap_ = false;
156
157 const T* values_ = nullptr;
158
159 // Compares indices a and b based on the corresponding elements from values_.
160 //
161 // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a]
162 // (notice the inversion of direction, not a typo); ties (==) are broken in
163 // favor of earlier elements (i.e., a < b).
164 bool compare_fun(int32 a, int32 b) const {
165 if (values_[b] < values_[a]) {
166 return true;
167 } else if (values_[b] > values_[a]) {
168 return false;
169 } else {
170 return a < b;
171 }
172 }
173};
174
175// Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
176template <typename T>
177void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
178 int32* output_indexes, T* output_values) {
179 TopContainer<T> topc(k, row_size);
180 for (int row = 0; row < num_rows; ++row) {
181 const T* values_row = data + row * row_size;
182 topc.start_collecting(values_row);
183 for (int32 c = 0; c < row_size; ++c) {
184 topc.push(c);
185 }
186
187 // Prepare output buffers.
188 int32* indexes_row = output_indexes + row * k;
189 T* output_row = output_values + row * k;
190 // We always assume that the output is sorted.
191 const auto& top_k = topc.sorted_result();
192 std::copy(top_k.begin(), top_k.end(), indexes_row);
193 std::transform(top_k.begin(), top_k.end(), output_row,
194 [values_row](const int32 loc) { return values_row[loc]; });
195 }
196}
197
198} // namespace
199
200TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
201 // Check that the inputs and outputs have the right sizes and types.
202 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
203 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
204
205 const TfLiteTensor* input;
206 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
207 TfLiteTensor* output_values;
208 TF_LITE_ENSURE_OK(
209 context, GetOutputSafe(context, node, kOutputValues, &output_values));
210 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
211
212 const TfLiteTensor* top_k;
213 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
214 TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
215
216 // Set output dynamic if the `top_k` tensor is not constant, or the input has
217 // dynamic dimensions (indicated by dims signature).
218 if (IsConstantTensor(top_k) && !HasUnspecifiedDimension(input)) {
219 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
220 } else {
221 TfLiteTensor* output_indexes;
222 TF_LITE_ENSURE_OK(
223 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
224 TfLiteTensor* output_values;
225 TF_LITE_ENSURE_OK(
226 context, GetOutputSafe(context, node, kOutputValues, &output_values));
227 SetTensorToDynamic(output_indexes);
228 SetTensorToDynamic(output_values);
229 }
230 return kTfLiteOk;
231}
232
233TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
234 TfLiteTensor* output_values;
235 TF_LITE_ENSURE_OK(
236 context, GetOutputSafe(context, node, kOutputValues, &output_values));
237 TfLiteTensor* output_indexes;
238 TF_LITE_ENSURE_OK(
239 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
240 if (IsDynamicTensor(output_values)) {
241 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
242 }
243 const TfLiteTensor* top_k;
244 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
245 const int32 k = top_k->data.i32[0];
246 // The tensor can have more than 2 dimensions or even be a vector, the code
247 // anyway calls the internal dimension as row;
248 const TfLiteTensor* input;
249 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
250 const int32 row_size = input->dims->data[input->dims->size - 1];
251 int32 num_rows = 1;
252 for (int i = 0; i < input->dims->size - 1; ++i) {
253 num_rows *= input->dims->data[i];
254 }
255 switch (output_values->type) {
256 case kTfLiteFloat32:
257 TopK(row_size, num_rows, GetTensorData<float>(input), k,
258 output_indexes->data.i32, GetTensorData<float>(output_values));
259 break;
260 case kTfLiteUInt8:
261 TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
262 output_values->data.uint8);
263 break;
264 case kTfLiteInt8:
265 TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
266 output_values->data.int8);
267 break;
268 case kTfLiteInt32:
269 TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
270 output_values->data.i32);
271 break;
272 case kTfLiteInt64:
273 TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
274 output_values->data.i64);
275 break;
276 default:
277 TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK.",
278 TfLiteTypeGetName(output_values->type));
279 return kTfLiteError;
280 }
281
282 return kTfLiteOk;
283}
284} // namespace topk_v2
285TfLiteRegistration* Register_TOPK_V2() {
286 static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
287 topk_v2::Eval};
288 return &r;
289}
290} // namespace builtin
291} // namespace ops
292} // namespace tflite
293