1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/graph_partitioner.h |
22 | * \brief The helper function for op fusion. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ |
26 | #define TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ |
27 | |
28 | #include <tvm/relay/op_attr_types.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | #include <vector> |
33 | |
34 | #include "../../support/arena.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | using support::LinkedList; |
40 | using support::LinkNode; |
41 | |
42 | /*! |
43 | * \brief Indexed data flow graph in forward direction. |
44 | * This is a temporary data structure used for operator fusion analysis. |
45 | * |
46 | * This data structure only captures the dataflow fragment and |
47 | * could ignore blocks like let by simply ordering each dataflow block |
48 | * and mark the output node as extern_ref; |
49 | */ |
50 | class IndexedForwardGraph { |
51 | public: |
52 | struct Node; |
53 | /*! |
54 | * The forward edge in the dataflow graph. |
55 | */ |
56 | struct Edge { |
57 | /*! \brief The corresponding node */ |
58 | Node* node{nullptr}; |
59 | /*! \brief The respective pattern of this op */ |
60 | OpPatternKind pattern{kOpaque}; |
61 | }; |
62 | /*! \brief A node in the graph. */ |
63 | struct Node { |
64 | /*! \brief weak reference to the corresponding edge. */ |
65 | const tvm::Object* ref{nullptr}; |
66 | /*! \brief The index of the node in topological order. */ |
67 | size_t index{0}; |
68 | /*! \brief Whether this node is referenced by external source */ |
69 | bool extern_ref{false}; |
70 | /*! \brief The general pattern in the node */ |
71 | OpPatternKind pattern{kOpaque}; |
72 | /*! \brief The outputs of the node. */ |
73 | LinkedList<Edge> outputs; |
74 | }; |
75 | /*! \brief The node map that maps node to graph */ |
76 | std::unordered_map<const tvm::Object*, Node*> node_map; |
77 | /*! \brief All the nodes in post DFS order */ |
78 | std::vector<Node*> post_dfs_order; |
79 | |
80 | /*! \brief Dump the graph into string. */ |
81 | void DebugDump() { |
82 | std::ostringstream os; |
83 | for (size_t i = 0; i < post_dfs_order.size(); ++i) { |
84 | Node* node = post_dfs_order[i]; |
85 | os << "node[" << i << "], " << GetRef<ObjectRef>(node->ref) << " outputs=[" ; |
86 | for (auto* link = node->outputs.head; link != nullptr; link = link->next) { |
87 | os << link->value.node->index << ", " ; |
88 | } |
89 | os << "]\n" ; |
90 | } |
91 | LOG(INFO) << os.str(); |
92 | } |
93 | }; |
94 | |
95 | /*! |
96 | * \brief Dominator tree that represent domination or |
97 | * post domination relation of the node. |
98 | */ |
99 | class DominatorTree { |
100 | public: |
101 | /*! |
102 | * \brief A node in the dominator tree. |
103 | */ |
104 | struct Node { |
105 | /*! \brief The node in the tree */ |
106 | IndexedForwardGraph::Node* gnode{nullptr}; |
107 | /*! \brief parent of the tree */ |
108 | Node* parent{nullptr}; |
109 | /*! \brief current depth*/ |
110 | int depth{0}; |
111 | /*! \brief aggregated pattern to parent */ |
112 | OpPatternKind pattern{kOpaque}; |
113 | }; |
114 | // index -> node. |
115 | std::vector<Node*> nodes; |
116 | /*! |
117 | * \brief compute a post dominator relation for a given dataflow graph. |
118 | * \param arena The arena used for node allocation. |
119 | * \param graph The graph to be analyzed. |
120 | * \return The dominator tree of the graph. |
121 | * \note This algorithm makes use of the fact that graph is DAG, |
122 | * and runs a single pass algorithm via LCA (Least Common Ancestor) |
123 | */ |
124 | static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); |
125 | |
126 | private: |
127 | // Combine pattern together. |
128 | inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { |
129 | if (lhs > rhs) return lhs; |
130 | return rhs; |
131 | } |
132 | /*! |
133 | * \brief Find the least common ancestor of the two nodes. |
134 | * \param lhs The left node. |
135 | * \param rhs The right node. |
136 | * \param edge_pattern |
137 | * The combined edge pattern across all the parents. |
138 | * \return The least common ancestor of the two. |
139 | */ |
140 | static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern); |
141 | /*! |
142 | * \brief Find the least common ancestor of a list of nodes. |
143 | * \param nodes the nodes. |
144 | * \param edge_pattern |
145 | * The combined edge pattern across all the parents. |
146 | * \return The least common ancestor of all nodes. |
147 | */ |
148 | Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes, |
149 | OpPatternKind* edge_pattern); |
150 | |
151 | /*! |
152 | * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. |
153 | * \param arena The Arena. |
154 | * \param gnode An IndexedForwardGraph Node. |
155 | * \return The DominatorTree Node. |
156 | */ |
157 | Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode); |
158 | }; |
159 | |
160 | /*! |
161 | * \brief A partition of the graph marked by union find data structure. |
162 | */ |
163 | class GraphPartitioner { |
164 | public: |
165 | explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) |
166 | : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} |
167 | /*! |
168 | * \brief Group as a union find data structure. |
169 | */ |
170 | struct Group { |
171 | /*! \brief The parent in the union find data structure. */ |
172 | Group* parent{nullptr}; |
173 | /*! \brief The pattern of the group */ |
174 | OpPatternKind pattern; |
175 | /*! \brief reference to the root node. */ |
176 | const tvm::Object* root_ref{nullptr}; |
177 | /*! |
178 | * \brief Reference to the anchor node, |
179 | * this field is not nullptr only if pattern is kOutEWiseFusable. |
180 | */ |
181 | const tvm::Object* anchor_ref{nullptr}; |
182 | /*! |
183 | * \brief The number of nodes belonging to this group |
184 | */ |
185 | uint32_t num_nodes{1}; |
186 | |
187 | /*! \brief Optional attributes to annotate the grouped function. */ |
188 | runtime::Map<runtime::String, ObjectRef> attrs; |
189 | /*! |
190 | * \brief Find the group root, perform path compression |
191 | * \return The root type node. |
192 | */ |
193 | Group* FindRoot(); |
194 | }; |
195 | /*! |
196 | * \brief Partition a graph. |
197 | * \return group assignments of each node. |
198 | */ |
199 | std::vector<Group*> Partition(const IndexedForwardGraph& graph); |
200 | |
201 | private: |
202 | /*! \brief The internal arena for temporary space. */ |
203 | support::Arena* arena_; |
204 | /*! \brief optimization level for fuse operation. */ |
205 | int opt_level_; |
206 | /*! \brief The maximum number of operations in one fused function */ |
207 | size_t max_fuse_depth_; |
208 | /*! \brief The internal groups. */ |
209 | std::vector<Group*> groups_; |
210 | /*! \brief internal field used for deduplication */ |
211 | std::unordered_set<IndexedForwardGraph::Node*> visited_; |
212 | // Internal implementation of CheckPath |
213 | template <typename F> |
214 | bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); |
215 | |
216 | /*! |
217 | * \brief Check all the node and edge pattern |
218 | * between src and sink satisfies fcond. |
219 | * |
220 | * src is not checked. |
221 | * |
222 | * \param src The source node. |
223 | * \param sink The termination node. |
224 | * \param fcond The condition to be checked. |
225 | * \tparam F the condition function, with signature |
226 | * \note sink must be a post-dominator of src. |
227 | */ |
228 | template <typename F> |
229 | bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); |
230 | |
231 | /*! |
232 | * \brief Merge the child group to the parent. |
233 | * \param child The child group. |
234 | * \param parent The parent group. |
235 | */ |
236 | void MergeFromTo(Group* child, Group* parent); |
237 | |
238 | // Internal implementation of CommitFuse |
239 | void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target); |
240 | |
241 | /*! |
242 | * \brief Commit fusion operation. |
243 | * \param src The source node. |
244 | * \param sink The termination node. |
245 | * \note sink must be a post-dominator of src. |
246 | */ |
247 | void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); |
248 | |
249 | size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); |
250 | |
251 | // Count the number of nodes in a fused subgraph if child is additionally fused. |
252 | // dom_parent is already known to be a part of the subgraph. |
253 | // For a diamond structure, there can be multiple paths connecting child and dom_parent. |
254 | // All intermediate nodes between child and dom_parent are taken into account. |
255 | // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() |
256 | // is important for correct calculation. |
257 | size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, |
258 | IndexedForwardGraph::Node* dom_parent); |
259 | |
260 | // Initialize the groups. |
261 | void InitGroups(const IndexedForwardGraph& graph); |
262 | |
263 | // execute the fusion algorithm. |
264 | void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase); |
265 | }; |
266 | |
267 | } // namespace relay |
268 | } // namespace tvm |
269 | #endif // TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ |
270 | |