1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_batch_matmul.cc
23 * \brief Combine parallel batch matmuls into a single one.
24 *
25 * This pass replaces batch_matmul that share the same lhs node with a
26 * single batch matmul.Elemwise and broadcast ops following batch_matmul are also
27 * combined if possible.
28 *
29 * This prevents launching multiple kernels in networks with multiple
30 * convolution branches, such as Inception block.
31 */
32
33#include <tvm/relay/analysis.h>
34#include <tvm/relay/attrs/nn.h>
35#include <tvm/relay/attrs/transform.h>
36#include <tvm/relay/expr_functor.h>
37#include <tvm/relay/op_attr_types.h>
38#include <tvm/relay/transform.h>
39
40#include <unordered_map>
41#include <unordered_set>
42
43#include "./combine_parallel_op.h"
44#include "./expr_subst.h"
45#include "pattern_utils.h"
46
47namespace tvm {
48namespace relay {
49
50class ParallelBatchMatmulCombiner : public ParallelOpCombiner {
51 public:
52 explicit ParallelBatchMatmulCombiner(uint64_t min_num_branches)
53 : ParallelOpCombiner("nn.batch_matmul", min_num_branches) {}
54
55 protected:
56 bool IsSupportedOp(const CallNode* n) { return true; }
57
58 bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
59 StructuralEqual eq;
60 const auto* attrs_a = a->attrs.as<BatchMatmulAttrs>();
61 const auto* attrs_b = b->attrs.as<BatchMatmulAttrs>();
62 ICHECK(attrs_a);
63 ICHECK(attrs_b);
64 const auto* rhs_a = a->args[1]->type_as<TensorTypeNode>();
65 const auto* rhs_b = b->args[1]->type_as<TensorTypeNode>();
66 const auto* restype_a = a->type_as<TensorTypeNode>();
67 const auto* restype_b = b->type_as<TensorTypeNode>();
68 // shape[2] is the contraction axis and automatically consistent
69 // if it were valid batch_matmul ops
70
71 // TODO(jcf94): Add full support of layout format
72 if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true &&
73 attrs_b->transpose_a == false && attrs_b->transpose_b == true)) {
74 LOG(WARNING) << "For legacy reason, this pass only supports"
75 << " (transpose_a=false, transpose_b=true) now, skip combining these two with:"
76 << " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b
77 << " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b;
78 return false;
79 }
80
81 auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) &&
82 (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) &&
83 eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype);
84 return res;
85 }
86
87 Call MakeCombinedOp(const Group& branches) {
88 Expr data = branches[0][0]->args[0];
89
90 Array<Expr> weights;
91 for (const auto& branch : branches) {
92 auto call = branch[0];
93 weights.push_back(call->args[1]);
94 }
95 Expr new_weight = MakeConcatenate(Tuple(weights), 1);
96
97 const auto* origin_attrs = branches[0][0]->attrs.as<BatchMatmulAttrs>();
98 ICHECK(origin_attrs);
99 return Downcast<Call>(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype,
100 origin_attrs->transpose_a, origin_attrs->transpose_b));
101 }
102
103 bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; }
104
105 Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
106 size_t parent_index) {
107 Array<Expr> new_args;
108 const CallNode* call = branches[0][depth];
109
110 for (size_t i = 0; i < call->args.size(); i++) {
111 if (i == parent_index) {
112 new_args.push_back(data);
113 continue;
114 }
115
116 Array<Expr> tuple;
117 for (const auto& branch : branches) {
118 tuple.push_back(branch[depth]->args[i]);
119 }
120
121 auto concat = MakeConcatenate(Tuple(tuple), -1);
122 new_args.push_back(std::move(concat));
123 }
124
125 return Call(call->op, new_args, call->attrs, {});
126 }
127
128 void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
129 ExprSubstMap* subst_map) {
130 int64_t index = 0;
131
132 for (const auto& branch : branches) {
133 const CallNode* batch_matmul = branch[0];
134 auto feature_dim = batch_matmul->args[1]->type_as<TensorTypeNode>()->shape[1];
135 auto fpp = tir::as_const_int(feature_dim);
136 int64_t features = *fpp;
137 Array<Integer> begin;
138 Array<Integer> end;
139 for (size_t i = 0; i < 2; i++) {
140 begin.push_back(0);
141 end.push_back(-1);
142 }
143 begin.push_back(index);
144 index += features;
145 end.push_back(features);
146 Array<Integer> strides(begin.size(), 1);
147 auto slice = MakeStridedSlice(data, begin, end, strides, "size");
148 subst_map->insert({GetRef<Expr>(branch[depth]), slice});
149 }
150 }
151};
152
153/*! \brief Combine parallel batch_matmul if number of branches >= min_num_branches */
154Expr CombineParallelBatchMatmul(const Expr& expr, uint64_t min_num_branches) {
155 return ParallelBatchMatmulCombiner(min_num_branches).Combine(expr);
156}
157
158namespace transform {
159
160Pass CombineParallelBatchMatmul(uint64_t min_num_branches) {
161 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
162 [=](Function f, IRModule m, PassContext pc) {
163 return Downcast<Function>(CombineParallelBatchMatmul(f, min_num_branches));
164 };
165 return CreateFunctionPass(pass_func, 4, "CombineParallelBatchMatmul", {"InferType"});
166}
167
168TVM_REGISTER_GLOBAL("relay._transform.CombineParallelBatchMatmul")
169 .set_body_typed(CombineParallelBatchMatmul);
170
171} // namespace transform
172
173} // namespace relay
174} // namespace tvm
175