1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file combine_parallel_batch_matmul.cc |
23 | * \brief Combine parallel batch matmuls into a single one. |
24 | * |
25 | * This pass replaces batch_matmul that share the same lhs node with a |
26 | * single batch matmul.Elemwise and broadcast ops following batch_matmul are also |
27 | * combined if possible. |
28 | * |
29 | * This prevents launching multiple kernels in networks with multiple |
30 | * convolution branches, such as Inception block. |
31 | */ |
32 | |
33 | #include <tvm/relay/analysis.h> |
34 | #include <tvm/relay/attrs/nn.h> |
35 | #include <tvm/relay/attrs/transform.h> |
36 | #include <tvm/relay/expr_functor.h> |
37 | #include <tvm/relay/op_attr_types.h> |
38 | #include <tvm/relay/transform.h> |
39 | |
40 | #include <unordered_map> |
41 | #include <unordered_set> |
42 | |
43 | #include "./combine_parallel_op.h" |
44 | #include "./expr_subst.h" |
45 | #include "pattern_utils.h" |
46 | |
47 | namespace tvm { |
48 | namespace relay { |
49 | |
50 | class ParallelBatchMatmulCombiner : public ParallelOpCombiner { |
51 | public: |
52 | explicit ParallelBatchMatmulCombiner(uint64_t min_num_branches) |
53 | : ParallelOpCombiner("nn.batch_matmul" , min_num_branches) {} |
54 | |
55 | protected: |
56 | bool IsSupportedOp(const CallNode* n) { return true; } |
57 | |
58 | bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { |
59 | StructuralEqual eq; |
60 | const auto* attrs_a = a->attrs.as<BatchMatmulAttrs>(); |
61 | const auto* attrs_b = b->attrs.as<BatchMatmulAttrs>(); |
62 | ICHECK(attrs_a); |
63 | ICHECK(attrs_b); |
64 | const auto* rhs_a = a->args[1]->type_as<TensorTypeNode>(); |
65 | const auto* rhs_b = b->args[1]->type_as<TensorTypeNode>(); |
66 | const auto* restype_a = a->type_as<TensorTypeNode>(); |
67 | const auto* restype_b = b->type_as<TensorTypeNode>(); |
68 | // shape[2] is the contraction axis and automatically consistent |
69 | // if it were valid batch_matmul ops |
70 | |
71 | // TODO(jcf94): Add full support of layout format |
72 | if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true && |
73 | attrs_b->transpose_a == false && attrs_b->transpose_b == true)) { |
74 | LOG(WARNING) << "For legacy reason, this pass only supports" |
75 | << " (transpose_a=false, transpose_b=true) now, skip combining these two with:" |
76 | << " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b |
77 | << " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b; |
78 | return false; |
79 | } |
80 | |
81 | auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && |
82 | (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && |
83 | eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype); |
84 | return res; |
85 | } |
86 | |
87 | Call MakeCombinedOp(const Group& branches) { |
88 | Expr data = branches[0][0]->args[0]; |
89 | |
90 | Array<Expr> weights; |
91 | for (const auto& branch : branches) { |
92 | auto call = branch[0]; |
93 | weights.push_back(call->args[1]); |
94 | } |
95 | Expr new_weight = MakeConcatenate(Tuple(weights), 1); |
96 | |
97 | const auto* origin_attrs = branches[0][0]->attrs.as<BatchMatmulAttrs>(); |
98 | ICHECK(origin_attrs); |
99 | return Downcast<Call>(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype, |
100 | origin_attrs->transpose_a, origin_attrs->transpose_b)); |
101 | } |
102 | |
103 | bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } |
104 | |
105 | Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, |
106 | size_t parent_index) { |
107 | Array<Expr> new_args; |
108 | const CallNode* call = branches[0][depth]; |
109 | |
110 | for (size_t i = 0; i < call->args.size(); i++) { |
111 | if (i == parent_index) { |
112 | new_args.push_back(data); |
113 | continue; |
114 | } |
115 | |
116 | Array<Expr> tuple; |
117 | for (const auto& branch : branches) { |
118 | tuple.push_back(branch[depth]->args[i]); |
119 | } |
120 | |
121 | auto concat = MakeConcatenate(Tuple(tuple), -1); |
122 | new_args.push_back(std::move(concat)); |
123 | } |
124 | |
125 | return Call(call->op, new_args, call->attrs, {}); |
126 | } |
127 | |
128 | void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, |
129 | ExprSubstMap* subst_map) { |
130 | int64_t index = 0; |
131 | |
132 | for (const auto& branch : branches) { |
133 | const CallNode* batch_matmul = branch[0]; |
134 | auto feature_dim = batch_matmul->args[1]->type_as<TensorTypeNode>()->shape[1]; |
135 | auto fpp = tir::as_const_int(feature_dim); |
136 | int64_t features = *fpp; |
137 | Array<Integer> begin; |
138 | Array<Integer> end; |
139 | for (size_t i = 0; i < 2; i++) { |
140 | begin.push_back(0); |
141 | end.push_back(-1); |
142 | } |
143 | begin.push_back(index); |
144 | index += features; |
145 | end.push_back(features); |
146 | Array<Integer> strides(begin.size(), 1); |
147 | auto slice = MakeStridedSlice(data, begin, end, strides, "size" ); |
148 | subst_map->insert({GetRef<Expr>(branch[depth]), slice}); |
149 | } |
150 | } |
151 | }; |
152 | |
153 | /*! \brief Combine parallel batch_matmul if number of branches >= min_num_branches */ |
154 | Expr CombineParallelBatchMatmul(const Expr& expr, uint64_t min_num_branches) { |
155 | return ParallelBatchMatmulCombiner(min_num_branches).Combine(expr); |
156 | } |
157 | |
158 | namespace transform { |
159 | |
160 | Pass CombineParallelBatchMatmul(uint64_t min_num_branches) { |
161 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
162 | [=](Function f, IRModule m, PassContext pc) { |
163 | return Downcast<Function>(CombineParallelBatchMatmul(f, min_num_branches)); |
164 | }; |
165 | return CreateFunctionPass(pass_func, 4, "CombineParallelBatchMatmul" , {"InferType" }); |
166 | } |
167 | |
168 | TVM_REGISTER_GLOBAL("relay._transform.CombineParallelBatchMatmul" ) |
169 | .set_body_typed(CombineParallelBatchMatmul); |
170 | |
171 | } // namespace transform |
172 | |
173 | } // namespace relay |
174 | } // namespace tvm |
175 | |