1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/tvm/relay/dataflow_matcher.cc
22 * \brief The dataflow pattern matcher for Relay.
23 */
24
25#include <tvm/ir/global_var_supply.h>
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/dataflow_matcher.h>
28#include <tvm/relay/expr_functor.h>
29#include <tvm/relay/transform.h>
30
31#include <stack>
32
33#include "dataflow_matcher_impl.h"
34
35namespace tvm {
36namespace relay {
37
38// Pattern Matcher
39bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
40 VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr);
41 memo_.clear();
42 matched_nodes_.clear();
43 return VisitDFPattern(pattern, expr);
44}
45
46void DFPatternMatcher::ClearMap(size_t watermark) {
47 for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
48 memo_.erase(matched_nodes_[i]);
49 }
50 matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
51}
52
53bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
54 if (memoize_ && memo_.count(pattern)) {
55 ICHECK_EQ(memo_[pattern].size(), 1);
56 return expr.same_as(memo_[pattern][0]);
57 } else {
58 auto watermark = matched_nodes_.size();
59 auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
60 if (out) {
61 memo_[pattern].push_back(expr);
62 matched_nodes_.push_back(pattern);
63 VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr);
64 } else {
65 ClearMap(watermark);
66 }
67 return out;
68 }
69}
70
71bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
72 return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
73}
74
75bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
76 switch (rhs.type_code()) {
77 case kDLInt:
78 if (auto* val = lhs.as<IntImmNode>()) {
79 return val->value == rhs.operator int64_t();
80 }
81 break;
82 case kDLFloat:
83 if (auto* val = lhs.as<FloatImmNode>()) {
84 return val->value == rhs.operator double();
85 }
86 break;
87 case kTVMStr:
88 if (auto* val = lhs.as<tir::StringImmNode>()) {
89 return val->value == rhs.operator std::string();
90 } else if (auto* val = lhs.as<StringObj>()) {
91 return val->data == rhs.operator std::string();
92 }
93 break;
94 case kTVMDataType:
95 if (auto* val = lhs.as<tir::StringImmNode>()) {
96 return rhs.operator std::string() == val->value;
97 } else if (auto* val = lhs.as<StringObj>()) {
98 return rhs.operator std::string() == val->data;
99 } else {
100 ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs;
101 }
102 break;
103 case kTVMObjectHandle:
104 if (rhs.IsObjectRef<String>()) {
105 if (auto* val = lhs.as<tir::StringImmNode>()) {
106 return rhs.operator String() == val->value;
107 } else if (auto* val = lhs.as<StringObj>()) {
108 return rhs.operator String() == val->data;
109 }
110 } else {
111 // Compare the objects for structural equality
112 static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
113 ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
114 if ((*structural_equal)(lhs, GetRef<ObjectRef>(rhs.ptr<Object>()), false, true)) {
115 return true;
116 }
117 }
118 break;
119 default:
120 ICHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
121 }
122 return false;
123}
124
125bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
126 bool matches = VisitDFPattern(attr_pattern->pattern, expr);
127 if (!matches) {
128 return matches;
129 }
130 auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
131 if (const auto* op_node = expr.as<OpNode>()) {
132 Op op = GetRef<Op>(op_node);
133 for (auto kv : attributes) {
134 auto attr_name = kv.first;
135 auto attr_value = kv.second;
136 if (Op::HasAttrMap(attr_name)) {
137 auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
138 if (op_map.count(op)) {
139 matches &= MatchRetValue(attr_value, op_map[op]);
140 } else {
141 matches = false;
142 }
143 } else {
144 matches = false;
145 }
146 }
147 } else if (auto* op = expr.as<CallNode>()) {
148 matches = true;
149 // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this
150 // and replace the whole thing with a Visitor-based approach
151 ReflectionVTable* reflection = ReflectionVTable::Global();
152 auto attrs_node = const_cast<BaseAttrsNode*>(op->attrs.get());
153 // attrs may be undefined on non-op calls so we check first
154 std::vector<std::string> attr_names;
155 if (attrs_node) {
156 attr_names = reflection->ListAttrNames(attrs_node);
157 }
158 for (auto kv : attributes) {
159 std::string attr = kv.first;
160 if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) {
161 matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr));
162 } else {
163 matches = false;
164 break;
165 }
166 }
167 } else if (auto* op = expr.as<FunctionNode>()) {
168 matches = true;
169 for (auto kv : attributes) {
170 if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) {
171 matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
172 } else {
173 matches = false;
174 break;
175 }
176 }
177 } else {
178 matches = false;
179 }
180 return matches;
181}
182
183Array<DFPattern> reverse(const Array<DFPattern>& args) {
184 Array<DFPattern> new_args;
185 for (auto it = args.rbegin(); it != args.rend(); ++it) {
186 new_args.push_back(*it);
187 }
188 return new_args;
189}
190
191bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
192 // utilities
193 auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
194 if (op) {
195 if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
196 return expr_pattern->expr.as<OpNode>();
197 }
198 }
199 return nullptr;
200 };
201 auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
202 if (const auto* op_node = get_op_node(op)) {
203 if (op_node->name == op_type) {
204 return true;
205 }
206 }
207 return false;
208 };
209 auto is_expr_op = [](const Expr& expr, std::string op_type) {
210 if (const auto* call_node = expr.as<CallNode>()) {
211 if (const auto* op_node = call_node->op.as<OpNode>()) {
212 if (op_node->name == op_type) {
213 return true;
214 }
215 }
216 }
217 return false;
218 };
219
220 // logic
221 auto watermark = matched_nodes_.size();
222 if (const auto* call_node = expr.as<CallNode>()) {
223 auto matches_op = VisitDFPattern(op->op, call_node->op);
224 if (matches_op) {
225 auto watermark2 = matched_nodes_.size();
226
227 auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
228 const Array<Expr> expr_args) {
229 bool matches = true;
230 size_t i = 0;
231 if (pattern_args.defined()) {
232 if (pattern_args.size() == expr_args.size()) {
233 while (matches && i < pattern_args.size()) {
234 matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
235 ++i;
236 }
237 } else {
238 matches = false;
239 }
240 }
241 if (!matches) {
242 ClearMap(watermark2);
243 }
244 return matches;
245 };
246
247 // Standard case
248 if (match_args(op->args, call_node->args)) {
249 return true;
250 }
251 // Commutative Matching
252 if (const OpNode* op_node = get_op_node(op)) {
253 if ((op_node->name == "add") || (op_node->name == "multiply")) {
254 if (match_args(reverse(op->args), call_node->args)) {
255 return true;
256 }
257 }
258 }
259 } else {
260 ClearMap(watermark);
261 // associate divide/multiply
262 if (is_pattern_op(op, "divide")) {
263 if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
264 if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
265 (is_expr_op(call_node->args[0], "divide") ||
266 is_expr_op(call_node->args[1], "divide"))) {
267 bool out = false;
268 for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
269 auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]});
270 auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div});
271 out = VisitDFPattern(mul, expr);
272 if (out) {
273 return true;
274 } else {
275 ClearMap(watermark);
276 }
277 }
278 return out;
279 }
280 }
281 }
282 if (is_pattern_op(op, "multiply")) {
283 // associate multiply/divide
284 for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
285 if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
286 if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
287 (is_expr_op(call_node->args[0], "multiply") ||
288 is_expr_op(call_node->args[1], "multiply"))) {
289 auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]});
290 auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]});
291 return VisitDFPattern(div, expr);
292 }
293 }
294 }
295 }
296 }
297 }
298 return false;
299}
300
301// Recursively find the Dominator parent along all inputs paths.
302bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
303 auto call_node = expr.as<CallNode>();
304 auto index_node = expr_to_node(expr);
305 for (auto node : index_node->inputs_) {
306 if (!(call_node && node->ref() == call_node->op)) {
307 memoize_ = true;
308 if (VisitDFPattern(op->parent, node->ref())) {
309 return true;
310 } else {
311 memoize_ = false;
312 if (!VisitDFPattern(op->path, node->ref())) {
313 return false;
314 }
315 if (!MatchesPath(op, node->ref())) {
316 return false;
317 }
318 }
319 }
320 }
321 return true;
322}
323
324// Iteratively ensure that the parent is dominated somewhere by the child or the path
325bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
326 std::stack<Expr> stack;
327 std::unordered_set<const ExprNode*> visited;
328 stack.push(expr);
329 while (!stack.empty()) {
330 Expr current = stack.top();
331 stack.pop();
332 for (auto node : expr_to_node(current)->dominator_children_) {
333 if (visited.count(node->node_ref_) == 0) {
334 if (VisitDFPattern(op->parent, node->ref())) {
335 return true;
336 } else {
337 stack.push(node->ref());
338 }
339 visited.insert(node->node_ref_);
340 }
341 }
342 }
343 return false;
344}
345
346bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
347 if (VisitDFPattern(op->child, expr)) {
348 bool matches_path = MatchesPath(op, expr);
349 memoize_ = true;
350 if (matches_path) {
351 return DominatesParent(op, expr);
352 }
353 }
354 return false;
355}
356
357bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
358 return StructuralEqual()(op->expr, expr);
359}
360
361bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) {
362 bool matches = false;
363 if (const auto* func = expr.as<FunctionNode>()) {
364 matches = true;
365 if (op->params.defined()) {
366 size_t i = 0;
367 if (op->params.size() == func->params.size()) {
368 while (matches && i < op->params.size()) {
369 matches &= VisitDFPattern(op->params[i], func->params[i]);
370 ++i;
371 }
372 } else {
373 matches = false;
374 }
375 }
376 if (matches) {
377 matches &= VisitDFPattern(op->body, func->body);
378 }
379 }
380 return matches;
381}
382
383bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
384 bool matches = false;
385 if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
386 matches = (op->index == -1 || op->index == tuple_get_item_node->index) &&
387 VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
388 }
389 return matches;
390}
391
392bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
393 bool matches = false;
394 if (const auto* tuple_node = expr.as<TupleNode>()) {
395 matches = true;
396 if (op->fields.defined()) {
397 if (op->fields.size() == tuple_node->fields.size()) {
398 size_t i = 0;
399 while (matches && i < op->fields.size()) {
400 matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
401 ++i;
402 }
403 } else {
404 matches = false;
405 }
406 }
407 }
408 return matches;
409}
410
411bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr) {
412 if (const auto* if_node = expr.as<IfNode>()) {
413 auto cond = if_node->cond;
414 auto true_branch = if_node->true_branch;
415 auto false_branch = if_node->false_branch;
416 return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) &&
417 VisitDFPattern(op->false_branch, false_branch);
418 }
419 return false;
420}
421
422bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& expr) {
423 if (const auto* let_node = expr.as<LetNode>()) {
424 return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) &&
425 VisitDFPattern(op->body, let_node->body);
426 }
427 return false;
428}
429
430Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
431 IRModule mod(m->functions, m->type_definitions, m->Imports());
432 GlobalVarSupply global_var_supply = GlobalVarSupply(mod);
433 GlobalVar gvar = global_var_supply->FreshGlobal("_tmp", false);
434 BaseFunc func;
435 if (expr.as<FunctionNode>()) {
436 func = Downcast<Function>(expr);
437 } else {
438 func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
439 }
440 mod->Add(gvar, func);
441 mod = transform::InferType()(mod);
442 Expr ret;
443 if (expr.as<FunctionNode>()) {
444 ret = mod->Lookup(gvar);
445 } else {
446 ret = mod->Lookup(gvar).as<FunctionNode>()->body;
447 }
448 return ret;
449}
450
451bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
452 auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
453 return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
454}
455
456bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) {
457 auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
458 if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
459 return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr);
460 }
461 return false;
462}
463
464bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) {
465 auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
466 if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
467 return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr);
468 }
469 return false;
470}
471
472bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
473 bool matches = false;
474 if (const auto* var_node = expr.as<VarNode>()) {
475 matches = true;
476 if (op->name_hint() != "") {
477 matches &= op->name_hint() == var_node->name_hint();
478 }
479 }
480 return matches;
481}
482
483bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) {
484 return expr.as<ConstantNode>() != nullptr;
485}
486
487bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
488 return true;
489}
490
491bool MatchPattern(DFPattern pattern, Expr expr) {
492 std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(expr);
493 return DFPatternMatcher(expr_graph.get()).Match(pattern, expr);
494}
495
496TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
497
498/*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
499 * group overlap analysis */
500class MatchExtractor : public ExprMutator {
501 public:
502 explicit MatchExtractor(
503 const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>& inputs)
504 : inputs_(inputs) {}
505 const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>& GetMemo() {
506 return this->memo_;
507 }
508 const std::string& GetName() { return name_; }
509
510 protected:
511 Expr VisitExpr(const Expr& pre) override {
512 if (inputs_.count(pre)) {
513 return inputs_.at(pre);
514 }
515 return ExprMutator::VisitExpr(pre);
516 }
517 Expr VisitExpr_(const TupleNode* op) override {
518 auto out = ExprMutator::VisitExpr_(op);
519 name_ += "Tuple_";
520 return out;
521 };
522 Expr VisitExpr_(const FunctionNode* op) override {
523 auto out = ExprMutator::VisitExpr_(op);
524 name_ += "Function";
525 return out;
526 };
527 Expr VisitExpr_(const CallNode* call_node) override {
528 auto out = ExprMutator::VisitExpr_(call_node);
529 if (auto operation = call_node->op.as<OpNode>()) {
530 name_ += operation->name + "_";
531 } else {
532 name_ += "Call_";
533 }
534 return out;
535 };
536 Expr VisitExpr_(const LetNode* op) override {
537 auto out = ExprMutator::VisitExpr_(op);
538 name_ += "Let_";
539 return out;
540 };
541 Expr VisitExpr_(const IfNode* op) override {
542 auto out = ExprMutator::VisitExpr_(op);
543 name_ += "If_";
544 return out;
545 };
546 Expr VisitExpr_(const TupleGetItemNode* op) override {
547 auto out = ExprMutator::VisitExpr_(op);
548 name_ += "TupleGetItem" + std::to_string(op->index) + "_";
549 return out;
550 };
551 Expr VisitExpr_(const MatchNode* op) override {
552 auto out = ExprMutator::VisitExpr_(op);
553 name_ += "Match_";
554 return out;
555 };
556 std::string name_;
557 const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
558};
559
560/*! \brief Group expressions that match the pattern */
561const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatches(
562 const DFPattern& pattern, const Expr& pre) {
563 groups_.clear();
564 gid_assignments_.clear();
565
566 pattern_ = pattern;
567 pattern_graph_ = CreateIndexedGraph(pattern_);
568 std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(pre);
569 DFPatternMatcher matcher(expr_graph.get());
570 matcher_ = &matcher;
571 this->VisitExprs();
572 return this->groups_;
573}
574
575void PatternGrouper::VisitExprs() {
576 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
577 for (PostDfsIndex i = matcher_->size(); i != 0; --i) {
578 PostDfsIndex index = i - 1;
579 const auto current = matcher_->index_to_node(index)->ref();
580 if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
581 if (auto op = current.as<FunctionNode>()) {
582 if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
583 pre_partitioned.insert(current);
584 PostOrderVisit(op->body,
585 [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); });
586 }
587 }
588 if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) {
589 CreateGroup(current);
590 }
591 }
592 }
593}
594
595void PatternGrouper::CreateGroup(const Expr& expr) {
596 VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr);
597
598 int var_number = 0;
599
600 auto node_map = matcher_->GetMemo();
601 // Get fuzzy patterns
602 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
603 for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
604 auto node = pattern_graph_->index_to_node(index);
605 // Don't treat fuzzy Dominator patterns input variables for partition
606 if (auto op = node->ref().as<DominatorPatternNode>()) {
607 for (auto fuzzy_op : {op->parent, op->path}) {
608 for (auto match : node_map[fuzzy_op]) {
609 fuzzy_matches.insert(match);
610 }
611 }
612 }
613 // Don't treat Function params or body as input variables for partition
614 if (node->ref().as<FunctionPatternNode>()) {
615 if (node_map.count(node->ref())) {
616 auto matches = node_map[node->ref()];
617 for (auto match : matches) {
618 auto sub_graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
619 for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) {
620 auto sub_node = sub_graph->index_to_node(sub_index);
621 fuzzy_matches.insert(sub_node->ref());
622 }
623 }
624 }
625 }
626 }
627
628 // Create input variables
629 Group group;
630 group.root_node = expr;
631 group.matched_nodes = node_map;
632
633 std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
634 Array<Var> params;
635
636 for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
637 auto node = pattern_graph_->index_to_node(index);
638 auto make_input = [&](const Expr& input) {
639 if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
640 input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref())) {
641 // Avoid adding parameters repeatedly because multiple operatorss in the partition
642 // may use the same input.
643 if (inputs.find(input) != inputs.end()) {
644 return;
645 }
646 inputs[input] =
647 Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
648 NullValue<Type>());
649 group.args.push_back(input);
650 params.push_back(inputs[input]);
651 var_number++;
652 }
653 };
654 auto tuple = node->ref().as<TuplePatternNode>();
655 auto call = node->ref().as<CallPatternNode>();
656 if (tuple && !tuple->fields.defined()) {
657 if (node_map.count(node->ref())) {
658 auto matches = node_map[node->ref()];
659 for (auto match : matches) {
660 for (auto input : match.as<TupleNode>()->fields) {
661 make_input(input);
662 }
663 }
664 }
665 } else if (call && !call->args.defined()) {
666 if (node_map.count(node->ref())) {
667 auto matches = node_map[node->ref()];
668 for (auto match : matches) {
669 for (auto input : match.as<CallNode>()->args) {
670 make_input(input);
671 }
672 }
673 }
674 } else if (node->inputs_.size() == 0) {
675 if (node_map.count(node->ref())) {
676 auto matches = node_map[node->ref()];
677 for (auto match : matches) {
678 make_input(match);
679 }
680 }
681 }
682 }
683
684 graph_number_++;
685
686 // Extract a Function. Used in Partition directly,
687 // used to determine Group overlap in other passes
688 auto extractor = MatchExtractor(inputs);
689 auto body = extractor.Mutate(expr);
690
691 group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
692 VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function);
693 group.name = extractor.GetName();
694 // Check to make sure we aren't overlapping with another group or creating an invalid fusion
695 // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
696 // pattern with the input FunctionVar* Variables. The resulting memoization map will only
697 // contain nodes in the expression that matched the pattern. If a non-input node of the pattern
698 // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a
699 // situation where we try to rewrite the same node twice in the second rewriting or parition
700 // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants
701 // because they exist more globally outside of the fusion.
702 // Similiarly, if interior nodes in a group are used outside of the group fusing to a single
703 // output would create an invalid graph tranformation, so we block the creation of such groups.
704 auto memo = extractor.GetMemo();
705 for (auto kv : memo) {
706 VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_;
707 }
708
709 for (auto kv : memo) {
710 // Check to ensure that this node isn't an input or a global
711 if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
712 kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>() == nullptr) {
713 if (gid_assignments_.count(kv.first) != 0) {
714 // check to see if the node is use in other groups
715 // Exit due to overlapping partitions
716 return;
717 } else if (kv.second != body) {
718 // if the node isn't the output of the group
719 auto node = matcher_->expr_to_node(kv.first);
720 for (auto* output : node->outputs_) {
721 if (memo.count(output->ref()) == 0) {
722 // A node inside the matched group contributes an output to nodes outside of the matched
723 // group...
724 auto root = matcher_->expr_to_node(expr);
725 if (!root->Dominates(output)) {
726 // ...and the outside dataflow does not come back to the root of the matched group.
727 // So reject the match since it would create a cycle.
728 VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_
729 << " for root " << root->index_ << " in graph:" << std::endl
730 << matcher_->expr_graph().ToString();
731 return;
732 }
733 // else: We'll allow the output to be included in the matched group.
734 }
735 }
736 }
737 }
738 }
739 // Assign Group Ids
740 group.gid = ++gid_;
741 for (auto kv : extractor.GetMemo()) {
742 gid_assignments_[kv.first] = gid_;
743 }
744
745 // Save Group
746 groups_[group.gid] = std::move(group);
747}
748
749bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) {
750 bool embed = false;
751 if (expr.as<ConstantNode>()) {
752 if (pattern.as<ConstantPatternNode>() != nullptr) {
753 embed = true;
754 } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
755 if (expr_pat->expr.as<ConstantNode>()) {
756 embed = true;
757 }
758 } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
759 if (matcher_->Match(alt_pat->left, expr)) {
760 embed = EmbedConst(expr, alt_pat->left);
761 } else {
762 embed = EmbedConst(expr, alt_pat->right);
763 }
764 }
765 }
766 return embed;
767}
768
769// Rewrite
770
771DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type,
772 bool rewrite_once) {
773 ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
774 n->pattern = std::move(pattern);
775 n->function = std::move(function);
776 n->require_type = require_type;
777 n->rewrite_once = rewrite_once;
778 data_ = std::move(n);
779}
780
781TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
782
783TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
784 .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type,
785 bool rewrite_once) {
786 return DFPatternCallback(pattern, function, require_type, rewrite_once);
787 });
788
789Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
790 VLOG_CONTEXT << "PatternRewriter";
791 VLOG(1) << "rewriting:" << std::endl << PrettyPrint(pre);
792 auto post = pre;
793 auto last = post;
794 // rewrite the graph until it stops changing to make sure all rewrites are complete
795 int count = 0;
796 bool equal = true;
797 static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
798 ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
799 do {
800 last = post;
801 for (auto callback : callbacks) {
802 callback_ = callback;
803 if (callback_->require_type) {
804 post = InferTypeWithModule(post, mod_);
805 }
806 auto grouper = PatternGrouper();
807 groups_ = grouper.GroupMatches(callback_->pattern, post);
808 gid_assignments_ = grouper.GetGIDAssignments();
809 memo_.clear();
810 VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
811 post = this->VisitExpr(post);
812 VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
813 count++;
814 }
815 equal = (*structural_equal)(last, post, false, true);
816 } while (!equal && count < 100 && !callback_->rewrite_once);
817 if (count >= 100) {
818 LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
819 }
820 return post;
821}
822
823Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) {
824 auto post = MixedModeMutator::DispatchVisitExpr(pre);
825 if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
826 // Convert the pre-rewrite node map to a post-rewrite node map
827 auto group = groups_[gid_assignments_[pre]];
828 std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> node_map;
829 for (auto kv : group.matched_nodes) {
830 Array<Expr> tmp;
831 for (size_t i = 0; i < kv.second.size(); ++i) {
832 tmp.push_back(this->memo_[kv.second[i]]);
833 }
834 node_map.insert({kv.first, tmp});
835 }
836 // run the user callback function
837 return callback_->function(pre, post, Map<DFPattern, Array<Expr>>(node_map));
838 }
839 return post;
840}
841
842Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod) {
843 return PatternRewriter(mod).Rewrite(callbacks, expr);
844}
845
846TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns);
847
848/*!
849 * \brief PatternPartitioner replaces expressions that match a pattern with function call that
850 * perform the same computation but allow for further analysis and lowering.
851 *
852 * The class uses PatternGrouper to support the dominator pattern.
853 */
854class PatternPartitioner : protected MixedModeMutator {
855 public:
856 Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String, ObjectRef>& attrs,
857 PackedFunc check) {
858 if (pattern.as<FunctionPatternNode>()) {
859 LOG(WARNING) << "Partioning a Function that isn't called doesn't make sense, skipping"
860 << pattern;
861 return pre;
862 }
863 auto grouper = PatternGrouper();
864 groups_ = grouper.GroupMatches(pattern, pre);
865 gid_assignments_ = grouper.GetGIDAssignments();
866 attrs_ = attrs;
867 check_ = check;
868 return this->VisitExpr(pre);
869 }
870
871 protected:
872 Expr RewritePartition(const PatternGrouper::Group& group) {
873 Array<Expr> args;
874 for (size_t i = 0; i < group.args.size(); ++i) {
875 args.push_back(memo_[group.args[i]]);
876 }
877 Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name));
878 if (!attrs_.empty()) {
879 for (auto kv : attrs_) {
880 func = WithAttr(std::move(func), kv.first, kv.second);
881 }
882 }
883 return Call(func, args);
884 }
885
886 Expr DispatchVisitExpr(const Expr& pre) override {
887 auto post = MixedModeMutator::DispatchVisitExpr(pre);
888 if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node &&
889 static_cast<bool>(check_(pre))) {
890 post = RewritePartition(groups_[gid_assignments_[pre]]);
891 }
892 return post;
893 }
894
895 Map<String, ObjectRef> attrs_;
896 std::unordered_map<int, PatternGrouper::Group> groups_;
897 std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
898 PackedFunc check_;
899};
900
901Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs,
902 PackedFunc check) {
903 return PatternPartitioner().Partition(pattern, expr, attrs, check);
904}
905
906TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
907 .set_body_typed([](DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs,
908 PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); });
909
910} // namespace relay
911} // namespace tvm
912