1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/ir/indexed_graph.h |
22 | * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) |
23 | * pattern. Each 'indexed graph' node is 1:1 with an expression/pattern 'node', hence the |
24 | * term 'IndexedGraph'. Dataflow is captured in a generic representation which is convenient |
25 | * for analysis, particularly pattern matching and partitioning. |
26 | * |
27 | * TODO(mbs): Copied from fuse_ops.cc, consider refactoring to share implementation. |
28 | */ |
29 | #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ |
30 | #define TVM_RELAY_IR_INDEXED_GRAPH_H_ |
31 | |
32 | #include <tvm/relay/dataflow_pattern.h> |
33 | |
34 | #include <memory> |
35 | #include <stack> |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <unordered_set> |
39 | #include <utility> |
40 | #include <vector> |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | /*! \brief The index of a node in the post-dfs traversal of overall expression. */ |
46 | using PostDfsIndex = size_t; |
47 | |
48 | /*! |
49 | * \brief Returns a brief summary of the 'reference' expression or pattern. Only used by |
50 | * IndexedGraph::ToString() for debugging. |
51 | */ |
52 | std::string RefToSummary(const Expr& expr); |
53 | std::string RefToSummary(const DFPattern& pattern); |
54 | |
55 | /*! |
56 | * \brief Represents the implied dataflow of an expression or (dataflow) pattern as a DAG who's |
57 | * nodes are 1:1 with those in the underlying expression/pattern. |
58 | * |
59 | * Each indexed graph node captures: |
60 | * - Dataflow inputs. |
61 | * - Dataflow outputs (or a flag indicating the node is an implied output). |
62 | * - Dominator parent (ie closest node at which all outputs of the current node re-combine). |
63 | * - Dominator children (inverse of above). |
64 | * - Basic block (ie node representing the body of a function, arm of an if, etc). |
65 | * |
66 | * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. |
67 | * |
68 | * IndexedGraph should be instantiated through the CreateIndexedGraph utilities below. |
69 | */ |
70 | template <typename T> |
71 | class IndexedGraph { |
72 | public: |
73 | using TNode = typename T::ContainerType; |
74 | |
75 | /*! \brief A Node in the graph. */ |
76 | struct Node { |
77 | /*! \brief Node Constructor |
78 | * \param ref The expression or dataflow pattern node this indexed graph node is augmenting. |
79 | * \param index The index of this node in the topological order |
80 | */ |
81 | Node(const TNode* ref, PostDfsIndex index) : node_ref_(ref), index_(index) {} |
82 | |
83 | /*! \brief The underlying expression or pattern node. */ |
84 | const TNode* node_ref_; |
85 | |
86 | T ref() const { |
87 | ICHECK(node_ref_ != nullptr); |
88 | return GetRef<T>(node_ref_); |
89 | } |
90 | |
91 | /*! |
92 | * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then |
93 | * left does not flow into right. If left.index_ = right.index_ then left and right are |
94 | * the same node. |
95 | */ |
96 | const PostDfsIndex index_; |
97 | |
98 | /*! \brief If true this node has implicit outputs, for example as the result of a function. */ |
99 | bool is_external_ = false; |
100 | /*! \brief Immediate dataflow inputs to this node. */ |
101 | std::vector<Node*> inputs_; |
102 | /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */ |
103 | std::vector<Node*> outputs_; |
104 | |
105 | /*! |
106 | * \brief The node representing the 'basic block' containing this node: |
107 | * - Function bodies start a new basic block for their bodies. |
108 | * - The true and false branches of an if start their own blocks. |
109 | * - The arms of a match each have their own blocks. |
110 | */ |
111 | Node* basic_block_ = nullptr; |
112 | |
113 | /*! \brief The depth of this node in the dominator tree */ |
114 | size_t depth_ = 0; |
115 | /*! |
116 | * \brief The dominator parent of this node. This is the node N with least index such that |
117 | * all possible dataflows from this node pass through N. |
118 | */ |
119 | Node* dominator_parent_ = nullptr; |
120 | /*! \brief The nodes this node dominates. */ |
121 | std::vector<Node*> dominator_children_; |
122 | |
123 | /*! |
124 | * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be |
125 | * reached by following output paths. |
126 | */ |
127 | void AccumulateDownstreamNodes(std::unordered_set<const Node*>* nodes) const { |
128 | std::stack<const Node*> stack; |
129 | stack.push(this); |
130 | while (!stack.empty()) { |
131 | const Node* current = stack.top(); |
132 | stack.pop(); |
133 | for (auto node : current->outputs_) { |
134 | if (nodes->count(node) == 0) { |
135 | stack.push(node); |
136 | nodes->insert(node); |
137 | } |
138 | } |
139 | } |
140 | } |
141 | |
142 | /*! |
143 | * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p |
144 | * other pass through \p this. |
145 | */ |
146 | bool Dominates(const Node* other) const { |
147 | std::stack<const Node*> stack; |
148 | std::unordered_set<const Node*> visited; |
149 | stack.push(this); |
150 | while (!stack.empty()) { |
151 | const Node* current = stack.top(); |
152 | stack.pop(); |
153 | for (auto node : current->dominator_children_) { |
154 | if (visited.count(node) == 0) { |
155 | if (other == node) { |
156 | return true; |
157 | } else { |
158 | stack.push(node); |
159 | } |
160 | visited.insert(node); |
161 | } |
162 | } |
163 | } |
164 | return false; |
165 | } |
166 | }; |
167 | |
168 | PostDfsIndex size() const { return topological_order_.size(); } |
169 | |
170 | Node* item_to_node(const T& item) { return item_to_node(item.get()); } |
171 | const Node* item_to_node(const T& item) const { return item_to_node(item.get()); } |
172 | |
173 | Node* item_to_node(const TNode* item) { |
174 | auto itr = node_map_.find(item); |
175 | ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef<T>(item)); |
176 | return itr->second; |
177 | } |
178 | |
179 | const Node* item_to_node(const TNode* item) const { |
180 | auto itr = node_map_.find(item); |
181 | ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef<T>(item)); |
182 | return itr->second; |
183 | } |
184 | |
185 | Node* index_to_node(PostDfsIndex index) { |
186 | ICHECK_LT(index, topological_order_.size()) << index; |
187 | return topological_order_[index].get(); |
188 | } |
189 | |
190 | const Node* index_to_node(PostDfsIndex index) const { |
191 | ICHECK_LT(index, topological_order_.size()) << index; |
192 | return topological_order_[index].get(); |
193 | } |
194 | |
195 | /*! |
196 | * \brief (For debugging only) Returns description of indexed graph with hints as to the |
197 | * sub-expressions or sub-patterns corresponding to each indexed graph node. |
198 | */ |
199 | std::string ToString() const { |
200 | std::ostringstream os; |
201 | os << "IndexedGraph(size = " << topological_order_.size() << ") {" << std::endl; |
202 | for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { |
203 | const Node* node = topological_order_[index].get(); |
204 | ICHECK_EQ(index, node->index_); |
205 | os << " " << index << " (" << RefToSummary(node->ref()) << "): inputs=[" ; |
206 | for (const auto* sub_node : node->inputs_) { |
207 | os << sub_node->index_ << "," ; |
208 | } |
209 | os << "], outputs=[" ; |
210 | for (const auto* sub_node : node->outputs_) { |
211 | os << sub_node->index_ << "," ; |
212 | } |
213 | os << "]" ; |
214 | if (node->is_external_) { |
215 | os << ", external" ; |
216 | } |
217 | if (node->basic_block_) { |
218 | os << ", basic_block=" << node->basic_block_->index_; |
219 | } |
220 | if (node->depth_ > 0) { |
221 | os << ", depth=" << node->depth_; |
222 | } |
223 | if (node->dominator_parent_) { |
224 | os << ", dom_parent=" << node->dominator_parent_->index_; |
225 | } |
226 | os << ", dom_children=[" ; |
227 | for (const auto* sub_node : node->dominator_children_) { |
228 | os << sub_node->index_ << "," ; |
229 | } |
230 | os << "]" << std::endl; |
231 | } |
232 | os << "}" ; |
233 | return os.str(); |
234 | } |
235 | |
236 | /*! |
237 | * Check-fails if the graph is ill-formed. For debugging only. |
238 | */ |
239 | void CheckValid() const { |
240 | ICHECK_GT(topological_order_.size(), 0); |
241 | for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { |
242 | const Node* node = topological_order_[index].get(); |
243 | // We have a node. |
244 | ICHECK(node); |
245 | // Bijections with post-dfs indexes and expressions/patterns are correct. |
246 | ICHECK_EQ(node->index_, index); |
247 | ICHECK(node->node_ref_); |
248 | auto itr = node_map_.find(node->node_ref_); |
249 | ICHECK(itr != node_map_.end()); |
250 | ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString(); |
251 | // Inputs come before. |
252 | for (size_t i = 0; i < node->inputs_.size(); ++i) { |
253 | const Node* input = node->inputs_[i]; |
254 | ICHECK(input); |
255 | ICHECK_LT(input->index_, index); |
256 | ICHECK(std::find(input->outputs_.begin(), input->outputs_.end(), node) != |
257 | input->outputs_.end()); |
258 | } |
259 | // Outputs come after. |
260 | for (size_t i = 0; i < node->outputs_.size(); ++i) { |
261 | const Node* output = node->outputs_[i]; |
262 | ICHECK(output); |
263 | ICHECK_GT(output->index_, index); |
264 | ICHECK(std::find(output->inputs_.begin(), output->inputs_.end(), node) != |
265 | output->inputs_.end()); |
266 | } |
267 | ICHECK_GT(node->depth_, 0); |
268 | // Dominator children come before. |
269 | for (size_t i = 0; i < node->dominator_children_.size(); ++i) { |
270 | const Node* child = node->dominator_children_[i]; |
271 | ICHECK(child); |
272 | ICHECK_LT(child->index_, index); |
273 | } |
274 | if (node->dominator_parent_) { |
275 | // Dominator comes after. |
276 | ICHECK_GT(node->dominator_parent_->index_, index); |
277 | } |
278 | } |
279 | } |
280 | |
281 | private: |
282 | /*! \brief Construct the domination tree inside IndexedGraph */ |
283 | void PostDom() { |
284 | for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { |
285 | PostDfsIndex index = i - 1; |
286 | auto* current = topological_order_[index].get(); |
287 | if (current->is_external_) { |
288 | current->depth_ = 1; |
289 | current->dominator_parent_ = nullptr; |
290 | } else { |
291 | auto parent = LeastCommonAncestor(current->outputs_); |
292 | current->depth_ = parent ? parent->depth_ + 1 : 1; |
293 | current->dominator_parent_ = parent; |
294 | if (parent) { |
295 | parent->dominator_children_.push_back(current); |
296 | } |
297 | } |
298 | } |
299 | } |
300 | |
301 | /*! \brief Find the least common ancestor of all outputs of a node */ |
302 | Node* LeastCommonAncestor(const std::vector<Node*>& outputs) { |
303 | if (outputs.size() == 0) { |
304 | return nullptr; |
305 | } |
306 | auto parent = outputs.at(0); |
307 | for (size_t i = 1; i < outputs.size(); ++i) { |
308 | parent = LeastCommonAncestor(parent, outputs.at(i)); |
309 | } |
310 | return parent; |
311 | } |
312 | |
313 | /*! \brief Find the least common ancestor of two nodes */ |
314 | Node* LeastCommonAncestor(Node* lhs, Node* rhs) { |
315 | if (lhs == nullptr || rhs == nullptr) { |
316 | return nullptr; |
317 | } |
318 | PostDfsIndex lhs_index = lhs->index_; |
319 | PostDfsIndex rhs_index = rhs->index_; |
320 | while (lhs != rhs) { |
321 | ICHECK(lhs && rhs) << "LCA(" << lhs_index << ", " << rhs_index << ") on graph:" << std::endl |
322 | << ToString(); |
323 | if (lhs->depth_ < rhs->depth_) { |
324 | rhs = rhs->dominator_parent_; |
325 | } else if (lhs->depth_ > rhs->depth_) { |
326 | lhs = lhs->dominator_parent_; |
327 | } else { |
328 | rhs = rhs->dominator_parent_; |
329 | lhs = lhs->dominator_parent_; |
330 | } |
331 | } |
332 | return lhs; |
333 | } |
334 | |
335 | /*! |
336 | * \brief Appends a node corresponding to \p ref, and maintains the sub-expression/sub-pattern to |
337 | * node bijection. The insertion index will be the node's PostDfsIndex. All other node properties |
338 | * are accumulated in-place. |
339 | */ |
340 | void AddNode(const T& ref) { |
341 | PostDfsIndex index = topological_order_.size(); |
342 | auto node = std::make_unique<Node>(ref.get(), index); |
343 | node_map_[ref.get()] = node.get(); |
344 | topological_order_.emplace_back(std::move(node)); |
345 | } |
346 | |
347 | /*! |
348 | * \brief Map from underlying sub-expression or sub-pattern nodes to their indexed graph nodes. |
349 | */ |
350 | std::unordered_map<const TNode*, Node*> node_map_; |
351 | /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ |
352 | std::vector<std::unique_ptr<Node>> topological_order_; |
353 | |
354 | friend std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr); |
355 | friend std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pattern); |
356 | }; |
357 | |
358 | /*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */ |
359 | std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr); |
360 | |
361 | /*! |
362 | * \brief Returns an Indexed Graph for \p pattern, which must outlive the result. |
363 | * The dataflow for a pattern mimics the dataflow for the expression which would match |
364 | * that pattern. |
365 | */ |
366 | std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pattern); |
367 | |
368 | } // namespace relay |
369 | } // namespace tvm |
370 | |
371 | #endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ |
372 | |