1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/type.h |
22 | * \brief IR/AST nodes for the unified type system in TVM. |
23 | * |
24 | * We use Relay's type system as the unified type system |
25 | * throughout the stack. |
26 | * |
27 | * This file contains types that are common across IR variants. |
28 | * |
29 | * ## Relation between Type and runtime::DataType |
30 | * |
31 | * Besides Type, we also store a dtype field in the low-level PrimExpr. |
32 | * runtime::DataType(dtype) provides coarse grained type information |
33 | * during compile time and runtime. It is eagerly built in |
34 | * low-level expression construction and can be used for |
35 | * quick type checking in the low-level IR. |
36 | * For example, when an Expr's dtype is int32, |
37 | * we know for sure that its type is also int32. |
38 | * |
39 | * On the other hand, Type provides more fine grained information. |
40 | * For example, a low level expression can have DataType::Handle() as |
41 | * its dtype and MemRef[float32] as its type. |
42 | * Types are usually lazily constructed via type checking, |
43 | * so they may not readily be available during IR construction. |
44 | * |
45 | * The unified Type serves as a common bridge across IR dialects. |
46 | * For example, we require all the functions to have a type signature, |
47 | * which allow us to build cross dialect function calls. |
48 | */ |
49 | #ifndef TVM_IR_TYPE_H_ |
50 | #define TVM_IR_TYPE_H_ |
51 | |
52 | #include <tvm/ir/source_map.h> |
53 | #include <tvm/node/node.h> |
54 | #include <tvm/runtime/container/array.h> |
55 | #include <tvm/runtime/data_type.h> |
56 | #include <tvm/runtime/object.h> |
57 | |
58 | #include <string> |
59 | |
60 | namespace tvm { |
61 | |
62 | /*! |
63 | * \brief Type is the base type of all types. |
64 | * |
65 | * Relay's type system contains following subclasses: |
66 | * |
67 | * - PrimType: type of primitive type values used in the low-level IR. |
68 | * - FuncType: type of a function. |
69 | * - TensorType: type of certain Tensor values in the expression. |
70 | * |
71 | * There are also advanced types to support generic(polymorphic types). |
72 | * \sa Type |
73 | */ |
74 | class TypeNode : public Object { |
75 | public: |
76 | /*! |
77 | * \brief Span that points to the original source code. |
78 | * Reserved debug information. |
79 | */ |
80 | mutable Span span; |
81 | |
82 | static constexpr const char* _type_key = "Type" ; |
83 | static constexpr const bool _type_has_method_sequal_reduce = true; |
84 | static constexpr const bool _type_has_method_shash_reduce = true; |
85 | static constexpr const uint32_t _type_child_slots = 14; |
86 | TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); |
87 | }; |
88 | |
89 | /*! |
90 | * \brief Managed reference to TypeNode. |
91 | * \sa TypeNode |
92 | */ |
93 | class Type : public ObjectRef { |
94 | public: |
95 | TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode); |
96 | }; |
97 | |
98 | /*! |
99 | * \brief Primitive data types used in the low-level IR. |
100 | * |
101 | * PrimType represents POD-values and handles that are |
102 | * not automatically managed by the runtime. |
103 | * |
104 | * \sa PrimType |
105 | */ |
106 | class PrimTypeNode : public TypeNode { |
107 | public: |
108 | /*! |
109 | * \brief The corresponding dtype field. |
110 | */ |
111 | runtime::DataType dtype; |
112 | |
113 | void VisitAttrs(AttrVisitor* v) { v->Visit("dtype" , &dtype); } |
114 | |
115 | bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { |
116 | return equal(dtype, other->dtype); |
117 | } |
118 | |
119 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } |
120 | |
121 | static constexpr const char* _type_key = "PrimType" ; |
122 | TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); |
123 | }; |
124 | |
125 | /* |
126 | * \brief Managed reference to PrimTypeNode. |
127 | * \sa PrimTypeNode |
128 | */ |
129 | class PrimType : public Type { |
130 | public: |
131 | /*! |
132 | * \brief Constructor |
133 | * \param dtype The corresponding dtype. |
134 | */ |
135 | TVM_DLL explicit PrimType(runtime::DataType dtype); |
136 | |
137 | TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); |
138 | }; |
139 | |
140 | /*! |
141 | * \brief Low-level raw pointer type. |
142 | * |
143 | * PointerType represents type hints in the TIR to be |
144 | * passed to the final code generator. |
145 | * |
146 | * PointerType should not occur in the high-level analysis. |
147 | * |
148 | * \sa PointerType |
149 | */ |
150 | class PointerTypeNode : public TypeNode { |
151 | public: |
152 | /*! |
153 | * \brief The type of the element which the pointer points to. |
154 | */ |
155 | Type element_type; |
156 | /*! |
157 | * \brief The storage scope of the pointer |
158 | */ |
159 | String storage_scope; |
160 | |
161 | void VisitAttrs(AttrVisitor* v) { |
162 | v->Visit("element_type" , &element_type); |
163 | v->Visit("storage_scope" , &storage_scope); |
164 | } |
165 | |
166 | bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { |
167 | // Make "global" equal to "" |
168 | String lhs_scope = storage_scope.empty() ? "global" : storage_scope; |
169 | String rhs_scope = other->storage_scope.empty() ? "global" : other->storage_scope; |
170 | return equal(element_type, other->element_type) && equal(lhs_scope, rhs_scope); |
171 | } |
172 | |
173 | void SHashReduce(SHashReducer hash_reduce) const { |
174 | hash_reduce(element_type); |
175 | // Make "global" equal to "" |
176 | hash_reduce(storage_scope.empty() ? "global" : storage_scope); |
177 | } |
178 | |
179 | static constexpr const char* _type_key = "PointerType" ; |
180 | TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); |
181 | }; |
182 | |
183 | /* |
184 | * \brief Managed reference to PointerTypeNode. |
185 | * \sa PointerTypeNode |
186 | */ |
187 | class PointerType : public Type { |
188 | public: |
189 | /*! |
190 | * \brief Constructor |
191 | * \param element_type The type of the element which the pointer points to. |
192 | * \param storage_scope The storage scope into which the pointer addresses |
193 | */ |
194 | TVM_DLL explicit PointerType(Type element_type, String storage_scope = "" ); |
195 | |
196 | TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); |
197 | }; |
198 | |
199 | /*! \brief Possible kinds of TypeVars. */ |
200 | enum TypeKind : int { |
201 | kType = 0, |
202 | /*! \brief Template variable in shape expression. */ |
203 | kShapeVar = 1, |
204 | kBaseType = 2, |
205 | kConstraint = 4, |
206 | kAdtHandle = 5, |
207 | kTypeData = 6 |
208 | }; |
209 | |
210 | /*! \brief Converts a TypeKind to a string. */ |
211 | inline String TypeKind2String(TypeKind kind) { |
212 | switch (kind) { |
213 | case TypeKind::kType: |
214 | return "Type" ; |
215 | case TypeKind::kShapeVar: |
216 | return "ShapeVar" ; |
217 | case TypeKind::kBaseType: |
218 | return "BaseType" ; |
219 | case TypeKind::kConstraint: |
220 | return "Constraint" ; |
221 | case TypeKind::kAdtHandle: |
222 | return "AdtHandle" ; |
223 | case TypeKind::kTypeData: |
224 | return "TypeData" ; |
225 | } |
226 | LOG(FATAL) << "ValueError: Unknown TypeKind: " << static_cast<int>(kind); |
227 | } |
228 | |
229 | /*! |
230 | * \brief Type parameter in functions. |
231 | * |
232 | * A type variable can be viewed as template parameter in c++ template function. |
233 | * |
234 | * For example, in the following pesudo code, |
235 | * the TypeVar of f is TypeVar("n", kind=kShapeVar). |
236 | * This function can take in a Tensor with shape=(3, 3) and |
237 | * returns a Tensor with shape=(9,) |
238 | * |
239 | * \code |
240 | * |
241 | * template<i32 n> |
242 | * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] |
243 | * |
244 | * \endcode |
245 | * \sa TypeVar, TypeKind |
246 | */ |
247 | class TypeVarNode : public TypeNode { |
248 | public: |
249 | /*! |
250 | * \brief The name of the variable, |
251 | * this only acts as a hint to the user, |
252 | * and is not used for equality. |
253 | */ |
254 | String name_hint; |
255 | /*! \brief The kind of type parameter */ |
256 | TypeKind kind; |
257 | |
258 | void VisitAttrs(AttrVisitor* v) { |
259 | v->Visit("name_hint" , &name_hint); |
260 | v->Visit("kind" , &kind); |
261 | v->Visit("span" , &span); |
262 | } |
263 | |
264 | bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const { |
265 | return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); |
266 | } |
267 | |
268 | void SHashReduce(SHashReducer hash_reduce) const { |
269 | hash_reduce(kind); |
270 | hash_reduce.FreeVarHashImpl(this); |
271 | } |
272 | |
273 | static constexpr const char* _type_key = "TypeVar" ; |
274 | TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); |
275 | }; |
276 | |
277 | /*! |
278 | * \brief Managed reference to TypeVarNode |
279 | * \sa TypeVarNode |
280 | */ |
281 | class TypeVar : public Type { |
282 | public: |
283 | /*! |
284 | * \brief Constructor |
285 | * \param name_hint The name of the type var. |
286 | * \param kind The kind of the type var. |
287 | * \param span The span information. |
288 | */ |
289 | TVM_DLL TypeVar(String name_hint, TypeKind kind, Span span = Span()); |
290 | |
291 | TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); |
292 | }; |
293 | |
294 | /*! |
295 | * \brief A global type variable that is used for defining new types or type aliases. |
296 | * \sa GlobalTypeVar |
297 | */ |
298 | class GlobalTypeVarNode : public TypeNode { |
299 | public: |
300 | /*! |
301 | * \brief The name of the variable, |
302 | * this only acts as a hint to the user, |
303 | * and is not used for equality. |
304 | */ |
305 | String name_hint; |
306 | /*! \brief The kind of type parameter */ |
307 | TypeKind kind; |
308 | |
309 | void VisitAttrs(AttrVisitor* v) { |
310 | v->Visit("name_hint" , &name_hint); |
311 | v->Visit("kind" , &kind); |
312 | } |
313 | |
314 | bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const { |
315 | // name matters for now in global type var. |
316 | return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); |
317 | } |
318 | |
319 | void SHashReduce(SHashReducer hash_reduce) const { |
320 | hash_reduce(name_hint); |
321 | hash_reduce.FreeVarHashImpl(this); |
322 | } |
323 | |
324 | static constexpr const char* _type_key = "GlobalTypeVar" ; |
325 | TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); |
326 | }; |
327 | |
328 | /*! |
329 | * \brief Managed reference to GlobalTypeVarNode |
330 | * \sa GlobalTypeVarNode |
331 | */ |
332 | class GlobalTypeVar : public Type { |
333 | public: |
334 | /*! |
335 | * \brief Constructor |
336 | * \param name_hint The name of the type var. |
337 | * \param kind The kind of the type var. |
338 | * \param span The span of the type. |
339 | */ |
340 | TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind, Span span = Span()); |
341 | |
342 | TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); |
343 | }; |
344 | |
345 | /*! |
346 | * \brief The type of tuple values. |
347 | * \sa TupleType |
348 | */ |
349 | class TupleTypeNode : public TypeNode { |
350 | public: |
351 | /*! \brief The type of each field in the tuple. */ |
352 | Array<Type> fields; |
353 | |
354 | TupleTypeNode() {} |
355 | |
356 | void VisitAttrs(AttrVisitor* v) { |
357 | v->Visit("fields" , &fields); |
358 | v->Visit("span" , &span); |
359 | } |
360 | |
361 | bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const { |
362 | return equal(fields, other->fields); |
363 | } |
364 | |
365 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } |
366 | |
367 | static constexpr const char* _type_key = "TupleType" ; |
368 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); |
369 | }; |
370 | |
371 | /*! |
372 | * \brief Managed reference to TupleTypeNode. |
373 | * \sa TupleTypeNode. |
374 | */ |
375 | class TupleType : public Type { |
376 | public: |
377 | /*! |
378 | * \brief Constructor |
379 | * \param fields Fields in the tuple. |
380 | * \param span The span of the type. |
381 | */ |
382 | TVM_DLL explicit TupleType(Array<Type> fields, Span span = Span()); |
383 | |
384 | /*! |
385 | * \brief Create an empty tuple type that constains nothing. |
386 | * \return A empty tuple type. |
387 | */ |
388 | TVM_DLL TupleType static Empty(); |
389 | |
390 | TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); |
391 | }; |
392 | |
393 | /*! |
394 | * \return a type that represents void. |
395 | */ |
396 | inline Type VoidType() { return TupleType::Empty(); } |
397 | |
398 | /*! |
399 | * \brief Check whether the tyep represents void. |
400 | * \return The check result. |
401 | */ |
402 | inline bool IsVoidType(const Type& type) { |
403 | auto* n = type.as<TupleTypeNode>(); |
404 | return n && n->fields.size() == 0; |
405 | } |
406 | |
407 | /*! |
408 | * \brief Potential Constraints in a function. |
409 | * \sa TypeConstraint |
410 | */ |
411 | class TypeConstraintNode : public TypeNode { |
412 | public: |
413 | static constexpr const char* _type_key = "TypeConstraint" ; |
414 | static constexpr const uint32_t _type_child_slots = 1; |
415 | TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); |
416 | }; |
417 | |
418 | /*! |
419 | * \brief Managed reference to TypeConstraintNode. |
420 | * \sa TypeConstraintNode, TypeRelation |
421 | */ |
422 | class TypeConstraint : public Type { |
423 | public: |
424 | TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); |
425 | }; |
426 | |
427 | /*! |
428 | * \brief Function type. |
429 | * |
430 | * We support polymorphic function type. |
431 | * This can be roughly viewed as template function in C++. |
432 | * |
433 | * \sa FuncType, TypeVar, TypeConstraint |
434 | */ |
435 | class FuncTypeNode : public TypeNode { |
436 | public: |
437 | /*! \brief type type of arguments */ |
438 | Array<Type> arg_types; |
439 | /*! \brief The type of return value. */ |
440 | Type ret_type; |
441 | // The following fields are used in polymorphic(template) functions |
442 | // For normal functions, the following two fields will be empty. |
443 | /*! \brief The type parameters of the function */ |
444 | Array<TypeVar> type_params; |
445 | /*! |
446 | * \brief potential constraint the type need to obey |
447 | * \note this field is reserved for further purposes. |
448 | */ |
449 | Array<TypeConstraint> type_constraints; |
450 | |
451 | void VisitAttrs(AttrVisitor* v) { |
452 | v->Visit("arg_types" , &arg_types); |
453 | v->Visit("ret_type" , &ret_type); |
454 | v->Visit("type_params" , &type_params); |
455 | v->Visit("type_constraints" , &type_constraints); |
456 | v->Visit("span" , &span); |
457 | } |
458 | |
459 | bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { |
460 | // type params first as they defines type vars. |
461 | return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) && |
462 | equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints); |
463 | } |
464 | |
465 | void SHashReduce(SHashReducer hash_reduce) const { |
466 | hash_reduce.DefHash(type_params); |
467 | hash_reduce(arg_types); |
468 | hash_reduce(ret_type); |
469 | hash_reduce(type_constraints); |
470 | } |
471 | |
472 | static constexpr const char* _type_key = "FuncType" ; |
473 | TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); |
474 | }; |
475 | |
476 | /*! |
477 | * \brief Managed reference to FuncTypeNode. |
478 | * \sa FuncTypeNode |
479 | */ |
480 | class FuncType : public Type { |
481 | public: |
482 | /*! |
483 | * \brief Constructor |
484 | * \param arg_types The types of the arguments. |
485 | * \param ret_type The type of the return value. |
486 | * \param type_params The type parameters. |
487 | * \param type_constraints The type constraints. |
488 | * \param span The span information. |
489 | * \sa FuncTypeNode for more docs about these fields. |
490 | */ |
491 | TVM_DLL FuncType(Array<Type> arg_types, Type ret_type, Array<TypeVar> type_params, |
492 | Array<TypeConstraint> type_constraints, Span span = Span()); |
493 | |
494 | TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); |
495 | }; |
496 | |
497 | /*! |
498 | * \brief Intermediate values that is used to indicate incomplete type |
499 | * during type inference. |
500 | * |
501 | * If we view the type relations as "computational graph of types", |
502 | * then IncompleteType represents intermediate values of the graph, |
503 | * TypeVar represents the input to the graph. |
504 | * |
505 | * \sa IncompleteType |
506 | */ |
507 | class IncompleteTypeNode : public TypeNode { |
508 | public: |
509 | /*! \brief kind of the type. */ |
510 | TypeKind kind; |
511 | |
512 | void VisitAttrs(tvm::AttrVisitor* v) { |
513 | v->Visit("kind" , &kind); |
514 | v->Visit("span" , &span); |
515 | } |
516 | |
517 | bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { |
518 | return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); |
519 | } |
520 | |
521 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); } |
522 | |
523 | static constexpr const char* _type_key = "IncompleteType" ; |
524 | TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); |
525 | }; |
526 | |
527 | /*! |
528 | * \brief Managed reference to IncompleteTypeNode. |
529 | * \sa IncompleteTypeNode |
530 | */ |
531 | class IncompleteType : public Type { |
532 | public: |
533 | /*! |
534 | * \brief Constructor. |
535 | * \param kind kind of the type. |
536 | * \param span The span information. |
537 | */ |
538 | TVM_DLL explicit IncompleteType(TypeKind kind, Span span = Span()); |
539 | |
540 | TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); |
541 | }; |
542 | |
543 | /*! |
544 | * \brief Reference Type High-level Relay IR. |
545 | * |
546 | * \sa RelayRefType. |
547 | */ |
548 | class RelayRefTypeNode : public TypeNode { |
549 | public: |
550 | /*! \brief The type of value in the Reference. */ |
551 | Type value; |
552 | |
553 | RelayRefTypeNode() {} |
554 | |
555 | void VisitAttrs(tvm::AttrVisitor* v) { |
556 | v->Visit("value" , &value); |
557 | v->Visit("span" , &span); |
558 | } |
559 | |
560 | bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const { |
561 | return equal(value, other->value); |
562 | } |
563 | |
564 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } |
565 | |
566 | // Keep the relay prefix in the type as this type is specific |
567 | // to the relay itself. |
568 | static constexpr const char* _type_key = "relay.RefType" ; |
569 | TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode); |
570 | }; |
571 | |
572 | /*! |
573 | * \brief Managed reference to RelayRefTypeNode. |
574 | * \sa RelayRefTypeNode. |
575 | */ |
576 | class RelayRefType : public Type { |
577 | public: |
578 | TVM_DLL explicit RelayRefType(Type value, Span span = Span()); |
579 | TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode); |
580 | }; |
581 | } // namespace tvm |
582 | #endif // TVM_IR_TYPE_H_ |
583 | |