1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file pad.cc
22 * \brief Implementation of operator pad
23 */
24#include <tvm/relay/attrs/nn.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27#include <tvm/tir/op.h>
28#include <tvm/topi/elemwise.h>
29#include <tvm/topi/nn.h>
30
31#include <vector>
32
33#include "../make_op.h"
34#include "../op_common.h"
35
36namespace tvm {
37namespace relay {
38
39// relay.nn.pad
40TVM_REGISTER_NODE_TYPE(PadAttrs);
41
42InferCorrectLayoutOutput PadInferCorrectLayout(const Attrs& attrs,
43 const Array<Layout>& new_in_layouts,
44 const Array<Layout>& old_in_layouts,
45 const Array<tvm::relay::Type>& old_in_types) {
46 const auto* attrs_ptr = attrs.as<PadAttrs>();
47 CHECK(attrs_ptr);
48 ObjectPtr<PadAttrs> params = make_object<PadAttrs>(*attrs_ptr);
49
50 Layout ret_data;
51 // If new_in_layouts are defined, this code tries to modify the layout.
52 bool is_layout_modified = new_in_layouts.defined();
53 if (new_in_layouts.defined()) {
54 // Create a map of axis to param_width. For the new layout, a new param_width is generated using
55 // the map. The new layout is rejected, if the padding is happening along the axis which was
56 // split.
57
58 // 1) Create a map from axis to param_width using old layout.
59 std::map<std::string, tvm::Array<Integer>> axis_pad_width;
60 int index_counter = 0;
61 ICHECK_EQ(new_in_layouts.size(), 2);
62 ICHECK_EQ(old_in_layouts.size(), 2);
63 for (auto iter_var : old_in_layouts[0]->axes) {
64 const auto& old_layout_axis = LayoutAxis::Get(iter_var);
65 axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
66 index_counter++;
67 }
68
69 // 2) Create new pad width by walking over the new layout and using the map.
70 tvm::Array<tvm::Array<Integer>> new_pad_width;
71 for (auto iter_var : new_in_layouts[0]->axes) {
72 const auto& new_layout_axis = LayoutAxis::Get(iter_var);
73 auto axis_name = new_layout_axis.name();
74 if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
75 // This is primal axis. So, directly use the original pad_width.
76 new_pad_width.push_back(axis_pad_width.at(axis_name));
77 } else {
78 // This is the axis that got split. So, check that pad_width was [0, 0] originally.
79 const auto& dual_axis = new_layout_axis.ToPrimal();
80 auto dual_axis_name = dual_axis.name();
81 ICHECK(axis_pad_width.count(dual_axis_name))
82 << "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
83 new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
84
85 // If any pad_width element is not zero, do not change the layout.
86 for (auto width : axis_pad_width.at(dual_axis_name)) {
87 if (auto* width_imm = width.as<IntImmNode>()) {
88 if (width_imm->value != 0) {
89 is_layout_modified = false;
90 }
91 } else {
92 is_layout_modified = false;
93 }
94 }
95 }
96 }
97
98 // If the above conditions satisfied, we can set the newly created pad_width and use the new
99 // layout.
100 if (is_layout_modified) {
101 ret_data = new_in_layouts[0];
102 params->pad_width = new_pad_width;
103 }
104 }
105
106 if (!is_layout_modified) {
107 if (old_in_layouts.defined()) {
108 ICHECK_EQ(old_in_layouts.size(), 2);
109 ret_data = old_in_layouts[0];
110 } else {
111 ret_data = Layout::Undef();
112 }
113 }
114
115 // The pad value is always a scalar
116 Layout ret_pad_value = Layout("1");
117 return InferCorrectLayoutOutput({ret_data, ret_pad_value}, {ret_data}, Attrs(params));
118}
119
120bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
121 const TypeReporter& reporter) {
122 // types = [pad_data_type, pad_value_type, ret_type]
123 ICHECK_EQ(types.size(), 3);
124 const auto* data = types[0].as<TensorTypeNode>();
125 if (data == nullptr) return false;
126
127 const PadAttrs* param = attrs.as<PadAttrs>();
128 ICHECK(param != nullptr);
129
130 // check that pad widths match lengths
131 ICHECK(data->shape.size() == param->pad_width.size())
132 << "There should be as many pad width pairs as shape dimensions "
133 << "but the shape has " << data->shape.size() << " dimensions "
134 << "and there are " << param->pad_width.size() << " pad width pairs.";
135
136 // each pad width element should be a pair of positive integers
137 std::vector<IndexExpr> oshape;
138 for (size_t i = 0; i < param->pad_width.size(); i++) {
139 ICHECK(param->pad_width[i].size() == 2)
140 << "Each pad width element should be a pair but at index " << i << " there are "
141 << param->pad_width[i].size() << " elements.";
142
143 auto width1 = tir::as_const_int(param->pad_width[i][0]);
144 auto width2 = tir::as_const_int(param->pad_width[i][1]);
145 ICHECK(width1 != nullptr);
146 ICHECK(width2 != nullptr);
147
148 if (!data->shape[i].as<tir::AnyNode>()) {
149 auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
150 oshape.push_back(data->shape[i] + padding);
151 if (tir::as_const_int(data->shape[i])) {
152 ICHECK(topi::detail::GetConstInt(data->shape[i] + padding) >= 0)
153 << "Output shape post padding should be positive but got " << data->shape[i] + padding;
154 }
155 } else {
156 oshape.push_back(data->shape[i]);
157 }
158 }
159
160 reporter->Assign(types[2], TensorType(Array<IndexExpr>(oshape), data->dtype));
161 return true;
162}
163
164Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
165 const Type& out_type) {
166 const auto* param = attrs.as<PadAttrs>();
167 ICHECK(param != nullptr);
168
169 auto pad_width = param->pad_width;
170 ICHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width";
171 Array<IndexExpr> pad_before;
172 for (size_t i = 0; i < pad_width.size(); ++i) {
173 pad_before.push_back(pad_width[i][0]);
174 }
175 Array<IndexExpr> pad_after;
176 for (size_t i = 0; i < pad_width.size(); ++i) {
177 pad_after.push_back(pad_width[i][1]);
178 }
179 te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype);
180 const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>());
181 return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
182 topi::kElementWise, param->pad_mode)};
183}
184
185// Handler to create a call to the padding op used by front-end FFI
186Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode) {
187 auto attrs = make_object<PadAttrs>();
188 attrs->pad_width = std::move(pad_width);
189 attrs->pad_mode = std::move(pad_mode);
190 static const Op& op = Op::Get("nn.pad");
191 return Call(op, {data, pad_value}, Attrs(attrs), {});
192}
193
194TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad);
195
196RELAY_REGISTER_OP("nn.pad")
197 .describe(R"code(Pad for n-D tensor.
198
199)code" TVM_ADD_FILELINE)
200 .set_attrs_type<PadAttrs>()
201 .set_num_inputs(2)
202 .add_argument("data", "Tensor", "The input tensor.")
203 .add_argument("pad_val", "Tensor", "The value to fill the padded area with")
204 .set_support_level(2)
205 .add_type_rel("Pad", PadRel)
206 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
207 .set_attr<TOpPattern>("TOpPattern", kInjective)
208 .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
209
210// relay.nn.mirror_pad
211TVM_REGISTER_NODE_TYPE(MirrorPadAttrs);
212
213bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
214 const TypeReporter& reporter) {
215 ICHECK_EQ(types.size(), 2);
216 const auto* data = types[0].as<TensorTypeNode>();
217 if (data == nullptr) return false;
218
219 const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>();
220 ICHECK(param != nullptr);
221
222 // check that pad widths match lengths
223 ICHECK(data->shape.size() == param->pad_width.size())
224 << "There should be as many pad width pairs as shape dimensions "
225 << "but the shape has " << data->shape.size() << " dimensions "
226 << "and there are " << param->pad_width.size() << " pad width pairs.";
227
228 // each pad width element should be a pair of positive integers
229 std::vector<IndexExpr> oshape;
230 for (size_t i = 0; i < param->pad_width.size(); i++) {
231 ICHECK(param->pad_width[i].size() == 2)
232 << "Each pad width element should be a pair but at index " << i << " there are "
233 << param->pad_width[i].size() << " elements.";
234
235 auto width1 = tir::as_const_int(param->pad_width[i][0]);
236 auto width2 = tir::as_const_int(param->pad_width[i][1]);
237 ICHECK(width1 != nullptr);
238 ICHECK(width2 != nullptr);
239
240 ICHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at "
241 << "index " << i << " is " << *width1 << ".";
242 ICHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at "
243 << "index " << i << " is " << *width2 << ".";
244
245 auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
246 oshape.push_back(data->shape[i] + padding);
247 }
248
249 reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
250 return true;
251}
252
253// Handler to create a call to the padding op used by front-end FFI
254Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr>> pad_width, String mode) {
255 auto attrs = make_object<MirrorPadAttrs>();
256 attrs->mode = mode;
257 attrs->pad_width = std::move(pad_width);
258 static const Op& op = Op::Get("nn.mirror_pad");
259 return Call(op, {data}, Attrs(attrs), {});
260}
261
262TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad);
263
264RELAY_REGISTER_OP("nn.mirror_pad")
265 .describe(R"code(MirrorPad for n-D tensor.
266
267)code" TVM_ADD_FILELINE)
268 .set_attrs_type<MirrorPadAttrs>()
269 .set_num_inputs(1)
270 .add_argument("data", "Tensor", "The input tensor.")
271 .set_support_level(2)
272 .add_type_rel("MirrorPad", MirrorPadRel)
273 .set_attr<TOpPattern>("TOpPattern", kInjective);
274
275} // namespace relay
276} // namespace tvm
277