1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/fold_explicit_padding.cc |
22 | * \brief A pass for folding explicit pads into other ops. |
23 | */ |
24 | |
25 | #include <tvm/relay/dataflow_matcher.h> |
26 | #include <tvm/relay/expr.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | #include <tvm/runtime/logging.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/topi/nn/pooling.h> |
33 | |
34 | #include <optional> |
35 | #include <set> |
36 | #include <string> |
37 | |
38 | #include "../op/tensor/transform.h" |
39 | #include "pattern_utils.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | /*! |
45 | * \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool |
46 | * with a pad attribute and merges the padding into the kernel. |
47 | */ |
48 | class SimplifyExplicitPad { |
49 | public: |
50 | DFPattern pattern() const { return pattern_; } |
51 | |
52 | SimplifyExplicitPad() { |
53 | x_ = IsWildcard(); |
54 | pad_ = IsOp("nn.pad" )({x_, IsWildcard()}); |
55 | |
56 | // pad->conv patterns |
57 | w_ = IsWildcard(); |
58 | conv1d_ = IsOp("nn.conv1d" ); |
59 | conv2d_ = IsOp("nn.conv2d" ); |
60 | conv3d_ = IsOp("nn.conv3d" ); |
61 | contrib_conv2d_nchwc_ = IsOp("nn.contrib_conv2d_NCHWc" ); |
62 | conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_)({pad_, w_}); |
63 | |
64 | input_zero_point_ = IsWildcard(); |
65 | kernel_zero_point_ = IsWildcard(); |
66 | input_scale_ = IsWildcard(); |
67 | kernel_scale_ = IsWildcard(); |
68 | qconv2d_ = IsOp("qnn.conv2d" )( |
69 | {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_}); |
70 | |
71 | // pad->pool patterns |
72 | avg_pool1d_ = IsOp("nn.avg_pool1d" ); |
73 | avg_pool2d_ = IsOp("nn.avg_pool2d" ); |
74 | avg_pool3d_ = IsOp("nn.avg_pool3d" ); |
75 | max_pool1d_ = IsOp("nn.max_pool1d" ); |
76 | max_pool2d_ = IsOp("nn.max_pool2d" ); |
77 | max_pool3d_ = IsOp("nn.max_pool3d" ); |
78 | max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_; |
79 | pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_}); |
80 | |
81 | pattern_ = conv_ || qconv2d_ || pool_; |
82 | } |
83 | |
84 | template <typename T> |
85 | Array<PrimExpr> get_combined_padding(const T* old_attrs, Array<PrimExpr> padding) const { |
86 | ICHECK(padding.size() == old_attrs->padding.size()) |
87 | << "Number of dimensions to pad and convolution padding attributes should have the same " |
88 | "extent" ; |
89 | |
90 | Array<PrimExpr> combined_padding; |
91 | for (size_t i = 0; i < padding.size(); ++i) { |
92 | combined_padding.push_back(padding[i] + old_attrs->padding[i]); |
93 | } |
94 | return combined_padding; |
95 | } |
96 | |
97 | template <typename T> |
98 | Attrs MakeConvAttrs(const PadAttrs* param, const T* old_attrs) const { |
99 | // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs |
100 | ICHECK(old_attrs); |
101 | ICHECK(param); |
102 | auto padding = get_padding(param, old_attrs->data_layout); |
103 | if (!padding) { |
104 | return Attrs(); |
105 | } |
106 | auto combined_padding = get_combined_padding(old_attrs, padding.value()); |
107 | |
108 | auto new_attrs = make_object<T>(); |
109 | new_attrs->strides = old_attrs->strides; |
110 | new_attrs->padding = combined_padding; |
111 | new_attrs->dilation = old_attrs->dilation; |
112 | new_attrs->groups = old_attrs->groups; |
113 | new_attrs->channels = old_attrs->channels; |
114 | new_attrs->kernel_size = old_attrs->kernel_size; |
115 | new_attrs->data_layout = old_attrs->data_layout; |
116 | new_attrs->kernel_layout = old_attrs->kernel_layout; |
117 | new_attrs->out_layout = old_attrs->out_layout; |
118 | new_attrs->out_dtype = old_attrs->out_dtype; |
119 | return Attrs(new_attrs); |
120 | } |
121 | |
122 | template <typename T> |
123 | Attrs MakeConv2D3DAttrs(const PadAttrs* param, const T* old_attrs) const { |
124 | // Propagate additional Conv2D- and Conv3D-specific attrs |
125 | auto attrs = MakeConvAttrs(param, old_attrs); |
126 | if (!attrs.defined()) { |
127 | return Attrs(); |
128 | } |
129 | |
130 | T* new_attrs = const_cast<T*>(attrs.template as<T>()); |
131 | new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout; |
132 | new_attrs->meta_schedule_original_shape = old_attrs->meta_schedule_original_shape; |
133 | return attrs; |
134 | } |
135 | |
136 | template <typename T> |
137 | Attrs MakePoolAttrs(const PadAttrs* param, const T* old_attrs) const { |
138 | // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs |
139 | ICHECK(old_attrs); |
140 | ICHECK(param); |
141 | auto padding = get_padding(param, old_attrs->layout); |
142 | if (!padding) { |
143 | return Attrs(); |
144 | } |
145 | auto combined_padding = get_combined_padding(old_attrs, padding.value()); |
146 | |
147 | auto new_attrs = make_object<T>(); |
148 | new_attrs->pool_size = old_attrs->pool_size; |
149 | new_attrs->strides = old_attrs->strides; |
150 | new_attrs->dilation = old_attrs->dilation; |
151 | new_attrs->padding = combined_padding; |
152 | new_attrs->layout = old_attrs->layout; |
153 | new_attrs->out_layout = old_attrs->out_layout; |
154 | new_attrs->ceil_mode = old_attrs->ceil_mode; |
155 | return Attrs(new_attrs); |
156 | } |
157 | |
158 | template <typename T> |
159 | Attrs MakeAvgPoolAttrs(const PadAttrs* param, const T* old_attrs) const { |
160 | // Propagate additional AvgPool-specific attrs |
161 | auto attrs = MakePoolAttrs(param, old_attrs); |
162 | if (!attrs.defined()) { |
163 | return attrs; |
164 | } |
165 | |
166 | T* new_attrs = const_cast<T*>(attrs.template as<T>()); |
167 | new_attrs->count_include_pad = old_attrs->count_include_pad; |
168 | if (!new_attrs->count_include_pad) { |
169 | // AvgPool's divisor doesn't include padding, so don't fold the explicit pad |
170 | // unless all original pad items are 0. |
171 | for (IndexExpr pad : old_attrs->padding) { |
172 | const IntImmNode* maybe_int_imm = pad.as<IntImmNode>(); |
173 | if (!maybe_int_imm || maybe_int_imm->value != 0) { |
174 | // Return undefined attrs to signal that we don't want to fold explicit pad |
175 | return Attrs(); |
176 | } |
177 | } |
178 | // Turn on `count_include_pad` to preserve original pad first, then pool behavior |
179 | // where AvgPool's divisor implicitly includes padding. |
180 | new_attrs->count_include_pad = true; |
181 | } |
182 | |
183 | return attrs; |
184 | } |
185 | |
186 | static const std::optional<Array<PrimExpr>> get_padding(const PadAttrs* param, |
187 | std::string data_layout) { |
188 | // Gets spatial axes padding from the given PadAttrs `param`. If padding |
189 | // is non-zero on non-spatial axes, return std::nullopt. |
190 | ICHECK(param); |
191 | ICHECK(data_layout.size() == param->pad_width.size()) |
192 | << "Data Layout and padding attributes should have the same extent" ; |
193 | |
194 | std::set<char> image_dims({'H', 'W', 'D'}); |
195 | Array<PrimExpr> padding; |
196 | // If we're padding a non-spatial dimension, don't simplify |
197 | // Convolution/Pool can only pad on spatial axes |
198 | for (size_t i = 0; i < param->pad_width.size(); ++i) { |
199 | if (!image_dims.count(data_layout[i])) { |
200 | for (size_t j = 0; j < param->pad_width[i].size(); ++j) { |
201 | if (param->pad_width[i][j] != 0) { |
202 | return std::nullopt; |
203 | } |
204 | } |
205 | } |
206 | } |
207 | for (size_t j = 0; j < param->pad_width[0].size(); ++j) { |
208 | for (size_t i = 0; i < param->pad_width.size(); ++i) { |
209 | if (image_dims.count(data_layout[i])) { |
210 | padding.push_back(param->pad_width[i][j]); |
211 | } |
212 | } |
213 | } |
214 | return padding; |
215 | } |
216 | |
217 | Expr callback(const Expr& pre, const Expr& post, |
218 | const Map<DFPattern, Array<Expr>>& node_map) const { |
219 | const CallNode* call_node = post.as<CallNode>(); |
220 | ICHECK(call_node); |
221 | auto pad = node_map[pad_][0]; |
222 | const CallNode* pad_node = pad.as<CallNode>(); |
223 | ICHECK(pad_node); |
224 | const PadAttrs* param = pad_node->attrs.as<PadAttrs>(); |
225 | ICHECK(param); |
226 | |
227 | auto x = node_map[x_][0]; |
228 | |
229 | const Expr& pv = pad_node->args[1]; |
230 | const ConstantNode* pad_value = pv.as<ConstantNode>(); |
231 | auto pad_scalar = ToScalar(pad_value->data); |
232 | |
233 | if (node_map.find(qconv2d_) != node_map.end()) { |
234 | Attrs attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>()); |
235 | if (!attrs.defined()) { |
236 | return post; |
237 | } |
238 | auto input_zero_point = node_map[input_zero_point_][0]; |
239 | auto kernel_zero_point = node_map[kernel_zero_point_][0]; |
240 | auto input_scale = node_map[input_scale_][0]; |
241 | auto kernel_scale = node_map[kernel_scale_][0]; |
242 | // Fold Padding and QNN Convolution only if pad value == input zero point. |
243 | if (IsEqualScalar(input_zero_point, pv)) { |
244 | auto w = node_map[w_][0]; |
245 | return Call(call_node->op, |
246 | {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs, |
247 | call_node->type_args, call_node->span); |
248 | } |
249 | return post; |
250 | } |
251 | |
252 | if (param->pad_mode == "constant" && pad_value) { |
253 | Attrs attrs; |
254 | if (pad_scalar == 0.0) { |
255 | // Fold Padding and Conv/AvgPool only if pad_value == 0. |
256 | if (node_map.count(conv_)) { |
257 | if (node_map.count(conv1d_)) { |
258 | attrs = MakeConvAttrs(param, call_node->attrs.as<Conv1DAttrs>()); |
259 | } else if (node_map.count(conv2d_)) { |
260 | attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>()); |
261 | } else if (node_map.count(conv3d_)) { |
262 | attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv3DAttrs>()); |
263 | } |
264 | if (!attrs.defined()) { |
265 | return post; |
266 | } |
267 | auto w = node_map[w_][0]; |
268 | return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); |
269 | } else if (node_map.count(avg_pool1d_)) { |
270 | attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool1DAttrs>()); |
271 | } else if (node_map.count(avg_pool2d_)) { |
272 | attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool2DAttrs>()); |
273 | } else if (node_map.count(avg_pool3d_)) { |
274 | attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool3DAttrs>()); |
275 | } |
276 | } |
277 | if (node_map.count(max_pool_)) { |
278 | // Fold Padding and MaxPool only if pad_value is the min possible value for the dtype |
279 | auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype)); |
280 | const FloatImmNode* maybe_min_float = min_value.as<FloatImmNode>(); |
281 | const IntImmNode* maybe_min_int = min_value.as<IntImmNode>(); |
282 | |
283 | if ((maybe_min_float && pad_scalar == maybe_min_float->value) || |
284 | (maybe_min_int && pad_scalar == maybe_min_int->value)) { |
285 | if (node_map.count(max_pool1d_)) { |
286 | attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool1DAttrs>()); |
287 | } else if (node_map.count(max_pool2d_)) { |
288 | attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool2DAttrs>()); |
289 | } else if (node_map.count(max_pool3d_)) { |
290 | attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool3DAttrs>()); |
291 | } |
292 | } |
293 | } |
294 | if (!attrs.defined()) { |
295 | return post; |
296 | } |
297 | return Call(call_node->op, {x}, attrs, call_node->type_args, call_node->span); |
298 | } |
299 | return post; |
300 | } |
301 | |
302 | private: |
303 | /*! \brief Pattern for rewriting */ |
304 | DFPattern pattern_; |
305 | /*! \brief Pattern input */ |
306 | DFPattern x_; |
307 | /*! \brief Pattern input weight */ |
308 | DFPattern w_; |
309 | /*! \brief Pattern pad */ |
310 | DFPattern pad_; |
311 | /*! \brief Pattern conv */ |
312 | DFPattern conv_; |
313 | DFPattern conv1d_; |
314 | DFPattern conv2d_; |
315 | DFPattern conv3d_; |
316 | DFPattern contrib_conv2d_nchwc_; |
317 | DFPattern qconv2d_; |
318 | DFPattern input_zero_point_; |
319 | DFPattern kernel_zero_point_; |
320 | DFPattern input_scale_; |
321 | DFPattern kernel_scale_; |
322 | /*! \brief Pattern pool */ |
323 | DFPattern pool_; |
324 | DFPattern avg_pool1d_; |
325 | DFPattern avg_pool2d_; |
326 | DFPattern avg_pool3d_; |
327 | DFPattern max_pool1d_; |
328 | DFPattern max_pool2d_; |
329 | DFPattern max_pool3d_; |
330 | DFPattern max_pool_; |
331 | }; |
332 | |
333 | class SimplifyExplicitPadding { |
334 | public: |
335 | explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) { |
336 | CreateCallback(SimplifyExplicitPad()); |
337 | } |
338 | template <typename T> |
339 | void CreateCallback(const T& pattern) { |
340 | auto func = [pattern](TVMArgs args, TVMRetValue* rv) { |
341 | Expr pre = args[0]; |
342 | Expr post = args[1]; |
343 | Map<DFPattern, Array<Expr>> node_map = args[2]; |
344 | *rv = pattern.callback(pre, post, node_map); |
345 | }; |
346 | callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); |
347 | } |
348 | |
349 | Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } |
350 | |
351 | private: |
352 | IRModule mod_; |
353 | /*! \brief Callbacks for expr simplification */ |
354 | Array<DFPatternCallback> callbacks_; |
355 | }; |
356 | |
357 | /*! |
358 | * \brief FoldExplicitPadding finds explict padding before an op that can |
359 | * support implicit padding and fuses them. |
360 | */ |
361 | Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) { |
362 | return SimplifyExplicitPadding(mod).Simplify(expr); |
363 | } |
364 | |
365 | namespace transform { |
366 | |
367 | Pass FoldExplicitPadding() { |
368 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
369 | [=](Function f, IRModule m, PassContext pc) { |
370 | return Downcast<Function>(FoldExplicitPadding(f, m)); |
371 | }; |
372 | return CreateFunctionPass(pass_func, 0, " FoldExplicitPadding" , {"InferType" }); |
373 | } |
374 | |
375 | TVM_REGISTER_GLOBAL("relay._transform.FoldExplicitPadding" ).set_body_typed(FoldExplicitPadding); |
376 | |
377 | } // namespace transform |
378 | |
379 | } // namespace relay |
380 | } // namespace tvm |
381 | |