1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/dependency_graph.cc
22 * \brief Implementation of dependency graph APIs.
23 */
24#include "dependency_graph.h"
25
26#include <tvm/relay/expr_functor.h>
27
28#include <unordered_set>
29#include <utility>
30
31namespace tvm {
32namespace relay {
33
34// Creator of DependencyGraph
35class DependencyGraph::Creator : private MixedModeVisitor {
36 public:
37 explicit Creator(support::Arena* arena) : arena_(arena) {}
38
39 DependencyGraph Create(const Expr& body) {
40 this->VisitExpr(body);
41 return std::move(graph_);
42 }
43
44 private:
45 /*! \brief allocator of all the internal node object */
46 support::Arena* arena_;
47 // The output.
48 DependencyGraph graph_;
49 // Update the message stored at the node.
50 void Depend(DependencyGraph::Node* parent, const Expr& child) {
51 VisitExpr(child);
52
53 ICHECK_NE(graph_.expr_node.count(child), 0);
54
55 Depend(parent, graph_.expr_node[child]);
56 }
57
58 void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
59 auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*>>();
60 parent_link->value = parent;
61 child->parents.Push(parent_link);
62
63 auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*>>();
64 child_link->value = child;
65 parent->children.Push(child_link);
66 }
67
68 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited_;
69
70 DependencyGraph::Node* NewNode(bool new_scope) {
71 auto* ret = arena_->make<DependencyGraph::Node>();
72 ret->new_scope = new_scope;
73 return ret;
74 }
75
76 void VisitLeaf(const Expr& e) override {
77 if (visited_.count(e) == 0) {
78 if (graph_.expr_node.count(e) == 0) {
79 graph_.expr_node[e] = NewNode(false);
80 }
81 visited_.insert(e);
82 MixedModeVisitor::VisitLeaf(e);
83 graph_.post_dfs_order.push_back(graph_.expr_node[e]);
84 }
85 }
86
87 void VisitExpr_(const CallNode* c) final {
88 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
89 Depend(n, c->op);
90 for (const auto& a : c->args) {
91 Depend(n, a);
92 }
93 }
94
95 void VisitExpr_(const TupleNode* t) final {
96 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
97 for (const auto& a : t->fields) {
98 Depend(n, a);
99 }
100 }
101
102 void VisitExpr_(const TupleGetItemNode* t) final {
103 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
104 Depend(n, t->tuple);
105 }
106
107 void VisitExpr_(const RefCreateNode* r) final {
108 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
109 Depend(n, r->value);
110 }
111
112 void VisitExpr_(const RefReadNode* r) final {
113 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
114 Depend(n, r->ref);
115 }
116
117 void VisitExpr_(const RefWriteNode* r) final {
118 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
119 Depend(n, r->ref);
120 Depend(n, r->value);
121 }
122
123 void VisitExpr_(const IfNode* i) final {
124 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
125 DependencyGraph::Node* t = NewNode(true);
126 DependencyGraph::Node* f = NewNode(true);
127 Depend(n, i->cond);
128 Depend(n, t);
129 Depend(n, f);
130 Depend(t, i->true_branch);
131 Depend(f, i->false_branch);
132 graph_.post_dfs_order.push_back(f);
133 graph_.post_dfs_order.push_back(t);
134 }
135
136 void VisitExpr_(const FunctionNode* f) final {
137 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
138 DependencyGraph::Node* b = NewNode(true);
139 Depend(n, b);
140 for (const auto& p : f->params) {
141 Depend(b, p);
142 }
143 Depend(b, f->body);
144 graph_.post_dfs_order.push_back(b);
145 }
146
147 void VisitExpr_(const LetNode* l) final {
148 std::unordered_map<const LetNode*, DependencyGraph::Node*> b_map;
149 auto pre_visit = [&](const LetNode* op) {
150 Expr e = GetRef<Expr>(op);
151 // Derived VisitLeaf
152 if (visited_.count(e) == 0) {
153 if (graph_.expr_node.count(e) == 0) {
154 graph_.expr_node[e] = NewNode(false);
155 }
156 visited_.insert(e);
157 }
158 DependencyGraph::Node* n = graph_.expr_node[e];
159 DependencyGraph::Node* b = NewNode(true);
160 Depend(n, b);
161 Depend(b, op->var);
162 Depend(b, op->value);
163 b_map[op] = b;
164 };
165 auto post_visit = [&](const LetNode* op) {
166 ICHECK(b_map.count(op));
167 DependencyGraph::Node* b = b_map[op];
168 Expr e = GetRef<Expr>(op);
169 Depend(b, op->body);
170 graph_.post_dfs_order.push_back(b);
171 if (op != l) {
172 // Base VisitLeaf
173 this->visit_counter_[op]++;
174 // Derived VisitLeaf
175 graph_.post_dfs_order.push_back(graph_.expr_node[e]);
176 }
177 };
178 ExpandANormalForm(l, pre_visit, post_visit);
179 }
180
181 void VisitExpr_(const MatchNode* m) final {
182 DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
183 Depend(n, m->data);
184 std::vector<DependencyGraph::Node*> v;
185 for (const Clause& c : m->clauses) {
186 DependencyGraph::Node* b = NewNode(true);
187 Depend(n, b);
188 Depend(b, c->rhs);
189 v.push_back(b);
190 }
191 for (auto it = v.rbegin(); it != v.rend(); ++it) {
192 graph_.post_dfs_order.push_back(*it);
193 }
194 }
195
196 void VisitExpr_(const VarNode* v) final {}
197
198 void VisitExpr_(const GlobalVarNode* v) final {}
199
200 void VisitExpr_(const ConstantNode* c) final {}
201
202 void VisitExpr_(const OpNode* o) final {}
203
204 void VisitExpr_(const ConstructorNode* c) final {}
205};
206
207DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) {
208 return Creator(arena).Create(body);
209}
210
211} // namespace relay
212} // namespace tvm
213