1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file de_duplicate.cc |
23 | * \brief Use a fresh Id for every Var to make the result well-formed. |
24 | */ |
25 | #include <tvm/ir/type_functor.h> |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/pattern_functor.h> |
29 | |
30 | #include <stack> |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | Expr DeDup(const Expr& e) { |
36 | class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator { |
37 | public: |
38 | TypeVar Fresh(const TypeVar& tv) { |
39 | TypeVar ret = TypeVar(tv->name_hint, tv->kind); |
40 | type_rename_[tv] = ret; |
41 | return ret; |
42 | } |
43 | |
44 | Var Fresh(const Var& v) { |
45 | ICHECK_EQ(rename_.count(v), 0); |
46 | ICHECK_EQ(memo_.count(v), 0) << v.as<VarNode>(); |
47 | Var ret = Var(v->name_hint(), VisitType(v->type_annotation)); |
48 | rename_[v] = ret; |
49 | return ret; |
50 | } |
51 | |
52 | Expr DispatchVisitExpr(const Expr& e) final { |
53 | auto ret = ExprMutator::VisitExpr(e); |
54 | ret->checked_type_ = e->checked_type_; |
55 | ret->virtual_device_ = e->virtual_device_; |
56 | return ret; |
57 | } |
58 | |
59 | using MixedModeMutator::VisitExpr_; |
60 | |
61 | Expr VisitExpr_(const VarNode* op) final { |
62 | Var v = GetRef<Var>(op); |
63 | return rename_.count(v) != 0 ? rename_.at(v) : v; |
64 | } |
65 | |
66 | Expr VisitExpr_(const LetNode* op) final { |
67 | std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> new_vars; |
68 | auto pre_visit = [this, &new_vars](const LetNode* op) { |
69 | Expr expr = GetRef<Expr>(op); |
70 | new_vars[expr] = this->Fresh(op->var); |
71 | // Rely on the Memoizer to cache pre-visit values |
72 | this->VisitExpr(op->value); |
73 | }; |
74 | auto post_visit = [this, &new_vars](const LetNode* op) { |
75 | Expr expr = GetRef<Expr>(op); |
76 | this->memo_[expr] = |
77 | Let(new_vars[expr], this->VisitExpr(op->value), this->VisitExpr(op->body)); |
78 | }; |
79 | ExpandANormalForm(op, pre_visit, post_visit); |
80 | return memo_[GetRef<Expr>(op)]; |
81 | } |
82 | |
83 | Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } |
84 | |
85 | Expr VisitExpr_(const FunctionNode* func_node) final { |
86 | tvm::Array<TypeVar> type_params; |
87 | for (const TypeVar& type_param : func_node->type_params) { |
88 | type_params.push_back(Fresh(type_param)); |
89 | } |
90 | tvm::Array<Var> params; |
91 | for (const Var& param : func_node->params) { |
92 | params.push_back(Fresh(param)); |
93 | } |
94 | return WithFields(GetRef<Function>(func_node), params, VisitExpr(func_node->body), |
95 | VisitType(func_node->ret_type), type_params); |
96 | } |
97 | |
98 | Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } |
99 | |
100 | Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); } |
101 | |
102 | Type VisitType_(const TypeVarNode* op) final { |
103 | TypeVar v = GetRef<TypeVar>(op); |
104 | return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; |
105 | } |
106 | |
107 | Var VisitVar(const Var& v) final { return Fresh(v); } |
108 | |
109 | private: |
110 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> rename_; |
111 | std::unordered_map<TypeVar, TypeVar, ObjectPtrHash, ObjectPtrEqual> type_rename_; |
112 | }; |
113 | ICHECK(WellFormed(e)) << AsText(e, false); |
114 | Expr ret = DeDupMutator().VisitExpr(e); |
115 | ICHECK(WellFormed(ret)); |
116 | ICHECK_EQ(FreeVars(e).size(), FreeVars(ret).size()); |
117 | return ret; |
118 | } // namespace relay |
119 | |
120 | TVM_REGISTER_GLOBAL("relay._transform.dedup" ).set_body_typed(DeDup); |
121 | |
122 | } // namespace relay |
123 | } // namespace tvm |
124 | |