1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/indexed_graph.cc
22 * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow)
23 * pattern.
24 */
25#include "indexed_graph.h"
26
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/dataflow_pattern_functor.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/pattern_functor.h>
31
32#include <string>
33
34namespace tvm {
35namespace relay {
36
37std::string RefToSummary(const Expr& expr) {
38 class Visitor : public ExprFunctor<std::string(const Expr&)> {
39 std::string VisitExpr_(const VarNode* op) final { return "%" + op->name_hint(); }
40 std::string VisitExpr_(const GlobalVarNode* op) final { return "@" + op->name_hint; }
41 std::string VisitExpr_(const ConstantNode* op) final { return "const"; }
42 std::string VisitExpr_(const TupleNode* op) final {
43 return "tuple(" + std::to_string(op->fields.size()) + ")";
44 }
45 std::string VisitExpr_(const FunctionNode* op) final { return "fn"; }
46 std::string VisitExpr_(const CallNode* op) final {
47 return VisitExpr(op->op) + "(" + std::to_string(op->args.size()) + ")";
48 }
49 std::string VisitExpr_(const LetNode* op) final { return "let"; }
50 std::string VisitExpr_(const IfNode* op) final { return "if"; }
51 std::string VisitExpr_(const OpNode* op) final { return op->name; }
52 std::string VisitExpr_(const TupleGetItemNode* op) final {
53 return "." + std::to_string(op->index);
54 }
55 std::string VisitExpr_(const RefCreateNode* op) final { return "ref_create"; }
56 std::string VisitExpr_(const RefReadNode* op) final { return "ref_read"; }
57 std::string VisitExpr_(const RefWriteNode* op) final { return "ref_write"; }
58 std::string VisitExpr_(const ConstructorNode* op) final { return "ctor"; }
59 std::string VisitExpr_(const MatchNode* op) final { return "match"; }
60 };
61 return Visitor().VisitExpr(expr);
62}
63
64std::string RefToSummary(const DFPattern& pattern) {
65 // TODO(mbs): Implement as debugging requires.
66 return "";
67}
68
69std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr) {
70 /*!
71 * \brief Adds indexed graph nodes in post-dfs order, and discovers which let-bound vars are to
72 * recursive functions.
73 */
74 class Creator : public MixedModeVisitor {
75 public:
76 std::pair<std::unique_ptr<IndexedGraph<Expr>>,
77 std::unique_ptr<std::unordered_set<const CallNode*>>>
78 CreateGraph(const Expr& expr) {
79 VisitExpr(expr);
80 // Last visited node is implicitly used 'externally'.
81 graph_->item_to_node(expr)->is_external_ = true;
82 return {std::move(graph_), std::move(rec_calls_)};
83 }
84
85 protected:
86 using MixedModeVisitor::VisitExpr_;
87
88 // By the default the MixedModeVisitor will place
89 // - callee and arguments before a call
90 // - tuple fields before a tuple
91 // - tuple before a tuple projection
92 void VisitLeaf(const Expr& expr) override {
93 if (const auto* var_node = expr.as<VarNode>()) {
94 if (var_node == current_let_bound_var_) {
95 // Don't visit occurrences of let-rec bound vars in the recursive function body.
96 // Instead, wait for them to be visited at call sites outside of the function.
97 VLOG(1) << "Ignore let-rec var '" << var_node->name_hint() << "'";
98 return;
99 }
100 }
101
102 MixedModeVisitor::VisitLeaf(expr);
103 graph_->AddNode(expr);
104
105 if (const auto* call_node = expr.as<CallNode>()) {
106 if (const auto* var_node = call_node->op.as<VarNode>()) {
107 if (var_node == current_let_bound_var_) {
108 // Remember this is a recursive call to the let-rec bound function.
109 // The Annotator functor below will not record any dependency from the let-rec bound
110 // var to the expression so that the indexed graph is always a DAG.
111 VLOG(1) << "Remembering recursive call to '" << var_node->name_hint() << "'";
112 rec_calls_->emplace(call_node);
113 }
114 }
115 }
116 }
117
118 void VisitExpr_(const LetNode* let_node) override {
119 auto pre_visit = [&](const LetNode* op) {
120 // Let-bound values come before their let-bound variable.
121 const VarNode* prev_let_bound_var = current_let_bound_var_;
122 current_let_bound_var_ = op->var.get();
123 VisitExpr(op->value);
124 current_let_bound_var_ = prev_let_bound_var;
125 VisitExpr(op->var);
126 };
127 auto post_visit = [&](const LetNode* op) {
128 VisitExpr(op->body);
129 if (let_node != op) {
130 // Replicate VisitLeaf, which we are effectively bypassing.
131 visit_counter_[op]++;
132 graph_->AddNode(GetRef<Expr>(op));
133 }
134 };
135 ExpandANormalForm(let_node, pre_visit, post_visit);
136 }
137
138 class PatternCreator : public PatternVisitor {
139 public:
140 explicit PatternCreator(Creator* creator) : creator_(creator) {}
141
142 private:
143 void VisitPattern_(const PatternVarNode* pattern_var_node) final {
144 creator_->VisitLeaf(pattern_var_node->var);
145 }
146
147 Creator* creator_;
148 };
149
150 void VisitExpr_(const MatchNode* match_node) override {
151 // Matched data comes before match-bound vars then match rhs, in match order.
152 VisitExpr(match_node->data);
153 for (const Clause& c : match_node->clauses) {
154 PatternCreator pattern_creator(this);
155 pattern_creator.VisitPattern(c->lhs);
156 VisitExpr(c->rhs);
157 }
158 }
159
160 /*! \brief Graph we are accumulated nodes into. */
161 std::unique_ptr<IndexedGraph<Expr>> graph_ = std::make_unique<IndexedGraph<Expr>>();
162 /*! \brief Variable the currently visited expression is to be let-bound to, if any. */
163 const VarNode* current_let_bound_var_ = nullptr;
164 /*! \brief Accumulated calls to recursive functions. */
165 std::unique_ptr<std::unordered_set<const CallNode*>> rec_calls_ =
166 std::make_unique<std::unordered_set<const CallNode*>>();
167 };
168
169 /*!
170 * \brief Fills in the inputs and outputs for all nodes, then does dominator analysis.
171 *
172 * Thought we use the ExprFunctor to visit nodes, we never recurse and instead just inspect
173 * each sub-expression's immediate sub-sub-expressions to accumulate inputs and outputs.
174 */
175 class Annotator : public ExprFunctor<void(const Expr&)> {
176 public:
177 explicit Annotator(std::pair<std::unique_ptr<IndexedGraph<Expr>>,
178 std::unique_ptr<std::unordered_set<const CallNode*>>>
179 args)
180 : graph_(std::move(args.first)), rec_calls_(std::move(args.second)) {}
181
182 std::unique_ptr<IndexedGraph<Expr>> Annotate() {
183 // Visit all of the nodes in topological order to get forward outputs
184 for (PostDfsIndex index = 0; index < graph_->size(); ++index) {
185 VisitExpr(graph_->index_to_node(index)->ref());
186 }
187 // do the dominator analysis
188 graph_->PostDom();
189 return std::move(graph_);
190 }
191
192 /*!
193 * \brief Add \p parent as a possible output of the node corresponding to \p expr.
194 */
195 void AddOutput(const Expr& expr, IndexedGraph<Expr>::Node* parent) {
196 auto current = graph_->item_to_node(expr);
197 current->outputs_.push_back(parent);
198 parent->inputs_.push_back(current);
199 }
200
201 protected:
202 void VisitExpr_(const VarNode* var_node) override {}
203
204 void VisitExpr_(const GlobalVarNode* global_var_node) override {}
205
206 void VisitExpr_(const ConstantNode* constant_node) override {}
207
208 void VisitExpr_(const TupleNode* tuple_node) override {
209 auto node = graph_->item_to_node(GetRef<Tuple>(tuple_node));
210 for (auto field : tuple_node->fields) {
211 AddOutput(field, node);
212 }
213 }
214
215 void VisitExpr_(const FunctionNode* function_node) override {
216 auto node = graph_->item_to_node(GetRef<Function>(function_node));
217 // Nothing to do for parameters -- each use of a parameter will contribute to its outputs.
218 AddOutput(function_node->body, node);
219 }
220
221 void VisitExpr_(const CallNode* call_node) override {
222 auto node = graph_->item_to_node(GetRef<Call>(call_node));
223 if (rec_calls_->count(call_node)) {
224 // We want the indexed graph to be a DAG, so don't consider a call to a let-rec bound
225 // function from inside the function to depend on the let-rec bound var.
226 VLOG(1) << "Ignoring op in call " << RefToSummary(GetRef<Call>(call_node));
227 } else {
228 AddOutput(call_node->op, node);
229 }
230 for (auto arg : call_node->args) {
231 AddOutput(arg, node);
232 }
233 }
234
235 void VisitExpr_(const LetNode* let_node) override {
236 auto node = graph_->item_to_node(GetRef<Let>(let_node));
237 auto let_var_node = graph_->item_to_node(let_node->var);
238 AddOutput(let_node->value, let_var_node);
239 // Nothing to do for the let-bound variable -- each use of that variable in the let-body
240 // will contribute to its outputs.
241 AddOutput(let_node->body, node);
242 }
243
244 void VisitExpr_(const IfNode* if_node) override {
245 auto node = graph_->item_to_node(GetRef<If>(if_node));
246 AddOutput(if_node->cond, node);
247 AddOutput(if_node->true_branch, node);
248 AddOutput(if_node->false_branch, node);
249 }
250
251 void VisitExpr_(const OpNode* op_node) override {}
252
253 void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override {
254 auto node = graph_->item_to_node(GetRef<TupleGetItem>(tuple_get_item_node));
255 AddOutput(tuple_get_item_node->tuple, node);
256 }
257
258 void VisitExpr_(const RefCreateNode* ref_create_node) override {
259 auto node = graph_->item_to_node(GetRef<RefCreate>(ref_create_node));
260 AddOutput(ref_create_node->value, node);
261 }
262
263 void VisitExpr_(const RefReadNode* ref_read_node) override {
264 auto node = graph_->item_to_node(GetRef<RefRead>(ref_read_node));
265 AddOutput(ref_read_node->ref, node);
266 }
267
268 void VisitExpr_(const RefWriteNode* ref_write_node) override {
269 auto node = graph_->item_to_node(GetRef<RefWrite>(ref_write_node));
270 AddOutput(ref_write_node->ref, node);
271 AddOutput(ref_write_node->value, node);
272 }
273
274 void VisitExpr_(const ConstructorNode* constructor_node) override {}
275
276 class PatternAnnotator : public PatternVisitor {
277 public:
278 PatternAnnotator(Annotator* annotator, const ExprNode* adt_node)
279 : annotator_(annotator), adt_node_(adt_node) {}
280
281 private:
282 void VisitPattern_(const PatternVarNode* pattern_var_node) final {
283 auto node = annotator_->graph_->item_to_node(pattern_var_node->var);
284 annotator_->AddOutput(GetRef<Expr>(adt_node_), node);
285 }
286
287 Annotator* annotator_;
288 const ExprNode* adt_node_;
289 };
290
291 void VisitExpr_(const MatchNode* match_node) override {
292 // Data flows from the match data to pattern vars into match arms and out into overall
293 // match.
294 auto node = graph_->item_to_node(GetRef<Match>(match_node));
295 for (const Clause& c : match_node->clauses) {
296 PatternAnnotator pattern_annotator(this, match_node->data.get());
297 pattern_annotator.VisitPattern(c->lhs);
298 AddOutput(c->rhs, node);
299 }
300 }
301
302 std::unique_ptr<IndexedGraph<Expr>> graph_;
303 /*! \brief Accumulated calls to recursive functions. */
304 std::unique_ptr<std::unordered_set<const CallNode*>> rec_calls_;
305 };
306
307 /*! \brief Fills in the basic blocks for all nodes. */
308 class Blocker : public MixedModeVisitor {
309 public:
310 explicit Blocker(std::unique_ptr<IndexedGraph<Expr>> graph) : graph_(std::move(graph)) {}
311
312 std::unique_ptr<IndexedGraph<Expr>> Scope(const Expr& expr) {
313 VisitExpr(expr);
314 return std::move(graph_);
315 }
316
317 private:
318 using MixedModeVisitor::VisitExpr_;
319
320 void VisitLeaf(const Expr& expr) override {
321 MixedModeVisitor::VisitLeaf(expr);
322 SetScope(expr);
323 }
324
325 void VisitExpr_(const FunctionNode* function_node) override {
326 auto node = graph_->item_to_node(GetRef<Function>(function_node));
327 basic_block_stack_.push_back(node);
328 ExprVisitor::VisitExpr_(function_node);
329 basic_block_stack_.pop_back();
330 }
331
332 void VisitExpr_(const IfNode* if_node) override {
333 VisitExpr(if_node->cond);
334 auto node = graph_->item_to_node(GetRef<If>(if_node));
335 basic_block_stack_.push_back(node);
336 VisitExpr(if_node->true_branch);
337 VisitExpr(if_node->false_branch);
338 basic_block_stack_.pop_back();
339 }
340
341 void VisitExpr_(const LetNode* let_node) override {
342 auto pre_visit = [&](const LetNode* op) {
343 VisitExpr(op->value);
344 VisitExpr(op->var);
345 };
346 auto post_visit = [&](const LetNode* op) {
347 VisitExpr(op->body);
348 if (let_node != op) {
349 visit_counter_[op]++;
350 SetScope(GetRef<Let>(op));
351 }
352 };
353 ExpandANormalForm(let_node, pre_visit, post_visit);
354 }
355
356 class PatternBlocker : public PatternVisitor {
357 public:
358 explicit PatternBlocker(Blocker* scoper) : scoper_(scoper) {}
359
360 private:
361 void VisitPattern_(const PatternVarNode* pattern_var_node) final {
362 scoper_->SetScope(pattern_var_node->var);
363 }
364
365 Blocker* scoper_;
366 };
367
368 void VisitExpr_(const MatchNode* match_node) override {
369 VisitExpr(match_node->data);
370 auto node = graph_->item_to_node(GetRef<Match>(match_node));
371 basic_block_stack_.push_back(node);
372 for (const Clause& c : match_node->clauses) {
373 PatternBlocker pattern_scoper(this);
374 pattern_scoper.VisitPattern(c->lhs);
375 VisitExpr(c->rhs);
376 }
377 basic_block_stack_.pop_back();
378 }
379
380 void SetScope(const Expr& expr) {
381 auto node = graph_->item_to_node(expr);
382 if (!basic_block_stack_.empty()) {
383 node->basic_block_ = basic_block_stack_.back();
384 }
385 }
386
387 std::unique_ptr<IndexedGraph<Expr>> graph_;
388 std::vector<IndexedGraph<Expr>::Node*> basic_block_stack_;
389 };
390
391 VLOG(1) << "CreateIndexedGraph:" << std::endl << PrettyPrint(expr);
392 std::unique_ptr<IndexedGraph<Expr>> graph =
393 Blocker(Annotator(Creator().CreateGraph(expr)).Annotate()).Scope(expr);
394 VLOG(1) << "graph:" << std::endl << graph->ToString();
395#if TVM_LOG_DEBUG
396 graph->CheckValid();
397#endif
398 return graph;
399}
400
401std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pattern) {
402 /*! \brief Creates an IndexedGraph and determines topological order */
403 class Creator : public DFPatternVisitor {
404 public:
405 std::unique_ptr<IndexedGraph<DFPattern>> CreateGraph(const DFPattern& pattern) {
406 graph_ = std::make_unique<IndexedGraph<DFPattern>>();
407 VisitDFPattern(pattern);
408 graph_->item_to_node(pattern)->is_external_ = true;
409 return std::move(graph_);
410 }
411
412 protected:
413 void VisitDFPattern(const DFPattern& pattern) override {
414 if (this->visited_.count(pattern.get()) == 0) {
415 DFPatternVisitor::VisitDFPattern(pattern);
416 graph_->AddNode(pattern);
417 }
418 }
419
420 std::unique_ptr<IndexedGraph<DFPattern>> graph_;
421 };
422
423 /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree
424 * analysis.
425 *
426 * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined
427 * topological order instead of recursing.
428 */
429 class Annotator : public DFPatternFunctor<void(const DFPattern&)> {
430 public:
431 Annotator(std::unique_ptr<IndexedGraph<DFPattern>> graph) : graph_(std::move(graph)) {}
432
433 std::unique_ptr<IndexedGraph<DFPattern>> Annotate() {
434 // Visit all of the nodes in topological order to get forward outputs
435 for (PostDfsIndex index = 0; index < graph_->size(); ++index) {
436 VisitDFPattern(graph_->index_to_node(index)->ref());
437 }
438 // do the dominator analysis
439 graph_->PostDom();
440 return std::move(graph_);
441 }
442
443 /*! Default visitation pushes the parent to the child's outputs */
444 void AddOutput(const DFPattern& pattern, IndexedGraph<DFPattern>::Node* parent) {
445 auto current = graph_->item_to_node(pattern);
446 if (parent) {
447 current->outputs_.push_back(parent);
448 parent->inputs_.push_back(current);
449 }
450 }
451
452 protected:
453 void VisitDFPattern_(const AltPatternNode* op) override {
454 auto node = graph_->item_to_node(GetRef<AltPattern>(op));
455 AddOutput(op->left, node);
456 AddOutput(op->right, node);
457 }
458
459 void VisitDFPattern_(const AttrPatternNode* op) override {
460 auto node = graph_->item_to_node(GetRef<AttrPattern>(op));
461 AddOutput(op->pattern, node);
462 }
463
464 void VisitDFPattern_(const CallPatternNode* op) override {
465 auto node = graph_->item_to_node(GetRef<CallPattern>(op));
466 AddOutput(op->op, node);
467 if (op->args.defined()) {
468 for (auto arg : op->args) {
469 AddOutput(arg, node);
470 }
471 }
472 }
473
474 void VisitDFPattern_(const ConstantPatternNode* op) override {}
475
476 void VisitDFPattern_(const DataTypePatternNode* op) override {
477 auto node = graph_->item_to_node(GetRef<DataTypePattern>(op));
478 AddOutput(op->pattern, node);
479 }
480
481 void VisitDFPattern_(const DominatorPatternNode* op) override {
482 auto node = graph_->item_to_node(GetRef<DominatorPattern>(op));
483 AddOutput(op->parent, node);
484 AddOutput(op->path, node);
485 AddOutput(op->child, node);
486 }
487
488 void VisitDFPattern_(const ExprPatternNode* op) override {}
489
490 void VisitDFPattern_(const FunctionPatternNode* op) override {
491 auto node = graph_->item_to_node(GetRef<FunctionPattern>(op));
492 if (op->params.defined()) {
493 for (auto param : op->params) {
494 AddOutput(param, node);
495 }
496 }
497 AddOutput(op->body, node);
498 }
499
500 void VisitDFPattern_(const ShapePatternNode* op) override {
501 auto node = graph_->item_to_node(GetRef<ShapePattern>(op));
502 AddOutput(op->pattern, node);
503 }
504
505 void VisitDFPattern_(const TupleGetItemPatternNode* op) override {
506 auto node = graph_->item_to_node(GetRef<TupleGetItemPattern>(op));
507 AddOutput(op->tuple, node);
508 }
509
510 void VisitDFPattern_(const TuplePatternNode* op) override {
511 auto node = graph_->item_to_node(GetRef<TuplePattern>(op));
512 if (op->fields.defined()) {
513 for (auto field : op->fields) {
514 AddOutput(field, node);
515 }
516 }
517 }
518
519 void VisitDFPattern_(const IfPatternNode* op) override {
520 auto node = graph_->item_to_node(GetRef<IfPattern>(op));
521 AddOutput(op->cond, node);
522 AddOutput(op->true_branch, node);
523 AddOutput(op->false_branch, node);
524 }
525
526 void VisitDFPattern_(const LetPatternNode* op) override {
527 auto node = graph_->item_to_node(GetRef<LetPattern>(op));
528 AddOutput(op->var, node);
529 AddOutput(op->value, node);
530 AddOutput(op->body, node);
531 }
532
533 void VisitDFPattern_(const TypePatternNode* op) override {
534 auto node = graph_->item_to_node(GetRef<TypePattern>(op));
535 AddOutput(op->pattern, node);
536 }
537
538 void VisitDFPattern_(const VarPatternNode* op) override {}
539
540 void VisitDFPattern_(const WildcardPatternNode* op) override {}
541
542 std::unique_ptr<IndexedGraph<DFPattern>> graph_;
543 };
544
545 return Annotator(Creator().CreateGraph(pattern)).Annotate();
546}
547
548} // namespace relay
549} // namespace tvm
550