1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file combine_parallel_op_batch.cc |
23 | * \brief Combine parallel ops into a single batch op. |
24 | * |
25 | * This pass replaces ops that share the same input node and same shape |
26 | * with a single op that takes in batched input. The inputs of the new |
27 | * batched op are the stack of the original inputs. Elementwise and |
28 | * broadcast ops following the original op are also stacked |
29 | * and fused if possible. For example: |
30 | * |
31 | * data |
32 | * / \ |
33 | * add (2,2) add (2,2) |
34 | * | | |
35 | * elemwise (2,2) elemwise (2,2) |
36 | * | | |
37 | * |
38 | * Would become: |
39 | * |
40 | * data |
41 | * | |
42 | * add+elemwise (2,2,2) |
43 | * / \ |
44 | * |
45 | */ |
46 | |
47 | #include "./combine_parallel_op_batch.h" |
48 | |
49 | #include <tvm/relay/analysis.h> |
50 | #include <tvm/relay/attrs/nn.h> |
51 | #include <tvm/relay/attrs/transform.h> |
52 | #include <tvm/relay/expr_functor.h> |
53 | #include <tvm/relay/op_attr_types.h> |
54 | #include <tvm/relay/transform.h> |
55 | |
56 | #include <unordered_map> |
57 | #include <unordered_set> |
58 | |
59 | #include "./combine_parallel_op.h" |
60 | #include "./expr_subst.h" |
61 | #include "pattern_utils.h" |
62 | |
63 | namespace tvm { |
64 | namespace relay { |
65 | |
66 | ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, |
67 | const std::string& batch_op_name, |
68 | uint64_t min_num_branches) |
69 | : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {} |
70 | |
71 | bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; } |
72 | |
73 | bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { |
74 | if (a->args.size() != b->args.size()) { |
75 | return false; |
76 | } |
77 | |
78 | StructuralEqual eq; |
79 | for (size_t i = 0; i < a->args.size(); i++) { |
80 | auto ta = a->args[i]->type_as<TensorTypeNode>(); |
81 | auto tb = b->args[i]->type_as<TensorTypeNode>(); |
82 | if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) { |
83 | return false; |
84 | } |
85 | |
86 | for (size_t j = 0; j < ta->shape.size(); j++) { |
87 | if (!eq(ta->shape[j], tb->shape[j])) { |
88 | return false; |
89 | } |
90 | } |
91 | } |
92 | |
93 | return true; |
94 | } |
95 | |
96 | Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { |
97 | const Op& batch_op = Op::Get(batch_op_name_); |
98 | |
99 | Array<Expr> new_args; |
100 | size_t num_args = branches[0][0]->args.size(); |
101 | for (size_t i = 0; i < num_args; i++) { |
102 | Array<Expr> arg_from_all_branches; |
103 | for (const auto& branch : branches) { |
104 | arg_from_all_branches.push_back(branch[0]->args[i]); |
105 | } |
106 | |
107 | new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0)); |
108 | } |
109 | |
110 | return Call(batch_op, new_args, Attrs(), {}); |
111 | } |
112 | |
113 | bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { |
114 | StructuralEqual eq; |
115 | auto ta = a->args[index]->type_as<TensorTypeNode>(); |
116 | auto tb = b->args[index]->type_as<TensorTypeNode>(); |
117 | |
118 | if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; |
119 | |
120 | for (size_t i = 0; i < ta->shape.size(); i++) { |
121 | if (!eq(ta->shape[i], tb->shape[i])) return false; |
122 | } |
123 | return true; |
124 | } |
125 | |
126 | Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, |
127 | const Group& branches, size_t depth, |
128 | size_t parent_index) { |
129 | Array<Expr> new_args; |
130 | const CallNode* call = branches[0][depth]; |
131 | |
132 | for (size_t i = 0; i < call->args.size(); i++) { |
133 | if (i == parent_index) { |
134 | new_args.push_back(data); |
135 | continue; |
136 | } |
137 | |
138 | Array<Expr> tuple; |
139 | for (const auto& branch : branches) { |
140 | // if the shape of the arg is of shape (j,), |
141 | // expand it to (1,j) so it can be properly broadcasted. |
142 | Expr arg = branch[depth]->args[i]; |
143 | const TensorTypeNode* arg_tensor = arg->type_as<TensorTypeNode>(); |
144 | if (arg_tensor->shape.size() == 1) { |
145 | Expr expanded_arg = MakeExpandDims(arg, 0, 1); |
146 | tuple.push_back(expanded_arg); |
147 | } else { |
148 | tuple.push_back(arg); |
149 | } |
150 | } |
151 | |
152 | auto stack = MakeStack(Tuple(tuple), 0); |
153 | new_args.push_back(std::move(stack)); |
154 | } |
155 | |
156 | return Call(call->op, new_args, call->attrs, {}); |
157 | } |
158 | |
159 | void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, |
160 | size_t depth, ExprSubstMap* subst_map) { |
161 | int index = 0; |
162 | auto split = MakeSplit(data, Integer(branches.size()), 0); |
163 | for (const auto& branch : branches) { |
164 | auto split_data = TupleGetItem(split, index++); |
165 | auto squeezed_data = MakeSqueeze(split_data, {0}); |
166 | subst_map->insert({GetRef<Expr>(branch[depth]), squeezed_data}); |
167 | } |
168 | } |
169 | |
170 | /*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ |
171 | Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name, |
172 | const std::string& batch_op_name, uint64_t min_num_branches) { |
173 | return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); |
174 | } |
175 | |
176 | namespace transform { |
177 | |
178 | Pass CombineParallelOpBatch(const String& op_name, const String& batch_op_name, |
179 | uint64_t min_num_branches) { |
180 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
181 | [=](Function f, IRModule m, PassContext pc) { |
182 | return Downcast<Function>( |
183 | CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches)); |
184 | }; |
185 | return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch" , {"InferType" }); |
186 | } |
187 | |
188 | TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch" ) |
189 | .set_body_typed(CombineParallelOpBatch); |
190 | |
191 | } // namespace transform |
192 | |
193 | } // namespace relay |
194 | } // namespace tvm |
195 | |