1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_conv2d.cc
23 * \brief Combine parallel 2d convolutions into a single convolution.
24 *
25 * This pass replaces convolutions that share the same input node and the same
26 * arguments (except that the number of output channels can be different) with a
27 * single convolution. The weight of the new 2d convolution is the concatenation
28 * of the original weights. Elemwise and broadcast ops following conv2d are also
29 * combined if possible.
30 *
31 * This prevents launching multiple kernels in networks with multiple
32 * convolution branches, such as Inception block.
33 */
34
35#include <tvm/relay/analysis.h>
36#include <tvm/relay/attrs/nn.h>
37#include <tvm/relay/attrs/transform.h>
38#include <tvm/relay/expr_functor.h>
39#include <tvm/relay/op_attr_types.h>
40#include <tvm/relay/transform.h>
41
42#include <unordered_map>
43#include <unordered_set>
44
45#include "./combine_parallel_op.h"
46#include "./expr_subst.h"
47#include "pattern_utils.h"
48
49namespace tvm {
50namespace relay {
51
52class ParallelConv2DCombiner : public ParallelOpCombiner {
53 public:
54 explicit ParallelConv2DCombiner(uint64_t min_num_branches)
55 : ParallelOpCombiner("nn.conv2d", min_num_branches) {}
56
57 protected:
58 bool IsSupportedOp(const CallNode* n) { return n->attrs.as<Conv2DAttrs>()->groups == 1; }
59
60 bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
61 StructuralEqual eq;
62 const Layout kOIHW("OIHW");
63 const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
64 const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
65 ICHECK(attrs_a);
66 ICHECK(attrs_b);
67 const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
68 const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
69 const auto shape_a =
70 tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
71 const auto shape_b =
72 tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
73
74 return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
75 eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
76 eq(attrs_a->data_layout, attrs_b->data_layout) &&
77 eq(attrs_a->kernel_layout, attrs_b->kernel_layout) &&
78 eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
79 eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
80 eq(shape_a[3], shape_b[3]);
81 }
82
83 Call MakeCombinedOp(const Group& branches) {
84 const Op& conv2d = Op::Get("nn.conv2d");
85 Expr data = branches[0][0]->args[0];
86 auto [new_weight, new_channels] = TransformWeight(branches);
87
88 const CallNode* group_root = branches[0][0];
89 const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
90 ICHECK(attrs);
91 const auto new_attrs = make_object<Conv2DAttrs>();
92 new_attrs->strides = attrs->strides;
93 new_attrs->padding = attrs->padding;
94 new_attrs->dilation = attrs->dilation;
95 new_attrs->groups = attrs->groups;
96 new_attrs->kernel_size = attrs->kernel_size;
97 new_attrs->data_layout = attrs->data_layout;
98 new_attrs->kernel_layout = attrs->kernel_layout;
99 new_attrs->out_layout = attrs->out_layout;
100 new_attrs->out_dtype = attrs->out_dtype;
101 new_attrs->channels = new_channels;
102
103 const std::string& layout =
104 new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout;
105 channel_pos_ = layout.find('C');
106 ICHECK_NE(channel_pos_, std::string::npos);
107
108 return Call(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
109 }
110
111 bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
112 StructuralEqual eq;
113 auto ta = a->args[index]->type_as<TensorTypeNode>();
114 auto tb = b->args[index]->type_as<TensorTypeNode>();
115 auto toutput_a = a->type_as<TensorTypeNode>();
116 auto toutput_b = b->type_as<TensorTypeNode>();
117
118 if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false;
119
120 // Position of the 'C' dimension in the argument
121 size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size();
122
123 // Channel super-dimension shoule be present and not broadcasted
124 if ((arg_channel_pos > channel_pos_) || // size_t overflow
125 !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) ||
126 !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_]))
127 return false;
128
129 for (size_t i = 0; i < ta->shape.size(); i++) {
130 if (i == arg_channel_pos) continue;
131 if (!eq(ta->shape[i], tb->shape[i])) return false;
132 }
133 return true;
134 }
135
136 Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
137 size_t parent_index) {
138 Array<Expr> new_args;
139 const CallNode* call = branches[0][depth];
140 size_t ndim = call->type_as<TensorTypeNode>()->shape.size();
141
142 for (size_t i = 0; i < call->args.size(); i++) {
143 if (i == parent_index) {
144 new_args.push_back(data);
145 continue;
146 }
147
148 size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
149 size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim;
150 Array<Expr> tuple;
151 for (const auto& branch : branches) {
152 tuple.push_back(branch[depth]->args[i]);
153 }
154
155 auto concat = MakeConcatenate(Tuple(tuple), arg_channel_pos);
156 new_args.push_back(std::move(concat));
157 }
158
159 return Call(call->op, new_args, call->attrs, {});
160 }
161
162 void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
163 ExprSubstMap* subst_map) {
164 int64_t index = 0;
165
166 for (const auto& branch : branches) {
167 const CallNode* conv2d = branch[0];
168 int64_t channels = GetConv2DSuperChannelsDim(conv2d);
169 Array<Integer> begin;
170 Array<Integer> end;
171 for (size_t i = 0; i < channel_pos_; i++) {
172 begin.push_back(0);
173 end.push_back(-1);
174 }
175 begin.push_back(index);
176 index += channels;
177 end.push_back(channels);
178 Array<Integer> strides(begin.size(), 1);
179 auto slice = MakeStridedSlice(data, begin, end, strides, "size");
180 subst_map->insert({GetRef<Expr>(branch[depth]), slice});
181 }
182 }
183
184 private:
185 /* \brief index of channel dimension */
186 size_t channel_pos_;
187
188 std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
189 int64_t num_filters = 0; // number of filters of the transformed weight
190 Array<Expr> weights;
191 for (const auto& branch : branches) {
192 auto conv2d = branch[0];
193 weights.push_back(conv2d->args[1]);
194 auto channels = GetConv2DSuperChannelsDim(conv2d);
195 num_filters += channels;
196 }
197 auto index =
198 branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.operator std::string().find('O');
199 ICHECK_NE(index, std::string::npos);
200 return std::make_tuple(MakeConcatenate(Tuple(weights), index),
201 tir::make_const(DataType::Int(32), num_filters));
202 }
203};
204
205/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
206Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
207 return ParallelConv2DCombiner(min_num_branches).Combine(expr);
208}
209
210namespace transform {
211
212Pass CombineParallelConv2D(uint64_t min_num_branches) {
213 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
214 [=](Function f, IRModule m, PassContext pc) {
215 return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
216 };
217 return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
218}
219
220TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D").set_body_typed(CombineParallelConv2D);
221
222} // namespace transform
223
224} // namespace relay
225} // namespace tvm
226