1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file combine_parallel_conv2d.cc |
23 | * \brief Combine parallel 2d convolutions into a single convolution. |
24 | * |
25 | * This pass replaces convolutions that share the same input node and the same |
26 | * arguments (except that the number of output channels can be different) with a |
27 | * single convolution. The weight of the new 2d convolution is the concatenation |
28 | * of the original weights. Elemwise and broadcast ops following conv2d are also |
29 | * combined if possible. |
30 | * |
31 | * This prevents launching multiple kernels in networks with multiple |
32 | * convolution branches, such as Inception block. |
33 | */ |
34 | |
35 | #include <tvm/relay/analysis.h> |
36 | #include <tvm/relay/attrs/nn.h> |
37 | #include <tvm/relay/attrs/transform.h> |
38 | #include <tvm/relay/expr_functor.h> |
39 | #include <tvm/relay/op_attr_types.h> |
40 | #include <tvm/relay/transform.h> |
41 | |
42 | #include <unordered_map> |
43 | #include <unordered_set> |
44 | |
45 | #include "./combine_parallel_op.h" |
46 | #include "./expr_subst.h" |
47 | #include "pattern_utils.h" |
48 | |
49 | namespace tvm { |
50 | namespace relay { |
51 | |
52 | class ParallelConv2DCombiner : public ParallelOpCombiner { |
53 | public: |
54 | explicit ParallelConv2DCombiner(uint64_t min_num_branches) |
55 | : ParallelOpCombiner("nn.conv2d" , min_num_branches) {} |
56 | |
57 | protected: |
58 | bool IsSupportedOp(const CallNode* n) { return n->attrs.as<Conv2DAttrs>()->groups == 1; } |
59 | |
60 | bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { |
61 | StructuralEqual eq; |
62 | const Layout kOIHW("OIHW" ); |
63 | const auto* attrs_a = a->attrs.as<Conv2DAttrs>(); |
64 | const auto* attrs_b = b->attrs.as<Conv2DAttrs>(); |
65 | ICHECK(attrs_a); |
66 | ICHECK(attrs_b); |
67 | const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); |
68 | const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); |
69 | const auto shape_a = |
70 | tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); |
71 | const auto shape_b = |
72 | tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); |
73 | |
74 | return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && |
75 | eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && |
76 | eq(attrs_a->data_layout, attrs_b->data_layout) && |
77 | eq(attrs_a->kernel_layout, attrs_b->kernel_layout) && |
78 | eq(attrs_a->out_dtype, attrs_b->out_dtype) && |
79 | eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && |
80 | eq(shape_a[3], shape_b[3]); |
81 | } |
82 | |
83 | Call MakeCombinedOp(const Group& branches) { |
84 | const Op& conv2d = Op::Get("nn.conv2d" ); |
85 | Expr data = branches[0][0]->args[0]; |
86 | auto [new_weight, new_channels] = TransformWeight(branches); |
87 | |
88 | const CallNode* group_root = branches[0][0]; |
89 | const auto* attrs = group_root->attrs.as<Conv2DAttrs>(); |
90 | ICHECK(attrs); |
91 | const auto new_attrs = make_object<Conv2DAttrs>(); |
92 | new_attrs->strides = attrs->strides; |
93 | new_attrs->padding = attrs->padding; |
94 | new_attrs->dilation = attrs->dilation; |
95 | new_attrs->groups = attrs->groups; |
96 | new_attrs->kernel_size = attrs->kernel_size; |
97 | new_attrs->data_layout = attrs->data_layout; |
98 | new_attrs->kernel_layout = attrs->kernel_layout; |
99 | new_attrs->out_layout = attrs->out_layout; |
100 | new_attrs->out_dtype = attrs->out_dtype; |
101 | new_attrs->channels = new_channels; |
102 | |
103 | const std::string& layout = |
104 | new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout; |
105 | channel_pos_ = layout.find('C'); |
106 | ICHECK_NE(channel_pos_, std::string::npos); |
107 | |
108 | return Call(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); |
109 | } |
110 | |
111 | bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { |
112 | StructuralEqual eq; |
113 | auto ta = a->args[index]->type_as<TensorTypeNode>(); |
114 | auto tb = b->args[index]->type_as<TensorTypeNode>(); |
115 | auto toutput_a = a->type_as<TensorTypeNode>(); |
116 | auto toutput_b = b->type_as<TensorTypeNode>(); |
117 | |
118 | if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; |
119 | |
120 | // Position of the 'C' dimension in the argument |
121 | size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); |
122 | |
123 | // Channel super-dimension shoule be present and not broadcasted |
124 | if ((arg_channel_pos > channel_pos_) || // size_t overflow |
125 | !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) || |
126 | !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_])) |
127 | return false; |
128 | |
129 | for (size_t i = 0; i < ta->shape.size(); i++) { |
130 | if (i == arg_channel_pos) continue; |
131 | if (!eq(ta->shape[i], tb->shape[i])) return false; |
132 | } |
133 | return true; |
134 | } |
135 | |
136 | Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, |
137 | size_t parent_index) { |
138 | Array<Expr> new_args; |
139 | const CallNode* call = branches[0][depth]; |
140 | size_t ndim = call->type_as<TensorTypeNode>()->shape.size(); |
141 | |
142 | for (size_t i = 0; i < call->args.size(); i++) { |
143 | if (i == parent_index) { |
144 | new_args.push_back(data); |
145 | continue; |
146 | } |
147 | |
148 | size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size(); |
149 | size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim; |
150 | Array<Expr> tuple; |
151 | for (const auto& branch : branches) { |
152 | tuple.push_back(branch[depth]->args[i]); |
153 | } |
154 | |
155 | auto concat = MakeConcatenate(Tuple(tuple), arg_channel_pos); |
156 | new_args.push_back(std::move(concat)); |
157 | } |
158 | |
159 | return Call(call->op, new_args, call->attrs, {}); |
160 | } |
161 | |
162 | void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, |
163 | ExprSubstMap* subst_map) { |
164 | int64_t index = 0; |
165 | |
166 | for (const auto& branch : branches) { |
167 | const CallNode* conv2d = branch[0]; |
168 | int64_t channels = GetConv2DSuperChannelsDim(conv2d); |
169 | Array<Integer> begin; |
170 | Array<Integer> end; |
171 | for (size_t i = 0; i < channel_pos_; i++) { |
172 | begin.push_back(0); |
173 | end.push_back(-1); |
174 | } |
175 | begin.push_back(index); |
176 | index += channels; |
177 | end.push_back(channels); |
178 | Array<Integer> strides(begin.size(), 1); |
179 | auto slice = MakeStridedSlice(data, begin, end, strides, "size" ); |
180 | subst_map->insert({GetRef<Expr>(branch[depth]), slice}); |
181 | } |
182 | } |
183 | |
184 | private: |
185 | /* \brief index of channel dimension */ |
186 | size_t channel_pos_; |
187 | |
188 | std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) { |
189 | int64_t num_filters = 0; // number of filters of the transformed weight |
190 | Array<Expr> weights; |
191 | for (const auto& branch : branches) { |
192 | auto conv2d = branch[0]; |
193 | weights.push_back(conv2d->args[1]); |
194 | auto channels = GetConv2DSuperChannelsDim(conv2d); |
195 | num_filters += channels; |
196 | } |
197 | auto index = |
198 | branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.operator std::string().find('O'); |
199 | ICHECK_NE(index, std::string::npos); |
200 | return std::make_tuple(MakeConcatenate(Tuple(weights), index), |
201 | tir::make_const(DataType::Int(32), num_filters)); |
202 | } |
203 | }; |
204 | |
205 | /*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ |
206 | Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { |
207 | return ParallelConv2DCombiner(min_num_branches).Combine(expr); |
208 | } |
209 | |
210 | namespace transform { |
211 | |
212 | Pass CombineParallelConv2D(uint64_t min_num_branches) { |
213 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
214 | [=](Function f, IRModule m, PassContext pc) { |
215 | return Downcast<Function>(CombineParallelConv2D(f, min_num_branches)); |
216 | }; |
217 | return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d" , {"InferType" }); |
218 | } |
219 | |
220 | TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D" ).set_body_typed(CombineParallelConv2D); |
221 | |
222 | } // namespace transform |
223 | |
224 | } // namespace relay |
225 | } // namespace tvm |
226 | |