1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file multibox_op.cc |
22 | * \brief Multibox related operators |
23 | */ |
24 | #include <tvm/relay/attrs/vision.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs); |
32 | |
33 | bool MultiboxPriorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
34 | const TypeReporter& reporter) { |
35 | ICHECK_EQ(types.size(), 2); |
36 | const auto* data = types[0].as<TensorTypeNode>(); |
37 | const MultiBoxPriorAttrs* param = attrs.as<MultiBoxPriorAttrs>(); |
38 | const auto& dshape = data->shape; |
39 | ICHECK_EQ(dshape.size(), 4) << "Input data should be 4D: " |
40 | "[batch, channel, height, width]" ; |
41 | IndexExpr in_height = dshape[2]; |
42 | IndexExpr in_width = dshape[3]; |
43 | int num_sizes = static_cast<int>(param->sizes.size()); |
44 | int num_ratios = static_cast<int>(param->ratios.size()); |
45 | |
46 | // since input sizes are same in each batch, we could share MultiBoxPrior |
47 | std::vector<IndexExpr> oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); |
48 | |
49 | // assign output type |
50 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
51 | return true; |
52 | } |
53 | |
54 | Expr MakeMultiBoxPrior(Expr data, Array<IndexExpr> sizes, Array<IndexExpr> ratios, |
55 | Array<IndexExpr> steps, Array<IndexExpr> offsets, bool clip) { |
56 | auto attrs = make_object<MultiBoxPriorAttrs>(); |
57 | attrs->sizes = std::move(sizes); |
58 | attrs->ratios = std::move(ratios); |
59 | attrs->steps = std::move(steps); |
60 | attrs->offsets = std::move(offsets); |
61 | attrs->clip = clip; |
62 | static const Op& op = Op::Get("vision.multibox_prior" ); |
63 | return Call(op, {data}, Attrs(attrs), {}); |
64 | } |
65 | |
66 | TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior" ).set_body_typed(MakeMultiBoxPrior); |
67 | |
68 | RELAY_REGISTER_OP("vision.multibox_prior" ) |
69 | .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." |
70 | )doc" TVM_ADD_FILELINE) |
71 | .set_attrs_type<MultiBoxPriorAttrs>() |
72 | .set_num_inputs(1) |
73 | .add_argument("data" , "Tensor" , "The input tensor." ) |
74 | .set_support_level(5) |
75 | .add_type_rel("MultiBoxPrior" , MultiboxPriorRel); |
76 | |
77 | TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); |
78 | |
79 | bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
80 | const TypeReporter& reporter) { |
81 | ICHECK_EQ(types.size(), 4); |
82 | |
83 | const auto* cls_prob = types[0].as<TensorTypeNode>(); |
84 | const auto* loc_pred = types[1].as<TensorTypeNode>(); |
85 | const auto* anchor = types[2].as<TensorTypeNode>(); |
86 | |
87 | if (cls_prob == nullptr || loc_pred == nullptr || anchor == nullptr) { |
88 | return false; |
89 | } |
90 | |
91 | const auto& cls_shape = cls_prob->shape; |
92 | const auto& loc_shape = loc_pred->shape; |
93 | const auto& anchor_shape = anchor->shape; |
94 | |
95 | ICHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received " |
96 | << cls_shape.size(); |
97 | ICHECK_EQ(loc_shape.size(), 2U) |
98 | << "The dimension of location prediction should be 2, but received " << loc_shape.size(); |
99 | ICHECK_EQ(anchor_shape.size(), 3U) |
100 | << "The dimension of anchor should be 3, but received " << anchor_shape.size(); |
101 | |
102 | ICHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found" ; |
103 | ICHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc." ; |
104 | ICHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0." ; |
105 | ICHECK(reporter->AssertEQ(anchor_shape[2], 4)); |
106 | |
107 | std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6}); |
108 | std::vector<IndexExpr> oshape1({cls_shape[0]}); |
109 | std::vector<Type> fields; |
110 | fields.push_back(TensorType(oshape0, cls_prob->dtype)); |
111 | fields.push_back(TensorType(oshape1, DataType::Int(32))); |
112 | |
113 | // assign output type |
114 | reporter->Assign(types[3], TupleType(Array<Type>(fields))); |
115 | return true; |
116 | } |
117 | |
118 | Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip, |
119 | double threshold, Array<IndexExpr> variances) { |
120 | auto attrs = make_object<MultiBoxTransformLocAttrs>(); |
121 | attrs->clip = std::move(clip); |
122 | attrs->threshold = std::move(threshold); |
123 | attrs->variances = std::move(variances); |
124 | static const Op& op = Op::Get("vision.multibox_transform_loc" ); |
125 | return Call(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {}); |
126 | } |
127 | |
128 | TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc" ) |
129 | .set_body_typed(MakeMultiBoxTransformLoc); |
130 | |
131 | RELAY_REGISTER_OP("vision.multibox_transform_loc" ) |
132 | .describe(R"doc("Location transformation for multibox detection." |
133 | )doc" TVM_ADD_FILELINE) |
134 | .set_attrs_type<MultiBoxTransformLocAttrs>() |
135 | .set_num_inputs(3) |
136 | .add_argument("cls_prob" , "Tensor" , "Class probabilities." ) |
137 | .add_argument("loc_pred" , "Tensor" , "Location regression predictions." ) |
138 | .add_argument("anchor" , "Tensor" , "Multibox prior anchor boxes" ) |
139 | .add_type_rel("MultiBoxTransformLoc" , MultiBoxTransformLocRel) |
140 | .set_support_level(5); |
141 | |
142 | } // namespace relay |
143 | } // namespace tvm |
144 | |