1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file combine_parallel_op.h |
23 | * \brief Abstract class to combine parallel ops and their successive element-wise ops. |
24 | */ |
25 | #ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ |
26 | #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ |
27 | |
28 | #include <tvm/relay/analysis.h> |
29 | #include <tvm/relay/attrs/nn.h> |
30 | #include <tvm/relay/attrs/transform.h> |
31 | #include <tvm/relay/expr_functor.h> |
32 | #include <tvm/relay/op_attr_types.h> |
33 | #include <tvm/relay/transform.h> |
34 | |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <unordered_set> |
38 | #include <vector> |
39 | |
40 | #include "./expr_subst.h" |
41 | #include "pattern_utils.h" |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | using Branch = std::vector<const CallNode*>; |
47 | using Group = std::vector<Branch>; |
48 | using FIsSupportedOp = std::function<bool(const CallNode* n)>; |
49 | using FAreCompatibleOps = std::function<bool(const CallNode* a, const CallNode* b)>; |
50 | using ExprSubstMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>; |
51 | |
52 | /* |
53 | * Class to find parallel branches starting with op that are |
54 | * grouped if they are able to be combined. They are eligible to |
55 | * be combined if they have the same input data. |
56 | * Op can be followed by zero or more elemwise or broadcast ops, |
57 | * which are included in the group. |
58 | * Intermediate nodes have exactly one successor. It is possible that branches meet at a point, |
59 | * which should be handled in ParallelOpCombiner. |
60 | * |
61 | * data |
62 | * / \ |
63 | * op op |
64 | * | | |
65 | * elem-wise elem-wise |
66 | * | | |
67 | */ |
68 | class BranchGroupFinder : private ExprVisitor { |
69 | public: |
70 | /* |
71 | * \brief Constructor |
72 | * \param op The op that indicates the start of each group |
73 | * \param fis_supported_op function that returns true if op |
74 | * is supported for combining |
75 | * \param fare_compatible_ops function that returns true if |
76 | * two ops are compatible for combining |
77 | */ |
78 | BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, |
79 | FAreCompatibleOps fare_compatible_ops); |
80 | |
81 | /* |
82 | * \brief Finds all groups that can be combined. |
83 | * \param expr Relay expression that represents function |
84 | * to look at for groups to be combined |
85 | * \return Vector of groups which can be combined. |
86 | */ |
87 | std::vector<Group> Find(const Expr& expr); |
88 | |
89 | private: |
90 | /* \brief Cache the op for finding parallel branches */ |
91 | const Op& cached_op_; |
92 | |
93 | /* \brief function to return true if op is eligible to be combined, |
94 | * false otherwise |
95 | */ |
96 | FIsSupportedOp fis_supported_op_; |
97 | |
98 | /* \brief function to return true if two parallel ops are eligible |
99 | * to be combined, false otherwise |
100 | */ |
101 | FAreCompatibleOps fare_compatible_ops_; |
102 | |
103 | /* \brief ops that are on the first (logically, leftmost) branch |
104 | * of parallel ops and are eligible to be combined |
105 | */ |
106 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> op_roots_; |
107 | |
108 | /* \brief map of Expr to CallNodes that follow it */ |
109 | std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual> |
110 | children_map_; |
111 | |
112 | /* |
113 | * \brief Creates new branch from op and its children that have |
114 | * elementwise or broadcast patterns |
115 | * \return New branch |
116 | */ |
117 | Branch CreateBranch(const CallNode* op); |
118 | |
119 | /* |
120 | * \brief Expression visitor function |
121 | */ |
122 | void VisitExpr_(const CallNode* n) final; |
123 | }; |
124 | |
125 | /* |
126 | * Abstract class to find and combine parallel ops and the elementwise ops that follow. |
127 | */ |
128 | class ParallelOpCombiner { |
129 | public: |
130 | /*! \brief virtual destructor */ |
131 | virtual ~ParallelOpCombiner() {} |
132 | /* |
133 | * \brief Constructor. |
134 | * \param op_name name of op to combine |
135 | * \param min_num_branches min number of parallel branches beginning with op |
136 | * to start combining |
137 | */ |
138 | explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); |
139 | |
140 | /* |
141 | * \brief Combines ops and following elementwise or broadcast ops |
142 | * \param expr function to modify |
143 | * \return new function with combined ops |
144 | */ |
145 | Expr Combine(const Expr& expr); |
146 | |
147 | protected: |
148 | /* |
149 | * \brief Checks if node is supported to be combined |
150 | * \param n node in question |
151 | * \return True if the op represented by n is supported to be the root of a branch |
152 | * to be combined. False otherwise. |
153 | */ |
154 | virtual bool IsSupportedOp(const CallNode* n) = 0; |
155 | |
156 | /* |
157 | * \brief Checks if two ops can be combined |
158 | * \param a node a |
159 | * \param b node b |
160 | * \return True if a and b can be combined. False otherwise. |
161 | */ |
162 | virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; |
163 | |
164 | /* |
165 | * \brief Makes combined op from parallel ops in branches. This usually involves |
166 | * concatenating or stacking inputs, then creating a new call. |
167 | * \param branches branches that are to be combined |
168 | * \return new call with branches combined. |
169 | */ |
170 | virtual Call MakeCombinedOp(const Group& branches) = 0; |
171 | |
172 | /* |
173 | * \brief Checks if argument of op following combined ops are able to be combined |
174 | * \param a node a |
175 | * \param b node b |
176 | * \param index index of argument in question |
177 | * \return True if argument of a and b and index can be combined |
178 | */ |
179 | virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; |
180 | |
181 | /* |
182 | * \brief Create combined call from ops that follow the initial combined op at the depth-th level. |
183 | * This usually involves concatenating or stacking inputs, then creating a new call. |
184 | * Only called if IsArgCompatbile returns true for each arg. |
185 | * \param data combined op |
186 | * \param branches branches of parallel ops to be combined |
187 | * \param depth depth at which to combine ops |
188 | * \param parent_index index of arg that corresponds to original input that was shared among |
189 | * all combined ops |
190 | * \return new combined call |
191 | */ |
192 | virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, |
193 | size_t depth, size_t parent_index) = 0; |
194 | |
195 | /* |
196 | * \brief Updates map of expr to substitute with combined expr. This usually involves |
197 | * slicing or splitting data. |
198 | * \param data combined op |
199 | * \param branches branches of parallel ops to be combined |
200 | * \param depth depth at which to substitute |
201 | * \param subst_map map of Expr to replace with Expr to replace it with |
202 | */ |
203 | virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, |
204 | ExprSubstMap* subst_map) = 0; |
205 | |
206 | private: |
207 | /* \brief Cache the op to be combined */ |
208 | const Op& cached_op_; |
209 | |
210 | /* \brief minimum number of parallel branches to combine */ |
211 | uint64_t min_num_branches_; |
212 | |
213 | /* \brief map of Expr to Expr to substitute it with after running pass */ |
214 | ExprSubstMap subst_map_; |
215 | |
216 | /* |
217 | * \brief Combine parallel branches and updates subst_map_ with Exprs |
218 | * to be substituted |
219 | * \param branches branches to be combined |
220 | */ |
221 | void CombineBranches(const Group& branches); |
222 | |
223 | /* |
224 | * \brief Combine parallel branches and updates subst_map_ with Exprs |
225 | * to be substituted |
226 | * \param branches parallel branches to potentially be combined |
227 | * \param depth depth at which to look at op |
228 | * \param parent_index index of arg that corresponds to original input that was shared among |
229 | * all combined ops |
230 | * \return true if parallel ops at depth can be combined, false otherwise |
231 | */ |
232 | bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); |
233 | }; |
234 | |
235 | } // namespace relay |
236 | } // namespace tvm |
237 | #endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ |
238 | |