1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file yolo.cc |
22 | * \brief Yolo related operators |
23 | */ |
24 | #include <tvm/relay/attrs/vision.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/topi/vision/reorg.h> |
27 | |
28 | #include <vector> |
29 | |
30 | #include "../op_common.h" |
31 | #include "../type_relations.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | TVM_REGISTER_NODE_TYPE(YoloReorgAttrs); |
37 | |
38 | /*! |
39 | * \brief YoloReorgRel Output type and shape relation evaluation function. |
40 | * \param num_inputs Number of input types in the args. |
41 | * \param attrs The additional attributes of the operator. |
42 | * \param reporter The reporter to report solution to. |
43 | * \return false if This relation cannot be resolved. true if this relation has been resolved. |
44 | */ |
45 | bool YoloReorgRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
46 | const TypeReporter& reporter) { |
47 | ICHECK_EQ(types.size(), 2); |
48 | const auto* data = types[0].as<TensorTypeNode>(); |
49 | if (data == nullptr) return false; |
50 | |
51 | const YoloReorgAttrs* param = attrs.as<YoloReorgAttrs>(); |
52 | ICHECK(param != nullptr); |
53 | |
54 | ICHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension." ; |
55 | std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); |
56 | oshape[1] = oshape[1] * param->stride * param->stride; |
57 | oshape[2] = indexdiv(oshape[2], param->stride); |
58 | oshape[3] = indexdiv(oshape[3], param->stride); |
59 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
60 | return true; |
61 | } |
62 | |
63 | Expr MakeYoloReorg(Expr data, Integer stride) { |
64 | auto attrs = make_object<YoloReorgAttrs>(); |
65 | attrs->stride = stride; |
66 | static const Op& op = Op::Get("vision.yolo_reorg" ); |
67 | return Call(op, {data}, Attrs(attrs), {}); |
68 | } |
69 | |
70 | TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg" ).set_body_typed(MakeYoloReorg); |
71 | |
72 | RELAY_REGISTER_OP("vision.yolo_reorg" ) |
73 | .describe(R"doc("Yolo reorg operation. This layer reorganize the output. |
74 | Its function is mostly shape transform.")doc" TVM_ADD_FILELINE) |
75 | .add_argument("data" , "Tensor" , "The input tensor." ) |
76 | .set_num_inputs(1) |
77 | .set_support_level(5) |
78 | .set_attrs_type<YoloReorgAttrs>() |
79 | .add_type_rel("YoloReorg" , YoloReorgRel) |
80 | .set_attr<FTVMCompute>("FTVMCompute" , [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
81 | const Type& out_type) { |
82 | const auto* params = attrs.as<YoloReorgAttrs>(); |
83 | ICHECK(params != nullptr); |
84 | return Array<te::Tensor>{topi::vision::reorg(inputs[0], params->stride.IntValue())}; |
85 | }); |
86 | |
87 | } // namespace relay |
88 | } // namespace tvm |
89 | |