1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/sub_graph.cc
22 * \brief Represents a sub-graph of an overall Relay expression.
23 */
24
25#include "./sub_graph.h"
26
27#include <tvm/relay/transform.h>
28
29#include "../../support/scalars.h"
30#include "../transforms/pass_utils.h"
31#include "./utils.h"
32
33namespace tvm {
34namespace relay {
35namespace collage {
36
37namespace {
38
39class Extractor;
40
41/*!
42 * \brief Helper class for rewriting expressions to replace a sub-graph according to the
43 * given extractor.
44 */
45class Rewriter : public ExprMutator {
46 public:
47 explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
48
49 Expr VisitExpr(const Expr& expr) final;
50
51 private:
52 /*! \brief Already prepared extractor which will guide the rewrite. */
53 const Extractor* extractor_;
54};
55
56/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
57class Extractor : public ExprMutator {
58 public:
59 Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
60 FunctionAttrsMap opt_attrs)
61 : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
62 ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
63 }
64
65 const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
66
67 /*!
68 * \brief Collect the parameters and output expressions for the function representing
69 * the sub-graph.
70 */
71 void Extract() {
72 ICHECK(!sub_graph_->IsEmpty());
73 VLOG(2) << "Extracting " << sub_graph_->ToString();
74 const bool for_function = opt_attrs_.defined();
75
76 // In reverse dataflow order...
77 for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
78 PostDfsIndex index = i - 1;
79 if (!sub_graph_->inside_[index]) {
80 // Node is outside sub-graph.
81 continue;
82 }
83 VLOG(2) << "index " << index;
84 auto node = dataflow_graph_->index_to_node(index);
85 if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
86 // This sub-expression is:
87 // - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
88 // output (even if we've already visited it while constructing an output from a
89 // downstream sub-expression).
90 // - not yet visited, in which case it must still be considered an 'output' so it will
91 // be evaluated for any possible side effects.
92 Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
93 VLOG(2) << "index " << index << " added as output:\n"
94 << PrettyPrint(output) << "\nat " << outputs_.size();
95 expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
96 outputs_.emplace_back(std::move(output));
97 output_types_.emplace_back(node->node_ref_->checked_type());
98 }
99 }
100 ICHECK(!outputs_.empty());
101
102 // Reverse the outputs so as to preserve the original evaluation order.
103 std::reverse(outputs_.begin(), outputs_.end());
104 std::reverse(output_types_.begin(), output_types_.end());
105 for (auto& kv : expr_to_output_index_) {
106 kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
107 }
108
109 // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
110 // outputs we'll place them in a tuple.
111 Type body_type;
112 Expr body;
113 if (outputs_.size() > 1) {
114 body_type = TupleType(output_types_);
115 body = Tuple(outputs_);
116 body->checked_type_ = body_type;
117 } else {
118 body_type = output_types_.front();
119 body = outputs_.front();
120 }
121
122 // Re-express all the nested sub-graphs in terms of the body.
123 DataflowGraph body_dataflow_graph(body);
124 std::vector<NestedSubGraph> nested_sub_graphs;
125 IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
126 for (const auto& nested_sub_graph : sub_graph_->nested_sub_graphs_) {
127 nested_sub_graphs.emplace_back(nested_sub_graph.Subst(body_dataflow_graph, subst));
128 }
129
130 // Sweep backwards through the body, rewriting to account for each nested sub-graph.
131 body = NestedSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(nested_sub_graphs));
132
133 if (for_function) {
134 // Rewrite so all input nodes are now conveyed via call arguments to a new function.
135 Array<Type> arg_types;
136 arg_types.reserve(params_.size());
137 for (const auto& param : params_) {
138 arg_types.push_back(param->checked_type());
139 }
140 extracted_ = Function(std::move(params_), std::move(body), body_type,
141 /*ty_params=*/{}, DictAttrs(opt_attrs_));
142 extracted_->checked_type_ =
143 FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
144 body = Call(extracted_, std::move(args_));
145 body->checked_type_ = body_type;
146 } else {
147 // Don't do anything with the inputs.
148 extracted_ = body;
149 }
150
151 // Setup the output substitution.
152 for (const auto& kv : expr_to_output_index_) {
153 Expr expr;
154 if (outputs_.size() == 1) {
155 expr = body;
156 } else if (for_function) {
157 expr = TupleGetItem(body, kv.second);
158 expr->checked_type_ = output_types_[kv.second];
159 } else {
160 const auto* tuple_node = body.as<TupleNode>();
161 ICHECK(tuple_node);
162 expr = tuple_node->fields[kv.second];
163 }
164 VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
165 << kv.second << " (of " << outputs_.size() << " outputs)";
166 output_substitution_.emplace(kv.first, std::move(expr));
167 }
168 }
169
170 ////// Following members are valid only after Extract() has returned.
171
172 /*!
173 * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
174 * defined then will be a function.
175 */
176 Expr extracted() const { return extracted_; }
177
178 /*!
179 * \brief Returns the substitution to apply to all expression nodes in the overall expression
180 * so as to replace references to outputs of the sub-graph with their rewritten form.
181 */
182 const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
183 return output_substitution_;
184 }
185
186 private:
187 /*!
188 * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
189 * valid after \p Extract has made its backwards dataflow sweep.
190 */
191 IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
192 VLOG(2) << "building extractor substitution";
193 IndexSubst subst;
194 for (PostDfsIndex index : sub_graph_->inside_) {
195 auto orig_node = dataflow_graph_->index_to_node(index);
196 ICHECK_EQ(orig_node->index_, index);
197 auto itr = memo_.find(orig_node->ref());
198 ICHECK(itr != memo_.end());
199 auto new_node = new_dataflow_graph.item_to_node(itr->second);
200 VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
201 subst.emplace(orig_node->index_, new_node->index_);
202 }
203 return subst;
204 }
205
206 /*! \brief Returns true if \p expr is inside the sub-graph. */
207 bool inside(const Expr& expr) {
208 return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
209 }
210
211 /*!
212 * \brief Returns the variable uniquely representing \p expr, which should be
213 * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
214 *
215 * It is valid for:
216 * - An expression outside the sub-graph to be used multiple times inside the sub-graph.
217 * - An expression outside the sub-graph to be used both inside and outside the sub-graph.
218 */
219 Var VarFor(const Expr& expr) {
220 ICHECK(!inside(expr));
221 ICHECK(opt_attrs_.defined());
222 auto itr = expr_to_param_.find(expr.get());
223 if (itr != expr_to_param_.end()) {
224 return itr->second;
225 }
226 auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
227 fresh_var->checked_type_ = expr->checked_type();
228 params_.push_back(fresh_var);
229 args_.push_back(expr);
230 expr_to_param_.emplace(expr.get(), fresh_var);
231 return fresh_var;
232 }
233
234 /*!
235 * \brief If \p expr is inside the sub-graph then return it's rewritten form.
236 * If \p expr is outside the sub-graph then it must correspond to an input node.
237 * - If opt_attrs_ is defined return the variable to represent it.
238 * - Otherwise just return the expression directly.
239 *
240 * Should be called only on inputs to nodes which are inside the sub-graph.
241 */
242 Expr VisitExpr(const Expr& expr) final {
243 if (inside(expr)) {
244 return ExprMutator::VisitExpr(expr);
245 } else if (CanInline(expr)) {
246 // Implicitly include inlinable input sub-expressions.
247 return expr;
248 } else if (opt_attrs_.defined()) {
249 // Map to a function parameter.
250 return VarFor(expr);
251 } else {
252 // Stop rewriting.
253 return expr;
254 }
255 }
256
257 Expr VisitExpr_(const FunctionNode* function_node) override {
258 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
259 return GetRef<Function>(function_node);
260 }
261 return ExprMutator::VisitExpr_(function_node);
262 }
263
264 //// Context fields, passed in constructor.
265
266 /*! \brief The dataflow graph corresponding to the overall expression. */
267 const DataflowGraph* dataflow_graph_;
268 /*! \brief The sub-graph of the above we are extracting. */
269 const SubGraphNode* sub_graph_;
270 /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
271 FunctionAttrsMap opt_attrs_;
272
273 //// Result fields, available after Extract() called.
274
275 /*!
276 * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
277 */
278 Expr extracted_;
279 /*!
280 * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
281 * one exit node then each entry will be a tuple projection.
282 */
283 std::unordered_map<const ExprNode*, Expr> output_substitution_;
284
285 //// Accumulator fields, built as we visit expressions.
286
287 /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
288 Array<Var> params_;
289 /*!
290 * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
291 */
292 Array<Expr> args_;
293 /*!
294 * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
295 * in params_ which now representing them.
296 */
297 std::unordered_map<const ExprNode*, Var> expr_to_param_;
298 /*!
299 * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
300 * It is possible to have multiple outputs. It is possible one output also contributes to other
301 * outputs (ie the output is a 'tap').
302 */
303 std::vector<Expr> outputs_;
304 /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
305 std::vector<Type> output_types_;
306 /*!
307 * \brief Map from existing exit expression nodes to the index in outputs_ which should
308 * represent them in the rewritten overall expression.
309 */
310 std::unordered_map<const ExprNode*, int> expr_to_output_index_;
311};
312
313Expr Rewriter::VisitExpr(const Expr& expr) {
314 auto itr = extractor_->output_substitution().find(expr.get());
315 if (itr == extractor_->output_substitution().end()) {
316 return ExprMutator::VisitExpr(expr);
317 } else {
318 return itr->second;
319 }
320}
321
322} // namespace
323
324std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
325 class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
326 private:
327 std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
328 if (const auto* op_node = call_node->op.as<OpNode>()) {
329 auto op = GetRef<Op>(op_node);
330 static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
331 if (fpattern.count(op) == 0) {
332 VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
333 return {kOpaque, op->name};
334 } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
335 VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
336 return {kOpaque, op->name};
337 } else {
338 OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
339 VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
340 return {kind, op->name};
341 }
342 } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
343 Optional<Integer> opt_i =
344 function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
345 if (opt_i.defined()) {
346 OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
347 VLOG(1) << "TOpPattern for function is " << KindToString(kind);
348 return {kind, "call_prim"};
349 } else {
350 VLOG(1) << "calling function without TOpPattern, considering opaque";
351 return {kOpaque, "call_fun"};
352 }
353 } else {
354 VLOG(1) << "unsupported call, considering opaque";
355 return {kOpaque, "call_any"};
356 }
357 }
358
359 std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
360 VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
361 if (support::IsSimpleScalar(constant_node)) {
362 return {kElemWise, "scalar"};
363 } else {
364 return {kElemWise, "const"};
365 }
366 }
367
368 std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
369 const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
370 ICHECK(tuple_type_node != nullptr);
371 if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
372 [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
373 VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
374 return {kInjective, "tuple"};
375 } else {
376 VLOG(1) << "tuple contains non-tensors, considering opaque";
377 return {kOpaque, "tuple"};
378 }
379 }
380
381 std::pair<OpPatternKind, std::string> VisitExpr_(
382 const TupleGetItemNode* tuple_get_item_node) final {
383 const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
384 ICHECK(tuple_type_node != nullptr);
385 if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
386 [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
387 VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
388 return {kInjective, "proj"};
389 } else {
390 VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
391 return {kOpaque, "proj"};
392 }
393 }
394
395 // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
396 // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
397 // sub-language we should revise the returned operator kinds to match.
398
399 std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
400 return {kOpaque, "%" + var_node->name_hint()};
401 }
402 std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
403 return {kOpaque, "@" + global_var_node->name_hint};
404 }
405 std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
406 return {kOpaque, "`" + op_node->name};
407 }
408 std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
409 return {kOpaque, "fn"};
410 }
411 std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
412 return {kOpaque, "let"};
413 }
414 std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
415 return {kOpaque, "if"};
416 }
417 std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
418 return {kOpaque, "ref"};
419 }
420 std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
421 return {kOpaque, "ref_read"};
422 }
423 std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
424 return {kOpaque, "ref_write"};
425 }
426 std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
427 return {kOpaque, "`" + op->name_hint};
428 }
429 std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
430 return {kOpaque, "match"};
431 }
432 };
433 return Visitor().VisitExpr(sub_expr);
434}
435
436std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
437 const IndexSet& inside) {
438 std::ostringstream os;
439 bool first = true;
440 OpPatternKind max_kind = kElemWise;
441 for (PostDfsIndex index : inside) {
442 auto [sub_kind, sub_label] = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
443 if (!sub_label.empty()) {
444 if (first) {
445 first = false;
446 } else {
447 os << "+";
448 }
449 os << sub_label;
450 }
451 max_kind = CombineKinds(max_kind, sub_kind);
452 }
453 return {max_kind, os.str()};
454}
455
456IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
457 IndexSet result(matcher.size());
458 for (const auto& kv : matcher.memo()) {
459 for (const auto& matched_sub_expr : kv.second) {
460 if (CanInline(matched_sub_expr)) {
461 // Trivial sub-expressions can just be included in the extracted function body
462 // when we construct it and don't need to be considered part of the sub-graph.
463 continue;
464 }
465 if (kv.first.as<WildcardPatternNode>()) {
466 // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
467 continue;
468 }
469 result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
470 }
471 }
472 return result;
473}
474
475std::string SubGraphConfig::ToString() const {
476 std::ostringstream os;
477 os << "{max_exits=" << max_exits;
478 os << ", allow_taps=" << allow_taps;
479 os << ", max_depth=" << max_depth;
480 os << "}";
481 return os.str();
482}
483
484TVM_REGISTER_NODE_TYPE(NestedSubGraphNode);
485
486void NestedSubGraphNode::VisitAttrs(AttrVisitor* v) {
487 // TODO(mbs)
488}
489
490SubGraph NestedSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
491
492bool NestedSubGraphNode::operator==(const NestedSubGraphNode& that) const {
493 return *sub_graph().get() == *that.sub_graph().get();
494}
495
496bool NestedSubGraphNode::operator<(const NestedSubGraphNode& that) const {
497 return *sub_graph().get() < *that.sub_graph().get();
498}
499
500size_t NestedSubGraphNode::hash() const {
501 size_t h = StructuralHash()(attrs_);
502 h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
503 return h;
504}
505
506std::string NestedSubGraphNode::ToString() const {
507 std::ostringstream os;
508 os << "{sub_graph=" << sub_graph()->ToString();
509 os << ", attrs=" << PrettyPrint(attrs_);
510 os << "}";
511 return os.str();
512}
513
514Function NestedSubGraphNode::Extract(const DataflowGraph& dataflow_graph) const {
515 Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
516 extractor.Extract();
517 return Downcast<Function>(extractor.extracted());
518}
519
520Expr NestedSubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
521 Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
522 extractor.Extract();
523 Rewriter rewriter(&extractor);
524 return rewriter.VisitExpr(expr);
525}
526
527NestedSubGraph::NestedSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs) {
528 auto data = runtime::make_object<NestedSubGraphNode>();
529 data->sub_graph_obj_ = std::move(sub_graph);
530 data->attrs_ = std::move(attrs);
531 data_ = std::move(data);
532}
533
534NestedSubGraph NestedSubGraph::Subst(
535 const DataflowGraph& new_dataflow_graph,
536 const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const {
537 return NestedSubGraph(get()->sub_graph().Subst(new_dataflow_graph, subst), get()->attrs_);
538}
539
540bool NestedSubGraph::TriviallyUnionable(const NestedSubGraph& that) const {
541 if (get()->attrs_.size() != that->attrs_.size()) {
542 return false;
543 }
544 for (const auto& kv : get()->attrs_) {
545 if (kv.first == "Composite") {
546 // Even if all the attributes agree we don't consider "Composite" functions to
547 // ever be unionable.
548 // TODO(mbs): Find a cleaner way to do this.
549 return false;
550 }
551 auto itr = that->attrs_.find(kv.first);
552 if (itr == that->attrs_.end()) {
553 return false;
554 }
555 if (!StructuralEqual()(kv.second, (*itr).second)) {
556 return false;
557 }
558 }
559 return true;
560}
561
562NestedSubGraph NestedSubGraph::DisjointUnion(const DataflowGraph& dataflow_graph,
563 const NestedSubGraph& that) const {
564 ICHECK(TriviallyUnionable(that));
565 return NestedSubGraph(get()->sub_graph().DisjointUnion(dataflow_graph, that->sub_graph()),
566 get()->attrs_);
567}
568
569/*static*/
570Expr NestedSubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
571 std::vector<NestedSubGraph> nested_sub_graphs) {
572 // IMPORTANT: See the corresponding comment in SubGraph::ParallelRewrite.
573 std::sort(nested_sub_graphs.begin(), nested_sub_graphs.end(),
574 [](const NestedSubGraph& left, const NestedSubGraph& right) {
575 return left->sub_graph()->last_inside_index_ > right->sub_graph()->last_inside_index_;
576 });
577
578 Expr result = expr;
579 for (const auto& nested_sub_graph : nested_sub_graphs) {
580 result = nested_sub_graph->Rewrite(dataflow_graph, result);
581 }
582 return result;
583}
584
585TVM_REGISTER_NODE_TYPE(SubGraphNode);
586
587void SubGraphNode::VisitAttrs(AttrVisitor* v) {
588 // TODO(mbs)
589}
590
591IndexSet SubGraphNode::Downstream(const DataflowGraph& dataflow_graph) const {
592 IndexSet downstream(dataflow_graph.size());
593 for (PostDfsIndex exit_index : exit_) {
594 downstream = downstream | dataflow_graph.downstream_of(exit_index);
595 }
596 return downstream;
597}
598
599bool SubGraphNode::IsValid(const DataflowGraph& dataflow_graph,
600 const SubGraphConfig& config) const {
601 // Check we don't have too many exit nodes.
602 if (config.max_exits > 0 && exit_.PopCount() > config.max_exits) {
603 VLOG(1) << "Subgraph " << ToString() << " is invalid: " << exit_.PopCount()
604 << " exits exceeds maximum " << config.max_exits;
605 return false;
606 }
607
608 // Check the maximum path depth is in limit.
609 if (config.max_depth > 0 && depth_ > config.max_depth) {
610 VLOG(1) << "Subgraph " << ToString() << " is invalid: maximum depth " << depth_
611 << " exceeds limit " << config.max_depth;
612 return false;
613 }
614
615 // All inside nodes must be in the same basic block.
616 const DataflowGraph::Node* basic_block = nullptr;
617 for (PostDfsIndex index : inside_) {
618 auto node = dataflow_graph.index_to_node(index);
619 if (basic_block == nullptr) {
620 basic_block = node->basic_block_;
621 }
622 if (node->basic_block_ != basic_block) {
623 VLOG(1) << "Subgraph " << ToString() << " is invalid: nodes are from different basic blocks";
624 return false;
625 }
626 }
627
628 // The nested sub-graphs must be subsets and non-overlapping.
629 IndexSet union_inside(dataflow_graph.size());
630 for (const auto& nested_sub_graph : nested_sub_graphs_) {
631 if (!nested_sub_graph->sub_graph()->inside_.AreDisjoint(union_inside)) {
632 VLOG(1) << "Subgraph " << ToString() << " is invalid: nested sub-graphs overlap";
633 return false;
634 }
635 if (!nested_sub_graph->sub_graph()->inside_.IsSubset(inside_)) {
636 VLOG(1) << "Subgraph " << ToString()
637 << " is invalid: nested sub-graph is not subset of overall sub-graph";
638 return false;
639 }
640 }
641
642 if (!config.allow_taps) {
643 // Exit nodes cannot also contribute to inside nodes.
644 for (PostDfsIndex index : exit_) {
645 auto node = dataflow_graph.index_to_node(index);
646 if (AnyOutputInside(node)) {
647 VLOG(1) << "Subgraph " << ToString()
648 << " is invalid: inner node is 'tapped' and also contributes to output, but taps "
649 "are disabled";
650 return false;
651 }
652 }
653 }
654
655 // Check no output would end up feeding into any entry node.
656 for (PostDfsIndex output_index : output_) {
657 if (dataflow_graph.downstream_of(output_index).Intersects(entry_)) {
658 VLOG(1) << "Subgraph " << ToString() << " is invalid: output node " << output_index
659 << " feeds back into this sub-graph";
660 return false;
661 }
662 }
663
664 // Looks legit!
665 return true;
666}
667
668Function SubGraphNode::ExtractAsFunction(const DataflowGraph& dataflow_graph) const {
669 NestedSubGraph nested_sub_graph(GetRef<SubGraph>(this), FunctionAttrsMap());
670 return nested_sub_graph->Extract(dataflow_graph);
671}
672
673Expr SubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
674 if (nested_sub_graphs_.empty()) {
675 // Nothing to rewrite.
676 return expr;
677 }
678 Extractor extractor(&dataflow_graph, this, NullValue<FunctionAttrsMap>());
679 extractor.Extract();
680 Rewriter rewriter(&extractor);
681 return rewriter.VisitExpr(expr);
682}
683
684std::string SubGraphNode::ToString() const {
685 std::ostringstream os;
686 os << "{inside=" << inside_.ToString();
687 os << ", entry=" << entry_.ToString();
688 os << ", exit=" << exit_.ToString();
689 os << ", input=" << input_.ToString();
690 os << ", output=" << output_.ToString();
691 os << ", depth=" << depth_;
692 os << ", kind=" << KindToString(kind_);
693 if (!label_.empty()) {
694 os << ", label=" << label_;
695 }
696 for (const auto& nested_sub_graph : nested_sub_graphs_) {
697 os << ", nested_sub_graph=" << nested_sub_graph->ToString();
698 }
699 os << "}";
700 return os.str();
701}
702
703bool SubGraphNode::operator==(const SubGraphNode& that) const {
704 ICHECK_EQ(inside_.end_index(), that.inside_.end_index());
705 if (inside_ != that.inside_) {
706 return false;
707 }
708 if (nested_sub_graphs_.size() != that.nested_sub_graphs_.size()) {
709 return false;
710 }
711 for (size_t i = 0; i < nested_sub_graphs_.size(); ++i) {
712 if (*nested_sub_graphs_[i].get() != *that.nested_sub_graphs_[i].get()) {
713 return false;
714 }
715 }
716 return true;
717}
718
719bool SubGraphNode::operator<(const SubGraphNode& that) const {
720 if (first_inside_index_ < that.first_inside_index_) {
721 return true;
722 }
723 if (that.first_inside_index_ < first_inside_index_) {
724 return false;
725 }
726 return inside_ < that.inside_;
727}
728
729size_t SubGraphNode::hash() const {
730 size_t h = inside_.hash();
731 for (const auto& nested_sub_graph : nested_sub_graphs_) {
732 h ^= nested_sub_graph->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
733 }
734 return h;
735}
736
737void SubGraphNode::Init(const DataflowGraph& dataflow_graph) {
738 for (PostDfsIndex index = 0; index < inside_.end_index(); ++index) {
739 auto node = dataflow_graph.index_to_node(index);
740 if (inside_[index]) {
741 if (AnyInputOutside(node)) {
742 entry_.Add(index);
743 }
744 if (AnyOutputOutside(node) || node->is_external_) {
745 exit_.Add(index);
746 }
747 } else {
748 if (AnyInputInside(node)) {
749 output_.Add(index);
750 }
751 if (AnyOutputInside(node) && !CanInline(node->ref())) {
752 input_.Add(index);
753 }
754 }
755 }
756 depth_ = Depth(dataflow_graph);
757}
758
759size_t SubGraphNode::Depth(const DataflowGraph& dataflow_graph) const {
760 std::unordered_map<const DataflowGraph::Node*, size_t> max_depths;
761 std::vector<const DataflowGraph::Node*> stack;
762 size_t max_depth = 0;
763 // All the entry nodes have max depth 0.
764 for (PostDfsIndex index : entry_) {
765 auto node = dataflow_graph.index_to_node(index);
766 max_depths.emplace(node, 0);
767 stack.push_back(node);
768 }
769 while (!stack.empty()) {
770 const DataflowGraph::Node* node = stack.back();
771 stack.pop_back();
772 size_t next_depth = max_depths[node] + 1;
773 if (exit_[node->index_]) {
774 // If this node is external then it will have no outputs but we still wish to consider
775 // the path to the implied output as requiring one more step.
776 // Otherwise we're accounting for reaching one of the external outputs belowe.
777 max_depth = std::max(max_depth, next_depth);
778 }
779 for (const DataflowGraph::Node* output_node : node->outputs_) {
780 if (!inside_[output_node->index_]) {
781 continue;
782 }
783 if (max_depths.count(output_node) == 0) {
784 max_depths.emplace(output_node, next_depth);
785 stack.push_back(output_node);
786 } else if (next_depth > max_depths[output_node]) {
787 // We found a deeper path to an already expanded node. We'll expand again.
788 max_depths[output_node] = next_depth;
789 stack.push_back(output_node);
790 }
791 }
792 }
793 return max_depth;
794}
795
796/*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */
797bool SubGraphNode::AnyInputOutside(const DataflowGraph::Node* node) const {
798 return std::any_of(node->inputs_.begin(), node->inputs_.end(),
799 [this](const DataflowGraph::Node* sub_node) {
800 return !inside_[sub_node->index_] && !CanInline(sub_node->ref());
801 });
802}
803
804bool SubGraphNode::AnyInputInside(const DataflowGraph::Node* node) const {
805 return std::any_of(
806 node->inputs_.begin(), node->inputs_.end(),
807 [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; });
808}
809
810bool SubGraphNode::AnyOutputOutside(const DataflowGraph::Node* node) const {
811 return std::any_of(
812 node->outputs_.begin(), node->outputs_.end(),
813 [this](const DataflowGraph::Node* sub_node) { return !inside_[sub_node->index_]; });
814}
815
816bool SubGraphNode::AnyOutputInside(const DataflowGraph::Node* node) const {
817 return std::any_of(
818 node->outputs_.begin(), node->outputs_.end(),
819 [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; });
820}
821
822SubGraph::SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind,
823 String label, std::vector<NestedSubGraph> nested_sub_graphs) {
824 std::sort(nested_sub_graphs.begin(), nested_sub_graphs.end(),
825 [](const NestedSubGraph& left, const NestedSubGraph& right) {
826 return *left.get() < *right.get();
827 });
828 auto node = runtime::make_object<SubGraphNode>();
829 node->inside_ = std::move(inside);
830 node->first_inside_index_ = node->inside_.FirstInsideIndex();
831 node->last_inside_index_ = node->inside_.LastInsideIndex();
832 node->entry_ = IndexSet(node->inside_.end_index());
833 node->exit_ = IndexSet(node->inside_.end_index());
834 node->input_ = IndexSet(node->inside_.end_index());
835 node->output_ = IndexSet(node->inside_.end_index());
836 node->kind_ = kind;
837 node->label_ = std::move(label);
838 node->nested_sub_graphs_ = nested_sub_graphs;
839 node->Init(dataflow_graph);
840 data_ = std::move(node);
841}
842
843SubGraph::SubGraph(const DataflowGraph& dataflow_graph)
844 : SubGraph(dataflow_graph, IndexSet(dataflow_graph.size())) {}
845
846bool SubGraph::AreDisjoint(const SubGraph& that) const {
847 return get()->inside_.AreDisjoint(that->inside_);
848}
849
850namespace {
851/*! \brief Returns true if an output of \p left not in \p right ultimately flows into \p right. */
852bool FlowsInto(const DataflowGraph& dataflow_graph, const SubGraph& left, const SubGraph& right) {
853 for (PostDfsIndex output_index : left->output_) {
854 if (!right->inside_[output_index] &&
855 dataflow_graph.downstream_of(output_index).Intersects(right->entry_)) {
856 return true;
857 }
858 }
859 return false;
860}
861} // namespace
862
863bool SubGraph::AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const {
864 if (!get()->inside_.AreDisjoint(that->inside_)) {
865 // Easy rejection.
866 return false;
867 }
868 if (!get()->output_.Intersects(that->entry_)) {
869 // Not touching.
870 return false;
871 }
872 if (FlowsInto(dataflow_graph, *this, that) || FlowsInto(dataflow_graph, that, *this)) {
873 // Unioning would create a cycle.
874 return false;
875 }
876 return true;
877}
878
879bool SubGraph::AreSelfContained(const SubGraph& that) const {
880 return get()->output_.IsSubset(that->entry_) && that->input_.IsSubset(get()->exit_);
881}
882
883SubGraph SubGraph::DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const {
884 ICHECK(AreDisjoint(that));
885 IndexSet inside = get()->inside_ | that->inside_;
886 std::vector<NestedSubGraph> nested_sub_graphs;
887 for (const auto& nested_sub_graph : get()->nested_sub_graphs_) {
888 nested_sub_graphs.push_back(nested_sub_graph);
889 }
890 for (const auto& nested_sub_graph : that->nested_sub_graphs_) {
891 auto existing_itr = std::find_if(nested_sub_graphs.begin(), nested_sub_graphs.end(),
892 [&nested_sub_graph](const NestedSubGraph& existing) {
893 return existing.TriviallyUnionable(nested_sub_graph);
894 });
895 if (existing_itr != nested_sub_graphs.end()) {
896 *existing_itr = existing_itr->DisjointUnion(dataflow_graph, nested_sub_graph);
897 } else {
898 nested_sub_graphs.push_back(nested_sub_graph);
899 }
900 }
901 return SubGraph(dataflow_graph, std::move(inside), CombineKinds(get()->kind_, that->kind_),
902 UnionLabels(get()->label_, that->label_), std::move(nested_sub_graphs));
903}
904
905SubGraph SubGraph::WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const {
906 std::vector<NestedSubGraph> nested_sub_graphs;
907 nested_sub_graphs.push_back(NestedSubGraph(*this, attrs));
908 return SubGraph(dataflow_graph, get()->inside_, get()->kind_, get()->label_,
909 std::move(nested_sub_graphs));
910}
911
912SubGraph SubGraph::Subst(const DataflowGraph& new_dataflow_graph, const IndexSubst& subst) const {
913 IndexSet new_inside = get()->inside_.Subst(new_dataflow_graph.size(), subst);
914 std::vector<NestedSubGraph> new_nested_sub_graphs;
915 for (const auto& nested_sub_graph : get()->nested_sub_graphs_) {
916 new_nested_sub_graphs.push_back(nested_sub_graph.Subst(new_dataflow_graph, subst));
917 }
918 return SubGraph(new_dataflow_graph, std::move(new_inside), get()->kind_, get()->label_,
919 std::move(new_nested_sub_graphs));
920}
921
922/*static*/
923Expr SubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph,
924 std::vector<SubGraph> sub_graphs) {
925 // IMPORTANT:
926 // - All the sub-graphs will be w.r.t. the dataflow graph for the original expression.
927 // Each time we call Rewrite on one of those graphs the result expression will be rewritten
928 // from the final output back to the inputs. The inputs will then be shared with the original
929 // expression. Thus it is safe to iteratively rewrite all the sub-graphs without redoing the
930 // dataflow_graph and substituting indexes provided we work in reverse dataflow order.
931 // - We rely on the dataflow_graph expression reference holding the original expression alive
932 // so that the dataflow_graph will never contain dangling pointers (even though as per above
933 // we'll never dereference them).
934 std::sort(sub_graphs.begin(), sub_graphs.end(), [](const SubGraph& left, const SubGraph& right) {
935 return left->last_inside_index_ > right->last_inside_index_;
936 });
937 Expr result = dataflow_graph.expr();
938 for (const auto& sub_graph : sub_graphs) {
939 result = sub_graph->Rewrite(dataflow_graph, result);
940 }
941 return result;
942}
943
944/*!
945 * \brief A pass which partitions (the unique) global function in the module according to the
946 * post-dfs indexes in \p indexes. The partitioning must respect the configuration with \p max_exits
947 * and \p allow_taps.
948 *
949 * Each index is also paired with a label. A non-empty label denotes the index should also be
950 * included in a nested sub-graph which will be extracted as a function with the label as its
951 * "Composite" attribute. An empty label denotes the index should go into the overall partitioned
952 * "Compiler" function. In this way we can simulate the usual partitioning needed by external
953 * codegen integrations.
954 *
955 * This function is intended to support \p SubGraph unit tests and is not used by the regular
956 * compilation flow.
957 */
958transform::Pass PartitionForTesting(Integer max_exits, Bool allow_taps, String compiler,
959 Array<Integer> indexes, Array<String> labels) {
960 auto pass_func = [=](Function function, IRModule mod, transform::PassContext ctxt) {
961 ICHECK(max_exits.defined() && max_exits->value >= 0);
962 ICHECK(allow_taps.defined());
963 ICHECK(indexes.size() == labels.size());
964 VLOG(1) << "Partitioning:" << std::endl << PrettyPrint(function);
965 DataflowGraph dataflow_graph(function);
966 VLOG(1) << "Dataflow graph is:" << std::endl << dataflow_graph.indexed_graph().ToString();
967
968 // Collect the 'inside' indexes and any nested sub-graph indexes and labels.
969 std::vector<PostDfsIndex> node_indexes;
970 std::unordered_map<String, std::vector<PostDfsIndex>> nested_sub_graph_indexes;
971 node_indexes.reserve(indexes.size());
972 for (size_t i = 0; i < indexes.size(); ++i) {
973 const Integer& index = indexes[i];
974 ICHECK_GE(index->value, 0);
975 ICHECK_LT(index->value, dataflow_graph.size());
976 auto index_int = static_cast<PostDfsIndex>(index->value);
977 node_indexes.push_back(index_int);
978 const String& label = labels[i];
979 if (!label.empty()) {
980 nested_sub_graph_indexes[label].push_back(index_int);
981 }
982 }
983
984 // Build the nested sub-graphs representing the "Composite" functions (if any).
985 std::vector<NestedSubGraph> nested_sub_graphs;
986 for (const auto& kv : nested_sub_graph_indexes) {
987 FunctionAttrsMap composite_attrs;
988 composite_attrs.Set("Composite", kv.first);
989 nested_sub_graphs.emplace_back(
990 SubGraph(dataflow_graph, IndexSet(dataflow_graph.size(), kv.second)), composite_attrs);
991 }
992
993 // Build the overall sub-graph, which will include any "Composite" functions as
994 // well as any nodes without a label.
995 IndexSet inside(dataflow_graph.size(), node_indexes);
996 auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
997 SubGraph sub_graph(dataflow_graph, inside, kind, label, std::move(nested_sub_graphs));
998
999 // Push the overall sub-graph into the final "Compiler" function.
1000 FunctionAttrsMap compiler_attrs;
1001 compiler_attrs.Set("Compiler", compiler);
1002 NestedSubGraph overall_nested_sub_graph(sub_graph, compiler_attrs);
1003 SubGraph overall_sub_graph(dataflow_graph, inside, kind, label, {overall_nested_sub_graph});
1004
1005 // Check the sub-graph is valid.
1006 SubGraphConfig config;
1007 config.max_exits = static_cast<size_t>(max_exits->value);
1008 config.allow_taps = allow_taps;
1009 if (overall_sub_graph->IsValid(dataflow_graph, config)) {
1010 VLOG(1) << "Sub-graph " << overall_sub_graph->ToString() << " is considered valid";
1011 } else {
1012 VLOG(1) << "Sub-graph " << overall_sub_graph->ToString()
1013 << " is NOT considered valid, not partitioning";
1014 return function;
1015 }
1016
1017 // Do the partitioning.
1018 Function result = Downcast<Function>(overall_sub_graph->Rewrite(dataflow_graph, function));
1019 VLOG(1) << "Extracted as:" << std::endl << PrettyPrint(result);
1020
1021 return result;
1022 };
1023 return transform::CreateFunctionPass(pass_func, /*opt_level=*/0, "PartitionForTesting", {});
1024}
1025
1026TVM_REGISTER_GLOBAL("relay.collage.PartitionForTesting").set_body_typed(PartitionForTesting);
1027
1028} // namespace collage
1029} // namespace relay
1030} // namespace tvm
1031