1#include "taichi/ir/type_factory.h"
2
3#include "taichi/ir/type_utils.h"
4
5namespace taichi::lang {
6
7TypeFactory &TypeFactory::get_instance() {
8 static TypeFactory *type_factory = new TypeFactory;
9 return *type_factory;
10}
11
12TypeFactory::TypeFactory() {
13}
14
15Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
16 std::lock_guard<std::mutex> _(primitive_mut_);
17
18 if (primitive_types_.find(id) == primitive_types_.end()) {
19 primitive_types_[id] = std::make_unique<PrimitiveType>(id);
20 }
21
22 return primitive_types_[id].get();
23}
24
25Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
26 std::lock_guard<std::mutex> _(tensor_mut_);
27
28 auto encode = [](const std::vector<int> &shape) -> std::string {
29 std::string s;
30 for (int i = 0; i < (int)shape.size(); ++i)
31 s += fmt::format(i == 0 ? "{}" : "_{}", std::to_string(shape[i]));
32 return s;
33 };
34 auto key = std::make_pair(encode(shape), element);
35 if (tensor_types_.find(key) == tensor_types_.end()) {
36 tensor_types_[key] = std::make_unique<TensorType>(shape, element);
37 }
38 return tensor_types_[key].get();
39}
40
41const Type *TypeFactory::get_struct_type(
42 const std::vector<StructMember> &elements,
43 const std::string &layout) {
44 std::lock_guard<std::mutex> _(struct_mut_);
45 auto key = std::make_pair(elements, layout);
46
47 if (struct_types_.find(key) == struct_types_.end()) {
48 for (const auto &element : elements) {
49 TI_ASSERT_INFO(
50 element.type->is<PrimitiveType>() || element.type->is<TensorType>() ||
51 element.type->is<StructType>() || element.type->is<PointerType>(),
52 "Unsupported struct element type for element " + element.name + ": " +
53 element.type->to_string());
54 }
55 struct_types_[key] = std::make_unique<StructType>(elements, layout);
56 }
57 return struct_types_[key].get();
58}
59
60Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
61 std::lock_guard<std::mutex> _(pointer_mut_);
62
63 auto key = std::make_pair(element, is_bit_pointer);
64 if (pointer_types_.find(key) == pointer_types_.end()) {
65 pointer_types_[key] =
66 std::make_unique<PointerType>(element, is_bit_pointer);
67 }
68 return pointer_types_[key].get();
69}
70
71Type *TypeFactory::get_quant_int_type(int num_bits,
72 bool is_signed,
73 Type *compute_type) {
74 std::lock_guard<std::mutex> _(quant_int_mut_);
75
76 auto key = std::make_tuple(num_bits, is_signed, compute_type);
77 if (quant_int_types_.find(key) == quant_int_types_.end()) {
78 quant_int_types_[key] =
79 std::make_unique<QuantIntType>(num_bits, is_signed, compute_type);
80 }
81 return quant_int_types_[key].get();
82}
83
84Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
85 Type *compute_type,
86 float64 scale) {
87 std::lock_guard<std::mutex> _(quant_fixed_mut_);
88
89 auto key = std::make_tuple(digits_type, compute_type, scale);
90 if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) {
91 quant_fixed_types_[key] =
92 std::make_unique<QuantFixedType>(digits_type, compute_type, scale);
93 }
94 return quant_fixed_types_[key].get();
95}
96
97Type *TypeFactory::get_quant_float_type(Type *digits_type,
98 Type *exponent_type,
99 Type *compute_type) {
100 std::lock_guard<std::mutex> _(quant_float_mut_);
101
102 auto key = std::make_tuple(digits_type, exponent_type, compute_type);
103 if (quant_float_types_.find(key) == quant_float_types_.end()) {
104 quant_float_types_[key] = std::make_unique<QuantFloatType>(
105 digits_type, exponent_type, compute_type);
106 }
107 return quant_float_types_[key].get();
108}
109
110BitStructType *TypeFactory::get_bit_struct_type(
111 PrimitiveType *physical_type,
112 const std::vector<Type *> &member_types,
113 const std::vector<int> &member_bit_offsets,
114 const std::vector<int> &member_exponents,
115 const std::vector<std::vector<int>> &member_exponent_users) {
116 std::lock_guard<std::mutex> _(bit_struct_mut_);
117
118 bit_struct_types_.push_back(std::make_unique<BitStructType>(
119 physical_type, member_types, member_bit_offsets, member_exponents,
120 member_exponent_users));
121 return bit_struct_types_.back().get();
122}
123
124Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
125 Type *element_type,
126 int num_elements) {
127 std::lock_guard<std::mutex> _(quant_array_mut_);
128
129 quant_array_types_.push_back(std::make_unique<QuantArrayType>(
130 physical_type, element_type, num_elements));
131 return quant_array_types_.back().get();
132}
133
134PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
135 Type *int_type;
136 if (bits == 8) {
137 int_type = get_primitive_type(PrimitiveTypeID::i8);
138 } else if (bits == 16) {
139 int_type = get_primitive_type(PrimitiveTypeID::i16);
140 } else if (bits == 32) {
141 int_type = get_primitive_type(PrimitiveTypeID::i32);
142 } else if (bits == 64) {
143 int_type = get_primitive_type(PrimitiveTypeID::i64);
144 } else {
145 TI_ERROR("No primitive int type has {} bits", bits);
146 }
147 if (!is_signed) {
148 int_type = to_unsigned(DataType(int_type));
149 }
150 return int_type->cast<PrimitiveType>();
151}
152
153PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
154 Type *real_type;
155 if (bits == 16) {
156 real_type = get_primitive_type(PrimitiveTypeID::f16);
157 } else if (bits == 32) {
158 real_type = get_primitive_type(PrimitiveTypeID::f32);
159 } else if (bits == 64) {
160 real_type = get_primitive_type(PrimitiveTypeID::f64);
161 } else {
162 TI_ERROR("No primitive real type has {} bits", bits);
163 }
164 return real_type->cast<PrimitiveType>();
165}
166
167DataType TypeFactory::create_tensor_type(std::vector<int> shape,
168 DataType element) {
169 return TypeFactory::get_instance().get_tensor_type(shape, element);
170}
171
172namespace {
173static bool compare_types(DataType x, DataType y) {
174 // Is the first type "bigger" than the second type?
175 if (is_real(x) != is_real(y)) {
176 // One is real, the other is integral.
177 // real > integral
178 return is_real(x);
179 } else {
180 if (is_real(x) && is_real(y)) {
181 // Both are real
182 return data_type_bits(x) > data_type_bits(y);
183 } else {
184 // Both are integral
185 auto x_bits = data_type_bits(x);
186 auto y_bits = data_type_bits(y);
187 if (x_bits != y_bits) {
188 return x_bits > y_bits;
189 } else {
190 // Same number of bits. Unsigned > signed
191 auto x_unsigned = !is_signed(x);
192 auto y_unsigned = !is_signed(y);
193 return x_unsigned > y_unsigned;
194 }
195 }
196 }
197}
198
199static DataType to_primitive_type(DataType d) {
200 if (d->is<PointerType>()) {
201 d = d->as<PointerType>()->get_pointee_type();
202 TI_WARN("promoted_type got a pointer input.");
203 }
204
205 if (d->is<TensorType>()) {
206 d = d->as<TensorType>()->get_element_type();
207 TI_WARN("promoted_type got a tensor input.");
208 }
209
210 auto primitive = d->cast<PrimitiveType>();
211 TI_ASSERT_INFO(primitive, "Failed to get primitive type from {}",
212 d->to_string());
213 return primitive;
214};
215} // namespace
216
217DataType promoted_primitive_type(DataType x, DataType y) {
218 if (compare_types(to_primitive_type(x), to_primitive_type(y)))
219 return x;
220 else
221 return y;
222}
223
224DataType promoted_type(DataType a, DataType b) {
225 if (a->is<TensorType>() || b->is<TensorType>()) {
226 TI_ASSERT_INFO(a->is<TensorType>() && b->is<TensorType>(),
227 "a = {}, b = {}, only one of them is a tensor type",
228 a->to_string(), b->to_string());
229 auto tensor_ty_a = a->cast<TensorType>();
230 auto tensor_ty_b = b->cast<TensorType>();
231 auto promoted_dt = promoted_type(tensor_ty_a->get_element_type(),
232 tensor_ty_b->get_element_type());
233 return TypeFactory::create_tensor_type(tensor_ty_a->get_shape(),
234 promoted_dt);
235 } else {
236 return promoted_primitive_type(a, b);
237 }
238};
239
240} // namespace taichi::lang
241