1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file dilation2d.cc
22 * \brief Morphological dilation operator
23 */
24#include <tvm/relay/attrs/image.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27
28#include "../op_common.h"
29
30namespace tvm {
31namespace relay {
32
33// relay.image.dilation2d
34TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);
35
36template <typename T>
37InferCorrectLayoutOutput Dilation2DInferCorrectLayout(const Attrs& attrs,
38 const Array<Layout>& new_in_layouts,
39 const Array<Layout>& old_in_layouts,
40 const Array<tvm::relay::Type>& old_in_types) {
41 const T* params = attrs.as<T>();
42 return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout},
43 {params->data_layout}, attrs);
44}
45
46// Positional relay function to create dilation2d operator
47// used by frontend FFI.
48Expr MakeDilation2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
49 Array<IndexExpr> dilations, String data_layout, String kernel_layout,
50 DataType out_dtype) {
51 auto attrs = make_object<Dilation2DAttrs>();
52 attrs->strides = std::move(strides);
53 attrs->padding = std::move(padding);
54 attrs->dilations = std::move(dilations);
55 attrs->data_layout = std::move(data_layout);
56 attrs->kernel_layout = std::move(kernel_layout);
57 attrs->out_dtype = std::move(out_dtype);
58 static const Op& op = Op::Get("image.dilation2d");
59 return Call(op, {data, weight}, Attrs(attrs), {});
60}
61
62template <typename AttrType>
63bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
64 const TypeReporter& reporter) {
65 ICHECK_EQ(types.size(), 3);
66 const auto* data = types[0].as<TensorTypeNode>();
67 const auto* weight = types[1].as<TensorTypeNode>();
68 if (data == nullptr) return false;
69 static const Layout kNCHW("NCHW");
70 static const Layout kOIHW("IHW");
71
72 const AttrType* param = attrs.as<AttrType>();
73 ICHECK(param != nullptr);
74 const Layout in_layout(param->data_layout);
75 const Layout kernel_layout(param->kernel_layout);
76
77 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
78 ICHECK(trans_in_layout.defined())
79 << "Dilation2D only support input layouts that are convertible from NCHW."
80 << " But got " << in_layout;
81
82 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
83 ICHECK(trans_kernel_layout.defined())
84 << "Dilation2D only support kernel layouts that are convertible from OIHW."
85 << " But got " << kernel_layout;
86
87 Layout out_layout(param->data_layout);
88 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
89 ICHECK(trans_out_layout.defined())
90 << "Dilation2D only support output layouts that are convertible from NCHW."
91 << " But got " << out_layout;
92
93 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
94
95 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
96
97 // use weight to infer the conv shape.
98 if (weight == nullptr) return false;
99 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
100 channels = wshape[0];
101
102 dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0];
103 dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1];
104
105 // dilation
106 Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
107 IndexExpr pad_h, pad_w;
108 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
109 if (!dshape_nchw[2].as<tir::AnyNode>()) {
110 oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
111 } else {
112 oshape.Set(2, dshape_nchw[2]);
113 }
114
115 if (!dshape_nchw[3].as<tir::AnyNode>()) {
116 oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
117 } else {
118 oshape.Set(3, dshape_nchw[3]);
119 }
120
121 DataType out_dtype = param->out_dtype;
122 if (out_dtype.bits() == 0) {
123 out_dtype = data->dtype;
124 }
125 oshape = trans_out_layout.BackwardShape(oshape);
126 // assign output type
127 reporter->Assign(types[2], TensorType(oshape, out_dtype));
128 return true;
129}
130
131TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d").set_body_typed(MakeDilation2D);
132
133RELAY_REGISTER_OP("image.dilation2d")
134 .describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
135- **data**: This depends on the `layout` parameter. Input is 4D array of shape
136 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
137- **weight**: (in_channels, height, width)
138- **out**: This depends on the `layout` parameter. Output is 4D array of shape
139 (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
140)code" TVM_ADD_FILELINE)
141 .set_attrs_type<Dilation2DAttrs>()
142 .set_num_inputs(2)
143 .add_argument("data", "Tensor", "The input tensor.")
144 .add_argument("weight", "Tensor", "The weight tensor.")
145 .set_support_level(2)
146 .add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
147 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
148 Dilation2DInferCorrectLayout<Dilation2DAttrs>);
149
150} // namespace relay
151} // namespace tvm
152