1#pragma once
2
3#include "taichi/ir/type.h"
4#include "taichi/ir/type_factory.h"
5#include "taichi/rhi/arch.h"
6
7namespace taichi::lang {
8
9std::vector<int> data_type_shape(DataType t);
10
11TI_DLL_EXPORT std::string data_type_name(DataType t);
12
13TI_DLL_EXPORT int data_type_size(DataType t);
14
15TI_DLL_EXPORT std::string data_type_format(DataType dt, Arch arch = Arch::x64);
16
17inline int data_type_bits(DataType t) {
18 return data_type_size(t) * 8;
19}
20
21template <typename T>
22inline DataType get_data_type() {
23 if (std::is_same<T, float32>()) {
24 return PrimitiveType::f32;
25 } else if (std::is_same<T, float64>()) {
26 return PrimitiveType::f64;
27 } else if (std::is_same<T, bool>()) {
28 return PrimitiveType::u1;
29 } else if (std::is_same<T, int8>()) {
30 return PrimitiveType::i8;
31 } else if (std::is_same<T, int16>()) {
32 return PrimitiveType::i16;
33 } else if (std::is_same<T, int32>()) {
34 return PrimitiveType::i32;
35 } else if (std::is_same<T, int64>()) {
36 return PrimitiveType::i64;
37 } else if (std::is_same<T, uint8>()) {
38 return PrimitiveType::u8;
39 } else if (std::is_same<T, uint16>()) {
40 return PrimitiveType::u16;
41 } else if (std::is_same<T, uint32>()) {
42 return PrimitiveType::u32;
43 } else if (std::is_same<T, uint64>()) {
44 return PrimitiveType::u64;
45 } else {
46 TI_NOT_IMPLEMENTED;
47 }
48}
49
50template <typename T>
51inline PrimitiveTypeID get_primitive_data_type() {
52 if (std::is_same<T, float32>()) {
53 return PrimitiveTypeID::f32;
54 } else if (std::is_same<T, float64>()) {
55 return PrimitiveTypeID::f64;
56 } else if (std::is_same<T, bool>()) {
57 return PrimitiveTypeID::u1;
58 } else if (std::is_same<T, int8>()) {
59 return PrimitiveTypeID::i8;
60 } else if (std::is_same<T, int16>()) {
61 return PrimitiveTypeID::i16;
62 } else if (std::is_same<T, int32>()) {
63 return PrimitiveTypeID::i32;
64 } else if (std::is_same<T, int64>()) {
65 return PrimitiveTypeID::i64;
66 } else if (std::is_same<T, uint8>()) {
67 return PrimitiveTypeID::u8;
68 } else if (std::is_same<T, uint16>()) {
69 return PrimitiveTypeID::u16;
70 } else if (std::is_same<T, uint32>()) {
71 return PrimitiveTypeID::u32;
72 } else if (std::is_same<T, uint64>()) {
73 return PrimitiveTypeID::u64;
74 } else {
75 TI_NOT_IMPLEMENTED;
76 }
77}
78
79inline bool is_tensor(DataType dt) {
80 return dt->is<TensorType>();
81}
82
83inline bool is_quant(DataType dt) {
84 return dt->is<QuantIntType>() || dt->is<QuantFixedType>() ||
85 dt->is<QuantFloatType>();
86}
87
88inline bool is_real(DataType dt) {
89 return dt->is_primitive(PrimitiveTypeID::f16) ||
90 dt->is_primitive(PrimitiveTypeID::f32) ||
91 dt->is_primitive(PrimitiveTypeID::f64) || dt->is<QuantFixedType>() ||
92 dt->is<QuantFloatType>();
93}
94
95inline bool is_integral(DataType dt) {
96 return dt->is_primitive(PrimitiveTypeID::i8) ||
97 dt->is_primitive(PrimitiveTypeID::i16) ||
98 dt->is_primitive(PrimitiveTypeID::i32) ||
99 dt->is_primitive(PrimitiveTypeID::i64) ||
100 dt->is_primitive(PrimitiveTypeID::u8) ||
101 dt->is_primitive(PrimitiveTypeID::u16) ||
102 dt->is_primitive(PrimitiveTypeID::u32) ||
103 dt->is_primitive(PrimitiveTypeID::u64) || dt->is<QuantIntType>();
104}
105
106inline bool is_signed(DataType dt) {
107 // Shall we return false if is_integral returns false?
108 TI_ASSERT(is_integral(dt));
109 if (auto t = dt->cast<QuantIntType>())
110 return t->get_is_signed();
111 return dt->is_primitive(PrimitiveTypeID::i8) ||
112 dt->is_primitive(PrimitiveTypeID::i16) ||
113 dt->is_primitive(PrimitiveTypeID::i32) ||
114 dt->is_primitive(PrimitiveTypeID::i64);
115}
116
117inline bool is_unsigned(DataType dt) {
118 TI_ASSERT(is_integral(dt));
119 return !is_signed(dt);
120}
121
122inline DataType to_unsigned(DataType dt) {
123 TI_ASSERT(is_signed(dt));
124 if (dt->is_primitive(PrimitiveTypeID::i8))
125 return PrimitiveType::u8;
126 else if (dt->is_primitive(PrimitiveTypeID::i16))
127 return PrimitiveType::u16;
128 else if (dt->is_primitive(PrimitiveTypeID::i32))
129 return PrimitiveType::u32;
130 else if (dt->is_primitive(PrimitiveTypeID::i64))
131 return PrimitiveType::u64;
132 else
133 return PrimitiveType::unknown;
134}
135
136inline TypedConstant get_max_value(DataType dt) {
137 if (dt->is_primitive(PrimitiveTypeID::i8)) {
138 return {dt, std::numeric_limits<int8>::max()};
139 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
140 return {dt, std::numeric_limits<int16>::max()};
141 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
142 return {dt, std::numeric_limits<int32>::max()};
143 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
144 return {dt, std::numeric_limits<int64>::max()};
145 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
146 return {dt, std::numeric_limits<uint8>::max()};
147 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
148 return {dt, std::numeric_limits<uint16>::max()};
149 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
150 return {dt, std::numeric_limits<uint32>::max()};
151 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
152 return {dt, std::numeric_limits<uint64>::max()};
153 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
154 return {dt, std::numeric_limits<float32>::max()};
155 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
156 return {dt, std::numeric_limits<float64>::max()};
157 } else {
158 TI_NOT_IMPLEMENTED;
159 }
160}
161
162inline TypedConstant get_min_value(DataType dt) {
163 if (dt->is_primitive(PrimitiveTypeID::i8)) {
164 return {dt, std::numeric_limits<int8>::min()};
165 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
166 return {dt, std::numeric_limits<int16>::min()};
167 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
168 return {dt, std::numeric_limits<int32>::min()};
169 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
170 return {dt, std::numeric_limits<int64>::min()};
171 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
172 return {dt, std::numeric_limits<uint8>::min()};
173 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
174 return {dt, std::numeric_limits<uint16>::min()};
175 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
176 return {dt, std::numeric_limits<uint32>::min()};
177 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
178 return {dt, std::numeric_limits<uint64>::min()};
179 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
180 return {dt, std::numeric_limits<float32>::min()};
181 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
182 return {dt, std::numeric_limits<float64>::min()};
183 } else {
184 TI_NOT_IMPLEMENTED;
185 }
186}
187
188class BitStructTypeBuilder {
189 public:
190 explicit BitStructTypeBuilder(int max_num_bits) {
191 physical_type_ =
192 TypeFactory::get_instance().get_primitive_int_type(max_num_bits);
193 }
194
195 int add_member(Type *member_type) {
196 if (auto qflt = member_type->cast<QuantFloatType>()) {
197 auto exponent_type = qflt->get_exponent_type();
198 auto exponent_id = -1;
199 if (is_placing_shared_exponent_ && current_shared_exponent_ != -1) {
200 // Reuse existing exponent
201 TI_ASSERT_INFO(member_types_[current_shared_exponent_] == exponent_type,
202 "QuantFloatTypes with shared exponents must have "
203 "exactly the same exponent type.");
204 exponent_id = current_shared_exponent_;
205 } else {
206 exponent_id = add_member_impl(exponent_type);
207 if (is_placing_shared_exponent_) {
208 current_shared_exponent_ = exponent_id;
209 }
210 }
211 auto digits_id = add_member_impl(member_type);
212 member_exponents_[digits_id] = exponent_id;
213 member_exponent_users_[exponent_id].push_back(digits_id);
214 return digits_id;
215 }
216 return add_member_impl(member_type);
217 }
218
219 void begin_placing_shared_exponent() {
220 TI_ASSERT(!is_placing_shared_exponent_);
221 TI_ASSERT(current_shared_exponent_ == -1);
222 is_placing_shared_exponent_ = true;
223 }
224
225 void end_placing_shared_exponent() {
226 TI_ASSERT(is_placing_shared_exponent_);
227 TI_ASSERT(current_shared_exponent_ != -1);
228 current_shared_exponent_ = -1;
229 is_placing_shared_exponent_ = false;
230 }
231
232 BitStructType *build() const {
233 return TypeFactory::get_instance().get_bit_struct_type(
234 physical_type_, member_types_, member_bit_offsets_, member_exponents_,
235 member_exponent_users_);
236 }
237
238 private:
239 int add_member_impl(Type *member_type) {
240 int old_num_members = member_types_.size();
241 member_types_.push_back(member_type);
242 member_bit_offsets_.push_back(member_total_bits_);
243 member_exponents_.push_back(-1);
244 member_exponent_users_.push_back({});
245 QuantIntType *member_qit = nullptr;
246 if (auto qit = member_type->cast<QuantIntType>()) {
247 member_qit = qit;
248 } else if (auto qfxt = member_type->cast<QuantFixedType>()) {
249 member_qit = qfxt->get_digits_type()->as<QuantIntType>();
250 } else if (auto qflt = member_type->cast<QuantFloatType>()) {
251 member_qit = qflt->get_digits_type()->as<QuantIntType>();
252 } else {
253 TI_ERROR("Only a QuantType can be a member of a BitStructType.");
254 }
255 member_total_bits_ += member_qit->get_num_bits();
256 auto physical_bits = data_type_bits(physical_type_);
257 TI_ERROR_IF(member_total_bits_ > physical_bits,
258 "BitStructType overflows: {} bits used out of {}.",
259 member_total_bits_, physical_bits);
260 return old_num_members;
261 }
262
263 PrimitiveType *physical_type_{nullptr};
264 std::vector<Type *> member_types_;
265 std::vector<int> member_bit_offsets_;
266 int member_total_bits_{0};
267 std::vector<int> member_exponents_;
268 std::vector<std::vector<int>> member_exponent_users_;
269 bool is_placing_shared_exponent_{false};
270 int current_shared_exponent_{-1};
271};
272
273} // namespace taichi::lang
274