1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file kindchecker.cc
23 *
24 * \brief Check that types are well formed by applying "kinding rules".
25 *
26 * This pass ensures we do not do things that violate the design of the
27 * type system when writing down types.
28 *
29 * For example tensors are not allowed to contain functions in Relay.
30 *
31 * We check this by ensuring the `dtype` field of a Tensor always
32 * contains a data type such as `int`, `float`, `uint`.
33 */
34#include <tvm/ir/type_functor.h>
35#include <tvm/relay/analysis.h>
36#include <tvm/relay/error.h>
37
38namespace tvm {
39namespace relay {
40
41using namespace tvm::runtime;
42
43struct KindChecker : TypeFunctor<Kind(const Type&)> {
44 const IRModule& mod;
45 Optional<DiagnosticContext> diag_ctx;
46
47 explicit KindChecker(const IRModule& mod, Optional<DiagnosticContext> diag_ctx)
48 : mod(mod), diag_ctx(diag_ctx) {}
49
50 void EmitFatal(Diagnostic diagnostic) {
51 if (this->diag_ctx) {
52 this->diag_ctx.value().EmitFatal(diagnostic);
53 } else {
54 LOG(FATAL) << diagnostic->message;
55 }
56 }
57
58 void CheckKindMatches(const Type& t, const Type& outer, Kind expected,
59 const std::string& description) {
60 Kind k = this->VisitType(t);
61 if (k != expected) {
62 EmitFatal(Diagnostic::Error(t->span)
63 << "Incorrect kind for a " << description << ". Type " << t << " inside " << outer
64 << " is of kind " << k << " but was expected to be " << expected);
65 }
66 }
67
68 Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; }
69
70 Kind VisitType_(const TypeVarNode* op) override { return op->kind; }
71
72 Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; }
73
74 Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; }
75
76 Kind VisitType_(const TupleTypeNode* op) override {
77 // tuples should only contain normal types
78 for (const Type& t : op->fields) {
79 CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType, "tuple member");
80 }
81 return Kind::kType;
82 }
83
84 Kind VisitType_(const FuncTypeNode* op) override {
85 // Func types should only take normal types for arguments
86 // and only return a normal type. They should also have
87 // well-formed constraints
88 FuncType ft = GetRef<FuncType>(op);
89 for (const Type& t : op->arg_types) {
90 CheckKindMatches(t, ft, Kind::kType, "function type parameter");
91 }
92
93 CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type");
94
95 for (const TypeConstraint& tc : op->type_constraints) {
96 CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint");
97 }
98
99 return Kind::kType;
100 }
101
102 Kind VisitType_(const RelayRefTypeNode* op) override {
103 // ref types should only contain normal types
104 RelayRefType rt = GetRef<RelayRefType>(op);
105 CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
106 return Kind::kType;
107 }
108
109 Kind VisitType_(const TypeRelationNode* op) override {
110 // arguments to type relation should be normal types
111 for (const Type& t : op->args) {
112 CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType, "argument to type relation");
113 }
114 return Kind::kConstraint;
115 }
116
117 Kind VisitType_(const TypeCallNode* op) override {
118 // type call func should be a global type var, args should be type
119 TypeCall tc = GetRef<TypeCall>(op);
120 const auto* gtv = op->func.as<GlobalTypeVarNode>();
121 if (gtv == nullptr) {
122 EmitFatal(Diagnostic::Error(op->span)
123 << "The callee in " << tc << " is not a global type var, but is " << op->func);
124 }
125
126 CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
127
128 for (const Type& t : op->args) {
129 CheckKindMatches(t, tc, Kind::kType, "type call argument");
130 }
131
132 // finally we need to check the module to check the number of type params
133 auto var = GetRef<GlobalTypeVar>(gtv);
134 try {
135 auto data = mod->LookupTypeDef(var);
136
137 if (data->type_vars.size() != op->args.size()) {
138 EmitFatal(Diagnostic::Error(op->span)
139 << "Expected " << data->type_vars.size() << "arguments for " << tc << "; got "
140 << op->args.size());
141 }
142 } catch (const Error& err) {
143 // TODO(@jroesch): can probably relax to just emit
144 EmitFatal(Diagnostic::Error(op->span)
145 << "the type variable : `" << var->name_hint << "` is undefined");
146 }
147
148 return Kind::kType;
149 }
150
151 Kind VisitType_(const TypeDataNode* op) override {
152 // Constructors can reference the header var, but no other GlobalTypeVars.
153 // In theory, a TypeData could be nested, so the header scope
154 // should be tracked recursively, but it is unclear that we need
155 // to support it.
156 TypeData td = GetRef<TypeData>(op);
157 CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header");
158
159 for (const auto& var : op->type_vars) {
160 CheckKindMatches(var, td, Kind::kType, "ADT type var");
161 }
162
163 for (const auto& con : op->constructors) {
164 if (!con->belong_to.same_as(op->header)) {
165 EmitFatal(Diagnostic::Error(op->span) << con << " has header " << con->belong_to << " but "
166 << op << " has header " << op->header);
167 }
168
169 for (const Type& t : con->inputs) {
170 CheckKindMatches(t, td, Kind::kType, "ADT constructor input");
171 }
172 }
173 return Kind::kTypeData;
174 }
175
176 Kind Check(const Type& t) { return this->VisitType(t); }
177};
178
179Kind KindCheck(const Type& t, const IRModule& mod, Optional<DiagnosticContext> diag_ctx) {
180 KindChecker kc(mod, diag_ctx);
181 return kc.Check(t);
182}
183
184TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) {
185 if (args.size() == 1) {
186 *ret = KindCheck(args[0], IRModule({}, {}));
187 } else if (args.size() == 2) {
188 *ret = KindCheck(args[0], args[1], Optional<DiagnosticContext>());
189 } else {
190 *ret = KindCheck(args[0], args[1], args[2]);
191 }
192});
193
194} // namespace relay
195} // namespace tvm
196