1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/tvm/relay/dataflow_matcher.cc |
22 | * \brief The dataflow pattern matcher for Relay. |
23 | */ |
24 | |
25 | #include <tvm/ir/global_var_supply.h> |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/dataflow_matcher.h> |
28 | #include <tvm/relay/expr_functor.h> |
29 | #include <tvm/relay/transform.h> |
30 | |
31 | #include <stack> |
32 | |
33 | #include "dataflow_matcher_impl.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | // Pattern Matcher |
39 | bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { |
40 | VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr); |
41 | memo_.clear(); |
42 | matched_nodes_.clear(); |
43 | return VisitDFPattern(pattern, expr); |
44 | } |
45 | |
46 | void DFPatternMatcher::ClearMap(size_t watermark) { |
47 | for (size_t i = watermark; i < matched_nodes_.size(); ++i) { |
48 | memo_.erase(matched_nodes_[i]); |
49 | } |
50 | matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); |
51 | } |
52 | |
53 | bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { |
54 | if (memoize_ && memo_.count(pattern)) { |
55 | ICHECK_EQ(memo_[pattern].size(), 1); |
56 | return expr.same_as(memo_[pattern][0]); |
57 | } else { |
58 | auto watermark = matched_nodes_.size(); |
59 | auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); |
60 | if (out) { |
61 | memo_[pattern].push_back(expr); |
62 | matched_nodes_.push_back(pattern); |
63 | VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr); |
64 | } else { |
65 | ClearMap(watermark); |
66 | } |
67 | return out; |
68 | } |
69 | } |
70 | |
71 | bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { |
72 | return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); |
73 | } |
74 | |
75 | bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { |
76 | switch (rhs.type_code()) { |
77 | case kDLInt: |
78 | if (auto* val = lhs.as<IntImmNode>()) { |
79 | return val->value == rhs.operator int64_t(); |
80 | } |
81 | break; |
82 | case kDLFloat: |
83 | if (auto* val = lhs.as<FloatImmNode>()) { |
84 | return val->value == rhs.operator double(); |
85 | } |
86 | break; |
87 | case kTVMStr: |
88 | if (auto* val = lhs.as<tir::StringImmNode>()) { |
89 | return val->value == rhs.operator std::string(); |
90 | } else if (auto* val = lhs.as<StringObj>()) { |
91 | return val->data == rhs.operator std::string(); |
92 | } |
93 | break; |
94 | case kTVMDataType: |
95 | if (auto* val = lhs.as<tir::StringImmNode>()) { |
96 | return rhs.operator std::string() == val->value; |
97 | } else if (auto* val = lhs.as<StringObj>()) { |
98 | return rhs.operator std::string() == val->data; |
99 | } else { |
100 | ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; |
101 | } |
102 | break; |
103 | case kTVMObjectHandle: |
104 | if (rhs.IsObjectRef<String>()) { |
105 | if (auto* val = lhs.as<tir::StringImmNode>()) { |
106 | return rhs.operator String() == val->value; |
107 | } else if (auto* val = lhs.as<StringObj>()) { |
108 | return rhs.operator String() == val->data; |
109 | } |
110 | } else { |
111 | // Compare the objects for structural equality |
112 | static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual" ); |
113 | ICHECK(structural_equal) << "node.StructuralEqual is not registered." ; |
114 | if ((*structural_equal)(lhs, GetRef<ObjectRef>(rhs.ptr<Object>()), false, true)) { |
115 | return true; |
116 | } |
117 | } |
118 | break; |
119 | default: |
120 | ICHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); |
121 | } |
122 | return false; |
123 | } |
124 | |
125 | bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { |
126 | bool matches = VisitDFPattern(attr_pattern->pattern, expr); |
127 | if (!matches) { |
128 | return matches; |
129 | } |
130 | auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict; |
131 | if (const auto* op_node = expr.as<OpNode>()) { |
132 | Op op = GetRef<Op>(op_node); |
133 | for (auto kv : attributes) { |
134 | auto attr_name = kv.first; |
135 | auto attr_value = kv.second; |
136 | if (Op::HasAttrMap(attr_name)) { |
137 | auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name); |
138 | if (op_map.count(op)) { |
139 | matches &= MatchRetValue(attr_value, op_map[op]); |
140 | } else { |
141 | matches = false; |
142 | } |
143 | } else { |
144 | matches = false; |
145 | } |
146 | } |
147 | } else if (auto* op = expr.as<CallNode>()) { |
148 | matches = true; |
149 | // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this |
150 | // and replace the whole thing with a Visitor-based approach |
151 | ReflectionVTable* reflection = ReflectionVTable::Global(); |
152 | auto attrs_node = const_cast<BaseAttrsNode*>(op->attrs.get()); |
153 | // attrs may be undefined on non-op calls so we check first |
154 | std::vector<std::string> attr_names; |
155 | if (attrs_node) { |
156 | attr_names = reflection->ListAttrNames(attrs_node); |
157 | } |
158 | for (auto kv : attributes) { |
159 | std::string attr = kv.first; |
160 | if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) { |
161 | matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr)); |
162 | } else { |
163 | matches = false; |
164 | break; |
165 | } |
166 | } |
167 | } else if (auto* op = expr.as<FunctionNode>()) { |
168 | matches = true; |
169 | for (auto kv : attributes) { |
170 | if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { |
171 | matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); |
172 | } else { |
173 | matches = false; |
174 | break; |
175 | } |
176 | } |
177 | } else { |
178 | matches = false; |
179 | } |
180 | return matches; |
181 | } |
182 | |
183 | Array<DFPattern> reverse(const Array<DFPattern>& args) { |
184 | Array<DFPattern> new_args; |
185 | for (auto it = args.rbegin(); it != args.rend(); ++it) { |
186 | new_args.push_back(*it); |
187 | } |
188 | return new_args; |
189 | } |
190 | |
191 | bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { |
192 | // utilities |
193 | auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { |
194 | if (op) { |
195 | if (auto* expr_pattern = op->op.as<ExprPatternNode>()) { |
196 | return expr_pattern->expr.as<OpNode>(); |
197 | } |
198 | } |
199 | return nullptr; |
200 | }; |
201 | auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { |
202 | if (const auto* op_node = get_op_node(op)) { |
203 | if (op_node->name == op_type) { |
204 | return true; |
205 | } |
206 | } |
207 | return false; |
208 | }; |
209 | auto is_expr_op = [](const Expr& expr, std::string op_type) { |
210 | if (const auto* call_node = expr.as<CallNode>()) { |
211 | if (const auto* op_node = call_node->op.as<OpNode>()) { |
212 | if (op_node->name == op_type) { |
213 | return true; |
214 | } |
215 | } |
216 | } |
217 | return false; |
218 | }; |
219 | |
220 | // logic |
221 | auto watermark = matched_nodes_.size(); |
222 | if (const auto* call_node = expr.as<CallNode>()) { |
223 | auto matches_op = VisitDFPattern(op->op, call_node->op); |
224 | if (matches_op) { |
225 | auto watermark2 = matched_nodes_.size(); |
226 | |
227 | auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args, |
228 | const Array<Expr> expr_args) { |
229 | bool matches = true; |
230 | size_t i = 0; |
231 | if (pattern_args.defined()) { |
232 | if (pattern_args.size() == expr_args.size()) { |
233 | while (matches && i < pattern_args.size()) { |
234 | matches &= VisitDFPattern(pattern_args[i], expr_args[i]); |
235 | ++i; |
236 | } |
237 | } else { |
238 | matches = false; |
239 | } |
240 | } |
241 | if (!matches) { |
242 | ClearMap(watermark2); |
243 | } |
244 | return matches; |
245 | }; |
246 | |
247 | // Standard case |
248 | if (match_args(op->args, call_node->args)) { |
249 | return true; |
250 | } |
251 | // Commutative Matching |
252 | if (const OpNode* op_node = get_op_node(op)) { |
253 | if ((op_node->name == "add" ) || (op_node->name == "multiply" )) { |
254 | if (match_args(reverse(op->args), call_node->args)) { |
255 | return true; |
256 | } |
257 | } |
258 | } |
259 | } else { |
260 | ClearMap(watermark); |
261 | // associate divide/multiply |
262 | if (is_pattern_op(op, "divide" )) { |
263 | if (const auto* arg_node = op->args[0].as<CallPatternNode>()) { |
264 | if (is_pattern_op(arg_node, "multiply" ) && is_expr_op(expr, "multiply" ) && |
265 | (is_expr_op(call_node->args[0], "divide" ) || |
266 | is_expr_op(call_node->args[1], "divide" ))) { |
267 | bool out = false; |
268 | for (size_t arg_id = 0; arg_id < 2; ++arg_id) { |
269 | auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}); |
270 | auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}); |
271 | out = VisitDFPattern(mul, expr); |
272 | if (out) { |
273 | return true; |
274 | } else { |
275 | ClearMap(watermark); |
276 | } |
277 | } |
278 | return out; |
279 | } |
280 | } |
281 | } |
282 | if (is_pattern_op(op, "multiply" )) { |
283 | // associate multiply/divide |
284 | for (size_t arg_id = 0; arg_id < 2; ++arg_id) { |
285 | if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) { |
286 | if (is_pattern_op(arg_node, "divide" ) && is_expr_op(expr, "divide" ) && |
287 | (is_expr_op(call_node->args[0], "multiply" ) || |
288 | is_expr_op(call_node->args[1], "multiply" ))) { |
289 | auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}); |
290 | auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}); |
291 | return VisitDFPattern(div, expr); |
292 | } |
293 | } |
294 | } |
295 | } |
296 | } |
297 | } |
298 | return false; |
299 | } |
300 | |
301 | // Recursively find the Dominator parent along all inputs paths. |
302 | bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { |
303 | auto call_node = expr.as<CallNode>(); |
304 | auto index_node = expr_to_node(expr); |
305 | for (auto node : index_node->inputs_) { |
306 | if (!(call_node && node->ref() == call_node->op)) { |
307 | memoize_ = true; |
308 | if (VisitDFPattern(op->parent, node->ref())) { |
309 | return true; |
310 | } else { |
311 | memoize_ = false; |
312 | if (!VisitDFPattern(op->path, node->ref())) { |
313 | return false; |
314 | } |
315 | if (!MatchesPath(op, node->ref())) { |
316 | return false; |
317 | } |
318 | } |
319 | } |
320 | } |
321 | return true; |
322 | } |
323 | |
324 | // Iteratively ensure that the parent is dominated somewhere by the child or the path |
325 | bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { |
326 | std::stack<Expr> stack; |
327 | std::unordered_set<const ExprNode*> visited; |
328 | stack.push(expr); |
329 | while (!stack.empty()) { |
330 | Expr current = stack.top(); |
331 | stack.pop(); |
332 | for (auto node : expr_to_node(current)->dominator_children_) { |
333 | if (visited.count(node->node_ref_) == 0) { |
334 | if (VisitDFPattern(op->parent, node->ref())) { |
335 | return true; |
336 | } else { |
337 | stack.push(node->ref()); |
338 | } |
339 | visited.insert(node->node_ref_); |
340 | } |
341 | } |
342 | } |
343 | return false; |
344 | } |
345 | |
346 | bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { |
347 | if (VisitDFPattern(op->child, expr)) { |
348 | bool matches_path = MatchesPath(op, expr); |
349 | memoize_ = true; |
350 | if (matches_path) { |
351 | return DominatesParent(op, expr); |
352 | } |
353 | } |
354 | return false; |
355 | } |
356 | |
357 | bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { |
358 | return StructuralEqual()(op->expr, expr); |
359 | } |
360 | |
361 | bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) { |
362 | bool matches = false; |
363 | if (const auto* func = expr.as<FunctionNode>()) { |
364 | matches = true; |
365 | if (op->params.defined()) { |
366 | size_t i = 0; |
367 | if (op->params.size() == func->params.size()) { |
368 | while (matches && i < op->params.size()) { |
369 | matches &= VisitDFPattern(op->params[i], func->params[i]); |
370 | ++i; |
371 | } |
372 | } else { |
373 | matches = false; |
374 | } |
375 | } |
376 | if (matches) { |
377 | matches &= VisitDFPattern(op->body, func->body); |
378 | } |
379 | } |
380 | return matches; |
381 | } |
382 | |
383 | bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { |
384 | bool matches = false; |
385 | if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) { |
386 | matches = (op->index == -1 || op->index == tuple_get_item_node->index) && |
387 | VisitDFPattern(op->tuple, tuple_get_item_node->tuple); |
388 | } |
389 | return matches; |
390 | } |
391 | |
392 | bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { |
393 | bool matches = false; |
394 | if (const auto* tuple_node = expr.as<TupleNode>()) { |
395 | matches = true; |
396 | if (op->fields.defined()) { |
397 | if (op->fields.size() == tuple_node->fields.size()) { |
398 | size_t i = 0; |
399 | while (matches && i < op->fields.size()) { |
400 | matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); |
401 | ++i; |
402 | } |
403 | } else { |
404 | matches = false; |
405 | } |
406 | } |
407 | } |
408 | return matches; |
409 | } |
410 | |
411 | bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr) { |
412 | if (const auto* if_node = expr.as<IfNode>()) { |
413 | auto cond = if_node->cond; |
414 | auto true_branch = if_node->true_branch; |
415 | auto false_branch = if_node->false_branch; |
416 | return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) && |
417 | VisitDFPattern(op->false_branch, false_branch); |
418 | } |
419 | return false; |
420 | } |
421 | |
422 | bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& expr) { |
423 | if (const auto* let_node = expr.as<LetNode>()) { |
424 | return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) && |
425 | VisitDFPattern(op->body, let_node->body); |
426 | } |
427 | return false; |
428 | } |
429 | |
430 | Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { |
431 | IRModule mod(m->functions, m->type_definitions, m->Imports()); |
432 | GlobalVarSupply global_var_supply = GlobalVarSupply(mod); |
433 | GlobalVar gvar = global_var_supply->FreshGlobal("_tmp" , false); |
434 | BaseFunc func; |
435 | if (expr.as<FunctionNode>()) { |
436 | func = Downcast<Function>(expr); |
437 | } else { |
438 | func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); |
439 | } |
440 | mod->Add(gvar, func); |
441 | mod = transform::InferType()(mod); |
442 | Expr ret; |
443 | if (expr.as<FunctionNode>()) { |
444 | ret = mod->Lookup(gvar); |
445 | } else { |
446 | ret = mod->Lookup(gvar).as<FunctionNode>()->body; |
447 | } |
448 | return ret; |
449 | } |
450 | |
451 | bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { |
452 | auto expr_type = InferType(expr).as<ExprNode>()->checked_type(); |
453 | return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); |
454 | } |
455 | |
456 | bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { |
457 | auto expr_type = InferType(expr).as<ExprNode>()->checked_type(); |
458 | if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) { |
459 | return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr); |
460 | } |
461 | return false; |
462 | } |
463 | |
464 | bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { |
465 | auto expr_type = InferType(expr).as<ExprNode>()->checked_type(); |
466 | if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) { |
467 | return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); |
468 | } |
469 | return false; |
470 | } |
471 | |
472 | bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { |
473 | bool matches = false; |
474 | if (const auto* var_node = expr.as<VarNode>()) { |
475 | matches = true; |
476 | if (op->name_hint() != "" ) { |
477 | matches &= op->name_hint() == var_node->name_hint(); |
478 | } |
479 | } |
480 | return matches; |
481 | } |
482 | |
483 | bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) { |
484 | return expr.as<ConstantNode>() != nullptr; |
485 | } |
486 | |
487 | bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { |
488 | return true; |
489 | } |
490 | |
491 | bool MatchPattern(DFPattern pattern, Expr expr) { |
492 | std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(expr); |
493 | return DFPatternMatcher(expr_graph.get()).Match(pattern, expr); |
494 | } |
495 | |
496 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match" ).set_body_typed(MatchPattern); |
497 | |
498 | /*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform |
499 | * group overlap analysis */ |
500 | class : public ExprMutator { |
501 | public: |
502 | explicit ( |
503 | const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>& inputs) |
504 | : inputs_(inputs) {} |
505 | const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>& () { |
506 | return this->memo_; |
507 | } |
508 | const std::string& () { return name_; } |
509 | |
510 | protected: |
511 | Expr (const Expr& pre) override { |
512 | if (inputs_.count(pre)) { |
513 | return inputs_.at(pre); |
514 | } |
515 | return ExprMutator::VisitExpr(pre); |
516 | } |
517 | Expr (const TupleNode* op) override { |
518 | auto out = ExprMutator::VisitExpr_(op); |
519 | name_ += "Tuple_" ; |
520 | return out; |
521 | }; |
522 | Expr (const FunctionNode* op) override { |
523 | auto out = ExprMutator::VisitExpr_(op); |
524 | name_ += "Function" ; |
525 | return out; |
526 | }; |
527 | Expr (const CallNode* call_node) override { |
528 | auto out = ExprMutator::VisitExpr_(call_node); |
529 | if (auto operation = call_node->op.as<OpNode>()) { |
530 | name_ += operation->name + "_" ; |
531 | } else { |
532 | name_ += "Call_" ; |
533 | } |
534 | return out; |
535 | }; |
536 | Expr (const LetNode* op) override { |
537 | auto out = ExprMutator::VisitExpr_(op); |
538 | name_ += "Let_" ; |
539 | return out; |
540 | }; |
541 | Expr (const IfNode* op) override { |
542 | auto out = ExprMutator::VisitExpr_(op); |
543 | name_ += "If_" ; |
544 | return out; |
545 | }; |
546 | Expr (const TupleGetItemNode* op) override { |
547 | auto out = ExprMutator::VisitExpr_(op); |
548 | name_ += "TupleGetItem" + std::to_string(op->index) + "_" ; |
549 | return out; |
550 | }; |
551 | Expr (const MatchNode* op) override { |
552 | auto out = ExprMutator::VisitExpr_(op); |
553 | name_ += "Match_" ; |
554 | return out; |
555 | }; |
556 | std::string ; |
557 | const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> ; |
558 | }; |
559 | |
560 | /*! \brief Group expressions that match the pattern */ |
561 | const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatches( |
562 | const DFPattern& pattern, const Expr& pre) { |
563 | groups_.clear(); |
564 | gid_assignments_.clear(); |
565 | |
566 | pattern_ = pattern; |
567 | pattern_graph_ = CreateIndexedGraph(pattern_); |
568 | std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(pre); |
569 | DFPatternMatcher matcher(expr_graph.get()); |
570 | matcher_ = &matcher; |
571 | this->VisitExprs(); |
572 | return this->groups_; |
573 | } |
574 | |
575 | void PatternGrouper::VisitExprs() { |
576 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned; |
577 | for (PostDfsIndex i = matcher_->size(); i != 0; --i) { |
578 | PostDfsIndex index = i - 1; |
579 | const auto current = matcher_->index_to_node(index)->ref(); |
580 | if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped |
581 | if (auto op = current.as<FunctionNode>()) { |
582 | if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { |
583 | pre_partitioned.insert(current); |
584 | PostOrderVisit(op->body, |
585 | [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); |
586 | } |
587 | } |
588 | if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { |
589 | CreateGroup(current); |
590 | } |
591 | } |
592 | } |
593 | } |
594 | |
595 | void PatternGrouper::CreateGroup(const Expr& expr) { |
596 | VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr); |
597 | |
598 | int var_number = 0; |
599 | |
600 | auto node_map = matcher_->GetMemo(); |
601 | // Get fuzzy patterns |
602 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches; |
603 | for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { |
604 | auto node = pattern_graph_->index_to_node(index); |
605 | // Don't treat fuzzy Dominator patterns input variables for partition |
606 | if (auto op = node->ref().as<DominatorPatternNode>()) { |
607 | for (auto fuzzy_op : {op->parent, op->path}) { |
608 | for (auto match : node_map[fuzzy_op]) { |
609 | fuzzy_matches.insert(match); |
610 | } |
611 | } |
612 | } |
613 | // Don't treat Function params or body as input variables for partition |
614 | if (node->ref().as<FunctionPatternNode>()) { |
615 | if (node_map.count(node->ref())) { |
616 | auto matches = node_map[node->ref()]; |
617 | for (auto match : matches) { |
618 | auto sub_graph = CreateIndexedGraph(match.as<FunctionNode>()->body); |
619 | for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) { |
620 | auto sub_node = sub_graph->index_to_node(sub_index); |
621 | fuzzy_matches.insert(sub_node->ref()); |
622 | } |
623 | } |
624 | } |
625 | } |
626 | } |
627 | |
628 | // Create input variables |
629 | Group group; |
630 | group.root_node = expr; |
631 | group.matched_nodes = node_map; |
632 | |
633 | std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs; |
634 | Array<Var> params; |
635 | |
636 | for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { |
637 | auto node = pattern_graph_->index_to_node(index); |
638 | auto make_input = [&](const Expr& input) { |
639 | if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr && |
640 | input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref())) { |
641 | // Avoid adding parameters repeatedly because multiple operatorss in the partition |
642 | // may use the same input. |
643 | if (inputs.find(input) != inputs.end()) { |
644 | return; |
645 | } |
646 | inputs[input] = |
647 | Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), |
648 | NullValue<Type>()); |
649 | group.args.push_back(input); |
650 | params.push_back(inputs[input]); |
651 | var_number++; |
652 | } |
653 | }; |
654 | auto tuple = node->ref().as<TuplePatternNode>(); |
655 | auto call = node->ref().as<CallPatternNode>(); |
656 | if (tuple && !tuple->fields.defined()) { |
657 | if (node_map.count(node->ref())) { |
658 | auto matches = node_map[node->ref()]; |
659 | for (auto match : matches) { |
660 | for (auto input : match.as<TupleNode>()->fields) { |
661 | make_input(input); |
662 | } |
663 | } |
664 | } |
665 | } else if (call && !call->args.defined()) { |
666 | if (node_map.count(node->ref())) { |
667 | auto matches = node_map[node->ref()]; |
668 | for (auto match : matches) { |
669 | for (auto input : match.as<CallNode>()->args) { |
670 | make_input(input); |
671 | } |
672 | } |
673 | } |
674 | } else if (node->inputs_.size() == 0) { |
675 | if (node_map.count(node->ref())) { |
676 | auto matches = node_map[node->ref()]; |
677 | for (auto match : matches) { |
678 | make_input(match); |
679 | } |
680 | } |
681 | } |
682 | } |
683 | |
684 | graph_number_++; |
685 | |
686 | // Extract a Function. Used in Partition directly, |
687 | // used to determine Group overlap in other passes |
688 | auto = MatchExtractor(inputs); |
689 | auto body = extractor.Mutate(expr); |
690 | |
691 | group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>()); |
692 | VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function); |
693 | group.name = extractor.GetName(); |
694 | // Check to make sure we aren't overlapping with another group or creating an invalid fusion |
695 | // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the |
696 | // pattern with the input FunctionVar* Variables. The resulting memoization map will only |
697 | // contain nodes in the expression that matched the pattern. If a non-input node of the pattern |
698 | // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a |
699 | // situation where we try to rewrite the same node twice in the second rewriting or parition |
700 | // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants |
701 | // because they exist more globally outside of the fusion. |
702 | // Similiarly, if interior nodes in a group are used outside of the group fusing to a single |
703 | // output would create an invalid graph tranformation, so we block the creation of such groups. |
704 | auto memo = extractor.GetMemo(); |
705 | for (auto kv : memo) { |
706 | VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_; |
707 | } |
708 | |
709 | for (auto kv : memo) { |
710 | // Check to ensure that this node isn't an input or a global |
711 | if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr && |
712 | kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>() == nullptr) { |
713 | if (gid_assignments_.count(kv.first) != 0) { |
714 | // check to see if the node is use in other groups |
715 | // Exit due to overlapping partitions |
716 | return; |
717 | } else if (kv.second != body) { |
718 | // if the node isn't the output of the group |
719 | auto node = matcher_->expr_to_node(kv.first); |
720 | for (auto* output : node->outputs_) { |
721 | if (memo.count(output->ref()) == 0) { |
722 | // A node inside the matched group contributes an output to nodes outside of the matched |
723 | // group... |
724 | auto root = matcher_->expr_to_node(expr); |
725 | if (!root->Dominates(output)) { |
726 | // ...and the outside dataflow does not come back to the root of the matched group. |
727 | // So reject the match since it would create a cycle. |
728 | VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_ |
729 | << " for root " << root->index_ << " in graph:" << std::endl |
730 | << matcher_->expr_graph().ToString(); |
731 | return; |
732 | } |
733 | // else: We'll allow the output to be included in the matched group. |
734 | } |
735 | } |
736 | } |
737 | } |
738 | } |
739 | // Assign Group Ids |
740 | group.gid = ++gid_; |
741 | for (auto kv : extractor.GetMemo()) { |
742 | gid_assignments_[kv.first] = gid_; |
743 | } |
744 | |
745 | // Save Group |
746 | groups_[group.gid] = std::move(group); |
747 | } |
748 | |
749 | bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) { |
750 | bool embed = false; |
751 | if (expr.as<ConstantNode>()) { |
752 | if (pattern.as<ConstantPatternNode>() != nullptr) { |
753 | embed = true; |
754 | } else if (auto expr_pat = pattern.as<ExprPatternNode>()) { |
755 | if (expr_pat->expr.as<ConstantNode>()) { |
756 | embed = true; |
757 | } |
758 | } else if (auto alt_pat = pattern.as<AltPatternNode>()) { |
759 | if (matcher_->Match(alt_pat->left, expr)) { |
760 | embed = EmbedConst(expr, alt_pat->left); |
761 | } else { |
762 | embed = EmbedConst(expr, alt_pat->right); |
763 | } |
764 | } |
765 | } |
766 | return embed; |
767 | } |
768 | |
769 | // Rewrite |
770 | |
771 | DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type, |
772 | bool rewrite_once) { |
773 | ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>(); |
774 | n->pattern = std::move(pattern); |
775 | n->function = std::move(function); |
776 | n->require_type = require_type; |
777 | n->rewrite_once = rewrite_once; |
778 | data_ = std::move(n); |
779 | } |
780 | |
781 | TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); |
782 | |
783 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback" ) |
784 | .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type, |
785 | bool rewrite_once) { |
786 | return DFPatternCallback(pattern, function, require_type, rewrite_once); |
787 | }); |
788 | |
789 | Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) { |
790 | VLOG_CONTEXT << "PatternRewriter" ; |
791 | VLOG(1) << "rewriting:" << std::endl << PrettyPrint(pre); |
792 | auto post = pre; |
793 | auto last = post; |
794 | // rewrite the graph until it stops changing to make sure all rewrites are complete |
795 | int count = 0; |
796 | bool equal = true; |
797 | static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual" ); |
798 | ICHECK(structural_equal) << "node.StructuralEqual is not registered." ; |
799 | do { |
800 | last = post; |
801 | for (auto callback : callbacks) { |
802 | callback_ = callback; |
803 | if (callback_->require_type) { |
804 | post = InferTypeWithModule(post, mod_); |
805 | } |
806 | auto grouper = PatternGrouper(); |
807 | groups_ = grouper.GroupMatches(callback_->pattern, post); |
808 | gid_assignments_ = grouper.GetGIDAssignments(); |
809 | memo_.clear(); |
810 | VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); |
811 | post = this->VisitExpr(post); |
812 | VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); |
813 | count++; |
814 | } |
815 | equal = (*structural_equal)(last, post, false, true); |
816 | } while (!equal && count < 100 && !callback_->rewrite_once); |
817 | if (count >= 100) { |
818 | LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?" ; |
819 | } |
820 | return post; |
821 | } |
822 | |
823 | Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) { |
824 | auto post = MixedModeMutator::DispatchVisitExpr(pre); |
825 | if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { |
826 | // Convert the pre-rewrite node map to a post-rewrite node map |
827 | auto group = groups_[gid_assignments_[pre]]; |
828 | std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> node_map; |
829 | for (auto kv : group.matched_nodes) { |
830 | Array<Expr> tmp; |
831 | for (size_t i = 0; i < kv.second.size(); ++i) { |
832 | tmp.push_back(this->memo_[kv.second[i]]); |
833 | } |
834 | node_map.insert({kv.first, tmp}); |
835 | } |
836 | // run the user callback function |
837 | return callback_->function(pre, post, Map<DFPattern, Array<Expr>>(node_map)); |
838 | } |
839 | return post; |
840 | } |
841 | |
842 | Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod) { |
843 | return PatternRewriter(mod).Rewrite(callbacks, expr); |
844 | } |
845 | |
846 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite" ).set_body_typed(RewritePatterns); |
847 | |
848 | /*! |
849 | * \brief PatternPartitioner replaces expressions that match a pattern with function call that |
850 | * perform the same computation but allow for further analysis and lowering. |
851 | * |
852 | * The class uses PatternGrouper to support the dominator pattern. |
853 | */ |
854 | class PatternPartitioner : protected MixedModeMutator { |
855 | public: |
856 | Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String, ObjectRef>& attrs, |
857 | PackedFunc check) { |
858 | if (pattern.as<FunctionPatternNode>()) { |
859 | LOG(WARNING) << "Partioning a Function that isn't called doesn't make sense, skipping" |
860 | << pattern; |
861 | return pre; |
862 | } |
863 | auto grouper = PatternGrouper(); |
864 | groups_ = grouper.GroupMatches(pattern, pre); |
865 | gid_assignments_ = grouper.GetGIDAssignments(); |
866 | attrs_ = attrs; |
867 | check_ = check; |
868 | return this->VisitExpr(pre); |
869 | } |
870 | |
871 | protected: |
872 | Expr RewritePartition(const PatternGrouper::Group& group) { |
873 | Array<Expr> args; |
874 | for (size_t i = 0; i < group.args.size(); ++i) { |
875 | args.push_back(memo_[group.args[i]]); |
876 | } |
877 | Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name)); |
878 | if (!attrs_.empty()) { |
879 | for (auto kv : attrs_) { |
880 | func = WithAttr(std::move(func), kv.first, kv.second); |
881 | } |
882 | } |
883 | return Call(func, args); |
884 | } |
885 | |
886 | Expr DispatchVisitExpr(const Expr& pre) override { |
887 | auto post = MixedModeMutator::DispatchVisitExpr(pre); |
888 | if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node && |
889 | static_cast<bool>(check_(pre))) { |
890 | post = RewritePartition(groups_[gid_assignments_[pre]]); |
891 | } |
892 | return post; |
893 | } |
894 | |
895 | Map<String, ObjectRef> attrs_; |
896 | std::unordered_map<int, PatternGrouper::Group> groups_; |
897 | std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_; |
898 | PackedFunc check_; |
899 | }; |
900 | |
901 | Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, |
902 | PackedFunc check) { |
903 | return PatternPartitioner().Partition(pattern, expr, attrs, check); |
904 | } |
905 | |
906 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition" ) |
907 | .set_body_typed([](DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, |
908 | PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); }); |
909 | |
910 | } // namespace relay |
911 | } // namespace tvm |
912 | |