1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file nms.cc |
22 | * \brief Non-maximum suppression operators |
23 | */ |
24 | #include <tvm/relay/attrs/vision.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); |
32 | |
33 | bool GetValidCountRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
34 | const TypeReporter& reporter) { |
35 | ICHECK_EQ(types.size(), 3); |
36 | const auto* data = types[0].as<TensorTypeNode>(); |
37 | if (data == nullptr) return false; |
38 | const auto& dshape = data->shape; |
39 | ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D." ; |
40 | |
41 | std::vector<IndexExpr> oshape({data->shape[0]}); |
42 | std::vector<IndexExpr> oshape_indices({data->shape[0], data->shape[1]}); |
43 | std::vector<Type> fields; |
44 | fields.push_back(TensorType(oshape, DataType::Int(32))); |
45 | fields.push_back(TensorType(data->shape, data->dtype)); |
46 | fields.push_back(TensorType(oshape_indices, DataType::Int(32))); |
47 | |
48 | // assign output type |
49 | reporter->Assign(types[2], TupleType(Array<Type>(fields))); |
50 | return true; |
51 | } |
52 | |
53 | Expr MakeGetValidCounts(Expr data, Expr score_threshold, int id_index, int score_index) { |
54 | auto attrs = make_object<GetValidCountsAttrs>(); |
55 | attrs->id_index = id_index; |
56 | attrs->score_index = score_index; |
57 | static const Op& op = Op::Get("vision.get_valid_counts" ); |
58 | return Call(op, {data, score_threshold}, Attrs(attrs), {}); |
59 | } |
60 | |
61 | TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts" ).set_body_typed(MakeGetValidCounts); |
62 | |
63 | RELAY_REGISTER_OP("vision.get_valid_counts" ) |
64 | .describe(R"doc(Get valid count of bounding boxes given |
65 | a score threshold. Also moves valid boxes to the top of |
66 | input data. |
67 | )doc" TVM_ADD_FILELINE) |
68 | .set_num_inputs(2) |
69 | .add_argument("data" , "Tensor" , "Input data." ) |
70 | .add_argument("score_threshold" , "Tensor" , "Minimum Score." ) |
71 | .set_support_level(5) |
72 | .add_type_rel("GetValidCount" , GetValidCountRel); |
73 | |
74 | TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); |
75 | |
76 | bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
77 | const TypeReporter& reporter) { |
78 | ICHECK_EQ(types.size(), 6); |
79 | const auto* data = types[0].as<TensorTypeNode>(); |
80 | if (data == nullptr) return false; |
81 | const auto* valid_count = types[1].as<TensorTypeNode>(); |
82 | if (valid_count == nullptr) return false; |
83 | const NonMaximumSuppressionAttrs* param = attrs.as<NonMaximumSuppressionAttrs>(); |
84 | const auto& dshape = data->shape; |
85 | const auto& vshape = valid_count->shape; |
86 | ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D." ; |
87 | ICHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D." ; |
88 | |
89 | // assign output type |
90 | if (param->return_indices) { |
91 | std::vector<Type> fields; |
92 | // dynamic happens for return_indices in TensorFlow & ONNX |
93 | std::vector<IndexExpr> oshape({dshape[0], dshape[1]}); |
94 | fields.push_back(TensorType(oshape, DataType::Int(32))); |
95 | std::vector<IndexExpr> countshape({dshape[0], 1}); |
96 | fields.push_back(TensorType(countshape, DataType::Int(32))); |
97 | reporter->Assign(types[5], TupleType(Array<Type>(fields))); |
98 | } else { |
99 | reporter->Assign(types[5], TensorType(dshape, data->dtype)); |
100 | } |
101 | return true; |
102 | } |
103 | |
104 | Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, Expr iou_threshold, |
105 | bool force_suppress, int top_k, int coord_start, int score_index, int id_index, |
106 | bool return_indices, bool invalid_to_bottom) { |
107 | auto attrs = make_object<NonMaximumSuppressionAttrs>(); |
108 | attrs->force_suppress = force_suppress; |
109 | attrs->top_k = top_k; |
110 | attrs->coord_start = coord_start; |
111 | attrs->score_index = score_index; |
112 | attrs->id_index = id_index; |
113 | attrs->return_indices = return_indices; |
114 | attrs->invalid_to_bottom = invalid_to_bottom; |
115 | static const Op& op = Op::Get("vision.non_max_suppression" ); |
116 | return Call(op, {data, valid_count, indices, max_output_size, iou_threshold}, Attrs(attrs), {}); |
117 | } |
118 | |
119 | TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression" ).set_body_typed(MakeNMS); |
120 | |
121 | RELAY_REGISTER_OP("vision.non_max_suppression" ) |
122 | .describe(R"doc(Non-maximum suppression. The input boxes should |
123 | be in the format of [class_id, score, left, top, right, bottom] |
124 | or [score, left, top, right, bottom]. Set id_index to be -1 to |
125 | ignore class_id axis. |
126 | )doc" TVM_ADD_FILELINE) |
127 | .set_num_inputs(5) |
128 | .add_argument("data" , "Tensor" , "Input data." ) |
129 | .add_argument("valid_count" , "Tensor" , "Number of valid anchor boxes." ) |
130 | .add_argument("indices" , "Tensor" , "Corresponding indices in original input tensor." ) |
131 | .add_argument("max_output_size" , "Tensor" , "Max number of output valid boxes." ) |
132 | .add_argument("iou_threshold" , "Tensor" , "Threshold for box overlap." ) |
133 | .set_support_level(5) |
134 | .add_type_rel("NMS" , NMSRel); |
135 | |
136 | TVM_REGISTER_NODE_TYPE(AllClassNonMaximumSuppressionAttrs); |
137 | |
138 | bool AllClassNMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
139 | const TypeReporter& reporter) { |
140 | ICHECK_EQ(types.size(), 6); |
141 | const auto* boxes = types[0].as<TensorTypeNode>(); |
142 | if (boxes == nullptr) return false; |
143 | const auto* scores = types[1].as<TensorTypeNode>(); |
144 | if (scores == nullptr) return false; |
145 | |
146 | const auto& boxes_shape = boxes->shape; |
147 | const auto& scores_shape = scores->shape; |
148 | ICHECK_EQ(boxes_shape.size(), 3) << "Input boxes should be 3-D." ; |
149 | ICHECK_EQ(scores_shape.size(), 3) << "Input scores count should be 3-D." ; |
150 | |
151 | IndexExpr batch = boxes_shape[0]; |
152 | IndexExpr num_classes = scores_shape[1]; |
153 | IndexExpr num_boxes = boxes_shape[1]; |
154 | |
155 | const auto* param = attrs.as<AllClassNonMaximumSuppressionAttrs>(); |
156 | CHECK(param); |
157 | |
158 | std::vector<Type> fields; |
159 | if (param->output_format == "onnx" ) { |
160 | IndexExpr num_total_boxes = Any(); |
161 | if (!batch.as<AnyNode>() && !num_boxes.as<AnyNode>()) { |
162 | num_total_boxes = batch * num_classes * num_boxes; |
163 | } |
164 | std::vector<IndexExpr> oshape{num_total_boxes, 3}; |
165 | std::vector<IndexExpr> counts_shape{1}; |
166 | fields.push_back(TensorType(oshape, DataType::Int(64))); |
167 | fields.push_back(TensorType(counts_shape, DataType::Int(64))); |
168 | } else { |
169 | IndexExpr num_total_boxes_per_batch = Any(); |
170 | if (!num_boxes.as<AnyNode>()) { |
171 | num_total_boxes_per_batch = num_classes * num_boxes; |
172 | } |
173 | std::vector<IndexExpr> indices_shape{batch, num_total_boxes_per_batch, 2}; |
174 | std::vector<IndexExpr> scores_shape{batch, num_total_boxes_per_batch}; |
175 | std::vector<IndexExpr> counts_shape{batch}; |
176 | fields.push_back(TensorType(indices_shape, DataType::Int(64))); |
177 | fields.push_back(TensorType(scores_shape, DataType::Float(32))); |
178 | fields.push_back(TensorType(counts_shape, DataType::Int(64))); |
179 | } |
180 | reporter->Assign(types[5], TupleType(Array<Type>(fields))); |
181 | return true; |
182 | } |
183 | |
184 | Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, |
185 | Expr score_threshold, std::string output_format = "onnx" ) { |
186 | auto attrs = make_object<AllClassNonMaximumSuppressionAttrs>(); |
187 | attrs->output_format = std::move(output_format); |
188 | static const Op& op = Op::Get("vision.all_class_non_max_suppression" ); |
189 | return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, |
190 | Attrs(attrs), {}); |
191 | } |
192 | |
193 | TVM_REGISTER_GLOBAL("relay.op.vision._make.all_class_non_max_suppression" ) |
194 | .set_body_typed(MakeAllClassNMS); |
195 | |
196 | RELAY_REGISTER_OP("vision.all_class_non_max_suppression" ) |
197 | .describe(R"doc(Non-maximum suppression operator for object detection, corresponding to ONNX |
198 | NonMaxSuppression and TensorFlow combined_non_max_suppression. |
199 | NMS is performed for each class separately |
200 | )doc" TVM_ADD_FILELINE) |
201 | .set_num_inputs(5) |
202 | .add_argument("boxes" , "Tensor" , "The input boxes in the format [batch, num_boxes, 4]." ) |
203 | .add_argument("scores" , "Tensor" , |
204 | "Scores for each box and class in the format [batch, num_classes, num_boxes]." ) |
205 | .add_argument("max_output_boxes_per_class" , "Tensor" , |
206 | "The maximum number of output boxes per class." ) |
207 | .add_argument("iou_threshold" , "Tensor" , "The IoU threshold for box the overlap test." ) |
208 | .add_argument("score_threshold" , "Tensor" , |
209 | "The score threshold to filter out low score boxes early." ) |
210 | .set_support_level(5) |
211 | .add_type_rel("AllClassNMS" , AllClassNMSRel); |
212 | |
213 | } // namespace relay |
214 | } // namespace tvm |
215 | |