1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file pad.cc |
22 | * \brief Implementation of operator pad |
23 | */ |
24 | #include <tvm/relay/attrs/nn.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/topi/elemwise.h> |
29 | #include <tvm/topi/nn.h> |
30 | |
31 | #include <vector> |
32 | |
33 | #include "../make_op.h" |
34 | #include "../op_common.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | // relay.nn.pad |
40 | TVM_REGISTER_NODE_TYPE(PadAttrs); |
41 | |
42 | InferCorrectLayoutOutput PadInferCorrectLayout(const Attrs& attrs, |
43 | const Array<Layout>& new_in_layouts, |
44 | const Array<Layout>& old_in_layouts, |
45 | const Array<tvm::relay::Type>& old_in_types) { |
46 | const auto* attrs_ptr = attrs.as<PadAttrs>(); |
47 | CHECK(attrs_ptr); |
48 | ObjectPtr<PadAttrs> params = make_object<PadAttrs>(*attrs_ptr); |
49 | |
50 | Layout ret_data; |
51 | // If new_in_layouts are defined, this code tries to modify the layout. |
52 | bool is_layout_modified = new_in_layouts.defined(); |
53 | if (new_in_layouts.defined()) { |
54 | // Create a map of axis to param_width. For the new layout, a new param_width is generated using |
55 | // the map. The new layout is rejected, if the padding is happening along the axis which was |
56 | // split. |
57 | |
58 | // 1) Create a map from axis to param_width using old layout. |
59 | std::map<std::string, tvm::Array<Integer>> axis_pad_width; |
60 | int index_counter = 0; |
61 | ICHECK_EQ(new_in_layouts.size(), 2); |
62 | ICHECK_EQ(old_in_layouts.size(), 2); |
63 | for (auto iter_var : old_in_layouts[0]->axes) { |
64 | const auto& old_layout_axis = LayoutAxis::Get(iter_var); |
65 | axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]); |
66 | index_counter++; |
67 | } |
68 | |
69 | // 2) Create new pad width by walking over the new layout and using the map. |
70 | tvm::Array<tvm::Array<Integer>> new_pad_width; |
71 | for (auto iter_var : new_in_layouts[0]->axes) { |
72 | const auto& new_layout_axis = LayoutAxis::Get(iter_var); |
73 | auto axis_name = new_layout_axis.name(); |
74 | if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) { |
75 | // This is primal axis. So, directly use the original pad_width. |
76 | new_pad_width.push_back(axis_pad_width.at(axis_name)); |
77 | } else { |
78 | // This is the axis that got split. So, check that pad_width was [0, 0] originally. |
79 | const auto& dual_axis = new_layout_axis.ToPrimal(); |
80 | auto dual_axis_name = dual_axis.name(); |
81 | ICHECK(axis_pad_width.count(dual_axis_name)) |
82 | << "Missing axis " << dual_axis << " in " << old_in_layouts[0].name(); |
83 | new_pad_width.push_back(axis_pad_width.at(dual_axis_name)); |
84 | |
85 | // If any pad_width element is not zero, do not change the layout. |
86 | for (auto width : axis_pad_width.at(dual_axis_name)) { |
87 | if (auto* width_imm = width.as<IntImmNode>()) { |
88 | if (width_imm->value != 0) { |
89 | is_layout_modified = false; |
90 | } |
91 | } else { |
92 | is_layout_modified = false; |
93 | } |
94 | } |
95 | } |
96 | } |
97 | |
98 | // If the above conditions satisfied, we can set the newly created pad_width and use the new |
99 | // layout. |
100 | if (is_layout_modified) { |
101 | ret_data = new_in_layouts[0]; |
102 | params->pad_width = new_pad_width; |
103 | } |
104 | } |
105 | |
106 | if (!is_layout_modified) { |
107 | if (old_in_layouts.defined()) { |
108 | ICHECK_EQ(old_in_layouts.size(), 2); |
109 | ret_data = old_in_layouts[0]; |
110 | } else { |
111 | ret_data = Layout::Undef(); |
112 | } |
113 | } |
114 | |
115 | // The pad value is always a scalar |
116 | Layout ret_pad_value = Layout("1" ); |
117 | return InferCorrectLayoutOutput({ret_data, ret_pad_value}, {ret_data}, Attrs(params)); |
118 | } |
119 | |
120 | bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
121 | const TypeReporter& reporter) { |
122 | // types = [pad_data_type, pad_value_type, ret_type] |
123 | ICHECK_EQ(types.size(), 3); |
124 | const auto* data = types[0].as<TensorTypeNode>(); |
125 | if (data == nullptr) return false; |
126 | |
127 | const PadAttrs* param = attrs.as<PadAttrs>(); |
128 | ICHECK(param != nullptr); |
129 | |
130 | // check that pad widths match lengths |
131 | ICHECK(data->shape.size() == param->pad_width.size()) |
132 | << "There should be as many pad width pairs as shape dimensions " |
133 | << "but the shape has " << data->shape.size() << " dimensions " |
134 | << "and there are " << param->pad_width.size() << " pad width pairs." ; |
135 | |
136 | // each pad width element should be a pair of positive integers |
137 | std::vector<IndexExpr> oshape; |
138 | for (size_t i = 0; i < param->pad_width.size(); i++) { |
139 | ICHECK(param->pad_width[i].size() == 2) |
140 | << "Each pad width element should be a pair but at index " << i << " there are " |
141 | << param->pad_width[i].size() << " elements." ; |
142 | |
143 | auto width1 = tir::as_const_int(param->pad_width[i][0]); |
144 | auto width2 = tir::as_const_int(param->pad_width[i][1]); |
145 | ICHECK(width1 != nullptr); |
146 | ICHECK(width2 != nullptr); |
147 | |
148 | if (!data->shape[i].as<tir::AnyNode>()) { |
149 | auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); |
150 | oshape.push_back(data->shape[i] + padding); |
151 | if (tir::as_const_int(data->shape[i])) { |
152 | ICHECK(topi::detail::GetConstInt(data->shape[i] + padding) >= 0) |
153 | << "Output shape post padding should be positive but got " << data->shape[i] + padding; |
154 | } |
155 | } else { |
156 | oshape.push_back(data->shape[i]); |
157 | } |
158 | } |
159 | |
160 | reporter->Assign(types[2], TensorType(Array<IndexExpr>(oshape), data->dtype)); |
161 | return true; |
162 | } |
163 | |
164 | Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
165 | const Type& out_type) { |
166 | const auto* param = attrs.as<PadAttrs>(); |
167 | ICHECK(param != nullptr); |
168 | |
169 | auto pad_width = param->pad_width; |
170 | ICHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width" ; |
171 | Array<IndexExpr> pad_before; |
172 | for (size_t i = 0; i < pad_width.size(); ++i) { |
173 | pad_before.push_back(pad_width[i][0]); |
174 | } |
175 | Array<IndexExpr> pad_after; |
176 | for (size_t i = 0; i < pad_width.size(); ++i) { |
177 | pad_after.push_back(pad_width[i][1]); |
178 | } |
179 | te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype); |
180 | const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>()); |
181 | return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad" , |
182 | topi::kElementWise, param->pad_mode)}; |
183 | } |
184 | |
185 | // Handler to create a call to the padding op used by front-end FFI |
186 | Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode) { |
187 | auto attrs = make_object<PadAttrs>(); |
188 | attrs->pad_width = std::move(pad_width); |
189 | attrs->pad_mode = std::move(pad_mode); |
190 | static const Op& op = Op::Get("nn.pad" ); |
191 | return Call(op, {data, pad_value}, Attrs(attrs), {}); |
192 | } |
193 | |
194 | TVM_REGISTER_GLOBAL("relay.op.nn._make.pad" ).set_body_typed(MakePad); |
195 | |
196 | RELAY_REGISTER_OP("nn.pad" ) |
197 | .describe(R"code(Pad for n-D tensor. |
198 | |
199 | )code" TVM_ADD_FILELINE) |
200 | .set_attrs_type<PadAttrs>() |
201 | .set_num_inputs(2) |
202 | .add_argument("data" , "Tensor" , "The input tensor." ) |
203 | .add_argument("pad_val" , "Tensor" , "The value to fill the padded area with" ) |
204 | .set_support_level(2) |
205 | .add_type_rel("Pad" , PadRel) |
206 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , PadInferCorrectLayout) |
207 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
208 | .set_attr<FTVMCompute>("FTVMCompute" , PadCompute); |
209 | |
210 | // relay.nn.mirror_pad |
211 | TVM_REGISTER_NODE_TYPE(MirrorPadAttrs); |
212 | |
213 | bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
214 | const TypeReporter& reporter) { |
215 | ICHECK_EQ(types.size(), 2); |
216 | const auto* data = types[0].as<TensorTypeNode>(); |
217 | if (data == nullptr) return false; |
218 | |
219 | const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>(); |
220 | ICHECK(param != nullptr); |
221 | |
222 | // check that pad widths match lengths |
223 | ICHECK(data->shape.size() == param->pad_width.size()) |
224 | << "There should be as many pad width pairs as shape dimensions " |
225 | << "but the shape has " << data->shape.size() << " dimensions " |
226 | << "and there are " << param->pad_width.size() << " pad width pairs." ; |
227 | |
228 | // each pad width element should be a pair of positive integers |
229 | std::vector<IndexExpr> oshape; |
230 | for (size_t i = 0; i < param->pad_width.size(); i++) { |
231 | ICHECK(param->pad_width[i].size() == 2) |
232 | << "Each pad width element should be a pair but at index " << i << " there are " |
233 | << param->pad_width[i].size() << " elements." ; |
234 | |
235 | auto width1 = tir::as_const_int(param->pad_width[i][0]); |
236 | auto width2 = tir::as_const_int(param->pad_width[i][1]); |
237 | ICHECK(width1 != nullptr); |
238 | ICHECK(width2 != nullptr); |
239 | |
240 | ICHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " |
241 | << "index " << i << " is " << *width1 << "." ; |
242 | ICHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " |
243 | << "index " << i << " is " << *width2 << "." ; |
244 | |
245 | auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); |
246 | oshape.push_back(data->shape[i] + padding); |
247 | } |
248 | |
249 | reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype)); |
250 | return true; |
251 | } |
252 | |
253 | // Handler to create a call to the padding op used by front-end FFI |
254 | Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr>> pad_width, String mode) { |
255 | auto attrs = make_object<MirrorPadAttrs>(); |
256 | attrs->mode = mode; |
257 | attrs->pad_width = std::move(pad_width); |
258 | static const Op& op = Op::Get("nn.mirror_pad" ); |
259 | return Call(op, {data}, Attrs(attrs), {}); |
260 | } |
261 | |
262 | TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad" ).set_body_typed(MakeMirrorPad); |
263 | |
264 | RELAY_REGISTER_OP("nn.mirror_pad" ) |
265 | .describe(R"code(MirrorPad for n-D tensor. |
266 | |
267 | )code" TVM_ADD_FILELINE) |
268 | .set_attrs_type<MirrorPadAttrs>() |
269 | .set_num_inputs(1) |
270 | .add_argument("data" , "Tensor" , "The input tensor." ) |
271 | .set_support_level(2) |
272 | .add_type_rel("MirrorPad" , MirrorPadRel) |
273 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
274 | |
275 | } // namespace relay |
276 | } // namespace tvm |
277 | |