1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/flatten_atrous_conv.cc
22 * \brief This transform flattens atrous convolution, which corresponds to the sequence of
23 * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd".
24 */
25
26#include <tvm/relay/attrs/nn.h>
27#include <tvm/relay/dataflow_matcher.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/qnn/attrs.h>
31#include <tvm/relay/transform.h>
32#include <tvm/topi/broadcast.h>
33
34#include <array>
35#include <set>
36#include <unordered_map>
37
38#include "../qnn/utils.h"
39#include "pattern_utils.h"
40
41namespace tvm {
42namespace relay {
43
44/* Description of FlattenAtrousConv
45 *
46 * The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd
47 * operations:
48 *
49 * x w
50 * | |
51 * s2b |
52 * \ /
53 * conv2d
54 * |
55 * b2s
56 *
57 * and convert them into subgraphs with a convolution with the modified "dilation" and
58 * recalculated "padding" parameters.
59 */
60
61using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
62
63class FlattenAtrousConvSubgraphMutator {
64 public:
65 Expr MutateSubgraph(const Expr& expr) {
66 try {
67 const CallNode* b2s_node_ = expr.as<CallNode>();
68 const CallNode* conv2d_node_ = b2s_node_->args[0].as<CallNode>();
69 const CallNode* s2b_node_ = conv2d_node_->args[0].as<CallNode>();
70
71 ICHECK(b2s_node_ != nullptr);
72 const auto* b2s_attrs = b2s_node_->attrs.as<BatchToSpaceNDAttrs>();
73 ICHECK(b2s_attrs != nullptr);
74
75 Array<PrimExpr> dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]};
76
77 ICHECK(conv2d_node_ != nullptr);
78 const auto* conv2d_attrs = conv2d_node_->attrs.as<Conv2DAttrs>();
79 ICHECK(conv2d_attrs != nullptr);
80
81 Array<PrimExpr> kernel_shape = conv2d_attrs->kernel_size;
82 PrimExpr kernel_h = kernel_shape[0];
83 PrimExpr kernel_w = kernel_shape[1];
84
85 ICHECK(s2b_node_ != nullptr);
86 const auto* s2b_attrs = s2b_node_->attrs.as<SpaceToBatchNDAttrs>();
87 ICHECK(s2b_attrs != nullptr);
88
89 Expr data = s2b_node_->args[0];
90 ICHECK(conv2d_attrs->data_layout == "NHWC");
91 Array<PrimExpr> data_shape = transform::InferTypeLocal(data).as<TensorTypeNode>()->shape;
92 PrimExpr in_h = data_shape[1];
93 PrimExpr in_w = data_shape[2];
94
95 PrimExpr dilation_h = dilation[0];
96 PrimExpr dilation_w = dilation[1];
97
98 PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
99 PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;
100
101 Array<PrimExpr> strides = {1, 1};
102 PrimExpr stride_h = strides[0];
103 PrimExpr stride_w = strides[1];
104
105 auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d,
106 PrimExpr stride1d) -> Array<PrimExpr> {
107 PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d);
108 PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0);
109 PrimExpr pad_before = truncdiv(pad, 2);
110 PrimExpr pad_after = pad - pad_before;
111 return {pad_before, pad_after};
112 };
113
114 Array<PrimExpr> pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h);
115 Array<PrimExpr> pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w);
116
117 Array<IndexExpr> padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]};
118
119 Expr weight = conv2d_node_->args[1];
120
121 if (conv2d_node_->op == Op::Get("nn.conv2d")) {
122 return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups,
123 conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout,
124 conv2d_attrs->kernel_layout, conv2d_attrs->out_layout,
125 conv2d_attrs->out_dtype);
126 }
127
128 if (conv2d_node_->op == Op::Get("qnn.conv2d")) {
129 Expr input_zero_point = conv2d_node_->args[2];
130 Expr kernel_zero_point = conv2d_node_->args[3];
131 Expr input_scale = conv2d_node_->args[4];
132 Expr kernel_scale = conv2d_node_->args[5];
133 return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale,
134 kernel_scale, strides, padding, dilation, conv2d_attrs->groups,
135 conv2d_attrs->channels, conv2d_attrs->kernel_size,
136 conv2d_attrs->data_layout, conv2d_attrs->kernel_layout,
137 conv2d_attrs->out_layout, conv2d_attrs->out_dtype);
138 }
139
140 DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl;
141 return expr;
142 } catch (std::exception& e) {
143 DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with "
144 << e.what() << std::endl;
145 return expr;
146 }
147 }
148};
149
150class FlattenAtrousConvRewriter : public MixedModeMutator {
151 protected:
152 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
153 if (const CallNode* call_node = post.as<CallNode>()) {
154 if (ops_[op_iter_].count(call_node->op)) {
155 ++op_iter_;
156 if (op_iter_ == ops_.size()) {
157 op_iter_ = 0;
158 return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post);
159 }
160 } else {
161 op_iter_ = 0;
162 }
163 }
164 return post;
165 }
166
167 private:
168 size_t op_iter_ = 0;
169 const std::array<ExprSet, 3> ops_ = {
170 ExprSet{Op::Get("nn.space_to_batch_nd")},
171 ExprSet{Op::Get("nn.conv2d"), Op::Get("qnn.conv2d")},
172 ExprSet{Op::Get("nn.batch_to_space_nd")},
173 };
174};
175
176Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) {
177 return FlattenAtrousConvRewriter().Mutate(expr);
178}
179
180namespace transform {
181
182Pass FlattenAtrousConv() {
183 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
184 [=](Function f, IRModule m, PassContext pc) {
185 return Downcast<Function>(FlattenAtrousConv(f, m));
186 };
187 return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv", {"InferType"});
188}
189
190TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv").set_body_typed(FlattenAtrousConv);
191
192} // namespace transform
193
194} // namespace relay
195} // namespace tvm
196