1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file kindchecker.cc |
23 | * |
24 | * \brief Check that types are well formed by applying "kinding rules". |
25 | * |
26 | * This pass ensures we do not do things that violate the design of the |
27 | * type system when writing down types. |
28 | * |
29 | * For example tensors are not allowed to contain functions in Relay. |
30 | * |
31 | * We check this by ensuring the `dtype` field of a Tensor always |
32 | * contains a data type such as `int`, `float`, `uint`. |
33 | */ |
34 | #include <tvm/ir/type_functor.h> |
35 | #include <tvm/relay/analysis.h> |
36 | #include <tvm/relay/error.h> |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | |
41 | using namespace tvm::runtime; |
42 | |
43 | struct KindChecker : TypeFunctor<Kind(const Type&)> { |
44 | const IRModule& mod; |
45 | Optional<DiagnosticContext> diag_ctx; |
46 | |
47 | explicit KindChecker(const IRModule& mod, Optional<DiagnosticContext> diag_ctx) |
48 | : mod(mod), diag_ctx(diag_ctx) {} |
49 | |
50 | void EmitFatal(Diagnostic diagnostic) { |
51 | if (this->diag_ctx) { |
52 | this->diag_ctx.value().EmitFatal(diagnostic); |
53 | } else { |
54 | LOG(FATAL) << diagnostic->message; |
55 | } |
56 | } |
57 | |
58 | void CheckKindMatches(const Type& t, const Type& outer, Kind expected, |
59 | const std::string& description) { |
60 | Kind k = this->VisitType(t); |
61 | if (k != expected) { |
62 | EmitFatal(Diagnostic::Error(t->span) |
63 | << "Incorrect kind for a " << description << ". Type " << t << " inside " << outer |
64 | << " is of kind " << k << " but was expected to be " << expected); |
65 | } |
66 | } |
67 | |
68 | Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; } |
69 | |
70 | Kind VisitType_(const TypeVarNode* op) override { return op->kind; } |
71 | |
72 | Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; } |
73 | |
74 | Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; } |
75 | |
76 | Kind VisitType_(const TupleTypeNode* op) override { |
77 | // tuples should only contain normal types |
78 | for (const Type& t : op->fields) { |
79 | CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType, "tuple member" ); |
80 | } |
81 | return Kind::kType; |
82 | } |
83 | |
84 | Kind VisitType_(const FuncTypeNode* op) override { |
85 | // Func types should only take normal types for arguments |
86 | // and only return a normal type. They should also have |
87 | // well-formed constraints |
88 | FuncType ft = GetRef<FuncType>(op); |
89 | for (const Type& t : op->arg_types) { |
90 | CheckKindMatches(t, ft, Kind::kType, "function type parameter" ); |
91 | } |
92 | |
93 | CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type" ); |
94 | |
95 | for (const TypeConstraint& tc : op->type_constraints) { |
96 | CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint" ); |
97 | } |
98 | |
99 | return Kind::kType; |
100 | } |
101 | |
102 | Kind VisitType_(const RelayRefTypeNode* op) override { |
103 | // ref types should only contain normal types |
104 | RelayRefType rt = GetRef<RelayRefType>(op); |
105 | CheckKindMatches(op->value, rt, Kind::kType, "ref contents" ); |
106 | return Kind::kType; |
107 | } |
108 | |
109 | Kind VisitType_(const TypeRelationNode* op) override { |
110 | // arguments to type relation should be normal types |
111 | for (const Type& t : op->args) { |
112 | CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType, "argument to type relation" ); |
113 | } |
114 | return Kind::kConstraint; |
115 | } |
116 | |
117 | Kind VisitType_(const TypeCallNode* op) override { |
118 | // type call func should be a global type var, args should be type |
119 | TypeCall tc = GetRef<TypeCall>(op); |
120 | const auto* gtv = op->func.as<GlobalTypeVarNode>(); |
121 | if (gtv == nullptr) { |
122 | EmitFatal(Diagnostic::Error(op->span) |
123 | << "The callee in " << tc << " is not a global type var, but is " << op->func); |
124 | } |
125 | |
126 | CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function" ); |
127 | |
128 | for (const Type& t : op->args) { |
129 | CheckKindMatches(t, tc, Kind::kType, "type call argument" ); |
130 | } |
131 | |
132 | // finally we need to check the module to check the number of type params |
133 | auto var = GetRef<GlobalTypeVar>(gtv); |
134 | try { |
135 | auto data = mod->LookupTypeDef(var); |
136 | |
137 | if (data->type_vars.size() != op->args.size()) { |
138 | EmitFatal(Diagnostic::Error(op->span) |
139 | << "Expected " << data->type_vars.size() << "arguments for " << tc << "; got " |
140 | << op->args.size()); |
141 | } |
142 | } catch (const Error& err) { |
143 | // TODO(@jroesch): can probably relax to just emit |
144 | EmitFatal(Diagnostic::Error(op->span) |
145 | << "the type variable : `" << var->name_hint << "` is undefined" ); |
146 | } |
147 | |
148 | return Kind::kType; |
149 | } |
150 | |
151 | Kind VisitType_(const TypeDataNode* op) override { |
152 | // Constructors can reference the header var, but no other GlobalTypeVars. |
153 | // In theory, a TypeData could be nested, so the header scope |
154 | // should be tracked recursively, but it is unclear that we need |
155 | // to support it. |
156 | TypeData td = GetRef<TypeData>(op); |
157 | CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header" ); |
158 | |
159 | for (const auto& var : op->type_vars) { |
160 | CheckKindMatches(var, td, Kind::kType, "ADT type var" ); |
161 | } |
162 | |
163 | for (const auto& con : op->constructors) { |
164 | if (!con->belong_to.same_as(op->header)) { |
165 | EmitFatal(Diagnostic::Error(op->span) << con << " has header " << con->belong_to << " but " |
166 | << op << " has header " << op->header); |
167 | } |
168 | |
169 | for (const Type& t : con->inputs) { |
170 | CheckKindMatches(t, td, Kind::kType, "ADT constructor input" ); |
171 | } |
172 | } |
173 | return Kind::kTypeData; |
174 | } |
175 | |
176 | Kind Check(const Type& t) { return this->VisitType(t); } |
177 | }; |
178 | |
179 | Kind KindCheck(const Type& t, const IRModule& mod, Optional<DiagnosticContext> diag_ctx) { |
180 | KindChecker kc(mod, diag_ctx); |
181 | return kc.Check(t); |
182 | } |
183 | |
184 | TVM_REGISTER_GLOBAL("relay.analysis.check_kind" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
185 | if (args.size() == 1) { |
186 | *ret = KindCheck(args[0], IRModule({}, {})); |
187 | } else if (args.size() == 2) { |
188 | *ret = KindCheck(args[0], args[1], Optional<DiagnosticContext>()); |
189 | } else { |
190 | *ret = KindCheck(args[0], args[1], args[2]); |
191 | } |
192 | }); |
193 | |
194 | } // namespace relay |
195 | } // namespace tvm |
196 | |