1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/type.h
22 * \brief IR/AST nodes for the unified type system in TVM.
23 *
24 * We use Relay's type system as the unified type system
25 * throughout the stack.
26 *
27 * This file contains types that are common across IR variants.
28 *
29 * ## Relation between Type and runtime::DataType
30 *
31 * Besides Type, we also store a dtype field in the low-level PrimExpr.
32 * runtime::DataType(dtype) provides coarse grained type information
33 * during compile time and runtime. It is eagerly built in
34 * low-level expression construction and can be used for
35 * quick type checking in the low-level IR.
36 * For example, when an Expr's dtype is int32,
37 * we know for sure that its type is also int32.
38 *
39 * On the other hand, Type provides more fine grained information.
40 * For example, a low level expression can have DataType::Handle() as
41 * its dtype and MemRef[float32] as its type.
42 * Types are usually lazily constructed via type checking,
43 * so they may not readily be available during IR construction.
44 *
45 * The unified Type serves as a common bridge across IR dialects.
46 * For example, we require all the functions to have a type signature,
47 * which allow us to build cross dialect function calls.
48 */
49#ifndef TVM_IR_TYPE_H_
50#define TVM_IR_TYPE_H_
51
52#include <tvm/ir/source_map.h>
53#include <tvm/node/node.h>
54#include <tvm/runtime/container/array.h>
55#include <tvm/runtime/data_type.h>
56#include <tvm/runtime/object.h>
57
58#include <string>
59
60namespace tvm {
61
62/*!
63 * \brief Type is the base type of all types.
64 *
65 * Relay's type system contains following subclasses:
66 *
67 * - PrimType: type of primitive type values used in the low-level IR.
68 * - FuncType: type of a function.
69 * - TensorType: type of certain Tensor values in the expression.
70 *
71 * There are also advanced types to support generic(polymorphic types).
72 * \sa Type
73 */
74class TypeNode : public Object {
75 public:
76 /*!
77 * \brief Span that points to the original source code.
78 * Reserved debug information.
79 */
80 mutable Span span;
81
82 static constexpr const char* _type_key = "Type";
83 static constexpr const bool _type_has_method_sequal_reduce = true;
84 static constexpr const bool _type_has_method_shash_reduce = true;
85 static constexpr const uint32_t _type_child_slots = 14;
86 TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
87};
88
89/*!
90 * \brief Managed reference to TypeNode.
91 * \sa TypeNode
92 */
93class Type : public ObjectRef {
94 public:
95 TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
96};
97
98/*!
99 * \brief Primitive data types used in the low-level IR.
100 *
101 * PrimType represents POD-values and handles that are
102 * not automatically managed by the runtime.
103 *
104 * \sa PrimType
105 */
106class PrimTypeNode : public TypeNode {
107 public:
108 /*!
109 * \brief The corresponding dtype field.
110 */
111 runtime::DataType dtype;
112
113 void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
114
115 bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
116 return equal(dtype, other->dtype);
117 }
118
119 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
120
121 static constexpr const char* _type_key = "PrimType";
122 TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
123};
124
125/*
126 * \brief Managed reference to PrimTypeNode.
127 * \sa PrimTypeNode
128 */
129class PrimType : public Type {
130 public:
131 /*!
132 * \brief Constructor
133 * \param dtype The corresponding dtype.
134 */
135 TVM_DLL explicit PrimType(runtime::DataType dtype);
136
137 TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
138};
139
140/*!
141 * \brief Low-level raw pointer type.
142 *
143 * PointerType represents type hints in the TIR to be
144 * passed to the final code generator.
145 *
146 * PointerType should not occur in the high-level analysis.
147 *
148 * \sa PointerType
149 */
150class PointerTypeNode : public TypeNode {
151 public:
152 /*!
153 * \brief The type of the element which the pointer points to.
154 */
155 Type element_type;
156 /*!
157 * \brief The storage scope of the pointer
158 */
159 String storage_scope;
160
161 void VisitAttrs(AttrVisitor* v) {
162 v->Visit("element_type", &element_type);
163 v->Visit("storage_scope", &storage_scope);
164 }
165
166 bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
167 // Make "global" equal to ""
168 String lhs_scope = storage_scope.empty() ? "global" : storage_scope;
169 String rhs_scope = other->storage_scope.empty() ? "global" : other->storage_scope;
170 return equal(element_type, other->element_type) && equal(lhs_scope, rhs_scope);
171 }
172
173 void SHashReduce(SHashReducer hash_reduce) const {
174 hash_reduce(element_type);
175 // Make "global" equal to ""
176 hash_reduce(storage_scope.empty() ? "global" : storage_scope);
177 }
178
179 static constexpr const char* _type_key = "PointerType";
180 TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
181};
182
183/*
184 * \brief Managed reference to PointerTypeNode.
185 * \sa PointerTypeNode
186 */
187class PointerType : public Type {
188 public:
189 /*!
190 * \brief Constructor
191 * \param element_type The type of the element which the pointer points to.
192 * \param storage_scope The storage scope into which the pointer addresses
193 */
194 TVM_DLL explicit PointerType(Type element_type, String storage_scope = "");
195
196 TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
197};
198
199/*! \brief Possible kinds of TypeVars. */
200enum TypeKind : int {
201 kType = 0,
202 /*! \brief Template variable in shape expression. */
203 kShapeVar = 1,
204 kBaseType = 2,
205 kConstraint = 4,
206 kAdtHandle = 5,
207 kTypeData = 6
208};
209
210/*! \brief Converts a TypeKind to a string. */
211inline String TypeKind2String(TypeKind kind) {
212 switch (kind) {
213 case TypeKind::kType:
214 return "Type";
215 case TypeKind::kShapeVar:
216 return "ShapeVar";
217 case TypeKind::kBaseType:
218 return "BaseType";
219 case TypeKind::kConstraint:
220 return "Constraint";
221 case TypeKind::kAdtHandle:
222 return "AdtHandle";
223 case TypeKind::kTypeData:
224 return "TypeData";
225 }
226 LOG(FATAL) << "ValueError: Unknown TypeKind: " << static_cast<int>(kind);
227}
228
229/*!
230 * \brief Type parameter in functions.
231 *
232 * A type variable can be viewed as template parameter in c++ template function.
233 *
234 * For example, in the following pesudo code,
235 * the TypeVar of f is TypeVar("n", kind=kShapeVar).
236 * This function can take in a Tensor with shape=(3, 3) and
237 * returns a Tensor with shape=(9,)
238 *
239 * \code
240 *
241 * template<i32 n>
242 * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
243 *
244 * \endcode
245 * \sa TypeVar, TypeKind
246 */
247class TypeVarNode : public TypeNode {
248 public:
249 /*!
250 * \brief The name of the variable,
251 * this only acts as a hint to the user,
252 * and is not used for equality.
253 */
254 String name_hint;
255 /*! \brief The kind of type parameter */
256 TypeKind kind;
257
258 void VisitAttrs(AttrVisitor* v) {
259 v->Visit("name_hint", &name_hint);
260 v->Visit("kind", &kind);
261 v->Visit("span", &span);
262 }
263
264 bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
265 return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
266 }
267
268 void SHashReduce(SHashReducer hash_reduce) const {
269 hash_reduce(kind);
270 hash_reduce.FreeVarHashImpl(this);
271 }
272
273 static constexpr const char* _type_key = "TypeVar";
274 TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
275};
276
277/*!
278 * \brief Managed reference to TypeVarNode
279 * \sa TypeVarNode
280 */
281class TypeVar : public Type {
282 public:
283 /*!
284 * \brief Constructor
285 * \param name_hint The name of the type var.
286 * \param kind The kind of the type var.
287 * \param span The span information.
288 */
289 TVM_DLL TypeVar(String name_hint, TypeKind kind, Span span = Span());
290
291 TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
292};
293
294/*!
295 * \brief A global type variable that is used for defining new types or type aliases.
296 * \sa GlobalTypeVar
297 */
298class GlobalTypeVarNode : public TypeNode {
299 public:
300 /*!
301 * \brief The name of the variable,
302 * this only acts as a hint to the user,
303 * and is not used for equality.
304 */
305 String name_hint;
306 /*! \brief The kind of type parameter */
307 TypeKind kind;
308
309 void VisitAttrs(AttrVisitor* v) {
310 v->Visit("name_hint", &name_hint);
311 v->Visit("kind", &kind);
312 }
313
314 bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
315 // name matters for now in global type var.
316 return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
317 }
318
319 void SHashReduce(SHashReducer hash_reduce) const {
320 hash_reduce(name_hint);
321 hash_reduce.FreeVarHashImpl(this);
322 }
323
324 static constexpr const char* _type_key = "GlobalTypeVar";
325 TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
326};
327
328/*!
329 * \brief Managed reference to GlobalTypeVarNode
330 * \sa GlobalTypeVarNode
331 */
332class GlobalTypeVar : public Type {
333 public:
334 /*!
335 * \brief Constructor
336 * \param name_hint The name of the type var.
337 * \param kind The kind of the type var.
338 * \param span The span of the type.
339 */
340 TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind, Span span = Span());
341
342 TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
343};
344
345/*!
346 * \brief The type of tuple values.
347 * \sa TupleType
348 */
349class TupleTypeNode : public TypeNode {
350 public:
351 /*! \brief The type of each field in the tuple. */
352 Array<Type> fields;
353
354 TupleTypeNode() {}
355
356 void VisitAttrs(AttrVisitor* v) {
357 v->Visit("fields", &fields);
358 v->Visit("span", &span);
359 }
360
361 bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
362 return equal(fields, other->fields);
363 }
364
365 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
366
367 static constexpr const char* _type_key = "TupleType";
368 TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
369};
370
371/*!
372 * \brief Managed reference to TupleTypeNode.
373 * \sa TupleTypeNode.
374 */
375class TupleType : public Type {
376 public:
377 /*!
378 * \brief Constructor
379 * \param fields Fields in the tuple.
380 * \param span The span of the type.
381 */
382 TVM_DLL explicit TupleType(Array<Type> fields, Span span = Span());
383
384 /*!
385 * \brief Create an empty tuple type that constains nothing.
386 * \return A empty tuple type.
387 */
388 TVM_DLL TupleType static Empty();
389
390 TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
391};
392
393/*!
394 * \return a type that represents void.
395 */
396inline Type VoidType() { return TupleType::Empty(); }
397
398/*!
399 * \brief Check whether the tyep represents void.
400 * \return The check result.
401 */
402inline bool IsVoidType(const Type& type) {
403 auto* n = type.as<TupleTypeNode>();
404 return n && n->fields.size() == 0;
405}
406
407/*!
408 * \brief Potential Constraints in a function.
409 * \sa TypeConstraint
410 */
411class TypeConstraintNode : public TypeNode {
412 public:
413 static constexpr const char* _type_key = "TypeConstraint";
414 static constexpr const uint32_t _type_child_slots = 1;
415 TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
416};
417
418/*!
419 * \brief Managed reference to TypeConstraintNode.
420 * \sa TypeConstraintNode, TypeRelation
421 */
422class TypeConstraint : public Type {
423 public:
424 TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
425};
426
427/*!
428 * \brief Function type.
429 *
430 * We support polymorphic function type.
431 * This can be roughly viewed as template function in C++.
432 *
433 * \sa FuncType, TypeVar, TypeConstraint
434 */
435class FuncTypeNode : public TypeNode {
436 public:
437 /*! \brief type type of arguments */
438 Array<Type> arg_types;
439 /*! \brief The type of return value. */
440 Type ret_type;
441 // The following fields are used in polymorphic(template) functions
442 // For normal functions, the following two fields will be empty.
443 /*! \brief The type parameters of the function */
444 Array<TypeVar> type_params;
445 /*!
446 * \brief potential constraint the type need to obey
447 * \note this field is reserved for further purposes.
448 */
449 Array<TypeConstraint> type_constraints;
450
451 void VisitAttrs(AttrVisitor* v) {
452 v->Visit("arg_types", &arg_types);
453 v->Visit("ret_type", &ret_type);
454 v->Visit("type_params", &type_params);
455 v->Visit("type_constraints", &type_constraints);
456 v->Visit("span", &span);
457 }
458
459 bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
460 // type params first as they defines type vars.
461 return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) &&
462 equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints);
463 }
464
465 void SHashReduce(SHashReducer hash_reduce) const {
466 hash_reduce.DefHash(type_params);
467 hash_reduce(arg_types);
468 hash_reduce(ret_type);
469 hash_reduce(type_constraints);
470 }
471
472 static constexpr const char* _type_key = "FuncType";
473 TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
474};
475
476/*!
477 * \brief Managed reference to FuncTypeNode.
478 * \sa FuncTypeNode
479 */
480class FuncType : public Type {
481 public:
482 /*!
483 * \brief Constructor
484 * \param arg_types The types of the arguments.
485 * \param ret_type The type of the return value.
486 * \param type_params The type parameters.
487 * \param type_constraints The type constraints.
488 * \param span The span information.
489 * \sa FuncTypeNode for more docs about these fields.
490 */
491 TVM_DLL FuncType(Array<Type> arg_types, Type ret_type, Array<TypeVar> type_params,
492 Array<TypeConstraint> type_constraints, Span span = Span());
493
494 TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
495};
496
497/*!
498 * \brief Intermediate values that is used to indicate incomplete type
499 * during type inference.
500 *
501 * If we view the type relations as "computational graph of types",
502 * then IncompleteType represents intermediate values of the graph,
503 * TypeVar represents the input to the graph.
504 *
505 * \sa IncompleteType
506 */
507class IncompleteTypeNode : public TypeNode {
508 public:
509 /*! \brief kind of the type. */
510 TypeKind kind;
511
512 void VisitAttrs(tvm::AttrVisitor* v) {
513 v->Visit("kind", &kind);
514 v->Visit("span", &span);
515 }
516
517 bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
518 return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
519 }
520
521 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); }
522
523 static constexpr const char* _type_key = "IncompleteType";
524 TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
525};
526
527/*!
528 * \brief Managed reference to IncompleteTypeNode.
529 * \sa IncompleteTypeNode
530 */
531class IncompleteType : public Type {
532 public:
533 /*!
534 * \brief Constructor.
535 * \param kind kind of the type.
536 * \param span The span information.
537 */
538 TVM_DLL explicit IncompleteType(TypeKind kind, Span span = Span());
539
540 TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
541};
542
543/*!
544 * \brief Reference Type High-level Relay IR.
545 *
546 * \sa RelayRefType.
547 */
548class RelayRefTypeNode : public TypeNode {
549 public:
550 /*! \brief The type of value in the Reference. */
551 Type value;
552
553 RelayRefTypeNode() {}
554
555 void VisitAttrs(tvm::AttrVisitor* v) {
556 v->Visit("value", &value);
557 v->Visit("span", &span);
558 }
559
560 bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
561 return equal(value, other->value);
562 }
563
564 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
565
566 // Keep the relay prefix in the type as this type is specific
567 // to the relay itself.
568 static constexpr const char* _type_key = "relay.RefType";
569 TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
570};
571
572/*!
573 * \brief Managed reference to RelayRefTypeNode.
574 * \sa RelayRefTypeNode.
575 */
576class RelayRefType : public Type {
577 public:
578 TVM_DLL explicit RelayRefType(Type value, Span span = Span());
579 TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
580};
581} // namespace tvm
582#endif // TVM_IR_TYPE_H_
583