1 | #ifdef TI_WITH_LLVM |
2 | #include "taichi/codegen/llvm/codegen_llvm.h" |
3 | #include "taichi/ir/statements.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | namespace { |
8 | |
9 | inline void update_mask(uint64 &mask, uint32 num_bits, uint32 offset) { |
10 | uint64 new_mask = |
11 | (((~(uint64)0) << (64 - num_bits)) >> (64 - offset - num_bits)); |
12 | TI_ASSERT((mask & new_mask) == 0); |
13 | mask |= new_mask; |
14 | } |
15 | |
16 | } // namespace |
17 | |
18 | llvm::Value *TaskCodeGenLLVM::atomic_add_quant_int(llvm::Value *ptr, |
19 | llvm::Type *physical_type, |
20 | QuantIntType *qit, |
21 | llvm::Value *value, |
22 | bool value_is_signed) { |
23 | auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); |
24 | return call(fmt::format("atomic_add_partial_bits_b{}" , |
25 | physical_type->getIntegerBitWidth()), |
26 | byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), |
27 | builder->CreateIntCast(value, physical_type, value_is_signed)); |
28 | } |
29 | |
30 | llvm::Value *TaskCodeGenLLVM::atomic_add_quant_fixed(llvm::Value *ptr, |
31 | llvm::Type *physical_type, |
32 | QuantFixedType *qfxt, |
33 | llvm::Value *value) { |
34 | auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); |
35 | auto qit = qfxt->get_digits_type()->as<QuantIntType>(); |
36 | auto val_store = to_quant_fixed(value, qfxt); |
37 | val_store = builder->CreateSExt(val_store, physical_type); |
38 | return call(fmt::format("atomic_add_partial_bits_b{}" , |
39 | physical_type->getIntegerBitWidth()), |
40 | byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), |
41 | val_store); |
42 | } |
43 | |
44 | llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real, |
45 | QuantFixedType *qfxt) { |
46 | // Compute int(real * (1.0 / scale) + 0.5) |
47 | auto compute_type = qfxt->get_compute_type(); |
48 | auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), |
49 | tlctx->get_data_type(compute_type)); |
50 | auto input_real = |
51 | builder->CreateFPCast(real, tlctx->get_data_type(compute_type)); |
52 | auto scaled = builder->CreateFMul(input_real, s); |
53 | |
54 | // Add/minus the 0.5 offset for rounding |
55 | scaled = |
56 | call(fmt::format("rounding_prepare_f{}" , data_type_bits(compute_type)), |
57 | scaled); |
58 | |
59 | auto qit = qfxt->get_digits_type()->as<QuantIntType>(); |
60 | if (qit->get_is_signed()) { |
61 | return builder->CreateFPToSI(scaled, |
62 | tlctx->get_data_type(qit->get_compute_type())); |
63 | } else { |
64 | return builder->CreateFPToUI(scaled, |
65 | tlctx->get_data_type(qit->get_compute_type())); |
66 | } |
67 | } |
68 | |
69 | void TaskCodeGenLLVM::store_quant_int(llvm::Value *ptr, |
70 | llvm::Type *physical_type, |
71 | QuantIntType *qit, |
72 | llvm::Value *value, |
73 | bool atomic) { |
74 | auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); |
75 | // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. |
76 | // Try to support 8/16-bit physical types. |
77 | call(fmt::format("{}set_partial_bits_b{}" , atomic ? "atomic_" : "" , |
78 | physical_type->getIntegerBitWidth()), |
79 | byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), |
80 | builder->CreateIntCast(value, physical_type, false)); |
81 | } |
82 | |
83 | void TaskCodeGenLLVM::store_quant_fixed(llvm::Value *ptr, |
84 | llvm::Type *physical_type, |
85 | QuantFixedType *qfxt, |
86 | llvm::Value *value, |
87 | bool atomic) { |
88 | store_quant_int(ptr, physical_type, |
89 | qfxt->get_digits_type()->as<QuantIntType>(), |
90 | to_quant_fixed(value, qfxt), atomic); |
91 | } |
92 | |
93 | void TaskCodeGenLLVM::store_masked(llvm::Value *ptr, |
94 | llvm::Type *ty, |
95 | uint64 mask, |
96 | llvm::Value *value, |
97 | bool atomic) { |
98 | if (!mask) { |
99 | // do not store anything |
100 | return; |
101 | } |
102 | uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth()); |
103 | if ((!atomic || compile_config.quant_opt_atomic_demotion) && |
104 | ((mask & full_mask) == full_mask)) { |
105 | builder->CreateStore(value, ptr); |
106 | return; |
107 | } |
108 | call(fmt::format("{}set_mask_b{}" , atomic ? "atomic_" : "" , |
109 | ty->getIntegerBitWidth()), |
110 | ptr, tlctx->get_constant(mask), |
111 | builder->CreateIntCast(value, ty, false)); |
112 | } |
113 | |
114 | llvm::Value *TaskCodeGenLLVM::get_exponent_offset(llvm::Value *exponent, |
115 | QuantFloatType *qflt) { |
116 | // Since we have fewer bits in the exponent type than in f32, an |
117 | // offset is necessary to make sure the stored exponent values are |
118 | // representable by the exponent quant int type. |
119 | auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent, |
120 | tlctx->get_constant(0)); |
121 | return builder->CreateSelect( |
122 | cond, tlctx->get_constant(qflt->get_exponent_conversion_offset()), |
123 | tlctx->get_constant(0)); |
124 | } |
125 | |
126 | llvm::Value *TaskCodeGenLLVM::quant_int_or_quant_fixed_to_bits( |
127 | llvm::Value *val, |
128 | Type *input_type, |
129 | llvm::Type *output_type) { |
130 | QuantIntType *qit = nullptr; |
131 | if (auto qfxt = input_type->cast<QuantFixedType>()) { |
132 | qit = qfxt->get_digits_type()->as<QuantIntType>(); |
133 | val = to_quant_fixed(val, qfxt); |
134 | } else { |
135 | qit = input_type->as<QuantIntType>(); |
136 | } |
137 | if (qit->get_num_bits() < val->getType()->getIntegerBitWidth()) { |
138 | val = builder->CreateAnd( |
139 | val, tlctx->get_constant(qit->get_compute_type(), |
140 | uint64((1ULL << qit->get_num_bits()) - 1))); |
141 | } |
142 | val = builder->CreateZExt(val, output_type); |
143 | return val; |
144 | } |
145 | |
146 | void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) { |
147 | auto bit_struct = stmt->get_bit_struct(); |
148 | auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type()); |
149 | |
150 | int num_non_exponent_children = 0; |
151 | for (int i = 0; i < bit_struct->get_num_members(); i++) { |
152 | if (bit_struct->get_member_exponent_users(i).empty()) { |
153 | num_non_exponent_children++; |
154 | } |
155 | } |
156 | bool store_all_components = false; |
157 | if (compile_config.quant_opt_atomic_demotion && |
158 | stmt->ch_ids.size() == num_non_exponent_children) { |
159 | stmt->is_atomic = false; |
160 | store_all_components = true; |
161 | } |
162 | |
163 | bool has_shared_exponent = false; |
164 | for (auto ch_id : stmt->ch_ids) { |
165 | if (bit_struct->get_member_owns_shared_exponent(ch_id)) { |
166 | has_shared_exponent = true; |
167 | } |
168 | } |
169 | if (has_shared_exponent) { |
170 | store_quant_floats_with_shared_exponents(stmt); |
171 | } |
172 | |
173 | llvm::Value *bit_struct_val = nullptr; |
174 | for (int i = 0; i < stmt->ch_ids.size(); i++) { |
175 | auto ch_id = stmt->ch_ids[i]; |
176 | auto exp = bit_struct->get_member_exponent(ch_id); |
177 | if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) { |
178 | // already handled in store_quant_floats_with_shared_exponents |
179 | continue; |
180 | } |
181 | auto dtype = bit_struct->get_member_type(ch_id); |
182 | auto val = llvm_val[stmt->values[i]]; |
183 | if (auto qflt = dtype->cast<QuantFloatType>()) { |
184 | // Quant float type with non-shared exponent. |
185 | llvm::Value *digit_bits = nullptr; |
186 | // Extract exponent and digits from compute type (assumed to be f32 for |
187 | // now). |
188 | TI_ASSERT(qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32)); |
189 | |
190 | // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits |
191 | |
192 | auto f32_bits = |
193 | builder->CreateBitCast(val, llvm::Type::getInt32Ty(*llvm_context)); |
194 | // Rounding to nearest here. Note that if the digits overflows then the |
195 | // carry-on will contribute to the exponent, which is desired. |
196 | if (qflt->get_digit_bits() < 23) { |
197 | f32_bits = builder->CreateAdd( |
198 | f32_bits, tlctx->get_constant(1 << (22 - qflt->get_digit_bits()))); |
199 | } |
200 | |
201 | auto exponent_bits = builder->CreateAShr(f32_bits, 23); |
202 | exponent_bits = |
203 | builder->CreateAnd(exponent_bits, tlctx->get_constant((1 << 8) - 1)); |
204 | auto value_bits = builder->CreateAShr( |
205 | f32_bits, tlctx->get_constant(23 - qflt->get_digit_bits())); |
206 | |
207 | digit_bits = builder->CreateAnd( |
208 | value_bits, tlctx->get_constant((1 << (qflt->get_digit_bits())) - 1)); |
209 | |
210 | if (qflt->get_is_signed()) { |
211 | // extract the sign bit |
212 | auto sign_bit = |
213 | builder->CreateAnd(f32_bits, tlctx->get_constant(0x80000000u)); |
214 | // insert the sign bit to digit bits |
215 | digit_bits = builder->CreateOr( |
216 | digit_bits, |
217 | builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits())); |
218 | } |
219 | |
220 | auto exponent_offset = get_exponent_offset(exponent_bits, qflt); |
221 | exponent_bits = builder->CreateSub(exponent_bits, exponent_offset); |
222 | exponent_bits = call("max_i32" , exponent_bits, tlctx->get_constant(0)); |
223 | |
224 | // Compute the bit pointer of the exponent bits. |
225 | val = builder->CreateIntCast(exponent_bits, physical_type, false); |
226 | val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp)); |
227 | |
228 | if (bit_struct_val == nullptr) { |
229 | bit_struct_val = val; |
230 | } else { |
231 | bit_struct_val = builder->CreateOr(bit_struct_val, val); |
232 | } |
233 | // Here we implement flush to zero (FTZ): if exponent is zero, we force |
234 | // the digits to be zero. |
235 | // TODO: it seems that this can be more efficiently implemented using a |
236 | // bit_and. |
237 | auto exp_non_zero = |
238 | builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent_bits, |
239 | tlctx->get_constant(0)); |
240 | val = builder->CreateSelect(exp_non_zero, digit_bits, |
241 | tlctx->get_constant(0)); |
242 | val = builder->CreateIntCast(val, physical_type, false); |
243 | val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id)); |
244 | } else { |
245 | val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type); |
246 | val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id)); |
247 | } |
248 | |
249 | if (bit_struct_val == nullptr) { |
250 | bit_struct_val = val; |
251 | } else { |
252 | bit_struct_val = builder->CreateOr(bit_struct_val, val); |
253 | } |
254 | } |
255 | if (store_all_components && !has_shared_exponent) { |
256 | // Store all components here. |
257 | builder->CreateStore(bit_struct_val, llvm_val[stmt->ptr]); |
258 | } else { |
259 | // Create a mask and use a single (atomic)CAS |
260 | uint64 mask = 0; |
261 | for (int i = 0; i < stmt->ch_ids.size(); i++) { |
262 | auto ch_id = stmt->ch_ids[i]; |
263 | auto exp = bit_struct->get_member_exponent(ch_id); |
264 | if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) { |
265 | // already handled in store_quant_floats_with_shared_exponents |
266 | continue; |
267 | } |
268 | auto dtype = bit_struct->get_member_type(ch_id); |
269 | QuantIntType *qit = nullptr; |
270 | if (auto qflt = dtype->cast<QuantFloatType>()) { |
271 | auto exponent_qit = qflt->get_exponent_type()->as<QuantIntType>(); |
272 | update_mask(mask, exponent_qit->get_num_bits(), |
273 | bit_struct->get_member_bit_offset(exp)); |
274 | qit = qflt->get_digits_type()->as<QuantIntType>(); |
275 | } else if (auto qfxt = dtype->cast<QuantFixedType>()) { |
276 | qit = qfxt->get_digits_type()->as<QuantIntType>(); |
277 | } else { |
278 | qit = dtype->as<QuantIntType>(); |
279 | } |
280 | update_mask(mask, qit->get_num_bits(), |
281 | bit_struct->get_member_bit_offset(ch_id)); |
282 | } |
283 | store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val, |
284 | stmt->is_atomic); |
285 | } |
286 | } |
287 | |
288 | void TaskCodeGenLLVM::store_quant_floats_with_shared_exponents( |
289 | BitStructStoreStmt *stmt) { |
290 | // handle each exponent separately |
291 | auto bit_struct = stmt->get_bit_struct(); |
292 | auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type()); |
293 | auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]); |
294 | // fuse all stores into a masked store |
295 | llvm::Value *masked_val = nullptr; |
296 | uint64 mask = 0; |
297 | for (int i = 0; i < bit_struct->get_num_members(); i++) { |
298 | auto &exponent_users = bit_struct->get_member_exponent_users(i); |
299 | // make sure i-th member is a shared exponent |
300 | if (exponent_users.size() < 2) |
301 | continue; |
302 | // load all floats with the shared exponent |
303 | std::vector<llvm::Value *> floats; |
304 | for (auto user : exponent_users) { |
305 | if (auto input = |
306 | std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), user); |
307 | input != stmt->ch_ids.end()) { |
308 | floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]); |
309 | } else { |
310 | floats.push_back(extract_quant_float(physical_value, bit_struct, user)); |
311 | } |
312 | } |
313 | // convert to i32 for bit operations |
314 | llvm::Value *max_exp_bits = nullptr; |
315 | for (auto f : floats) { |
316 | // TODO: we only support f32 here. |
317 | auto exp_bits = extract_exponent_from_f32(f); |
318 | if (max_exp_bits) { |
319 | max_exp_bits = call("max_u32" , max_exp_bits, exp_bits); |
320 | } else { |
321 | max_exp_bits = exp_bits; |
322 | } |
323 | } |
324 | |
325 | auto first_qflt = |
326 | bit_struct->get_member_type(exponent_users[0])->as<QuantFloatType>(); |
327 | auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt); |
328 | |
329 | auto max_exp_bits_to_store = |
330 | builder->CreateSub(max_exp_bits, exponent_offset); |
331 | |
332 | max_exp_bits_to_store = |
333 | call("max_i32" , max_exp_bits_to_store, tlctx->get_constant(0)); |
334 | |
335 | // store the exponent |
336 | auto bit_offset = bit_struct->get_member_bit_offset(i); |
337 | auto val = builder->CreateZExt(max_exp_bits_to_store, physical_type); |
338 | val = builder->CreateShl(val, bit_offset); |
339 | if (masked_val == nullptr) { |
340 | masked_val = val; |
341 | } else { |
342 | masked_val = builder->CreateOr(masked_val, val); |
343 | } |
344 | update_mask( |
345 | mask, |
346 | bit_struct->get_member_type(i)->as<QuantIntType>()->get_num_bits(), |
347 | bit_offset); |
348 | |
349 | for (int c = 0; c < (int)exponent_users.size(); c++) { |
350 | auto user = exponent_users[c]; |
351 | auto digits = |
352 | extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits); |
353 | auto qflt = bit_struct->get_member_type(user)->as<QuantFloatType>(); |
354 | auto digits_bit_offset = bit_struct->get_member_bit_offset(user); |
355 | auto right_shift_bits = 24 - qflt->get_digit_bits(); |
356 | |
357 | // round to nearest |
358 | digits = builder->CreateAdd( |
359 | digits, tlctx->get_constant(1 << (right_shift_bits - 1))); |
360 | // do not allow overflowing |
361 | digits = call("min_u32" , digits, tlctx->get_constant((1u << 24) - 1)); |
362 | |
363 | // Compress f32 digits to qflt digits. |
364 | // Note that we need to keep the leading 1 bit so 24 instead of 23 in the |
365 | // following code. |
366 | digits = builder->CreateLShr(digits, right_shift_bits); |
367 | if (qflt->get_is_signed()) { |
368 | auto float_bits = builder->CreateBitCast( |
369 | floats[c], llvm::Type::getInt32Ty(*llvm_context)); |
370 | auto sign_bit = builder->CreateAnd(float_bits, 1 << 31); |
371 | sign_bit = builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits()); |
372 | digits = builder->CreateOr(digits, sign_bit); |
373 | } |
374 | |
375 | // store the digits |
376 | val = builder->CreateZExt(digits, physical_type); |
377 | val = builder->CreateShl(val, digits_bit_offset); |
378 | masked_val = builder->CreateOr(masked_val, val); |
379 | auto num_digit_bits = |
380 | qflt->get_digits_type()->as<QuantIntType>()->get_num_bits(); |
381 | update_mask(mask, num_digit_bits, digits_bit_offset); |
382 | } |
383 | } |
384 | store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val, |
385 | stmt->is_atomic); |
386 | } |
387 | |
388 | llvm::Value *TaskCodeGenLLVM::(llvm::Value *f) { |
389 | TI_ASSERT(f->getType() == llvm::Type::getFloatTy(*llvm_context)); |
390 | f = builder->CreateBitCast(f, llvm::Type::getInt32Ty(*llvm_context)); |
391 | auto exp_bits = builder->CreateLShr(f, tlctx->get_constant(23)); |
392 | return builder->CreateAnd(exp_bits, tlctx->get_constant((1 << 8) - 1)); |
393 | } |
394 | |
395 | llvm::Value *TaskCodeGenLLVM::(llvm::Value *f, |
396 | bool full) { |
397 | TI_ASSERT(f->getType() == llvm::Type::getFloatTy(*llvm_context)); |
398 | f = builder->CreateBitCast(f, llvm::Type::getInt32Ty(*llvm_context)); |
399 | auto digits = builder->CreateAnd(f, tlctx->get_constant((1 << 23) - 1)); |
400 | if (full) { |
401 | digits = builder->CreateOr(digits, tlctx->get_constant(1 << 23)); |
402 | } |
403 | return digits; |
404 | } |
405 | |
406 | llvm::Value *TaskCodeGenLLVM::( |
407 | llvm::Value *f, |
408 | llvm::Value *shared_exp) { |
409 | auto exp = extract_exponent_from_f32(f); |
410 | auto exp_offset = builder->CreateSub(shared_exp, exp); |
411 | // TODO: handle negative digits |
412 | |
413 | // There are two cases that may result in zero digits: |
414 | // - exp is zero. This means f itself is zero. Note that when processors |
415 | // running under FTZ (flush to zero), exp = 0 implies digits = 0. |
416 | // - exp is too small compared to shared_exp, or equivalently exp_offset is |
417 | // too large. This means we need to flush digits to zero. |
418 | |
419 | // If exp is nonzero, insert an extra "1" bit that was originally implicit. |
420 | auto exp_non_zero = builder->CreateICmpNE(exp, tlctx->get_constant(0)); |
421 | exp_non_zero = |
422 | builder->CreateZExt(exp_non_zero, llvm::Type::getInt32Ty(*llvm_context)); |
423 | auto implicit_bit = builder->CreateShl(exp_non_zero, tlctx->get_constant(23)); |
424 | |
425 | auto digits = extract_digits_from_f32(f, true); |
426 | digits = builder->CreateOr(digits, implicit_bit); |
427 | exp_offset = call("min_u32" , exp_offset, tlctx->get_constant(31)); |
428 | return builder->CreateLShr(digits, exp_offset); |
429 | } |
430 | |
431 | llvm::Value *TaskCodeGenLLVM::(llvm::Value *physical_value, |
432 | BitStructType *bit_struct, |
433 | int digits_id) { |
434 | auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>(); |
435 | auto exponent_id = bit_struct->get_member_exponent(digits_id); |
436 | auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id); |
437 | auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id); |
438 | auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id); |
439 | auto digits = |
440 | extract_quant_int(physical_value, tlctx->get_constant(digits_bit_offset), |
441 | qflt->get_digits_type()->as<QuantIntType>()); |
442 | auto exponent = extract_quant_int( |
443 | physical_value, tlctx->get_constant(exponent_bit_offset), |
444 | qflt->get_exponent_type()->as<QuantIntType>()); |
445 | return reconstruct_quant_float(digits, exponent, qflt, shared_exponent); |
446 | } |
447 | |
448 | llvm::Value *TaskCodeGenLLVM::(llvm::Value *physical_value, |
449 | llvm::Value *bit_offset, |
450 | QuantIntType *qit) { |
451 | auto physical_type = physical_value->getType(); |
452 | // bit shifting |
453 | // first left shift `physical_type - (offset + num_bits)` |
454 | // then right shift `physical_type - num_bits` |
455 | auto bit_end = |
456 | builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits())); |
457 | auto left = builder->CreateSub( |
458 | tlctx->get_constant(physical_type->getIntegerBitWidth()), bit_end); |
459 | auto right = builder->CreateSub( |
460 | tlctx->get_constant(physical_type->getIntegerBitWidth()), |
461 | tlctx->get_constant(qit->get_num_bits())); |
462 | left = builder->CreateIntCast(left, physical_type, false); |
463 | right = builder->CreateIntCast(right, physical_type, false); |
464 | auto step1 = builder->CreateShl(physical_value, left); |
465 | llvm::Value *step2 = nullptr; |
466 | |
467 | if (qit->get_is_signed()) |
468 | step2 = builder->CreateAShr(step1, right); |
469 | else |
470 | step2 = builder->CreateLShr(step1, right); |
471 | |
472 | return builder->CreateIntCast(step2, |
473 | tlctx->get_data_type(qit->get_compute_type()), |
474 | qit->get_is_signed()); |
475 | } |
476 | |
477 | llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, |
478 | QuantFixedType *qfxt) { |
479 | // Compute float(digits) * scale |
480 | llvm::Value *cast = nullptr; |
481 | auto compute_type = qfxt->get_compute_type()->as<PrimitiveType>(); |
482 | if (qfxt->get_is_signed()) { |
483 | cast = builder->CreateSIToFP(digits, tlctx->get_data_type(compute_type)); |
484 | } else { |
485 | cast = builder->CreateUIToFP(digits, tlctx->get_data_type(compute_type)); |
486 | } |
487 | llvm::Value *s = tlctx->get_constant(qfxt->get_scale()); |
488 | s = builder->CreateFPCast(s, tlctx->get_data_type(compute_type)); |
489 | return builder->CreateFMul(cast, s); |
490 | } |
491 | |
492 | llvm::Value *TaskCodeGenLLVM::reconstruct_quant_float( |
493 | llvm::Value *input_digits, |
494 | llvm::Value *input_exponent_val, |
495 | QuantFloatType *qflt, |
496 | bool shared_exponent) { |
497 | auto digits = input_digits; |
498 | auto exponent_val = input_exponent_val; |
499 | // Make sure the exponent is within the range of the exponent type |
500 | auto exponent_offset = |
501 | tlctx->get_constant(qflt->get_exponent_conversion_offset()); |
502 | |
503 | // Note that zeros need special treatment, when truncated during store. |
504 | auto exponent_type = qflt->get_exponent_type()->as<QuantIntType>(); |
505 | if (exponent_type->get_num_bits() < 8) { |
506 | auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, |
507 | exponent_val, tlctx->get_constant(0)); |
508 | exponent_offset = |
509 | builder->CreateSelect(cond, exponent_offset, tlctx->get_constant(0)); |
510 | } |
511 | |
512 | if (qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) { |
513 | // Construct an f32 out of exponent_val and digits |
514 | // Assuming digits and exponent_val are i32 |
515 | // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits |
516 | |
517 | digits = builder->CreateAnd( |
518 | digits, |
519 | (1u << qflt->get_digits_type()->as<QuantIntType>()->get_num_bits()) - |
520 | 1); |
521 | |
522 | llvm::Value *sign_bit = nullptr; |
523 | |
524 | if (shared_exponent) { |
525 | if (qflt->get_is_signed()) { |
526 | sign_bit = builder->CreateAnd( |
527 | digits, tlctx->get_constant(1u << qflt->get_digit_bits())); |
528 | digits = builder->CreateXor(digits, sign_bit); |
529 | sign_bit = builder->CreateShl(sign_bit, 31 - qflt->get_digit_bits()); |
530 | digits = builder->CreateShl(digits, 1); |
531 | } |
532 | // There is a leading 1 that marks the beginning of the digits. |
533 | // When not using shared exponents, the 1 bit is not needed (since digits |
534 | // always starts with 1). |
535 | // declare i32 @llvm.ctlz.i32 (i32 <src>, i1 <is_zero_undef>) |
536 | auto num_leading_zeros = builder->CreateIntrinsic( |
537 | llvm::Intrinsic::ctlz, {llvm::Type::getInt32Ty(*llvm_context)}, |
538 | {digits, tlctx->get_constant(false)}); |
539 | auto = builder->CreateSub( |
540 | tlctx->get_constant(31 - qflt->get_digit_bits()), num_leading_zeros); |
541 | exponent_offset = builder->CreateAdd(exponent_offset, extra_shift); |
542 | |
543 | if (!qflt->get_is_signed()) |
544 | exponent_offset = |
545 | builder->CreateAdd(exponent_offset, tlctx->get_constant(1)); |
546 | |
547 | auto digits_shift = builder->CreateSub( |
548 | tlctx->get_constant(23 - qflt->get_digit_bits()), extra_shift); |
549 | digits = builder->CreateShl(digits, digits_shift); |
550 | } else { |
551 | digits = builder->CreateShl( |
552 | digits, tlctx->get_constant(23 - qflt->get_digit_bits())); |
553 | } |
554 | auto fraction_bits = builder->CreateAnd(digits, (1u << 23) - 1); |
555 | |
556 | exponent_val = builder->CreateAdd(exponent_val, exponent_offset); |
557 | |
558 | auto exponent_bits = |
559 | builder->CreateShl(exponent_val, tlctx->get_constant(23)); |
560 | |
561 | auto f32_bits = builder->CreateOr(exponent_bits, fraction_bits); |
562 | |
563 | if (shared_exponent) { |
564 | // Handle zero exponent |
565 | auto zero_exponent = |
566 | builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ, |
567 | input_exponent_val, tlctx->get_constant(0)); |
568 | auto zero_digits = |
569 | builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ, input_digits, |
570 | tlctx->get_constant(0)); |
571 | auto zero_output = builder->CreateOr(zero_exponent, zero_digits); |
572 | f32_bits = |
573 | builder->CreateSelect(zero_output, tlctx->get_constant(0), f32_bits); |
574 | } |
575 | |
576 | if (qflt->get_is_signed()) { |
577 | if (!sign_bit) { |
578 | sign_bit = builder->CreateAnd(digits, tlctx->get_constant(1u << 23)); |
579 | sign_bit = builder->CreateShl(sign_bit, tlctx->get_constant(31 - 23)); |
580 | } |
581 | f32_bits = builder->CreateOr(f32_bits, sign_bit); |
582 | } |
583 | |
584 | return builder->CreateBitCast(f32_bits, |
585 | llvm::Type::getFloatTy(*llvm_context)); |
586 | } else { |
587 | TI_NOT_IMPLEMENTED; |
588 | } |
589 | } |
590 | |
591 | } // namespace taichi::lang |
592 | |
593 | #endif // #ifdef TI_WITH_LLVM |
594 | |