1#include "taichi/ir/type.h"
2
3#include "taichi/ir/type_factory.h"
4#include "taichi/ir/type_utils.h"
5
6namespace taichi::lang {
7
8// Note: these primitive types should never be freed. They are supposed to live
9// together with the process. This is a temporary solution. Later we should
10// manage its ownership more systematically.
11
12// This part doesn't look good, but we will remove it soon anyway.
13#define PER_TYPE(x) \
14 DataType PrimitiveType::x = DataType( \
15 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::x));
16
17#include "taichi/inc/data_type.inc.h"
18#undef PER_TYPE
19
20DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
21}
22
23DataType PrimitiveType::get(PrimitiveTypeID t) {
24 if (false) {
25 }
26#define PER_TYPE(x) else if (t == PrimitiveTypeID::x) return PrimitiveType::x;
27#include "taichi/inc/data_type.inc.h"
28#undef PER_TYPE
29 else {
30 TI_NOT_IMPLEMENTED
31 }
32}
33
34std::size_t DataType::hash() const {
35 if (auto primitive = ptr_->cast<PrimitiveType>()) {
36 return (std::size_t)primitive->type;
37 } else if (auto pointer = ptr_->cast<PointerType>()) {
38 return 10007 + DataType(pointer->get_pointee_type()).hash();
39 } else {
40 TI_NOT_IMPLEMENTED
41 }
42}
43
44bool DataType::is_pointer() const {
45 return ptr_->is<PointerType>();
46}
47
48void DataType::set_is_pointer(bool is_ptr) {
49 if (is_ptr && !ptr_->is<PointerType>()) {
50 ptr_ = TypeFactory::get_instance().get_pointer_type(ptr_);
51 }
52 if (!is_ptr && ptr_->is<PointerType>()) {
53 ptr_ = ptr_->cast<PointerType>()->get_pointee_type();
54 }
55}
56
57DataType DataType::ptr_removed() const {
58 auto t = ptr_;
59 auto ptr_type = t->cast<PointerType>();
60 if (ptr_type) {
61 return DataType(ptr_type->get_pointee_type());
62 } else {
63 return *this;
64 }
65}
66
67std::vector<int> DataType::get_shape() const {
68 if (ptr_->is<TensorType>()) {
69 return ptr_->as<TensorType>()->get_shape();
70 }
71
72 return {};
73}
74
75DataType DataType::get_element_type() const {
76 if (ptr_->is<TensorType>()) {
77 return ptr_->as<TensorType>()->get_element_type();
78 }
79
80 return *this;
81}
82
83std::string PrimitiveType::to_string() const {
84 return data_type_name(DataType(const_cast<PrimitiveType *>(this)));
85}
86
87std::string PointerType::to_string() const {
88 if (is_bit_pointer_) {
89 // "^" for bit-level pointers
90 return fmt::format("^{}", pointee_->to_string());
91 } else {
92 // "*" for C-style byte-level pointers
93 return fmt::format("*{}", pointee_->to_string());
94 }
95}
96
97std::string TensorType::to_string() const {
98 std::string s = "[Tensor (";
99 for (int i = 0; i < (int)shape_.size(); ++i) {
100 s += fmt::format(i == 0 ? "{}" : ", {}", shape_[i]);
101 }
102 s += fmt::format(") {}]", element_->to_string());
103 return s;
104}
105
106size_t TensorType::get_element_offset(int ind) const {
107 return data_type_size(element_) * ind;
108}
109
110std::string StructType::to_string() const {
111 std::string s = fmt::format("struct[{}]{{", layout_);
112 for (int i = 0; i < elements_.size(); i++) {
113 if (i) {
114 s += ", ";
115 }
116 s += fmt::format("{}({}, at {}B): {}", i, elements_[i].name,
117 elements_[i].offset, elements_[i].type->to_string());
118 }
119 s += "}";
120 return s;
121}
122
123const Type *StructType::get_element_type(
124 const std::vector<int> &indices) const {
125 const Type *type_now = this;
126 for (auto ind : indices) {
127 if (auto tensor_type = type_now->cast<TensorType>()) {
128 TI_ASSERT(ind < tensor_type->get_num_elements())
129 type_now = tensor_type->get_element_type();
130 } else {
131 type_now = type_now->as<StructType>()->elements_[ind].type;
132 }
133 }
134 return type_now;
135}
136
137size_t StructType::get_element_offset(const std::vector<int> &indices) const {
138 const Type *type_now = this;
139 size_t offset = 0;
140 for (auto ind : indices) {
141 if (auto tensor_type = type_now->cast<TensorType>()) {
142 TI_ASSERT(ind < tensor_type->get_num_elements())
143 offset += tensor_type->get_element_offset(ind);
144 type_now = tensor_type->get_element_type();
145 } else {
146 offset += type_now->as<StructType>()->elements_[ind].offset;
147 type_now = type_now->as<StructType>()->elements_[ind].type;
148 }
149 }
150 return offset;
151}
152
153bool Type::is_primitive(PrimitiveTypeID type) const {
154 if (auto p = cast<PrimitiveType>()) {
155 return p->type == type;
156 } else {
157 return false;
158 }
159}
160
161std::string QuantIntType::to_string() const {
162 return fmt::format("q{}{}", is_signed_ ? 'i' : 'u', num_bits_);
163}
164
165QuantIntType::QuantIntType(int num_bits, bool is_signed, Type *compute_type)
166 : compute_type_(compute_type), num_bits_(num_bits), is_signed_(is_signed) {
167 if (compute_type == nullptr) {
168 auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32;
169 this->compute_type_ =
170 TypeFactory::get_instance().get_primitive_type(type_id);
171 }
172}
173
174QuantFixedType::QuantFixedType(Type *digits_type,
175 Type *compute_type,
176 float64 scale)
177 : digits_type_(digits_type), compute_type_(compute_type), scale_(scale) {
178 TI_ASSERT(digits_type->is<QuantIntType>());
179 TI_ASSERT(compute_type->is<PrimitiveType>());
180 TI_ASSERT(is_real(compute_type));
181}
182
183std::string QuantFixedType::to_string() const {
184 return fmt::format("qfx(d={} c={} s={})", digits_type_->to_string(),
185 compute_type_->to_string(), scale_);
186}
187
188bool QuantFixedType::get_is_signed() const {
189 return digits_type_->as<QuantIntType>()->get_is_signed();
190}
191
192QuantFloatType::QuantFloatType(Type *digits_type,
193 Type *exponent_type,
194 Type *compute_type)
195 : digits_type_(digits_type),
196 exponent_type_(exponent_type),
197 compute_type_(compute_type) {
198 TI_ASSERT(digits_type->is<QuantIntType>());
199 // We only support f32 as compute type when when using exponents
200 TI_ASSERT(compute_type_->is_primitive(PrimitiveTypeID::f32));
201 // Exponent must be unsigned quant int
202 TI_ASSERT(exponent_type->is<QuantIntType>());
203 TI_ASSERT(exponent_type->as<QuantIntType>()->get_num_bits() <= 8);
204 TI_ASSERT(exponent_type->as<QuantIntType>()->get_is_signed() == false);
205 TI_ASSERT(get_digit_bits() <= 23);
206}
207
208std::string QuantFloatType::to_string() const {
209 return fmt::format("qfl(d={} e={} c={})", digits_type_->to_string(),
210 exponent_type_->to_string(), compute_type_->to_string());
211}
212
213int QuantFloatType::get_exponent_conversion_offset() const {
214 // Note that f32 has exponent offset -127
215 return 127 - (1 << (exponent_type_->as<QuantIntType>()->get_num_bits() - 1)) +
216 1;
217}
218
219int QuantFloatType::get_digit_bits() const {
220 return digits_type_->as<QuantIntType>()->get_num_bits() -
221 (int)get_is_signed();
222}
223
224bool QuantFloatType::get_is_signed() const {
225 return digits_type_->as<QuantIntType>()->get_is_signed();
226}
227
228BitStructType::BitStructType(
229 PrimitiveType *physical_type,
230 const std::vector<Type *> &member_types,
231 const std::vector<int> &member_bit_offsets,
232 const std::vector<int> &member_exponents,
233 const std::vector<std::vector<int>> &member_exponent_users)
234 : physical_type_(physical_type),
235 member_types_(member_types),
236 member_bit_offsets_(member_bit_offsets),
237 member_exponents_(member_exponents),
238 member_exponent_users_(member_exponent_users) {
239 TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
240 TI_ASSERT(member_types_.size() == member_exponents_.size());
241 TI_ASSERT(member_types_.size() == member_exponent_users_.size());
242 int physical_type_bits = data_type_bits(physical_type_);
243 int member_total_bits = 0;
244 for (auto i = 0; i < member_types_.size(); ++i) {
245 QuantIntType *component_qit = nullptr;
246 if (auto qit = member_types_[i]->cast<QuantIntType>()) {
247 component_qit = qit;
248 } else if (auto qfxt = member_types_[i]->cast<QuantFixedType>()) {
249 component_qit = qfxt->get_digits_type()->as<QuantIntType>();
250 } else {
251 TI_ASSERT(member_types_[i]->is<QuantFloatType>());
252 auto qflt = member_types_[i]->as<QuantFloatType>();
253 component_qit = qflt->get_digits_type()->as<QuantIntType>();
254 }
255 TI_ASSERT(member_bit_offsets_[i] == member_total_bits);
256 member_total_bits += component_qit->get_num_bits();
257 }
258 TI_ASSERT(physical_type_bits >= member_total_bits);
259 for (auto i = 0; i < member_types_.size(); ++i) {
260 auto exponent = member_exponents_[i];
261 if (exponent != -1) {
262 TI_ASSERT(std::find(member_exponent_users_[exponent].begin(),
263 member_exponent_users_[exponent].end(),
264 i) != member_exponent_users_[exponent].end());
265 }
266 for (auto user : member_exponent_users_[i]) {
267 TI_ASSERT(member_exponents_[user] == i);
268 }
269 }
270}
271
272std::string BitStructType::to_string() const {
273 std::string str = "bs(";
274 int num_members = (int)member_bit_offsets_.size();
275 for (int i = 0; i < num_members; i++) {
276 str += fmt::format("{}: {}@{}", i, member_types_[i]->to_string(),
277 member_bit_offsets_[i]);
278 if (member_exponents_[i] != -1) {
279 str += fmt::format(" {}exp={}",
280 get_member_owns_shared_exponent(i) ? "shared_" : "",
281 member_exponents_[i]);
282 }
283 if (i + 1 < num_members) {
284 str += ", ";
285 }
286 }
287 return str + ")";
288}
289
290std::string QuantArrayType::to_string() const {
291 return fmt::format("qa({}x{})", element_type_->to_string(), num_elements_);
292}
293
294std::string TypedConstant::stringify() const {
295 // TODO: remove the line below after type system upgrade.
296 auto dt = this->dt.ptr_removed();
297 if (dt->is_primitive(PrimitiveTypeID::f32)) {
298 return fmt::format("{}", val_f32);
299 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
300 return fmt::format("{}", val_i32);
301 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
302 return fmt::format("{}", val_i64);
303 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
304 return fmt::format("{}", val_f64);
305 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
306 return fmt::format("{}", val_i8);
307 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
308 return fmt::format("{}", val_i16);
309 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
310 return fmt::format("{}", val_u8);
311 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
312 return fmt::format("{}", val_u16);
313 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
314 return fmt::format("{}", val_u32);
315 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
316 return fmt::format("{}", val_u64);
317 } else {
318 TI_P(data_type_name(dt));
319 TI_NOT_IMPLEMENTED
320 return "";
321 }
322}
323
324bool TypedConstant::equal_type_and_value(const TypedConstant &o) const {
325 if (dt != o.dt)
326 return false;
327 if (dt->is_primitive(PrimitiveTypeID::f32)) {
328 return val_f32 == o.val_f32;
329 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
330 return val_i32 == o.val_i32;
331 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
332 return val_i64 == o.val_i64;
333 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
334 return val_f64 == o.val_f64;
335 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
336 return val_i8 == o.val_i8;
337 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
338 return val_i16 == o.val_i16;
339 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
340 return val_u8 == o.val_u8;
341 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
342 return val_u16 == o.val_u16;
343 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
344 return val_u32 == o.val_u32;
345 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
346 return val_u64 == o.val_u64;
347 } else {
348 TI_NOT_IMPLEMENTED
349 return false;
350 }
351}
352
353int32 &TypedConstant::val_int32() {
354 TI_ASSERT(get_data_type<int32>() == dt);
355 return val_i32;
356}
357
358float32 &TypedConstant::val_float32() {
359 TI_ASSERT(get_data_type<float32>() == dt);
360 return val_f32;
361}
362
363int64 &TypedConstant::val_int64() {
364 TI_ASSERT(get_data_type<int64>() == dt);
365 return val_i64;
366}
367
368float64 &TypedConstant::val_float64() {
369 TI_ASSERT(get_data_type<float64>() == dt);
370 return val_f64;
371}
372
373int8 &TypedConstant::val_int8() {
374 TI_ASSERT(get_data_type<int8>() == dt);
375 return val_i8;
376}
377
378int16 &TypedConstant::val_int16() {
379 TI_ASSERT(get_data_type<int16>() == dt);
380 return val_i16;
381}
382
383uint8 &TypedConstant::val_uint8() {
384 TI_ASSERT(get_data_type<uint8>() == dt);
385 return val_u8;
386}
387
388uint16 &TypedConstant::val_uint16() {
389 TI_ASSERT(get_data_type<uint16>() == dt);
390 return val_u16;
391}
392
393uint32 &TypedConstant::val_uint32() {
394 TI_ASSERT(get_data_type<uint32>() == dt);
395 return val_u32;
396}
397
398uint64 &TypedConstant::val_uint64() {
399 TI_ASSERT(get_data_type<uint64>() == dt);
400 return val_u64;
401}
402
403int64 TypedConstant::val_int() const {
404 TI_ASSERT(is_signed(dt));
405 if (dt->is_primitive(PrimitiveTypeID::i32)) {
406 return val_i32;
407 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
408 return val_i64;
409 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
410 return val_i8;
411 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
412 return val_i16;
413 } else {
414 TI_NOT_IMPLEMENTED
415 }
416}
417
418uint64 TypedConstant::val_uint() const {
419 TI_ASSERT(is_unsigned(dt));
420 if (dt->is_primitive(PrimitiveTypeID::u32)) {
421 return val_u32;
422 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
423 return val_u64;
424 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
425 return val_u8;
426 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
427 return val_u16;
428 } else {
429 TI_NOT_IMPLEMENTED
430 }
431}
432
433float64 TypedConstant::val_float() const {
434 TI_ASSERT(is_real(dt));
435 if (dt->is_primitive(PrimitiveTypeID::f32)) {
436 return val_f32;
437 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
438 return val_f64;
439 } else {
440 TI_NOT_IMPLEMENTED
441 }
442}
443
444int64 TypedConstant::val_as_int64() const {
445 if (is_real(dt)) {
446 TI_ERROR("Cannot cast floating point type {} to int64.", dt->to_string());
447 } else if (is_signed(dt)) {
448 return val_int();
449 } else if (is_unsigned(dt)) {
450 return val_uint();
451 } else {
452 TI_NOT_IMPLEMENTED
453 }
454}
455
456float64 TypedConstant::val_cast_to_float64() const {
457 if (is_real(dt))
458 return val_float();
459 else if (is_signed(dt))
460 return val_int();
461 else if (is_unsigned(dt))
462 return val_uint();
463 else {
464 TI_NOT_IMPLEMENTED
465 }
466}
467
468} // namespace taichi::lang
469