1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "./graph_partitioner.h"
21
22#include <vector>
23
24namespace tvm {
25namespace relay {
26
27DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
28 DominatorTree tree;
29 tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
30 // reverse topo order
31 for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
32 size_t index = i - 1;
33 tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
34 }
35 return tree;
36}
37
38DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
39 OpPatternKind* edge_pattern) {
40 while (lhs != rhs) {
41 if (lhs == nullptr) return nullptr;
42 if (rhs == nullptr) return nullptr;
43 if (lhs->depth < rhs->depth) {
44 edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
45 rhs = rhs->parent;
46 } else if (rhs->depth < lhs->depth) {
47 edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
48 lhs = lhs->parent;
49 } else {
50 edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
51 edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
52 lhs = lhs->parent;
53 rhs = rhs->parent;
54 }
55 }
56 return lhs;
57}
58
59DominatorTree::Node* DominatorTree::LeastCommonAncestor(
60 const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
61 auto link = input_nodes.head;
62 if (link == nullptr) {
63 return nullptr;
64 }
65 auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
66 size_t oindex = edge.node->index;
67 ICHECK_LT(oindex, nodes.size());
68 Node* onode = nodes[oindex];
69 ICHECK(onode != nullptr);
70 return onode;
71 };
72 Node* parent = get_node(link->value);
73 *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
74 link = link->next;
75 for (; link != nullptr; link = link->next) {
76 parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
77 *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
78 }
79 return parent;
80}
81
82DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
83 IndexedForwardGraph::Node* gnode) {
84 Node* tnode = arena->make<Node>();
85 tnode->gnode = gnode;
86 if (gnode->extern_ref) {
87 tnode->depth = 1;
88 tnode->parent = nullptr;
89 tnode->pattern = kOpaque;
90 } else {
91 // find the LCAs of all outputs.
92 OpPatternKind pattern = kElemWise;
93 Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
94 tnode->depth = parent ? parent->depth + 1 : 1;
95 tnode->parent = parent;
96 tnode->pattern = pattern;
97 }
98 return tnode;
99}
100
101std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
102 const IndexedForwardGraph& graph) {
103 this->InitGroups(graph);
104 if (opt_level_ == 0) return std::move(groups_);
105 // get post dominator tree
106 auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
107 // run fusion algorithm.
108 for (int phase = 0; phase < 3; ++phase) {
109 this->RunFuse(graph, post_dom_tree, phase);
110 }
111 return std::move(groups_);
112}
113
114GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() {
115 // fast path
116 if (this->parent == nullptr) return this;
117 // slow path with path compression.
118 Group* root = this;
119 while (root->parent != nullptr) {
120 root = root->parent;
121 }
122 for (Group* p = this; p != root;) {
123 Group* parent = p->parent;
124 p->parent = root;
125 p = parent;
126 }
127 return root;
128}
129
130template <typename F>
131bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
132 F fcond) {
133 if (visited_.count(src)) return true;
134 visited_.insert(src);
135 Group* gnode = groups_[src->index];
136 ICHECK(gnode != nullptr);
137 gnode = gnode->FindRoot();
138 if (!fcond(gnode->pattern, src == sink)) return false;
139 if (src == sink) return true;
140 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
141 if (!CheckPath_(link->value.node, sink, fcond)) return false;
142 }
143 return true;
144}
145
146template <typename F>
147bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
148 F fcond) {
149 ICHECK(!src->extern_ref);
150 visited_.clear();
151 ICHECK(src != sink);
152 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
153 if (!CheckPath_(link->value.node, sink, fcond)) return false;
154 }
155 return true;
156}
157
158OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
159 if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) {
160 LOG(FATAL) << "Cannot merge two complex group together";
161 }
162 if (lhs > rhs) return lhs;
163 return rhs;
164}
165
166void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
167 child = child->FindRoot();
168 parent = parent->FindRoot();
169 if (child == parent) return;
170 // update the number of nodes of the parent group
171 parent->num_nodes += child->num_nodes;
172 child->parent = parent;
173 // update anchor ref and pattern
174 if (child->anchor_ref != nullptr) {
175 ICHECK(parent->anchor_ref == nullptr);
176 parent->anchor_ref = child->anchor_ref;
177 parent->pattern = CombinePattern(child->pattern, parent->pattern);
178 }
179}
180
181void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
182 Group* target) {
183 if (src == sink) return;
184 if (visited_.count(src)) return;
185 visited_.insert(src);
186 Group* gnode = groups_[src->index];
187 ICHECK(gnode != nullptr);
188 // merge the current group to the parent if possible.
189 MergeFromTo(gnode, target);
190 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
191 CommitFuse_(link->value.node, sink, target);
192 }
193}
194
195void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
196 Group* target = groups_[sink->index];
197 visited_.clear();
198 ICHECK(src != sink);
199 CommitFuse_(src, sink, target);
200}
201
202size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src,
203 IndexedForwardGraph::Node* sink) {
204 if (src == sink || visited_.count(src)) return 0;
205 visited_.insert(src);
206 Group* gnode = groups_[src->index];
207 ICHECK(gnode != nullptr);
208 auto sum = gnode->num_nodes;
209 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
210 sum += CountNodesUptoSink_(link->value.node, sink);
211 }
212 return sum;
213}
214
215size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
216 IndexedForwardGraph::Node* dom_parent) {
217 Group* target = groups_[dom_parent->index];
218 visited_.clear();
219 ICHECK(child != dom_parent);
220 return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
221}
222
223void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
224 groups_.resize(graph.post_dfs_order.size());
225 for (size_t nid = 0; nid < groups_.size(); ++nid) {
226 const auto* graph_node = graph.post_dfs_order[nid];
227 auto* group_node = arena_->make<Group>();
228 group_node->pattern = graph_node->pattern;
229 group_node->root_ref = graph_node->ref;
230 // set anchor ref if necessary.
231 if (group_node->pattern == relay::kOutEWiseFusable) {
232 group_node->anchor_ref = graph_node->ref;
233 }
234 groups_[nid] = group_node;
235 }
236}
237
238void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
239 const DominatorTree& post_dom_tree, //
240 int phase) {
241 for (size_t nid = 0; nid < groups_.size(); ++nid) {
242 // the group of current node has been specified already.
243 auto* graph_node = graph.post_dfs_order[nid];
244 auto* dom_node = post_dom_tree.nodes[nid];
245 Group* group_node = groups_[nid];
246 ICHECK(group_node != nullptr);
247 // no actions for opaque nodes
248 if (group_node->pattern == kOpaque) continue;
249 // no actions needed if the current node have no dominator
250 if (dom_node->parent == nullptr) continue;
251 ICHECK(!graph_node->extern_ref);
252 size_t dom_parent_gindex = dom_node->parent->gnode->index;
253
254 // refuse the fusion if too many ops are going to be fused together
255 if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
256 continue;
257
258 if (phase == 2) {
259 // Fuse injective ops into intermediate tuples, if any
260 if (group_node->pattern > relay::kInjective) continue;
261 Group* dom_parent_group = groups_[dom_parent_gindex];
262 Group* dom_root_group = dom_parent_group->FindRoot();
263 // If dom node group has a tuple as its root, we do not fuse tuple fields into it
264 if (dom_root_group->pattern == relay::kTuple) continue;
265 if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
266 // Now we know the tuple has been fused into subsequent injective ops
267 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
268 // dom_root_group can also be tuple, as in inception layers
269 // CheckPath is needed to avoid fusing two intermediate tuples
270 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
271 CommitFuse(graph_node, dom_node->parent->gnode);
272 }
273 }
274 continue;
275 }
276
277 // Skip if current node is already fused to the parent.
278 if (groups_[dom_parent_gindex] != nullptr &&
279 group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
280 continue;
281 }
282 // Do not fuse into tuple for now
283 if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
284 // Try to fuse current node to its post-dominator.
285 if (group_node->pattern == kOutEWiseFusable) {
286 if (phase != 0) continue;
287 // Path for OutEWiseFusable: conv2d
288 // Check if the dominator relation is elemwise.
289 if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
290 ICHECK(dom_node->parent->gnode != nullptr);
291 // The fuse can be executed if all the intermediate ops are still broadcast.
292 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
293 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
294 CommitFuse(graph_node, dom_node->parent->gnode);
295 }
296 }
297 } else if (group_node->pattern <= kBroadcast) {
298 // Pre-condition: can only be fused to parent which is injective or reduction.
299 if (dom_node->parent != nullptr &&
300 (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
301 // Check if all the intermediate ops are still broadcast.
302 // The final terminal node can already be fused to a OutEWiseFusable group.
303 auto fcond = [](OpPatternKind kind, bool is_sink) {
304 if (!is_sink) {
305 // Elemwise, broadcast, and injective ops on the parallel branches
306 // are allowed be fused to the elemwise/broadcast anchor.
307 return kind <= kInjective;
308 } else {
309 return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
310 kind == kOutEWiseFusable);
311 }
312 };
313 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
314 CommitFuse(graph_node, dom_node->parent->gnode);
315 }
316 }
317 } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
318 // defer injective fusion to second phase.
319 // so conv2d always finishes fusing.
320 if (phase != 1) continue;
321 // Check if all path are injective.
322 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
323 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
324 CommitFuse(graph_node, dom_node->parent->gnode);
325 }
326 } else {
327 // do nothing.
328 ICHECK(group_node->pattern == kCommReduce);
329 }
330 }
331}
332
333} // namespace relay
334} // namespace tvm
335