1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_op_batch.cc
23 * \brief Combine parallel ops into a single batch op.
24 *
25 * This pass replaces ops that share the same input node and same shape
26 * with a single op that takes in batched input. The inputs of the new
27 * batched op are the stack of the original inputs. Elementwise and
28 * broadcast ops following the original op are also stacked
29 * and fused if possible. For example:
30 *
31 * data
32 * / \
33 * add (2,2) add (2,2)
34 * | |
35 * elemwise (2,2) elemwise (2,2)
36 * | |
37 *
38 * Would become:
39 *
40 * data
41 * |
42 * add+elemwise (2,2,2)
43 * / \
44 *
45 */
46
47#include "./combine_parallel_op_batch.h"
48
49#include <tvm/relay/analysis.h>
50#include <tvm/relay/attrs/nn.h>
51#include <tvm/relay/attrs/transform.h>
52#include <tvm/relay/expr_functor.h>
53#include <tvm/relay/op_attr_types.h>
54#include <tvm/relay/transform.h>
55
56#include <unordered_map>
57#include <unordered_set>
58
59#include "./combine_parallel_op.h"
60#include "./expr_subst.h"
61#include "pattern_utils.h"
62
63namespace tvm {
64namespace relay {
65
66ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name,
67 const std::string& batch_op_name,
68 uint64_t min_num_branches)
69 : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {}
70
71bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; }
72
73bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) {
74 if (a->args.size() != b->args.size()) {
75 return false;
76 }
77
78 StructuralEqual eq;
79 for (size_t i = 0; i < a->args.size(); i++) {
80 auto ta = a->args[i]->type_as<TensorTypeNode>();
81 auto tb = b->args[i]->type_as<TensorTypeNode>();
82 if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) {
83 return false;
84 }
85
86 for (size_t j = 0; j < ta->shape.size(); j++) {
87 if (!eq(ta->shape[j], tb->shape[j])) {
88 return false;
89 }
90 }
91 }
92
93 return true;
94}
95
96Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
97 const Op& batch_op = Op::Get(batch_op_name_);
98
99 Array<Expr> new_args;
100 size_t num_args = branches[0][0]->args.size();
101 for (size_t i = 0; i < num_args; i++) {
102 Array<Expr> arg_from_all_branches;
103 for (const auto& branch : branches) {
104 arg_from_all_branches.push_back(branch[0]->args[i]);
105 }
106
107 new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0));
108 }
109
110 return Call(batch_op, new_args, Attrs(), {});
111}
112
113bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
114 StructuralEqual eq;
115 auto ta = a->args[index]->type_as<TensorTypeNode>();
116 auto tb = b->args[index]->type_as<TensorTypeNode>();
117
118 if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false;
119
120 for (size_t i = 0; i < ta->shape.size(); i++) {
121 if (!eq(ta->shape[i], tb->shape[i])) return false;
122 }
123 return true;
124}
125
126Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data,
127 const Group& branches, size_t depth,
128 size_t parent_index) {
129 Array<Expr> new_args;
130 const CallNode* call = branches[0][depth];
131
132 for (size_t i = 0; i < call->args.size(); i++) {
133 if (i == parent_index) {
134 new_args.push_back(data);
135 continue;
136 }
137
138 Array<Expr> tuple;
139 for (const auto& branch : branches) {
140 // if the shape of the arg is of shape (j,),
141 // expand it to (1,j) so it can be properly broadcasted.
142 Expr arg = branch[depth]->args[i];
143 const TensorTypeNode* arg_tensor = arg->type_as<TensorTypeNode>();
144 if (arg_tensor->shape.size() == 1) {
145 Expr expanded_arg = MakeExpandDims(arg, 0, 1);
146 tuple.push_back(expanded_arg);
147 } else {
148 tuple.push_back(arg);
149 }
150 }
151
152 auto stack = MakeStack(Tuple(tuple), 0);
153 new_args.push_back(std::move(stack));
154 }
155
156 return Call(call->op, new_args, call->attrs, {});
157}
158
159void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches,
160 size_t depth, ExprSubstMap* subst_map) {
161 int index = 0;
162 auto split = MakeSplit(data, Integer(branches.size()), 0);
163 for (const auto& branch : branches) {
164 auto split_data = TupleGetItem(split, index++);
165 auto squeezed_data = MakeSqueeze(split_data, {0});
166 subst_map->insert({GetRef<Expr>(branch[depth]), squeezed_data});
167 }
168}
169
170/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */
171Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name,
172 const std::string& batch_op_name, uint64_t min_num_branches) {
173 return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr);
174}
175
176namespace transform {
177
178Pass CombineParallelOpBatch(const String& op_name, const String& batch_op_name,
179 uint64_t min_num_branches) {
180 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
181 [=](Function f, IRModule m, PassContext pc) {
182 return Downcast<Function>(
183 CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches));
184 };
185 return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
186}
187
188TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
189 .set_body_typed(CombineParallelOpBatch);
190
191} // namespace transform
192
193} // namespace relay
194} // namespace tvm
195