1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/expr_functor.cc |
22 | * \brief A wrapper around ExprFunctor which functionally updates the AST. |
23 | * |
24 | * ExprMutator uses memoization and self return in order to amortize |
25 | * the cost of using functional updates. |
26 | */ |
27 | #include <tvm/ir/type_functor.h> |
28 | #include <tvm/relay/adt.h> |
29 | #include <tvm/relay/analysis.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/pattern_functor.h> |
32 | |
33 | #include <stack> |
34 | |
35 | #include "../op/annotation/annotation.h" |
36 | #include "../op/memory/on_device.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | MixedModeVisitor::MixedModeVisitor(int visit_limit) { |
41 | ICHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0" ; |
42 | ICHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10" ; |
43 | visit_limit_ = visit_limit; |
44 | } |
45 | |
46 | void MixedModeVisitor::VisitLeaf(const Expr& expr) { |
47 | if (visit_counter_[expr.get()] < visit_limit_) { |
48 | ExprFunctor::VisitExpr(expr); |
49 | } |
50 | visit_counter_[expr.get()]++; |
51 | } |
52 | |
53 | bool MixedModeVisitor::CheckVisited(const Expr& expr) { |
54 | if (visit_counter_[expr.get()] < visit_limit_) { |
55 | return false; |
56 | } else { |
57 | visit_counter_[expr.get()]++; |
58 | return true; |
59 | } |
60 | } |
61 | |
62 | void MixedModeVisitor::VisitExpr(const Expr& expr) { |
63 | auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; |
64 | auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; |
65 | if (visit_counter_[expr.get()] < visit_limit_) { |
66 | ExpandDataflow(expr, fcheck_visited, fvisit_leaf); |
67 | } |
68 | } |
69 | |
70 | // Overwrite the VisitExpr so we don't recurse for dataflow nodes |
71 | void MixedModeVisitor::VisitExpr_(const CallNode* op) {} |
72 | |
73 | // Overwrite the VisitExpr so we don't recurse for dataflow nodes |
74 | void MixedModeVisitor::VisitExpr_(const TupleNode* op) {} |
75 | |
76 | // Overwrite the VisitExpr so we don't recurse for dataflow nodes |
77 | void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {} |
78 | |
79 | void MixedModeMutator::VisitLeaf(const Expr& expr) { |
80 | if (!memo_.count(expr)) { |
81 | Expr ret = this->DispatchVisitExpr(expr); |
82 | memo_[expr] = ret; |
83 | } |
84 | } |
85 | |
86 | bool MixedModeMutator::CheckVisited(const Expr& expr) { |
87 | if (memo_.count(expr)) { |
88 | return true; |
89 | } else { |
90 | return false; |
91 | } |
92 | } |
93 | |
94 | Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); } |
95 | |
96 | Expr MixedModeMutator::VisitExpr(const Expr& expr) { |
97 | auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; |
98 | auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; |
99 | if (memo_.count(expr)) { |
100 | return memo_[expr]; |
101 | } else { |
102 | ExpandDataflow(expr, fcheck_visited, fvisit_leaf); |
103 | return memo_[expr]; |
104 | } |
105 | } |
106 | |
107 | class PostOrderRewriter : public MixedModeMutator { |
108 | public: |
109 | explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {} |
110 | |
111 | Expr DispatchVisitExpr(const Expr& expr) final { |
112 | auto post = ExprFunctor::VisitExpr(expr); |
113 | return rewriter_->Rewrite(expr, post); |
114 | } |
115 | |
116 | using MixedModeMutator::VisitExpr_; |
117 | |
118 | Expr VisitExpr_(const LetNode* node) final { |
119 | auto pre_visit = [this](const LetNode* op) { |
120 | Expr var = this->Mutate(op->var); |
121 | Expr value = this->Mutate(op->value); |
122 | }; |
123 | auto post_visit = [this, node](const LetNode* op) { |
124 | Var var = Downcast<Var>(this->Mutate(op->var)); |
125 | Expr value = this->Mutate(op->value); |
126 | Expr body = this->Mutate(op->body); |
127 | Expr expr = GetRef<Expr>(op); |
128 | Expr post; |
129 | if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { |
130 | post = expr; |
131 | } else { |
132 | post = Let(var, value, body); |
133 | } |
134 | // avoid rewriting the first LetNode twice |
135 | if (op == node) { |
136 | this->memo_[expr] = post; |
137 | } else { |
138 | this->memo_[expr] = this->rewriter_->Rewrite(expr, post); |
139 | } |
140 | }; |
141 | ExpandANormalForm(node, pre_visit, post_visit); |
142 | return memo_[GetRef<Expr>(node)]; |
143 | } |
144 | |
145 | protected: |
146 | ExprRewriter* rewriter_; |
147 | }; |
148 | |
149 | Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) { |
150 | return PostOrderRewriter(rewriter).VisitExpr(expr); |
151 | } |
152 | |
153 | Expr ExprMutator::VisitExpr(const Expr& expr) { |
154 | auto it = this->memo_.find(expr); |
155 | if (it != this->memo_.end()) { |
156 | return it->second; |
157 | } else { |
158 | Expr new_expr = ExprFunctor::VisitExpr(expr); |
159 | memo_[expr] = new_expr; |
160 | return new_expr; |
161 | } |
162 | } |
163 | |
164 | Expr ExprMutator::VisitExpr_(const VarNode* var_node) { |
165 | Type type_annotation = var_node->type_annotation; |
166 | if (var_node->type_annotation.defined()) { |
167 | type_annotation = this->VisitType(var_node->type_annotation); |
168 | } |
169 | return WithFields(GetRef<Var>(var_node), var_node->vid, type_annotation); |
170 | } |
171 | |
172 | Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); } |
173 | |
174 | Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); } |
175 | |
176 | Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); } |
177 | |
178 | Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) { |
179 | tvm::Array<Expr> fields; |
180 | fields.reserve(tuple_node->fields.size()); |
181 | |
182 | for (auto field : tuple_node->fields) { |
183 | auto new_field = this->Mutate(field); |
184 | fields.push_back(new_field); |
185 | } |
186 | return WithFields(GetRef<Tuple>(tuple_node), fields); |
187 | } |
188 | |
189 | Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { |
190 | tvm::Array<TypeVar> ty_params; |
191 | |
192 | for (auto ty_param : func_node->type_params) { |
193 | TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param)); |
194 | ty_params.push_back(new_ty_param); |
195 | } |
196 | |
197 | tvm::Array<Var> params; |
198 | for (auto param : func_node->params) { |
199 | Var new_param = Downcast<Var>(this->Mutate(param)); |
200 | params.push_back(new_param); |
201 | } |
202 | |
203 | auto ret_type = this->VisitType(func_node->ret_type); |
204 | auto body = this->Mutate(func_node->body); |
205 | |
206 | return WithFields(GetRef<Function>(func_node), params, body, ret_type, ty_params); |
207 | } |
208 | |
209 | Expr ExprMutator::VisitExpr_(const CallNode* call_node) { |
210 | auto new_op = this->Mutate(call_node->op); |
211 | |
212 | tvm::Array<Type> ty_args; |
213 | ty_args.reserve(call_node->type_args.size()); |
214 | |
215 | for (auto ty_arg : call_node->type_args) { |
216 | auto new_ty_arg = this->VisitType(ty_arg); |
217 | ty_args.push_back(new_ty_arg); |
218 | } |
219 | |
220 | tvm::Array<Expr> call_args; |
221 | call_args.reserve(call_node->args.size()); |
222 | for (auto arg : call_node->args) { |
223 | auto new_arg = this->Mutate(arg); |
224 | call_args.push_back(new_arg); |
225 | } |
226 | |
227 | return WithFields(GetRef<Call>(call_node), new_op, call_args, {}, ty_args); |
228 | } |
229 | |
230 | Expr ExprMutator::VisitExpr_(const LetNode* let_node) { |
231 | Var var = Downcast<Var>(this->Mutate(let_node->var)); |
232 | auto value = this->Mutate(let_node->value); |
233 | auto body = this->Mutate(let_node->body); |
234 | |
235 | return WithFields(GetRef<Let>(let_node), var, value, body); |
236 | } |
237 | |
238 | Expr ExprMutator::VisitExpr_(const IfNode* if_node) { |
239 | auto cond = this->Mutate(if_node->cond); |
240 | auto true_b = this->Mutate(if_node->true_branch); |
241 | auto false_b = this->Mutate(if_node->false_branch); |
242 | |
243 | return WithFields(GetRef<If>(if_node), cond, true_b, false_b); |
244 | } |
245 | |
246 | Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { |
247 | Expr tuple = this->Mutate(get_item->tuple); |
248 | return WithFields(GetRef<TupleGetItem>(get_item), tuple); |
249 | } |
250 | |
251 | Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) { |
252 | Expr value = this->Mutate(ref_create->value); |
253 | return WithFields(GetRef<RefCreate>(ref_create), value); |
254 | } |
255 | |
256 | Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) { |
257 | Expr ref = this->Mutate(ref_read->ref); |
258 | return WithFields(GetRef<RefRead>(ref_read), ref); |
259 | } |
260 | |
261 | Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) { |
262 | Expr ref = this->Mutate(ref_write->ref); |
263 | Expr value = this->Mutate(ref_write->value); |
264 | return WithFields(GetRef<RefWrite>(ref_write), ref, value); |
265 | } |
266 | |
267 | Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); } |
268 | |
269 | Expr ExprMutator::VisitExpr_(const MatchNode* match_node) { |
270 | Array<Clause> clauses; |
271 | for (const Clause& p : match_node->clauses) { |
272 | clauses.push_back(VisitClause(p)); |
273 | } |
274 | Expr data = Mutate(match_node->data); |
275 | |
276 | return WithFields(GetRef<Match>(match_node), data, clauses); |
277 | } |
278 | |
279 | Clause ExprMutator::VisitClause(const Clause& clause) { |
280 | Pattern lhs = VisitPattern(clause->lhs); |
281 | Expr rhs = Mutate(clause->rhs); |
282 | return WithFields(clause, lhs, rhs); |
283 | } |
284 | |
285 | Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } |
286 | |
287 | Type ExprMutator::VisitType(const Type& t) { return t; } |
288 | |
289 | void ExprVisitor::VisitExpr(const Expr& expr) { |
290 | auto it = visit_counter_.find(expr.get()); |
291 | if (it != visit_counter_.end()) { |
292 | ++it->second; |
293 | } else { |
294 | using TParent = ExprFunctor<void(const Expr&)>; |
295 | TParent::VisitExpr(expr); |
296 | visit_counter_.insert({expr.get(), 1}); |
297 | } |
298 | } |
299 | |
300 | void ExprVisitor::VisitExpr_(const VarNode* op) { |
301 | this->VisitSpan(op->span); |
302 | if (op->type_annotation.defined()) { |
303 | this->VisitType(op->type_annotation); |
304 | } |
305 | } |
306 | |
307 | void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } |
308 | |
309 | void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } |
310 | |
311 | void ExprVisitor::VisitExpr_(const TupleNode* op) { |
312 | this->VisitSpan(op->span); |
313 | for (auto field : op->fields) { |
314 | this->VisitExpr(field); |
315 | } |
316 | } |
317 | |
318 | void ExprVisitor::VisitExpr_(const FunctionNode* op) { |
319 | this->VisitSpan(op->span); |
320 | for (auto param : op->params) { |
321 | this->VisitExpr(param); |
322 | } |
323 | |
324 | this->VisitExpr(op->body); |
325 | } |
326 | |
327 | void ExprVisitor::VisitExpr_(const CallNode* op) { |
328 | this->VisitSpan(op->span); |
329 | this->VisitExpr(op->op); |
330 | |
331 | for (auto ty_arg : op->type_args) { |
332 | this->VisitType(ty_arg); |
333 | } |
334 | |
335 | for (auto arg : op->args) { |
336 | this->VisitExpr(arg); |
337 | } |
338 | } |
339 | |
340 | void ExprVisitor::VisitExpr_(const LetNode* op) { |
341 | this->VisitSpan(op->span); |
342 | this->VisitExpr(op->value); |
343 | this->VisitExpr(op->var); |
344 | this->VisitExpr(op->body); |
345 | } |
346 | |
347 | void ExprVisitor::VisitExpr_(const IfNode* op) { |
348 | this->VisitSpan(op->span); |
349 | this->VisitExpr(op->cond); |
350 | this->VisitExpr(op->true_branch); |
351 | this->VisitExpr(op->false_branch); |
352 | } |
353 | |
354 | void ExprVisitor::VisitExpr_(const OpNode* op) { return; } |
355 | |
356 | void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { |
357 | this->VisitSpan(op->span); |
358 | this->VisitExpr(op->tuple); |
359 | } |
360 | |
361 | void ExprVisitor::VisitExpr_(const RefCreateNode* op) { |
362 | this->VisitSpan(op->span); |
363 | this->VisitExpr(op->value); |
364 | } |
365 | |
366 | void ExprVisitor::VisitExpr_(const RefReadNode* op) { |
367 | this->VisitSpan(op->span); |
368 | this->VisitExpr(op->ref); |
369 | } |
370 | |
371 | void ExprVisitor::VisitExpr_(const RefWriteNode* op) { |
372 | this->VisitSpan(op->span); |
373 | this->VisitExpr(op->ref); |
374 | this->VisitExpr(op->value); |
375 | } |
376 | |
377 | void ExprVisitor::VisitExpr_(const ConstructorNode* op) { |
378 | // TODO(@jroesch): visit spans |
379 | for (const Type& t : op->inputs) { |
380 | this->VisitType(t); |
381 | } |
382 | this->VisitType(op->belong_to); |
383 | } |
384 | |
385 | void ExprVisitor::VisitExpr_(const MatchNode* op) { |
386 | this->VisitSpan(op->span); |
387 | this->VisitExpr(op->data); |
388 | for (const Clause& c : op->clauses) { |
389 | this->VisitClause(c); |
390 | } |
391 | } |
392 | |
393 | void ExprVisitor::VisitClause(const Clause& op) { |
394 | // TODO(@jroesch): visit spans |
395 | this->VisitPattern(op->lhs); |
396 | this->VisitExpr(op->rhs); |
397 | } |
398 | |
399 | void ExprVisitor::VisitPattern(const Pattern& p) { return; } |
400 | |
401 | void ExprVisitor::VisitType(const Type& t) { return; } |
402 | |
403 | void ExprVisitor::VisitSpan(const Span& span) { return; } |
404 | |
405 | // visitor to implement apply |
406 | class ExprApplyVisit : public ExprVisitor { |
407 | public: |
408 | explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {} |
409 | |
410 | void VisitExpr(const Expr& e) final { |
411 | if (visited_.count(e.get()) != 0) return; |
412 | visited_.insert(e.get()); |
413 | ExprVisitor::VisitExpr(e); |
414 | f_(e); |
415 | } |
416 | |
417 | private: |
418 | std::function<void(const Expr&)> f_; |
419 | std::unordered_set<const Object*> visited_; |
420 | }; |
421 | |
422 | void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { |
423 | ExprApplyVisit(fvisit).VisitExpr(e); |
424 | } |
425 | |
426 | TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit" ).set_body_typed([](Expr expr, PackedFunc f) { |
427 | PostOrderVisit(expr, [f](const Expr& n) { f(n); }); |
428 | }); |
429 | |
430 | // Implement bind. |
431 | class ExprBinder : public MixedModeMutator, PatternMutator { |
432 | public: |
433 | explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) {} |
434 | |
435 | using MixedModeMutator::VisitExpr_; |
436 | |
437 | Expr VisitExpr_(const LetNode* op) final { |
438 | ICHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let" ; |
439 | return ExprMutator::VisitExpr_(op); |
440 | } |
441 | |
442 | Expr VisitExpr_(const FunctionNode* op) final { |
443 | for (Var param : op->params) { |
444 | ICHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter" ; |
445 | } |
446 | return ExprMutator::VisitExpr_(op); |
447 | } |
448 | |
449 | Expr VisitExpr_(const VarNode* op) final { |
450 | auto id = GetRef<Var>(op); |
451 | auto it = args_map_.find(id); |
452 | if (it != args_map_.end()) { |
453 | return (*it).second; |
454 | } else { |
455 | return std::move(id); |
456 | } |
457 | } |
458 | |
459 | Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } |
460 | |
461 | Clause VisitClause(const Clause& clause) final { |
462 | Pattern lhs = VisitPattern(clause->lhs); |
463 | return WithFields(clause, lhs, VisitExpr(clause->rhs)); |
464 | } |
465 | |
466 | Var VisitVar(const Var& v) final { |
467 | ICHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable" ; |
468 | return v; |
469 | } |
470 | |
471 | private: |
472 | const tvm::Map<Var, Expr>& args_map_; |
473 | }; |
474 | |
475 | // This function should be called SubstAndBind, since it assumes any variables introduced |
476 | // in the substitution right hand side should be implicitly bound in the function. |
477 | Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { |
478 | if (const FunctionNode* func = expr.as<FunctionNode>()) { |
479 | Expr new_body = ExprBinder(args_map).VisitExpr(func->body); |
480 | Array<Var> new_params; |
481 | for (size_t i = 0; i < func->params.size(); ++i) { |
482 | if (!args_map.count(func->params[i])) { |
483 | new_params.push_back(func->params[i]); |
484 | } |
485 | } |
486 | if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { |
487 | return expr; |
488 | } |
489 | |
490 | auto ret = |
491 | Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); |
492 | ret->virtual_device_ = func->virtual_device(); |
493 | |
494 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set; |
495 | for (const auto& v : FreeVars(expr)) { |
496 | set.insert(v); |
497 | } |
498 | for (const auto& v : FreeVars(ret)) { |
499 | if (set.count(v) == 0) { |
500 | new_params.push_back(v); |
501 | } |
502 | } |
503 | |
504 | ret = |
505 | Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); |
506 | ret->virtual_device_ = func->virtual_device(); |
507 | |
508 | VLOG(4) << "Expr:\n" << expr; |
509 | VLOG(4) << "Ret:\n" << ret; |
510 | |
511 | ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); |
512 | return std::move(ret); |
513 | } else { |
514 | return ExprBinder(args_map).VisitExpr(expr); |
515 | } |
516 | } |
517 | |
518 | TVM_REGISTER_GLOBAL("relay.ir.Bind" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
519 | ObjectRef input = args[0]; |
520 | if (input->IsInstance<ExprNode>()) { |
521 | *ret = Bind(Downcast<Expr>(input), args[1]); |
522 | } else { |
523 | ICHECK(input->IsInstance<TypeNode>()); |
524 | *ret = Bind(Downcast<Type>(input), args[1]); |
525 | } |
526 | }); |
527 | |
528 | Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& args_map) { |
529 | Expr new_body = ExprBinder(args_map).VisitExpr(func->body); |
530 | Array<Var> new_params; |
531 | for (size_t i = 0; i < func->params.size(); i++) { |
532 | if (!args_map.count(func->params[i])) { |
533 | new_params.push_back(func->params[i]); |
534 | } else { |
535 | if (const VarNode* var = args_map[func->params[i]].as<VarNode>()) { |
536 | new_params.push_back(GetRef<Var>(var)); |
537 | } else { |
538 | ICHECK(false) << "Expected all values in args_map to be vars, but found " |
539 | << args_map[func->params[i]]->GetTypeKey(); |
540 | } |
541 | } |
542 | } |
543 | auto ret = |
544 | Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); |
545 | ret->virtual_device_ = func->virtual_device(); |
546 | return ret; |
547 | } |
548 | |
549 | void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit, |
550 | std::function<void(const LetNode*)> post_visit) { |
551 | std::stack<const LetNode*> stack; |
552 | stack.push(op); |
553 | bool is_anormal = true; |
554 | while (is_anormal) { |
555 | const LetNode* current_op = stack.top(); |
556 | pre_visit(current_op); |
557 | if (const LetNode* new_op = current_op->body.as<LetNode>()) { |
558 | stack.push(new_op); |
559 | } else { |
560 | is_anormal = false; |
561 | } |
562 | } |
563 | while (stack.size()) { |
564 | const LetNode* current_op = stack.top(); |
565 | stack.pop(); |
566 | post_visit(current_op); |
567 | } |
568 | } |
569 | |
570 | } // namespace relay |
571 | } // namespace tvm |
572 | |