1 | #pragma once |
2 | |
3 | #include "taichi/ir/type.h" |
4 | #include "taichi/ir/type_factory.h" |
5 | #include "taichi/rhi/arch.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | std::vector<int> data_type_shape(DataType t); |
10 | |
11 | TI_DLL_EXPORT std::string data_type_name(DataType t); |
12 | |
13 | TI_DLL_EXPORT int data_type_size(DataType t); |
14 | |
15 | TI_DLL_EXPORT std::string data_type_format(DataType dt, Arch arch = Arch::x64); |
16 | |
17 | inline int data_type_bits(DataType t) { |
18 | return data_type_size(t) * 8; |
19 | } |
20 | |
21 | template <typename T> |
22 | inline DataType get_data_type() { |
23 | if (std::is_same<T, float32>()) { |
24 | return PrimitiveType::f32; |
25 | } else if (std::is_same<T, float64>()) { |
26 | return PrimitiveType::f64; |
27 | } else if (std::is_same<T, bool>()) { |
28 | return PrimitiveType::u1; |
29 | } else if (std::is_same<T, int8>()) { |
30 | return PrimitiveType::i8; |
31 | } else if (std::is_same<T, int16>()) { |
32 | return PrimitiveType::i16; |
33 | } else if (std::is_same<T, int32>()) { |
34 | return PrimitiveType::i32; |
35 | } else if (std::is_same<T, int64>()) { |
36 | return PrimitiveType::i64; |
37 | } else if (std::is_same<T, uint8>()) { |
38 | return PrimitiveType::u8; |
39 | } else if (std::is_same<T, uint16>()) { |
40 | return PrimitiveType::u16; |
41 | } else if (std::is_same<T, uint32>()) { |
42 | return PrimitiveType::u32; |
43 | } else if (std::is_same<T, uint64>()) { |
44 | return PrimitiveType::u64; |
45 | } else { |
46 | TI_NOT_IMPLEMENTED; |
47 | } |
48 | } |
49 | |
50 | template <typename T> |
51 | inline PrimitiveTypeID get_primitive_data_type() { |
52 | if (std::is_same<T, float32>()) { |
53 | return PrimitiveTypeID::f32; |
54 | } else if (std::is_same<T, float64>()) { |
55 | return PrimitiveTypeID::f64; |
56 | } else if (std::is_same<T, bool>()) { |
57 | return PrimitiveTypeID::u1; |
58 | } else if (std::is_same<T, int8>()) { |
59 | return PrimitiveTypeID::i8; |
60 | } else if (std::is_same<T, int16>()) { |
61 | return PrimitiveTypeID::i16; |
62 | } else if (std::is_same<T, int32>()) { |
63 | return PrimitiveTypeID::i32; |
64 | } else if (std::is_same<T, int64>()) { |
65 | return PrimitiveTypeID::i64; |
66 | } else if (std::is_same<T, uint8>()) { |
67 | return PrimitiveTypeID::u8; |
68 | } else if (std::is_same<T, uint16>()) { |
69 | return PrimitiveTypeID::u16; |
70 | } else if (std::is_same<T, uint32>()) { |
71 | return PrimitiveTypeID::u32; |
72 | } else if (std::is_same<T, uint64>()) { |
73 | return PrimitiveTypeID::u64; |
74 | } else { |
75 | TI_NOT_IMPLEMENTED; |
76 | } |
77 | } |
78 | |
79 | inline bool is_tensor(DataType dt) { |
80 | return dt->is<TensorType>(); |
81 | } |
82 | |
83 | inline bool is_quant(DataType dt) { |
84 | return dt->is<QuantIntType>() || dt->is<QuantFixedType>() || |
85 | dt->is<QuantFloatType>(); |
86 | } |
87 | |
88 | inline bool is_real(DataType dt) { |
89 | return dt->is_primitive(PrimitiveTypeID::f16) || |
90 | dt->is_primitive(PrimitiveTypeID::f32) || |
91 | dt->is_primitive(PrimitiveTypeID::f64) || dt->is<QuantFixedType>() || |
92 | dt->is<QuantFloatType>(); |
93 | } |
94 | |
95 | inline bool is_integral(DataType dt) { |
96 | return dt->is_primitive(PrimitiveTypeID::i8) || |
97 | dt->is_primitive(PrimitiveTypeID::i16) || |
98 | dt->is_primitive(PrimitiveTypeID::i32) || |
99 | dt->is_primitive(PrimitiveTypeID::i64) || |
100 | dt->is_primitive(PrimitiveTypeID::u8) || |
101 | dt->is_primitive(PrimitiveTypeID::u16) || |
102 | dt->is_primitive(PrimitiveTypeID::u32) || |
103 | dt->is_primitive(PrimitiveTypeID::u64) || dt->is<QuantIntType>(); |
104 | } |
105 | |
106 | inline bool is_signed(DataType dt) { |
107 | // Shall we return false if is_integral returns false? |
108 | TI_ASSERT(is_integral(dt)); |
109 | if (auto t = dt->cast<QuantIntType>()) |
110 | return t->get_is_signed(); |
111 | return dt->is_primitive(PrimitiveTypeID::i8) || |
112 | dt->is_primitive(PrimitiveTypeID::i16) || |
113 | dt->is_primitive(PrimitiveTypeID::i32) || |
114 | dt->is_primitive(PrimitiveTypeID::i64); |
115 | } |
116 | |
117 | inline bool is_unsigned(DataType dt) { |
118 | TI_ASSERT(is_integral(dt)); |
119 | return !is_signed(dt); |
120 | } |
121 | |
122 | inline DataType to_unsigned(DataType dt) { |
123 | TI_ASSERT(is_signed(dt)); |
124 | if (dt->is_primitive(PrimitiveTypeID::i8)) |
125 | return PrimitiveType::u8; |
126 | else if (dt->is_primitive(PrimitiveTypeID::i16)) |
127 | return PrimitiveType::u16; |
128 | else if (dt->is_primitive(PrimitiveTypeID::i32)) |
129 | return PrimitiveType::u32; |
130 | else if (dt->is_primitive(PrimitiveTypeID::i64)) |
131 | return PrimitiveType::u64; |
132 | else |
133 | return PrimitiveType::unknown; |
134 | } |
135 | |
136 | inline TypedConstant get_max_value(DataType dt) { |
137 | if (dt->is_primitive(PrimitiveTypeID::i8)) { |
138 | return {dt, std::numeric_limits<int8>::max()}; |
139 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
140 | return {dt, std::numeric_limits<int16>::max()}; |
141 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
142 | return {dt, std::numeric_limits<int32>::max()}; |
143 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
144 | return {dt, std::numeric_limits<int64>::max()}; |
145 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
146 | return {dt, std::numeric_limits<uint8>::max()}; |
147 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
148 | return {dt, std::numeric_limits<uint16>::max()}; |
149 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
150 | return {dt, std::numeric_limits<uint32>::max()}; |
151 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
152 | return {dt, std::numeric_limits<uint64>::max()}; |
153 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
154 | return {dt, std::numeric_limits<float32>::max()}; |
155 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
156 | return {dt, std::numeric_limits<float64>::max()}; |
157 | } else { |
158 | TI_NOT_IMPLEMENTED; |
159 | } |
160 | } |
161 | |
162 | inline TypedConstant get_min_value(DataType dt) { |
163 | if (dt->is_primitive(PrimitiveTypeID::i8)) { |
164 | return {dt, std::numeric_limits<int8>::min()}; |
165 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
166 | return {dt, std::numeric_limits<int16>::min()}; |
167 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
168 | return {dt, std::numeric_limits<int32>::min()}; |
169 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
170 | return {dt, std::numeric_limits<int64>::min()}; |
171 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
172 | return {dt, std::numeric_limits<uint8>::min()}; |
173 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
174 | return {dt, std::numeric_limits<uint16>::min()}; |
175 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
176 | return {dt, std::numeric_limits<uint32>::min()}; |
177 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
178 | return {dt, std::numeric_limits<uint64>::min()}; |
179 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
180 | return {dt, std::numeric_limits<float32>::min()}; |
181 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
182 | return {dt, std::numeric_limits<float64>::min()}; |
183 | } else { |
184 | TI_NOT_IMPLEMENTED; |
185 | } |
186 | } |
187 | |
188 | class BitStructTypeBuilder { |
189 | public: |
190 | explicit BitStructTypeBuilder(int max_num_bits) { |
191 | physical_type_ = |
192 | TypeFactory::get_instance().get_primitive_int_type(max_num_bits); |
193 | } |
194 | |
195 | int add_member(Type *member_type) { |
196 | if (auto qflt = member_type->cast<QuantFloatType>()) { |
197 | auto exponent_type = qflt->get_exponent_type(); |
198 | auto exponent_id = -1; |
199 | if (is_placing_shared_exponent_ && current_shared_exponent_ != -1) { |
200 | // Reuse existing exponent |
201 | TI_ASSERT_INFO(member_types_[current_shared_exponent_] == exponent_type, |
202 | "QuantFloatTypes with shared exponents must have " |
203 | "exactly the same exponent type." ); |
204 | exponent_id = current_shared_exponent_; |
205 | } else { |
206 | exponent_id = add_member_impl(exponent_type); |
207 | if (is_placing_shared_exponent_) { |
208 | current_shared_exponent_ = exponent_id; |
209 | } |
210 | } |
211 | auto digits_id = add_member_impl(member_type); |
212 | member_exponents_[digits_id] = exponent_id; |
213 | member_exponent_users_[exponent_id].push_back(digits_id); |
214 | return digits_id; |
215 | } |
216 | return add_member_impl(member_type); |
217 | } |
218 | |
219 | void begin_placing_shared_exponent() { |
220 | TI_ASSERT(!is_placing_shared_exponent_); |
221 | TI_ASSERT(current_shared_exponent_ == -1); |
222 | is_placing_shared_exponent_ = true; |
223 | } |
224 | |
225 | void end_placing_shared_exponent() { |
226 | TI_ASSERT(is_placing_shared_exponent_); |
227 | TI_ASSERT(current_shared_exponent_ != -1); |
228 | current_shared_exponent_ = -1; |
229 | is_placing_shared_exponent_ = false; |
230 | } |
231 | |
232 | BitStructType *build() const { |
233 | return TypeFactory::get_instance().get_bit_struct_type( |
234 | physical_type_, member_types_, member_bit_offsets_, member_exponents_, |
235 | member_exponent_users_); |
236 | } |
237 | |
238 | private: |
239 | int add_member_impl(Type *member_type) { |
240 | int old_num_members = member_types_.size(); |
241 | member_types_.push_back(member_type); |
242 | member_bit_offsets_.push_back(member_total_bits_); |
243 | member_exponents_.push_back(-1); |
244 | member_exponent_users_.push_back({}); |
245 | QuantIntType *member_qit = nullptr; |
246 | if (auto qit = member_type->cast<QuantIntType>()) { |
247 | member_qit = qit; |
248 | } else if (auto qfxt = member_type->cast<QuantFixedType>()) { |
249 | member_qit = qfxt->get_digits_type()->as<QuantIntType>(); |
250 | } else if (auto qflt = member_type->cast<QuantFloatType>()) { |
251 | member_qit = qflt->get_digits_type()->as<QuantIntType>(); |
252 | } else { |
253 | TI_ERROR("Only a QuantType can be a member of a BitStructType." ); |
254 | } |
255 | member_total_bits_ += member_qit->get_num_bits(); |
256 | auto physical_bits = data_type_bits(physical_type_); |
257 | TI_ERROR_IF(member_total_bits_ > physical_bits, |
258 | "BitStructType overflows: {} bits used out of {}." , |
259 | member_total_bits_, physical_bits); |
260 | return old_num_members; |
261 | } |
262 | |
263 | PrimitiveType *physical_type_{nullptr}; |
264 | std::vector<Type *> member_types_; |
265 | std::vector<int> member_bit_offsets_; |
266 | int member_total_bits_{0}; |
267 | std::vector<int> member_exponents_; |
268 | std::vector<std::vector<int>> member_exponent_users_; |
269 | bool is_placing_shared_exponent_{false}; |
270 | int current_shared_exponent_{-1}; |
271 | }; |
272 | |
273 | } // namespace taichi::lang |
274 | |