1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/fold_explicit_padding.cc
22 * \brief A pass for folding explicit pads into other ops.
23 */
24
25#include <tvm/relay/dataflow_matcher.h>
26#include <tvm/relay/expr.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29#include <tvm/runtime/data_type.h>
30#include <tvm/runtime/logging.h>
31#include <tvm/tir/op.h>
32#include <tvm/topi/nn/pooling.h>
33
34#include <optional>
35#include <set>
36#include <string>
37
38#include "../op/tensor/transform.h"
39#include "pattern_utils.h"
40
41namespace tvm {
42namespace relay {
43
44/*!
45 * \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool
46 * with a pad attribute and merges the padding into the kernel.
47 */
48class SimplifyExplicitPad {
49 public:
50 DFPattern pattern() const { return pattern_; }
51
52 SimplifyExplicitPad() {
53 x_ = IsWildcard();
54 pad_ = IsOp("nn.pad")({x_, IsWildcard()});
55
56 // pad->conv patterns
57 w_ = IsWildcard();
58 conv1d_ = IsOp("nn.conv1d");
59 conv2d_ = IsOp("nn.conv2d");
60 conv3d_ = IsOp("nn.conv3d");
61 contrib_conv2d_nchwc_ = IsOp("nn.contrib_conv2d_NCHWc");
62 conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_)({pad_, w_});
63
64 input_zero_point_ = IsWildcard();
65 kernel_zero_point_ = IsWildcard();
66 input_scale_ = IsWildcard();
67 kernel_scale_ = IsWildcard();
68 qconv2d_ = IsOp("qnn.conv2d")(
69 {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});
70
71 // pad->pool patterns
72 avg_pool1d_ = IsOp("nn.avg_pool1d");
73 avg_pool2d_ = IsOp("nn.avg_pool2d");
74 avg_pool3d_ = IsOp("nn.avg_pool3d");
75 max_pool1d_ = IsOp("nn.max_pool1d");
76 max_pool2d_ = IsOp("nn.max_pool2d");
77 max_pool3d_ = IsOp("nn.max_pool3d");
78 max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_;
79 pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_});
80
81 pattern_ = conv_ || qconv2d_ || pool_;
82 }
83
84 template <typename T>
85 Array<PrimExpr> get_combined_padding(const T* old_attrs, Array<PrimExpr> padding) const {
86 ICHECK(padding.size() == old_attrs->padding.size())
87 << "Number of dimensions to pad and convolution padding attributes should have the same "
88 "extent";
89
90 Array<PrimExpr> combined_padding;
91 for (size_t i = 0; i < padding.size(); ++i) {
92 combined_padding.push_back(padding[i] + old_attrs->padding[i]);
93 }
94 return combined_padding;
95 }
96
97 template <typename T>
98 Attrs MakeConvAttrs(const PadAttrs* param, const T* old_attrs) const {
99 // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs
100 ICHECK(old_attrs);
101 ICHECK(param);
102 auto padding = get_padding(param, old_attrs->data_layout);
103 if (!padding) {
104 return Attrs();
105 }
106 auto combined_padding = get_combined_padding(old_attrs, padding.value());
107
108 auto new_attrs = make_object<T>();
109 new_attrs->strides = old_attrs->strides;
110 new_attrs->padding = combined_padding;
111 new_attrs->dilation = old_attrs->dilation;
112 new_attrs->groups = old_attrs->groups;
113 new_attrs->channels = old_attrs->channels;
114 new_attrs->kernel_size = old_attrs->kernel_size;
115 new_attrs->data_layout = old_attrs->data_layout;
116 new_attrs->kernel_layout = old_attrs->kernel_layout;
117 new_attrs->out_layout = old_attrs->out_layout;
118 new_attrs->out_dtype = old_attrs->out_dtype;
119 return Attrs(new_attrs);
120 }
121
122 template <typename T>
123 Attrs MakeConv2D3DAttrs(const PadAttrs* param, const T* old_attrs) const {
124 // Propagate additional Conv2D- and Conv3D-specific attrs
125 auto attrs = MakeConvAttrs(param, old_attrs);
126 if (!attrs.defined()) {
127 return Attrs();
128 }
129
130 T* new_attrs = const_cast<T*>(attrs.template as<T>());
131 new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout;
132 new_attrs->meta_schedule_original_shape = old_attrs->meta_schedule_original_shape;
133 return attrs;
134 }
135
136 template <typename T>
137 Attrs MakePoolAttrs(const PadAttrs* param, const T* old_attrs) const {
138 // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs
139 ICHECK(old_attrs);
140 ICHECK(param);
141 auto padding = get_padding(param, old_attrs->layout);
142 if (!padding) {
143 return Attrs();
144 }
145 auto combined_padding = get_combined_padding(old_attrs, padding.value());
146
147 auto new_attrs = make_object<T>();
148 new_attrs->pool_size = old_attrs->pool_size;
149 new_attrs->strides = old_attrs->strides;
150 new_attrs->dilation = old_attrs->dilation;
151 new_attrs->padding = combined_padding;
152 new_attrs->layout = old_attrs->layout;
153 new_attrs->out_layout = old_attrs->out_layout;
154 new_attrs->ceil_mode = old_attrs->ceil_mode;
155 return Attrs(new_attrs);
156 }
157
158 template <typename T>
159 Attrs MakeAvgPoolAttrs(const PadAttrs* param, const T* old_attrs) const {
160 // Propagate additional AvgPool-specific attrs
161 auto attrs = MakePoolAttrs(param, old_attrs);
162 if (!attrs.defined()) {
163 return attrs;
164 }
165
166 T* new_attrs = const_cast<T*>(attrs.template as<T>());
167 new_attrs->count_include_pad = old_attrs->count_include_pad;
168 if (!new_attrs->count_include_pad) {
169 // AvgPool's divisor doesn't include padding, so don't fold the explicit pad
170 // unless all original pad items are 0.
171 for (IndexExpr pad : old_attrs->padding) {
172 const IntImmNode* maybe_int_imm = pad.as<IntImmNode>();
173 if (!maybe_int_imm || maybe_int_imm->value != 0) {
174 // Return undefined attrs to signal that we don't want to fold explicit pad
175 return Attrs();
176 }
177 }
178 // Turn on `count_include_pad` to preserve original pad first, then pool behavior
179 // where AvgPool's divisor implicitly includes padding.
180 new_attrs->count_include_pad = true;
181 }
182
183 return attrs;
184 }
185
186 static const std::optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
187 std::string data_layout) {
188 // Gets spatial axes padding from the given PadAttrs `param`. If padding
189 // is non-zero on non-spatial axes, return std::nullopt.
190 ICHECK(param);
191 ICHECK(data_layout.size() == param->pad_width.size())
192 << "Data Layout and padding attributes should have the same extent";
193
194 std::set<char> image_dims({'H', 'W', 'D'});
195 Array<PrimExpr> padding;
196 // If we're padding a non-spatial dimension, don't simplify
197 // Convolution/Pool can only pad on spatial axes
198 for (size_t i = 0; i < param->pad_width.size(); ++i) {
199 if (!image_dims.count(data_layout[i])) {
200 for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
201 if (param->pad_width[i][j] != 0) {
202 return std::nullopt;
203 }
204 }
205 }
206 }
207 for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
208 for (size_t i = 0; i < param->pad_width.size(); ++i) {
209 if (image_dims.count(data_layout[i])) {
210 padding.push_back(param->pad_width[i][j]);
211 }
212 }
213 }
214 return padding;
215 }
216
217 Expr callback(const Expr& pre, const Expr& post,
218 const Map<DFPattern, Array<Expr>>& node_map) const {
219 const CallNode* call_node = post.as<CallNode>();
220 ICHECK(call_node);
221 auto pad = node_map[pad_][0];
222 const CallNode* pad_node = pad.as<CallNode>();
223 ICHECK(pad_node);
224 const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
225 ICHECK(param);
226
227 auto x = node_map[x_][0];
228
229 const Expr& pv = pad_node->args[1];
230 const ConstantNode* pad_value = pv.as<ConstantNode>();
231 auto pad_scalar = ToScalar(pad_value->data);
232
233 if (node_map.find(qconv2d_) != node_map.end()) {
234 Attrs attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
235 if (!attrs.defined()) {
236 return post;
237 }
238 auto input_zero_point = node_map[input_zero_point_][0];
239 auto kernel_zero_point = node_map[kernel_zero_point_][0];
240 auto input_scale = node_map[input_scale_][0];
241 auto kernel_scale = node_map[kernel_scale_][0];
242 // Fold Padding and QNN Convolution only if pad value == input zero point.
243 if (IsEqualScalar(input_zero_point, pv)) {
244 auto w = node_map[w_][0];
245 return Call(call_node->op,
246 {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
247 call_node->type_args, call_node->span);
248 }
249 return post;
250 }
251
252 if (param->pad_mode == "constant" && pad_value) {
253 Attrs attrs;
254 if (pad_scalar == 0.0) {
255 // Fold Padding and Conv/AvgPool only if pad_value == 0.
256 if (node_map.count(conv_)) {
257 if (node_map.count(conv1d_)) {
258 attrs = MakeConvAttrs(param, call_node->attrs.as<Conv1DAttrs>());
259 } else if (node_map.count(conv2d_)) {
260 attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
261 } else if (node_map.count(conv3d_)) {
262 attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv3DAttrs>());
263 }
264 if (!attrs.defined()) {
265 return post;
266 }
267 auto w = node_map[w_][0];
268 return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
269 } else if (node_map.count(avg_pool1d_)) {
270 attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool1DAttrs>());
271 } else if (node_map.count(avg_pool2d_)) {
272 attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool2DAttrs>());
273 } else if (node_map.count(avg_pool3d_)) {
274 attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool3DAttrs>());
275 }
276 }
277 if (node_map.count(max_pool_)) {
278 // Fold Padding and MaxPool only if pad_value is the min possible value for the dtype
279 auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype));
280 const FloatImmNode* maybe_min_float = min_value.as<FloatImmNode>();
281 const IntImmNode* maybe_min_int = min_value.as<IntImmNode>();
282
283 if ((maybe_min_float && pad_scalar == maybe_min_float->value) ||
284 (maybe_min_int && pad_scalar == maybe_min_int->value)) {
285 if (node_map.count(max_pool1d_)) {
286 attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool1DAttrs>());
287 } else if (node_map.count(max_pool2d_)) {
288 attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool2DAttrs>());
289 } else if (node_map.count(max_pool3d_)) {
290 attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool3DAttrs>());
291 }
292 }
293 }
294 if (!attrs.defined()) {
295 return post;
296 }
297 return Call(call_node->op, {x}, attrs, call_node->type_args, call_node->span);
298 }
299 return post;
300 }
301
302 private:
303 /*! \brief Pattern for rewriting */
304 DFPattern pattern_;
305 /*! \brief Pattern input */
306 DFPattern x_;
307 /*! \brief Pattern input weight */
308 DFPattern w_;
309 /*! \brief Pattern pad */
310 DFPattern pad_;
311 /*! \brief Pattern conv */
312 DFPattern conv_;
313 DFPattern conv1d_;
314 DFPattern conv2d_;
315 DFPattern conv3d_;
316 DFPattern contrib_conv2d_nchwc_;
317 DFPattern qconv2d_;
318 DFPattern input_zero_point_;
319 DFPattern kernel_zero_point_;
320 DFPattern input_scale_;
321 DFPattern kernel_scale_;
322 /*! \brief Pattern pool */
323 DFPattern pool_;
324 DFPattern avg_pool1d_;
325 DFPattern avg_pool2d_;
326 DFPattern avg_pool3d_;
327 DFPattern max_pool1d_;
328 DFPattern max_pool2d_;
329 DFPattern max_pool3d_;
330 DFPattern max_pool_;
331};
332
333class SimplifyExplicitPadding {
334 public:
335 explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) {
336 CreateCallback(SimplifyExplicitPad());
337 }
338 template <typename T>
339 void CreateCallback(const T& pattern) {
340 auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
341 Expr pre = args[0];
342 Expr post = args[1];
343 Map<DFPattern, Array<Expr>> node_map = args[2];
344 *rv = pattern.callback(pre, post, node_map);
345 };
346 callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
347 }
348
349 Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
350
351 private:
352 IRModule mod_;
353 /*! \brief Callbacks for expr simplification */
354 Array<DFPatternCallback> callbacks_;
355};
356
357/*!
358 * \brief FoldExplicitPadding finds explict padding before an op that can
359 * support implicit padding and fuses them.
360 */
361Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) {
362 return SimplifyExplicitPadding(mod).Simplify(expr);
363}
364
365namespace transform {
366
367Pass FoldExplicitPadding() {
368 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
369 [=](Function f, IRModule m, PassContext pc) {
370 return Downcast<Function>(FoldExplicitPadding(f, m));
371 };
372 return CreateFunctionPass(pass_func, 0, " FoldExplicitPadding", {"InferType"});
373}
374
375TVM_REGISTER_GLOBAL("relay._transform.FoldExplicitPadding").set_body_typed(FoldExplicitPadding);
376
377} // namespace transform
378
379} // namespace relay
380} // namespace tvm
381