1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file nms.cc
22 * \brief Non-maximum suppression operators
23 */
24#include <tvm/relay/attrs/vision.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/op.h>
27
28namespace tvm {
29namespace relay {
30
31TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
32
33bool GetValidCountRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
34 const TypeReporter& reporter) {
35 ICHECK_EQ(types.size(), 3);
36 const auto* data = types[0].as<TensorTypeNode>();
37 if (data == nullptr) return false;
38 const auto& dshape = data->shape;
39 ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
40
41 std::vector<IndexExpr> oshape({data->shape[0]});
42 std::vector<IndexExpr> oshape_indices({data->shape[0], data->shape[1]});
43 std::vector<Type> fields;
44 fields.push_back(TensorType(oshape, DataType::Int(32)));
45 fields.push_back(TensorType(data->shape, data->dtype));
46 fields.push_back(TensorType(oshape_indices, DataType::Int(32)));
47
48 // assign output type
49 reporter->Assign(types[2], TupleType(Array<Type>(fields)));
50 return true;
51}
52
53Expr MakeGetValidCounts(Expr data, Expr score_threshold, int id_index, int score_index) {
54 auto attrs = make_object<GetValidCountsAttrs>();
55 attrs->id_index = id_index;
56 attrs->score_index = score_index;
57 static const Op& op = Op::Get("vision.get_valid_counts");
58 return Call(op, {data, score_threshold}, Attrs(attrs), {});
59}
60
61TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts);
62
63RELAY_REGISTER_OP("vision.get_valid_counts")
64 .describe(R"doc(Get valid count of bounding boxes given
65a score threshold. Also moves valid boxes to the top of
66input data.
67)doc" TVM_ADD_FILELINE)
68 .set_num_inputs(2)
69 .add_argument("data", "Tensor", "Input data.")
70 .add_argument("score_threshold", "Tensor", "Minimum Score.")
71 .set_support_level(5)
72 .add_type_rel("GetValidCount", GetValidCountRel);
73
74TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
75
76bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
77 const TypeReporter& reporter) {
78 ICHECK_EQ(types.size(), 6);
79 const auto* data = types[0].as<TensorTypeNode>();
80 if (data == nullptr) return false;
81 const auto* valid_count = types[1].as<TensorTypeNode>();
82 if (valid_count == nullptr) return false;
83 const NonMaximumSuppressionAttrs* param = attrs.as<NonMaximumSuppressionAttrs>();
84 const auto& dshape = data->shape;
85 const auto& vshape = valid_count->shape;
86 ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
87 ICHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";
88
89 // assign output type
90 if (param->return_indices) {
91 std::vector<Type> fields;
92 // dynamic happens for return_indices in TensorFlow & ONNX
93 std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
94 fields.push_back(TensorType(oshape, DataType::Int(32)));
95 std::vector<IndexExpr> countshape({dshape[0], 1});
96 fields.push_back(TensorType(countshape, DataType::Int(32)));
97 reporter->Assign(types[5], TupleType(Array<Type>(fields)));
98 } else {
99 reporter->Assign(types[5], TensorType(dshape, data->dtype));
100 }
101 return true;
102}
103
104Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, Expr iou_threshold,
105 bool force_suppress, int top_k, int coord_start, int score_index, int id_index,
106 bool return_indices, bool invalid_to_bottom) {
107 auto attrs = make_object<NonMaximumSuppressionAttrs>();
108 attrs->force_suppress = force_suppress;
109 attrs->top_k = top_k;
110 attrs->coord_start = coord_start;
111 attrs->score_index = score_index;
112 attrs->id_index = id_index;
113 attrs->return_indices = return_indices;
114 attrs->invalid_to_bottom = invalid_to_bottom;
115 static const Op& op = Op::Get("vision.non_max_suppression");
116 return Call(op, {data, valid_count, indices, max_output_size, iou_threshold}, Attrs(attrs), {});
117}
118
119TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS);
120
121RELAY_REGISTER_OP("vision.non_max_suppression")
122 .describe(R"doc(Non-maximum suppression. The input boxes should
123be in the format of [class_id, score, left, top, right, bottom]
124or [score, left, top, right, bottom]. Set id_index to be -1 to
125ignore class_id axis.
126)doc" TVM_ADD_FILELINE)
127 .set_num_inputs(5)
128 .add_argument("data", "Tensor", "Input data.")
129 .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
130 .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.")
131 .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.")
132 .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.")
133 .set_support_level(5)
134 .add_type_rel("NMS", NMSRel);
135
136TVM_REGISTER_NODE_TYPE(AllClassNonMaximumSuppressionAttrs);
137
138bool AllClassNMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
139 const TypeReporter& reporter) {
140 ICHECK_EQ(types.size(), 6);
141 const auto* boxes = types[0].as<TensorTypeNode>();
142 if (boxes == nullptr) return false;
143 const auto* scores = types[1].as<TensorTypeNode>();
144 if (scores == nullptr) return false;
145
146 const auto& boxes_shape = boxes->shape;
147 const auto& scores_shape = scores->shape;
148 ICHECK_EQ(boxes_shape.size(), 3) << "Input boxes should be 3-D.";
149 ICHECK_EQ(scores_shape.size(), 3) << "Input scores count should be 3-D.";
150
151 IndexExpr batch = boxes_shape[0];
152 IndexExpr num_classes = scores_shape[1];
153 IndexExpr num_boxes = boxes_shape[1];
154
155 const auto* param = attrs.as<AllClassNonMaximumSuppressionAttrs>();
156 CHECK(param);
157
158 std::vector<Type> fields;
159 if (param->output_format == "onnx") {
160 IndexExpr num_total_boxes = Any();
161 if (!batch.as<AnyNode>() && !num_boxes.as<AnyNode>()) {
162 num_total_boxes = batch * num_classes * num_boxes;
163 }
164 std::vector<IndexExpr> oshape{num_total_boxes, 3};
165 std::vector<IndexExpr> counts_shape{1};
166 fields.push_back(TensorType(oshape, DataType::Int(64)));
167 fields.push_back(TensorType(counts_shape, DataType::Int(64)));
168 } else {
169 IndexExpr num_total_boxes_per_batch = Any();
170 if (!num_boxes.as<AnyNode>()) {
171 num_total_boxes_per_batch = num_classes * num_boxes;
172 }
173 std::vector<IndexExpr> indices_shape{batch, num_total_boxes_per_batch, 2};
174 std::vector<IndexExpr> scores_shape{batch, num_total_boxes_per_batch};
175 std::vector<IndexExpr> counts_shape{batch};
176 fields.push_back(TensorType(indices_shape, DataType::Int(64)));
177 fields.push_back(TensorType(scores_shape, DataType::Float(32)));
178 fields.push_back(TensorType(counts_shape, DataType::Int(64)));
179 }
180 reporter->Assign(types[5], TupleType(Array<Type>(fields)));
181 return true;
182}
183
184Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold,
185 Expr score_threshold, std::string output_format = "onnx") {
186 auto attrs = make_object<AllClassNonMaximumSuppressionAttrs>();
187 attrs->output_format = std::move(output_format);
188 static const Op& op = Op::Get("vision.all_class_non_max_suppression");
189 return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold},
190 Attrs(attrs), {});
191}
192
193TVM_REGISTER_GLOBAL("relay.op.vision._make.all_class_non_max_suppression")
194 .set_body_typed(MakeAllClassNMS);
195
196RELAY_REGISTER_OP("vision.all_class_non_max_suppression")
197 .describe(R"doc(Non-maximum suppression operator for object detection, corresponding to ONNX
198 NonMaxSuppression and TensorFlow combined_non_max_suppression.
199 NMS is performed for each class separately
200)doc" TVM_ADD_FILELINE)
201 .set_num_inputs(5)
202 .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].")
203 .add_argument("scores", "Tensor",
204 "Scores for each box and class in the format [batch, num_classes, num_boxes].")
205 .add_argument("max_output_boxes_per_class", "Tensor",
206 "The maximum number of output boxes per class.")
207 .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.")
208 .add_argument("score_threshold", "Tensor",
209 "The score threshold to filter out low score boxes early.")
210 .set_support_level(5)
211 .add_type_rel("AllClassNMS", AllClassNMSRel);
212
213} // namespace relay
214} // namespace tvm
215