1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "./graph_partitioner.h" |
21 | |
22 | #include <vector> |
23 | |
24 | namespace tvm { |
25 | namespace relay { |
26 | |
27 | DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { |
28 | DominatorTree tree; |
29 | tree.nodes.resize(graph.post_dfs_order.size(), nullptr); |
30 | // reverse topo order |
31 | for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { |
32 | size_t index = i - 1; |
33 | tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); |
34 | } |
35 | return tree; |
36 | } |
37 | |
38 | DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, |
39 | OpPatternKind* edge_pattern) { |
40 | while (lhs != rhs) { |
41 | if (lhs == nullptr) return nullptr; |
42 | if (rhs == nullptr) return nullptr; |
43 | if (lhs->depth < rhs->depth) { |
44 | edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); |
45 | rhs = rhs->parent; |
46 | } else if (rhs->depth < lhs->depth) { |
47 | edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); |
48 | lhs = lhs->parent; |
49 | } else { |
50 | edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); |
51 | edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); |
52 | lhs = lhs->parent; |
53 | rhs = rhs->parent; |
54 | } |
55 | } |
56 | return lhs; |
57 | } |
58 | |
59 | DominatorTree::Node* DominatorTree::LeastCommonAncestor( |
60 | const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) { |
61 | auto link = input_nodes.head; |
62 | if (link == nullptr) { |
63 | return nullptr; |
64 | } |
65 | auto get_node = [&](const IndexedForwardGraph::Edge& edge) { |
66 | size_t oindex = edge.node->index; |
67 | ICHECK_LT(oindex, nodes.size()); |
68 | Node* onode = nodes[oindex]; |
69 | ICHECK(onode != nullptr); |
70 | return onode; |
71 | }; |
72 | Node* parent = get_node(link->value); |
73 | *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); |
74 | link = link->next; |
75 | for (; link != nullptr; link = link->next) { |
76 | parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); |
77 | *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); |
78 | } |
79 | return parent; |
80 | } |
81 | |
82 | DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, |
83 | IndexedForwardGraph::Node* gnode) { |
84 | Node* tnode = arena->make<Node>(); |
85 | tnode->gnode = gnode; |
86 | if (gnode->extern_ref) { |
87 | tnode->depth = 1; |
88 | tnode->parent = nullptr; |
89 | tnode->pattern = kOpaque; |
90 | } else { |
91 | // find the LCAs of all outputs. |
92 | OpPatternKind pattern = kElemWise; |
93 | Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); |
94 | tnode->depth = parent ? parent->depth + 1 : 1; |
95 | tnode->parent = parent; |
96 | tnode->pattern = pattern; |
97 | } |
98 | return tnode; |
99 | } |
100 | |
101 | std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition( |
102 | const IndexedForwardGraph& graph) { |
103 | this->InitGroups(graph); |
104 | if (opt_level_ == 0) return std::move(groups_); |
105 | // get post dominator tree |
106 | auto post_dom_tree = DominatorTree::PostDom(arena_, graph); |
107 | // run fusion algorithm. |
108 | for (int phase = 0; phase < 3; ++phase) { |
109 | this->RunFuse(graph, post_dom_tree, phase); |
110 | } |
111 | return std::move(groups_); |
112 | } |
113 | |
114 | GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { |
115 | // fast path |
116 | if (this->parent == nullptr) return this; |
117 | // slow path with path compression. |
118 | Group* root = this; |
119 | while (root->parent != nullptr) { |
120 | root = root->parent; |
121 | } |
122 | for (Group* p = this; p != root;) { |
123 | Group* parent = p->parent; |
124 | p->parent = root; |
125 | p = parent; |
126 | } |
127 | return root; |
128 | } |
129 | |
130 | template <typename F> |
131 | bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
132 | F fcond) { |
133 | if (visited_.count(src)) return true; |
134 | visited_.insert(src); |
135 | Group* gnode = groups_[src->index]; |
136 | ICHECK(gnode != nullptr); |
137 | gnode = gnode->FindRoot(); |
138 | if (!fcond(gnode->pattern, src == sink)) return false; |
139 | if (src == sink) return true; |
140 | for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
141 | if (!CheckPath_(link->value.node, sink, fcond)) return false; |
142 | } |
143 | return true; |
144 | } |
145 | |
146 | template <typename F> |
147 | bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
148 | F fcond) { |
149 | ICHECK(!src->extern_ref); |
150 | visited_.clear(); |
151 | ICHECK(src != sink); |
152 | for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
153 | if (!CheckPath_(link->value.node, sink, fcond)) return false; |
154 | } |
155 | return true; |
156 | } |
157 | |
158 | OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { |
159 | if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { |
160 | LOG(FATAL) << "Cannot merge two complex group together" ; |
161 | } |
162 | if (lhs > rhs) return lhs; |
163 | return rhs; |
164 | } |
165 | |
166 | void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { |
167 | child = child->FindRoot(); |
168 | parent = parent->FindRoot(); |
169 | if (child == parent) return; |
170 | // update the number of nodes of the parent group |
171 | parent->num_nodes += child->num_nodes; |
172 | child->parent = parent; |
173 | // update anchor ref and pattern |
174 | if (child->anchor_ref != nullptr) { |
175 | ICHECK(parent->anchor_ref == nullptr); |
176 | parent->anchor_ref = child->anchor_ref; |
177 | parent->pattern = CombinePattern(child->pattern, parent->pattern); |
178 | } |
179 | } |
180 | |
181 | void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
182 | Group* target) { |
183 | if (src == sink) return; |
184 | if (visited_.count(src)) return; |
185 | visited_.insert(src); |
186 | Group* gnode = groups_[src->index]; |
187 | ICHECK(gnode != nullptr); |
188 | // merge the current group to the parent if possible. |
189 | MergeFromTo(gnode, target); |
190 | for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
191 | CommitFuse_(link->value.node, sink, target); |
192 | } |
193 | } |
194 | |
195 | void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { |
196 | Group* target = groups_[sink->index]; |
197 | visited_.clear(); |
198 | ICHECK(src != sink); |
199 | CommitFuse_(src, sink, target); |
200 | } |
201 | |
202 | size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, |
203 | IndexedForwardGraph::Node* sink) { |
204 | if (src == sink || visited_.count(src)) return 0; |
205 | visited_.insert(src); |
206 | Group* gnode = groups_[src->index]; |
207 | ICHECK(gnode != nullptr); |
208 | auto sum = gnode->num_nodes; |
209 | for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
210 | sum += CountNodesUptoSink_(link->value.node, sink); |
211 | } |
212 | return sum; |
213 | } |
214 | |
215 | size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, |
216 | IndexedForwardGraph::Node* dom_parent) { |
217 | Group* target = groups_[dom_parent->index]; |
218 | visited_.clear(); |
219 | ICHECK(child != dom_parent); |
220 | return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); |
221 | } |
222 | |
223 | void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { |
224 | groups_.resize(graph.post_dfs_order.size()); |
225 | for (size_t nid = 0; nid < groups_.size(); ++nid) { |
226 | const auto* graph_node = graph.post_dfs_order[nid]; |
227 | auto* group_node = arena_->make<Group>(); |
228 | group_node->pattern = graph_node->pattern; |
229 | group_node->root_ref = graph_node->ref; |
230 | // set anchor ref if necessary. |
231 | if (group_node->pattern == relay::kOutEWiseFusable) { |
232 | group_node->anchor_ref = graph_node->ref; |
233 | } |
234 | groups_[nid] = group_node; |
235 | } |
236 | } |
237 | |
238 | void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // |
239 | const DominatorTree& post_dom_tree, // |
240 | int phase) { |
241 | for (size_t nid = 0; nid < groups_.size(); ++nid) { |
242 | // the group of current node has been specified already. |
243 | auto* graph_node = graph.post_dfs_order[nid]; |
244 | auto* dom_node = post_dom_tree.nodes[nid]; |
245 | Group* group_node = groups_[nid]; |
246 | ICHECK(group_node != nullptr); |
247 | // no actions for opaque nodes |
248 | if (group_node->pattern == kOpaque) continue; |
249 | // no actions needed if the current node have no dominator |
250 | if (dom_node->parent == nullptr) continue; |
251 | ICHECK(!graph_node->extern_ref); |
252 | size_t dom_parent_gindex = dom_node->parent->gnode->index; |
253 | |
254 | // refuse the fusion if too many ops are going to be fused together |
255 | if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) |
256 | continue; |
257 | |
258 | if (phase == 2) { |
259 | // Fuse injective ops into intermediate tuples, if any |
260 | if (group_node->pattern > relay::kInjective) continue; |
261 | Group* dom_parent_group = groups_[dom_parent_gindex]; |
262 | Group* dom_root_group = dom_parent_group->FindRoot(); |
263 | // If dom node group has a tuple as its root, we do not fuse tuple fields into it |
264 | if (dom_root_group->pattern == relay::kTuple) continue; |
265 | if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { |
266 | // Now we know the tuple has been fused into subsequent injective ops |
267 | auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; |
268 | // dom_root_group can also be tuple, as in inception layers |
269 | // CheckPath is needed to avoid fusing two intermediate tuples |
270 | if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
271 | CommitFuse(graph_node, dom_node->parent->gnode); |
272 | } |
273 | } |
274 | continue; |
275 | } |
276 | |
277 | // Skip if current node is already fused to the parent. |
278 | if (groups_[dom_parent_gindex] != nullptr && |
279 | group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { |
280 | continue; |
281 | } |
282 | // Do not fuse into tuple for now |
283 | if (groups_[dom_parent_gindex]->pattern == kTuple) continue; |
284 | // Try to fuse current node to its post-dominator. |
285 | if (group_node->pattern == kOutEWiseFusable) { |
286 | if (phase != 0) continue; |
287 | // Path for OutEWiseFusable: conv2d |
288 | // Check if the dominator relation is elemwise. |
289 | if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { |
290 | ICHECK(dom_node->parent->gnode != nullptr); |
291 | // The fuse can be executed if all the intermediate ops are still broadcast. |
292 | auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; |
293 | if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
294 | CommitFuse(graph_node, dom_node->parent->gnode); |
295 | } |
296 | } |
297 | } else if (group_node->pattern <= kBroadcast) { |
298 | // Pre-condition: can only be fused to parent which is injective or reduction. |
299 | if (dom_node->parent != nullptr && |
300 | (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { |
301 | // Check if all the intermediate ops are still broadcast. |
302 | // The final terminal node can already be fused to a OutEWiseFusable group. |
303 | auto fcond = [](OpPatternKind kind, bool is_sink) { |
304 | if (!is_sink) { |
305 | // Elemwise, broadcast, and injective ops on the parallel branches |
306 | // are allowed be fused to the elemwise/broadcast anchor. |
307 | return kind <= kInjective; |
308 | } else { |
309 | return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || |
310 | kind == kOutEWiseFusable); |
311 | } |
312 | }; |
313 | if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
314 | CommitFuse(graph_node, dom_node->parent->gnode); |
315 | } |
316 | } |
317 | } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { |
318 | // defer injective fusion to second phase. |
319 | // so conv2d always finishes fusing. |
320 | if (phase != 1) continue; |
321 | // Check if all path are injective. |
322 | auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; |
323 | if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
324 | CommitFuse(graph_node, dom_node->parent->gnode); |
325 | } |
326 | } else { |
327 | // do nothing. |
328 | ICHECK(group_node->pattern == kCommReduce); |
329 | } |
330 | } |
331 | } |
332 | |
333 | } // namespace relay |
334 | } // namespace tvm |
335 | |