1#ifdef TI_WITH_LLVM
2#include "taichi/codegen/llvm/codegen_llvm.h"
3#include "taichi/ir/statements.h"
4
5namespace taichi::lang {
6
7namespace {
8
9inline void update_mask(uint64 &mask, uint32 num_bits, uint32 offset) {
10 uint64 new_mask =
11 (((~(uint64)0) << (64 - num_bits)) >> (64 - offset - num_bits));
12 TI_ASSERT((mask & new_mask) == 0);
13 mask |= new_mask;
14}
15
16} // namespace
17
18llvm::Value *TaskCodeGenLLVM::atomic_add_quant_int(llvm::Value *ptr,
19 llvm::Type *physical_type,
20 QuantIntType *qit,
21 llvm::Value *value,
22 bool value_is_signed) {
23 auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
24 return call(fmt::format("atomic_add_partial_bits_b{}",
25 physical_type->getIntegerBitWidth()),
26 byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()),
27 builder->CreateIntCast(value, physical_type, value_is_signed));
28}
29
30llvm::Value *TaskCodeGenLLVM::atomic_add_quant_fixed(llvm::Value *ptr,
31 llvm::Type *physical_type,
32 QuantFixedType *qfxt,
33 llvm::Value *value) {
34 auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
35 auto qit = qfxt->get_digits_type()->as<QuantIntType>();
36 auto val_store = to_quant_fixed(value, qfxt);
37 val_store = builder->CreateSExt(val_store, physical_type);
38 return call(fmt::format("atomic_add_partial_bits_b{}",
39 physical_type->getIntegerBitWidth()),
40 byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()),
41 val_store);
42}
43
44llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,
45 QuantFixedType *qfxt) {
46 // Compute int(real * (1.0 / scale) + 0.5)
47 auto compute_type = qfxt->get_compute_type();
48 auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()),
49 tlctx->get_data_type(compute_type));
50 auto input_real =
51 builder->CreateFPCast(real, tlctx->get_data_type(compute_type));
52 auto scaled = builder->CreateFMul(input_real, s);
53
54 // Add/minus the 0.5 offset for rounding
55 scaled =
56 call(fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)),
57 scaled);
58
59 auto qit = qfxt->get_digits_type()->as<QuantIntType>();
60 if (qit->get_is_signed()) {
61 return builder->CreateFPToSI(scaled,
62 tlctx->get_data_type(qit->get_compute_type()));
63 } else {
64 return builder->CreateFPToUI(scaled,
65 tlctx->get_data_type(qit->get_compute_type()));
66 }
67}
68
69void TaskCodeGenLLVM::store_quant_int(llvm::Value *ptr,
70 llvm::Type *physical_type,
71 QuantIntType *qit,
72 llvm::Value *value,
73 bool atomic) {
74 auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
75 // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers.
76 // Try to support 8/16-bit physical types.
77 call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "",
78 physical_type->getIntegerBitWidth()),
79 byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()),
80 builder->CreateIntCast(value, physical_type, false));
81}
82
83void TaskCodeGenLLVM::store_quant_fixed(llvm::Value *ptr,
84 llvm::Type *physical_type,
85 QuantFixedType *qfxt,
86 llvm::Value *value,
87 bool atomic) {
88 store_quant_int(ptr, physical_type,
89 qfxt->get_digits_type()->as<QuantIntType>(),
90 to_quant_fixed(value, qfxt), atomic);
91}
92
93void TaskCodeGenLLVM::store_masked(llvm::Value *ptr,
94 llvm::Type *ty,
95 uint64 mask,
96 llvm::Value *value,
97 bool atomic) {
98 if (!mask) {
99 // do not store anything
100 return;
101 }
102 uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth());
103 if ((!atomic || compile_config.quant_opt_atomic_demotion) &&
104 ((mask & full_mask) == full_mask)) {
105 builder->CreateStore(value, ptr);
106 return;
107 }
108 call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "",
109 ty->getIntegerBitWidth()),
110 ptr, tlctx->get_constant(mask),
111 builder->CreateIntCast(value, ty, false));
112}
113
114llvm::Value *TaskCodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
115 QuantFloatType *qflt) {
116 // Since we have fewer bits in the exponent type than in f32, an
117 // offset is necessary to make sure the stored exponent values are
118 // representable by the exponent quant int type.
119 auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent,
120 tlctx->get_constant(0));
121 return builder->CreateSelect(
122 cond, tlctx->get_constant(qflt->get_exponent_conversion_offset()),
123 tlctx->get_constant(0));
124}
125
126llvm::Value *TaskCodeGenLLVM::quant_int_or_quant_fixed_to_bits(
127 llvm::Value *val,
128 Type *input_type,
129 llvm::Type *output_type) {
130 QuantIntType *qit = nullptr;
131 if (auto qfxt = input_type->cast<QuantFixedType>()) {
132 qit = qfxt->get_digits_type()->as<QuantIntType>();
133 val = to_quant_fixed(val, qfxt);
134 } else {
135 qit = input_type->as<QuantIntType>();
136 }
137 if (qit->get_num_bits() < val->getType()->getIntegerBitWidth()) {
138 val = builder->CreateAnd(
139 val, tlctx->get_constant(qit->get_compute_type(),
140 uint64((1ULL << qit->get_num_bits()) - 1)));
141 }
142 val = builder->CreateZExt(val, output_type);
143 return val;
144}
145
146void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) {
147 auto bit_struct = stmt->get_bit_struct();
148 auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());
149
150 int num_non_exponent_children = 0;
151 for (int i = 0; i < bit_struct->get_num_members(); i++) {
152 if (bit_struct->get_member_exponent_users(i).empty()) {
153 num_non_exponent_children++;
154 }
155 }
156 bool store_all_components = false;
157 if (compile_config.quant_opt_atomic_demotion &&
158 stmt->ch_ids.size() == num_non_exponent_children) {
159 stmt->is_atomic = false;
160 store_all_components = true;
161 }
162
163 bool has_shared_exponent = false;
164 for (auto ch_id : stmt->ch_ids) {
165 if (bit_struct->get_member_owns_shared_exponent(ch_id)) {
166 has_shared_exponent = true;
167 }
168 }
169 if (has_shared_exponent) {
170 store_quant_floats_with_shared_exponents(stmt);
171 }
172
173 llvm::Value *bit_struct_val = nullptr;
174 for (int i = 0; i < stmt->ch_ids.size(); i++) {
175 auto ch_id = stmt->ch_ids[i];
176 auto exp = bit_struct->get_member_exponent(ch_id);
177 if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
178 // already handled in store_quant_floats_with_shared_exponents
179 continue;
180 }
181 auto dtype = bit_struct->get_member_type(ch_id);
182 auto val = llvm_val[stmt->values[i]];
183 if (auto qflt = dtype->cast<QuantFloatType>()) {
184 // Quant float type with non-shared exponent.
185 llvm::Value *digit_bits = nullptr;
186 // Extract exponent and digits from compute type (assumed to be f32 for
187 // now).
188 TI_ASSERT(qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32));
189
190 // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits
191
192 auto f32_bits =
193 builder->CreateBitCast(val, llvm::Type::getInt32Ty(*llvm_context));
194 // Rounding to nearest here. Note that if the digits overflows then the
195 // carry-on will contribute to the exponent, which is desired.
196 if (qflt->get_digit_bits() < 23) {
197 f32_bits = builder->CreateAdd(
198 f32_bits, tlctx->get_constant(1 << (22 - qflt->get_digit_bits())));
199 }
200
201 auto exponent_bits = builder->CreateAShr(f32_bits, 23);
202 exponent_bits =
203 builder->CreateAnd(exponent_bits, tlctx->get_constant((1 << 8) - 1));
204 auto value_bits = builder->CreateAShr(
205 f32_bits, tlctx->get_constant(23 - qflt->get_digit_bits()));
206
207 digit_bits = builder->CreateAnd(
208 value_bits, tlctx->get_constant((1 << (qflt->get_digit_bits())) - 1));
209
210 if (qflt->get_is_signed()) {
211 // extract the sign bit
212 auto sign_bit =
213 builder->CreateAnd(f32_bits, tlctx->get_constant(0x80000000u));
214 // insert the sign bit to digit bits
215 digit_bits = builder->CreateOr(
216 digit_bits,
217 builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits()));
218 }
219
220 auto exponent_offset = get_exponent_offset(exponent_bits, qflt);
221 exponent_bits = builder->CreateSub(exponent_bits, exponent_offset);
222 exponent_bits = call("max_i32", exponent_bits, tlctx->get_constant(0));
223
224 // Compute the bit pointer of the exponent bits.
225 val = builder->CreateIntCast(exponent_bits, physical_type, false);
226 val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp));
227
228 if (bit_struct_val == nullptr) {
229 bit_struct_val = val;
230 } else {
231 bit_struct_val = builder->CreateOr(bit_struct_val, val);
232 }
233 // Here we implement flush to zero (FTZ): if exponent is zero, we force
234 // the digits to be zero.
235 // TODO: it seems that this can be more efficiently implemented using a
236 // bit_and.
237 auto exp_non_zero =
238 builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent_bits,
239 tlctx->get_constant(0));
240 val = builder->CreateSelect(exp_non_zero, digit_bits,
241 tlctx->get_constant(0));
242 val = builder->CreateIntCast(val, physical_type, false);
243 val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
244 } else {
245 val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type);
246 val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
247 }
248
249 if (bit_struct_val == nullptr) {
250 bit_struct_val = val;
251 } else {
252 bit_struct_val = builder->CreateOr(bit_struct_val, val);
253 }
254 }
255 if (store_all_components && !has_shared_exponent) {
256 // Store all components here.
257 builder->CreateStore(bit_struct_val, llvm_val[stmt->ptr]);
258 } else {
259 // Create a mask and use a single (atomic)CAS
260 uint64 mask = 0;
261 for (int i = 0; i < stmt->ch_ids.size(); i++) {
262 auto ch_id = stmt->ch_ids[i];
263 auto exp = bit_struct->get_member_exponent(ch_id);
264 if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
265 // already handled in store_quant_floats_with_shared_exponents
266 continue;
267 }
268 auto dtype = bit_struct->get_member_type(ch_id);
269 QuantIntType *qit = nullptr;
270 if (auto qflt = dtype->cast<QuantFloatType>()) {
271 auto exponent_qit = qflt->get_exponent_type()->as<QuantIntType>();
272 update_mask(mask, exponent_qit->get_num_bits(),
273 bit_struct->get_member_bit_offset(exp));
274 qit = qflt->get_digits_type()->as<QuantIntType>();
275 } else if (auto qfxt = dtype->cast<QuantFixedType>()) {
276 qit = qfxt->get_digits_type()->as<QuantIntType>();
277 } else {
278 qit = dtype->as<QuantIntType>();
279 }
280 update_mask(mask, qit->get_num_bits(),
281 bit_struct->get_member_bit_offset(ch_id));
282 }
283 store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val,
284 stmt->is_atomic);
285 }
286}
287
288void TaskCodeGenLLVM::store_quant_floats_with_shared_exponents(
289 BitStructStoreStmt *stmt) {
290 // handle each exponent separately
291 auto bit_struct = stmt->get_bit_struct();
292 auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());
293 auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]);
294 // fuse all stores into a masked store
295 llvm::Value *masked_val = nullptr;
296 uint64 mask = 0;
297 for (int i = 0; i < bit_struct->get_num_members(); i++) {
298 auto &exponent_users = bit_struct->get_member_exponent_users(i);
299 // make sure i-th member is a shared exponent
300 if (exponent_users.size() < 2)
301 continue;
302 // load all floats with the shared exponent
303 std::vector<llvm::Value *> floats;
304 for (auto user : exponent_users) {
305 if (auto input =
306 std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), user);
307 input != stmt->ch_ids.end()) {
308 floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]);
309 } else {
310 floats.push_back(extract_quant_float(physical_value, bit_struct, user));
311 }
312 }
313 // convert to i32 for bit operations
314 llvm::Value *max_exp_bits = nullptr;
315 for (auto f : floats) {
316 // TODO: we only support f32 here.
317 auto exp_bits = extract_exponent_from_f32(f);
318 if (max_exp_bits) {
319 max_exp_bits = call("max_u32", max_exp_bits, exp_bits);
320 } else {
321 max_exp_bits = exp_bits;
322 }
323 }
324
325 auto first_qflt =
326 bit_struct->get_member_type(exponent_users[0])->as<QuantFloatType>();
327 auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt);
328
329 auto max_exp_bits_to_store =
330 builder->CreateSub(max_exp_bits, exponent_offset);
331
332 max_exp_bits_to_store =
333 call("max_i32", max_exp_bits_to_store, tlctx->get_constant(0));
334
335 // store the exponent
336 auto bit_offset = bit_struct->get_member_bit_offset(i);
337 auto val = builder->CreateZExt(max_exp_bits_to_store, physical_type);
338 val = builder->CreateShl(val, bit_offset);
339 if (masked_val == nullptr) {
340 masked_val = val;
341 } else {
342 masked_val = builder->CreateOr(masked_val, val);
343 }
344 update_mask(
345 mask,
346 bit_struct->get_member_type(i)->as<QuantIntType>()->get_num_bits(),
347 bit_offset);
348
349 for (int c = 0; c < (int)exponent_users.size(); c++) {
350 auto user = exponent_users[c];
351 auto digits =
352 extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits);
353 auto qflt = bit_struct->get_member_type(user)->as<QuantFloatType>();
354 auto digits_bit_offset = bit_struct->get_member_bit_offset(user);
355 auto right_shift_bits = 24 - qflt->get_digit_bits();
356
357 // round to nearest
358 digits = builder->CreateAdd(
359 digits, tlctx->get_constant(1 << (right_shift_bits - 1)));
360 // do not allow overflowing
361 digits = call("min_u32", digits, tlctx->get_constant((1u << 24) - 1));
362
363 // Compress f32 digits to qflt digits.
364 // Note that we need to keep the leading 1 bit so 24 instead of 23 in the
365 // following code.
366 digits = builder->CreateLShr(digits, right_shift_bits);
367 if (qflt->get_is_signed()) {
368 auto float_bits = builder->CreateBitCast(
369 floats[c], llvm::Type::getInt32Ty(*llvm_context));
370 auto sign_bit = builder->CreateAnd(float_bits, 1 << 31);
371 sign_bit = builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits());
372 digits = builder->CreateOr(digits, sign_bit);
373 }
374
375 // store the digits
376 val = builder->CreateZExt(digits, physical_type);
377 val = builder->CreateShl(val, digits_bit_offset);
378 masked_val = builder->CreateOr(masked_val, val);
379 auto num_digit_bits =
380 qflt->get_digits_type()->as<QuantIntType>()->get_num_bits();
381 update_mask(mask, num_digit_bits, digits_bit_offset);
382 }
383 }
384 store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val,
385 stmt->is_atomic);
386}
387
388llvm::Value *TaskCodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) {
389 TI_ASSERT(f->getType() == llvm::Type::getFloatTy(*llvm_context));
390 f = builder->CreateBitCast(f, llvm::Type::getInt32Ty(*llvm_context));
391 auto exp_bits = builder->CreateLShr(f, tlctx->get_constant(23));
392 return builder->CreateAnd(exp_bits, tlctx->get_constant((1 << 8) - 1));
393}
394
395llvm::Value *TaskCodeGenLLVM::extract_digits_from_f32(llvm::Value *f,
396 bool full) {
397 TI_ASSERT(f->getType() == llvm::Type::getFloatTy(*llvm_context));
398 f = builder->CreateBitCast(f, llvm::Type::getInt32Ty(*llvm_context));
399 auto digits = builder->CreateAnd(f, tlctx->get_constant((1 << 23) - 1));
400 if (full) {
401 digits = builder->CreateOr(digits, tlctx->get_constant(1 << 23));
402 }
403 return digits;
404}
405
406llvm::Value *TaskCodeGenLLVM::extract_digits_from_f32_with_shared_exponent(
407 llvm::Value *f,
408 llvm::Value *shared_exp) {
409 auto exp = extract_exponent_from_f32(f);
410 auto exp_offset = builder->CreateSub(shared_exp, exp);
411 // TODO: handle negative digits
412
413 // There are two cases that may result in zero digits:
414 // - exp is zero. This means f itself is zero. Note that when processors
415 // running under FTZ (flush to zero), exp = 0 implies digits = 0.
416 // - exp is too small compared to shared_exp, or equivalently exp_offset is
417 // too large. This means we need to flush digits to zero.
418
419 // If exp is nonzero, insert an extra "1" bit that was originally implicit.
420 auto exp_non_zero = builder->CreateICmpNE(exp, tlctx->get_constant(0));
421 exp_non_zero =
422 builder->CreateZExt(exp_non_zero, llvm::Type::getInt32Ty(*llvm_context));
423 auto implicit_bit = builder->CreateShl(exp_non_zero, tlctx->get_constant(23));
424
425 auto digits = extract_digits_from_f32(f, true);
426 digits = builder->CreateOr(digits, implicit_bit);
427 exp_offset = call("min_u32", exp_offset, tlctx->get_constant(31));
428 return builder->CreateLShr(digits, exp_offset);
429}
430
431llvm::Value *TaskCodeGenLLVM::extract_quant_float(llvm::Value *physical_value,
432 BitStructType *bit_struct,
433 int digits_id) {
434 auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
435 auto exponent_id = bit_struct->get_member_exponent(digits_id);
436 auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
437 auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id);
438 auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id);
439 auto digits =
440 extract_quant_int(physical_value, tlctx->get_constant(digits_bit_offset),
441 qflt->get_digits_type()->as<QuantIntType>());
442 auto exponent = extract_quant_int(
443 physical_value, tlctx->get_constant(exponent_bit_offset),
444 qflt->get_exponent_type()->as<QuantIntType>());
445 return reconstruct_quant_float(digits, exponent, qflt, shared_exponent);
446}
447
448llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
449 llvm::Value *bit_offset,
450 QuantIntType *qit) {
451 auto physical_type = physical_value->getType();
452 // bit shifting
453 // first left shift `physical_type - (offset + num_bits)`
454 // then right shift `physical_type - num_bits`
455 auto bit_end =
456 builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits()));
457 auto left = builder->CreateSub(
458 tlctx->get_constant(physical_type->getIntegerBitWidth()), bit_end);
459 auto right = builder->CreateSub(
460 tlctx->get_constant(physical_type->getIntegerBitWidth()),
461 tlctx->get_constant(qit->get_num_bits()));
462 left = builder->CreateIntCast(left, physical_type, false);
463 right = builder->CreateIntCast(right, physical_type, false);
464 auto step1 = builder->CreateShl(physical_value, left);
465 llvm::Value *step2 = nullptr;
466
467 if (qit->get_is_signed())
468 step2 = builder->CreateAShr(step1, right);
469 else
470 step2 = builder->CreateLShr(step1, right);
471
472 return builder->CreateIntCast(step2,
473 tlctx->get_data_type(qit->get_compute_type()),
474 qit->get_is_signed());
475}
476
477llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
478 QuantFixedType *qfxt) {
479 // Compute float(digits) * scale
480 llvm::Value *cast = nullptr;
481 auto compute_type = qfxt->get_compute_type()->as<PrimitiveType>();
482 if (qfxt->get_is_signed()) {
483 cast = builder->CreateSIToFP(digits, tlctx->get_data_type(compute_type));
484 } else {
485 cast = builder->CreateUIToFP(digits, tlctx->get_data_type(compute_type));
486 }
487 llvm::Value *s = tlctx->get_constant(qfxt->get_scale());
488 s = builder->CreateFPCast(s, tlctx->get_data_type(compute_type));
489 return builder->CreateFMul(cast, s);
490}
491
492llvm::Value *TaskCodeGenLLVM::reconstruct_quant_float(
493 llvm::Value *input_digits,
494 llvm::Value *input_exponent_val,
495 QuantFloatType *qflt,
496 bool shared_exponent) {
497 auto digits = input_digits;
498 auto exponent_val = input_exponent_val;
499 // Make sure the exponent is within the range of the exponent type
500 auto exponent_offset =
501 tlctx->get_constant(qflt->get_exponent_conversion_offset());
502
503 // Note that zeros need special treatment, when truncated during store.
504 auto exponent_type = qflt->get_exponent_type()->as<QuantIntType>();
505 if (exponent_type->get_num_bits() < 8) {
506 auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE,
507 exponent_val, tlctx->get_constant(0));
508 exponent_offset =
509 builder->CreateSelect(cond, exponent_offset, tlctx->get_constant(0));
510 }
511
512 if (qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) {
513 // Construct an f32 out of exponent_val and digits
514 // Assuming digits and exponent_val are i32
515 // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits
516
517 digits = builder->CreateAnd(
518 digits,
519 (1u << qflt->get_digits_type()->as<QuantIntType>()->get_num_bits()) -
520 1);
521
522 llvm::Value *sign_bit = nullptr;
523
524 if (shared_exponent) {
525 if (qflt->get_is_signed()) {
526 sign_bit = builder->CreateAnd(
527 digits, tlctx->get_constant(1u << qflt->get_digit_bits()));
528 digits = builder->CreateXor(digits, sign_bit);
529 sign_bit = builder->CreateShl(sign_bit, 31 - qflt->get_digit_bits());
530 digits = builder->CreateShl(digits, 1);
531 }
532 // There is a leading 1 that marks the beginning of the digits.
533 // When not using shared exponents, the 1 bit is not needed (since digits
534 // always starts with 1).
535 // declare i32 @llvm.ctlz.i32 (i32 <src>, i1 <is_zero_undef>)
536 auto num_leading_zeros = builder->CreateIntrinsic(
537 llvm::Intrinsic::ctlz, {llvm::Type::getInt32Ty(*llvm_context)},
538 {digits, tlctx->get_constant(false)});
539 auto extra_shift = builder->CreateSub(
540 tlctx->get_constant(31 - qflt->get_digit_bits()), num_leading_zeros);
541 exponent_offset = builder->CreateAdd(exponent_offset, extra_shift);
542
543 if (!qflt->get_is_signed())
544 exponent_offset =
545 builder->CreateAdd(exponent_offset, tlctx->get_constant(1));
546
547 auto digits_shift = builder->CreateSub(
548 tlctx->get_constant(23 - qflt->get_digit_bits()), extra_shift);
549 digits = builder->CreateShl(digits, digits_shift);
550 } else {
551 digits = builder->CreateShl(
552 digits, tlctx->get_constant(23 - qflt->get_digit_bits()));
553 }
554 auto fraction_bits = builder->CreateAnd(digits, (1u << 23) - 1);
555
556 exponent_val = builder->CreateAdd(exponent_val, exponent_offset);
557
558 auto exponent_bits =
559 builder->CreateShl(exponent_val, tlctx->get_constant(23));
560
561 auto f32_bits = builder->CreateOr(exponent_bits, fraction_bits);
562
563 if (shared_exponent) {
564 // Handle zero exponent
565 auto zero_exponent =
566 builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ,
567 input_exponent_val, tlctx->get_constant(0));
568 auto zero_digits =
569 builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ, input_digits,
570 tlctx->get_constant(0));
571 auto zero_output = builder->CreateOr(zero_exponent, zero_digits);
572 f32_bits =
573 builder->CreateSelect(zero_output, tlctx->get_constant(0), f32_bits);
574 }
575
576 if (qflt->get_is_signed()) {
577 if (!sign_bit) {
578 sign_bit = builder->CreateAnd(digits, tlctx->get_constant(1u << 23));
579 sign_bit = builder->CreateShl(sign_bit, tlctx->get_constant(31 - 23));
580 }
581 f32_bits = builder->CreateOr(f32_bits, sign_bit);
582 }
583
584 return builder->CreateBitCast(f32_bits,
585 llvm::Type::getFloatTy(*llvm_context));
586 } else {
587 TI_NOT_IMPLEMENTED;
588 }
589}
590
591} // namespace taichi::lang
592
593#endif // #ifdef TI_WITH_LLVM
594