1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/affine_type.h
22 * \brief Quantized Tensor Types.
23 */
24#ifndef TVM_IR_AFFINE_TYPE_H_
25#define TVM_IR_AFFINE_TYPE_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/ir/type.h>
29
30namespace tvm {
31
32/*!
33 * \brief AffineType representation
34 * \sa AffineType
35 */
36class AffineTypeNode : public Object {
37 public:
38 /*!
39 * \brief Span that points to the original source code.
40 * Reserved debug information.
41 */
42 mutable Span span;
43
44 static constexpr const char* _type_key = "AffineType";
45 static constexpr const bool _type_has_method_sequal_reduce = true;
46 static constexpr const bool _type_has_method_shash_reduce = true;
47 TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object);
48};
49
50/*!
51 * \brief Managed reference to AffineTypeNode.
52 * \sa AffineTypeNode
53 */
54class AffineType : public ObjectRef {
55 public:
56 TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode);
57};
58
59/*!
60 * \brief TensorAffineType representation
61 * \sa TensorAffineType
62 *
63 * This Type represents a quantized integer tensor that can be converted
64 * back to real space via the x_real = scale * (x_quant - zero_point)
65 */
66class TensorAffineTypeNode : public AffineTypeNode {
67 public:
68 /*! \brief The scale of this type */
69 RelayExpr scale;
70 /*! \brief The zero point of this type */
71 RelayExpr zero_point;
72 /*! \brief The data type of this type */
73 DataType dtype;
74 /*! \brief The axis for per-channel quantization */
75 int axis;
76
77 void VisitAttrs(tvm::AttrVisitor* v) {
78 v->Visit("scale", &scale);
79 v->Visit("zero_point", &zero_point);
80 v->Visit("dtype", &dtype);
81 v->Visit("axis", &axis);
82 }
83
84 bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const {
85 equal->MarkGraphNode();
86 return equal(scale, other->scale) && equal(zero_point, other->zero_point) &&
87 equal(dtype, other->dtype) && equal(axis, other->axis);
88 }
89
90 void SHashReduce(SHashReducer hash_reduce) const {
91 hash_reduce->MarkGraphNode();
92 hash_reduce(scale);
93 hash_reduce(zero_point);
94 hash_reduce(dtype);
95 hash_reduce(axis);
96 }
97
98 static constexpr const char* _type_key = "TensorAffineType";
99 TVM_DECLARE_BASE_OBJECT_INFO(TensorAffineTypeNode, AffineTypeNode);
100};
101
102/*!
103 * \brief Managed reference to AffineTypes.
104 * \sa AffineTypeNode
105 */
106class TensorAffineType : public AffineType {
107 public:
108 TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis);
109
110 TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode);
111};
112
113/*!
114 * \brief TupleAffineType representation
115 * \sa TupleAffineType
116 */
117class TupleAffineTypeNode : public AffineTypeNode {
118 public:
119 /*! \brief The types of this tuple*/
120 Array<TensorAffineType> types;
121
122 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("types", &types); }
123
124 bool SEqualReduce(const TupleAffineTypeNode* other, SEqualReducer equal) const {
125 equal->MarkGraphNode();
126 return equal(types, other->types);
127 }
128
129 void SHashReduce(SHashReducer hash_reduce) const {
130 hash_reduce->MarkGraphNode();
131 hash_reduce(types);
132 }
133
134 static constexpr const char* _type_key = "TupleAffineType";
135 TVM_DECLARE_BASE_OBJECT_INFO(TupleAffineTypeNode, AffineTypeNode);
136};
137
138/*!
139 * \brief Managed reference to TupleAffineTypes.
140 * \sa TupleAffineType
141 */
142class TupleAffineType : public AffineType {
143 public:
144 TVM_DLL TupleAffineType(Array<TensorAffineType> types);
145
146 TVM_DEFINE_OBJECT_REF_METHODS(TupleAffineType, AffineType, TupleAffineTypeNode);
147};
148
149} // namespace tvm
150#endif // TVM_IR_AFFINE_TYPE_H_
151