1 | #include "taichi/ir/type_factory.h" |
2 | |
3 | #include "taichi/ir/type_utils.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | TypeFactory &TypeFactory::get_instance() { |
8 | static TypeFactory *type_factory = new TypeFactory; |
9 | return *type_factory; |
10 | } |
11 | |
12 | TypeFactory::TypeFactory() { |
13 | } |
14 | |
15 | Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { |
16 | std::lock_guard<std::mutex> _(primitive_mut_); |
17 | |
18 | if (primitive_types_.find(id) == primitive_types_.end()) { |
19 | primitive_types_[id] = std::make_unique<PrimitiveType>(id); |
20 | } |
21 | |
22 | return primitive_types_[id].get(); |
23 | } |
24 | |
25 | Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) { |
26 | std::lock_guard<std::mutex> _(tensor_mut_); |
27 | |
28 | auto encode = [](const std::vector<int> &shape) -> std::string { |
29 | std::string s; |
30 | for (int i = 0; i < (int)shape.size(); ++i) |
31 | s += fmt::format(i == 0 ? "{}" : "_{}" , std::to_string(shape[i])); |
32 | return s; |
33 | }; |
34 | auto key = std::make_pair(encode(shape), element); |
35 | if (tensor_types_.find(key) == tensor_types_.end()) { |
36 | tensor_types_[key] = std::make_unique<TensorType>(shape, element); |
37 | } |
38 | return tensor_types_[key].get(); |
39 | } |
40 | |
41 | const Type *TypeFactory::get_struct_type( |
42 | const std::vector<StructMember> &elements, |
43 | const std::string &layout) { |
44 | std::lock_guard<std::mutex> _(struct_mut_); |
45 | auto key = std::make_pair(elements, layout); |
46 | |
47 | if (struct_types_.find(key) == struct_types_.end()) { |
48 | for (const auto &element : elements) { |
49 | TI_ASSERT_INFO( |
50 | element.type->is<PrimitiveType>() || element.type->is<TensorType>() || |
51 | element.type->is<StructType>() || element.type->is<PointerType>(), |
52 | "Unsupported struct element type for element " + element.name + ": " + |
53 | element.type->to_string()); |
54 | } |
55 | struct_types_[key] = std::make_unique<StructType>(elements, layout); |
56 | } |
57 | return struct_types_[key].get(); |
58 | } |
59 | |
60 | Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { |
61 | std::lock_guard<std::mutex> _(pointer_mut_); |
62 | |
63 | auto key = std::make_pair(element, is_bit_pointer); |
64 | if (pointer_types_.find(key) == pointer_types_.end()) { |
65 | pointer_types_[key] = |
66 | std::make_unique<PointerType>(element, is_bit_pointer); |
67 | } |
68 | return pointer_types_[key].get(); |
69 | } |
70 | |
71 | Type *TypeFactory::get_quant_int_type(int num_bits, |
72 | bool is_signed, |
73 | Type *compute_type) { |
74 | std::lock_guard<std::mutex> _(quant_int_mut_); |
75 | |
76 | auto key = std::make_tuple(num_bits, is_signed, compute_type); |
77 | if (quant_int_types_.find(key) == quant_int_types_.end()) { |
78 | quant_int_types_[key] = |
79 | std::make_unique<QuantIntType>(num_bits, is_signed, compute_type); |
80 | } |
81 | return quant_int_types_[key].get(); |
82 | } |
83 | |
84 | Type *TypeFactory::get_quant_fixed_type(Type *digits_type, |
85 | Type *compute_type, |
86 | float64 scale) { |
87 | std::lock_guard<std::mutex> _(quant_fixed_mut_); |
88 | |
89 | auto key = std::make_tuple(digits_type, compute_type, scale); |
90 | if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) { |
91 | quant_fixed_types_[key] = |
92 | std::make_unique<QuantFixedType>(digits_type, compute_type, scale); |
93 | } |
94 | return quant_fixed_types_[key].get(); |
95 | } |
96 | |
97 | Type *TypeFactory::get_quant_float_type(Type *digits_type, |
98 | Type *exponent_type, |
99 | Type *compute_type) { |
100 | std::lock_guard<std::mutex> _(quant_float_mut_); |
101 | |
102 | auto key = std::make_tuple(digits_type, exponent_type, compute_type); |
103 | if (quant_float_types_.find(key) == quant_float_types_.end()) { |
104 | quant_float_types_[key] = std::make_unique<QuantFloatType>( |
105 | digits_type, exponent_type, compute_type); |
106 | } |
107 | return quant_float_types_[key].get(); |
108 | } |
109 | |
110 | BitStructType *TypeFactory::get_bit_struct_type( |
111 | PrimitiveType *physical_type, |
112 | const std::vector<Type *> &member_types, |
113 | const std::vector<int> &member_bit_offsets, |
114 | const std::vector<int> &member_exponents, |
115 | const std::vector<std::vector<int>> &member_exponent_users) { |
116 | std::lock_guard<std::mutex> _(bit_struct_mut_); |
117 | |
118 | bit_struct_types_.push_back(std::make_unique<BitStructType>( |
119 | physical_type, member_types, member_bit_offsets, member_exponents, |
120 | member_exponent_users)); |
121 | return bit_struct_types_.back().get(); |
122 | } |
123 | |
124 | Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, |
125 | Type *element_type, |
126 | int num_elements) { |
127 | std::lock_guard<std::mutex> _(quant_array_mut_); |
128 | |
129 | quant_array_types_.push_back(std::make_unique<QuantArrayType>( |
130 | physical_type, element_type, num_elements)); |
131 | return quant_array_types_.back().get(); |
132 | } |
133 | |
134 | PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { |
135 | Type *int_type; |
136 | if (bits == 8) { |
137 | int_type = get_primitive_type(PrimitiveTypeID::i8); |
138 | } else if (bits == 16) { |
139 | int_type = get_primitive_type(PrimitiveTypeID::i16); |
140 | } else if (bits == 32) { |
141 | int_type = get_primitive_type(PrimitiveTypeID::i32); |
142 | } else if (bits == 64) { |
143 | int_type = get_primitive_type(PrimitiveTypeID::i64); |
144 | } else { |
145 | TI_ERROR("No primitive int type has {} bits" , bits); |
146 | } |
147 | if (!is_signed) { |
148 | int_type = to_unsigned(DataType(int_type)); |
149 | } |
150 | return int_type->cast<PrimitiveType>(); |
151 | } |
152 | |
153 | PrimitiveType *TypeFactory::get_primitive_real_type(int bits) { |
154 | Type *real_type; |
155 | if (bits == 16) { |
156 | real_type = get_primitive_type(PrimitiveTypeID::f16); |
157 | } else if (bits == 32) { |
158 | real_type = get_primitive_type(PrimitiveTypeID::f32); |
159 | } else if (bits == 64) { |
160 | real_type = get_primitive_type(PrimitiveTypeID::f64); |
161 | } else { |
162 | TI_ERROR("No primitive real type has {} bits" , bits); |
163 | } |
164 | return real_type->cast<PrimitiveType>(); |
165 | } |
166 | |
167 | DataType TypeFactory::create_tensor_type(std::vector<int> shape, |
168 | DataType element) { |
169 | return TypeFactory::get_instance().get_tensor_type(shape, element); |
170 | } |
171 | |
172 | namespace { |
173 | static bool compare_types(DataType x, DataType y) { |
174 | // Is the first type "bigger" than the second type? |
175 | if (is_real(x) != is_real(y)) { |
176 | // One is real, the other is integral. |
177 | // real > integral |
178 | return is_real(x); |
179 | } else { |
180 | if (is_real(x) && is_real(y)) { |
181 | // Both are real |
182 | return data_type_bits(x) > data_type_bits(y); |
183 | } else { |
184 | // Both are integral |
185 | auto x_bits = data_type_bits(x); |
186 | auto y_bits = data_type_bits(y); |
187 | if (x_bits != y_bits) { |
188 | return x_bits > y_bits; |
189 | } else { |
190 | // Same number of bits. Unsigned > signed |
191 | auto x_unsigned = !is_signed(x); |
192 | auto y_unsigned = !is_signed(y); |
193 | return x_unsigned > y_unsigned; |
194 | } |
195 | } |
196 | } |
197 | } |
198 | |
199 | static DataType to_primitive_type(DataType d) { |
200 | if (d->is<PointerType>()) { |
201 | d = d->as<PointerType>()->get_pointee_type(); |
202 | TI_WARN("promoted_type got a pointer input." ); |
203 | } |
204 | |
205 | if (d->is<TensorType>()) { |
206 | d = d->as<TensorType>()->get_element_type(); |
207 | TI_WARN("promoted_type got a tensor input." ); |
208 | } |
209 | |
210 | auto primitive = d->cast<PrimitiveType>(); |
211 | TI_ASSERT_INFO(primitive, "Failed to get primitive type from {}" , |
212 | d->to_string()); |
213 | return primitive; |
214 | }; |
215 | } // namespace |
216 | |
217 | DataType promoted_primitive_type(DataType x, DataType y) { |
218 | if (compare_types(to_primitive_type(x), to_primitive_type(y))) |
219 | return x; |
220 | else |
221 | return y; |
222 | } |
223 | |
224 | DataType promoted_type(DataType a, DataType b) { |
225 | if (a->is<TensorType>() || b->is<TensorType>()) { |
226 | TI_ASSERT_INFO(a->is<TensorType>() && b->is<TensorType>(), |
227 | "a = {}, b = {}, only one of them is a tensor type" , |
228 | a->to_string(), b->to_string()); |
229 | auto tensor_ty_a = a->cast<TensorType>(); |
230 | auto tensor_ty_b = b->cast<TensorType>(); |
231 | auto promoted_dt = promoted_type(tensor_ty_a->get_element_type(), |
232 | tensor_ty_b->get_element_type()); |
233 | return TypeFactory::create_tensor_type(tensor_ty_a->get_shape(), |
234 | promoted_dt); |
235 | } else { |
236 | return promoted_primitive_type(a, b); |
237 | } |
238 | }; |
239 | |
240 | } // namespace taichi::lang |
241 | |