1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/tensor_type.h |
22 | * \brief Polymorphic tensor types. |
23 | */ |
24 | #ifndef TVM_IR_TENSOR_TYPE_H_ |
25 | #define TVM_IR_TENSOR_TYPE_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/ir/type.h> |
29 | |
30 | namespace tvm { |
31 | /*! |
32 | * \brief Base of all Tensor types |
33 | * This container can hold TensorType or GenericTensorType. |
34 | * \sa BaseTensorType, TensorTypeNode |
35 | */ |
36 | class BaseTensorTypeNode : public TypeNode { |
37 | public: |
38 | static constexpr const char* _type_key = "relay.BaseTensorType" ; |
39 | static constexpr const uint32_t _type_child_slots = 1; |
40 | TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); |
41 | }; |
42 | |
43 | /*! |
44 | * \brief Managed reference to BaseTensorTypeNode. |
45 | * \sa BaseTensorTypeNode. |
46 | */ |
47 | class BaseTensorType : public Type { |
48 | public: |
49 | TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode); |
50 | }; |
51 | |
52 | /*! |
53 | * \brief This is the most commonly used type in relay. |
54 | * TensorType have a fixed dimension, data type. |
55 | * |
56 | * The elements of shape can be either IntImm(constant integer), |
57 | * or any symbolic integer expression. |
58 | * The symbolic integer allows generic shape inference in certain cases. |
59 | * \sa TensorType |
60 | */ |
61 | class TensorTypeNode : public BaseTensorTypeNode { |
62 | public: |
63 | /*! |
64 | * \brief The shape of the tensor, |
65 | * represented by PrimExpr(tvm::Expr). |
66 | */ |
67 | Array<PrimExpr> shape; |
68 | /*! \brief The content data type */ |
69 | DataType dtype; |
70 | |
71 | void VisitAttrs(tvm::AttrVisitor* v) { |
72 | v->Visit("shape" , &shape); |
73 | v->Visit("dtype" , &dtype); |
74 | v->Visit("span" , &span); |
75 | } |
76 | |
77 | bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { |
78 | return equal(shape, other->shape) && equal(dtype, other->dtype); |
79 | } |
80 | |
81 | void SHashReduce(SHashReducer hash_reduce) const { |
82 | hash_reduce(shape); |
83 | hash_reduce(dtype); |
84 | } |
85 | |
86 | /*! \brief Return product of elements in the shape. |
87 | * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. |
88 | */ |
89 | TVM_DLL PrimExpr Size() const; |
90 | |
91 | static constexpr const char* _type_key = "relay.TensorType" ; |
92 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode); |
93 | }; |
94 | |
95 | /*! |
96 | * \brief Managed reference to TensorTypeNode. |
97 | * \sa TensorTypeNode. |
98 | */ |
99 | class TensorType : public Type { |
100 | public: |
101 | /*! |
102 | * \brief Constructor. |
103 | * \param shape The shape of the tensor. |
104 | * \param dtype The runtime dtype of the tensor's elements. |
105 | */ |
106 | TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype); |
107 | |
108 | /*! |
109 | * \brief Construct an scalar containing elements of dtype. |
110 | * \param dtype The runtime dtype of the tensor's elements. |
111 | * \return THe constructed type. |
112 | */ |
113 | TVM_DLL static TensorType Scalar(DataType dtype); |
114 | |
115 | TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); |
116 | }; |
117 | |
118 | // The following fields contains advanced typing |
119 | // Only keep the class name and reserved for future usage. |
120 | class GenericTensorType; |
121 | // stores a DataType. |
122 | class GenericDataType; |
123 | // stores a DataType. |
124 | class GenericShape; |
125 | |
126 | } // namespace tvm |
127 | #endif // TVM_IR_TENSOR_TYPE_H_ |
128 | |