1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/flatten_atrous_conv.cc |
22 | * \brief This transform flattens atrous convolution, which corresponds to the sequence of |
23 | * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd". |
24 | */ |
25 | |
26 | #include <tvm/relay/attrs/nn.h> |
27 | #include <tvm/relay/dataflow_matcher.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/qnn/attrs.h> |
31 | #include <tvm/relay/transform.h> |
32 | #include <tvm/topi/broadcast.h> |
33 | |
34 | #include <array> |
35 | #include <set> |
36 | #include <unordered_map> |
37 | |
38 | #include "../qnn/utils.h" |
39 | #include "pattern_utils.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | /* Description of FlattenAtrousConv |
45 | * |
46 | * The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd |
47 | * operations: |
48 | * |
49 | * x w |
50 | * | | |
51 | * s2b | |
52 | * \ / |
53 | * conv2d |
54 | * | |
55 | * b2s |
56 | * |
57 | * and convert them into subgraphs with a convolution with the modified "dilation" and |
58 | * recalculated "padding" parameters. |
59 | */ |
60 | |
61 | using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>; |
62 | |
63 | class FlattenAtrousConvSubgraphMutator { |
64 | public: |
65 | Expr MutateSubgraph(const Expr& expr) { |
66 | try { |
67 | const CallNode* b2s_node_ = expr.as<CallNode>(); |
68 | const CallNode* conv2d_node_ = b2s_node_->args[0].as<CallNode>(); |
69 | const CallNode* s2b_node_ = conv2d_node_->args[0].as<CallNode>(); |
70 | |
71 | ICHECK(b2s_node_ != nullptr); |
72 | const auto* b2s_attrs = b2s_node_->attrs.as<BatchToSpaceNDAttrs>(); |
73 | ICHECK(b2s_attrs != nullptr); |
74 | |
75 | Array<PrimExpr> dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]}; |
76 | |
77 | ICHECK(conv2d_node_ != nullptr); |
78 | const auto* conv2d_attrs = conv2d_node_->attrs.as<Conv2DAttrs>(); |
79 | ICHECK(conv2d_attrs != nullptr); |
80 | |
81 | Array<PrimExpr> kernel_shape = conv2d_attrs->kernel_size; |
82 | PrimExpr kernel_h = kernel_shape[0]; |
83 | PrimExpr kernel_w = kernel_shape[1]; |
84 | |
85 | ICHECK(s2b_node_ != nullptr); |
86 | const auto* s2b_attrs = s2b_node_->attrs.as<SpaceToBatchNDAttrs>(); |
87 | ICHECK(s2b_attrs != nullptr); |
88 | |
89 | Expr data = s2b_node_->args[0]; |
90 | ICHECK(conv2d_attrs->data_layout == "NHWC" ); |
91 | Array<PrimExpr> data_shape = transform::InferTypeLocal(data).as<TensorTypeNode>()->shape; |
92 | PrimExpr in_h = data_shape[1]; |
93 | PrimExpr in_w = data_shape[2]; |
94 | |
95 | PrimExpr dilation_h = dilation[0]; |
96 | PrimExpr dilation_w = dilation[1]; |
97 | |
98 | PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; |
99 | PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; |
100 | |
101 | Array<PrimExpr> strides = {1, 1}; |
102 | PrimExpr stride_h = strides[0]; |
103 | PrimExpr stride_w = strides[1]; |
104 | |
105 | auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d, |
106 | PrimExpr stride1d) -> Array<PrimExpr> { |
107 | PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d); |
108 | PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0); |
109 | PrimExpr pad_before = truncdiv(pad, 2); |
110 | PrimExpr pad_after = pad - pad_before; |
111 | return {pad_before, pad_after}; |
112 | }; |
113 | |
114 | Array<PrimExpr> pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h); |
115 | Array<PrimExpr> pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w); |
116 | |
117 | Array<IndexExpr> padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]}; |
118 | |
119 | Expr weight = conv2d_node_->args[1]; |
120 | |
121 | if (conv2d_node_->op == Op::Get("nn.conv2d" )) { |
122 | return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups, |
123 | conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout, |
124 | conv2d_attrs->kernel_layout, conv2d_attrs->out_layout, |
125 | conv2d_attrs->out_dtype); |
126 | } |
127 | |
128 | if (conv2d_node_->op == Op::Get("qnn.conv2d" )) { |
129 | Expr input_zero_point = conv2d_node_->args[2]; |
130 | Expr kernel_zero_point = conv2d_node_->args[3]; |
131 | Expr input_scale = conv2d_node_->args[4]; |
132 | Expr kernel_scale = conv2d_node_->args[5]; |
133 | return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale, |
134 | kernel_scale, strides, padding, dilation, conv2d_attrs->groups, |
135 | conv2d_attrs->channels, conv2d_attrs->kernel_size, |
136 | conv2d_attrs->data_layout, conv2d_attrs->kernel_layout, |
137 | conv2d_attrs->out_layout, conv2d_attrs->out_dtype); |
138 | } |
139 | |
140 | DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl; |
141 | return expr; |
142 | } catch (std::exception& e) { |
143 | DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with " |
144 | << e.what() << std::endl; |
145 | return expr; |
146 | } |
147 | } |
148 | }; |
149 | |
150 | class FlattenAtrousConvRewriter : public MixedModeMutator { |
151 | protected: |
152 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
153 | if (const CallNode* call_node = post.as<CallNode>()) { |
154 | if (ops_[op_iter_].count(call_node->op)) { |
155 | ++op_iter_; |
156 | if (op_iter_ == ops_.size()) { |
157 | op_iter_ = 0; |
158 | return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post); |
159 | } |
160 | } else { |
161 | op_iter_ = 0; |
162 | } |
163 | } |
164 | return post; |
165 | } |
166 | |
167 | private: |
168 | size_t op_iter_ = 0; |
169 | const std::array<ExprSet, 3> ops_ = { |
170 | ExprSet{Op::Get("nn.space_to_batch_nd" )}, |
171 | ExprSet{Op::Get("nn.conv2d" ), Op::Get("qnn.conv2d" )}, |
172 | ExprSet{Op::Get("nn.batch_to_space_nd" )}, |
173 | }; |
174 | }; |
175 | |
176 | Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) { |
177 | return FlattenAtrousConvRewriter().Mutate(expr); |
178 | } |
179 | |
180 | namespace transform { |
181 | |
182 | Pass FlattenAtrousConv() { |
183 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
184 | [=](Function f, IRModule m, PassContext pc) { |
185 | return Downcast<Function>(FlattenAtrousConv(f, m)); |
186 | }; |
187 | return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv" , {"InferType" }); |
188 | } |
189 | |
190 | TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv" ).set_body_typed(FlattenAtrousConv); |
191 | |
192 | } // namespace transform |
193 | |
194 | } // namespace relay |
195 | } // namespace tvm |
196 | |