1#pragma once
2
3#include "taichi/common/core.h"
4#include "taichi/util/bit.h"
5#include "taichi/util/hash.h"
6
7namespace taichi::lang {
8
9class TensorType;
10
11enum class PrimitiveTypeID : int {
12#define PER_TYPE(x) x,
13#include "taichi/inc/data_type.inc.h"
14#undef PER_TYPE
15};
16
17class TI_DLL_EXPORT Type {
18 public:
19 virtual std::string to_string() const = 0;
20
21 template <typename T>
22 bool is() const {
23 return cast<T>() != nullptr;
24 }
25
26 template <typename T>
27 const T *cast() const {
28 return dynamic_cast<const T *>(this);
29 }
30
31 template <typename T>
32 T *cast() {
33 return dynamic_cast<T *>(this);
34 }
35
36 template <typename T>
37 T *as() {
38 auto p = dynamic_cast<T *>(this);
39 TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}", this->to_string(),
40 typeid(T).name());
41 return p;
42 }
43
44 template <typename T>
45 const T *as() const {
46 auto p = dynamic_cast<const T *>(this);
47 TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}", this->to_string(),
48 typeid(T).name());
49 return p;
50 }
51
52 bool is_primitive(PrimitiveTypeID type) const;
53
54 virtual Type *get_compute_type() {
55 TI_NOT_IMPLEMENTED;
56 }
57
58 virtual ~Type() {
59 }
60};
61
62// A "Type" handle. This should be removed later.
63class TI_DLL_EXPORT DataType {
64 public:
65 DataType();
66
67 // NOLINTNEXTLINE(google-explicit-constructor)
68 DataType(const Type *ptr) : ptr_((Type *)ptr) {
69 }
70
71 DataType(const DataType &o) : ptr_(o.ptr_) {
72 }
73
74 bool operator==(const DataType &o) const {
75 return ptr_ == o.ptr_;
76 }
77
78 bool operator!=(const DataType &o) const {
79 return !(*this == o);
80 }
81
82 std::size_t hash() const;
83
84 std::string to_string() const {
85 return ptr_->to_string();
86 };
87
88 // NOLINTNEXTLINE(google-explicit-constructor)
89 operator const Type *() const {
90 return ptr_;
91 }
92
93 // NOLINTNEXTLINE(google-explicit-constructor)
94 operator Type *() {
95 return ptr_;
96 }
97
98 // Temporary API and members
99 // for LegacyVectorType-compatibility
100
101 Type *operator->() const {
102 return ptr_;
103 }
104
105 DataType &operator=(const DataType &o) {
106 ptr_ = o.ptr_;
107 return *this;
108 }
109
110 bool is_pointer() const;
111
112 void set_is_pointer(bool ptr);
113
114 DataType ptr_removed() const;
115
116 std::vector<int> get_shape() const;
117
118 DataType get_element_type() const;
119
120 private:
121 Type *ptr_;
122};
123
124// Note that all types are immutable once created.
125
126class TI_DLL_EXPORT PrimitiveType : public Type {
127 public:
128#define PER_TYPE(x) static DataType x;
129#include "taichi/inc/data_type.inc.h"
130#undef PER_TYPE
131
132 // TODO(type): make 'type' private and add a const getter
133 PrimitiveTypeID type;
134
135 explicit PrimitiveType(PrimitiveTypeID type) : type(type) {
136 }
137
138 std::string to_string() const override;
139
140 Type *get_compute_type() override {
141 return this;
142 }
143
144 static DataType get(PrimitiveTypeID type);
145};
146
147class PointerType : public Type {
148 public:
149 PointerType(Type *pointee, bool is_bit_pointer)
150 : pointee_(pointee), is_bit_pointer_(is_bit_pointer) {
151 }
152
153 Type *get_pointee_type() const {
154 return pointee_;
155 }
156
157 auto get_addr_space() const {
158 return addr_space_;
159 }
160
161 bool is_bit_pointer() const {
162 return is_bit_pointer_;
163 }
164
165 std::string to_string() const override;
166
167 private:
168 Type *pointee_{nullptr};
169 int addr_space_{0}; // TODO: make this an enum
170 bool is_bit_pointer_{false};
171};
172
173class TensorType : public Type {
174 public:
175 TensorType(std::vector<int> shape, Type *element)
176 : shape_(std::move(shape)), element_(element) {
177 }
178
179 Type *get_element_type() const {
180 return element_;
181 }
182
183 int get_num_elements() const {
184 int num_elements = 1;
185 for (int i = 0; i < (int)shape_.size(); ++i)
186 num_elements *= shape_[i];
187 return num_elements;
188 }
189
190 std::vector<int> get_shape() const {
191 return shape_;
192 }
193
194 Type *get_compute_type() override {
195 return this;
196 }
197
198 std::string to_string() const override;
199
200 size_t get_element_offset(int ind) const;
201
202 private:
203 std::vector<int> shape_;
204 Type *element_{nullptr};
205};
206
207struct StructMember {
208 const Type *type;
209 std::string name;
210 size_t offset{0};
211 bool operator==(const StructMember &other) const {
212 return type == other.type && name == other.name && offset == other.offset;
213 }
214};
215
216class StructType : public Type {
217 public:
218 explicit StructType(const std::vector<StructMember> &elements,
219 const std::string &layout = "none")
220 : elements_(elements), layout_(layout) {
221 }
222
223 std::string to_string() const override;
224
225 const std::string &get_layout() const {
226 return layout_;
227 }
228
229 const Type *get_element_type(const std::vector<int> &indices) const;
230 size_t get_element_offset(const std::vector<int> &indices) const;
231 const std::vector<StructMember> &elements() const {
232 return elements_;
233 }
234
235 int get_num_elements() const {
236 int num = 0;
237 for (const auto &element : elements_) {
238 if (auto struct_type = element.type->cast<StructType>()) {
239 num += struct_type->get_num_elements();
240 } else if (auto tensor_type = element.type->cast<TensorType>()) {
241 num += tensor_type->get_num_elements();
242 } else {
243 TI_ASSERT(element.type->is<PrimitiveType>());
244 num += 1;
245 }
246 }
247 return num;
248 }
249
250 Type *get_compute_type() override {
251 return this;
252 }
253
254 private:
255 std::vector<StructMember> elements_;
256 std::string layout_;
257};
258
259class QuantIntType : public Type {
260 public:
261 QuantIntType(int num_bits, bool is_signed, Type *compute_type = nullptr);
262
263 std::string to_string() const override;
264
265 Type *get_compute_type() override {
266 return compute_type_;
267 }
268
269 int get_num_bits() const {
270 return num_bits_;
271 }
272
273 bool get_is_signed() const {
274 return is_signed_;
275 }
276
277 private:
278 // TODO(type): for now we can uniformly use i32 as the "compute_type". It may
279 // be a good idea to make "compute_type" also customizable.
280 Type *compute_type_{nullptr};
281 int num_bits_{32};
282 bool is_signed_{true};
283};
284
285class QuantFixedType : public Type {
286 public:
287 QuantFixedType(Type *digits_type, Type *compute_type, float64 scale);
288
289 std::string to_string() const override;
290
291 bool get_is_signed() const;
292
293 Type *get_digits_type() {
294 return digits_type_;
295 }
296
297 Type *get_compute_type() override {
298 return compute_type_;
299 }
300
301 float64 get_scale() const {
302 return scale_;
303 }
304
305 private:
306 Type *digits_type_{nullptr};
307 Type *compute_type_{nullptr};
308 float64 scale_{1.0};
309};
310
311class QuantFloatType : public Type {
312 public:
313 QuantFloatType(Type *digits_type, Type *exponent_type, Type *compute_type);
314
315 std::string to_string() const override;
316
317 Type *get_digits_type() {
318 return digits_type_;
319 }
320
321 Type *get_exponent_type() {
322 return exponent_type_;
323 }
324
325 int get_exponent_conversion_offset() const;
326
327 int get_digit_bits() const;
328
329 bool get_is_signed() const;
330
331 Type *get_compute_type() override {
332 return compute_type_;
333 }
334
335 private:
336 Type *digits_type_{nullptr};
337 Type *exponent_type_{nullptr};
338 Type *compute_type_{nullptr};
339};
340
341class BitStructType : public Type {
342 public:
343 BitStructType(PrimitiveType *physical_type,
344 const std::vector<Type *> &member_types,
345 const std::vector<int> &member_bit_offsets,
346 const std::vector<int> &member_exponents,
347 const std::vector<std::vector<int>> &member_exponent_users);
348
349 std::string to_string() const override;
350
351 PrimitiveType *get_physical_type() const {
352 return physical_type_;
353 }
354
355 int get_num_members() const {
356 return (int)member_types_.size();
357 }
358
359 Type *get_member_type(int i) const {
360 return member_types_[i];
361 }
362
363 int get_member_bit_offset(int i) const {
364 return member_bit_offsets_[i];
365 }
366
367 bool get_member_owns_shared_exponent(int i) const {
368 return member_exponents_[i] != -1 &&
369 member_exponent_users_[member_exponents_[i]].size() > 1;
370 }
371
372 int get_member_exponent(int i) const {
373 return member_exponents_[i];
374 }
375
376 const std::vector<int> &get_member_exponent_users(int i) const {
377 return member_exponent_users_[i];
378 }
379
380 private:
381 PrimitiveType *physical_type_;
382 std::vector<Type *> member_types_;
383 std::vector<int> member_bit_offsets_;
384 std::vector<int> member_exponents_;
385 std::vector<std::vector<int>> member_exponent_users_;
386};
387
388class QuantArrayType : public Type {
389 public:
390 QuantArrayType(PrimitiveType *physical_type,
391 Type *element_type_,
392 int num_elements_)
393 : physical_type_(physical_type),
394 element_type_(element_type_),
395 num_elements_(num_elements_) {
396 if (auto qit = element_type_->cast<QuantIntType>()) {
397 element_num_bits_ = qit->get_num_bits();
398 } else if (auto qfxt = element_type_->cast<QuantFixedType>()) {
399 element_num_bits_ =
400 qfxt->get_digits_type()->as<QuantIntType>()->get_num_bits();
401 } else {
402 TI_ERROR("Quant array only supports quant int/fixed type for now.");
403 }
404 }
405
406 std::string to_string() const override;
407
408 PrimitiveType *get_physical_type() const {
409 return physical_type_;
410 }
411
412 Type *get_element_type() const {
413 return element_type_;
414 }
415
416 int get_num_elements() const {
417 return num_elements_;
418 }
419
420 int get_element_num_bits() const {
421 return element_num_bits_;
422 }
423
424 private:
425 PrimitiveType *physical_type_;
426 Type *element_type_;
427 int num_elements_;
428 int element_num_bits_;
429};
430
431class TypedConstant {
432 public:
433 DataType dt;
434 union {
435 uint64 value_bits;
436 int32 val_i32;
437 float32 val_f32;
438 int64 val_i64;
439 float64 val_f64;
440 int8 val_i8;
441 int16 val_i16;
442 uint8 val_u8;
443 uint16 val_u16;
444 uint32 val_u32;
445 uint64 val_u64;
446 };
447
448 public:
449 TypedConstant() : dt(PrimitiveType::unknown) {
450 }
451
452 explicit TypedConstant(DataType dt) : dt(dt) {
453 TI_ASSERT_INFO(dt->is<PrimitiveType>(),
454 "TypedConstant can only be PrimitiveType, got {}",
455 dt->to_string());
456 value_bits = 0;
457 }
458
459 explicit TypedConstant(int32 x) : dt(PrimitiveType::i32), val_i32(x) {
460 }
461
462 explicit TypedConstant(float32 x) : dt(PrimitiveType::f32), val_f32(x) {
463 }
464
465 explicit TypedConstant(int64 x) : dt(PrimitiveType::i64), val_i64(x) {
466 }
467
468 explicit TypedConstant(float64 x) : dt(PrimitiveType::f64), val_f64(x) {
469 }
470
471 explicit TypedConstant(int8 x) : dt(PrimitiveType::i8), val_i8(x) {
472 }
473
474 explicit TypedConstant(int16 x) : dt(PrimitiveType::i16), val_i16(x) {
475 }
476
477 explicit TypedConstant(uint8 x) : dt(PrimitiveType::u8), val_u8(x) {
478 }
479
480 explicit TypedConstant(uint16 x) : dt(PrimitiveType::u16), val_u16(x) {
481 }
482
483 explicit TypedConstant(uint32 x) : dt(PrimitiveType::u32), val_u32(x) {
484 }
485
486 explicit TypedConstant(uint64 x) : dt(PrimitiveType::u64), val_u64(x) {
487 }
488
489 template <typename T>
490 TypedConstant(DataType dt, const T &value) : dt(dt) {
491 // TODO: loud failure on pointers
492 dt.set_is_pointer(false);
493 if (dt->is_primitive(PrimitiveTypeID::f32)) {
494 val_f32 = value;
495 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
496 val_i32 = value;
497 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
498 val_i64 = value;
499 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
500 val_f64 = value;
501 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
502 val_i8 = value;
503 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
504 val_i16 = value;
505 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
506 val_u8 = value;
507 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
508 val_u16 = value;
509 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
510 val_u32 = value;
511 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
512 val_u64 = value;
513 } else {
514 TI_NOT_IMPLEMENTED
515 }
516 }
517
518 template <typename T>
519 bool equal_value(const T &value) const {
520 return equal_type_and_value(TypedConstant(dt, value));
521 }
522
523 std::string stringify() const;
524
525 bool equal_type_and_value(const TypedConstant &o) const;
526
527 bool operator==(const TypedConstant &o) const {
528 return equal_type_and_value(o);
529 }
530
531 int32 &val_int32();
532 float32 &val_float32();
533 int64 &val_int64();
534 float64 &val_float64();
535 int8 &val_int8();
536 int16 &val_int16();
537 uint8 &val_uint8();
538 uint16 &val_uint16();
539 uint32 &val_uint32();
540 uint64 &val_uint64();
541 int64 val_int() const;
542 uint64 val_uint() const;
543 float64 val_float() const;
544 int64 val_as_int64() const; // unifies val_int() and val_uint()
545 float64 val_cast_to_float64() const;
546};
547
548} // namespace taichi::lang
549
550namespace taichi::hashing {
551
552template <>
553struct Hasher<lang::StructMember> {
554 public:
555 size_t operator()(lang::StructMember const &member) const {
556 size_t ret = hash_value(member.type);
557 hash_combine(ret, member.name);
558 hash_combine(ret, member.offset);
559 return ret;
560 }
561};
562
563} // namespace taichi::hashing
564