1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/affine_type.h |
22 | * \brief Quantized Tensor Types. |
23 | */ |
24 | #ifndef TVM_IR_AFFINE_TYPE_H_ |
25 | #define TVM_IR_AFFINE_TYPE_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/ir/type.h> |
29 | |
30 | namespace tvm { |
31 | |
32 | /*! |
33 | * \brief AffineType representation |
34 | * \sa AffineType |
35 | */ |
36 | class AffineTypeNode : public Object { |
37 | public: |
38 | /*! |
39 | * \brief Span that points to the original source code. |
40 | * Reserved debug information. |
41 | */ |
42 | mutable Span span; |
43 | |
44 | static constexpr const char* _type_key = "AffineType" ; |
45 | static constexpr const bool _type_has_method_sequal_reduce = true; |
46 | static constexpr const bool _type_has_method_shash_reduce = true; |
47 | TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); |
48 | }; |
49 | |
50 | /*! |
51 | * \brief Managed reference to AffineTypeNode. |
52 | * \sa AffineTypeNode |
53 | */ |
54 | class AffineType : public ObjectRef { |
55 | public: |
56 | TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); |
57 | }; |
58 | |
59 | /*! |
60 | * \brief TensorAffineType representation |
61 | * \sa TensorAffineType |
62 | * |
63 | * This Type represents a quantized integer tensor that can be converted |
64 | * back to real space via the x_real = scale * (x_quant - zero_point) |
65 | */ |
66 | class TensorAffineTypeNode : public AffineTypeNode { |
67 | public: |
68 | /*! \brief The scale of this type */ |
69 | RelayExpr scale; |
70 | /*! \brief The zero point of this type */ |
71 | RelayExpr zero_point; |
72 | /*! \brief The data type of this type */ |
73 | DataType dtype; |
74 | /*! \brief The axis for per-channel quantization */ |
75 | int axis; |
76 | |
77 | void VisitAttrs(tvm::AttrVisitor* v) { |
78 | v->Visit("scale" , &scale); |
79 | v->Visit("zero_point" , &zero_point); |
80 | v->Visit("dtype" , &dtype); |
81 | v->Visit("axis" , &axis); |
82 | } |
83 | |
84 | bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { |
85 | equal->MarkGraphNode(); |
86 | return equal(scale, other->scale) && equal(zero_point, other->zero_point) && |
87 | equal(dtype, other->dtype) && equal(axis, other->axis); |
88 | } |
89 | |
90 | void SHashReduce(SHashReducer hash_reduce) const { |
91 | hash_reduce->MarkGraphNode(); |
92 | hash_reduce(scale); |
93 | hash_reduce(zero_point); |
94 | hash_reduce(dtype); |
95 | hash_reduce(axis); |
96 | } |
97 | |
98 | static constexpr const char* _type_key = "TensorAffineType" ; |
99 | TVM_DECLARE_BASE_OBJECT_INFO(TensorAffineTypeNode, AffineTypeNode); |
100 | }; |
101 | |
102 | /*! |
103 | * \brief Managed reference to AffineTypes. |
104 | * \sa AffineTypeNode |
105 | */ |
106 | class TensorAffineType : public AffineType { |
107 | public: |
108 | TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis); |
109 | |
110 | TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); |
111 | }; |
112 | |
113 | /*! |
114 | * \brief TupleAffineType representation |
115 | * \sa TupleAffineType |
116 | */ |
117 | class TupleAffineTypeNode : public AffineTypeNode { |
118 | public: |
119 | /*! \brief The types of this tuple*/ |
120 | Array<TensorAffineType> types; |
121 | |
122 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("types" , &types); } |
123 | |
124 | bool SEqualReduce(const TupleAffineTypeNode* other, SEqualReducer equal) const { |
125 | equal->MarkGraphNode(); |
126 | return equal(types, other->types); |
127 | } |
128 | |
129 | void SHashReduce(SHashReducer hash_reduce) const { |
130 | hash_reduce->MarkGraphNode(); |
131 | hash_reduce(types); |
132 | } |
133 | |
134 | static constexpr const char* _type_key = "TupleAffineType" ; |
135 | TVM_DECLARE_BASE_OBJECT_INFO(TupleAffineTypeNode, AffineTypeNode); |
136 | }; |
137 | |
138 | /*! |
139 | * \brief Managed reference to TupleAffineTypes. |
140 | * \sa TupleAffineType |
141 | */ |
142 | class TupleAffineType : public AffineType { |
143 | public: |
144 | TVM_DLL TupleAffineType(Array<TensorAffineType> types); |
145 | |
146 | TVM_DEFINE_OBJECT_REF_METHODS(TupleAffineType, AffineType, TupleAffineTypeNode); |
147 | }; |
148 | |
149 | } // namespace tvm |
150 | #endif // TVM_IR_AFFINE_TYPE_H_ |
151 | |