1 | #pragma once |
2 | |
3 | #include "taichi/common/core.h" |
4 | #include "taichi/util/bit.h" |
5 | #include "taichi/util/hash.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | class TensorType; |
10 | |
11 | enum class PrimitiveTypeID : int { |
12 | #define PER_TYPE(x) x, |
13 | #include "taichi/inc/data_type.inc.h" |
14 | #undef PER_TYPE |
15 | }; |
16 | |
17 | class TI_DLL_EXPORT Type { |
18 | public: |
19 | virtual std::string to_string() const = 0; |
20 | |
21 | template <typename T> |
22 | bool is() const { |
23 | return cast<T>() != nullptr; |
24 | } |
25 | |
26 | template <typename T> |
27 | const T *cast() const { |
28 | return dynamic_cast<const T *>(this); |
29 | } |
30 | |
31 | template <typename T> |
32 | T *cast() { |
33 | return dynamic_cast<T *>(this); |
34 | } |
35 | |
36 | template <typename T> |
37 | T *as() { |
38 | auto p = dynamic_cast<T *>(this); |
39 | TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}" , this->to_string(), |
40 | typeid(T).name()); |
41 | return p; |
42 | } |
43 | |
44 | template <typename T> |
45 | const T *as() const { |
46 | auto p = dynamic_cast<const T *>(this); |
47 | TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}" , this->to_string(), |
48 | typeid(T).name()); |
49 | return p; |
50 | } |
51 | |
52 | bool is_primitive(PrimitiveTypeID type) const; |
53 | |
54 | virtual Type *get_compute_type() { |
55 | TI_NOT_IMPLEMENTED; |
56 | } |
57 | |
58 | virtual ~Type() { |
59 | } |
60 | }; |
61 | |
62 | // A "Type" handle. This should be removed later. |
63 | class TI_DLL_EXPORT DataType { |
64 | public: |
65 | DataType(); |
66 | |
67 | // NOLINTNEXTLINE(google-explicit-constructor) |
68 | DataType(const Type *ptr) : ptr_((Type *)ptr) { |
69 | } |
70 | |
71 | DataType(const DataType &o) : ptr_(o.ptr_) { |
72 | } |
73 | |
74 | bool operator==(const DataType &o) const { |
75 | return ptr_ == o.ptr_; |
76 | } |
77 | |
78 | bool operator!=(const DataType &o) const { |
79 | return !(*this == o); |
80 | } |
81 | |
82 | std::size_t hash() const; |
83 | |
84 | std::string to_string() const { |
85 | return ptr_->to_string(); |
86 | }; |
87 | |
88 | // NOLINTNEXTLINE(google-explicit-constructor) |
89 | operator const Type *() const { |
90 | return ptr_; |
91 | } |
92 | |
93 | // NOLINTNEXTLINE(google-explicit-constructor) |
94 | operator Type *() { |
95 | return ptr_; |
96 | } |
97 | |
98 | // Temporary API and members |
99 | // for LegacyVectorType-compatibility |
100 | |
101 | Type *operator->() const { |
102 | return ptr_; |
103 | } |
104 | |
105 | DataType &operator=(const DataType &o) { |
106 | ptr_ = o.ptr_; |
107 | return *this; |
108 | } |
109 | |
110 | bool is_pointer() const; |
111 | |
112 | void set_is_pointer(bool ptr); |
113 | |
114 | DataType ptr_removed() const; |
115 | |
116 | std::vector<int> get_shape() const; |
117 | |
118 | DataType get_element_type() const; |
119 | |
120 | private: |
121 | Type *ptr_; |
122 | }; |
123 | |
124 | // Note that all types are immutable once created. |
125 | |
126 | class TI_DLL_EXPORT PrimitiveType : public Type { |
127 | public: |
128 | #define PER_TYPE(x) static DataType x; |
129 | #include "taichi/inc/data_type.inc.h" |
130 | #undef PER_TYPE |
131 | |
132 | // TODO(type): make 'type' private and add a const getter |
133 | PrimitiveTypeID type; |
134 | |
135 | explicit PrimitiveType(PrimitiveTypeID type) : type(type) { |
136 | } |
137 | |
138 | std::string to_string() const override; |
139 | |
140 | Type *get_compute_type() override { |
141 | return this; |
142 | } |
143 | |
144 | static DataType get(PrimitiveTypeID type); |
145 | }; |
146 | |
147 | class PointerType : public Type { |
148 | public: |
149 | PointerType(Type *pointee, bool is_bit_pointer) |
150 | : pointee_(pointee), is_bit_pointer_(is_bit_pointer) { |
151 | } |
152 | |
153 | Type *get_pointee_type() const { |
154 | return pointee_; |
155 | } |
156 | |
157 | auto get_addr_space() const { |
158 | return addr_space_; |
159 | } |
160 | |
161 | bool is_bit_pointer() const { |
162 | return is_bit_pointer_; |
163 | } |
164 | |
165 | std::string to_string() const override; |
166 | |
167 | private: |
168 | Type *pointee_{nullptr}; |
169 | int addr_space_{0}; // TODO: make this an enum |
170 | bool is_bit_pointer_{false}; |
171 | }; |
172 | |
173 | class TensorType : public Type { |
174 | public: |
175 | TensorType(std::vector<int> shape, Type *element) |
176 | : shape_(std::move(shape)), element_(element) { |
177 | } |
178 | |
179 | Type *get_element_type() const { |
180 | return element_; |
181 | } |
182 | |
183 | int get_num_elements() const { |
184 | int num_elements = 1; |
185 | for (int i = 0; i < (int)shape_.size(); ++i) |
186 | num_elements *= shape_[i]; |
187 | return num_elements; |
188 | } |
189 | |
190 | std::vector<int> get_shape() const { |
191 | return shape_; |
192 | } |
193 | |
194 | Type *get_compute_type() override { |
195 | return this; |
196 | } |
197 | |
198 | std::string to_string() const override; |
199 | |
200 | size_t get_element_offset(int ind) const; |
201 | |
202 | private: |
203 | std::vector<int> shape_; |
204 | Type *element_{nullptr}; |
205 | }; |
206 | |
207 | struct StructMember { |
208 | const Type *type; |
209 | std::string name; |
210 | size_t offset{0}; |
211 | bool operator==(const StructMember &other) const { |
212 | return type == other.type && name == other.name && offset == other.offset; |
213 | } |
214 | }; |
215 | |
216 | class StructType : public Type { |
217 | public: |
218 | explicit StructType(const std::vector<StructMember> &elements, |
219 | const std::string &layout = "none" ) |
220 | : elements_(elements), layout_(layout) { |
221 | } |
222 | |
223 | std::string to_string() const override; |
224 | |
225 | const std::string &get_layout() const { |
226 | return layout_; |
227 | } |
228 | |
229 | const Type *get_element_type(const std::vector<int> &indices) const; |
230 | size_t get_element_offset(const std::vector<int> &indices) const; |
231 | const std::vector<StructMember> &elements() const { |
232 | return elements_; |
233 | } |
234 | |
235 | int get_num_elements() const { |
236 | int num = 0; |
237 | for (const auto &element : elements_) { |
238 | if (auto struct_type = element.type->cast<StructType>()) { |
239 | num += struct_type->get_num_elements(); |
240 | } else if (auto tensor_type = element.type->cast<TensorType>()) { |
241 | num += tensor_type->get_num_elements(); |
242 | } else { |
243 | TI_ASSERT(element.type->is<PrimitiveType>()); |
244 | num += 1; |
245 | } |
246 | } |
247 | return num; |
248 | } |
249 | |
250 | Type *get_compute_type() override { |
251 | return this; |
252 | } |
253 | |
254 | private: |
255 | std::vector<StructMember> elements_; |
256 | std::string layout_; |
257 | }; |
258 | |
259 | class QuantIntType : public Type { |
260 | public: |
261 | QuantIntType(int num_bits, bool is_signed, Type *compute_type = nullptr); |
262 | |
263 | std::string to_string() const override; |
264 | |
265 | Type *get_compute_type() override { |
266 | return compute_type_; |
267 | } |
268 | |
269 | int get_num_bits() const { |
270 | return num_bits_; |
271 | } |
272 | |
273 | bool get_is_signed() const { |
274 | return is_signed_; |
275 | } |
276 | |
277 | private: |
278 | // TODO(type): for now we can uniformly use i32 as the "compute_type". It may |
279 | // be a good idea to make "compute_type" also customizable. |
280 | Type *compute_type_{nullptr}; |
281 | int num_bits_{32}; |
282 | bool is_signed_{true}; |
283 | }; |
284 | |
285 | class QuantFixedType : public Type { |
286 | public: |
287 | QuantFixedType(Type *digits_type, Type *compute_type, float64 scale); |
288 | |
289 | std::string to_string() const override; |
290 | |
291 | bool get_is_signed() const; |
292 | |
293 | Type *get_digits_type() { |
294 | return digits_type_; |
295 | } |
296 | |
297 | Type *get_compute_type() override { |
298 | return compute_type_; |
299 | } |
300 | |
301 | float64 get_scale() const { |
302 | return scale_; |
303 | } |
304 | |
305 | private: |
306 | Type *digits_type_{nullptr}; |
307 | Type *compute_type_{nullptr}; |
308 | float64 scale_{1.0}; |
309 | }; |
310 | |
311 | class QuantFloatType : public Type { |
312 | public: |
313 | QuantFloatType(Type *digits_type, Type *exponent_type, Type *compute_type); |
314 | |
315 | std::string to_string() const override; |
316 | |
317 | Type *get_digits_type() { |
318 | return digits_type_; |
319 | } |
320 | |
321 | Type *get_exponent_type() { |
322 | return exponent_type_; |
323 | } |
324 | |
325 | int get_exponent_conversion_offset() const; |
326 | |
327 | int get_digit_bits() const; |
328 | |
329 | bool get_is_signed() const; |
330 | |
331 | Type *get_compute_type() override { |
332 | return compute_type_; |
333 | } |
334 | |
335 | private: |
336 | Type *digits_type_{nullptr}; |
337 | Type *exponent_type_{nullptr}; |
338 | Type *compute_type_{nullptr}; |
339 | }; |
340 | |
341 | class BitStructType : public Type { |
342 | public: |
343 | BitStructType(PrimitiveType *physical_type, |
344 | const std::vector<Type *> &member_types, |
345 | const std::vector<int> &member_bit_offsets, |
346 | const std::vector<int> &member_exponents, |
347 | const std::vector<std::vector<int>> &member_exponent_users); |
348 | |
349 | std::string to_string() const override; |
350 | |
351 | PrimitiveType *get_physical_type() const { |
352 | return physical_type_; |
353 | } |
354 | |
355 | int get_num_members() const { |
356 | return (int)member_types_.size(); |
357 | } |
358 | |
359 | Type *get_member_type(int i) const { |
360 | return member_types_[i]; |
361 | } |
362 | |
363 | int get_member_bit_offset(int i) const { |
364 | return member_bit_offsets_[i]; |
365 | } |
366 | |
367 | bool get_member_owns_shared_exponent(int i) const { |
368 | return member_exponents_[i] != -1 && |
369 | member_exponent_users_[member_exponents_[i]].size() > 1; |
370 | } |
371 | |
372 | int get_member_exponent(int i) const { |
373 | return member_exponents_[i]; |
374 | } |
375 | |
376 | const std::vector<int> &get_member_exponent_users(int i) const { |
377 | return member_exponent_users_[i]; |
378 | } |
379 | |
380 | private: |
381 | PrimitiveType *physical_type_; |
382 | std::vector<Type *> member_types_; |
383 | std::vector<int> member_bit_offsets_; |
384 | std::vector<int> member_exponents_; |
385 | std::vector<std::vector<int>> member_exponent_users_; |
386 | }; |
387 | |
388 | class QuantArrayType : public Type { |
389 | public: |
390 | QuantArrayType(PrimitiveType *physical_type, |
391 | Type *element_type_, |
392 | int num_elements_) |
393 | : physical_type_(physical_type), |
394 | element_type_(element_type_), |
395 | num_elements_(num_elements_) { |
396 | if (auto qit = element_type_->cast<QuantIntType>()) { |
397 | element_num_bits_ = qit->get_num_bits(); |
398 | } else if (auto qfxt = element_type_->cast<QuantFixedType>()) { |
399 | element_num_bits_ = |
400 | qfxt->get_digits_type()->as<QuantIntType>()->get_num_bits(); |
401 | } else { |
402 | TI_ERROR("Quant array only supports quant int/fixed type for now." ); |
403 | } |
404 | } |
405 | |
406 | std::string to_string() const override; |
407 | |
408 | PrimitiveType *get_physical_type() const { |
409 | return physical_type_; |
410 | } |
411 | |
412 | Type *get_element_type() const { |
413 | return element_type_; |
414 | } |
415 | |
416 | int get_num_elements() const { |
417 | return num_elements_; |
418 | } |
419 | |
420 | int get_element_num_bits() const { |
421 | return element_num_bits_; |
422 | } |
423 | |
424 | private: |
425 | PrimitiveType *physical_type_; |
426 | Type *element_type_; |
427 | int num_elements_; |
428 | int element_num_bits_; |
429 | }; |
430 | |
431 | class TypedConstant { |
432 | public: |
433 | DataType dt; |
434 | union { |
435 | uint64 value_bits; |
436 | int32 val_i32; |
437 | float32 val_f32; |
438 | int64 val_i64; |
439 | float64 val_f64; |
440 | int8 val_i8; |
441 | int16 val_i16; |
442 | uint8 val_u8; |
443 | uint16 val_u16; |
444 | uint32 val_u32; |
445 | uint64 val_u64; |
446 | }; |
447 | |
448 | public: |
449 | TypedConstant() : dt(PrimitiveType::unknown) { |
450 | } |
451 | |
452 | explicit TypedConstant(DataType dt) : dt(dt) { |
453 | TI_ASSERT_INFO(dt->is<PrimitiveType>(), |
454 | "TypedConstant can only be PrimitiveType, got {}" , |
455 | dt->to_string()); |
456 | value_bits = 0; |
457 | } |
458 | |
459 | explicit TypedConstant(int32 x) : dt(PrimitiveType::i32), val_i32(x) { |
460 | } |
461 | |
462 | explicit TypedConstant(float32 x) : dt(PrimitiveType::f32), val_f32(x) { |
463 | } |
464 | |
465 | explicit TypedConstant(int64 x) : dt(PrimitiveType::i64), val_i64(x) { |
466 | } |
467 | |
468 | explicit TypedConstant(float64 x) : dt(PrimitiveType::f64), val_f64(x) { |
469 | } |
470 | |
471 | explicit TypedConstant(int8 x) : dt(PrimitiveType::i8), val_i8(x) { |
472 | } |
473 | |
474 | explicit TypedConstant(int16 x) : dt(PrimitiveType::i16), val_i16(x) { |
475 | } |
476 | |
477 | explicit TypedConstant(uint8 x) : dt(PrimitiveType::u8), val_u8(x) { |
478 | } |
479 | |
480 | explicit TypedConstant(uint16 x) : dt(PrimitiveType::u16), val_u16(x) { |
481 | } |
482 | |
483 | explicit TypedConstant(uint32 x) : dt(PrimitiveType::u32), val_u32(x) { |
484 | } |
485 | |
486 | explicit TypedConstant(uint64 x) : dt(PrimitiveType::u64), val_u64(x) { |
487 | } |
488 | |
489 | template <typename T> |
490 | TypedConstant(DataType dt, const T &value) : dt(dt) { |
491 | // TODO: loud failure on pointers |
492 | dt.set_is_pointer(false); |
493 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
494 | val_f32 = value; |
495 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
496 | val_i32 = value; |
497 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
498 | val_i64 = value; |
499 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
500 | val_f64 = value; |
501 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
502 | val_i8 = value; |
503 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
504 | val_i16 = value; |
505 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
506 | val_u8 = value; |
507 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
508 | val_u16 = value; |
509 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
510 | val_u32 = value; |
511 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
512 | val_u64 = value; |
513 | } else { |
514 | TI_NOT_IMPLEMENTED |
515 | } |
516 | } |
517 | |
518 | template <typename T> |
519 | bool equal_value(const T &value) const { |
520 | return equal_type_and_value(TypedConstant(dt, value)); |
521 | } |
522 | |
523 | std::string stringify() const; |
524 | |
525 | bool equal_type_and_value(const TypedConstant &o) const; |
526 | |
527 | bool operator==(const TypedConstant &o) const { |
528 | return equal_type_and_value(o); |
529 | } |
530 | |
531 | int32 &val_int32(); |
532 | float32 &val_float32(); |
533 | int64 &val_int64(); |
534 | float64 &val_float64(); |
535 | int8 &val_int8(); |
536 | int16 &val_int16(); |
537 | uint8 &val_uint8(); |
538 | uint16 &val_uint16(); |
539 | uint32 &val_uint32(); |
540 | uint64 &val_uint64(); |
541 | int64 val_int() const; |
542 | uint64 val_uint() const; |
543 | float64 val_float() const; |
544 | int64 val_as_int64() const; // unifies val_int() and val_uint() |
545 | float64 val_cast_to_float64() const; |
546 | }; |
547 | |
548 | } // namespace taichi::lang |
549 | |
550 | namespace taichi::hashing { |
551 | |
552 | template <> |
553 | struct Hasher<lang::StructMember> { |
554 | public: |
555 | size_t operator()(lang::StructMember const &member) const { |
556 | size_t ret = hash_value(member.type); |
557 | hash_combine(ret, member.name); |
558 | hash_combine(ret, member.offset); |
559 | return ret; |
560 | } |
561 | }; |
562 | |
563 | } // namespace taichi::hashing |
564 | |