1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include <algorithm> |
18 | #include <iterator> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
25 | #include "tensorflow/lite/kernels/kernel_util.h" |
26 | |
27 | namespace tflite { |
28 | namespace ops { |
29 | namespace builtin { |
30 | namespace topk_v2 { |
31 | constexpr int kInputTensor = 0; |
32 | constexpr int kInputTopK = 1; |
33 | constexpr int kOutputValues = 0; |
34 | constexpr int kOutputIndexes = 1; |
35 | |
36 | namespace { |
37 | TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { |
38 | const TfLiteTensor* top_k; |
39 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k)); |
40 | // INT32 number of top results is supported. |
41 | TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32); |
42 | // Check that the tensor contains only one value. |
43 | TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); |
44 | const int32 k = *GetTensorData<int32_t>(top_k); |
45 | |
46 | const TfLiteTensor* input; |
47 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
48 | const int num_dimensions = NumDimensions(input); |
49 | // Check that input has one or more dimensions. |
50 | TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, |
51 | "TopK k input must have 1 or more dimensions." ); |
52 | // Check that k is less or equal the internal dimension. |
53 | TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1], |
54 | "TopK k is higher than the internal dimension." ); |
55 | |
56 | TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions); |
57 | TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions); |
58 | for (int i = 0; i < num_dimensions - 1; ++i) { |
59 | output_indexes_shape->data[i] = input->dims->data[i]; |
60 | output_values_shape->data[i] = input->dims->data[i]; |
61 | } |
62 | output_indexes_shape->data[num_dimensions - 1] = k; |
63 | output_values_shape->data[num_dimensions - 1] = k; |
64 | TfLiteTensor* output_indexes; |
65 | TF_LITE_ENSURE_OK( |
66 | context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes)); |
67 | TfLiteTensor* output_values; |
68 | TF_LITE_ENSURE_OK( |
69 | context, GetOutputSafe(context, node, kOutputValues, &output_values)); |
70 | // Force output types. |
71 | output_indexes->type = kTfLiteInt32; |
72 | output_values->type = input->type; |
73 | auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, |
74 | TfLiteIntArray* delete_on_error) { |
75 | TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); |
76 | if (status != kTfLiteOk) { |
77 | if (delete_on_error != nullptr) { |
78 | TfLiteIntArrayFree(delete_on_error); |
79 | } |
80 | } |
81 | return status; |
82 | }; |
83 | TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape, |
84 | output_values_shape)); |
85 | TF_LITE_ENSURE_OK(context, |
86 | resize_tensor(output_values, output_values_shape, nullptr)); |
87 | return kTfLiteOk; |
88 | } |
89 | |
90 | // Class that collects indices of top k values. Based on template |
91 | // tensorflow::gtl::TopN<> but, for optimization, it re-uses the same container. |
92 | template <typename T> |
93 | class TopContainer { |
94 | public: |
95 | TopContainer() = delete; |
96 | TopContainer(int32 k, int32 row_size) : k_(k) { |
97 | container_.reserve(std::min(k, row_size) + 1); |
98 | } |
99 | |
100 | void start_collecting(const T* values) { |
101 | values_ = values; |
102 | container_.clear(); |
103 | is_heap_ = false; |
104 | } |
105 | |
106 | void push(int32 a) { |
107 | auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; |
108 | if (!is_heap_) { |
109 | container_.push_back(a); |
110 | if (container_.size() == k_ + 1) { |
111 | std::make_heap(container_.begin(), container_.end(), comparator); |
112 | std::pop_heap(container_.begin(), container_.end(), comparator); |
113 | container_.pop_back(); |
114 | is_heap_ = true; |
115 | } |
116 | } else if (comparator(a, container_.front())) { |
117 | // Due to how we defined comparator / compare_fun, container_.front() |
118 | // contains the index of the smallest of the top-k elements seen so far. |
119 | // |
120 | // If control reaches this point, we know that the current index a |
121 | // corresponds to an element which is bigger than the smallest of the |
122 | // top-k elements seen so far. Hence, we have to update the indices of |
123 | // the top-k elements, by removing the index of the smallest top-k |
124 | // element, adding a, and making sure container_[0:k] is still a heap. |
125 | std::pop_heap(container_.begin(), container_.end(), comparator); |
126 | container_.back() = a; |
127 | std::push_heap(container_.begin(), container_.end(), comparator); |
128 | } |
129 | } |
130 | |
131 | const std::vector<int32>& sorted_result() { |
132 | auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; |
133 | if (!is_heap_) { |
134 | // Note: due to the way we defined compare_fun (see comments for that |
135 | // function) std::sort puts the indices from container_ in decreasing |
136 | // order of the corresponding elements. |
137 | std::sort(container_.begin(), container_.end(), comparator); |
138 | } else { |
139 | std::sort_heap(container_.begin(), container_.end(), comparator); |
140 | } |
141 | return container_; |
142 | } |
143 | |
144 | private: |
145 | const int32 k_; |
146 | |
147 | // container_[0,k) holds the indices of the largest k elements from values_ |
148 | // seen so far. If more than k elements are pushed, then elements are |
149 | // maintained in a min-heap order: container_.front() is |
150 | // the index of the smallest of the top-k elements see so far. |
151 | std::vector<int32> container_; |
152 | |
153 | // Once more than k elements are pushed, the container becomes a min heap, |
154 | // and is_heap_ becomes true. |
155 | bool is_heap_ = false; |
156 | |
157 | const T* values_ = nullptr; |
158 | |
159 | // Compares indices a and b based on the corresponding elements from values_. |
160 | // |
161 | // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a] |
162 | // (notice the inversion of direction, not a typo); ties (==) are broken in |
163 | // favor of earlier elements (i.e., a < b). |
164 | bool compare_fun(int32 a, int32 b) const { |
165 | if (values_[b] < values_[a]) { |
166 | return true; |
167 | } else if (values_[b] > values_[a]) { |
168 | return false; |
169 | } else { |
170 | return a < b; |
171 | } |
172 | } |
173 | }; |
174 | |
175 | // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU. |
176 | template <typename T> |
177 | void TopK(int32 row_size, int32 num_rows, const T* data, int32 k, |
178 | int32* output_indexes, T* output_values) { |
179 | TopContainer<T> topc(k, row_size); |
180 | for (int row = 0; row < num_rows; ++row) { |
181 | const T* values_row = data + row * row_size; |
182 | topc.start_collecting(values_row); |
183 | for (int32 c = 0; c < row_size; ++c) { |
184 | topc.push(c); |
185 | } |
186 | |
187 | // Prepare output buffers. |
188 | int32* indexes_row = output_indexes + row * k; |
189 | T* output_row = output_values + row * k; |
190 | // We always assume that the output is sorted. |
191 | const auto& top_k = topc.sorted_result(); |
192 | std::copy(top_k.begin(), top_k.end(), indexes_row); |
193 | std::transform(top_k.begin(), top_k.end(), output_row, |
194 | [values_row](const int32 loc) { return values_row[loc]; }); |
195 | } |
196 | } |
197 | |
198 | } // namespace |
199 | |
200 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
201 | // Check that the inputs and outputs have the right sizes and types. |
202 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
203 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); |
204 | |
205 | const TfLiteTensor* input; |
206 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
207 | TfLiteTensor* output_values; |
208 | TF_LITE_ENSURE_OK( |
209 | context, GetOutputSafe(context, node, kOutputValues, &output_values)); |
210 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type); |
211 | |
212 | const TfLiteTensor* top_k; |
213 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k)); |
214 | TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32); |
215 | |
216 | // Set output dynamic if the `top_k` tensor is not constant, or the input has |
217 | // dynamic dimensions (indicated by dims signature). |
218 | if (IsConstantTensor(top_k) && !HasUnspecifiedDimension(input)) { |
219 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
220 | } else { |
221 | TfLiteTensor* output_indexes; |
222 | TF_LITE_ENSURE_OK( |
223 | context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes)); |
224 | TfLiteTensor* output_values; |
225 | TF_LITE_ENSURE_OK( |
226 | context, GetOutputSafe(context, node, kOutputValues, &output_values)); |
227 | SetTensorToDynamic(output_indexes); |
228 | SetTensorToDynamic(output_values); |
229 | } |
230 | return kTfLiteOk; |
231 | } |
232 | |
233 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
234 | TfLiteTensor* output_values; |
235 | TF_LITE_ENSURE_OK( |
236 | context, GetOutputSafe(context, node, kOutputValues, &output_values)); |
237 | TfLiteTensor* output_indexes; |
238 | TF_LITE_ENSURE_OK( |
239 | context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes)); |
240 | if (IsDynamicTensor(output_values)) { |
241 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
242 | } |
243 | const TfLiteTensor* top_k; |
244 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k)); |
245 | const int32 k = top_k->data.i32[0]; |
246 | // The tensor can have more than 2 dimensions or even be a vector, the code |
247 | // anyway calls the internal dimension as row; |
248 | const TfLiteTensor* input; |
249 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
250 | const int32 row_size = input->dims->data[input->dims->size - 1]; |
251 | int32 num_rows = 1; |
252 | for (int i = 0; i < input->dims->size - 1; ++i) { |
253 | num_rows *= input->dims->data[i]; |
254 | } |
255 | switch (output_values->type) { |
256 | case kTfLiteFloat32: |
257 | TopK(row_size, num_rows, GetTensorData<float>(input), k, |
258 | output_indexes->data.i32, GetTensorData<float>(output_values)); |
259 | break; |
260 | case kTfLiteUInt8: |
261 | TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, |
262 | output_values->data.uint8); |
263 | break; |
264 | case kTfLiteInt8: |
265 | TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32, |
266 | output_values->data.int8); |
267 | break; |
268 | case kTfLiteInt32: |
269 | TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, |
270 | output_values->data.i32); |
271 | break; |
272 | case kTfLiteInt64: |
273 | TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32, |
274 | output_values->data.i64); |
275 | break; |
276 | default: |
277 | TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK." , |
278 | TfLiteTypeGetName(output_values->type)); |
279 | return kTfLiteError; |
280 | } |
281 | |
282 | return kTfLiteOk; |
283 | } |
284 | } // namespace topk_v2 |
285 | TfLiteRegistration* Register_TOPK_V2() { |
286 | static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare, |
287 | topk_v2::Eval}; |
288 | return &r; |
289 | } |
290 | } // namespace builtin |
291 | } // namespace ops |
292 | } // namespace tflite |
293 | |