1 | #include "taichi/ir/type.h" |
2 | |
3 | #include "taichi/ir/type_factory.h" |
4 | #include "taichi/ir/type_utils.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | // Note: these primitive types should never be freed. They are supposed to live |
9 | // together with the process. This is a temporary solution. Later we should |
10 | // manage its ownership more systematically. |
11 | |
12 | // This part doesn't look good, but we will remove it soon anyway. |
13 | #define PER_TYPE(x) \ |
14 | DataType PrimitiveType::x = DataType( \ |
15 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::x)); |
16 | |
17 | #include "taichi/inc/data_type.inc.h" |
18 | #undef PER_TYPE |
19 | |
20 | DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { |
21 | } |
22 | |
23 | DataType PrimitiveType::get(PrimitiveTypeID t) { |
24 | if (false) { |
25 | } |
26 | #define PER_TYPE(x) else if (t == PrimitiveTypeID::x) return PrimitiveType::x; |
27 | #include "taichi/inc/data_type.inc.h" |
28 | #undef PER_TYPE |
29 | else { |
30 | TI_NOT_IMPLEMENTED |
31 | } |
32 | } |
33 | |
34 | std::size_t DataType::hash() const { |
35 | if (auto primitive = ptr_->cast<PrimitiveType>()) { |
36 | return (std::size_t)primitive->type; |
37 | } else if (auto pointer = ptr_->cast<PointerType>()) { |
38 | return 10007 + DataType(pointer->get_pointee_type()).hash(); |
39 | } else { |
40 | TI_NOT_IMPLEMENTED |
41 | } |
42 | } |
43 | |
44 | bool DataType::is_pointer() const { |
45 | return ptr_->is<PointerType>(); |
46 | } |
47 | |
48 | void DataType::set_is_pointer(bool is_ptr) { |
49 | if (is_ptr && !ptr_->is<PointerType>()) { |
50 | ptr_ = TypeFactory::get_instance().get_pointer_type(ptr_); |
51 | } |
52 | if (!is_ptr && ptr_->is<PointerType>()) { |
53 | ptr_ = ptr_->cast<PointerType>()->get_pointee_type(); |
54 | } |
55 | } |
56 | |
57 | DataType DataType::ptr_removed() const { |
58 | auto t = ptr_; |
59 | auto ptr_type = t->cast<PointerType>(); |
60 | if (ptr_type) { |
61 | return DataType(ptr_type->get_pointee_type()); |
62 | } else { |
63 | return *this; |
64 | } |
65 | } |
66 | |
67 | std::vector<int> DataType::get_shape() const { |
68 | if (ptr_->is<TensorType>()) { |
69 | return ptr_->as<TensorType>()->get_shape(); |
70 | } |
71 | |
72 | return {}; |
73 | } |
74 | |
75 | DataType DataType::get_element_type() const { |
76 | if (ptr_->is<TensorType>()) { |
77 | return ptr_->as<TensorType>()->get_element_type(); |
78 | } |
79 | |
80 | return *this; |
81 | } |
82 | |
83 | std::string PrimitiveType::to_string() const { |
84 | return data_type_name(DataType(const_cast<PrimitiveType *>(this))); |
85 | } |
86 | |
87 | std::string PointerType::to_string() const { |
88 | if (is_bit_pointer_) { |
89 | // "^" for bit-level pointers |
90 | return fmt::format("^{}" , pointee_->to_string()); |
91 | } else { |
92 | // "*" for C-style byte-level pointers |
93 | return fmt::format("*{}" , pointee_->to_string()); |
94 | } |
95 | } |
96 | |
97 | std::string TensorType::to_string() const { |
98 | std::string s = "[Tensor (" ; |
99 | for (int i = 0; i < (int)shape_.size(); ++i) { |
100 | s += fmt::format(i == 0 ? "{}" : ", {}" , shape_[i]); |
101 | } |
102 | s += fmt::format(") {}]" , element_->to_string()); |
103 | return s; |
104 | } |
105 | |
106 | size_t TensorType::get_element_offset(int ind) const { |
107 | return data_type_size(element_) * ind; |
108 | } |
109 | |
110 | std::string StructType::to_string() const { |
111 | std::string s = fmt::format("struct[{}]{{" , layout_); |
112 | for (int i = 0; i < elements_.size(); i++) { |
113 | if (i) { |
114 | s += ", " ; |
115 | } |
116 | s += fmt::format("{}({}, at {}B): {}" , i, elements_[i].name, |
117 | elements_[i].offset, elements_[i].type->to_string()); |
118 | } |
119 | s += "}" ; |
120 | return s; |
121 | } |
122 | |
123 | const Type *StructType::get_element_type( |
124 | const std::vector<int> &indices) const { |
125 | const Type *type_now = this; |
126 | for (auto ind : indices) { |
127 | if (auto tensor_type = type_now->cast<TensorType>()) { |
128 | TI_ASSERT(ind < tensor_type->get_num_elements()) |
129 | type_now = tensor_type->get_element_type(); |
130 | } else { |
131 | type_now = type_now->as<StructType>()->elements_[ind].type; |
132 | } |
133 | } |
134 | return type_now; |
135 | } |
136 | |
137 | size_t StructType::get_element_offset(const std::vector<int> &indices) const { |
138 | const Type *type_now = this; |
139 | size_t offset = 0; |
140 | for (auto ind : indices) { |
141 | if (auto tensor_type = type_now->cast<TensorType>()) { |
142 | TI_ASSERT(ind < tensor_type->get_num_elements()) |
143 | offset += tensor_type->get_element_offset(ind); |
144 | type_now = tensor_type->get_element_type(); |
145 | } else { |
146 | offset += type_now->as<StructType>()->elements_[ind].offset; |
147 | type_now = type_now->as<StructType>()->elements_[ind].type; |
148 | } |
149 | } |
150 | return offset; |
151 | } |
152 | |
153 | bool Type::is_primitive(PrimitiveTypeID type) const { |
154 | if (auto p = cast<PrimitiveType>()) { |
155 | return p->type == type; |
156 | } else { |
157 | return false; |
158 | } |
159 | } |
160 | |
161 | std::string QuantIntType::to_string() const { |
162 | return fmt::format("q{}{}" , is_signed_ ? 'i' : 'u', num_bits_); |
163 | } |
164 | |
165 | QuantIntType::QuantIntType(int num_bits, bool is_signed, Type *compute_type) |
166 | : compute_type_(compute_type), num_bits_(num_bits), is_signed_(is_signed) { |
167 | if (compute_type == nullptr) { |
168 | auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32; |
169 | this->compute_type_ = |
170 | TypeFactory::get_instance().get_primitive_type(type_id); |
171 | } |
172 | } |
173 | |
174 | QuantFixedType::QuantFixedType(Type *digits_type, |
175 | Type *compute_type, |
176 | float64 scale) |
177 | : digits_type_(digits_type), compute_type_(compute_type), scale_(scale) { |
178 | TI_ASSERT(digits_type->is<QuantIntType>()); |
179 | TI_ASSERT(compute_type->is<PrimitiveType>()); |
180 | TI_ASSERT(is_real(compute_type)); |
181 | } |
182 | |
183 | std::string QuantFixedType::to_string() const { |
184 | return fmt::format("qfx(d={} c={} s={})" , digits_type_->to_string(), |
185 | compute_type_->to_string(), scale_); |
186 | } |
187 | |
188 | bool QuantFixedType::get_is_signed() const { |
189 | return digits_type_->as<QuantIntType>()->get_is_signed(); |
190 | } |
191 | |
192 | QuantFloatType::QuantFloatType(Type *digits_type, |
193 | Type *exponent_type, |
194 | Type *compute_type) |
195 | : digits_type_(digits_type), |
196 | exponent_type_(exponent_type), |
197 | compute_type_(compute_type) { |
198 | TI_ASSERT(digits_type->is<QuantIntType>()); |
199 | // We only support f32 as compute type when when using exponents |
200 | TI_ASSERT(compute_type_->is_primitive(PrimitiveTypeID::f32)); |
201 | // Exponent must be unsigned quant int |
202 | TI_ASSERT(exponent_type->is<QuantIntType>()); |
203 | TI_ASSERT(exponent_type->as<QuantIntType>()->get_num_bits() <= 8); |
204 | TI_ASSERT(exponent_type->as<QuantIntType>()->get_is_signed() == false); |
205 | TI_ASSERT(get_digit_bits() <= 23); |
206 | } |
207 | |
208 | std::string QuantFloatType::to_string() const { |
209 | return fmt::format("qfl(d={} e={} c={})" , digits_type_->to_string(), |
210 | exponent_type_->to_string(), compute_type_->to_string()); |
211 | } |
212 | |
213 | int QuantFloatType::get_exponent_conversion_offset() const { |
214 | // Note that f32 has exponent offset -127 |
215 | return 127 - (1 << (exponent_type_->as<QuantIntType>()->get_num_bits() - 1)) + |
216 | 1; |
217 | } |
218 | |
219 | int QuantFloatType::get_digit_bits() const { |
220 | return digits_type_->as<QuantIntType>()->get_num_bits() - |
221 | (int)get_is_signed(); |
222 | } |
223 | |
224 | bool QuantFloatType::get_is_signed() const { |
225 | return digits_type_->as<QuantIntType>()->get_is_signed(); |
226 | } |
227 | |
228 | BitStructType::BitStructType( |
229 | PrimitiveType *physical_type, |
230 | const std::vector<Type *> &member_types, |
231 | const std::vector<int> &member_bit_offsets, |
232 | const std::vector<int> &member_exponents, |
233 | const std::vector<std::vector<int>> &member_exponent_users) |
234 | : physical_type_(physical_type), |
235 | member_types_(member_types), |
236 | member_bit_offsets_(member_bit_offsets), |
237 | member_exponents_(member_exponents), |
238 | member_exponent_users_(member_exponent_users) { |
239 | TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); |
240 | TI_ASSERT(member_types_.size() == member_exponents_.size()); |
241 | TI_ASSERT(member_types_.size() == member_exponent_users_.size()); |
242 | int physical_type_bits = data_type_bits(physical_type_); |
243 | int member_total_bits = 0; |
244 | for (auto i = 0; i < member_types_.size(); ++i) { |
245 | QuantIntType *component_qit = nullptr; |
246 | if (auto qit = member_types_[i]->cast<QuantIntType>()) { |
247 | component_qit = qit; |
248 | } else if (auto qfxt = member_types_[i]->cast<QuantFixedType>()) { |
249 | component_qit = qfxt->get_digits_type()->as<QuantIntType>(); |
250 | } else { |
251 | TI_ASSERT(member_types_[i]->is<QuantFloatType>()); |
252 | auto qflt = member_types_[i]->as<QuantFloatType>(); |
253 | component_qit = qflt->get_digits_type()->as<QuantIntType>(); |
254 | } |
255 | TI_ASSERT(member_bit_offsets_[i] == member_total_bits); |
256 | member_total_bits += component_qit->get_num_bits(); |
257 | } |
258 | TI_ASSERT(physical_type_bits >= member_total_bits); |
259 | for (auto i = 0; i < member_types_.size(); ++i) { |
260 | auto exponent = member_exponents_[i]; |
261 | if (exponent != -1) { |
262 | TI_ASSERT(std::find(member_exponent_users_[exponent].begin(), |
263 | member_exponent_users_[exponent].end(), |
264 | i) != member_exponent_users_[exponent].end()); |
265 | } |
266 | for (auto user : member_exponent_users_[i]) { |
267 | TI_ASSERT(member_exponents_[user] == i); |
268 | } |
269 | } |
270 | } |
271 | |
272 | std::string BitStructType::to_string() const { |
273 | std::string str = "bs(" ; |
274 | int num_members = (int)member_bit_offsets_.size(); |
275 | for (int i = 0; i < num_members; i++) { |
276 | str += fmt::format("{}: {}@{}" , i, member_types_[i]->to_string(), |
277 | member_bit_offsets_[i]); |
278 | if (member_exponents_[i] != -1) { |
279 | str += fmt::format(" {}exp={}" , |
280 | get_member_owns_shared_exponent(i) ? "shared_" : "" , |
281 | member_exponents_[i]); |
282 | } |
283 | if (i + 1 < num_members) { |
284 | str += ", " ; |
285 | } |
286 | } |
287 | return str + ")" ; |
288 | } |
289 | |
290 | std::string QuantArrayType::to_string() const { |
291 | return fmt::format("qa({}x{})" , element_type_->to_string(), num_elements_); |
292 | } |
293 | |
294 | std::string TypedConstant::stringify() const { |
295 | // TODO: remove the line below after type system upgrade. |
296 | auto dt = this->dt.ptr_removed(); |
297 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
298 | return fmt::format("{}" , val_f32); |
299 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
300 | return fmt::format("{}" , val_i32); |
301 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
302 | return fmt::format("{}" , val_i64); |
303 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
304 | return fmt::format("{}" , val_f64); |
305 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
306 | return fmt::format("{}" , val_i8); |
307 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
308 | return fmt::format("{}" , val_i16); |
309 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
310 | return fmt::format("{}" , val_u8); |
311 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
312 | return fmt::format("{}" , val_u16); |
313 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
314 | return fmt::format("{}" , val_u32); |
315 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
316 | return fmt::format("{}" , val_u64); |
317 | } else { |
318 | TI_P(data_type_name(dt)); |
319 | TI_NOT_IMPLEMENTED |
320 | return "" ; |
321 | } |
322 | } |
323 | |
324 | bool TypedConstant::equal_type_and_value(const TypedConstant &o) const { |
325 | if (dt != o.dt) |
326 | return false; |
327 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
328 | return val_f32 == o.val_f32; |
329 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
330 | return val_i32 == o.val_i32; |
331 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
332 | return val_i64 == o.val_i64; |
333 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
334 | return val_f64 == o.val_f64; |
335 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
336 | return val_i8 == o.val_i8; |
337 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
338 | return val_i16 == o.val_i16; |
339 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
340 | return val_u8 == o.val_u8; |
341 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
342 | return val_u16 == o.val_u16; |
343 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
344 | return val_u32 == o.val_u32; |
345 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
346 | return val_u64 == o.val_u64; |
347 | } else { |
348 | TI_NOT_IMPLEMENTED |
349 | return false; |
350 | } |
351 | } |
352 | |
353 | int32 &TypedConstant::val_int32() { |
354 | TI_ASSERT(get_data_type<int32>() == dt); |
355 | return val_i32; |
356 | } |
357 | |
358 | float32 &TypedConstant::val_float32() { |
359 | TI_ASSERT(get_data_type<float32>() == dt); |
360 | return val_f32; |
361 | } |
362 | |
363 | int64 &TypedConstant::val_int64() { |
364 | TI_ASSERT(get_data_type<int64>() == dt); |
365 | return val_i64; |
366 | } |
367 | |
368 | float64 &TypedConstant::val_float64() { |
369 | TI_ASSERT(get_data_type<float64>() == dt); |
370 | return val_f64; |
371 | } |
372 | |
373 | int8 &TypedConstant::val_int8() { |
374 | TI_ASSERT(get_data_type<int8>() == dt); |
375 | return val_i8; |
376 | } |
377 | |
378 | int16 &TypedConstant::val_int16() { |
379 | TI_ASSERT(get_data_type<int16>() == dt); |
380 | return val_i16; |
381 | } |
382 | |
383 | uint8 &TypedConstant::val_uint8() { |
384 | TI_ASSERT(get_data_type<uint8>() == dt); |
385 | return val_u8; |
386 | } |
387 | |
388 | uint16 &TypedConstant::val_uint16() { |
389 | TI_ASSERT(get_data_type<uint16>() == dt); |
390 | return val_u16; |
391 | } |
392 | |
393 | uint32 &TypedConstant::val_uint32() { |
394 | TI_ASSERT(get_data_type<uint32>() == dt); |
395 | return val_u32; |
396 | } |
397 | |
398 | uint64 &TypedConstant::val_uint64() { |
399 | TI_ASSERT(get_data_type<uint64>() == dt); |
400 | return val_u64; |
401 | } |
402 | |
403 | int64 TypedConstant::val_int() const { |
404 | TI_ASSERT(is_signed(dt)); |
405 | if (dt->is_primitive(PrimitiveTypeID::i32)) { |
406 | return val_i32; |
407 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
408 | return val_i64; |
409 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
410 | return val_i8; |
411 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
412 | return val_i16; |
413 | } else { |
414 | TI_NOT_IMPLEMENTED |
415 | } |
416 | } |
417 | |
418 | uint64 TypedConstant::val_uint() const { |
419 | TI_ASSERT(is_unsigned(dt)); |
420 | if (dt->is_primitive(PrimitiveTypeID::u32)) { |
421 | return val_u32; |
422 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
423 | return val_u64; |
424 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
425 | return val_u8; |
426 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
427 | return val_u16; |
428 | } else { |
429 | TI_NOT_IMPLEMENTED |
430 | } |
431 | } |
432 | |
433 | float64 TypedConstant::val_float() const { |
434 | TI_ASSERT(is_real(dt)); |
435 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
436 | return val_f32; |
437 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
438 | return val_f64; |
439 | } else { |
440 | TI_NOT_IMPLEMENTED |
441 | } |
442 | } |
443 | |
444 | int64 TypedConstant::val_as_int64() const { |
445 | if (is_real(dt)) { |
446 | TI_ERROR("Cannot cast floating point type {} to int64." , dt->to_string()); |
447 | } else if (is_signed(dt)) { |
448 | return val_int(); |
449 | } else if (is_unsigned(dt)) { |
450 | return val_uint(); |
451 | } else { |
452 | TI_NOT_IMPLEMENTED |
453 | } |
454 | } |
455 | |
456 | float64 TypedConstant::val_cast_to_float64() const { |
457 | if (is_real(dt)) |
458 | return val_float(); |
459 | else if (is_signed(dt)) |
460 | return val_int(); |
461 | else if (is_unsigned(dt)) |
462 | return val_uint(); |
463 | else { |
464 | TI_NOT_IMPLEMENTED |
465 | } |
466 | } |
467 | |
468 | } // namespace taichi::lang |
469 | |