1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <math.h> |
16 | #include <stddef.h> |
17 | #include <stdint.h> |
18 | |
19 | #include <algorithm> |
20 | #include <initializer_list> |
21 | #include <numeric> |
22 | #include <vector> |
23 | |
24 | #include "flatbuffers/flexbuffers.h" // from @flatbuffers |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
27 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
28 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
29 | #include "tensorflow/lite/kernels/internal/tensor.h" |
30 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
31 | #include "tensorflow/lite/kernels/kernel_util.h" |
32 | |
33 | namespace tflite { |
34 | namespace ops { |
35 | namespace custom { |
36 | namespace detection_postprocess { |
37 | |
38 | // Input tensors |
39 | constexpr int kInputTensorBoxEncodings = 0; |
40 | constexpr int kInputTensorClassPredictions = 1; |
41 | constexpr int kInputTensorAnchors = 2; |
42 | |
43 | // Output tensors |
44 | // When max_classes_per_detection > 1, detection boxes will be replicated by the |
45 | // number of detected classes of that box. Dummy data will be appended if the |
46 | // number of classes is smaller than max_classes_per_detection. |
47 | constexpr int kOutputTensorDetectionBoxes = 0; |
48 | constexpr int kOutputTensorDetectionClasses = 1; |
49 | constexpr int kOutputTensorDetectionScores = 2; |
50 | constexpr int kOutputTensorNumDetections = 3; |
51 | |
52 | constexpr int kNumCoordBox = 4; |
53 | constexpr int kBatchSize = 1; |
54 | |
55 | constexpr int kNumDetectionsPerClass = 100; |
56 | |
57 | // Object Detection model produces axis-aligned boxes in two formats: |
58 | // BoxCorner represents the lower left corner (xmin, ymin) and |
59 | // the upper right corner (xmax, ymax). |
60 | // CenterSize represents the center (xcenter, ycenter), height and width. |
61 | // BoxCornerEncoding and CenterSizeEncoding are related as follows: |
62 | // ycenter = y / y_scale * anchor.h + anchor.y; |
63 | // xcenter = x / x_scale * anchor.w + anchor.x; |
64 | // half_h = 0.5*exp(h/ h_scale)) * anchor.h; |
65 | // half_w = 0.5*exp(w / w_scale)) * anchor.w; |
66 | // ymin = ycenter - half_h |
67 | // ymax = ycenter + half_h |
68 | // xmin = xcenter - half_w |
69 | // xmax = xcenter + half_w |
70 | struct BoxCornerEncoding { |
71 | float ymin; |
72 | float xmin; |
73 | float ymax; |
74 | float xmax; |
75 | }; |
76 | |
77 | struct CenterSizeEncoding { |
78 | float y; |
79 | float x; |
80 | float h; |
81 | float w; |
82 | }; |
83 | // We make sure that the memory allocations are contiguous with static assert. |
84 | static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, |
85 | "Size of BoxCornerEncoding is 4 float values" ); |
86 | static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, |
87 | "Size of CenterSizeEncoding is 4 float values" ); |
88 | |
89 | struct OpData { |
90 | int max_detections; |
91 | int max_classes_per_detection; // Fast Non-Max-Suppression |
92 | int detections_per_class; // Regular Non-Max-Suppression |
93 | float non_max_suppression_score_threshold; |
94 | float intersection_over_union_threshold; |
95 | int num_classes; |
96 | bool use_regular_non_max_suppression; |
97 | CenterSizeEncoding scale_values; |
98 | // Indices of Temporary tensors |
99 | int decoded_boxes_index; |
100 | int scores_index; |
101 | }; |
102 | |
103 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
104 | auto* op_data = new OpData; |
105 | const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); |
106 | const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); |
107 | op_data->max_detections = m["max_detections" ].AsInt32(); |
108 | op_data->max_classes_per_detection = m["max_classes_per_detection" ].AsInt32(); |
109 | if (m["detections_per_class" ].IsNull()) |
110 | op_data->detections_per_class = kNumDetectionsPerClass; |
111 | else |
112 | op_data->detections_per_class = m["detections_per_class" ].AsInt32(); |
113 | if (m["use_regular_nms" ].IsNull()) |
114 | op_data->use_regular_non_max_suppression = false; |
115 | else |
116 | op_data->use_regular_non_max_suppression = m["use_regular_nms" ].AsBool(); |
117 | |
118 | op_data->non_max_suppression_score_threshold = |
119 | m["nms_score_threshold" ].AsFloat(); |
120 | op_data->intersection_over_union_threshold = m["nms_iou_threshold" ].AsFloat(); |
121 | op_data->num_classes = m["num_classes" ].AsInt32(); |
122 | op_data->scale_values.y = m["y_scale" ].AsFloat(); |
123 | op_data->scale_values.x = m["x_scale" ].AsFloat(); |
124 | op_data->scale_values.h = m["h_scale" ].AsFloat(); |
125 | op_data->scale_values.w = m["w_scale" ].AsFloat(); |
126 | context->AddTensors(context, 1, &op_data->decoded_boxes_index); |
127 | context->AddTensors(context, 1, &op_data->scores_index); |
128 | return op_data; |
129 | } |
130 | |
131 | void Free(TfLiteContext* context, void* buffer) { |
132 | delete static_cast<OpData*>(buffer); |
133 | } |
134 | |
135 | TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, |
136 | std::initializer_list<int> values) { |
137 | TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); |
138 | int index = 0; |
139 | for (const auto& v : values) { |
140 | size->data[index] = v; |
141 | ++index; |
142 | } |
143 | return context->ResizeTensor(context, tensor, size); |
144 | } |
145 | |
146 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
147 | auto* op_data = static_cast<OpData*>(node->user_data); |
148 | // Inputs: box_encodings, scores, anchors |
149 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
150 | const TfLiteTensor* input_box_encodings; |
151 | TF_LITE_ENSURE_OK(context, |
152 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
153 | &input_box_encodings)); |
154 | const TfLiteTensor* input_class_predictions; |
155 | TF_LITE_ENSURE_OK(context, |
156 | GetInputSafe(context, node, kInputTensorClassPredictions, |
157 | &input_class_predictions)); |
158 | const TfLiteTensor* input_anchors; |
159 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors, |
160 | &input_anchors)); |
161 | TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); |
162 | TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); |
163 | TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); |
164 | // number of detected boxes |
165 | const int num_detected_boxes = |
166 | op_data->max_detections * op_data->max_classes_per_detection; |
167 | |
168 | // Outputs: detection_boxes, detection_scores, detection_classes, |
169 | // num_detections |
170 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); |
171 | // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) |
172 | TfLiteTensor* detection_boxes; |
173 | TF_LITE_ENSURE_OK(context, |
174 | GetOutputSafe(context, node, kOutputTensorDetectionBoxes, |
175 | &detection_boxes)); |
176 | detection_boxes->type = kTfLiteFloat32; |
177 | SetTensorSizes(context, detection_boxes, |
178 | {kBatchSize, num_detected_boxes, kNumCoordBox}); |
179 | |
180 | // Output Tensor detection_classes: size is set to (1, num_detected_boxes) |
181 | TfLiteTensor* detection_classes; |
182 | TF_LITE_ENSURE_OK(context, |
183 | GetOutputSafe(context, node, kOutputTensorDetectionClasses, |
184 | &detection_classes)); |
185 | detection_classes->type = kTfLiteFloat32; |
186 | SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); |
187 | |
188 | // Output Tensor detection_scores: size is set to (1, num_detected_boxes) |
189 | TfLiteTensor* detection_scores; |
190 | TF_LITE_ENSURE_OK(context, |
191 | GetOutputSafe(context, node, kOutputTensorDetectionScores, |
192 | &detection_scores)); |
193 | detection_scores->type = kTfLiteFloat32; |
194 | SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); |
195 | |
196 | // Output Tensor num_detections: size is set to 1 |
197 | TfLiteTensor* num_detections; |
198 | TF_LITE_ENSURE_OK(context, |
199 | GetOutputSafe(context, node, kOutputTensorNumDetections, |
200 | &num_detections)); |
201 | num_detections->type = kTfLiteFloat32; |
202 | SetTensorSizes(context, num_detections, {1}); |
203 | |
204 | // Temporary tensors |
205 | TfLiteIntArrayFree(node->temporaries); |
206 | node->temporaries = TfLiteIntArrayCreate(2); |
207 | node->temporaries->data[0] = op_data->decoded_boxes_index; |
208 | node->temporaries->data[1] = op_data->scores_index; |
209 | |
210 | // decoded_boxes |
211 | TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; |
212 | decoded_boxes->type = kTfLiteFloat32; |
213 | decoded_boxes->allocation_type = kTfLiteArenaRw; |
214 | SetTensorSizes(context, decoded_boxes, |
215 | {input_box_encodings->dims->data[1], kNumCoordBox}); |
216 | |
217 | // scores |
218 | TfLiteTensor* scores = &context->tensors[op_data->scores_index]; |
219 | scores->type = kTfLiteFloat32; |
220 | scores->allocation_type = kTfLiteArenaRw; |
221 | SetTensorSizes(context, scores, |
222 | {input_class_predictions->dims->data[1], |
223 | input_class_predictions->dims->data[2]}); |
224 | |
225 | return kTfLiteOk; |
226 | } |
227 | |
228 | class Dequantizer { |
229 | public: |
230 | Dequantizer(int zero_point, float scale) |
231 | : zero_point_(zero_point), scale_(scale) {} |
232 | float operator()(uint8 x) { |
233 | return (static_cast<float>(x) - zero_point_) * scale_; |
234 | } |
235 | |
236 | private: |
237 | int zero_point_; |
238 | float scale_; |
239 | }; |
240 | |
241 | void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, |
242 | float quant_zero_point, float quant_scale, |
243 | int length_box_encoding, |
244 | CenterSizeEncoding* box_centersize) { |
245 | const uint8* boxes = |
246 | GetTensorData<uint8>(input_box_encodings) + length_box_encoding * idx; |
247 | Dequantizer dequantize(quant_zero_point, quant_scale); |
248 | // See definition of the KeyPointBoxCoder at |
249 | // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py |
250 | // The first four elements are the box coordinates, which is the same as the |
251 | // FastRnnBoxCoder at |
252 | // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py |
253 | box_centersize->y = dequantize(boxes[0]); |
254 | box_centersize->x = dequantize(boxes[1]); |
255 | box_centersize->h = dequantize(boxes[2]); |
256 | box_centersize->w = dequantize(boxes[3]); |
257 | } |
258 | |
259 | template <class T> |
260 | T ReInterpretTensor(const TfLiteTensor* tensor) { |
261 | const float* tensor_base = GetTensorData<float>(tensor); |
262 | return reinterpret_cast<T>(tensor_base); |
263 | } |
264 | |
265 | template <class T> |
266 | T ReInterpretTensor(TfLiteTensor* tensor) { |
267 | float* tensor_base = GetTensorData<float>(tensor); |
268 | return reinterpret_cast<T>(tensor_base); |
269 | } |
270 | |
271 | TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, |
272 | OpData* op_data) { |
273 | // Parse input tensor boxencodings |
274 | const TfLiteTensor* input_box_encodings; |
275 | TF_LITE_ENSURE_OK(context, |
276 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
277 | &input_box_encodings)); |
278 | TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); |
279 | const int num_boxes = input_box_encodings->dims->data[1]; |
280 | TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox); |
281 | const TfLiteTensor* input_anchors; |
282 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors, |
283 | &input_anchors)); |
284 | |
285 | // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors |
286 | CenterSizeEncoding box_centersize; |
287 | CenterSizeEncoding scale_values = op_data->scale_values; |
288 | CenterSizeEncoding anchor; |
289 | for (int idx = 0; idx < num_boxes; ++idx) { |
290 | switch (input_box_encodings->type) { |
291 | // Quantized |
292 | case kTfLiteUInt8: |
293 | DequantizeBoxEncodings( |
294 | input_box_encodings, idx, |
295 | static_cast<float>(input_box_encodings->params.zero_point), |
296 | static_cast<float>(input_box_encodings->params.scale), |
297 | input_box_encodings->dims->data[2], &box_centersize); |
298 | DequantizeBoxEncodings( |
299 | input_anchors, idx, |
300 | static_cast<float>(input_anchors->params.zero_point), |
301 | static_cast<float>(input_anchors->params.scale), kNumCoordBox, |
302 | &anchor); |
303 | break; |
304 | // Float |
305 | case kTfLiteFloat32: { |
306 | // Please see DequantizeBoxEncodings function for the support detail. |
307 | const int box_encoding_idx = idx * input_box_encodings->dims->data[2]; |
308 | const float* boxes = |
309 | &(GetTensorData<float>(input_box_encodings)[box_encoding_idx]); |
310 | box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes); |
311 | TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32); |
312 | anchor = |
313 | ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx]; |
314 | break; |
315 | } |
316 | default: |
317 | // Unsupported type. |
318 | return kTfLiteError; |
319 | } |
320 | |
321 | float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) / |
322 | static_cast<double>(scale_values.y) * |
323 | static_cast<double>(anchor.h) + |
324 | static_cast<double>(anchor.y)); |
325 | |
326 | float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) / |
327 | static_cast<double>(scale_values.x) * |
328 | static_cast<double>(anchor.w) + |
329 | static_cast<double>(anchor.x)); |
330 | |
331 | float half_h = |
332 | static_cast<float>(0.5 * |
333 | (std::exp(static_cast<double>(box_centersize.h) / |
334 | static_cast<double>(scale_values.h))) * |
335 | static_cast<double>(anchor.h)); |
336 | float half_w = |
337 | static_cast<float>(0.5 * |
338 | (std::exp(static_cast<double>(box_centersize.w) / |
339 | static_cast<double>(scale_values.w))) * |
340 | static_cast<double>(anchor.w)); |
341 | |
342 | TfLiteTensor* decoded_boxes = |
343 | &context->tensors[op_data->decoded_boxes_index]; |
344 | TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); |
345 | auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx]; |
346 | box.ymin = ycenter - half_h; |
347 | box.xmin = xcenter - half_w; |
348 | box.ymax = ycenter + half_h; |
349 | box.xmax = xcenter + half_w; |
350 | } |
351 | return kTfLiteOk; |
352 | } |
353 | |
354 | void DecreasingPartialArgSort(const float* values, int num_values, |
355 | int num_to_sort, int* indices) { |
356 | if (num_to_sort == 1) { |
357 | indices[0] = optimized_ops::ArgMaxVector(values, num_values); |
358 | } else { |
359 | std::iota(indices, indices + num_values, 0); |
360 | std::partial_sort( |
361 | indices, indices + num_to_sort, indices + num_values, |
362 | [&values](const int i, const int j) { return values[i] > values[j]; }); |
363 | } |
364 | } |
365 | |
366 | void DecreasingArgSort(const float* values, int num_values, int* indices) { |
367 | std::iota(indices, indices + num_values, 0); |
368 | |
369 | // We want here a stable sort, in order to get completely defined output. |
370 | // In this way TFL and TFLM can be bit-exact. |
371 | std::stable_sort( |
372 | indices, indices + num_values, |
373 | [&values](const int i, const int j) { return values[i] > values[j]; }); |
374 | } |
375 | |
376 | void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values, |
377 | const float threshold, |
378 | std::vector<float>* keep_values, |
379 | std::vector<int>* keep_indices) { |
380 | for (int i = 0; i < values.size(); i++) { |
381 | if (values[i] >= threshold) { |
382 | keep_values->emplace_back(values[i]); |
383 | keep_indices->emplace_back(i); |
384 | } |
385 | } |
386 | } |
387 | |
388 | bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { |
389 | for (int i = 0; i < num_boxes; ++i) { |
390 | auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i]; |
391 | // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes |
392 | // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box |
393 | // area is <= 0. |
394 | if (box.ymin > box.ymax || box.xmin > box.xmax) { |
395 | return false; |
396 | } |
397 | } |
398 | return true; |
399 | } |
400 | |
401 | float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, |
402 | const int i, const int j) { |
403 | auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i]; |
404 | auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j]; |
405 | const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); |
406 | const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); |
407 | if (area_i <= 0 || area_j <= 0) return 0.0; |
408 | const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin); |
409 | const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin); |
410 | const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax); |
411 | const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax); |
412 | const float intersection_area = |
413 | std::max<float>(intersection_ymax - intersection_ymin, 0.0) * |
414 | std::max<float>(intersection_xmax - intersection_xmin, 0.0); |
415 | return intersection_area / (area_i + area_j - intersection_area); |
416 | } |
417 | |
418 | // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap |
419 | // before selecting the highest scoring boxes (max_detections in number) |
420 | // It assumes all boxes are good in beginning and sorts based on the scores. |
421 | // If lower-scoring box has too much overlap with a higher-scoring box, |
422 | // we get rid of the lower-scoring box. |
423 | // Complexity is O(N^2) pairwise comparison between boxes |
424 | TfLiteStatus NonMaxSuppressionSingleClassHelper( |
425 | TfLiteContext* context, TfLiteNode* node, OpData* op_data, |
426 | const std::vector<float>& scores, int max_detections, |
427 | std::vector<int>* selected) { |
428 | const TfLiteTensor* input_box_encodings; |
429 | TF_LITE_ENSURE_OK(context, |
430 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
431 | &input_box_encodings)); |
432 | const TfLiteTensor* decoded_boxes = |
433 | &context->tensors[op_data->decoded_boxes_index]; |
434 | const int num_boxes = input_box_encodings->dims->data[1]; |
435 | const float non_max_suppression_score_threshold = |
436 | op_data->non_max_suppression_score_threshold; |
437 | const float intersection_over_union_threshold = |
438 | op_data->intersection_over_union_threshold; |
439 | // Maximum detections should be positive. |
440 | TF_LITE_ENSURE(context, (max_detections >= 0)); |
441 | // intersection_over_union_threshold should be positive |
442 | // and should be less than 1. |
443 | TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && |
444 | (intersection_over_union_threshold <= 1.0f)); |
445 | // Validate boxes |
446 | TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); |
447 | TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); |
448 | |
449 | // threshold scores |
450 | std::vector<int> keep_indices; |
451 | // TODO(b/177068807): Remove the dynamic allocation and replace it |
452 | // with temporaries, esp for std::vector<float> |
453 | std::vector<float> keep_scores; |
454 | SelectDetectionsAboveScoreThreshold( |
455 | scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); |
456 | |
457 | int num_scores_kept = keep_scores.size(); |
458 | std::vector<int> sorted_indices; |
459 | sorted_indices.resize(num_scores_kept); |
460 | DecreasingArgSort(keep_scores.data(), num_scores_kept, sorted_indices.data()); |
461 | |
462 | const int num_boxes_kept = num_scores_kept; |
463 | const int output_size = std::min(num_boxes_kept, max_detections); |
464 | selected->clear(); |
465 | int num_active_candidate = num_boxes_kept; |
466 | std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1); |
467 | |
468 | for (int i = 0; i < num_boxes_kept; ++i) { |
469 | if (num_active_candidate == 0 || selected->size() >= output_size) break; |
470 | if (active_box_candidate[i] == 1) { |
471 | selected->push_back(keep_indices[sorted_indices[i]]); |
472 | active_box_candidate[i] = 0; |
473 | num_active_candidate--; |
474 | } else { |
475 | continue; |
476 | } |
477 | for (int j = i + 1; j < num_boxes_kept; ++j) { |
478 | if (active_box_candidate[j] == 1) { |
479 | TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); |
480 | float intersection_over_union = ComputeIntersectionOverUnion( |
481 | decoded_boxes, keep_indices[sorted_indices[i]], |
482 | keep_indices[sorted_indices[j]]); |
483 | |
484 | if (intersection_over_union > intersection_over_union_threshold) { |
485 | active_box_candidate[j] = 0; |
486 | num_active_candidate--; |
487 | } |
488 | } |
489 | } |
490 | } |
491 | return kTfLiteOk; |
492 | } |
493 | |
494 | struct BoxInfo { |
495 | int index; |
496 | float score; |
497 | }; |
498 | |
499 | struct NMSTaskParam { |
500 | // Caller retains the ownership of `context`, `node`, `op_data` and `scores`. |
501 | // Caller should ensure their lifetime is longer than NMSTaskParam instance. |
502 | TfLiteContext* context; |
503 | TfLiteNode* node; |
504 | OpData* op_data; |
505 | const float* scores; |
506 | |
507 | int num_classes; |
508 | int num_boxes; |
509 | int label_offset; |
510 | int num_classes_with_background; |
511 | int num_detections_per_class; |
512 | int max_detections; |
513 | std::vector<int>& num_selected; |
514 | }; |
515 | |
516 | void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index, |
517 | int end_index) { |
518 | std::inplace_merge( |
519 | boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index, |
520 | [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; }); |
521 | } |
522 | |
523 | TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin, |
524 | int col_end, int& sorted_indices_size, |
525 | std::vector<BoxInfo>& resulted_sorted_box_info) { |
526 | std::vector<float> class_scores(nms_task_param.num_boxes); |
527 | std::vector<int> selected; |
528 | selected.reserve(nms_task_param.num_detections_per_class); |
529 | |
530 | for (int col = col_begin; col <= col_end; ++col) { |
531 | const float* scores_base = |
532 | nms_task_param.scores + col + nms_task_param.label_offset; |
533 | for (int row = 0; row < nms_task_param.num_boxes; row++) { |
534 | // Get scores of boxes corresponding to all anchors for single class |
535 | class_scores[row] = *scores_base; |
536 | scores_base += nms_task_param.num_classes_with_background; |
537 | } |
538 | |
539 | // Perform non-maximal suppression on single class |
540 | selected.clear(); |
541 | TF_LITE_ENSURE_OK( |
542 | nms_task_param.context, |
543 | NonMaxSuppressionSingleClassHelper( |
544 | nms_task_param.context, nms_task_param.node, nms_task_param.op_data, |
545 | class_scores, nms_task_param.num_detections_per_class, &selected)); |
546 | if (selected.empty()) { |
547 | continue; |
548 | } |
549 | |
550 | for (int i = 0; i < selected.size(); ++i) { |
551 | resulted_sorted_box_info[sorted_indices_size + i].score = |
552 | class_scores[selected[i]]; |
553 | resulted_sorted_box_info[sorted_indices_size + i].index = |
554 | (selected[i] * nms_task_param.num_classes_with_background + col + |
555 | nms_task_param.label_offset); |
556 | } |
557 | |
558 | // In-place merge the original boxes and new selected boxes which are both |
559 | // sorted by scores. |
560 | InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size, |
561 | sorted_indices_size + selected.size()); |
562 | |
563 | sorted_indices_size = |
564 | std::min(sorted_indices_size + static_cast<int>(selected.size()), |
565 | nms_task_param.max_detections); |
566 | } |
567 | return kTfLiteOk; |
568 | } |
569 | |
570 | struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task { |
571 | NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param, |
572 | std::atomic<int>& next_col, int col_begin) |
573 | : nms_task_param(nms_task_param), |
574 | next_col(next_col), |
575 | col_begin(col_begin), |
576 | sorted_indices_size(0) {} |
577 | void Run() override { |
578 | sorted_box_info.resize(nms_task_param.num_detections_per_class + |
579 | nms_task_param.max_detections); |
580 | for (int col = col_begin; col < nms_task_param.num_classes; |
581 | col = (++next_col)) { |
582 | if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size, |
583 | sorted_box_info) != kTfLiteOk) { |
584 | break; |
585 | } |
586 | } |
587 | } |
588 | NMSTaskParam& nms_task_param; |
589 | // A shared atomic variable across threads, representing the next col this |
590 | // task will work on after completing the work for 'col_begin' |
591 | std::atomic<int>& next_col; |
592 | const int col_begin; |
593 | int sorted_indices_size; |
594 | std::vector<BoxInfo> sorted_box_info; |
595 | }; |
596 | |
597 | // This function implements a regular version of Non Maximal Suppression (NMS) |
598 | // for multiple classes where |
599 | // 1) we do NMS separately for each class across all anchors and |
600 | // 2) keep only the highest anchor scores across all classes |
601 | // 3) The worst runtime of the regular NMS is O(K*N^2) |
602 | // where N is the number of anchors and K the number of |
603 | // classes. |
604 | TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, |
605 | TfLiteNode* node, |
606 | OpData* op_data, |
607 | const float* scores) { |
608 | const TfLiteTensor* input_box_encodings; |
609 | TF_LITE_ENSURE_OK(context, |
610 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
611 | &input_box_encodings)); |
612 | const TfLiteTensor* input_class_predictions; |
613 | TF_LITE_ENSURE_OK(context, |
614 | GetInputSafe(context, node, kInputTensorClassPredictions, |
615 | &input_class_predictions)); |
616 | const TfLiteTensor* decoded_boxes = |
617 | &context->tensors[op_data->decoded_boxes_index]; |
618 | |
619 | TfLiteTensor* detection_boxes; |
620 | TF_LITE_ENSURE_OK(context, |
621 | GetOutputSafe(context, node, kOutputTensorDetectionBoxes, |
622 | &detection_boxes)); |
623 | TfLiteTensor* detection_classes; |
624 | TF_LITE_ENSURE_OK(context, |
625 | GetOutputSafe(context, node, kOutputTensorDetectionClasses, |
626 | &detection_classes)); |
627 | TfLiteTensor* detection_scores; |
628 | TF_LITE_ENSURE_OK(context, |
629 | GetOutputSafe(context, node, kOutputTensorDetectionScores, |
630 | &detection_scores)); |
631 | TfLiteTensor* num_detections; |
632 | TF_LITE_ENSURE_OK(context, |
633 | GetOutputSafe(context, node, kOutputTensorNumDetections, |
634 | &num_detections)); |
635 | |
636 | const int num_boxes = input_box_encodings->dims->data[1]; |
637 | const int num_classes = op_data->num_classes; |
638 | const int num_detections_per_class = |
639 | std::min(op_data->detections_per_class, op_data->max_detections); |
640 | const int max_detections = op_data->max_detections; |
641 | const int num_classes_with_background = |
642 | input_class_predictions->dims->data[2]; |
643 | // The row index offset is 1 if background class is included and 0 otherwise. |
644 | int label_offset = num_classes_with_background - num_classes; |
645 | TF_LITE_ENSURE(context, num_detections_per_class > 0); |
646 | |
647 | int sorted_indices_size = 0; |
648 | std::vector<BoxInfo> box_info_after_regular_non_max_suppression( |
649 | max_detections + num_detections_per_class); |
650 | std::vector<int> num_selected(num_classes); |
651 | |
652 | NMSTaskParam nms_task_param{context, |
653 | node, |
654 | op_data, |
655 | scores, |
656 | num_classes, |
657 | num_boxes, |
658 | label_offset, |
659 | num_classes_with_background, |
660 | num_detections_per_class, |
661 | max_detections, |
662 | num_selected}; |
663 | |
664 | int num_threads = |
665 | CpuBackendContext::GetFromContext(context)->max_num_threads(); |
666 | if (num_threads == 1) { |
667 | // For each class, perform non-max suppression. |
668 | TF_LITE_ENSURE_OK( |
669 | context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0, |
670 | num_classes - 1, sorted_indices_size, |
671 | box_info_after_regular_non_max_suppression)); |
672 | } else { |
673 | std::atomic<int> next_col(num_threads); |
674 | std::vector<NonMaxSuppressionWorkerTask> tasks; |
675 | tasks.reserve(num_threads); |
676 | for (int i = 0; i < num_threads; ++i) { |
677 | tasks.emplace_back( |
678 | NonMaxSuppressionWorkerTask(nms_task_param, next_col, i)); |
679 | } |
680 | cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), |
681 | CpuBackendContext::GetFromContext(context)); |
682 | |
683 | // Merge results from tasks. |
684 | for (int j = 0; j < tasks.size(); ++j) { |
685 | if (tasks[j].sorted_indices_size == 0) { |
686 | continue; |
687 | } |
688 | memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size], |
689 | &tasks[j].sorted_box_info[0], |
690 | sizeof(BoxInfo) * tasks[j].sorted_indices_size); |
691 | InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression, |
692 | sorted_indices_size, |
693 | sorted_indices_size + tasks[j].sorted_indices_size); |
694 | sorted_indices_size = std::min( |
695 | sorted_indices_size + tasks[j].sorted_indices_size, max_detections); |
696 | } |
697 | } |
698 | |
699 | // Allocate output tensors |
700 | for (int output_box_index = 0; output_box_index < max_detections; |
701 | output_box_index++) { |
702 | if (output_box_index < sorted_indices_size) { |
703 | const int anchor_index = floor( |
704 | box_info_after_regular_non_max_suppression[output_box_index].index / |
705 | num_classes_with_background); |
706 | const int class_index = |
707 | box_info_after_regular_non_max_suppression[output_box_index].index - |
708 | anchor_index * num_classes_with_background - label_offset; |
709 | const float selected_score = |
710 | box_info_after_regular_non_max_suppression[output_box_index].score; |
711 | // detection_boxes |
712 | TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); |
713 | TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); |
714 | ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] = |
715 | ReInterpretTensor<const BoxCornerEncoding*>( |
716 | decoded_boxes)[anchor_index]; |
717 | // detection_classes |
718 | GetTensorData<float>(detection_classes)[output_box_index] = class_index; |
719 | // detection_scores |
720 | GetTensorData<float>(detection_scores)[output_box_index] = selected_score; |
721 | } else { |
722 | TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); |
723 | ReInterpretTensor<BoxCornerEncoding*>( |
724 | detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; |
725 | // detection_classes |
726 | GetTensorData<float>(detection_classes)[output_box_index] = 0.0f; |
727 | // detection_scores |
728 | GetTensorData<float>(detection_scores)[output_box_index] = 0.0f; |
729 | } |
730 | } |
731 | GetTensorData<float>(num_detections)[0] = sorted_indices_size; |
732 | box_info_after_regular_non_max_suppression.clear(); |
733 | return kTfLiteOk; |
734 | } |
735 | |
736 | // This function implements a fast version of Non Maximal Suppression for |
737 | // multiple classes where |
738 | // 1) we keep the top-k scores for each anchor and |
739 | // 2) during NMS, each anchor only uses the highest class score for sorting. |
740 | // 3) Compared to standard NMS, the worst runtime of this version is O(N^2) |
741 | // instead of O(KN^2) where N is the number of anchors and K the number of |
742 | // classes. |
743 | TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, |
744 | TfLiteNode* node, |
745 | OpData* op_data, |
746 | const float* scores) { |
747 | const TfLiteTensor* input_box_encodings; |
748 | TF_LITE_ENSURE_OK(context, |
749 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
750 | &input_box_encodings)); |
751 | const TfLiteTensor* input_class_predictions; |
752 | TF_LITE_ENSURE_OK(context, |
753 | GetInputSafe(context, node, kInputTensorClassPredictions, |
754 | &input_class_predictions)); |
755 | const TfLiteTensor* decoded_boxes = |
756 | &context->tensors[op_data->decoded_boxes_index]; |
757 | |
758 | TfLiteTensor* detection_boxes; |
759 | TF_LITE_ENSURE_OK(context, |
760 | GetOutputSafe(context, node, kOutputTensorDetectionBoxes, |
761 | &detection_boxes)); |
762 | TfLiteTensor* detection_classes; |
763 | TF_LITE_ENSURE_OK(context, |
764 | GetOutputSafe(context, node, kOutputTensorDetectionClasses, |
765 | &detection_classes)); |
766 | TfLiteTensor* detection_scores; |
767 | TF_LITE_ENSURE_OK(context, |
768 | GetOutputSafe(context, node, kOutputTensorDetectionScores, |
769 | &detection_scores)); |
770 | TfLiteTensor* num_detections; |
771 | TF_LITE_ENSURE_OK(context, |
772 | GetOutputSafe(context, node, kOutputTensorNumDetections, |
773 | &num_detections)); |
774 | |
775 | const int num_boxes = input_box_encodings->dims->data[1]; |
776 | const int num_classes = op_data->num_classes; |
777 | const int max_categories_per_anchor = op_data->max_classes_per_detection; |
778 | const int num_classes_with_background = |
779 | input_class_predictions->dims->data[2]; |
780 | // The row index offset is 1 if background class is included and 0 otherwise. |
781 | int label_offset = num_classes_with_background - num_classes; |
782 | TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); |
783 | const int num_categories_per_anchor = |
784 | std::min(max_categories_per_anchor, num_classes); |
785 | std::vector<float> max_scores; |
786 | max_scores.resize(num_boxes); |
787 | std::vector<int> sorted_class_indices; |
788 | sorted_class_indices.resize(num_boxes * num_classes); |
789 | for (int row = 0; row < num_boxes; row++) { |
790 | const float* box_scores = |
791 | scores + row * num_classes_with_background + label_offset; |
792 | int* class_indices = sorted_class_indices.data() + row * num_classes; |
793 | DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, |
794 | class_indices); |
795 | max_scores[row] = box_scores[class_indices[0]]; |
796 | } |
797 | // Perform non-maximal suppression on max scores |
798 | std::vector<int> selected; |
799 | TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( |
800 | context, node, op_data, max_scores, op_data->max_detections, &selected)); |
801 | // Allocate output tensors |
802 | int output_box_index = 0; |
803 | for (const auto& selected_index : selected) { |
804 | const float* box_scores = |
805 | scores + selected_index * num_classes_with_background + label_offset; |
806 | const int* class_indices = |
807 | sorted_class_indices.data() + selected_index * num_classes; |
808 | |
809 | for (int col = 0; col < num_categories_per_anchor; ++col) { |
810 | int box_offset = max_categories_per_anchor * output_box_index + col; |
811 | // detection_boxes |
812 | TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); |
813 | TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); |
814 | ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] = |
815 | ReInterpretTensor<const BoxCornerEncoding*>( |
816 | decoded_boxes)[selected_index]; |
817 | // detection_classes |
818 | GetTensorData<float>(detection_classes)[box_offset] = class_indices[col]; |
819 | // detection_scores |
820 | GetTensorData<float>(detection_scores)[box_offset] = |
821 | box_scores[class_indices[col]]; |
822 | } |
823 | output_box_index++; |
824 | } |
825 | GetTensorData<float>(num_detections)[0] = output_box_index; |
826 | return kTfLiteOk; |
827 | } |
828 | |
829 | void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, |
830 | const int num_boxes, |
831 | const int num_classes_with_background, |
832 | TfLiteTensor* scores) { |
833 | float quant_zero_point = |
834 | static_cast<float>(input_class_predictions->params.zero_point); |
835 | float quant_scale = static_cast<float>(input_class_predictions->params.scale); |
836 | tflite::DequantizationParams op_params; |
837 | op_params.zero_point = quant_zero_point; |
838 | op_params.scale = quant_scale; |
839 | const auto shape = RuntimeShape(1, num_boxes * num_classes_with_background); |
840 | optimized_ops::Dequantize(op_params, shape, |
841 | GetTensorData<uint8>(input_class_predictions), |
842 | shape, GetTensorData<float>(scores)); |
843 | } |
844 | |
845 | TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, |
846 | TfLiteNode* node, OpData* op_data) { |
847 | // Get the input tensors |
848 | const TfLiteTensor* input_box_encodings; |
849 | TF_LITE_ENSURE_OK(context, |
850 | GetInputSafe(context, node, kInputTensorBoxEncodings, |
851 | &input_box_encodings)); |
852 | const TfLiteTensor* input_class_predictions; |
853 | TF_LITE_ENSURE_OK(context, |
854 | GetInputSafe(context, node, kInputTensorClassPredictions, |
855 | &input_class_predictions)); |
856 | const int num_boxes = input_box_encodings->dims->data[1]; |
857 | const int num_classes = op_data->num_classes; |
858 | TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], |
859 | kBatchSize); |
860 | TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); |
861 | const int num_classes_with_background = |
862 | input_class_predictions->dims->data[2]; |
863 | |
864 | TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1)); |
865 | TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes)); |
866 | |
867 | const TfLiteTensor* scores; |
868 | switch (input_class_predictions->type) { |
869 | case kTfLiteUInt8: { |
870 | TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index]; |
871 | DequantizeClassPredictions(input_class_predictions, num_boxes, |
872 | num_classes_with_background, temporary_scores); |
873 | scores = temporary_scores; |
874 | } break; |
875 | case kTfLiteFloat32: |
876 | scores = input_class_predictions; |
877 | break; |
878 | default: |
879 | // Unsupported type. |
880 | return kTfLiteError; |
881 | } |
882 | if (op_data->use_regular_non_max_suppression) |
883 | TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper( |
884 | context, node, op_data, GetTensorData<float>(scores))); |
885 | else |
886 | TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper( |
887 | context, node, op_data, GetTensorData<float>(scores))); |
888 | |
889 | return kTfLiteOk; |
890 | } |
891 | |
892 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
893 | // TODO(b/177068051): Generalize for any batch size. |
894 | TF_LITE_ENSURE(context, (kBatchSize == 1)); |
895 | auto* op_data = static_cast<OpData*>(node->user_data); |
896 | // These two functions correspond to two blocks in the Object Detection model. |
897 | // In future, we would like to break the custom op in two blocks, which is |
898 | // currently not feasible because we would like to input quantized inputs |
899 | // and do all calculations in float. Mixed quantized/float calculations are |
900 | // currently not supported in TFLite. |
901 | |
902 | // This fills in temporary decoded_boxes |
903 | // by transforming input_box_encodings and input_anchors from |
904 | // CenterSizeEncodings to BoxCornerEncoding |
905 | TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data)); |
906 | // This fills in the output tensors |
907 | // by choosing effective set of decoded boxes |
908 | // based on Non Maximal Suppression, i.e. selecting |
909 | // highest scoring non-overlapping boxes. |
910 | TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data)); |
911 | |
912 | return kTfLiteOk; |
913 | } |
914 | } // namespace detection_postprocess |
915 | |
916 | TfLiteRegistration* Register_DETECTION_POSTPROCESS() { |
917 | static TfLiteRegistration r = { |
918 | detection_postprocess::Init, detection_postprocess::Free, |
919 | detection_postprocess::Prepare, detection_postprocess::Eval}; |
920 | return &r; |
921 | } |
922 | |
923 | // Since the op is named "TFLite_Detection_PostProcess", the selective build |
924 | // tool will assume the register function is named |
925 | // "Register_TFLITE_DETECTION_POST_PROCESS". |
926 | TfLiteRegistration* Register_TFLITE_DETECTION_POST_PROCESS() { |
927 | return Register_DETECTION_POSTPROCESS(); |
928 | } |
929 | |
930 | } // namespace custom |
931 | } // namespace ops |
932 | } // namespace tflite |
933 | |