1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/graph_partitioner.h
22 * \brief The helper function for op fusion.
23 */
24
25#ifndef TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_
26#define TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_
27
28#include <tvm/relay/op_attr_types.h>
29
30#include <unordered_map>
31#include <unordered_set>
32#include <vector>
33
34#include "../../support/arena.h"
35
36namespace tvm {
37namespace relay {
38
39using support::LinkedList;
40using support::LinkNode;
41
42/*!
43 * \brief Indexed data flow graph in forward direction.
44 * This is a temporary data structure used for operator fusion analysis.
45 *
46 * This data structure only captures the dataflow fragment and
47 * could ignore blocks like let by simply ordering each dataflow block
48 * and mark the output node as extern_ref;
49 */
50class IndexedForwardGraph {
51 public:
52 struct Node;
53 /*!
54 * The forward edge in the dataflow graph.
55 */
56 struct Edge {
57 /*! \brief The corresponding node */
58 Node* node{nullptr};
59 /*! \brief The respective pattern of this op */
60 OpPatternKind pattern{kOpaque};
61 };
62 /*! \brief A node in the graph. */
63 struct Node {
64 /*! \brief weak reference to the corresponding edge. */
65 const tvm::Object* ref{nullptr};
66 /*! \brief The index of the node in topological order. */
67 size_t index{0};
68 /*! \brief Whether this node is referenced by external source */
69 bool extern_ref{false};
70 /*! \brief The general pattern in the node */
71 OpPatternKind pattern{kOpaque};
72 /*! \brief The outputs of the node. */
73 LinkedList<Edge> outputs;
74 };
75 /*! \brief The node map that maps node to graph */
76 std::unordered_map<const tvm::Object*, Node*> node_map;
77 /*! \brief All the nodes in post DFS order */
78 std::vector<Node*> post_dfs_order;
79
80 /*! \brief Dump the graph into string. */
81 void DebugDump() {
82 std::ostringstream os;
83 for (size_t i = 0; i < post_dfs_order.size(); ++i) {
84 Node* node = post_dfs_order[i];
85 os << "node[" << i << "], " << GetRef<ObjectRef>(node->ref) << " outputs=[";
86 for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
87 os << link->value.node->index << ", ";
88 }
89 os << "]\n";
90 }
91 LOG(INFO) << os.str();
92 }
93};
94
95/*!
96 * \brief Dominator tree that represent domination or
97 * post domination relation of the node.
98 */
99class DominatorTree {
100 public:
101 /*!
102 * \brief A node in the dominator tree.
103 */
104 struct Node {
105 /*! \brief The node in the tree */
106 IndexedForwardGraph::Node* gnode{nullptr};
107 /*! \brief parent of the tree */
108 Node* parent{nullptr};
109 /*! \brief current depth*/
110 int depth{0};
111 /*! \brief aggregated pattern to parent */
112 OpPatternKind pattern{kOpaque};
113 };
114 // index -> node.
115 std::vector<Node*> nodes;
116 /*!
117 * \brief compute a post dominator relation for a given dataflow graph.
118 * \param arena The arena used for node allocation.
119 * \param graph The graph to be analyzed.
120 * \return The dominator tree of the graph.
121 * \note This algorithm makes use of the fact that graph is DAG,
122 * and runs a single pass algorithm via LCA (Least Common Ancestor)
123 */
124 static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph);
125
126 private:
127 // Combine pattern together.
128 inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
129 if (lhs > rhs) return lhs;
130 return rhs;
131 }
132 /*!
133 * \brief Find the least common ancestor of the two nodes.
134 * \param lhs The left node.
135 * \param rhs The right node.
136 * \param edge_pattern
137 * The combined edge pattern across all the parents.
138 * \return The least common ancestor of the two.
139 */
140 static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern);
141 /*!
142 * \brief Find the least common ancestor of a list of nodes.
143 * \param nodes the nodes.
144 * \param edge_pattern
145 * The combined edge pattern across all the parents.
146 * \return The least common ancestor of all nodes.
147 */
148 Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
149 OpPatternKind* edge_pattern);
150
151 /*!
152 * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
153 * \param arena The Arena.
154 * \param gnode An IndexedForwardGraph Node.
155 * \return The DominatorTree Node.
156 */
157 Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode);
158};
159
160/*!
161 * \brief A partition of the graph marked by union find data structure.
162 */
163class GraphPartitioner {
164 public:
165 explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
166 : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
167 /*!
168 * \brief Group as a union find data structure.
169 */
170 struct Group {
171 /*! \brief The parent in the union find data structure. */
172 Group* parent{nullptr};
173 /*! \brief The pattern of the group */
174 OpPatternKind pattern;
175 /*! \brief reference to the root node. */
176 const tvm::Object* root_ref{nullptr};
177 /*!
178 * \brief Reference to the anchor node,
179 * this field is not nullptr only if pattern is kOutEWiseFusable.
180 */
181 const tvm::Object* anchor_ref{nullptr};
182 /*!
183 * \brief The number of nodes belonging to this group
184 */
185 uint32_t num_nodes{1};
186
187 /*! \brief Optional attributes to annotate the grouped function. */
188 runtime::Map<runtime::String, ObjectRef> attrs;
189 /*!
190 * \brief Find the group root, perform path compression
191 * \return The root type node.
192 */
193 Group* FindRoot();
194 };
195 /*!
196 * \brief Partition a graph.
197 * \return group assignments of each node.
198 */
199 std::vector<Group*> Partition(const IndexedForwardGraph& graph);
200
201 private:
202 /*! \brief The internal arena for temporary space. */
203 support::Arena* arena_;
204 /*! \brief optimization level for fuse operation. */
205 int opt_level_;
206 /*! \brief The maximum number of operations in one fused function */
207 size_t max_fuse_depth_;
208 /*! \brief The internal groups. */
209 std::vector<Group*> groups_;
210 /*! \brief internal field used for deduplication */
211 std::unordered_set<IndexedForwardGraph::Node*> visited_;
212 // Internal implementation of CheckPath
213 template <typename F>
214 bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
215
216 /*!
217 * \brief Check all the node and edge pattern
218 * between src and sink satisfies fcond.
219 *
220 * src is not checked.
221 *
222 * \param src The source node.
223 * \param sink The termination node.
224 * \param fcond The condition to be checked.
225 * \tparam F the condition function, with signature
226 * \note sink must be a post-dominator of src.
227 */
228 template <typename F>
229 bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
230
231 /*!
232 * \brief Merge the child group to the parent.
233 * \param child The child group.
234 * \param parent The parent group.
235 */
236 void MergeFromTo(Group* child, Group* parent);
237
238 // Internal implementation of CommitFuse
239 void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target);
240
241 /*!
242 * \brief Commit fusion operation.
243 * \param src The source node.
244 * \param sink The termination node.
245 * \note sink must be a post-dominator of src.
246 */
247 void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
248
249 size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
250
251 // Count the number of nodes in a fused subgraph if child is additionally fused.
252 // dom_parent is already known to be a part of the subgraph.
253 // For a diamond structure, there can be multiple paths connecting child and dom_parent.
254 // All intermediate nodes between child and dom_parent are taken into account.
255 // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
256 // is important for correct calculation.
257 size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
258 IndexedForwardGraph::Node* dom_parent);
259
260 // Initialize the groups.
261 void InitGroups(const IndexedForwardGraph& graph);
262
263 // execute the fusion algorithm.
264 void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase);
265};
266
267} // namespace relay
268} // namespace tvm
269#endif // TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_
270