1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <math.h>
16#include <stddef.h>
17#include <stdint.h>
18
19#include <algorithm>
20#include <initializer_list>
21#include <numeric>
22#include <vector>
23
24#include "flatbuffers/flexbuffers.h" // from @flatbuffers
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/internal/compatibility.h"
27#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
29#include "tensorflow/lite/kernels/internal/tensor.h"
30#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31#include "tensorflow/lite/kernels/kernel_util.h"
32
33namespace tflite {
34namespace ops {
35namespace custom {
36namespace detection_postprocess {
37
38// Input tensors
39constexpr int kInputTensorBoxEncodings = 0;
40constexpr int kInputTensorClassPredictions = 1;
41constexpr int kInputTensorAnchors = 2;
42
43// Output tensors
44// When max_classes_per_detection > 1, detection boxes will be replicated by the
45// number of detected classes of that box. Dummy data will be appended if the
46// number of classes is smaller than max_classes_per_detection.
47constexpr int kOutputTensorDetectionBoxes = 0;
48constexpr int kOutputTensorDetectionClasses = 1;
49constexpr int kOutputTensorDetectionScores = 2;
50constexpr int kOutputTensorNumDetections = 3;
51
52constexpr int kNumCoordBox = 4;
53constexpr int kBatchSize = 1;
54
55constexpr int kNumDetectionsPerClass = 100;
56
57// Object Detection model produces axis-aligned boxes in two formats:
58// BoxCorner represents the lower left corner (xmin, ymin) and
59// the upper right corner (xmax, ymax).
60// CenterSize represents the center (xcenter, ycenter), height and width.
61// BoxCornerEncoding and CenterSizeEncoding are related as follows:
62// ycenter = y / y_scale * anchor.h + anchor.y;
63// xcenter = x / x_scale * anchor.w + anchor.x;
64// half_h = 0.5*exp(h/ h_scale)) * anchor.h;
65// half_w = 0.5*exp(w / w_scale)) * anchor.w;
66// ymin = ycenter - half_h
67// ymax = ycenter + half_h
68// xmin = xcenter - half_w
69// xmax = xcenter + half_w
70struct BoxCornerEncoding {
71 float ymin;
72 float xmin;
73 float ymax;
74 float xmax;
75};
76
77struct CenterSizeEncoding {
78 float y;
79 float x;
80 float h;
81 float w;
82};
83// We make sure that the memory allocations are contiguous with static assert.
84static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
85 "Size of BoxCornerEncoding is 4 float values");
86static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
87 "Size of CenterSizeEncoding is 4 float values");
88
89struct OpData {
90 int max_detections;
91 int max_classes_per_detection; // Fast Non-Max-Suppression
92 int detections_per_class; // Regular Non-Max-Suppression
93 float non_max_suppression_score_threshold;
94 float intersection_over_union_threshold;
95 int num_classes;
96 bool use_regular_non_max_suppression;
97 CenterSizeEncoding scale_values;
98 // Indices of Temporary tensors
99 int decoded_boxes_index;
100 int scores_index;
101};
102
103void* Init(TfLiteContext* context, const char* buffer, size_t length) {
104 auto* op_data = new OpData;
105 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
106 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
107 op_data->max_detections = m["max_detections"].AsInt32();
108 op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
109 if (m["detections_per_class"].IsNull())
110 op_data->detections_per_class = kNumDetectionsPerClass;
111 else
112 op_data->detections_per_class = m["detections_per_class"].AsInt32();
113 if (m["use_regular_nms"].IsNull())
114 op_data->use_regular_non_max_suppression = false;
115 else
116 op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
117
118 op_data->non_max_suppression_score_threshold =
119 m["nms_score_threshold"].AsFloat();
120 op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
121 op_data->num_classes = m["num_classes"].AsInt32();
122 op_data->scale_values.y = m["y_scale"].AsFloat();
123 op_data->scale_values.x = m["x_scale"].AsFloat();
124 op_data->scale_values.h = m["h_scale"].AsFloat();
125 op_data->scale_values.w = m["w_scale"].AsFloat();
126 context->AddTensors(context, 1, &op_data->decoded_boxes_index);
127 context->AddTensors(context, 1, &op_data->scores_index);
128 return op_data;
129}
130
131void Free(TfLiteContext* context, void* buffer) {
132 delete static_cast<OpData*>(buffer);
133}
134
135TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
136 std::initializer_list<int> values) {
137 TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
138 int index = 0;
139 for (const auto& v : values) {
140 size->data[index] = v;
141 ++index;
142 }
143 return context->ResizeTensor(context, tensor, size);
144}
145
146TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
147 auto* op_data = static_cast<OpData*>(node->user_data);
148 // Inputs: box_encodings, scores, anchors
149 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
150 const TfLiteTensor* input_box_encodings;
151 TF_LITE_ENSURE_OK(context,
152 GetInputSafe(context, node, kInputTensorBoxEncodings,
153 &input_box_encodings));
154 const TfLiteTensor* input_class_predictions;
155 TF_LITE_ENSURE_OK(context,
156 GetInputSafe(context, node, kInputTensorClassPredictions,
157 &input_class_predictions));
158 const TfLiteTensor* input_anchors;
159 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
160 &input_anchors));
161 TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
162 TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
163 TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
164 // number of detected boxes
165 const int num_detected_boxes =
166 op_data->max_detections * op_data->max_classes_per_detection;
167
168 // Outputs: detection_boxes, detection_scores, detection_classes,
169 // num_detections
170 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
171 // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
172 TfLiteTensor* detection_boxes;
173 TF_LITE_ENSURE_OK(context,
174 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
175 &detection_boxes));
176 detection_boxes->type = kTfLiteFloat32;
177 SetTensorSizes(context, detection_boxes,
178 {kBatchSize, num_detected_boxes, kNumCoordBox});
179
180 // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
181 TfLiteTensor* detection_classes;
182 TF_LITE_ENSURE_OK(context,
183 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
184 &detection_classes));
185 detection_classes->type = kTfLiteFloat32;
186 SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
187
188 // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
189 TfLiteTensor* detection_scores;
190 TF_LITE_ENSURE_OK(context,
191 GetOutputSafe(context, node, kOutputTensorDetectionScores,
192 &detection_scores));
193 detection_scores->type = kTfLiteFloat32;
194 SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
195
196 // Output Tensor num_detections: size is set to 1
197 TfLiteTensor* num_detections;
198 TF_LITE_ENSURE_OK(context,
199 GetOutputSafe(context, node, kOutputTensorNumDetections,
200 &num_detections));
201 num_detections->type = kTfLiteFloat32;
202 SetTensorSizes(context, num_detections, {1});
203
204 // Temporary tensors
205 TfLiteIntArrayFree(node->temporaries);
206 node->temporaries = TfLiteIntArrayCreate(2);
207 node->temporaries->data[0] = op_data->decoded_boxes_index;
208 node->temporaries->data[1] = op_data->scores_index;
209
210 // decoded_boxes
211 TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
212 decoded_boxes->type = kTfLiteFloat32;
213 decoded_boxes->allocation_type = kTfLiteArenaRw;
214 SetTensorSizes(context, decoded_boxes,
215 {input_box_encodings->dims->data[1], kNumCoordBox});
216
217 // scores
218 TfLiteTensor* scores = &context->tensors[op_data->scores_index];
219 scores->type = kTfLiteFloat32;
220 scores->allocation_type = kTfLiteArenaRw;
221 SetTensorSizes(context, scores,
222 {input_class_predictions->dims->data[1],
223 input_class_predictions->dims->data[2]});
224
225 return kTfLiteOk;
226}
227
228class Dequantizer {
229 public:
230 Dequantizer(int zero_point, float scale)
231 : zero_point_(zero_point), scale_(scale) {}
232 float operator()(uint8 x) {
233 return (static_cast<float>(x) - zero_point_) * scale_;
234 }
235
236 private:
237 int zero_point_;
238 float scale_;
239};
240
241void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
242 float quant_zero_point, float quant_scale,
243 int length_box_encoding,
244 CenterSizeEncoding* box_centersize) {
245 const uint8* boxes =
246 GetTensorData<uint8>(input_box_encodings) + length_box_encoding * idx;
247 Dequantizer dequantize(quant_zero_point, quant_scale);
248 // See definition of the KeyPointBoxCoder at
249 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
250 // The first four elements are the box coordinates, which is the same as the
251 // FastRnnBoxCoder at
252 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
253 box_centersize->y = dequantize(boxes[0]);
254 box_centersize->x = dequantize(boxes[1]);
255 box_centersize->h = dequantize(boxes[2]);
256 box_centersize->w = dequantize(boxes[3]);
257}
258
259template <class T>
260T ReInterpretTensor(const TfLiteTensor* tensor) {
261 const float* tensor_base = GetTensorData<float>(tensor);
262 return reinterpret_cast<T>(tensor_base);
263}
264
265template <class T>
266T ReInterpretTensor(TfLiteTensor* tensor) {
267 float* tensor_base = GetTensorData<float>(tensor);
268 return reinterpret_cast<T>(tensor_base);
269}
270
271TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
272 OpData* op_data) {
273 // Parse input tensor boxencodings
274 const TfLiteTensor* input_box_encodings;
275 TF_LITE_ENSURE_OK(context,
276 GetInputSafe(context, node, kInputTensorBoxEncodings,
277 &input_box_encodings));
278 TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
279 const int num_boxes = input_box_encodings->dims->data[1];
280 TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
281 const TfLiteTensor* input_anchors;
282 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
283 &input_anchors));
284
285 // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
286 CenterSizeEncoding box_centersize;
287 CenterSizeEncoding scale_values = op_data->scale_values;
288 CenterSizeEncoding anchor;
289 for (int idx = 0; idx < num_boxes; ++idx) {
290 switch (input_box_encodings->type) {
291 // Quantized
292 case kTfLiteUInt8:
293 DequantizeBoxEncodings(
294 input_box_encodings, idx,
295 static_cast<float>(input_box_encodings->params.zero_point),
296 static_cast<float>(input_box_encodings->params.scale),
297 input_box_encodings->dims->data[2], &box_centersize);
298 DequantizeBoxEncodings(
299 input_anchors, idx,
300 static_cast<float>(input_anchors->params.zero_point),
301 static_cast<float>(input_anchors->params.scale), kNumCoordBox,
302 &anchor);
303 break;
304 // Float
305 case kTfLiteFloat32: {
306 // Please see DequantizeBoxEncodings function for the support detail.
307 const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
308 const float* boxes =
309 &(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
310 box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
311 TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
312 anchor =
313 ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
314 break;
315 }
316 default:
317 // Unsupported type.
318 return kTfLiteError;
319 }
320
321 float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
322 static_cast<double>(scale_values.y) *
323 static_cast<double>(anchor.h) +
324 static_cast<double>(anchor.y));
325
326 float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
327 static_cast<double>(scale_values.x) *
328 static_cast<double>(anchor.w) +
329 static_cast<double>(anchor.x));
330
331 float half_h =
332 static_cast<float>(0.5 *
333 (std::exp(static_cast<double>(box_centersize.h) /
334 static_cast<double>(scale_values.h))) *
335 static_cast<double>(anchor.h));
336 float half_w =
337 static_cast<float>(0.5 *
338 (std::exp(static_cast<double>(box_centersize.w) /
339 static_cast<double>(scale_values.w))) *
340 static_cast<double>(anchor.w));
341
342 TfLiteTensor* decoded_boxes =
343 &context->tensors[op_data->decoded_boxes_index];
344 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
345 auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
346 box.ymin = ycenter - half_h;
347 box.xmin = xcenter - half_w;
348 box.ymax = ycenter + half_h;
349 box.xmax = xcenter + half_w;
350 }
351 return kTfLiteOk;
352}
353
354void DecreasingPartialArgSort(const float* values, int num_values,
355 int num_to_sort, int* indices) {
356 if (num_to_sort == 1) {
357 indices[0] = optimized_ops::ArgMaxVector(values, num_values);
358 } else {
359 std::iota(indices, indices + num_values, 0);
360 std::partial_sort(
361 indices, indices + num_to_sort, indices + num_values,
362 [&values](const int i, const int j) { return values[i] > values[j]; });
363 }
364}
365
366void DecreasingArgSort(const float* values, int num_values, int* indices) {
367 std::iota(indices, indices + num_values, 0);
368
369 // We want here a stable sort, in order to get completely defined output.
370 // In this way TFL and TFLM can be bit-exact.
371 std::stable_sort(
372 indices, indices + num_values,
373 [&values](const int i, const int j) { return values[i] > values[j]; });
374}
375
376void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
377 const float threshold,
378 std::vector<float>* keep_values,
379 std::vector<int>* keep_indices) {
380 for (int i = 0; i < values.size(); i++) {
381 if (values[i] >= threshold) {
382 keep_values->emplace_back(values[i]);
383 keep_indices->emplace_back(i);
384 }
385 }
386}
387
388bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
389 for (int i = 0; i < num_boxes; ++i) {
390 auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
391 // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
392 // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
393 // area is <= 0.
394 if (box.ymin > box.ymax || box.xmin > box.xmax) {
395 return false;
396 }
397 }
398 return true;
399}
400
401float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
402 const int i, const int j) {
403 auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
404 auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
405 const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
406 const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
407 if (area_i <= 0 || area_j <= 0) return 0.0;
408 const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
409 const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
410 const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
411 const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
412 const float intersection_area =
413 std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
414 std::max<float>(intersection_xmax - intersection_xmin, 0.0);
415 return intersection_area / (area_i + area_j - intersection_area);
416}
417
418// NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
419// before selecting the highest scoring boxes (max_detections in number)
420// It assumes all boxes are good in beginning and sorts based on the scores.
421// If lower-scoring box has too much overlap with a higher-scoring box,
422// we get rid of the lower-scoring box.
423// Complexity is O(N^2) pairwise comparison between boxes
424TfLiteStatus NonMaxSuppressionSingleClassHelper(
425 TfLiteContext* context, TfLiteNode* node, OpData* op_data,
426 const std::vector<float>& scores, int max_detections,
427 std::vector<int>* selected) {
428 const TfLiteTensor* input_box_encodings;
429 TF_LITE_ENSURE_OK(context,
430 GetInputSafe(context, node, kInputTensorBoxEncodings,
431 &input_box_encodings));
432 const TfLiteTensor* decoded_boxes =
433 &context->tensors[op_data->decoded_boxes_index];
434 const int num_boxes = input_box_encodings->dims->data[1];
435 const float non_max_suppression_score_threshold =
436 op_data->non_max_suppression_score_threshold;
437 const float intersection_over_union_threshold =
438 op_data->intersection_over_union_threshold;
439 // Maximum detections should be positive.
440 TF_LITE_ENSURE(context, (max_detections >= 0));
441 // intersection_over_union_threshold should be positive
442 // and should be less than 1.
443 TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
444 (intersection_over_union_threshold <= 1.0f));
445 // Validate boxes
446 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
447 TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
448
449 // threshold scores
450 std::vector<int> keep_indices;
451 // TODO(b/177068807): Remove the dynamic allocation and replace it
452 // with temporaries, esp for std::vector<float>
453 std::vector<float> keep_scores;
454 SelectDetectionsAboveScoreThreshold(
455 scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
456
457 int num_scores_kept = keep_scores.size();
458 std::vector<int> sorted_indices;
459 sorted_indices.resize(num_scores_kept);
460 DecreasingArgSort(keep_scores.data(), num_scores_kept, sorted_indices.data());
461
462 const int num_boxes_kept = num_scores_kept;
463 const int output_size = std::min(num_boxes_kept, max_detections);
464 selected->clear();
465 int num_active_candidate = num_boxes_kept;
466 std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1);
467
468 for (int i = 0; i < num_boxes_kept; ++i) {
469 if (num_active_candidate == 0 || selected->size() >= output_size) break;
470 if (active_box_candidate[i] == 1) {
471 selected->push_back(keep_indices[sorted_indices[i]]);
472 active_box_candidate[i] = 0;
473 num_active_candidate--;
474 } else {
475 continue;
476 }
477 for (int j = i + 1; j < num_boxes_kept; ++j) {
478 if (active_box_candidate[j] == 1) {
479 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
480 float intersection_over_union = ComputeIntersectionOverUnion(
481 decoded_boxes, keep_indices[sorted_indices[i]],
482 keep_indices[sorted_indices[j]]);
483
484 if (intersection_over_union > intersection_over_union_threshold) {
485 active_box_candidate[j] = 0;
486 num_active_candidate--;
487 }
488 }
489 }
490 }
491 return kTfLiteOk;
492}
493
494struct BoxInfo {
495 int index;
496 float score;
497};
498
499struct NMSTaskParam {
500 // Caller retains the ownership of `context`, `node`, `op_data` and `scores`.
501 // Caller should ensure their lifetime is longer than NMSTaskParam instance.
502 TfLiteContext* context;
503 TfLiteNode* node;
504 OpData* op_data;
505 const float* scores;
506
507 int num_classes;
508 int num_boxes;
509 int label_offset;
510 int num_classes_with_background;
511 int num_detections_per_class;
512 int max_detections;
513 std::vector<int>& num_selected;
514};
515
516void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index,
517 int end_index) {
518 std::inplace_merge(
519 boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index,
520 [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; });
521}
522
523TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin,
524 int col_end, int& sorted_indices_size,
525 std::vector<BoxInfo>& resulted_sorted_box_info) {
526 std::vector<float> class_scores(nms_task_param.num_boxes);
527 std::vector<int> selected;
528 selected.reserve(nms_task_param.num_detections_per_class);
529
530 for (int col = col_begin; col <= col_end; ++col) {
531 const float* scores_base =
532 nms_task_param.scores + col + nms_task_param.label_offset;
533 for (int row = 0; row < nms_task_param.num_boxes; row++) {
534 // Get scores of boxes corresponding to all anchors for single class
535 class_scores[row] = *scores_base;
536 scores_base += nms_task_param.num_classes_with_background;
537 }
538
539 // Perform non-maximal suppression on single class
540 selected.clear();
541 TF_LITE_ENSURE_OK(
542 nms_task_param.context,
543 NonMaxSuppressionSingleClassHelper(
544 nms_task_param.context, nms_task_param.node, nms_task_param.op_data,
545 class_scores, nms_task_param.num_detections_per_class, &selected));
546 if (selected.empty()) {
547 continue;
548 }
549
550 for (int i = 0; i < selected.size(); ++i) {
551 resulted_sorted_box_info[sorted_indices_size + i].score =
552 class_scores[selected[i]];
553 resulted_sorted_box_info[sorted_indices_size + i].index =
554 (selected[i] * nms_task_param.num_classes_with_background + col +
555 nms_task_param.label_offset);
556 }
557
558 // In-place merge the original boxes and new selected boxes which are both
559 // sorted by scores.
560 InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size,
561 sorted_indices_size + selected.size());
562
563 sorted_indices_size =
564 std::min(sorted_indices_size + static_cast<int>(selected.size()),
565 nms_task_param.max_detections);
566 }
567 return kTfLiteOk;
568}
569
570struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task {
571 NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param,
572 std::atomic<int>& next_col, int col_begin)
573 : nms_task_param(nms_task_param),
574 next_col(next_col),
575 col_begin(col_begin),
576 sorted_indices_size(0) {}
577 void Run() override {
578 sorted_box_info.resize(nms_task_param.num_detections_per_class +
579 nms_task_param.max_detections);
580 for (int col = col_begin; col < nms_task_param.num_classes;
581 col = (++next_col)) {
582 if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size,
583 sorted_box_info) != kTfLiteOk) {
584 break;
585 }
586 }
587 }
588 NMSTaskParam& nms_task_param;
589 // A shared atomic variable across threads, representing the next col this
590 // task will work on after completing the work for 'col_begin'
591 std::atomic<int>& next_col;
592 const int col_begin;
593 int sorted_indices_size;
594 std::vector<BoxInfo> sorted_box_info;
595};
596
597// This function implements a regular version of Non Maximal Suppression (NMS)
598// for multiple classes where
599// 1) we do NMS separately for each class across all anchors and
600// 2) keep only the highest anchor scores across all classes
601// 3) The worst runtime of the regular NMS is O(K*N^2)
602// where N is the number of anchors and K the number of
603// classes.
604TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
605 TfLiteNode* node,
606 OpData* op_data,
607 const float* scores) {
608 const TfLiteTensor* input_box_encodings;
609 TF_LITE_ENSURE_OK(context,
610 GetInputSafe(context, node, kInputTensorBoxEncodings,
611 &input_box_encodings));
612 const TfLiteTensor* input_class_predictions;
613 TF_LITE_ENSURE_OK(context,
614 GetInputSafe(context, node, kInputTensorClassPredictions,
615 &input_class_predictions));
616 const TfLiteTensor* decoded_boxes =
617 &context->tensors[op_data->decoded_boxes_index];
618
619 TfLiteTensor* detection_boxes;
620 TF_LITE_ENSURE_OK(context,
621 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
622 &detection_boxes));
623 TfLiteTensor* detection_classes;
624 TF_LITE_ENSURE_OK(context,
625 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
626 &detection_classes));
627 TfLiteTensor* detection_scores;
628 TF_LITE_ENSURE_OK(context,
629 GetOutputSafe(context, node, kOutputTensorDetectionScores,
630 &detection_scores));
631 TfLiteTensor* num_detections;
632 TF_LITE_ENSURE_OK(context,
633 GetOutputSafe(context, node, kOutputTensorNumDetections,
634 &num_detections));
635
636 const int num_boxes = input_box_encodings->dims->data[1];
637 const int num_classes = op_data->num_classes;
638 const int num_detections_per_class =
639 std::min(op_data->detections_per_class, op_data->max_detections);
640 const int max_detections = op_data->max_detections;
641 const int num_classes_with_background =
642 input_class_predictions->dims->data[2];
643 // The row index offset is 1 if background class is included and 0 otherwise.
644 int label_offset = num_classes_with_background - num_classes;
645 TF_LITE_ENSURE(context, num_detections_per_class > 0);
646
647 int sorted_indices_size = 0;
648 std::vector<BoxInfo> box_info_after_regular_non_max_suppression(
649 max_detections + num_detections_per_class);
650 std::vector<int> num_selected(num_classes);
651
652 NMSTaskParam nms_task_param{context,
653 node,
654 op_data,
655 scores,
656 num_classes,
657 num_boxes,
658 label_offset,
659 num_classes_with_background,
660 num_detections_per_class,
661 max_detections,
662 num_selected};
663
664 int num_threads =
665 CpuBackendContext::GetFromContext(context)->max_num_threads();
666 if (num_threads == 1) {
667 // For each class, perform non-max suppression.
668 TF_LITE_ENSURE_OK(
669 context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0,
670 num_classes - 1, sorted_indices_size,
671 box_info_after_regular_non_max_suppression));
672 } else {
673 std::atomic<int> next_col(num_threads);
674 std::vector<NonMaxSuppressionWorkerTask> tasks;
675 tasks.reserve(num_threads);
676 for (int i = 0; i < num_threads; ++i) {
677 tasks.emplace_back(
678 NonMaxSuppressionWorkerTask(nms_task_param, next_col, i));
679 }
680 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
681 CpuBackendContext::GetFromContext(context));
682
683 // Merge results from tasks.
684 for (int j = 0; j < tasks.size(); ++j) {
685 if (tasks[j].sorted_indices_size == 0) {
686 continue;
687 }
688 memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size],
689 &tasks[j].sorted_box_info[0],
690 sizeof(BoxInfo) * tasks[j].sorted_indices_size);
691 InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression,
692 sorted_indices_size,
693 sorted_indices_size + tasks[j].sorted_indices_size);
694 sorted_indices_size = std::min(
695 sorted_indices_size + tasks[j].sorted_indices_size, max_detections);
696 }
697 }
698
699 // Allocate output tensors
700 for (int output_box_index = 0; output_box_index < max_detections;
701 output_box_index++) {
702 if (output_box_index < sorted_indices_size) {
703 const int anchor_index = floor(
704 box_info_after_regular_non_max_suppression[output_box_index].index /
705 num_classes_with_background);
706 const int class_index =
707 box_info_after_regular_non_max_suppression[output_box_index].index -
708 anchor_index * num_classes_with_background - label_offset;
709 const float selected_score =
710 box_info_after_regular_non_max_suppression[output_box_index].score;
711 // detection_boxes
712 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
713 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
714 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
715 ReInterpretTensor<const BoxCornerEncoding*>(
716 decoded_boxes)[anchor_index];
717 // detection_classes
718 GetTensorData<float>(detection_classes)[output_box_index] = class_index;
719 // detection_scores
720 GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
721 } else {
722 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
723 ReInterpretTensor<BoxCornerEncoding*>(
724 detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
725 // detection_classes
726 GetTensorData<float>(detection_classes)[output_box_index] = 0.0f;
727 // detection_scores
728 GetTensorData<float>(detection_scores)[output_box_index] = 0.0f;
729 }
730 }
731 GetTensorData<float>(num_detections)[0] = sorted_indices_size;
732 box_info_after_regular_non_max_suppression.clear();
733 return kTfLiteOk;
734}
735
736// This function implements a fast version of Non Maximal Suppression for
737// multiple classes where
738// 1) we keep the top-k scores for each anchor and
739// 2) during NMS, each anchor only uses the highest class score for sorting.
740// 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
741// instead of O(KN^2) where N is the number of anchors and K the number of
742// classes.
743TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
744 TfLiteNode* node,
745 OpData* op_data,
746 const float* scores) {
747 const TfLiteTensor* input_box_encodings;
748 TF_LITE_ENSURE_OK(context,
749 GetInputSafe(context, node, kInputTensorBoxEncodings,
750 &input_box_encodings));
751 const TfLiteTensor* input_class_predictions;
752 TF_LITE_ENSURE_OK(context,
753 GetInputSafe(context, node, kInputTensorClassPredictions,
754 &input_class_predictions));
755 const TfLiteTensor* decoded_boxes =
756 &context->tensors[op_data->decoded_boxes_index];
757
758 TfLiteTensor* detection_boxes;
759 TF_LITE_ENSURE_OK(context,
760 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
761 &detection_boxes));
762 TfLiteTensor* detection_classes;
763 TF_LITE_ENSURE_OK(context,
764 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
765 &detection_classes));
766 TfLiteTensor* detection_scores;
767 TF_LITE_ENSURE_OK(context,
768 GetOutputSafe(context, node, kOutputTensorDetectionScores,
769 &detection_scores));
770 TfLiteTensor* num_detections;
771 TF_LITE_ENSURE_OK(context,
772 GetOutputSafe(context, node, kOutputTensorNumDetections,
773 &num_detections));
774
775 const int num_boxes = input_box_encodings->dims->data[1];
776 const int num_classes = op_data->num_classes;
777 const int max_categories_per_anchor = op_data->max_classes_per_detection;
778 const int num_classes_with_background =
779 input_class_predictions->dims->data[2];
780 // The row index offset is 1 if background class is included and 0 otherwise.
781 int label_offset = num_classes_with_background - num_classes;
782 TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
783 const int num_categories_per_anchor =
784 std::min(max_categories_per_anchor, num_classes);
785 std::vector<float> max_scores;
786 max_scores.resize(num_boxes);
787 std::vector<int> sorted_class_indices;
788 sorted_class_indices.resize(num_boxes * num_classes);
789 for (int row = 0; row < num_boxes; row++) {
790 const float* box_scores =
791 scores + row * num_classes_with_background + label_offset;
792 int* class_indices = sorted_class_indices.data() + row * num_classes;
793 DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
794 class_indices);
795 max_scores[row] = box_scores[class_indices[0]];
796 }
797 // Perform non-maximal suppression on max scores
798 std::vector<int> selected;
799 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
800 context, node, op_data, max_scores, op_data->max_detections, &selected));
801 // Allocate output tensors
802 int output_box_index = 0;
803 for (const auto& selected_index : selected) {
804 const float* box_scores =
805 scores + selected_index * num_classes_with_background + label_offset;
806 const int* class_indices =
807 sorted_class_indices.data() + selected_index * num_classes;
808
809 for (int col = 0; col < num_categories_per_anchor; ++col) {
810 int box_offset = max_categories_per_anchor * output_box_index + col;
811 // detection_boxes
812 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
813 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
814 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
815 ReInterpretTensor<const BoxCornerEncoding*>(
816 decoded_boxes)[selected_index];
817 // detection_classes
818 GetTensorData<float>(detection_classes)[box_offset] = class_indices[col];
819 // detection_scores
820 GetTensorData<float>(detection_scores)[box_offset] =
821 box_scores[class_indices[col]];
822 }
823 output_box_index++;
824 }
825 GetTensorData<float>(num_detections)[0] = output_box_index;
826 return kTfLiteOk;
827}
828
829void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
830 const int num_boxes,
831 const int num_classes_with_background,
832 TfLiteTensor* scores) {
833 float quant_zero_point =
834 static_cast<float>(input_class_predictions->params.zero_point);
835 float quant_scale = static_cast<float>(input_class_predictions->params.scale);
836 tflite::DequantizationParams op_params;
837 op_params.zero_point = quant_zero_point;
838 op_params.scale = quant_scale;
839 const auto shape = RuntimeShape(1, num_boxes * num_classes_with_background);
840 optimized_ops::Dequantize(op_params, shape,
841 GetTensorData<uint8>(input_class_predictions),
842 shape, GetTensorData<float>(scores));
843}
844
845TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
846 TfLiteNode* node, OpData* op_data) {
847 // Get the input tensors
848 const TfLiteTensor* input_box_encodings;
849 TF_LITE_ENSURE_OK(context,
850 GetInputSafe(context, node, kInputTensorBoxEncodings,
851 &input_box_encodings));
852 const TfLiteTensor* input_class_predictions;
853 TF_LITE_ENSURE_OK(context,
854 GetInputSafe(context, node, kInputTensorClassPredictions,
855 &input_class_predictions));
856 const int num_boxes = input_box_encodings->dims->data[1];
857 const int num_classes = op_data->num_classes;
858 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
859 kBatchSize);
860 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
861 const int num_classes_with_background =
862 input_class_predictions->dims->data[2];
863
864 TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
865 TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
866
867 const TfLiteTensor* scores;
868 switch (input_class_predictions->type) {
869 case kTfLiteUInt8: {
870 TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
871 DequantizeClassPredictions(input_class_predictions, num_boxes,
872 num_classes_with_background, temporary_scores);
873 scores = temporary_scores;
874 } break;
875 case kTfLiteFloat32:
876 scores = input_class_predictions;
877 break;
878 default:
879 // Unsupported type.
880 return kTfLiteError;
881 }
882 if (op_data->use_regular_non_max_suppression)
883 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
884 context, node, op_data, GetTensorData<float>(scores)));
885 else
886 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(
887 context, node, op_data, GetTensorData<float>(scores)));
888
889 return kTfLiteOk;
890}
891
892TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
893 // TODO(b/177068051): Generalize for any batch size.
894 TF_LITE_ENSURE(context, (kBatchSize == 1));
895 auto* op_data = static_cast<OpData*>(node->user_data);
896 // These two functions correspond to two blocks in the Object Detection model.
897 // In future, we would like to break the custom op in two blocks, which is
898 // currently not feasible because we would like to input quantized inputs
899 // and do all calculations in float. Mixed quantized/float calculations are
900 // currently not supported in TFLite.
901
902 // This fills in temporary decoded_boxes
903 // by transforming input_box_encodings and input_anchors from
904 // CenterSizeEncodings to BoxCornerEncoding
905 TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
906 // This fills in the output tensors
907 // by choosing effective set of decoded boxes
908 // based on Non Maximal Suppression, i.e. selecting
909 // highest scoring non-overlapping boxes.
910 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
911
912 return kTfLiteOk;
913}
914} // namespace detection_postprocess
915
916TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
917 static TfLiteRegistration r = {
918 detection_postprocess::Init, detection_postprocess::Free,
919 detection_postprocess::Prepare, detection_postprocess::Eval};
920 return &r;
921}
922
923// Since the op is named "TFLite_Detection_PostProcess", the selective build
924// tool will assume the register function is named
925// "Register_TFLITE_DETECTION_POST_PROCESS".
926TfLiteRegistration* Register_TFLITE_DETECTION_POST_PROCESS() {
927 return Register_DETECTION_POSTPROCESS();
928}
929
930} // namespace custom
931} // namespace ops
932} // namespace tflite
933