1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file multibox_op.cc
22 * \brief Multibox related operators
23 */
24#include <tvm/relay/attrs/vision.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/op.h>
27
28namespace tvm {
29namespace relay {
30
31TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs);
32
33bool MultiboxPriorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
34 const TypeReporter& reporter) {
35 ICHECK_EQ(types.size(), 2);
36 const auto* data = types[0].as<TensorTypeNode>();
37 const MultiBoxPriorAttrs* param = attrs.as<MultiBoxPriorAttrs>();
38 const auto& dshape = data->shape;
39 ICHECK_EQ(dshape.size(), 4) << "Input data should be 4D: "
40 "[batch, channel, height, width]";
41 IndexExpr in_height = dshape[2];
42 IndexExpr in_width = dshape[3];
43 int num_sizes = static_cast<int>(param->sizes.size());
44 int num_ratios = static_cast<int>(param->ratios.size());
45
46 // since input sizes are same in each batch, we could share MultiBoxPrior
47 std::vector<IndexExpr> oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
48
49 // assign output type
50 reporter->Assign(types[1], TensorType(oshape, data->dtype));
51 return true;
52}
53
54Expr MakeMultiBoxPrior(Expr data, Array<IndexExpr> sizes, Array<IndexExpr> ratios,
55 Array<IndexExpr> steps, Array<IndexExpr> offsets, bool clip) {
56 auto attrs = make_object<MultiBoxPriorAttrs>();
57 attrs->sizes = std::move(sizes);
58 attrs->ratios = std::move(ratios);
59 attrs->steps = std::move(steps);
60 attrs->offsets = std::move(offsets);
61 attrs->clip = clip;
62 static const Op& op = Op::Get("vision.multibox_prior");
63 return Call(op, {data}, Attrs(attrs), {});
64}
65
66TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior").set_body_typed(MakeMultiBoxPrior);
67
68RELAY_REGISTER_OP("vision.multibox_prior")
69 .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
70)doc" TVM_ADD_FILELINE)
71 .set_attrs_type<MultiBoxPriorAttrs>()
72 .set_num_inputs(1)
73 .add_argument("data", "Tensor", "The input tensor.")
74 .set_support_level(5)
75 .add_type_rel("MultiBoxPrior", MultiboxPriorRel);
76
77TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
78
79bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
80 const TypeReporter& reporter) {
81 ICHECK_EQ(types.size(), 4);
82
83 const auto* cls_prob = types[0].as<TensorTypeNode>();
84 const auto* loc_pred = types[1].as<TensorTypeNode>();
85 const auto* anchor = types[2].as<TensorTypeNode>();
86
87 if (cls_prob == nullptr || loc_pred == nullptr || anchor == nullptr) {
88 return false;
89 }
90
91 const auto& cls_shape = cls_prob->shape;
92 const auto& loc_shape = loc_pred->shape;
93 const auto& anchor_shape = anchor->shape;
94
95 ICHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received "
96 << cls_shape.size();
97 ICHECK_EQ(loc_shape.size(), 2U)
98 << "The dimension of location prediction should be 2, but received " << loc_shape.size();
99 ICHECK_EQ(anchor_shape.size(), 3U)
100 << "The dimension of anchor should be 3, but received " << anchor_shape.size();
101
102 ICHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found";
103 ICHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc.";
104 ICHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0.";
105 ICHECK(reporter->AssertEQ(anchor_shape[2], 4));
106
107 std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
108 std::vector<IndexExpr> oshape1({cls_shape[0]});
109 std::vector<Type> fields;
110 fields.push_back(TensorType(oshape0, cls_prob->dtype));
111 fields.push_back(TensorType(oshape1, DataType::Int(32)));
112
113 // assign output type
114 reporter->Assign(types[3], TupleType(Array<Type>(fields)));
115 return true;
116}
117
118Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip,
119 double threshold, Array<IndexExpr> variances) {
120 auto attrs = make_object<MultiBoxTransformLocAttrs>();
121 attrs->clip = std::move(clip);
122 attrs->threshold = std::move(threshold);
123 attrs->variances = std::move(variances);
124 static const Op& op = Op::Get("vision.multibox_transform_loc");
125 return Call(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {});
126}
127
128TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc")
129 .set_body_typed(MakeMultiBoxTransformLoc);
130
131RELAY_REGISTER_OP("vision.multibox_transform_loc")
132 .describe(R"doc("Location transformation for multibox detection."
133)doc" TVM_ADD_FILELINE)
134 .set_attrs_type<MultiBoxTransformLocAttrs>()
135 .set_num_inputs(3)
136 .add_argument("cls_prob", "Tensor", "Class probabilities.")
137 .add_argument("loc_pred", "Tensor", "Location regression predictions.")
138 .add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
139 .add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel)
140 .set_support_level(5);
141
142} // namespace relay
143} // namespace tvm
144