1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_dense.cc
23 * \brief Combine parallel dense ops into a single dense.
24 *
25 * This pass replaces dense ops that share the same input node, same shape,
26 * and don't have "units" defined with a single batch matrix multiplication.
27 * The inputs of the new batch_matmul is the stack of the original inputs.
28 * Elemwise and broadcast ops following dense are also combined if possible.
29 *
30 * This prevents launching multiple kernels in networks with multiple
31 * dense branches, such as BERT.
32 */
33
34#include <tvm/relay/analysis.h>
35#include <tvm/relay/attrs/nn.h>
36#include <tvm/relay/attrs/transform.h>
37#include <tvm/relay/expr_functor.h>
38#include <tvm/relay/op_attr_types.h>
39#include <tvm/relay/transform.h>
40
41#include <unordered_map>
42#include <unordered_set>
43
44#include "./combine_parallel_op_batch.h"
45#include "./expr_subst.h"
46#include "pattern_utils.h"
47
48namespace tvm {
49namespace relay {
50
51/*
52 * Class that find and combine parallel dense ops into batch_matmul.
53 */
54class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner {
55 public:
56 explicit ParallelDenseToBatchCombiner(uint64_t min_num_branches)
57 : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}
58
59 protected:
60 Call MakeCombinedOp(const Group& branches) {
61 Array<Expr> new_args;
62 size_t num_args = branches[0][0]->args.size();
63 for (size_t i = 0; i < num_args; i++) {
64 Array<Expr> arg_from_all_branches;
65 for (const auto& branch : branches) {
66 arg_from_all_branches.push_back(branch[0]->args[i]);
67 }
68
69 new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0));
70 }
71
72 CHECK_EQ(num_args, 2);
73 const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
74 ICHECK(origin_attrs);
75 return Downcast<Call>(
76 MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype, false, true));
77 }
78
79 virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
80 StructuralEqual eq;
81 const auto* attrs_a = a->attrs.as<DenseAttrs>();
82 const auto* attrs_b = b->attrs.as<DenseAttrs>();
83 ICHECK(attrs_a);
84 ICHECK(attrs_b);
85 const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
86 const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
87
88 return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
89 eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]);
90 }
91};
92
93/*
94 * Class that find and combine parallel dense ops into one dense op
95 * whose num of output units equals to sum of each sub-ops.
96 */
97class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
98 public:
99 explicit ParallelDenseToDenseCombiner(uint64_t min_num_branches)
100 : ParallelOpCombiner("nn.dense", min_num_branches) {}
101
102 protected:
103 bool IsSupportedOp(const CallNode* n) { return true; }
104
105 bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
106 StructuralEqual eq;
107 const auto* attrs_a = a->attrs.as<DenseAttrs>();
108 const auto* attrs_b = b->attrs.as<DenseAttrs>();
109 const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
110 const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
111 ICHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
112 // output dims (weight->shape[0]) can be different
113 return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
114 }
115
116 Call MakeCombinedOp(const Group& branches) {
117 const Op& dense_op = Op::Get("nn.dense");
118 Expr input = branches[0][0]->args[0];
119 // concat all weights into one
120 auto [new_weight, new_output_dims] = TransformWeight(branches);
121 const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
122 ICHECK(origin_attrs);
123 const auto dense_attrs = make_object<DenseAttrs>();
124 dense_attrs->units = new_output_dims;
125 dense_attrs->out_dtype = origin_attrs->out_dtype;
126 return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
127 }
128
129 bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
130 StructuralEqual eq;
131 auto ta = a->args[index]->type_as<TensorTypeNode>();
132 auto tb = b->args[index]->type_as<TensorTypeNode>();
133 auto toutput_a = a->type_as<TensorTypeNode>();
134 auto toutput_b = b->type_as<TensorTypeNode>();
135 ICHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);
136
137 if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
138 return false;
139 }
140 if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
141 return false; // not broadcast/elemwise
142 }
143 if (ta->shape.size() > 0) {
144 for (size_t i = 0; i < ta->shape.size() - 1; i++) {
145 // shape dims must match except last dim
146 if (!eq(ta->shape[i], tb->shape[i])) return false;
147 }
148 }
149 return true;
150 }
151
152 Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
153 size_t parent_index) {
154 Array<Expr> new_args;
155 const CallNode* call = branches[0][depth];
156 for (size_t i = 0; i < call->args.size(); i++) {
157 if (i == parent_index) {
158 new_args.push_back(data);
159 continue;
160 }
161 size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
162 size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
163 Array<Expr> tuple;
164 for (const auto& branch : branches) {
165 auto parent = branch[depth]->args[parent_index];
166 auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
167 auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
168 ICHECK(out_dim != nullptr);
169
170 auto arg = branch[depth]->args[i];
171 auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
172 bool repeat_last_dim = false;
173 if (arg_ndim == 0) {
174 repeat_last_dim = true;
175 arg = MakeExpandDims(arg, -1, 1);
176 } else {
177 auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
178 ICHECK(arg_last_dim != nullptr);
179 if (*out_dim > 1 && *arg_last_dim == 1) {
180 repeat_last_dim = true;
181 }
182 }
183 if (repeat_last_dim) {
184 // ensure broadcast is valid after concat args
185 arg = MakeRepeat(arg, *out_dim, concat_axis);
186 }
187 tuple.push_back(arg);
188 }
189 auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
190 new_args.push_back(std::move(concat));
191 }
192 return Call(call->op, new_args, call->attrs, {});
193 }
194
195 void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
196 ExprSubstMap* subst_map) {
197 int index = 0;
198 const auto dense_op = Op::Get("nn.dense");
199 for (const auto& branch : branches) {
200 const CallNode* call = branch[depth];
201 auto& out_shape = call->type_as<TensorTypeNode>()->shape;
202
203 const CallNode* dense = branch[0];
204 ICHECK(dense->op.same_as(dense_op));
205 auto& dense_shape = dense->type_as<TensorTypeNode>()->shape;
206 auto dense_out_dims = tir::as_const_int(dense_shape[1]);
207 ICHECK(dense_out_dims != nullptr);
208
209 // dense can be followed by shape-changing operations, so the slicing axis is
210 // not necessarily the last one.
211 // TODO(masahi): The following logic is incorrect if (1) there is no axis in
212 // out_shape[i] that directly corresponds to the output channel of dense or (2) there
213 // is another axis that happens to have the same size as the output channel of dense.
214 // Such cases might arise due to reshape / transpose / split etc. Revisit this logic
215 // when we encounter them in practice.
216 auto slice_axis = -1;
217 for (size_t i = out_shape.size() - 1; i >= 0; --i) {
218 ICHECK(tir::as_const_int(out_shape[i]));
219 if (*tir::as_const_int(out_shape[i]) == *dense_out_dims) {
220 slice_axis = i;
221 break;
222 }
223 }
224 ICHECK(slice_axis != -1);
225
226 Array<Integer> begin(out_shape.size(), 0);
227 Array<Integer> end(out_shape.size(), -1);
228 Array<Integer> strides(out_shape.size(), 1);
229 begin.Set(slice_axis, index);
230 end.Set(slice_axis, *dense_out_dims);
231 index += *dense_out_dims;
232 auto slice = MakeStridedSlice(data, begin, end, strides, "size");
233 subst_map->insert({GetRef<Expr>(branch[depth]), slice});
234 }
235 }
236
237 private:
238 std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
239 int64_t out_dims = 0;
240 Array<Expr> weights;
241 for (const auto& branch : branches) {
242 auto weight = branch[0]->args[1];
243 weights.push_back(weight);
244 out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
245 }
246 return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
247 tir::make_const(DataType::Int(32), out_dims));
248 }
249};
250
251/*! \brief Combine parallel dense if number of branches >= min_num_branches */
252Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
253 if (to_batch) {
254 return ParallelDenseToBatchCombiner(min_num_branches).Combine(expr);
255 } else {
256 return ParallelDenseToDenseCombiner(min_num_branches).Combine(expr);
257 }
258}
259
260namespace transform {
261
262Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch_matmul) {
263 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
264 [=](Function f, IRModule m, PassContext pc) {
265 return Downcast<Function>(CombineParallelDense(f, min_num_branches, to_batch_matmul));
266 };
267 return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
268}
269
270TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense);
271
272} // namespace transform
273
274} // namespace relay
275} // namespace tvm
276