1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file de_duplicate.cc
23 * \brief Use a fresh Id for every Var to make the result well-formed.
24 */
25#include <tvm/ir/type_functor.h>
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/pattern_functor.h>
29
30#include <stack>
31
32namespace tvm {
33namespace relay {
34
35Expr DeDup(const Expr& e) {
36 class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator {
37 public:
38 TypeVar Fresh(const TypeVar& tv) {
39 TypeVar ret = TypeVar(tv->name_hint, tv->kind);
40 type_rename_[tv] = ret;
41 return ret;
42 }
43
44 Var Fresh(const Var& v) {
45 ICHECK_EQ(rename_.count(v), 0);
46 ICHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
47 Var ret = Var(v->name_hint(), VisitType(v->type_annotation));
48 rename_[v] = ret;
49 return ret;
50 }
51
52 Expr DispatchVisitExpr(const Expr& e) final {
53 auto ret = ExprMutator::VisitExpr(e);
54 ret->checked_type_ = e->checked_type_;
55 ret->virtual_device_ = e->virtual_device_;
56 return ret;
57 }
58
59 using MixedModeMutator::VisitExpr_;
60
61 Expr VisitExpr_(const VarNode* op) final {
62 Var v = GetRef<Var>(op);
63 return rename_.count(v) != 0 ? rename_.at(v) : v;
64 }
65
66 Expr VisitExpr_(const LetNode* op) final {
67 std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> new_vars;
68 auto pre_visit = [this, &new_vars](const LetNode* op) {
69 Expr expr = GetRef<Expr>(op);
70 new_vars[expr] = this->Fresh(op->var);
71 // Rely on the Memoizer to cache pre-visit values
72 this->VisitExpr(op->value);
73 };
74 auto post_visit = [this, &new_vars](const LetNode* op) {
75 Expr expr = GetRef<Expr>(op);
76 this->memo_[expr] =
77 Let(new_vars[expr], this->VisitExpr(op->value), this->VisitExpr(op->body));
78 };
79 ExpandANormalForm(op, pre_visit, post_visit);
80 return memo_[GetRef<Expr>(op)];
81 }
82
83 Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; }
84
85 Expr VisitExpr_(const FunctionNode* func_node) final {
86 tvm::Array<TypeVar> type_params;
87 for (const TypeVar& type_param : func_node->type_params) {
88 type_params.push_back(Fresh(type_param));
89 }
90 tvm::Array<Var> params;
91 for (const Var& param : func_node->params) {
92 params.push_back(Fresh(param));
93 }
94 return WithFields(GetRef<Function>(func_node), params, VisitExpr(func_node->body),
95 VisitType(func_node->ret_type), type_params);
96 }
97
98 Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); }
99
100 Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); }
101
102 Type VisitType_(const TypeVarNode* op) final {
103 TypeVar v = GetRef<TypeVar>(op);
104 return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
105 }
106
107 Var VisitVar(const Var& v) final { return Fresh(v); }
108
109 private:
110 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> rename_;
111 std::unordered_map<TypeVar, TypeVar, ObjectPtrHash, ObjectPtrEqual> type_rename_;
112 };
113 ICHECK(WellFormed(e)) << AsText(e, false);
114 Expr ret = DeDupMutator().VisitExpr(e);
115 ICHECK(WellFormed(ret));
116 ICHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
117 return ret;
118} // namespace relay
119
120TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup);
121
122} // namespace relay
123} // namespace tvm
124