1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/tensor_type.h
22 * \brief Polymorphic tensor types.
23 */
24#ifndef TVM_IR_TENSOR_TYPE_H_
25#define TVM_IR_TENSOR_TYPE_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/ir/type.h>
29
30namespace tvm {
31/*!
32 * \brief Base of all Tensor types
33 * This container can hold TensorType or GenericTensorType.
34 * \sa BaseTensorType, TensorTypeNode
35 */
36class BaseTensorTypeNode : public TypeNode {
37 public:
38 static constexpr const char* _type_key = "relay.BaseTensorType";
39 static constexpr const uint32_t _type_child_slots = 1;
40 TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
41};
42
43/*!
44 * \brief Managed reference to BaseTensorTypeNode.
45 * \sa BaseTensorTypeNode.
46 */
47class BaseTensorType : public Type {
48 public:
49 TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
50};
51
52/*!
53 * \brief This is the most commonly used type in relay.
54 * TensorType have a fixed dimension, data type.
55 *
56 * The elements of shape can be either IntImm(constant integer),
57 * or any symbolic integer expression.
58 * The symbolic integer allows generic shape inference in certain cases.
59 * \sa TensorType
60 */
61class TensorTypeNode : public BaseTensorTypeNode {
62 public:
63 /*!
64 * \brief The shape of the tensor,
65 * represented by PrimExpr(tvm::Expr).
66 */
67 Array<PrimExpr> shape;
68 /*! \brief The content data type */
69 DataType dtype;
70
71 void VisitAttrs(tvm::AttrVisitor* v) {
72 v->Visit("shape", &shape);
73 v->Visit("dtype", &dtype);
74 v->Visit("span", &span);
75 }
76
77 bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
78 return equal(shape, other->shape) && equal(dtype, other->dtype);
79 }
80
81 void SHashReduce(SHashReducer hash_reduce) const {
82 hash_reduce(shape);
83 hash_reduce(dtype);
84 }
85
86 /*! \brief Return product of elements in the shape.
87 * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
88 */
89 TVM_DLL PrimExpr Size() const;
90
91 static constexpr const char* _type_key = "relay.TensorType";
92 TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
93};
94
95/*!
96 * \brief Managed reference to TensorTypeNode.
97 * \sa TensorTypeNode.
98 */
99class TensorType : public Type {
100 public:
101 /*!
102 * \brief Constructor.
103 * \param shape The shape of the tensor.
104 * \param dtype The runtime dtype of the tensor's elements.
105 */
106 TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);
107
108 /*!
109 * \brief Construct an scalar containing elements of dtype.
110 * \param dtype The runtime dtype of the tensor's elements.
111 * \return THe constructed type.
112 */
113 TVM_DLL static TensorType Scalar(DataType dtype);
114
115 TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
116};
117
118// The following fields contains advanced typing
119// Only keep the class name and reserved for future usage.
120class GenericTensorType;
121// stores a DataType.
122class GenericDataType;
123// stores a DataType.
124class GenericShape;
125
126} // namespace tvm
127#endif // TVM_IR_TENSOR_TYPE_H_
128