1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/dependency_graph.cc |
22 | * \brief Implementation of dependency graph APIs. |
23 | */ |
24 | #include "dependency_graph.h" |
25 | |
26 | #include <tvm/relay/expr_functor.h> |
27 | |
28 | #include <unordered_set> |
29 | #include <utility> |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | // Creator of DependencyGraph |
35 | class DependencyGraph::Creator : private MixedModeVisitor { |
36 | public: |
37 | explicit Creator(support::Arena* arena) : arena_(arena) {} |
38 | |
39 | DependencyGraph Create(const Expr& body) { |
40 | this->VisitExpr(body); |
41 | return std::move(graph_); |
42 | } |
43 | |
44 | private: |
45 | /*! \brief allocator of all the internal node object */ |
46 | support::Arena* arena_; |
47 | // The output. |
48 | DependencyGraph graph_; |
49 | // Update the message stored at the node. |
50 | void Depend(DependencyGraph::Node* parent, const Expr& child) { |
51 | VisitExpr(child); |
52 | |
53 | ICHECK_NE(graph_.expr_node.count(child), 0); |
54 | |
55 | Depend(parent, graph_.expr_node[child]); |
56 | } |
57 | |
58 | void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) { |
59 | auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*>>(); |
60 | parent_link->value = parent; |
61 | child->parents.Push(parent_link); |
62 | |
63 | auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*>>(); |
64 | child_link->value = child; |
65 | parent->children.Push(child_link); |
66 | } |
67 | |
68 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited_; |
69 | |
70 | DependencyGraph::Node* NewNode(bool new_scope) { |
71 | auto* ret = arena_->make<DependencyGraph::Node>(); |
72 | ret->new_scope = new_scope; |
73 | return ret; |
74 | } |
75 | |
76 | void VisitLeaf(const Expr& e) override { |
77 | if (visited_.count(e) == 0) { |
78 | if (graph_.expr_node.count(e) == 0) { |
79 | graph_.expr_node[e] = NewNode(false); |
80 | } |
81 | visited_.insert(e); |
82 | MixedModeVisitor::VisitLeaf(e); |
83 | graph_.post_dfs_order.push_back(graph_.expr_node[e]); |
84 | } |
85 | } |
86 | |
87 | void VisitExpr_(const CallNode* c) final { |
88 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)]; |
89 | Depend(n, c->op); |
90 | for (const auto& a : c->args) { |
91 | Depend(n, a); |
92 | } |
93 | } |
94 | |
95 | void VisitExpr_(const TupleNode* t) final { |
96 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)]; |
97 | for (const auto& a : t->fields) { |
98 | Depend(n, a); |
99 | } |
100 | } |
101 | |
102 | void VisitExpr_(const TupleGetItemNode* t) final { |
103 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)]; |
104 | Depend(n, t->tuple); |
105 | } |
106 | |
107 | void VisitExpr_(const RefCreateNode* r) final { |
108 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)]; |
109 | Depend(n, r->value); |
110 | } |
111 | |
112 | void VisitExpr_(const RefReadNode* r) final { |
113 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)]; |
114 | Depend(n, r->ref); |
115 | } |
116 | |
117 | void VisitExpr_(const RefWriteNode* r) final { |
118 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)]; |
119 | Depend(n, r->ref); |
120 | Depend(n, r->value); |
121 | } |
122 | |
123 | void VisitExpr_(const IfNode* i) final { |
124 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)]; |
125 | DependencyGraph::Node* t = NewNode(true); |
126 | DependencyGraph::Node* f = NewNode(true); |
127 | Depend(n, i->cond); |
128 | Depend(n, t); |
129 | Depend(n, f); |
130 | Depend(t, i->true_branch); |
131 | Depend(f, i->false_branch); |
132 | graph_.post_dfs_order.push_back(f); |
133 | graph_.post_dfs_order.push_back(t); |
134 | } |
135 | |
136 | void VisitExpr_(const FunctionNode* f) final { |
137 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)]; |
138 | DependencyGraph::Node* b = NewNode(true); |
139 | Depend(n, b); |
140 | for (const auto& p : f->params) { |
141 | Depend(b, p); |
142 | } |
143 | Depend(b, f->body); |
144 | graph_.post_dfs_order.push_back(b); |
145 | } |
146 | |
147 | void VisitExpr_(const LetNode* l) final { |
148 | std::unordered_map<const LetNode*, DependencyGraph::Node*> b_map; |
149 | auto pre_visit = [&](const LetNode* op) { |
150 | Expr e = GetRef<Expr>(op); |
151 | // Derived VisitLeaf |
152 | if (visited_.count(e) == 0) { |
153 | if (graph_.expr_node.count(e) == 0) { |
154 | graph_.expr_node[e] = NewNode(false); |
155 | } |
156 | visited_.insert(e); |
157 | } |
158 | DependencyGraph::Node* n = graph_.expr_node[e]; |
159 | DependencyGraph::Node* b = NewNode(true); |
160 | Depend(n, b); |
161 | Depend(b, op->var); |
162 | Depend(b, op->value); |
163 | b_map[op] = b; |
164 | }; |
165 | auto post_visit = [&](const LetNode* op) { |
166 | ICHECK(b_map.count(op)); |
167 | DependencyGraph::Node* b = b_map[op]; |
168 | Expr e = GetRef<Expr>(op); |
169 | Depend(b, op->body); |
170 | graph_.post_dfs_order.push_back(b); |
171 | if (op != l) { |
172 | // Base VisitLeaf |
173 | this->visit_counter_[op]++; |
174 | // Derived VisitLeaf |
175 | graph_.post_dfs_order.push_back(graph_.expr_node[e]); |
176 | } |
177 | }; |
178 | ExpandANormalForm(l, pre_visit, post_visit); |
179 | } |
180 | |
181 | void VisitExpr_(const MatchNode* m) final { |
182 | DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)]; |
183 | Depend(n, m->data); |
184 | std::vector<DependencyGraph::Node*> v; |
185 | for (const Clause& c : m->clauses) { |
186 | DependencyGraph::Node* b = NewNode(true); |
187 | Depend(n, b); |
188 | Depend(b, c->rhs); |
189 | v.push_back(b); |
190 | } |
191 | for (auto it = v.rbegin(); it != v.rend(); ++it) { |
192 | graph_.post_dfs_order.push_back(*it); |
193 | } |
194 | } |
195 | |
196 | void VisitExpr_(const VarNode* v) final {} |
197 | |
198 | void VisitExpr_(const GlobalVarNode* v) final {} |
199 | |
200 | void VisitExpr_(const ConstantNode* c) final {} |
201 | |
202 | void VisitExpr_(const OpNode* o) final {} |
203 | |
204 | void VisitExpr_(const ConstructorNode* c) final {} |
205 | }; |
206 | |
207 | DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) { |
208 | return Creator(arena).Create(body); |
209 | } |
210 | |
211 | } // namespace relay |
212 | } // namespace tvm |
213 | |