1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/expr_functor.cc
22 * \brief A wrapper around ExprFunctor which functionally updates the AST.
23 *
24 * ExprMutator uses memoization and self return in order to amortize
25 * the cost of using functional updates.
26 */
27#include <tvm/ir/type_functor.h>
28#include <tvm/relay/adt.h>
29#include <tvm/relay/analysis.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/pattern_functor.h>
32
33#include <stack>
34
35#include "../op/annotation/annotation.h"
36#include "../op/memory/on_device.h"
37
38namespace tvm {
39namespace relay {
40MixedModeVisitor::MixedModeVisitor(int visit_limit) {
41 ICHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
42 ICHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
43 visit_limit_ = visit_limit;
44}
45
46void MixedModeVisitor::VisitLeaf(const Expr& expr) {
47 if (visit_counter_[expr.get()] < visit_limit_) {
48 ExprFunctor::VisitExpr(expr);
49 }
50 visit_counter_[expr.get()]++;
51}
52
53bool MixedModeVisitor::CheckVisited(const Expr& expr) {
54 if (visit_counter_[expr.get()] < visit_limit_) {
55 return false;
56 } else {
57 visit_counter_[expr.get()]++;
58 return true;
59 }
60}
61
62void MixedModeVisitor::VisitExpr(const Expr& expr) {
63 auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
64 auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
65 if (visit_counter_[expr.get()] < visit_limit_) {
66 ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
67 }
68}
69
70// Overwrite the VisitExpr so we don't recurse for dataflow nodes
71void MixedModeVisitor::VisitExpr_(const CallNode* op) {}
72
73// Overwrite the VisitExpr so we don't recurse for dataflow nodes
74void MixedModeVisitor::VisitExpr_(const TupleNode* op) {}
75
76// Overwrite the VisitExpr so we don't recurse for dataflow nodes
77void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {}
78
79void MixedModeMutator::VisitLeaf(const Expr& expr) {
80 if (!memo_.count(expr)) {
81 Expr ret = this->DispatchVisitExpr(expr);
82 memo_[expr] = ret;
83 }
84}
85
86bool MixedModeMutator::CheckVisited(const Expr& expr) {
87 if (memo_.count(expr)) {
88 return true;
89 } else {
90 return false;
91 }
92}
93
94Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); }
95
96Expr MixedModeMutator::VisitExpr(const Expr& expr) {
97 auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
98 auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
99 if (memo_.count(expr)) {
100 return memo_[expr];
101 } else {
102 ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
103 return memo_[expr];
104 }
105}
106
107class PostOrderRewriter : public MixedModeMutator {
108 public:
109 explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
110
111 Expr DispatchVisitExpr(const Expr& expr) final {
112 auto post = ExprFunctor::VisitExpr(expr);
113 return rewriter_->Rewrite(expr, post);
114 }
115
116 using MixedModeMutator::VisitExpr_;
117
118 Expr VisitExpr_(const LetNode* node) final {
119 auto pre_visit = [this](const LetNode* op) {
120 Expr var = this->Mutate(op->var);
121 Expr value = this->Mutate(op->value);
122 };
123 auto post_visit = [this, node](const LetNode* op) {
124 Var var = Downcast<Var>(this->Mutate(op->var));
125 Expr value = this->Mutate(op->value);
126 Expr body = this->Mutate(op->body);
127 Expr expr = GetRef<Expr>(op);
128 Expr post;
129 if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
130 post = expr;
131 } else {
132 post = Let(var, value, body);
133 }
134 // avoid rewriting the first LetNode twice
135 if (op == node) {
136 this->memo_[expr] = post;
137 } else {
138 this->memo_[expr] = this->rewriter_->Rewrite(expr, post);
139 }
140 };
141 ExpandANormalForm(node, pre_visit, post_visit);
142 return memo_[GetRef<Expr>(node)];
143 }
144
145 protected:
146 ExprRewriter* rewriter_;
147};
148
149Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) {
150 return PostOrderRewriter(rewriter).VisitExpr(expr);
151}
152
153Expr ExprMutator::VisitExpr(const Expr& expr) {
154 auto it = this->memo_.find(expr);
155 if (it != this->memo_.end()) {
156 return it->second;
157 } else {
158 Expr new_expr = ExprFunctor::VisitExpr(expr);
159 memo_[expr] = new_expr;
160 return new_expr;
161 }
162}
163
164Expr ExprMutator::VisitExpr_(const VarNode* var_node) {
165 Type type_annotation = var_node->type_annotation;
166 if (var_node->type_annotation.defined()) {
167 type_annotation = this->VisitType(var_node->type_annotation);
168 }
169 return WithFields(GetRef<Var>(var_node), var_node->vid, type_annotation);
170}
171
172Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); }
173
174Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); }
175
176Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }
177
178Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
179 tvm::Array<Expr> fields;
180 fields.reserve(tuple_node->fields.size());
181
182 for (auto field : tuple_node->fields) {
183 auto new_field = this->Mutate(field);
184 fields.push_back(new_field);
185 }
186 return WithFields(GetRef<Tuple>(tuple_node), fields);
187}
188
189Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) {
190 tvm::Array<TypeVar> ty_params;
191
192 for (auto ty_param : func_node->type_params) {
193 TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
194 ty_params.push_back(new_ty_param);
195 }
196
197 tvm::Array<Var> params;
198 for (auto param : func_node->params) {
199 Var new_param = Downcast<Var>(this->Mutate(param));
200 params.push_back(new_param);
201 }
202
203 auto ret_type = this->VisitType(func_node->ret_type);
204 auto body = this->Mutate(func_node->body);
205
206 return WithFields(GetRef<Function>(func_node), params, body, ret_type, ty_params);
207}
208
209Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
210 auto new_op = this->Mutate(call_node->op);
211
212 tvm::Array<Type> ty_args;
213 ty_args.reserve(call_node->type_args.size());
214
215 for (auto ty_arg : call_node->type_args) {
216 auto new_ty_arg = this->VisitType(ty_arg);
217 ty_args.push_back(new_ty_arg);
218 }
219
220 tvm::Array<Expr> call_args;
221 call_args.reserve(call_node->args.size());
222 for (auto arg : call_node->args) {
223 auto new_arg = this->Mutate(arg);
224 call_args.push_back(new_arg);
225 }
226
227 return WithFields(GetRef<Call>(call_node), new_op, call_args, {}, ty_args);
228}
229
230Expr ExprMutator::VisitExpr_(const LetNode* let_node) {
231 Var var = Downcast<Var>(this->Mutate(let_node->var));
232 auto value = this->Mutate(let_node->value);
233 auto body = this->Mutate(let_node->body);
234
235 return WithFields(GetRef<Let>(let_node), var, value, body);
236}
237
238Expr ExprMutator::VisitExpr_(const IfNode* if_node) {
239 auto cond = this->Mutate(if_node->cond);
240 auto true_b = this->Mutate(if_node->true_branch);
241 auto false_b = this->Mutate(if_node->false_branch);
242
243 return WithFields(GetRef<If>(if_node), cond, true_b, false_b);
244}
245
246Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) {
247 Expr tuple = this->Mutate(get_item->tuple);
248 return WithFields(GetRef<TupleGetItem>(get_item), tuple);
249}
250
251Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) {
252 Expr value = this->Mutate(ref_create->value);
253 return WithFields(GetRef<RefCreate>(ref_create), value);
254}
255
256Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) {
257 Expr ref = this->Mutate(ref_read->ref);
258 return WithFields(GetRef<RefRead>(ref_read), ref);
259}
260
261Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) {
262 Expr ref = this->Mutate(ref_write->ref);
263 Expr value = this->Mutate(ref_write->value);
264 return WithFields(GetRef<RefWrite>(ref_write), ref, value);
265}
266
267Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }
268
269Expr ExprMutator::VisitExpr_(const MatchNode* match_node) {
270 Array<Clause> clauses;
271 for (const Clause& p : match_node->clauses) {
272 clauses.push_back(VisitClause(p));
273 }
274 Expr data = Mutate(match_node->data);
275
276 return WithFields(GetRef<Match>(match_node), data, clauses);
277}
278
279Clause ExprMutator::VisitClause(const Clause& clause) {
280 Pattern lhs = VisitPattern(clause->lhs);
281 Expr rhs = Mutate(clause->rhs);
282 return WithFields(clause, lhs, rhs);
283}
284
285Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
286
287Type ExprMutator::VisitType(const Type& t) { return t; }
288
289void ExprVisitor::VisitExpr(const Expr& expr) {
290 auto it = visit_counter_.find(expr.get());
291 if (it != visit_counter_.end()) {
292 ++it->second;
293 } else {
294 using TParent = ExprFunctor<void(const Expr&)>;
295 TParent::VisitExpr(expr);
296 visit_counter_.insert({expr.get(), 1});
297 }
298}
299
300void ExprVisitor::VisitExpr_(const VarNode* op) {
301 this->VisitSpan(op->span);
302 if (op->type_annotation.defined()) {
303 this->VisitType(op->type_annotation);
304 }
305}
306
307void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); }
308
309void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); }
310
311void ExprVisitor::VisitExpr_(const TupleNode* op) {
312 this->VisitSpan(op->span);
313 for (auto field : op->fields) {
314 this->VisitExpr(field);
315 }
316}
317
318void ExprVisitor::VisitExpr_(const FunctionNode* op) {
319 this->VisitSpan(op->span);
320 for (auto param : op->params) {
321 this->VisitExpr(param);
322 }
323
324 this->VisitExpr(op->body);
325}
326
327void ExprVisitor::VisitExpr_(const CallNode* op) {
328 this->VisitSpan(op->span);
329 this->VisitExpr(op->op);
330
331 for (auto ty_arg : op->type_args) {
332 this->VisitType(ty_arg);
333 }
334
335 for (auto arg : op->args) {
336 this->VisitExpr(arg);
337 }
338}
339
340void ExprVisitor::VisitExpr_(const LetNode* op) {
341 this->VisitSpan(op->span);
342 this->VisitExpr(op->value);
343 this->VisitExpr(op->var);
344 this->VisitExpr(op->body);
345}
346
347void ExprVisitor::VisitExpr_(const IfNode* op) {
348 this->VisitSpan(op->span);
349 this->VisitExpr(op->cond);
350 this->VisitExpr(op->true_branch);
351 this->VisitExpr(op->false_branch);
352}
353
354void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
355
356void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
357 this->VisitSpan(op->span);
358 this->VisitExpr(op->tuple);
359}
360
361void ExprVisitor::VisitExpr_(const RefCreateNode* op) {
362 this->VisitSpan(op->span);
363 this->VisitExpr(op->value);
364}
365
366void ExprVisitor::VisitExpr_(const RefReadNode* op) {
367 this->VisitSpan(op->span);
368 this->VisitExpr(op->ref);
369}
370
371void ExprVisitor::VisitExpr_(const RefWriteNode* op) {
372 this->VisitSpan(op->span);
373 this->VisitExpr(op->ref);
374 this->VisitExpr(op->value);
375}
376
377void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
378 // TODO(@jroesch): visit spans
379 for (const Type& t : op->inputs) {
380 this->VisitType(t);
381 }
382 this->VisitType(op->belong_to);
383}
384
385void ExprVisitor::VisitExpr_(const MatchNode* op) {
386 this->VisitSpan(op->span);
387 this->VisitExpr(op->data);
388 for (const Clause& c : op->clauses) {
389 this->VisitClause(c);
390 }
391}
392
393void ExprVisitor::VisitClause(const Clause& op) {
394 // TODO(@jroesch): visit spans
395 this->VisitPattern(op->lhs);
396 this->VisitExpr(op->rhs);
397}
398
399void ExprVisitor::VisitPattern(const Pattern& p) { return; }
400
401void ExprVisitor::VisitType(const Type& t) { return; }
402
403void ExprVisitor::VisitSpan(const Span& span) { return; }
404
405// visitor to implement apply
406class ExprApplyVisit : public ExprVisitor {
407 public:
408 explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
409
410 void VisitExpr(const Expr& e) final {
411 if (visited_.count(e.get()) != 0) return;
412 visited_.insert(e.get());
413 ExprVisitor::VisitExpr(e);
414 f_(e);
415 }
416
417 private:
418 std::function<void(const Expr&)> f_;
419 std::unordered_set<const Object*> visited_;
420};
421
422void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
423 ExprApplyVisit(fvisit).VisitExpr(e);
424}
425
426TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) {
427 PostOrderVisit(expr, [f](const Expr& n) { f(n); });
428});
429
430// Implement bind.
431class ExprBinder : public MixedModeMutator, PatternMutator {
432 public:
433 explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) {}
434
435 using MixedModeMutator::VisitExpr_;
436
437 Expr VisitExpr_(const LetNode* op) final {
438 ICHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let";
439 return ExprMutator::VisitExpr_(op);
440 }
441
442 Expr VisitExpr_(const FunctionNode* op) final {
443 for (Var param : op->params) {
444 ICHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter";
445 }
446 return ExprMutator::VisitExpr_(op);
447 }
448
449 Expr VisitExpr_(const VarNode* op) final {
450 auto id = GetRef<Var>(op);
451 auto it = args_map_.find(id);
452 if (it != args_map_.end()) {
453 return (*it).second;
454 } else {
455 return std::move(id);
456 }
457 }
458
459 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
460
461 Clause VisitClause(const Clause& clause) final {
462 Pattern lhs = VisitPattern(clause->lhs);
463 return WithFields(clause, lhs, VisitExpr(clause->rhs));
464 }
465
466 Var VisitVar(const Var& v) final {
467 ICHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable";
468 return v;
469 }
470
471 private:
472 const tvm::Map<Var, Expr>& args_map_;
473};
474
475// This function should be called SubstAndBind, since it assumes any variables introduced
476// in the substitution right hand side should be implicitly bound in the function.
477Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
478 if (const FunctionNode* func = expr.as<FunctionNode>()) {
479 Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
480 Array<Var> new_params;
481 for (size_t i = 0; i < func->params.size(); ++i) {
482 if (!args_map.count(func->params[i])) {
483 new_params.push_back(func->params[i]);
484 }
485 }
486 if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
487 return expr;
488 }
489
490 auto ret =
491 Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
492 ret->virtual_device_ = func->virtual_device();
493
494 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
495 for (const auto& v : FreeVars(expr)) {
496 set.insert(v);
497 }
498 for (const auto& v : FreeVars(ret)) {
499 if (set.count(v) == 0) {
500 new_params.push_back(v);
501 }
502 }
503
504 ret =
505 Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
506 ret->virtual_device_ = func->virtual_device();
507
508 VLOG(4) << "Expr:\n" << expr;
509 VLOG(4) << "Ret:\n" << ret;
510
511 ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
512 return std::move(ret);
513 } else {
514 return ExprBinder(args_map).VisitExpr(expr);
515 }
516}
517
518TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) {
519 ObjectRef input = args[0];
520 if (input->IsInstance<ExprNode>()) {
521 *ret = Bind(Downcast<Expr>(input), args[1]);
522 } else {
523 ICHECK(input->IsInstance<TypeNode>());
524 *ret = Bind(Downcast<Type>(input), args[1]);
525 }
526});
527
528Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& args_map) {
529 Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
530 Array<Var> new_params;
531 for (size_t i = 0; i < func->params.size(); i++) {
532 if (!args_map.count(func->params[i])) {
533 new_params.push_back(func->params[i]);
534 } else {
535 if (const VarNode* var = args_map[func->params[i]].as<VarNode>()) {
536 new_params.push_back(GetRef<Var>(var));
537 } else {
538 ICHECK(false) << "Expected all values in args_map to be vars, but found "
539 << args_map[func->params[i]]->GetTypeKey();
540 }
541 }
542 }
543 auto ret =
544 Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
545 ret->virtual_device_ = func->virtual_device();
546 return ret;
547}
548
549void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
550 std::function<void(const LetNode*)> post_visit) {
551 std::stack<const LetNode*> stack;
552 stack.push(op);
553 bool is_anormal = true;
554 while (is_anormal) {
555 const LetNode* current_op = stack.top();
556 pre_visit(current_op);
557 if (const LetNode* new_op = current_op->body.as<LetNode>()) {
558 stack.push(new_op);
559 } else {
560 is_anormal = false;
561 }
562 }
563 while (stack.size()) {
564 const LetNode* current_op = stack.top();
565 stack.pop();
566 post_visit(current_op);
567 }
568}
569
570} // namespace relay
571} // namespace tvm
572