1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file well_formed.cc
22 * \brief check that expression is well formed.
23 */
24#include <tvm/relay/analysis.h>
25#include <tvm/relay/expr_functor.h>
26#include <tvm/relay/pattern_functor.h>
27#include <tvm/runtime/logging.h>
28
29#include <unordered_set>
30
31namespace tvm {
32namespace relay {
33
34//! brief make sure each Var is bound at most once in a scope.
35class WellFormedChecker : private MixedModeVisitor, PatternVisitor {
36 public:
37 Optional<DiagnosticContext> diag_ctx;
38 Span occurs_in;
39
40 explicit WellFormedChecker(const Optional<DiagnosticContext>& ctx) : diag_ctx(ctx) {}
41
42 bool well_formed = true;
43
44 void Illformed(Diagnostic diag) {
45 well_formed = false;
46 if (diag_ctx) {
47 diag_ctx.value().Emit(diag);
48 } else {
49 LOG(INFO) << "The IR is not well formed with: " << diag->message;
50 }
51 }
52
53 std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> scope;
54 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> current_bound;
55 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> total_bound;
56 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> free;
57
58 struct Scope {
59 WellFormedChecker* wfc;
60 explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); }
61 ~Scope() {
62 ICHECK_GE(wfc->scope.size(), 0);
63 for (const Var& v : wfc->scope.back()) {
64 ICHECK_GE(wfc->current_bound.count(v), 0);
65 wfc->current_bound.erase(v);
66 }
67 wfc->scope.pop_back();
68 }
69 };
70
71 void Bound(const Var& v) {
72 if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) {
73 Illformed(Diagnostic::Error(v->span) << "The variable " << v->name_hint()
74 << " is bound more than once, this is not valid IR");
75 }
76 ICHECK_GE(scope.size(), 0);
77 scope.back().insert(v);
78 current_bound.insert(v);
79 total_bound.insert(v);
80 }
81
82 using MixedModeVisitor::VisitExpr_;
83
84 void VisitExpr_(const VarNode* op) final {
85 Var v = GetRef<Var>(op);
86 if (current_bound.count(v) == 0) {
87 if (total_bound.count(v) != 0) {
88 Illformed(Diagnostic::Error(v->span) << "the variable " << v->name_hint()
89 << "is bound more then once, this is not valid IR");
90 } else {
91 free.insert(v);
92 }
93 }
94 }
95
96 void VisitExpr_(const LetNode* l) final {
97 std::vector<Scope*> scopes;
98 Expr let = GetRef<Let>(l);
99 while (auto let_node = let.as<LetNode>()) {
100 scopes.push_back(new Scope(this));
101 // we do letrec only for FunctionNode,
102 // but shadowing let in let binding is likely programming error, and we should forbidden it.
103 Bound(let_node->var);
104 CheckWellFormed(let_node->value);
105 let = let_node->body;
106 }
107 CheckWellFormed(let);
108 while (!scopes.empty()) {
109 delete scopes.back();
110 scopes.pop_back();
111 }
112 }
113
114 void VisitExpr_(const FunctionNode* f) final {
115 Scope s(this);
116 for (const Var& param : f->params) {
117 Bound(param);
118 }
119 CheckWellFormed(f->body);
120 }
121
122 void VisitExpr_(const CallNode* call) final {
123 ICHECK(call->op.defined());
124
125 for (auto arg : call->args) {
126 ICHECK(arg.defined());
127 }
128
129 // ICHECK(call->attrs.defined());
130 ICHECK(call->type_args.defined());
131 MixedModeVisitor::VisitExpr_(call);
132 }
133
134 void VisitClause(const Clause& c) final {
135 Scope s(this);
136 VisitPattern(c->lhs);
137 VisitExpr(c->rhs);
138 }
139
140 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
141
142 void VisitVar(const Var& v) final { Bound(v); }
143
144 public:
145 bool CheckWellFormed(const Expr& e) {
146 if (auto v = e.as<VarNode>()) {
147 VisitExpr_(v);
148 } else {
149 // this->occurs_in = e->span;
150 VisitExpr(e);
151 }
152 return well_formed;
153 }
154};
155
156bool WellFormed(const Expr& e, Optional<DiagnosticContext> diag_ctx) {
157 return WellFormedChecker(diag_ctx).CheckWellFormed(e);
158}
159
160TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed([](Expr e) {
161 return WellFormed(e);
162});
163
164} // namespace relay
165} // namespace tvm
166