1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file well_formed.cc |
22 | * \brief check that expression is well formed. |
23 | */ |
24 | #include <tvm/relay/analysis.h> |
25 | #include <tvm/relay/expr_functor.h> |
26 | #include <tvm/relay/pattern_functor.h> |
27 | #include <tvm/runtime/logging.h> |
28 | |
29 | #include <unordered_set> |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | //! brief make sure each Var is bound at most once in a scope. |
35 | class WellFormedChecker : private MixedModeVisitor, PatternVisitor { |
36 | public: |
37 | Optional<DiagnosticContext> diag_ctx; |
38 | Span occurs_in; |
39 | |
40 | explicit WellFormedChecker(const Optional<DiagnosticContext>& ctx) : diag_ctx(ctx) {} |
41 | |
42 | bool well_formed = true; |
43 | |
44 | void Illformed(Diagnostic diag) { |
45 | well_formed = false; |
46 | if (diag_ctx) { |
47 | diag_ctx.value().Emit(diag); |
48 | } else { |
49 | LOG(INFO) << "The IR is not well formed with: " << diag->message; |
50 | } |
51 | } |
52 | |
53 | std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> scope; |
54 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> current_bound; |
55 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> total_bound; |
56 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> free; |
57 | |
58 | struct Scope { |
59 | WellFormedChecker* wfc; |
60 | explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); } |
61 | ~Scope() { |
62 | ICHECK_GE(wfc->scope.size(), 0); |
63 | for (const Var& v : wfc->scope.back()) { |
64 | ICHECK_GE(wfc->current_bound.count(v), 0); |
65 | wfc->current_bound.erase(v); |
66 | } |
67 | wfc->scope.pop_back(); |
68 | } |
69 | }; |
70 | |
71 | void Bound(const Var& v) { |
72 | if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { |
73 | Illformed(Diagnostic::Error(v->span) << "The variable " << v->name_hint() |
74 | << " is bound more than once, this is not valid IR" ); |
75 | } |
76 | ICHECK_GE(scope.size(), 0); |
77 | scope.back().insert(v); |
78 | current_bound.insert(v); |
79 | total_bound.insert(v); |
80 | } |
81 | |
82 | using MixedModeVisitor::VisitExpr_; |
83 | |
84 | void VisitExpr_(const VarNode* op) final { |
85 | Var v = GetRef<Var>(op); |
86 | if (current_bound.count(v) == 0) { |
87 | if (total_bound.count(v) != 0) { |
88 | Illformed(Diagnostic::Error(v->span) << "the variable " << v->name_hint() |
89 | << "is bound more then once, this is not valid IR" ); |
90 | } else { |
91 | free.insert(v); |
92 | } |
93 | } |
94 | } |
95 | |
96 | void VisitExpr_(const LetNode* l) final { |
97 | std::vector<Scope*> scopes; |
98 | Expr let = GetRef<Let>(l); |
99 | while (auto let_node = let.as<LetNode>()) { |
100 | scopes.push_back(new Scope(this)); |
101 | // we do letrec only for FunctionNode, |
102 | // but shadowing let in let binding is likely programming error, and we should forbidden it. |
103 | Bound(let_node->var); |
104 | CheckWellFormed(let_node->value); |
105 | let = let_node->body; |
106 | } |
107 | CheckWellFormed(let); |
108 | while (!scopes.empty()) { |
109 | delete scopes.back(); |
110 | scopes.pop_back(); |
111 | } |
112 | } |
113 | |
114 | void VisitExpr_(const FunctionNode* f) final { |
115 | Scope s(this); |
116 | for (const Var& param : f->params) { |
117 | Bound(param); |
118 | } |
119 | CheckWellFormed(f->body); |
120 | } |
121 | |
122 | void VisitExpr_(const CallNode* call) final { |
123 | ICHECK(call->op.defined()); |
124 | |
125 | for (auto arg : call->args) { |
126 | ICHECK(arg.defined()); |
127 | } |
128 | |
129 | // ICHECK(call->attrs.defined()); |
130 | ICHECK(call->type_args.defined()); |
131 | MixedModeVisitor::VisitExpr_(call); |
132 | } |
133 | |
134 | void VisitClause(const Clause& c) final { |
135 | Scope s(this); |
136 | VisitPattern(c->lhs); |
137 | VisitExpr(c->rhs); |
138 | } |
139 | |
140 | void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } |
141 | |
142 | void VisitVar(const Var& v) final { Bound(v); } |
143 | |
144 | public: |
145 | bool CheckWellFormed(const Expr& e) { |
146 | if (auto v = e.as<VarNode>()) { |
147 | VisitExpr_(v); |
148 | } else { |
149 | // this->occurs_in = e->span; |
150 | VisitExpr(e); |
151 | } |
152 | return well_formed; |
153 | } |
154 | }; |
155 | |
156 | bool WellFormed(const Expr& e, Optional<DiagnosticContext> diag_ctx) { |
157 | return WellFormedChecker(diag_ctx).CheckWellFormed(e); |
158 | } |
159 | |
160 | TVM_REGISTER_GLOBAL("relay.analysis.well_formed" ).set_body_typed([](Expr e) { |
161 | return WellFormed(e); |
162 | }); |
163 | |
164 | } // namespace relay |
165 | } // namespace tvm |
166 | |