1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file dilation2d.cc |
22 | * \brief Morphological dilation operator |
23 | */ |
24 | #include <tvm/relay/attrs/image.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | |
28 | #include "../op_common.h" |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | |
33 | // relay.image.dilation2d |
34 | TVM_REGISTER_NODE_TYPE(Dilation2DAttrs); |
35 | |
36 | template <typename T> |
37 | InferCorrectLayoutOutput Dilation2DInferCorrectLayout(const Attrs& attrs, |
38 | const Array<Layout>& new_in_layouts, |
39 | const Array<Layout>& old_in_layouts, |
40 | const Array<tvm::relay::Type>& old_in_types) { |
41 | const T* params = attrs.as<T>(); |
42 | return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout}, |
43 | {params->data_layout}, attrs); |
44 | } |
45 | |
46 | // Positional relay function to create dilation2d operator |
47 | // used by frontend FFI. |
48 | Expr MakeDilation2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
49 | Array<IndexExpr> dilations, String data_layout, String kernel_layout, |
50 | DataType out_dtype) { |
51 | auto attrs = make_object<Dilation2DAttrs>(); |
52 | attrs->strides = std::move(strides); |
53 | attrs->padding = std::move(padding); |
54 | attrs->dilations = std::move(dilations); |
55 | attrs->data_layout = std::move(data_layout); |
56 | attrs->kernel_layout = std::move(kernel_layout); |
57 | attrs->out_dtype = std::move(out_dtype); |
58 | static const Op& op = Op::Get("image.dilation2d" ); |
59 | return Call(op, {data, weight}, Attrs(attrs), {}); |
60 | } |
61 | |
62 | template <typename AttrType> |
63 | bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
64 | const TypeReporter& reporter) { |
65 | ICHECK_EQ(types.size(), 3); |
66 | const auto* data = types[0].as<TensorTypeNode>(); |
67 | const auto* weight = types[1].as<TensorTypeNode>(); |
68 | if (data == nullptr) return false; |
69 | static const Layout kNCHW("NCHW" ); |
70 | static const Layout kOIHW("IHW" ); |
71 | |
72 | const AttrType* param = attrs.as<AttrType>(); |
73 | ICHECK(param != nullptr); |
74 | const Layout in_layout(param->data_layout); |
75 | const Layout kernel_layout(param->kernel_layout); |
76 | |
77 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
78 | ICHECK(trans_in_layout.defined()) |
79 | << "Dilation2D only support input layouts that are convertible from NCHW." |
80 | << " But got " << in_layout; |
81 | |
82 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); |
83 | ICHECK(trans_kernel_layout.defined()) |
84 | << "Dilation2D only support kernel layouts that are convertible from OIHW." |
85 | << " But got " << kernel_layout; |
86 | |
87 | Layout out_layout(param->data_layout); |
88 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); |
89 | ICHECK(trans_out_layout.defined()) |
90 | << "Dilation2D only support output layouts that are convertible from NCHW." |
91 | << " But got " << out_layout; |
92 | |
93 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
94 | |
95 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
96 | |
97 | // use weight to infer the conv shape. |
98 | if (weight == nullptr) return false; |
99 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
100 | channels = wshape[0]; |
101 | |
102 | dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0]; |
103 | dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1]; |
104 | |
105 | // dilation |
106 | Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
107 | IndexExpr pad_h, pad_w; |
108 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
109 | if (!dshape_nchw[2].as<tir::AnyNode>()) { |
110 | oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); |
111 | } else { |
112 | oshape.Set(2, dshape_nchw[2]); |
113 | } |
114 | |
115 | if (!dshape_nchw[3].as<tir::AnyNode>()) { |
116 | oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); |
117 | } else { |
118 | oshape.Set(3, dshape_nchw[3]); |
119 | } |
120 | |
121 | DataType out_dtype = param->out_dtype; |
122 | if (out_dtype.bits() == 0) { |
123 | out_dtype = data->dtype; |
124 | } |
125 | oshape = trans_out_layout.BackwardShape(oshape); |
126 | // assign output type |
127 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
128 | return true; |
129 | } |
130 | |
131 | TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d" ).set_body_typed(MakeDilation2D); |
132 | |
133 | RELAY_REGISTER_OP("image.dilation2d" ) |
134 | .describe(R"code(Computes grayscale dilation of 4D input and 3D filter. |
135 | - **data**: This depends on the `layout` parameter. Input is 4D array of shape |
136 | (batch_size, in_channels, height, width) if `layout` is `NCHW`. |
137 | - **weight**: (in_channels, height, width) |
138 | - **out**: This depends on the `layout` parameter. Output is 4D array of shape |
139 | (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. |
140 | )code" TVM_ADD_FILELINE) |
141 | .set_attrs_type<Dilation2DAttrs>() |
142 | .set_num_inputs(2) |
143 | .add_argument("data" , "Tensor" , "The input tensor." ) |
144 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
145 | .set_support_level(2) |
146 | .add_type_rel("Dilation2D" , Dilation2DRel<Dilation2DAttrs>) |
147 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
148 | Dilation2DInferCorrectLayout<Dilation2DAttrs>); |
149 | |
150 | } // namespace relay |
151 | } // namespace tvm |
152 | |