1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_op.h
23 * \brief Abstract class to combine parallel ops and their successive element-wise ops.
24 */
25#ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
26#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
27
28#include <tvm/relay/analysis.h>
29#include <tvm/relay/attrs/nn.h>
30#include <tvm/relay/attrs/transform.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/op_attr_types.h>
33#include <tvm/relay/transform.h>
34
35#include <string>
36#include <unordered_map>
37#include <unordered_set>
38#include <vector>
39
40#include "./expr_subst.h"
41#include "pattern_utils.h"
42
43namespace tvm {
44namespace relay {
45
46using Branch = std::vector<const CallNode*>;
47using Group = std::vector<Branch>;
48using FIsSupportedOp = std::function<bool(const CallNode* n)>;
49using FAreCompatibleOps = std::function<bool(const CallNode* a, const CallNode* b)>;
50using ExprSubstMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>;
51
52/*
53 * Class to find parallel branches starting with op that are
54 * grouped if they are able to be combined. They are eligible to
55 * be combined if they have the same input data.
56 * Op can be followed by zero or more elemwise or broadcast ops,
57 * which are included in the group.
58 * Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
59 * which should be handled in ParallelOpCombiner.
60 *
61 * data
62 * / \
63 * op op
64 * | |
65 * elem-wise elem-wise
66 * | |
67 */
68class BranchGroupFinder : private ExprVisitor {
69 public:
70 /*
71 * \brief Constructor
72 * \param op The op that indicates the start of each group
73 * \param fis_supported_op function that returns true if op
74 * is supported for combining
75 * \param fare_compatible_ops function that returns true if
76 * two ops are compatible for combining
77 */
78 BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op,
79 FAreCompatibleOps fare_compatible_ops);
80
81 /*
82 * \brief Finds all groups that can be combined.
83 * \param expr Relay expression that represents function
84 * to look at for groups to be combined
85 * \return Vector of groups which can be combined.
86 */
87 std::vector<Group> Find(const Expr& expr);
88
89 private:
90 /* \brief Cache the op for finding parallel branches */
91 const Op& cached_op_;
92
93 /* \brief function to return true if op is eligible to be combined,
94 * false otherwise
95 */
96 FIsSupportedOp fis_supported_op_;
97
98 /* \brief function to return true if two parallel ops are eligible
99 * to be combined, false otherwise
100 */
101 FAreCompatibleOps fare_compatible_ops_;
102
103 /* \brief ops that are on the first (logically, leftmost) branch
104 * of parallel ops and are eligible to be combined
105 */
106 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> op_roots_;
107
108 /* \brief map of Expr to CallNodes that follow it */
109 std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual>
110 children_map_;
111
112 /*
113 * \brief Creates new branch from op and its children that have
114 * elementwise or broadcast patterns
115 * \return New branch
116 */
117 Branch CreateBranch(const CallNode* op);
118
119 /*
120 * \brief Expression visitor function
121 */
122 void VisitExpr_(const CallNode* n) final;
123};
124
125/*
126 * Abstract class to find and combine parallel ops and the elementwise ops that follow.
127 */
128class ParallelOpCombiner {
129 public:
130 /*! \brief virtual destructor */
131 virtual ~ParallelOpCombiner() {}
132 /*
133 * \brief Constructor.
134 * \param op_name name of op to combine
135 * \param min_num_branches min number of parallel branches beginning with op
136 * to start combining
137 */
138 explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches);
139
140 /*
141 * \brief Combines ops and following elementwise or broadcast ops
142 * \param expr function to modify
143 * \return new function with combined ops
144 */
145 Expr Combine(const Expr& expr);
146
147 protected:
148 /*
149 * \brief Checks if node is supported to be combined
150 * \param n node in question
151 * \return True if the op represented by n is supported to be the root of a branch
152 * to be combined. False otherwise.
153 */
154 virtual bool IsSupportedOp(const CallNode* n) = 0;
155
156 /*
157 * \brief Checks if two ops can be combined
158 * \param a node a
159 * \param b node b
160 * \return True if a and b can be combined. False otherwise.
161 */
162 virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0;
163
164 /*
165 * \brief Makes combined op from parallel ops in branches. This usually involves
166 * concatenating or stacking inputs, then creating a new call.
167 * \param branches branches that are to be combined
168 * \return new call with branches combined.
169 */
170 virtual Call MakeCombinedOp(const Group& branches) = 0;
171
172 /*
173 * \brief Checks if argument of op following combined ops are able to be combined
174 * \param a node a
175 * \param b node b
176 * \param index index of argument in question
177 * \return True if argument of a and b and index can be combined
178 */
179 virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0;
180
181 /*
182 * \brief Create combined call from ops that follow the initial combined op at the depth-th level.
183 * This usually involves concatenating or stacking inputs, then creating a new call.
184 * Only called if IsArgCompatbile returns true for each arg.
185 * \param data combined op
186 * \param branches branches of parallel ops to be combined
187 * \param depth depth at which to combine ops
188 * \param parent_index index of arg that corresponds to original input that was shared among
189 * all combined ops
190 * \return new combined call
191 */
192 virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches,
193 size_t depth, size_t parent_index) = 0;
194
195 /*
196 * \brief Updates map of expr to substitute with combined expr. This usually involves
197 * slicing or splitting data.
198 * \param data combined op
199 * \param branches branches of parallel ops to be combined
200 * \param depth depth at which to substitute
201 * \param subst_map map of Expr to replace with Expr to replace it with
202 */
203 virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
204 ExprSubstMap* subst_map) = 0;
205
206 private:
207 /* \brief Cache the op to be combined */
208 const Op& cached_op_;
209
210 /* \brief minimum number of parallel branches to combine */
211 uint64_t min_num_branches_;
212
213 /* \brief map of Expr to Expr to substitute it with after running pass */
214 ExprSubstMap subst_map_;
215
216 /*
217 * \brief Combine parallel branches and updates subst_map_ with Exprs
218 * to be substituted
219 * \param branches branches to be combined
220 */
221 void CombineBranches(const Group& branches);
222
223 /*
224 * \brief Combine parallel branches and updates subst_map_ with Exprs
225 * to be substituted
226 * \param branches parallel branches to potentially be combined
227 * \param depth depth at which to look at op
228 * \param parent_index index of arg that corresponds to original input that was shared among
229 * all combined ops
230 * \return true if parallel ops at depth can be combined, false otherwise
231 */
232 bool CheckLevel(const Group& branches, size_t depth, size_t parent_index);
233};
234
235} // namespace relay
236} // namespace tvm
237#endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
238