1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/non_max_suppression.h"
16
17#include <initializer_list>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace non_max_suppression {
28
29// Boxes in format [y1, x1, y2, x2]. Shape: [num_boxes, 4]
30// Type: Float.
31constexpr int kInputTensorBoxes = 0;
32// Shape: [num_boxes]
33// Type: Float.
34constexpr int kInputTensorScores = 1;
35// Max number of boxes to output. Actual output can be smaller.
36// The output tensors (indices/scores) are of this length.
37// Type: Int32.
38constexpr int kInputTensorMaxOutputSize = 2;
39// Type: Float.
40constexpr int kInputTensorIouThreshold = 3;
41// Type: Float.
42constexpr int kInputTensorScoreThreshold = 4;
43// Only applies to NON_MAX_SUPPRESSION_V5.
44// Type: Float.
45constexpr int kInputTensorSigma = 5;
46
47// Indices of selected boxes. Shape: [num_selected_indices]
48// Type: Int32.
49constexpr int kNMSOutputTensorSelectedIndices = 0;
50// Type: Int32.
51constexpr int kNMSOutputTensorNumSelectedIndices = 1;
52
53// Indices of selected boxes. Shape: [num_selected_indices]
54// Type: Int32.
55constexpr int kSoftNMSOutputTensorSelectedIndices = 0;
56// Scores of selected boxes. Shape: [num_selected_indices]
57// Type: Float.
58constexpr int kSoftNMSOutputTensorSelectedScores = 1;
59// Type: Int32.
60constexpr int kSoftNMSOutputTensorNumSelectedIndices = 2;
61
62TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
63 std::initializer_list<int> values) {
64 TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
65 int index = 0;
66 for (const auto& v : values) {
67 size->data[index++] = v;
68 }
69 return context->ResizeTensor(context, tensor, size);
70}
71
72TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
73 const int num_inputs = NumInputs(node);
74 const bool is_soft_nms = num_inputs == 6;
75 if (num_inputs != 5 && num_inputs != 6) {
76 TF_LITE_KERNEL_LOG(context, "Found NMS op with invalid num inputs: %d",
77 NumInputs(node));
78 return kTfLiteError;
79 }
80
81 // Boxes & Scores.
82 const TfLiteTensor* input_boxes;
83 TF_LITE_ENSURE_OK(
84 context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
85 TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32);
86 TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2);
87 TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4);
88 const int num_boxes = SizeOfDimension(input_boxes, 0);
89 const TfLiteTensor* input_scores;
90 TF_LITE_ENSURE_OK(
91 context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
92 TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32);
93 TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1);
94 TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0));
95
96 // Max output size.
97 const TfLiteTensor* input_max_output_size;
98 TF_LITE_ENSURE_OK(context,
99 GetInputSafe(context, node, kInputTensorMaxOutputSize,
100 &input_max_output_size));
101 TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
102 TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
103 const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
104 int max_output_size_value = 0;
105 if (is_max_output_size_const) {
106 max_output_size_value = *GetTensorData<int>(input_max_output_size);
107 TF_LITE_ENSURE(context, (max_output_size_value >= 0));
108 }
109
110 // IoU & Score thresholds.
111 const TfLiteTensor* input_iou_threshold;
112 TF_LITE_ENSURE_OK(context,
113 GetInputSafe(context, node, kInputTensorIouThreshold,
114 &input_iou_threshold));
115 TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
116 TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0);
117 const TfLiteTensor* input_score_threshold;
118 TF_LITE_ENSURE_OK(context,
119 GetInputSafe(context, node, kInputTensorScoreThreshold,
120 &input_score_threshold));
121 TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
122 TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0);
123
124 if (is_soft_nms) {
125 const TfLiteTensor* input_sigma;
126 TF_LITE_ENSURE_OK(
127 context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
128 TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32);
129 TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0);
130
131 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3);
132 TfLiteTensor* output_selected_indices;
133 TF_LITE_ENSURE_OK(
134 context,
135 GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
136 &output_selected_indices));
137 output_selected_indices->type = kTfLiteInt32;
138 TfLiteTensor* output_selected_scores;
139 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
140 kSoftNMSOutputTensorSelectedScores,
141 &output_selected_scores));
142 output_selected_scores->type = kTfLiteFloat32;
143 TfLiteTensor* output_num_selected_indices;
144 TF_LITE_ENSURE_OK(
145 context,
146 GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
147 &output_num_selected_indices));
148 output_num_selected_indices->type = kTfLiteInt32;
149 SetTensorSizes(context, output_num_selected_indices, {});
150
151 if (is_max_output_size_const) {
152 SetTensorSizes(context, output_selected_indices, {max_output_size_value});
153 SetTensorSizes(context, output_selected_scores, {max_output_size_value});
154 } else {
155 SetTensorToDynamic(output_selected_indices);
156 SetTensorToDynamic(output_selected_scores);
157 }
158 } else {
159 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
160 TfLiteTensor* output_selected_indices;
161 TF_LITE_ENSURE_OK(
162 context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
163 &output_selected_indices));
164 output_selected_indices->type = kTfLiteInt32;
165 TfLiteTensor* output_num_selected_indices;
166 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
167 kNMSOutputTensorNumSelectedIndices,
168 &output_num_selected_indices));
169 output_num_selected_indices->type = kTfLiteInt32;
170 SetTensorSizes(context, output_num_selected_indices, {});
171
172 if (is_max_output_size_const) {
173 SetTensorSizes(context, output_selected_indices, {max_output_size_value});
174 } else {
175 SetTensorToDynamic(output_selected_indices);
176 }
177 }
178
179 return kTfLiteOk;
180}
181
182// If num_selected_indices < max_output_size, the output tensor can contain
183// garbage values initially present in memory. This causes segfault in
184// downstream ops such as GATHER, since one of the outputs denotes indices and
185// int garbage values can be pretty large. This method zeroes-out the remaining
186// values.
187// NOTE: We ensure memory being reset is valid, by setting pertinent output
188// tensors to max_output_size length in Prepare.
189void ResetUnusedElementsToZeroes(const int max_output_size,
190 const int num_selected_indices,
191 int* selected_indices,
192 float* selected_scores) {
193 for (int i = num_selected_indices; i < max_output_size; ++i) {
194 selected_indices[i] = 0;
195 if (selected_scores) {
196 selected_scores[i] = 0.0;
197 }
198 }
199}
200
201TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
202 const bool is_soft_nms = NumInputs(node) == 6;
203
204 const TfLiteTensor* input_boxes;
205 TF_LITE_ENSURE_OK(
206 context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
207 const int num_boxes = SizeOfDimension(input_boxes, 0);
208 const TfLiteTensor* input_scores;
209 TF_LITE_ENSURE_OK(
210 context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
211 const TfLiteTensor* input_max_output_size;
212 TF_LITE_ENSURE_OK(context,
213 GetInputSafe(context, node, kInputTensorMaxOutputSize,
214 &input_max_output_size));
215 const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
216 TF_LITE_ENSURE(context, (max_output_size_value >= 0));
217 const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
218 const TfLiteTensor* input_iou_threshold;
219 TF_LITE_ENSURE_OK(context,
220 GetInputSafe(context, node, kInputTensorIouThreshold,
221 &input_iou_threshold));
222 const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
223 const TfLiteTensor* input_score_threshold;
224 TF_LITE_ENSURE_OK(context,
225 GetInputSafe(context, node, kInputTensorScoreThreshold,
226 &input_score_threshold));
227 const float score_threshold = *GetTensorData<float>(input_score_threshold);
228
229 TfLiteTensor* output_selected_indices = nullptr;
230 TfLiteTensor* output_selected_scores = nullptr;
231 TfLiteTensor* output_num_selected_indices = nullptr;
232
233 if (is_soft_nms) {
234 const TfLiteTensor* input_sigma;
235 TF_LITE_ENSURE_OK(
236 context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
237 const float soft_nms_sigma = *GetTensorData<float>(input_sigma);
238 if (soft_nms_sigma < 0) {
239 TF_LITE_KERNEL_LOG(context, "Invalid sigma value for soft NMS: %f",
240 soft_nms_sigma);
241 return kTfLiteError;
242 }
243
244 TF_LITE_ENSURE_OK(
245 context,
246 GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
247 &output_selected_indices));
248 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
249 kSoftNMSOutputTensorSelectedScores,
250 &output_selected_scores));
251 TF_LITE_ENSURE_OK(
252 context,
253 GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
254 &output_num_selected_indices));
255 if (!is_max_output_size_const) {
256 SetTensorSizes(context, output_selected_indices, {max_output_size_value});
257 SetTensorSizes(context, output_selected_scores, {max_output_size_value});
258 }
259 reference_ops::NonMaxSuppression(
260 input_boxes->data.f, num_boxes, input_scores->data.f,
261 max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
262 output_selected_indices->data.i32, output_selected_scores->data.f,
263 output_num_selected_indices->data.i32);
264 ResetUnusedElementsToZeroes(
265 max_output_size_value, *output_num_selected_indices->data.i32,
266 output_selected_indices->data.i32, output_selected_scores->data.f);
267 } else {
268 TF_LITE_ENSURE_OK(
269 context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
270 &output_selected_indices));
271 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
272 kNMSOutputTensorNumSelectedIndices,
273 &output_num_selected_indices));
274 if (!is_max_output_size_const) {
275 SetTensorSizes(context, output_selected_indices, {max_output_size_value});
276 }
277 reference_ops::NonMaxSuppression(
278 input_boxes->data.f, num_boxes, input_scores->data.f,
279 max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,
280 output_selected_indices->data.i32, /**selected_scores=**/ nullptr,
281 output_num_selected_indices->data.i32);
282 ResetUnusedElementsToZeroes(max_output_size_value,
283 *output_num_selected_indices->data.i32,
284 output_selected_indices->data.i32, nullptr);
285 }
286
287 return kTfLiteOk;
288}
289} // namespace non_max_suppression
290
291TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4() {
292 static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare,
293 non_max_suppression::Eval};
294 return &r;
295}
296
297TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5() {
298 static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare,
299 non_max_suppression::Eval};
300 return &r;
301}
302
303} // namespace builtin
304} // namespace ops
305} // namespace tflite
306