1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/indexed_graph.h
22 * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow)
23 * pattern. Each 'indexed graph' node is 1:1 with an expression/pattern 'node', hence the
24 * term 'IndexedGraph'. Dataflow is captured in a generic representation which is convenient
25 * for analysis, particularly pattern matching and partitioning.
26 *
27 * TODO(mbs): Copied from fuse_ops.cc, consider refactoring to share implementation.
28 */
29#ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_
30#define TVM_RELAY_IR_INDEXED_GRAPH_H_
31
32#include <tvm/relay/dataflow_pattern.h>
33
34#include <memory>
35#include <stack>
36#include <string>
37#include <unordered_map>
38#include <unordered_set>
39#include <utility>
40#include <vector>
41
42namespace tvm {
43namespace relay {
44
45/*! \brief The index of a node in the post-dfs traversal of overall expression. */
46using PostDfsIndex = size_t;
47
48/*!
49 * \brief Returns a brief summary of the 'reference' expression or pattern. Only used by
50 * IndexedGraph::ToString() for debugging.
51 */
52std::string RefToSummary(const Expr& expr);
53std::string RefToSummary(const DFPattern& pattern);
54
55/*!
56 * \brief Represents the implied dataflow of an expression or (dataflow) pattern as a DAG who's
57 * nodes are 1:1 with those in the underlying expression/pattern.
58 *
59 * Each indexed graph node captures:
60 * - Dataflow inputs.
61 * - Dataflow outputs (or a flag indicating the node is an implied output).
62 * - Dominator parent (ie closest node at which all outputs of the current node re-combine).
63 * - Dominator children (inverse of above).
64 * - Basic block (ie node representing the body of a function, arm of an if, etc).
65 *
66 * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure.
67 *
68 * IndexedGraph should be instantiated through the CreateIndexedGraph utilities below.
69 */
70template <typename T>
71class IndexedGraph {
72 public:
73 using TNode = typename T::ContainerType;
74
75 /*! \brief A Node in the graph. */
76 struct Node {
77 /*! \brief Node Constructor
78 * \param ref The expression or dataflow pattern node this indexed graph node is augmenting.
79 * \param index The index of this node in the topological order
80 */
81 Node(const TNode* ref, PostDfsIndex index) : node_ref_(ref), index_(index) {}
82
83 /*! \brief The underlying expression or pattern node. */
84 const TNode* node_ref_;
85
86 T ref() const {
87 ICHECK(node_ref_ != nullptr);
88 return GetRef<T>(node_ref_);
89 }
90
91 /*!
92 * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then
93 * left does not flow into right. If left.index_ = right.index_ then left and right are
94 * the same node.
95 */
96 const PostDfsIndex index_;
97
98 /*! \brief If true this node has implicit outputs, for example as the result of a function. */
99 bool is_external_ = false;
100 /*! \brief Immediate dataflow inputs to this node. */
101 std::vector<Node*> inputs_;
102 /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */
103 std::vector<Node*> outputs_;
104
105 /*!
106 * \brief The node representing the 'basic block' containing this node:
107 * - Function bodies start a new basic block for their bodies.
108 * - The true and false branches of an if start their own blocks.
109 * - The arms of a match each have their own blocks.
110 */
111 Node* basic_block_ = nullptr;
112
113 /*! \brief The depth of this node in the dominator tree */
114 size_t depth_ = 0;
115 /*!
116 * \brief The dominator parent of this node. This is the node N with least index such that
117 * all possible dataflows from this node pass through N.
118 */
119 Node* dominator_parent_ = nullptr;
120 /*! \brief The nodes this node dominates. */
121 std::vector<Node*> dominator_children_;
122
123 /*!
124 * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be
125 * reached by following output paths.
126 */
127 void AccumulateDownstreamNodes(std::unordered_set<const Node*>* nodes) const {
128 std::stack<const Node*> stack;
129 stack.push(this);
130 while (!stack.empty()) {
131 const Node* current = stack.top();
132 stack.pop();
133 for (auto node : current->outputs_) {
134 if (nodes->count(node) == 0) {
135 stack.push(node);
136 nodes->insert(node);
137 }
138 }
139 }
140 }
141
142 /*!
143 * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p
144 * other pass through \p this.
145 */
146 bool Dominates(const Node* other) const {
147 std::stack<const Node*> stack;
148 std::unordered_set<const Node*> visited;
149 stack.push(this);
150 while (!stack.empty()) {
151 const Node* current = stack.top();
152 stack.pop();
153 for (auto node : current->dominator_children_) {
154 if (visited.count(node) == 0) {
155 if (other == node) {
156 return true;
157 } else {
158 stack.push(node);
159 }
160 visited.insert(node);
161 }
162 }
163 }
164 return false;
165 }
166 };
167
168 PostDfsIndex size() const { return topological_order_.size(); }
169
170 Node* item_to_node(const T& item) { return item_to_node(item.get()); }
171 const Node* item_to_node(const T& item) const { return item_to_node(item.get()); }
172
173 Node* item_to_node(const TNode* item) {
174 auto itr = node_map_.find(item);
175 ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef<T>(item));
176 return itr->second;
177 }
178
179 const Node* item_to_node(const TNode* item) const {
180 auto itr = node_map_.find(item);
181 ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef<T>(item));
182 return itr->second;
183 }
184
185 Node* index_to_node(PostDfsIndex index) {
186 ICHECK_LT(index, topological_order_.size()) << index;
187 return topological_order_[index].get();
188 }
189
190 const Node* index_to_node(PostDfsIndex index) const {
191 ICHECK_LT(index, topological_order_.size()) << index;
192 return topological_order_[index].get();
193 }
194
195 /*!
196 * \brief (For debugging only) Returns description of indexed graph with hints as to the
197 * sub-expressions or sub-patterns corresponding to each indexed graph node.
198 */
199 std::string ToString() const {
200 std::ostringstream os;
201 os << "IndexedGraph(size = " << topological_order_.size() << ") {" << std::endl;
202 for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) {
203 const Node* node = topological_order_[index].get();
204 ICHECK_EQ(index, node->index_);
205 os << " " << index << " (" << RefToSummary(node->ref()) << "): inputs=[";
206 for (const auto* sub_node : node->inputs_) {
207 os << sub_node->index_ << ",";
208 }
209 os << "], outputs=[";
210 for (const auto* sub_node : node->outputs_) {
211 os << sub_node->index_ << ",";
212 }
213 os << "]";
214 if (node->is_external_) {
215 os << ", external";
216 }
217 if (node->basic_block_) {
218 os << ", basic_block=" << node->basic_block_->index_;
219 }
220 if (node->depth_ > 0) {
221 os << ", depth=" << node->depth_;
222 }
223 if (node->dominator_parent_) {
224 os << ", dom_parent=" << node->dominator_parent_->index_;
225 }
226 os << ", dom_children=[";
227 for (const auto* sub_node : node->dominator_children_) {
228 os << sub_node->index_ << ",";
229 }
230 os << "]" << std::endl;
231 }
232 os << "}";
233 return os.str();
234 }
235
236 /*!
237 * Check-fails if the graph is ill-formed. For debugging only.
238 */
239 void CheckValid() const {
240 ICHECK_GT(topological_order_.size(), 0);
241 for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) {
242 const Node* node = topological_order_[index].get();
243 // We have a node.
244 ICHECK(node);
245 // Bijections with post-dfs indexes and expressions/patterns are correct.
246 ICHECK_EQ(node->index_, index);
247 ICHECK(node->node_ref_);
248 auto itr = node_map_.find(node->node_ref_);
249 ICHECK(itr != node_map_.end());
250 ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString();
251 // Inputs come before.
252 for (size_t i = 0; i < node->inputs_.size(); ++i) {
253 const Node* input = node->inputs_[i];
254 ICHECK(input);
255 ICHECK_LT(input->index_, index);
256 ICHECK(std::find(input->outputs_.begin(), input->outputs_.end(), node) !=
257 input->outputs_.end());
258 }
259 // Outputs come after.
260 for (size_t i = 0; i < node->outputs_.size(); ++i) {
261 const Node* output = node->outputs_[i];
262 ICHECK(output);
263 ICHECK_GT(output->index_, index);
264 ICHECK(std::find(output->inputs_.begin(), output->inputs_.end(), node) !=
265 output->inputs_.end());
266 }
267 ICHECK_GT(node->depth_, 0);
268 // Dominator children come before.
269 for (size_t i = 0; i < node->dominator_children_.size(); ++i) {
270 const Node* child = node->dominator_children_[i];
271 ICHECK(child);
272 ICHECK_LT(child->index_, index);
273 }
274 if (node->dominator_parent_) {
275 // Dominator comes after.
276 ICHECK_GT(node->dominator_parent_->index_, index);
277 }
278 }
279 }
280
281 private:
282 /*! \brief Construct the domination tree inside IndexedGraph */
283 void PostDom() {
284 for (PostDfsIndex i = topological_order_.size(); i != 0; --i) {
285 PostDfsIndex index = i - 1;
286 auto* current = topological_order_[index].get();
287 if (current->is_external_) {
288 current->depth_ = 1;
289 current->dominator_parent_ = nullptr;
290 } else {
291 auto parent = LeastCommonAncestor(current->outputs_);
292 current->depth_ = parent ? parent->depth_ + 1 : 1;
293 current->dominator_parent_ = parent;
294 if (parent) {
295 parent->dominator_children_.push_back(current);
296 }
297 }
298 }
299 }
300
301 /*! \brief Find the least common ancestor of all outputs of a node */
302 Node* LeastCommonAncestor(const std::vector<Node*>& outputs) {
303 if (outputs.size() == 0) {
304 return nullptr;
305 }
306 auto parent = outputs.at(0);
307 for (size_t i = 1; i < outputs.size(); ++i) {
308 parent = LeastCommonAncestor(parent, outputs.at(i));
309 }
310 return parent;
311 }
312
313 /*! \brief Find the least common ancestor of two nodes */
314 Node* LeastCommonAncestor(Node* lhs, Node* rhs) {
315 if (lhs == nullptr || rhs == nullptr) {
316 return nullptr;
317 }
318 PostDfsIndex lhs_index = lhs->index_;
319 PostDfsIndex rhs_index = rhs->index_;
320 while (lhs != rhs) {
321 ICHECK(lhs && rhs) << "LCA(" << lhs_index << ", " << rhs_index << ") on graph:" << std::endl
322 << ToString();
323 if (lhs->depth_ < rhs->depth_) {
324 rhs = rhs->dominator_parent_;
325 } else if (lhs->depth_ > rhs->depth_) {
326 lhs = lhs->dominator_parent_;
327 } else {
328 rhs = rhs->dominator_parent_;
329 lhs = lhs->dominator_parent_;
330 }
331 }
332 return lhs;
333 }
334
335 /*!
336 * \brief Appends a node corresponding to \p ref, and maintains the sub-expression/sub-pattern to
337 * node bijection. The insertion index will be the node's PostDfsIndex. All other node properties
338 * are accumulated in-place.
339 */
340 void AddNode(const T& ref) {
341 PostDfsIndex index = topological_order_.size();
342 auto node = std::make_unique<Node>(ref.get(), index);
343 node_map_[ref.get()] = node.get();
344 topological_order_.emplace_back(std::move(node));
345 }
346
347 /*!
348 * \brief Map from underlying sub-expression or sub-pattern nodes to their indexed graph nodes.
349 */
350 std::unordered_map<const TNode*, Node*> node_map_;
351 /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */
352 std::vector<std::unique_ptr<Node>> topological_order_;
353
354 friend std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr);
355 friend std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pattern);
356};
357
358/*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */
359std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr);
360
361/*!
362 * \brief Returns an Indexed Graph for \p pattern, which must outlive the result.
363 * The dataflow for a pattern mimics the dataflow for the expression which would match
364 * that pattern.
365 */
366std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pattern);
367
368} // namespace relay
369} // namespace tvm
370
371#endif // TVM_RELAY_IR_INDEXED_GRAPH_H_
372