1#include "taichi/codegen/llvm/codegen_llvm.h"
2
3#include <algorithm>
4
5#ifdef TI_WITH_LLVM
6#include "llvm/Bitcode/BitcodeReader.h"
7#include "llvm/IR/Module.h"
8#include "llvm/Linker/Linker.h"
9#include "taichi/analysis/offline_cache_util.h"
10#include "taichi/ir/statements.h"
11#include "taichi/ir/transforms.h"
12#include "taichi/runtime/llvm/launch_arg_info.h"
13#include "taichi/runtime/llvm/llvm_offline_cache.h"
14#include "taichi/program/extension.h"
15#include "taichi/runtime/program_impls/llvm/llvm_program.h"
16#include "taichi/codegen/llvm/struct_llvm.h"
17#include "taichi/util/file_sequence_writer.h"
18#include "taichi/codegen/codegen_utils.h"
19
20namespace taichi::lang {
21
22// TODO: sort function definitions to match declaration order in header
23
24// TODO(k-ye): Hide FunctionCreationGuard inside cpp file
25FunctionCreationGuard::FunctionCreationGuard(
26 TaskCodeGenLLVM *mb,
27 std::vector<llvm::Type *> arguments,
28 const std::string &func_name)
29 : mb(mb) {
30 // Create the loop body function
31 auto body_function_type = llvm::FunctionType::get(
32 llvm::Type::getVoidTy(*mb->llvm_context), arguments, false);
33
34 body = llvm::Function::Create(body_function_type,
35 llvm::Function::InternalLinkage, func_name,
36 mb->module.get());
37 old_func = mb->func;
38 // emit into loop body function
39 mb->func = body;
40
41 allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs", body);
42 old_entry = mb->entry_block;
43 mb->entry_block = allocas;
44
45 final = llvm::BasicBlock::Create(*mb->llvm_context, "final", body);
46 old_final = mb->final_block;
47 mb->final_block = final;
48
49 entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry", mb->func);
50
51 ip = mb->builder->saveIP();
52 mb->builder->SetInsertPoint(entry);
53
54 auto body_bb =
55 llvm::BasicBlock::Create(*mb->llvm_context, "function_body", mb->func);
56 mb->builder->CreateBr(body_bb);
57 mb->builder->SetInsertPoint(body_bb);
58}
59
60FunctionCreationGuard::~FunctionCreationGuard() {
61 if (!mb->returned) {
62 mb->builder->CreateBr(final);
63 }
64 mb->builder->SetInsertPoint(final);
65 mb->builder->CreateRetVoid();
66 mb->returned = false;
67
68 mb->builder->SetInsertPoint(allocas);
69 mb->builder->CreateBr(entry);
70
71 mb->entry_block = old_entry;
72 mb->final_block = old_final;
73 mb->func = old_func;
74 mb->builder->restoreIP(ip);
75
76 TI_ASSERT(!llvm::verifyFunction(*body, &llvm::errs()));
77}
78
79namespace {
80
81class CodeGenStmtGuard {
82 public:
83 using Getter = std::function<llvm::BasicBlock *(void)>;
84 using Setter = std::function<void(llvm::BasicBlock *)>;
85
86 explicit CodeGenStmtGuard(Getter getter, Setter setter)
87 : saved_stmt_(getter()), setter_(std::move(setter)) {
88 }
89
90 ~CodeGenStmtGuard() {
91 setter_(saved_stmt_);
92 }
93
94 CodeGenStmtGuard(CodeGenStmtGuard &&) = default;
95 CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default;
96
97 private:
98 llvm::BasicBlock *saved_stmt_;
99 Setter setter_;
100};
101
102CodeGenStmtGuard make_loop_reentry_guard(TaskCodeGenLLVM *cg) {
103 return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; },
104 [cg](llvm::BasicBlock *saved_stmt) {
105 cg->current_loop_reentry = saved_stmt;
106 });
107}
108
109CodeGenStmtGuard make_while_after_loop_guard(TaskCodeGenLLVM *cg) {
110 return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; },
111 [cg](llvm::BasicBlock *saved_stmt) {
112 cg->current_while_after_loop = saved_stmt;
113 });
114}
115
116} // namespace
117
118// TaskCodeGenLLVM
119void TaskCodeGenLLVM::visit(Block *stmt_list) {
120 for (auto &stmt : stmt_list->statements) {
121 stmt->accept(this);
122 if (returned) {
123 break;
124 }
125 }
126}
127
128void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
129 if (stmt->ret_type->is<TensorType>()) {
130 auto tensor_type = stmt->ret_type->cast<TensorType>();
131 auto type = tlctx->get_data_type(tensor_type);
132 if (stmt->is_shared) {
133 auto base = new llvm::GlobalVariable(
134 *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr,
135 fmt::format("shared_array_{}", stmt->id), nullptr,
136 llvm::GlobalVariable::NotThreadLocal, 3 /*addrspace=shared*/);
137 base->setAlignment(llvm::MaybeAlign(8));
138 auto ptr_type = llvm::PointerType::get(type, 0);
139 llvm_val[stmt] = builder->CreatePointerCast(base, ptr_type);
140 } else {
141 llvm_val[stmt] = create_entry_block_alloca(type);
142 }
143 } else {
144 llvm_val[stmt] =
145 create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer());
146 // initialize as zero if element is not a pointer
147 if (!stmt->ret_type.is_pointer())
148 builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0),
149 llvm_val[stmt]);
150 }
151}
152
153void TaskCodeGenLLVM::visit(RandStmt *stmt) {
154 if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
155 // Promoting to f32 since there's no rand_f16 support in runtime.cpp.
156 auto val_f32 = call("rand_f32", get_context());
157 llvm_val[stmt] =
158 builder->CreateFPTrunc(val_f32, llvm::Type::getHalfTy(*llvm_context));
159 } else {
160 llvm_val[stmt] = call(
161 fmt::format("rand_{}", data_type_name(stmt->ret_type)), get_context());
162 }
163}
164
165void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) {
166 auto input = llvm_val[stmt->operand];
167 auto input_taichi_type = stmt->operand->ret_type;
168 if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) {
169 // Promote to f32 since we don't have f16 support for extra unary ops in in
170 // runtime.cpp.
171 input = builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context));
172 input_taichi_type = PrimitiveType::f32;
173 }
174
175 auto op = stmt->op_type;
176 auto input_type = input->getType();
177
178#define UNARY_STD(x) \
179 else if (op == UnaryOpType::x) { \
180 if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
181 llvm_val[stmt] = call(#x "_f32", input); \
182 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
183 llvm_val[stmt] = call(#x "_f64", input); \
184 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \
185 llvm_val[stmt] = call(#x "_i32", input); \
186 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { \
187 llvm_val[stmt] = call(#x "_i64", input); \
188 } else { \
189 TI_NOT_IMPLEMENTED \
190 } \
191 }
192 if (false) {
193 }
194 UNARY_STD(abs)
195 UNARY_STD(exp)
196 UNARY_STD(log)
197 UNARY_STD(tan)
198 UNARY_STD(tanh)
199 UNARY_STD(sgn)
200 UNARY_STD(logic_not)
201 UNARY_STD(acos)
202 UNARY_STD(asin)
203 UNARY_STD(cos)
204 UNARY_STD(sin)
205 else if (op == UnaryOpType::sqrt) {
206 llvm_val[stmt] =
207 builder->CreateIntrinsic(llvm::Intrinsic::sqrt, {input_type}, {input});
208 }
209 else {
210 TI_P(unary_op_type_name(op));
211 TI_NOT_IMPLEMENTED
212 }
213#undef UNARY_STD
214 if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
215 // Convert back to f16
216 llvm_val[stmt] = builder->CreateFPTrunc(
217 llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
218 }
219}
220
221std::unique_ptr<RuntimeObject> TaskCodeGenLLVM::emit_struct_meta_object(
222 SNode *snode) {
223 std::unique_ptr<RuntimeObject> meta;
224 if (snode->type == SNodeType::dense) {
225 meta = std::make_unique<RuntimeObject>("DenseMeta", this, builder.get());
226 emit_struct_meta_base("Dense", meta->ptr, snode);
227 meta->call("set_morton_dim", tlctx->get_constant((int)snode->_morton));
228 } else if (snode->type == SNodeType::pointer) {
229 meta = std::make_unique<RuntimeObject>("PointerMeta", this, builder.get());
230 emit_struct_meta_base("Pointer", meta->ptr, snode);
231 } else if (snode->type == SNodeType::root) {
232 meta = std::make_unique<RuntimeObject>("RootMeta", this, builder.get());
233 emit_struct_meta_base("Root", meta->ptr, snode);
234 } else if (snode->type == SNodeType::dynamic) {
235 meta = std::make_unique<RuntimeObject>("DynamicMeta", this, builder.get());
236 emit_struct_meta_base("Dynamic", meta->ptr, snode);
237 meta->call("set_chunk_size", tlctx->get_constant(snode->chunk_size));
238 } else if (snode->type == SNodeType::bitmasked) {
239 meta =
240 std::make_unique<RuntimeObject>("BitmaskedMeta", this, builder.get());
241 emit_struct_meta_base("Bitmasked", meta->ptr, snode);
242 } else if (snode->type == SNodeType::quant_array) {
243 meta = std::make_unique<RuntimeObject>("DenseMeta", this, builder.get());
244 emit_struct_meta_base("Dense", meta->ptr, snode);
245 } else {
246 TI_P(snode_type_name(snode->type));
247 TI_NOT_IMPLEMENTED;
248 }
249 return meta;
250}
251
252void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name,
253 llvm::Value *node_meta,
254 SNode *snode) {
255 RuntimeObject common("StructMeta", this, builder.get(), node_meta);
256 std::size_t element_size;
257 if (snode->type == SNodeType::dense) {
258 auto body_type =
259 StructCompilerLLVM::get_llvm_body_type(module.get(), snode);
260 auto element_ty = body_type->getArrayElementType();
261 element_size = tlctx->get_type_size(element_ty);
262 } else if (snode->type == SNodeType::pointer) {
263 auto element_ty = StructCompilerLLVM::get_llvm_node_type(
264 module.get(), snode->ch[0].get());
265 element_size = tlctx->get_type_size(element_ty);
266 } else {
267 auto element_ty =
268 StructCompilerLLVM::get_llvm_element_type(module.get(), snode);
269 element_size = tlctx->get_type_size(element_ty);
270 }
271 common.set("snode_id", tlctx->get_constant(snode->id));
272 common.set("element_size", tlctx->get_constant((uint64)element_size));
273 common.set("max_num_elements",
274 tlctx->get_constant(snode->max_num_elements()));
275 common.set("context", get_context());
276
277 /*
278 uint8 *(*lookup_element)(uint8 *, int i);
279 uint8 *(*from_parent_element)(uint8 *);
280 bool (*is_active)(uint8 *, int i);
281 int (*get_num_elements)(uint8 *);
282 void (*refine_coordinates)(PhysicalCoordinates *inp_coord,
283 PhysicalCoordinates *refined_coord,
284 int index);
285 */
286
287 std::vector<std::string> functions = {"lookup_element", "is_active",
288 "get_num_elements"};
289
290 for (auto const &f : functions)
291 common.set(f, get_runtime_function(fmt::format("{}_{}", name, f)));
292
293 // "from_parent_element", "refine_coordinates" are different for different
294 // snodes, even if they have the same type.
295 if (snode->parent)
296 common.set("from_parent_element",
297 get_struct_function(snode->get_ch_from_parent_func_name(),
298 snode->get_snode_tree_id()));
299
300 if (snode->type != SNodeType::place)
301 common.set("refine_coordinates",
302 get_struct_function(snode->refine_coordinates_func_name(),
303 snode->get_snode_tree_id()));
304}
305
306TaskCodeGenLLVM::TaskCodeGenLLVM(const CompileConfig &compile_config,
307 TaichiLLVMContext &tlctx,
308 Kernel *kernel,
309 IRNode *ir,
310 std::unique_ptr<llvm::Module> &&module)
311 // TODO: simplify LLVMModuleBuilder ctor input
312 : LLVMModuleBuilder(
313 module == nullptr ? tlctx.new_module("kernel") : std::move(module),
314 &tlctx),
315 compile_config(compile_config),
316 kernel(kernel),
317 ir(ir),
318 prog(kernel->program) {
319 if (ir == nullptr)
320 this->ir = kernel->ir.get();
321 initialize_context();
322
323 context_ty = get_runtime_type("RuntimeContext");
324 physical_coordinate_ty = get_runtime_type(kLLVMPhysicalCoordinatesName);
325
326 kernel_name = kernel->name + "_kernel";
327 current_callable = kernel;
328}
329
330void TaskCodeGenLLVM::visit(DecorationStmt *stmt) {
331}
332
333void TaskCodeGenLLVM::create_elementwise_cast(
334 UnaryOpStmt *stmt,
335 llvm::Type *to_ty,
336 std::function<llvm::Value *(llvm::Value *, llvm::Type *)> func,
337 bool on_self) {
338 auto from_ty = stmt->operand->ret_type->cast<TensorType>();
339 TI_ASSERT_INFO(from_ty,
340 "Cannot perform elementwise ops on non-tensor type {}",
341 from_ty->to_string());
342 llvm::Value *vec = llvm::UndefValue::get(llvm::VectorType::get(
343 to_ty, from_ty->get_num_elements(), /*scalable=*/false));
344 for (int i = 0; i < from_ty->get_num_elements(); ++i) {
345 auto elem = builder->CreateExtractElement(
346 on_self ? llvm_val[stmt] : llvm_val[stmt->operand], i);
347 auto cast_value = func(elem, to_ty);
348 vec = builder->CreateInsertElement(vec, cast_value, i);
349 }
350 llvm_val[stmt] = vec;
351}
352
353void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
354 auto input = llvm_val[stmt->operand];
355 auto input_type = input->getType();
356 auto op = stmt->op_type;
357
358#define UNARY_INTRINSIC(x) \
359 else if (op == UnaryOpType::x) { \
360 llvm_val[stmt] = \
361 builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \
362 }
363 if (stmt->op_type == UnaryOpType::cast_value) {
364 llvm::CastInst::CastOps cast_op;
365 auto from = stmt->operand->ret_type;
366 auto to = stmt->ret_type;
367 TI_ASSERT_INFO(
368 from->is<TensorType>() == to->is<TensorType>(),
369 "Cannot cast between tensor type and non-tensor type: {} v.s. {}",
370 from->to_string(), to->to_string());
371
372 if (from == to) {
373 llvm_val[stmt] = llvm_val[stmt->operand];
374 } else if (is_real(from.get_element_type()) !=
375 is_real(to.get_element_type())) {
376 if (is_real(from.get_element_type()) &&
377 (is_integral(to.get_element_type()))) {
378 cast_op = (is_signed(to.get_element_type()))
379 ? llvm::Instruction::CastOps::FPToSI
380 : llvm::Instruction::CastOps::FPToUI;
381 } else if (is_integral(from.get_element_type()) &&
382 is_real(to.get_element_type())) {
383 cast_op = (is_signed(from.get_element_type()))
384 ? llvm::Instruction::CastOps::SIToFP
385 : llvm::Instruction::CastOps::UIToFP;
386 } else {
387 TI_P(data_type_name(from));
388 TI_P(data_type_name(to));
389 TI_NOT_IMPLEMENTED;
390 }
391 bool use_f16 = to->is_primitive(PrimitiveTypeID::f16) ||
392 (to->is<TensorType>() &&
393 to->cast<TensorType>()->get_element_type()->is_primitive(
394 PrimitiveTypeID::f16));
395 auto cast_type = use_f16 ? (to->is<TensorType>()
396 ? TypeFactory::create_tensor_type(
397 to->cast<TensorType>()->get_shape(),
398 PrimitiveType::f32)
399 : PrimitiveType::f32)
400 : stmt->cast_type;
401
402 auto cast_func = [this, cast_op](llvm::Value *value, llvm::Type *type) {
403 return this->builder->CreateCast(cast_op, value, type);
404 };
405 if (!cast_type->is<TensorType>()) {
406 llvm_val[stmt] = cast_func(input, tlctx->get_data_type(cast_type));
407 } else {
408 create_elementwise_cast(
409 stmt,
410 tlctx->get_data_type(
411 cast_type->cast<TensorType>()->get_element_type()),
412 cast_func);
413 }
414
415 if (use_f16) {
416 auto trunc_func = [this](llvm::Value *value, llvm::Type *type) {
417 return this->builder->CreateFPTrunc(value, type);
418 };
419 auto to_ty = llvm::Type::getHalfTy(*llvm_context);
420 if (!cast_type->is<TensorType>()) {
421 llvm_val[stmt] = trunc_func(llvm_val[stmt], to_ty);
422 } else {
423 create_elementwise_cast(stmt, to_ty, trunc_func, /*on_self=*/true);
424 }
425 }
426 } else if (is_real(from.get_element_type()) &&
427 is_real(to.get_element_type())) {
428 auto t1 = from->is<TensorType>()
429 ? from->cast<TensorType>()->get_element_type()
430 : from.operator->();
431 auto t2 = to->is<TensorType>()
432 ? to->cast<TensorType>()->get_element_type()
433 : to.operator->();
434 if (data_type_size(t1) < data_type_size(t2)) {
435 auto cast_func = [this](llvm::Value *value, llvm::Type *type) {
436 return this->builder->CreateFPExt(value, type);
437 };
438 if (!stmt->cast_type->is<TensorType>()) {
439 llvm_val[stmt] =
440 cast_func(input, tlctx->get_data_type(stmt->cast_type));
441 } else {
442 create_elementwise_cast(
443 stmt,
444 tlctx->get_data_type(
445 stmt->cast_type->cast<TensorType>()->get_element_type()),
446 cast_func);
447 }
448 } else {
449 if (to->is_primitive(PrimitiveTypeID::f16) ||
450 (to->is<TensorType>() &&
451 to->cast<TensorType>()->get_element_type()->is_primitive(
452 PrimitiveTypeID::f16))) {
453 if (!to->is<TensorType>()) {
454 llvm_val[stmt] = builder->CreateFPTrunc(
455 builder->CreateFPTrunc(llvm_val[stmt->operand],
456 llvm::Type::getFloatTy(*llvm_context)),
457 llvm::Type::getHalfTy(*llvm_context));
458 } else {
459 auto tensor_type = to->cast<TensorType>();
460 llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(to));
461 for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
462 auto elem = builder->CreateExtractElement(vec, i);
463 auto double_trunced = builder->CreateFPTrunc(
464 builder->CreateFPTrunc(elem,
465 llvm::Type::getFloatTy(*llvm_context)),
466 llvm::Type::getHalfTy(*llvm_context));
467 vec = builder->CreateInsertElement(vec, double_trunced, i);
468 }
469 llvm_val[stmt] = vec;
470 }
471 } else {
472 auto trunc_fn = [this](llvm::Value *value, llvm::Type *type) {
473 return this->builder->CreateFPTrunc(value, type);
474 };
475 auto cast_type =
476 stmt->cast_type->is<TensorType>()
477 ? stmt->cast_type->cast<TensorType>()->get_element_type()
478 : stmt->cast_type.operator->();
479 if (!stmt->cast_type->is<TensorType>()) {
480 llvm_val[stmt] = trunc_fn(input, tlctx->get_data_type(cast_type));
481 } else {
482 create_elementwise_cast(
483 stmt,
484 tlctx->get_data_type(
485 cast_type->cast<TensorType>()->get_element_type()),
486 trunc_fn);
487 }
488 }
489 }
490 } else if (!is_real(from.get_element_type()) &&
491 !is_real(to.get_element_type())) {
492 llvm_val[stmt] = builder->CreateIntCast(
493 llvm_val[stmt->operand], tlctx->get_data_type(to),
494 is_signed(from.get_element_type()));
495 }
496 } else if (stmt->op_type == UnaryOpType::cast_bits) {
497 TI_ASSERT(data_type_size(stmt->ret_type) ==
498 data_type_size(stmt->cast_type));
499 if (stmt->operand->ret_type.is_pointer()) {
500 TI_ASSERT(is_integral(stmt->cast_type));
501 llvm_val[stmt] = builder->CreatePtrToInt(
502 llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
503 } else {
504 llvm_val[stmt] = builder->CreateBitCast(
505 llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
506 }
507 } else if (op == UnaryOpType::rsqrt) {
508 llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration(
509 module.get(), llvm::Intrinsic::sqrt, input->getType());
510 auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt");
511 llvm_val[stmt] = builder->CreateFDiv(
512 tlctx->get_constant(stmt->ret_type, 1.0), intermediate);
513 } else if (op == UnaryOpType::bit_not) {
514 llvm_val[stmt] = builder->CreateNot(input);
515 } else if (op == UnaryOpType::neg) {
516 if (is_real(stmt->operand->ret_type)) {
517 llvm_val[stmt] = builder->CreateFNeg(input, "neg");
518 } else {
519 llvm_val[stmt] = builder->CreateNeg(input, "neg");
520 }
521 }
522 UNARY_INTRINSIC(round)
523 UNARY_INTRINSIC(floor)
524 UNARY_INTRINSIC(ceil)
525 else {
526 emit_extra_unary(stmt);
527 }
528#undef UNARY_INTRINSIC
529}
530
531void TaskCodeGenLLVM::create_elementwise_binary(
532 BinaryOpStmt *stmt,
533 std::function<llvm::Value *(llvm::Value *lhs, llvm::Value *rhs)> f) {
534 TI_ASSERT(stmt->lhs->ret_type->is<TensorType>());
535 TI_ASSERT(stmt->rhs->ret_type->is<TensorType>());
536 auto lhs_ty = stmt->lhs->ret_type->cast<TensorType>();
537 auto rhs_ty = stmt->rhs->ret_type->cast<TensorType>();
538 TI_ASSERT(lhs_ty->get_num_elements() == rhs_ty->get_num_elements());
539 auto lhs_vec = llvm_val[stmt->lhs];
540 auto rhs_vec = llvm_val[stmt->rhs];
541 auto elt_type_name = data_type_name(lhs_ty->get_element_type());
542 llvm::Value *result =
543 llvm::UndefValue::get(tlctx->get_data_type(stmt->ret_type));
544 for (int i = 0; i < lhs_ty->get_num_elements(); ++i) {
545 auto lhs = builder->CreateExtractElement(lhs_vec, i);
546 auto rhs = builder->CreateExtractElement(rhs_vec, i);
547 result = builder->CreateInsertElement(result, f(lhs, rhs), i);
548 }
549 llvm_val[stmt] = result;
550}
551
552void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
553 auto op = stmt->op_type;
554 auto ret_type = stmt->ret_type;
555
556 if (op == BinaryOpType::add) {
557 if (is_real(stmt->ret_type.get_element_type())) {
558 llvm_val[stmt] =
559 builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
560#if defined(__clang__) || defined(__GNUC__)
561 } else if (compile_config.debug && is_integral(stmt->ret_type)) {
562 llvm_val[stmt] =
563 call("debug_add_" + stmt->ret_type->to_string(), get_arg(0),
564 llvm_val[stmt->lhs], llvm_val[stmt->rhs],
565 builder->CreateGlobalStringPtr(stmt->tb));
566#endif
567 } else {
568 llvm_val[stmt] =
569 builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
570 }
571 } else if (op == BinaryOpType::sub) {
572 if (is_real(stmt->ret_type.get_element_type())) {
573 llvm_val[stmt] =
574 builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
575#if defined(__clang__) || defined(__GNUC__)
576 } else if (compile_config.debug && is_integral(stmt->ret_type)) {
577 llvm_val[stmt] =
578 call("debug_sub_" + stmt->ret_type->to_string(), get_arg(0),
579 llvm_val[stmt->lhs], llvm_val[stmt->rhs],
580 builder->CreateGlobalStringPtr(stmt->tb));
581#endif
582 } else {
583 llvm_val[stmt] =
584 builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
585 }
586 } else if (op == BinaryOpType::mul) {
587 if (is_real(stmt->ret_type.get_element_type())) {
588 llvm_val[stmt] =
589 builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
590#if defined(__clang__) || defined(__GNUC__)
591 } else if (compile_config.debug && is_integral(stmt->ret_type)) {
592 llvm_val[stmt] =
593 call("debug_mul_" + stmt->ret_type->to_string(), get_arg(0),
594 llvm_val[stmt->lhs], llvm_val[stmt->rhs],
595 builder->CreateGlobalStringPtr(stmt->tb));
596#endif
597 } else {
598 llvm_val[stmt] =
599 builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
600 }
601 } else if (op == BinaryOpType::div) {
602 if (is_real(stmt->ret_type.get_element_type())) {
603 llvm_val[stmt] =
604 builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
605 } else if (is_signed(stmt->ret_type)) {
606 llvm_val[stmt] =
607 builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
608 } else {
609 llvm_val[stmt] =
610 builder->CreateUDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
611 }
612 } else if (op == BinaryOpType::mod) {
613 llvm_val[stmt] =
614 builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
615 } else if (op == BinaryOpType::bit_and) {
616 llvm_val[stmt] =
617 builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
618 } else if (op == BinaryOpType::bit_or) {
619 llvm_val[stmt] =
620 builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
621 } else if (op == BinaryOpType::bit_xor) {
622 llvm_val[stmt] =
623 builder->CreateXor(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
624 } else if (op == BinaryOpType::bit_shl) {
625#if defined(__clang__) || defined(__GNUC__)
626 if (compile_config.debug && is_integral(stmt->ret_type)) {
627 llvm_val[stmt] =
628 call("debug_shl_" + stmt->ret_type->to_string(), get_arg(0),
629 llvm_val[stmt->lhs], llvm_val[stmt->rhs],
630 builder->CreateGlobalStringPtr(stmt->tb));
631 } else {
632 llvm_val[stmt] =
633 builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
634 }
635#else
636 llvm_val[stmt] =
637 builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
638#endif
639 } else if (op == BinaryOpType::bit_sar) {
640 if (is_signed(stmt->lhs->element_type())) {
641 llvm_val[stmt] =
642 builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
643 } else {
644 llvm_val[stmt] =
645 builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
646 }
647 } else if (op == BinaryOpType::max) {
648#define BINARYOP_MAX(x) \
649 else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \
650 llvm_val[stmt] = \
651 call("max_" #x, llvm_val[stmt->lhs], llvm_val[stmt->rhs]); \
652 }
653
654 if (is_real(ret_type.get_element_type())) {
655 llvm_val[stmt] =
656 builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
657 }
658 BINARYOP_MAX(u16)
659 BINARYOP_MAX(i16)
660 BINARYOP_MAX(u32)
661 BINARYOP_MAX(i32)
662 BINARYOP_MAX(u64)
663 BINARYOP_MAX(i64)
664 else {
665 if (auto tensor_ty = ret_type->cast<TensorType>()) {
666 auto elt_ty = tensor_ty->get_element_type();
667 TI_ASSERT(elt_ty->is_primitive(PrimitiveTypeID::u16) ||
668 elt_ty->is_primitive(PrimitiveTypeID::i16) ||
669 elt_ty->is_primitive(PrimitiveTypeID::u32) ||
670 elt_ty->is_primitive(PrimitiveTypeID::i32) ||
671 elt_ty->is_primitive(PrimitiveTypeID::u64) ||
672 elt_ty->is_primitive(PrimitiveTypeID::i64));
673 auto dtype_name = data_type_name(elt_ty);
674 auto binary_max = [this, &dtype_name](llvm::Value *lhs,
675 llvm::Value *rhs) {
676 return call("max_" + dtype_name, lhs, rhs);
677 };
678 create_elementwise_binary(stmt, binary_max);
679 } else {
680 TI_P(data_type_name(ret_type));
681 TI_NOT_IMPLEMENTED
682 }
683 }
684 } else if (op == BinaryOpType::min) {
685#define BINARYOP_MIN(x) \
686 else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \
687 llvm_val[stmt] = \
688 call("min_" #x, llvm_val[stmt->lhs], llvm_val[stmt->rhs]); \
689 }
690
691 if (is_real(ret_type.get_element_type())) {
692 llvm_val[stmt] =
693 builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
694 }
695 BINARYOP_MIN(u16)
696 BINARYOP_MIN(i16)
697 BINARYOP_MIN(u32)
698 BINARYOP_MIN(i32)
699 BINARYOP_MIN(u64)
700 BINARYOP_MIN(i64)
701 else {
702 if (auto tensor_ty = ret_type->cast<TensorType>()) {
703 auto elt_ty = tensor_ty->get_element_type();
704 TI_ASSERT(elt_ty->is_primitive(PrimitiveTypeID::u16) ||
705 elt_ty->is_primitive(PrimitiveTypeID::i16) ||
706 elt_ty->is_primitive(PrimitiveTypeID::u32) ||
707 elt_ty->is_primitive(PrimitiveTypeID::i32) ||
708 elt_ty->is_primitive(PrimitiveTypeID::u64) ||
709 elt_ty->is_primitive(PrimitiveTypeID::i64));
710 auto dtype_name = data_type_name(elt_ty);
711 auto binary_min = [this, &dtype_name](llvm::Value *lhs,
712 llvm::Value *rhs) {
713 return call("min_" + dtype_name, lhs, rhs);
714 };
715 create_elementwise_binary(stmt, binary_min);
716 } else {
717 TI_P(data_type_name(ret_type));
718 TI_NOT_IMPLEMENTED
719 }
720 }
721 } else if (is_comparison(op)) {
722 llvm::Value *cmp = nullptr;
723 auto input_type = stmt->lhs->ret_type;
724 if (op == BinaryOpType::cmp_eq) {
725 if (is_real(input_type.get_element_type())) {
726 cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
727 } else {
728 cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
729 }
730 } else if (op == BinaryOpType::cmp_le) {
731 if (is_real(input_type.get_element_type())) {
732 cmp = builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
733 } else {
734 if (is_signed(input_type.get_element_type())) {
735 cmp =
736 builder->CreateICmpSLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
737 } else {
738 cmp =
739 builder->CreateICmpULE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
740 }
741 }
742 } else if (op == BinaryOpType::cmp_ge) {
743 if (is_real(input_type.get_element_type())) {
744 cmp = builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
745 } else {
746 if (is_signed(input_type.get_element_type())) {
747 cmp =
748 builder->CreateICmpSGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
749 } else {
750 cmp =
751 builder->CreateICmpUGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
752 }
753 }
754 } else if (op == BinaryOpType::cmp_lt) {
755 if (is_real(input_type.get_element_type())) {
756 cmp = builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
757 } else {
758 if (is_signed(input_type.get_element_type())) {
759 cmp =
760 builder->CreateICmpSLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
761 } else {
762 cmp =
763 builder->CreateICmpULT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
764 }
765 }
766 } else if (op == BinaryOpType::cmp_gt) {
767 if (is_real(input_type.get_element_type())) {
768 cmp = builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
769 } else {
770 if (is_signed(input_type.get_element_type())) {
771 cmp =
772 builder->CreateICmpSGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
773 } else {
774 cmp =
775 builder->CreateICmpUGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
776 }
777 }
778 } else if (op == BinaryOpType::cmp_ne) {
779 if (is_real(input_type.get_element_type())) {
780 cmp = builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
781 } else {
782 cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
783 }
784 } else {
785 TI_NOT_IMPLEMENTED
786 }
787 llvm_val[stmt] =
788 builder->CreateSExt(cmp, tlctx->get_data_type(PrimitiveType::i32));
789 } else {
790 // This branch contains atan2 and pow which use runtime.cpp function for
791 // **real** type. We don't have f16 support there so promoting to f32 is
792 // necessary.
793 llvm::Value *lhs = llvm_val[stmt->lhs];
794 llvm::Value *rhs = llvm_val[stmt->rhs];
795 if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) {
796 lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context));
797 }
798 if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) {
799 rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context));
800 }
801 if (ret_type->is_primitive(PrimitiveTypeID::f16)) {
802 ret_type = PrimitiveType::f32;
803 }
804
805 if (op == BinaryOpType::atan2) {
806 if (arch_is_cpu(current_arch())) {
807 if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
808 llvm_val[stmt] = call("atan2_f32", lhs, rhs);
809 } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
810 llvm_val[stmt] = call("atan2_f64", lhs, rhs);
811 } else {
812 TI_P(data_type_name(ret_type));
813 TI_NOT_IMPLEMENTED
814 }
815 } else {
816 TI_NOT_IMPLEMENTED
817 }
818 } else if (op == BinaryOpType::pow) {
819 if (arch_is_cpu(current_arch())) {
820 // Note that ret_type here cannot be integral because pow with an
821 // integral exponent has been demoted in the demote_operations pass
822 if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
823 llvm_val[stmt] = call("pow_f32", lhs, rhs);
824 } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
825 llvm_val[stmt] = call("pow_f64", lhs, rhs);
826 } else {
827 TI_P(data_type_name(ret_type));
828 TI_NOT_IMPLEMENTED
829 }
830 } else {
831 TI_NOT_IMPLEMENTED
832 }
833 } else {
834 TI_P(binary_op_type_name(op));
835 TI_NOT_IMPLEMENTED
836 }
837
838 // Convert back to f16 if applicable.
839 if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
840 llvm_val[stmt] = builder->CreateFPTrunc(
841 llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
842 }
843 }
844}
845
846void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) {
847 TI_ASSERT(stmt->op_type == TernaryOpType::select);
848 llvm_val[stmt] = builder->CreateSelect(
849 builder->CreateTrunc(llvm_val[stmt->op1],
850 tlctx->get_data_type(PrimitiveType::u1)),
851 llvm_val[stmt->op2], llvm_val[stmt->op3]);
852}
853
854void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
855 // TODO: take care of vectorized cases
856 llvm::BasicBlock *true_block =
857 llvm::BasicBlock::Create(*llvm_context, "true_block", func);
858 llvm::BasicBlock *false_block =
859 llvm::BasicBlock::Create(*llvm_context, "false_block", func);
860 llvm::BasicBlock *after_if =
861 llvm::BasicBlock::Create(*llvm_context, "after_if", func);
862 builder->CreateCondBr(
863 builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)),
864 true_block, false_block);
865 builder->SetInsertPoint(true_block);
866 if (if_stmt->true_statements) {
867 if_stmt->true_statements->accept(this);
868 }
869 if (!returned) {
870 builder->CreateBr(after_if);
871 } else {
872 returned = false;
873 }
874 builder->SetInsertPoint(false_block);
875 if (if_stmt->false_statements) {
876 if_stmt->false_statements->accept(this);
877 }
878 if (!returned) {
879 builder->CreateBr(after_if);
880 } else {
881 returned = false;
882 }
883 builder->SetInsertPoint(after_if);
884}
885
886llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
887 DataType dt,
888 llvm::Value *value) {
889 if (!arch_is_cpu(compile_config.arch)) {
890 TI_WARN("print not supported on arch {}", arch_name(compile_config.arch));
891 return nullptr;
892 }
893 std::vector<llvm::Value *> args;
894 std::string format = data_type_format(dt);
895 auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime());
896 args.push_back(builder->CreateGlobalStringPtr(
897 ("[llvm codegen debug] " + tag + " = " + format + "\n").c_str(),
898 "format_string"));
899 if (dt->is_primitive(PrimitiveTypeID::f32))
900 value =
901 builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64));
902 args.push_back(value);
903
904 auto func_type_func = get_runtime_function("get_func_type_host_printf");
905 return call(runtime_printf, func_type_func->getFunctionType(),
906 std::move(args));
907}
908
909llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
910 llvm::Value *value) {
911 if (value->getType() == llvm::Type::getFloatTy(*llvm_context))
912 return create_print(
913 tag,
914 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32),
915 value);
916 else if (value->getType() == llvm::Type::getInt32Ty(*llvm_context))
917 return create_print(
918 tag,
919 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32),
920 value);
921 else if (value->getType() == llvm::Type::getHalfTy(*llvm_context)) {
922 auto extended =
923 builder->CreateFPExt(value, llvm::Type::getFloatTy(*llvm_context));
924 return create_print(
925 tag,
926 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32),
927 extended);
928 } else if (value->getType() == llvm::Type::getInt64Ty(*llvm_context))
929 return create_print(
930 tag,
931 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64),
932 value);
933 else if (value->getType() == llvm::Type::getInt16Ty(*llvm_context))
934 return create_print(
935 tag,
936 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i16),
937 value);
938 else
939 TI_NOT_IMPLEMENTED
940}
941
942void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
943 std::vector<llvm::Value *> args;
944 std::string formats;
945 auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) {
946 if (dtype->is_primitive(PrimitiveTypeID::f32) ||
947 dtype->is_primitive(PrimitiveTypeID::f16))
948 return this->builder->CreateFPExt(
949 to_print, this->tlctx->get_data_type(PrimitiveType::f64));
950 if (dtype->is_primitive(PrimitiveTypeID::i8))
951 return builder->CreateSExt(to_print,
952 tlctx->get_data_type(PrimitiveType::i16));
953 if (dtype->is_primitive(PrimitiveTypeID::u8))
954 return builder->CreateZExt(to_print,
955 tlctx->get_data_type(PrimitiveType::u16));
956 return to_print;
957 };
958 for (auto const &content : stmt->contents) {
959 if (std::holds_alternative<Stmt *>(content)) {
960 auto arg_stmt = std::get<Stmt *>(content);
961 auto value = llvm_val[arg_stmt];
962 auto value_type = value->getType();
963 if (arg_stmt->ret_type->is<TensorType>()) {
964 auto dtype = arg_stmt->ret_type->cast<TensorType>();
965 auto elem_type = dtype->get_element_type();
966 for (int i = 0; i < dtype->get_num_elements(); ++i) {
967 if (codegen_vector_type(compile_config)) {
968 TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(value_type));
969 auto elem = builder->CreateExtractElement(value, i);
970 args.push_back(value_for_printf(elem, elem_type));
971 } else {
972 TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(value_type));
973 auto elem = builder->CreateExtractValue(value, i);
974 args.push_back(value_for_printf(elem, elem_type));
975 }
976 }
977 formats += data_type_format(arg_stmt->ret_type);
978 } else {
979 args.push_back(value_for_printf(value, arg_stmt->ret_type));
980 formats += data_type_format(arg_stmt->ret_type);
981 }
982 } else {
983 auto arg_str = std::get<std::string>(content);
984 auto value = builder->CreateGlobalStringPtr(arg_str, "content_string");
985 args.push_back(value);
986 formats += "%s";
987 }
988 }
989 auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime());
990 args.insert(args.begin(),
991 builder->CreateGlobalStringPtr(formats.c_str(), "format_string"));
992 auto func_type_func = get_runtime_function("get_func_type_host_printf");
993 llvm_val[stmt] =
994 call(runtime_printf, func_type_func->getFunctionType(), std::move(args));
995}
996
997void TaskCodeGenLLVM::visit(ConstStmt *stmt) {
998 auto val = stmt->val;
999 if (val.dt->is_primitive(PrimitiveTypeID::f32)) {
1000 llvm_val[stmt] =
1001 llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float32()));
1002 } else if (val.dt->is_primitive(PrimitiveTypeID::f16)) {
1003 llvm_val[stmt] = llvm::ConstantFP::get(llvm::Type::getHalfTy(*llvm_context),
1004 val.val_float32());
1005 } else if (val.dt->is_primitive(PrimitiveTypeID::f64)) {
1006 llvm_val[stmt] =
1007 llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64()));
1008 } else if (val.dt->is_primitive(PrimitiveTypeID::i8)) {
1009 llvm_val[stmt] = llvm::ConstantInt::get(
1010 *llvm_context, llvm::APInt(8, (uint64)val.val_int8(), true));
1011 } else if (val.dt->is_primitive(PrimitiveTypeID::u8)) {
1012 llvm_val[stmt] = llvm::ConstantInt::get(
1013 *llvm_context, llvm::APInt(8, (uint64)val.val_uint8(), false));
1014 } else if (val.dt->is_primitive(PrimitiveTypeID::i16)) {
1015 llvm_val[stmt] = llvm::ConstantInt::get(
1016 *llvm_context, llvm::APInt(16, (uint64)val.val_int16(), true));
1017 } else if (val.dt->is_primitive(PrimitiveTypeID::u16)) {
1018 llvm_val[stmt] = llvm::ConstantInt::get(
1019 *llvm_context, llvm::APInt(16, (uint64)val.val_uint16(), false));
1020 } else if (val.dt->is_primitive(PrimitiveTypeID::i32)) {
1021 llvm_val[stmt] = llvm::ConstantInt::get(
1022 *llvm_context, llvm::APInt(32, (uint64)val.val_int32(), true));
1023 } else if (val.dt->is_primitive(PrimitiveTypeID::u32)) {
1024 llvm_val[stmt] = llvm::ConstantInt::get(
1025 *llvm_context, llvm::APInt(32, (uint64)val.val_uint32(), false));
1026 } else if (val.dt->is_primitive(PrimitiveTypeID::i64)) {
1027 llvm_val[stmt] = llvm::ConstantInt::get(
1028 *llvm_context, llvm::APInt(64, (uint64)val.val_int64(), true));
1029 } else if (val.dt->is_primitive(PrimitiveTypeID::u64)) {
1030 llvm_val[stmt] = llvm::ConstantInt::get(
1031 *llvm_context, llvm::APInt(64, val.val_uint64(), false));
1032 } else {
1033 TI_P(data_type_name(val.dt));
1034 TI_NOT_IMPLEMENTED;
1035 }
1036}
1037
1038void TaskCodeGenLLVM::visit(WhileControlStmt *stmt) {
1039 using namespace llvm;
1040
1041 BasicBlock *after_break =
1042 BasicBlock::Create(*llvm_context, "after_break", func);
1043 TI_ASSERT(current_while_after_loop);
1044 auto cond =
1045 builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0));
1046 builder->CreateCondBr(cond, current_while_after_loop, after_break);
1047 builder->SetInsertPoint(after_break);
1048}
1049
1050void TaskCodeGenLLVM::visit(ContinueStmt *stmt) {
1051 using namespace llvm;
1052 auto stmt_in_off_range_for = [stmt]() {
1053 TI_ASSERT(stmt->scope != nullptr);
1054 if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) {
1055 TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
1056 offl->task_type == OffloadedStmt::TaskType::struct_for);
1057 return offl->task_type == OffloadedStmt::TaskType::range_for;
1058 }
1059 return false;
1060 };
1061 if (stmt_in_off_range_for()) {
1062 builder->CreateRetVoid();
1063 } else {
1064 TI_ASSERT(current_loop_reentry != nullptr);
1065 builder->CreateBr(current_loop_reentry);
1066 }
1067 // Stmts after continue are useless, so we switch the insertion point to
1068 // /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!".
1069 BasicBlock *after_continue =
1070 BasicBlock::Create(*llvm_context, "after_continue", func);
1071 builder->SetInsertPoint(after_continue);
1072}
1073
1074void TaskCodeGenLLVM::visit(WhileStmt *stmt) {
1075 using namespace llvm;
1076 BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
1077 builder->CreateBr(body);
1078 builder->SetInsertPoint(body);
1079 auto lrg = make_loop_reentry_guard(this);
1080 current_loop_reentry = body;
1081
1082 BasicBlock *after_loop =
1083 BasicBlock::Create(*llvm_context, "after_while", func);
1084 auto walg = make_while_after_loop_guard(this);
1085 current_while_after_loop = after_loop;
1086
1087 stmt->body->accept(this);
1088
1089 if (!returned) {
1090 builder->CreateBr(body); // jump to head
1091 } else {
1092 returned = false;
1093 }
1094
1095 builder->SetInsertPoint(after_loop);
1096}
1097
1098llvm::Value *TaskCodeGenLLVM::cast_pointer(llvm::Value *val,
1099 std::string dest_ty_name,
1100 int addr_space) {
1101 return builder->CreateBitCast(
1102 val, llvm::PointerType::get(get_runtime_type(dest_ty_name), addr_space));
1103}
1104
1105void TaskCodeGenLLVM::emit_list_gen(OffloadedStmt *listgen) {
1106 auto snode_child = listgen->snode;
1107 auto snode_parent = listgen->snode->parent;
1108 auto meta_child = cast_pointer(emit_struct_meta(snode_child), "StructMeta");
1109 auto meta_parent = cast_pointer(emit_struct_meta(snode_parent), "StructMeta");
1110 if (snode_parent->type == SNodeType::root) {
1111 // Since there's only one container to expand, we need a special kernel for
1112 // more parallelism.
1113 call("element_listgen_root", get_runtime(), meta_parent, meta_child);
1114 } else {
1115 call("element_listgen_nonroot", get_runtime(), meta_parent, meta_child);
1116 }
1117}
1118
1119void TaskCodeGenLLVM::emit_gc(OffloadedStmt *stmt) {
1120 auto snode = stmt->snode->id;
1121 call("node_gc", get_runtime(), tlctx->get_constant(snode));
1122}
1123
1124void TaskCodeGenLLVM::emit_gc_rc() {
1125 call("runtime_context_gc", get_runtime());
1126}
1127
1128void TaskCodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
1129 auto original_value = builder->CreateLoad(value->getType(), ptr);
1130 builder->CreateStore(builder->CreateAdd(original_value, value), ptr);
1131}
1132
1133void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
1134 using namespace llvm;
1135 BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
1136 BasicBlock *loop_inc =
1137 BasicBlock::Create(*llvm_context, "for_loop_inc", func);
1138 BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
1139 BasicBlock *loop_test =
1140 BasicBlock::Create(*llvm_context, "for_loop_test", func);
1141
1142 auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32);
1143
1144 auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
1145 loop_vars_llvm[for_stmt].push_back(loop_var);
1146
1147 if (!for_stmt->reversed) {
1148 builder->CreateStore(llvm_val[for_stmt->begin], loop_var);
1149 } else {
1150 builder->CreateStore(
1151 builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
1152 loop_var);
1153 }
1154 builder->CreateBr(loop_test);
1155
1156 {
1157 // test block
1158 builder->SetInsertPoint(loop_test);
1159 llvm::Value *cond;
1160 if (!for_stmt->reversed) {
1161 cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
1162 builder->CreateLoad(loop_var_ty, loop_var),
1163 llvm_val[for_stmt->end]);
1164 } else {
1165 cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
1166 builder->CreateLoad(loop_var_ty, loop_var),
1167 llvm_val[for_stmt->begin]);
1168 }
1169 builder->CreateCondBr(cond, body, after_loop);
1170 }
1171
1172 {
1173 {
1174 auto lrg = make_loop_reentry_guard(this);
1175 // The continue stmt should jump to the loop-increment block!
1176 current_loop_reentry = loop_inc;
1177 // body cfg
1178 builder->SetInsertPoint(body);
1179
1180 for_stmt->body->accept(this);
1181 }
1182 if (!returned) {
1183 builder->CreateBr(loop_inc);
1184 } else {
1185 returned = false;
1186 }
1187 builder->SetInsertPoint(loop_inc);
1188
1189 if (!for_stmt->reversed) {
1190 create_increment(loop_var, tlctx->get_constant(1));
1191 } else {
1192 create_increment(loop_var, tlctx->get_constant(-1));
1193 }
1194 builder->CreateBr(loop_test);
1195 }
1196
1197 // next cfg
1198 builder->SetInsertPoint(after_loop);
1199}
1200
1201void TaskCodeGenLLVM::visit(RangeForStmt *for_stmt) {
1202 create_naive_range_for(for_stmt);
1203}
1204
1205llvm::Value *TaskCodeGenLLVM::bitcast_from_u64(llvm::Value *val,
1206 DataType type) {
1207 llvm::Type *dest_ty = nullptr;
1208 TI_ASSERT(!type->is<PointerType>());
1209 if (auto qit = type->cast<QuantIntType>()) {
1210 if (qit->get_is_signed())
1211 dest_ty = tlctx->get_data_type(PrimitiveType::i32);
1212 else
1213 dest_ty = tlctx->get_data_type(PrimitiveType::u32);
1214 } else {
1215 dest_ty = tlctx->get_data_type(type);
1216 }
1217 auto dest_bits = dest_ty->getPrimitiveSizeInBits();
1218 if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) {
1219 // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
1220 // which doesn't mean anything.
1221 // So we truncate to 32 bits first and then fptrunc to half if applicable
1222 auto truncated =
1223 builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32));
1224 auto casted = builder->CreateBitCast(truncated,
1225 llvm::Type::getFloatTy(*llvm_context));
1226 return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context));
1227 } else {
1228 auto truncated = builder->CreateTrunc(
1229 val, llvm::Type::getIntNTy(*llvm_context, dest_bits));
1230
1231 return builder->CreateBitCast(truncated, dest_ty);
1232 }
1233}
1234
1235llvm::Value *TaskCodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
1236 auto intermediate_bits = 0;
1237 if (type.is_pointer()) {
1238 return builder->CreatePtrToInt(val, tlctx->get_data_type<int64>());
1239 }
1240 if (auto qit = type->cast<QuantIntType>()) {
1241 intermediate_bits = data_type_bits(qit->get_compute_type());
1242 } else {
1243 intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits();
1244 }
1245 llvm::Type *dest_ty = tlctx->get_data_type<int64>();
1246 llvm::Type *intermediate_type = nullptr;
1247 if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) {
1248 val = builder->CreateFPExt(val, tlctx->get_data_type<float>());
1249 intermediate_type = tlctx->get_data_type<int32>();
1250 } else {
1251 intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
1252 }
1253 return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type),
1254 dest_ty);
1255}
1256
1257void TaskCodeGenLLVM::visit(ArgLoadStmt *stmt) {
1258 auto raw_arg = stmt->is_grad
1259 ? (call(builder.get(), "RuntimeContext_get_grad_args",
1260 get_context(), tlctx->get_constant(stmt->arg_id)))
1261 : (call(builder.get(), "RuntimeContext_get_args",
1262 get_context(), tlctx->get_constant(stmt->arg_id)));
1263 llvm::Type *dest_ty = nullptr;
1264 if (stmt->is_ptr) {
1265 dest_ty = llvm::PointerType::get(
1266 tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
1267 llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
1268 } else {
1269 llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type);
1270 }
1271}
1272
1273void TaskCodeGenLLVM::visit(ReturnStmt *stmt) {
1274 auto types = stmt->element_types();
1275 if (std::any_of(types.begin(), types.end(),
1276 [](const DataType &t) { return t.is_pointer(); })) {
1277 TI_NOT_IMPLEMENTED
1278 } else {
1279 TI_ASSERT(stmt->values.size() ==
1280 current_callable->ret_type->get_num_elements());
1281 create_return(stmt->values);
1282 }
1283 builder->CreateBr(final_block);
1284 returned = true;
1285}
1286
1287void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) {
1288 // FIXME: get ptr_ty from taichi instead of llvm.
1289 llvm::Type *ptr_ty = nullptr;
1290 auto *val = llvm_val[stmt->src];
1291 if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val))
1292 ptr_ty = alloc->getAllocatedType();
1293 if (!ptr_ty && stmt->src->element_type().is_pointer()) {
1294 ptr_ty = tlctx->get_data_type(stmt->src->element_type().ptr_removed());
1295 }
1296 TI_ASSERT(ptr_ty);
1297 llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]);
1298}
1299
1300void TaskCodeGenLLVM::visit(LocalStoreStmt *stmt) {
1301 builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
1302}
1303
1304void TaskCodeGenLLVM::visit(AssertStmt *stmt) {
1305 TI_ASSERT((int)stmt->args.size() <= taichi_error_message_max_num_arguments);
1306 auto argument_buffer_size = llvm::ArrayType::get(
1307 llvm::Type::getInt64Ty(*llvm_context), stmt->args.size());
1308
1309 // TODO: maybe let all asserts in a single offload share a single buffer?
1310 auto arguments = create_entry_block_alloca(argument_buffer_size);
1311
1312 std::vector<llvm::Value *> args;
1313 args.emplace_back(get_runtime());
1314 args.emplace_back(llvm_val[stmt->cond]);
1315 args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));
1316
1317 for (int i = 0; i < stmt->args.size(); i++) {
1318 auto arg = stmt->args[i];
1319 TI_ASSERT(llvm_val[arg]);
1320
1321 // First convert the argument to an integral type with the same number of
1322 // bits:
1323 auto cast_type = llvm::Type::getIntNTy(
1324 *llvm_context, 8 * (std::size_t)data_type_size(arg->ret_type));
1325 auto cast_int = builder->CreateBitCast(llvm_val[arg], cast_type);
1326
1327 // Then zero-extend the conversion result into int64:
1328 auto cast_int64 =
1329 builder->CreateZExt(cast_int, llvm::Type::getInt64Ty(*llvm_context));
1330
1331 // Finally store the int64 value to the argument buffer:
1332 builder->CreateStore(
1333 cast_int64,
1334 builder->CreateGEP(argument_buffer_size, arguments,
1335 {tlctx->get_constant(0), tlctx->get_constant(i)}));
1336 }
1337
1338 args.emplace_back(tlctx->get_constant((int)stmt->args.size()));
1339 args.emplace_back(
1340 builder->CreateGEP(argument_buffer_size, arguments,
1341 {tlctx->get_constant(0), tlctx->get_constant(0)}));
1342
1343 llvm_val[stmt] = call("taichi_assert_format", std::move(args));
1344}
1345
1346void TaskCodeGenLLVM::visit(SNodeOpStmt *stmt) {
1347 auto snode = stmt->snode;
1348 if (stmt->op_type == SNodeOpType::allocate) {
1349 TI_ASSERT(snode->type == SNodeType::dynamic);
1350 TI_ASSERT(stmt->ret_type.is_pointer() &&
1351 stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::gen));
1352 auto ptr =
1353 call(snode, llvm_val[stmt->ptr], "allocate", {llvm_val[stmt->val]});
1354 llvm_val[stmt] = ptr;
1355 } else if (stmt->op_type == SNodeOpType::length) {
1356 TI_ASSERT(snode->type == SNodeType::dynamic);
1357 llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "get_num_elements", {});
1358 } else if (stmt->op_type == SNodeOpType::is_active) {
1359 llvm_val[stmt] =
1360 call(snode, llvm_val[stmt->ptr], "is_active", {llvm_val[stmt->val]});
1361 } else if (stmt->op_type == SNodeOpType::activate) {
1362 llvm_val[stmt] =
1363 call(snode, llvm_val[stmt->ptr], "activate", {llvm_val[stmt->val]});
1364 } else if (stmt->op_type == SNodeOpType::deactivate) {
1365 if (snode->type == SNodeType::pointer || snode->type == SNodeType::hash ||
1366 snode->type == SNodeType::bitmasked) {
1367 llvm_val[stmt] =
1368 call(snode, llvm_val[stmt->ptr], "deactivate", {llvm_val[stmt->val]});
1369 } else if (snode->type == SNodeType::dynamic) {
1370 llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "deactivate", {});
1371 }
1372 } else {
1373 TI_NOT_IMPLEMENTED
1374 }
1375}
1376
1377llvm::Value *TaskCodeGenLLVM::optimized_reduction(AtomicOpStmt *stmt) {
1378 return nullptr;
1379}
1380
1381llvm::Value *TaskCodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {
1382 // TODO(type): support all AtomicOpTypes on quant types
1383 if (stmt->op_type != AtomicOpType::add) {
1384 return nullptr;
1385 }
1386
1387 auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
1388 if (auto qit = dst_type->cast<QuantIntType>()) {
1389 return atomic_add_quant_int(
1390 llvm_val[stmt->dest],
1391 tlctx->get_data_type(
1392 stmt->dest->as<GetChStmt>()->input_snode->physical_type),
1393 qit, llvm_val[stmt->val], is_signed(stmt->val->ret_type));
1394 } else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
1395 return atomic_add_quant_fixed(
1396 llvm_val[stmt->dest],
1397 tlctx->get_data_type(
1398 stmt->dest->as<GetChStmt>()->input_snode->physical_type),
1399 qfxt, llvm_val[stmt->val]);
1400 } else {
1401 return nullptr;
1402 }
1403}
1404
1405llvm::Value *TaskCodeGenLLVM::integral_type_atomic(AtomicOpStmt *stmt) {
1406 if (!is_integral(stmt->val->ret_type)) {
1407 return nullptr;
1408 }
1409
1410 std::unordered_map<AtomicOpType, llvm::AtomicRMWInst::BinOp> bin_op;
1411 bin_op[AtomicOpType::add] = llvm::AtomicRMWInst::BinOp::Add;
1412 if (is_signed(stmt->val->ret_type)) {
1413 bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::Min;
1414 bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::Max;
1415 } else {
1416 bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::UMin;
1417 bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::UMax;
1418 }
1419 bin_op[AtomicOpType::bit_and] = llvm::AtomicRMWInst::BinOp::And;
1420 bin_op[AtomicOpType::bit_or] = llvm::AtomicRMWInst::BinOp::Or;
1421 bin_op[AtomicOpType::bit_xor] = llvm::AtomicRMWInst::BinOp::Xor;
1422 TI_ASSERT(bin_op.find(stmt->op_type) != bin_op.end());
1423 return builder->CreateAtomicRMW(
1424 bin_op.at(stmt->op_type), llvm_val[stmt->dest], llvm_val[stmt->val],
1425 llvm::MaybeAlign(0), llvm::AtomicOrdering::SequentiallyConsistent);
1426}
1427
1428llvm::Value *TaskCodeGenLLVM::atomic_op_using_cas(
1429 llvm::Value *dest,
1430 llvm::Value *val,
1431 std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op) {
1432 using namespace llvm;
1433 BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
1434 BasicBlock *after_loop =
1435 BasicBlock::Create(*llvm_context, "after_while", func);
1436
1437 builder->CreateBr(body);
1438 builder->SetInsertPoint(body);
1439
1440 llvm::Value *old_val;
1441
1442 {
1443 old_val = builder->CreateLoad(val->getType(), dest);
1444 auto new_val = op(old_val, val);
1445 dest =
1446 builder->CreateBitCast(dest, llvm::Type::getInt16PtrTy(*llvm_context));
1447 auto atomicCmpXchg = builder->CreateAtomicCmpXchg(
1448 dest,
1449 builder->CreateBitCast(old_val, llvm::Type::getInt16Ty(*llvm_context)),
1450 builder->CreateBitCast(new_val, llvm::Type::getInt16Ty(*llvm_context)),
1451 llvm::MaybeAlign(0), AtomicOrdering::SequentiallyConsistent,
1452 AtomicOrdering::SequentiallyConsistent);
1453 // Check whether CAS was succussful
1454 auto ok = builder->CreateExtractValue(atomicCmpXchg, 1);
1455 builder->CreateCondBr(builder->CreateNot(ok), body, after_loop);
1456 }
1457
1458 builder->SetInsertPoint(after_loop);
1459
1460 return old_val;
1461}
1462
1463llvm::Value *TaskCodeGenLLVM::real_type_atomic(AtomicOpStmt *stmt) {
1464 if (!is_real(stmt->val->ret_type)) {
1465 return nullptr;
1466 }
1467
1468 PrimitiveTypeID prim_type = stmt->val->ret_type->cast<PrimitiveType>()->type;
1469 AtomicOpType op = stmt->op_type;
1470 if (prim_type == PrimitiveTypeID::f16) {
1471 switch (op) {
1472 case AtomicOpType::add:
1473 return atomic_op_using_cas(
1474 llvm_val[stmt->dest], llvm_val[stmt->val],
1475 [&](auto v1, auto v2) { return builder->CreateFAdd(v1, v2); });
1476 case AtomicOpType::max:
1477 return atomic_op_using_cas(
1478 llvm_val[stmt->dest], llvm_val[stmt->val],
1479 [&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); });
1480 case AtomicOpType::min:
1481 return atomic_op_using_cas(
1482 llvm_val[stmt->dest], llvm_val[stmt->val],
1483 [&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); });
1484 default:
1485 break;
1486 }
1487 }
1488
1489 if (op == AtomicOpType::add) {
1490 return builder->CreateAtomicRMW(
1491 llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val],
1492 llvm::MaybeAlign(0), llvm::AtomicOrdering::SequentiallyConsistent);
1493 }
1494
1495 std::unordered_map<PrimitiveTypeID,
1496 std::unordered_map<AtomicOpType, std::string>>
1497 atomics;
1498 atomics[PrimitiveTypeID::f32][AtomicOpType::min] = "atomic_min_f32";
1499 atomics[PrimitiveTypeID::f64][AtomicOpType::min] = "atomic_min_f64";
1500 atomics[PrimitiveTypeID::f32][AtomicOpType::max] = "atomic_max_f32";
1501 atomics[PrimitiveTypeID::f64][AtomicOpType::max] = "atomic_max_f64";
1502 TI_ASSERT(atomics.find(prim_type) != atomics.end());
1503 TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end());
1504 return call(atomics.at(prim_type).at(op), llvm_val[stmt->dest],
1505 llvm_val[stmt->val]);
1506}
1507
1508void TaskCodeGenLLVM::visit(AtomicOpStmt *stmt) {
1509 bool is_local = stmt->dest->is<AllocaStmt>();
1510 if (is_local) {
1511 TI_ERROR("Local atomics should have been demoted.");
1512 }
1513 llvm::Value *old_value;
1514 if (llvm::Value *result = optimized_reduction(stmt)) {
1515 old_value = result;
1516 } else if (llvm::Value *result = quant_type_atomic(stmt)) {
1517 old_value = result;
1518 } else if (llvm::Value *result = real_type_atomic(stmt)) {
1519 old_value = result;
1520 } else if (llvm::Value *result = integral_type_atomic(stmt)) {
1521 old_value = result;
1522 } else {
1523 TI_NOT_IMPLEMENTED
1524 }
1525 llvm_val[stmt] = old_value;
1526}
1527
1528void TaskCodeGenLLVM::visit(GlobalPtrStmt *stmt) {
1529 TI_ERROR("Global Ptrs should have been lowered.");
1530}
1531
1532void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) {
1533 TI_ASSERT(llvm_val[stmt->val]);
1534 TI_ASSERT(llvm_val[stmt->dest]);
1535 auto ptr_type = stmt->dest->ret_type->as<PointerType>();
1536 if (ptr_type->is_bit_pointer()) {
1537 auto pointee_type = ptr_type->get_pointee_type();
1538 auto snode = stmt->dest->as<GetChStmt>()->input_snode;
1539 if (snode->type == SNodeType::bit_struct) {
1540 TI_ERROR(
1541 "Bit struct stores with type {} should have been handled by "
1542 "BitStructStoreStmt.",
1543 pointee_type->to_string());
1544 }
1545 if (auto qit = pointee_type->cast<QuantIntType>()) {
1546 store_quant_int(llvm_val[stmt->dest],
1547 tlctx->get_data_type(snode->physical_type), qit,
1548 llvm_val[stmt->val], true);
1549 } else if (auto qfxt = pointee_type->cast<QuantFixedType>()) {
1550 store_quant_fixed(llvm_val[stmt->dest],
1551 tlctx->get_data_type(snode->physical_type), qfxt,
1552 llvm_val[stmt->val], true);
1553 } else {
1554 TI_NOT_IMPLEMENTED;
1555 }
1556 } else {
1557 builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
1558 }
1559}
1560
1561llvm::Value *TaskCodeGenLLVM::create_intrinsic_load(llvm::Value *ptr,
1562 llvm::Type *ty) {
1563 TI_NOT_IMPLEMENTED;
1564}
1565
1566void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
1567 bool should_cache_as_read_only) {
1568 auto ptr = llvm_val[stmt->src];
1569 auto ptr_type = stmt->src->ret_type->as<PointerType>();
1570 if (ptr_type->is_bit_pointer()) {
1571 auto val_type = ptr_type->get_pointee_type();
1572 auto get_ch = stmt->src->as<GetChStmt>();
1573 auto physical_type =
1574 tlctx->get_data_type(get_ch->input_snode->physical_type);
1575 auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
1576 auto physical_value = should_cache_as_read_only
1577 ? create_intrinsic_load(byte_ptr, physical_type)
1578 : builder->CreateLoad(physical_type, byte_ptr);
1579 if (auto qit = val_type->cast<QuantIntType>()) {
1580 llvm_val[stmt] = extract_quant_int(physical_value, bit_offset, qit);
1581 } else if (auto qfxt = val_type->cast<QuantFixedType>()) {
1582 qit = qfxt->get_digits_type()->as<QuantIntType>();
1583 auto digits = extract_quant_int(physical_value, bit_offset, qit);
1584 llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt);
1585 } else {
1586 TI_ASSERT(val_type->is<QuantFloatType>());
1587 TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>());
1588 llvm_val[stmt] = extract_quant_float(
1589 physical_value, get_ch->input_snode->dt->as<BitStructType>(),
1590 get_ch->output_snode->id_in_bit_struct);
1591 }
1592 } else {
1593 // Byte pointer case.
1594 if (should_cache_as_read_only) {
1595 llvm_val[stmt] =
1596 create_intrinsic_load(ptr, tlctx->get_data_type(stmt->ret_type));
1597 } else {
1598 llvm_val[stmt] =
1599 builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr);
1600 }
1601 }
1602}
1603
1604void TaskCodeGenLLVM::visit(GlobalLoadStmt *stmt) {
1605 create_global_load(stmt, false);
1606}
1607
1608std::string TaskCodeGenLLVM::get_runtime_snode_name(SNode *snode) {
1609 if (snode->type == SNodeType::root) {
1610 return "Root";
1611 } else if (snode->type == SNodeType::dense) {
1612 return "Dense";
1613 } else if (snode->type == SNodeType::dynamic) {
1614 return "Dynamic";
1615 } else if (snode->type == SNodeType::pointer) {
1616 return "Pointer";
1617 } else if (snode->type == SNodeType::hash) {
1618 return "Hash";
1619 } else if (snode->type == SNodeType::bitmasked) {
1620 return "Bitmasked";
1621 } else if (snode->type == SNodeType::bit_struct) {
1622 return "BitStruct";
1623 } else if (snode->type == SNodeType::quant_array) {
1624 return "QuantArray";
1625 } else {
1626 TI_P(snode_type_name(snode->type));
1627 TI_NOT_IMPLEMENTED
1628 }
1629}
1630
1631llvm::Value *TaskCodeGenLLVM::call(
1632 SNode *snode,
1633 llvm::Value *node_ptr,
1634 const std::string &method,
1635 const std::vector<llvm::Value *> &arguments) {
1636 auto prefix = get_runtime_snode_name(snode);
1637 auto s = emit_struct_meta(snode);
1638 auto s_ptr =
1639 builder->CreateBitCast(s, llvm::Type::getInt8PtrTy(*llvm_context));
1640
1641 node_ptr =
1642 builder->CreateBitCast(node_ptr, llvm::Type::getInt8PtrTy(*llvm_context));
1643
1644 std::vector<llvm::Value *> func_arguments{s_ptr, node_ptr};
1645
1646 func_arguments.insert(func_arguments.end(), arguments.begin(),
1647 arguments.end());
1648
1649 return call(prefix + "_" + method, std::move(func_arguments));
1650}
1651
1652llvm::Function *TaskCodeGenLLVM::get_struct_function(const std::string &name,
1653 int tree_id) {
1654 used_tree_ids.insert(tree_id);
1655 auto f = tlctx->get_struct_function(name, tree_id);
1656 if (!f) {
1657 TI_ERROR("Struct function {} not found.", name);
1658 }
1659 f = llvm::cast<llvm::Function>(
1660 module
1661 ->getOrInsertFunction(name, f->getFunctionType(), f->getAttributes())
1662 .getCallee());
1663 return f;
1664}
1665
1666template <typename... Args>
1667llvm::Value *TaskCodeGenLLVM::call_struct_func(int tree_id,
1668 const std::string &func_name,
1669 Args &&...args) {
1670 auto func = get_struct_function(func_name, tree_id);
1671 auto arglist = std::vector<llvm::Value *>({args...});
1672 check_func_call_signature(func->getFunctionType(), func->getName(), arglist,
1673 builder.get());
1674 return builder->CreateCall(func, arglist);
1675}
1676
1677void TaskCodeGenLLVM::visit(GetRootStmt *stmt) {
1678 if (stmt->root() == nullptr)
1679 llvm_val[stmt] = builder->CreateBitCast(
1680 get_root(SNodeTree::kFirstID),
1681 llvm::PointerType::get(
1682 StructCompilerLLVM::get_llvm_node_type(
1683 module.get(), prog->get_snode_root(SNodeTree::kFirstID)),
1684 0));
1685 else
1686 llvm_val[stmt] = builder->CreateBitCast(
1687 get_root(stmt->root()->get_snode_tree_id()),
1688 llvm::PointerType::get(
1689 StructCompilerLLVM::get_llvm_node_type(module.get(), stmt->root()),
1690 0));
1691}
1692
1693void TaskCodeGenLLVM::visit(LinearizeStmt *stmt) {
1694 llvm::Value *val = tlctx->get_constant(0);
1695 for (int i = 0; i < (int)stmt->inputs.size(); i++) {
1696 val = builder->CreateAdd(
1697 builder->CreateMul(val, tlctx->get_constant(stmt->strides[i])),
1698 llvm_val[stmt->inputs[i]]);
1699 }
1700 llvm_val[stmt] = val;
1701}
1702
1703void TaskCodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}
1704
1705llvm::Value *TaskCodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr,
1706 llvm::Value *bit_offset) {
1707 // 1. define the bit pointer struct (X=8/16/32/64)
1708 // struct bit_pointer_X {
1709 // iX* byte_ptr;
1710 // i32 bit_offset;
1711 // };
1712 TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
1713 auto struct_type = llvm::StructType::get(
1714 *llvm_context, {byte_ptr->getType(), bit_offset->getType()});
1715 // 2. allocate the bit pointer struct
1716 auto bit_ptr = create_entry_block_alloca(struct_type);
1717 // 3. store `byte_ptr`
1718 builder->CreateStore(byte_ptr, builder->CreateGEP(struct_type, bit_ptr,
1719 {tlctx->get_constant(0),
1720 tlctx->get_constant(0)}));
1721 // 4. store `bit_offset
1722 builder->CreateStore(
1723 bit_offset,
1724 builder->CreateGEP(struct_type, bit_ptr,
1725 {tlctx->get_constant(0), tlctx->get_constant(1)}));
1726 return bit_ptr;
1727}
1728
1729std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::load_bit_ptr(
1730 llvm::Value *bit_ptr) {
1731 // FIXME: get ptr_ty from taichi instead of llvm.
1732 llvm::Type *ptr_ty = nullptr;
1733 if (auto *AI = llvm::dyn_cast<llvm::AllocaInst>(bit_ptr))
1734 ptr_ty = AI->getAllocatedType();
1735 TI_ASSERT(ptr_ty);
1736 auto *struct_ty = llvm::cast<llvm::StructType>(ptr_ty);
1737 auto byte_ptr = builder->CreateLoad(
1738 struct_ty->getElementType(0),
1739 builder->CreateGEP(ptr_ty, bit_ptr,
1740 {tlctx->get_constant(0), tlctx->get_constant(0)}));
1741 auto bit_offset = builder->CreateLoad(
1742 struct_ty->getElementType(1),
1743 builder->CreateGEP(ptr_ty, bit_ptr,
1744 {tlctx->get_constant(0), tlctx->get_constant(1)}));
1745
1746 return std::make_tuple(byte_ptr, bit_offset);
1747}
1748
1749void TaskCodeGenLLVM::visit(SNodeLookupStmt *stmt) {
1750 llvm::Value *parent = nullptr;
1751 parent = llvm_val[stmt->input_snode];
1752 TI_ASSERT(parent);
1753 auto snode = stmt->snode;
1754 if (snode->type == SNodeType::root) {
1755 // FIXME: get parent_type from taichi instead of llvm.
1756 llvm::Type *parent_ty = builder->getInt8Ty();
1757 if (auto bit_cast = llvm::dyn_cast<llvm::BitCastInst>(parent)) {
1758 parent_ty = bit_cast->getDestTy();
1759 if (auto ptr_ty = llvm::dyn_cast<llvm::PointerType>(parent_ty))
1760 parent_ty = ptr_ty->getPointerElementType();
1761 }
1762 llvm_val[stmt] =
1763 builder->CreateGEP(parent_ty, parent, llvm_val[stmt->input_index]);
1764 } else if (snode->type == SNodeType::dense ||
1765 snode->type == SNodeType::pointer ||
1766 snode->type == SNodeType::dynamic ||
1767 snode->type == SNodeType::bitmasked) {
1768 if (stmt->activate) {
1769 call(snode, llvm_val[stmt->input_snode], "activate",
1770 {llvm_val[stmt->input_index]});
1771 }
1772 llvm_val[stmt] = call(snode, llvm_val[stmt->input_snode], "lookup_element",
1773 {llvm_val[stmt->input_index]});
1774 } else if (snode->type == SNodeType::bit_struct) {
1775 llvm_val[stmt] = parent;
1776 } else if (snode->type == SNodeType::quant_array) {
1777 auto element_num_bits =
1778 snode->dt->as<QuantArrayType>()->get_element_num_bits();
1779 auto offset = tlctx->get_constant(element_num_bits);
1780 offset = builder->CreateMul(offset, llvm_val[stmt->input_index]);
1781 llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset);
1782 } else {
1783 TI_INFO(snode_type_name(snode->type));
1784 TI_NOT_IMPLEMENTED
1785 }
1786}
1787
1788void TaskCodeGenLLVM::visit(GetChStmt *stmt) {
1789 if (stmt->input_snode->type == SNodeType::quant_array) {
1790 llvm_val[stmt] = llvm_val[stmt->input_ptr];
1791 } else if (stmt->ret_type->as<PointerType>()->is_bit_pointer()) {
1792 auto bit_struct = stmt->input_snode->dt->cast<BitStructType>();
1793 auto bit_offset =
1794 bit_struct->get_member_bit_offset(stmt->output_snode->id_in_bit_struct);
1795 auto offset = tlctx->get_constant(bit_offset);
1796 llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
1797 } else {
1798 auto ch = call_struct_func(
1799 stmt->output_snode->get_snode_tree_id(),
1800 stmt->output_snode->get_ch_from_parent_func_name(),
1801 builder->CreateBitCast(llvm_val[stmt->input_ptr],
1802 llvm::PointerType::getInt8PtrTy(*llvm_context)));
1803 llvm_val[stmt] = builder->CreateBitCast(
1804 ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
1805 module.get(), stmt->output_snode),
1806 0));
1807 }
1808}
1809
1810void TaskCodeGenLLVM::visit(MatrixPtrStmt *stmt) {
1811 if (stmt->offset_used_as_index()) {
1812 auto type = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed());
1813 llvm_val[stmt] =
1814 builder->CreateGEP(type, llvm_val[stmt->origin],
1815 {tlctx->get_constant(0), llvm_val[stmt->offset]});
1816 } else {
1817 // Access PtrOffset via: base_ptr + offset
1818 auto origin_address = builder->CreatePtrToInt(
1819 llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context));
1820 auto address_offset = builder->CreateSExt(
1821 llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context));
1822 auto target_address = builder->CreateAdd(origin_address, address_offset);
1823 auto dt = stmt->ret_type.ptr_removed();
1824 llvm_val[stmt] = builder->CreateIntToPtr(
1825 target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0));
1826 }
1827}
1828
1829void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
1830 auto argload = stmt->base_ptr->as<ArgLoadStmt>();
1831 auto arg_id = argload->arg_id;
1832 int num_indices = stmt->indices.size();
1833 std::vector<llvm::Value *> sizes(num_indices);
1834 auto dt = stmt->ret_type.ptr_removed();
1835 int num_element_indices =
1836 dt->is<TensorType>() ? 0 : stmt->element_shape.size();
1837 const auto layout = stmt->element_dim <= 0 ? ExternalArrayLayout::kAOS
1838 : ExternalArrayLayout::kSOA;
1839
1840 /*
1841 ExternalPtrStmt can be divided into "outter" and "inner" parts.
1842
1843 For example, "x" is an Ndarray with shape = (5, 5, 6), m=2, n=3.
1844 Indexing to a single element of "x" is of form: x[i, j, k][m, n]
1845
1846 The "outter" part is x[i, j, k], and the "inner" part is [m, n].
1847 Shape of the inner part is known at compile time, stored in its ret_type.
1848 Shape of the outter part is determined at runtime, passed from the
1849 "extra_args".
1850
1851 "num_indices - num_element_indices" gives how many "extra_args" to read from
1852 */
1853 int num_array_args = num_indices - num_element_indices;
1854 const size_t element_shape_index_offset =
1855 (layout == ExternalArrayLayout::kAOS) ? num_array_args : 0;
1856
1857 for (int i = 0; i < num_array_args; i++) {
1858 auto raw_arg = call("RuntimeContext_get_extra_args", get_context(),
1859 tlctx->get_constant(arg_id), tlctx->get_constant(i));
1860 sizes[i] = raw_arg;
1861 }
1862
1863 auto linear_index = tlctx->get_constant(0);
1864 size_t size_var_index = 0;
1865 for (int i = 0; i < num_indices; i++) {
1866 if (i >= element_shape_index_offset &&
1867 i < element_shape_index_offset + num_element_indices) {
1868 // Indexing TensorType-elements
1869 llvm::Value *size_var = tlctx->get_constant(
1870 stmt->element_shape[i - element_shape_index_offset]);
1871 linear_index = builder->CreateMul(linear_index, size_var);
1872 } else {
1873 // Indexing array dimensions
1874 linear_index = builder->CreateMul(linear_index, sizes[size_var_index++]);
1875 }
1876 linear_index = builder->CreateAdd(linear_index, llvm_val[stmt->indices[i]]);
1877 }
1878 TI_ASSERT(size_var_index == num_indices - num_element_indices);
1879
1880 /*
1881 llvm::GEP implicitly indicates alignment when used upon llvm::VectorType.
1882 For example:
1883
1884 "getelementptr <10 x i32>* %1, 0, 1" is interpreted as "%1 + 16(aligned)"
1885
1886 However, this does not fit with Taichi's Ndarray semantics. We will have to
1887 do pointer arithmetics to manually calculate the offset.
1888 */
1889 DataType operand_dtype = argload->ret_type.ptr_removed();
1890 if (operand_dtype->is<TensorType>()) {
1891 // Access PtrOffset via: base_ptr + offset * sizeof(element)
1892 auto primitive_type = operand_dtype.get_element_type();
1893 auto primitive_ptr = builder->CreateBitCast(
1894 llvm_val[stmt->base_ptr],
1895 llvm::PointerType::get(tlctx->get_data_type(primitive_type), 0));
1896
1897 auto address_offset = builder->CreateSExt(
1898 linear_index, llvm::Type::getInt64Ty(*llvm_context));
1899
1900 if (stmt->ret_type->is<TensorType>()) {
1901 // This case corresponds to outter indexing only
1902 // The stride for linear_index is num_elements() in TensorType.
1903 address_offset = builder->CreateMul(
1904 address_offset,
1905 tlctx->get_constant(
1906 get_data_type<int64>(),
1907 stmt->ret_type->cast<TensorType>()->get_num_elements()));
1908 } else {
1909 // This case corresponds to outter + inner indexing
1910 // Since both outter and inner indices are linearized into linear_index,
1911 // the stride for linear_index is 1, and there's nothing to do here.
1912 }
1913
1914 auto ret_ptr = builder->CreateGEP(tlctx->get_data_type(primitive_type),
1915 primitive_ptr, address_offset);
1916 llvm_val[stmt] = builder->CreateBitCast(
1917 ret_ptr, llvm::PointerType::get(tlctx->get_data_type(dt), 0));
1918
1919 } else {
1920 auto base_ty = tlctx->get_data_type(dt);
1921 auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr],
1922 llvm::PointerType::get(base_ty, 0));
1923
1924 llvm_val[stmt] = builder->CreateGEP(base_ty, base, linear_index);
1925 }
1926}
1927
1928void TaskCodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) {
1929 const auto arg_id = stmt->arg_id;
1930 const auto axis = stmt->axis;
1931 llvm_val[stmt] = call("RuntimeContext_get_extra_args", get_context(),
1932 tlctx->get_constant(arg_id), tlctx->get_constant(axis));
1933}
1934
1935std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
1936 std::string suffix) {
1937 current_loop_reentry = nullptr;
1938 current_while_after_loop = nullptr;
1939
1940 task_function_type =
1941 llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
1942 {llvm::PointerType::get(context_ty, 0)}, false);
1943
1944 auto task_kernel_name =
1945 fmt::format("{}_{}_{}{}", kernel_name, kernel->get_next_task_id(),
1946 stmt->task_name(), suffix);
1947 func = llvm::Function::Create(task_function_type,
1948 llvm::Function::ExternalLinkage,
1949 task_kernel_name, module.get());
1950
1951 current_task = std::make_unique<OffloadedTask>(task_kernel_name);
1952
1953 for (auto &arg : func->args()) {
1954 kernel_args.push_back(&arg);
1955 }
1956 kernel_args[0]->setName("context");
1957 if (kernel_argument_by_val())
1958 func->addParamAttr(
1959 0, llvm::Attribute::getWithByValType(*llvm_context, context_ty));
1960 // entry_block has all the allocas
1961 this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);
1962 this->final_block = llvm::BasicBlock::Create(*llvm_context, "final", func);
1963
1964 // The real function body
1965 func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
1966 builder->SetInsertPoint(func_body_bb);
1967 return task_kernel_name;
1968}
1969
1970void TaskCodeGenLLVM::finalize_offloaded_task_function() {
1971 if (!returned) {
1972 builder->CreateBr(final_block);
1973 } else {
1974 returned = false;
1975 }
1976 builder->SetInsertPoint(final_block);
1977 builder->CreateRetVoid();
1978
1979 // entry_block should jump to the body after all allocas are inserted
1980 builder->SetInsertPoint(entry_block);
1981 builder->CreateBr(func_body_bb);
1982
1983 if (compile_config.print_kernel_llvm_ir) {
1984 static FileSequenceWriter writer("taichi_kernel_generic_llvm_ir_{:04d}.ll",
1985 "unoptimized LLVM IR (generic)");
1986 writer.write(module.get());
1987 }
1988 TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
1989 // TI_INFO("Kernel function verified.");
1990}
1991
1992std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
1993 OffloadedStmt *stmt) {
1994 llvm::Value *begin, *end;
1995 if (stmt->const_begin) {
1996 begin = tlctx->get_constant(stmt->begin_value);
1997 } else {
1998 auto begin_stmt =
1999 Stmt::make<GlobalTemporaryStmt>(stmt->begin_offset, PrimitiveType::i32);
2000 begin_stmt->accept(this);
2001 begin = builder->CreateLoad(tlctx->get_data_type(PrimitiveType::i32),
2002 llvm_val[begin_stmt.get()]);
2003 }
2004 if (stmt->const_end) {
2005 end = tlctx->get_constant(stmt->end_value);
2006 } else {
2007 auto end_stmt =
2008 Stmt::make<GlobalTemporaryStmt>(stmt->end_offset, PrimitiveType::i32);
2009 end_stmt->accept(this);
2010 end = builder->CreateLoad(tlctx->get_data_type(PrimitiveType::i32),
2011 llvm_val[end_stmt.get()]);
2012 }
2013 return std::tuple(begin, end);
2014}
2015
2016void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) {
2017 using namespace llvm;
2018 // TODO: instead of constructing tons of LLVM IR, writing the logic in
2019 // runtime.cpp may be a cleaner solution. See
2020 // TaskCodeGenCPU::create_offload_range_for as an example.
2021
2022 llvm::Function *body = nullptr;
2023 auto leaf_block = stmt->snode;
2024
2025 // For a bit-vectorized loop over a quant array, we generate struct for on its
2026 // parent node (must be "dense") instead of itself for higher performance.
2027 if (stmt->is_bit_vectorized) {
2028 if (leaf_block->type == SNodeType::quant_array &&
2029 leaf_block->parent->type == SNodeType::dense) {
2030 leaf_block = leaf_block->parent;
2031 } else {
2032 TI_ERROR(
2033 "A bit-vectorized struct-for must loop over a quant array with a "
2034 "dense parent");
2035 }
2036 }
2037
2038 {
2039 // Create the loop body function
2040 auto guard = get_function_creation_guard({
2041 llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
2042 get_tls_buffer_type(),
2043 llvm::PointerType::get(get_runtime_type("Element"), 0),
2044 tlctx->get_data_type<int>(),
2045 tlctx->get_data_type<int>(),
2046 });
2047
2048 body = guard.body;
2049
2050 /* Function structure:
2051 *
2052 * function_body (entry):
2053 * loop_index = lower_bound;
2054 * tls_prologue()
2055 * bls_prologue()
2056 * goto loop_test
2057 *
2058 * loop_test:
2059 * if (loop_index < upper_bound)
2060 * goto loop_body
2061 * else
2062 * goto func_exit
2063 *
2064 * loop_body:
2065 * initialize_coordinates()
2066 * if (bitmasked voxel is active)
2067 * goto struct_for_body
2068 * else
2069 * goto loop_body_tail
2070 *
2071 * struct_for_body:
2072 * ... (Run codegen on the StructForStmt::body Taichi Block)
2073 * goto loop_body_tail
2074 *
2075 * loop_body_tail:
2076 * loop_index += block_dim
2077 * goto loop_test
2078 *
2079 * func_exit:
2080 * bls_epilogue()
2081 * tls_epilogue()
2082 * return
2083 */
2084 auto loop_index_ty = llvm::Type::getInt32Ty(*llvm_context);
2085 auto loop_index = create_entry_block_alloca(loop_index_ty);
2086
2087 RuntimeObject element("Element", this, builder.get(), get_arg(2));
2088
2089 // Loop ranges
2090 auto lower_bound = get_arg(3);
2091 auto upper_bound = get_arg(4);
2092
2093 parent_coordinates = element.get_ptr("pcoord");
2094 block_corner_coordinates =
2095 create_entry_block_alloca(physical_coordinate_ty);
2096
2097 auto refine =
2098 get_struct_function(leaf_block->refine_coordinates_func_name(),
2099 leaf_block->get_snode_tree_id());
2100 // A block corner is the global coordinate/index of the lower-left corner
2101 // cell within that block, and is the same for all the cells within that
2102 // block.
2103 call(refine, parent_coordinates, block_corner_coordinates,
2104 tlctx->get_constant(0));
2105
2106 if (stmt->tls_prologue) {
2107 stmt->tls_prologue->accept(this);
2108 }
2109
2110 if (stmt->bls_prologue) {
2111 call("block_barrier"); // "__syncthreads()"
2112 stmt->bls_prologue->accept(this);
2113 call("block_barrier"); // "__syncthreads()"
2114 }
2115
2116 auto [thread_idx, block_dim] = this->get_spmd_info();
2117 builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound),
2118 loop_index);
2119
2120 auto loop_test_bb = BasicBlock::Create(*llvm_context, "loop_test", func);
2121 auto loop_body_bb = BasicBlock::Create(*llvm_context, "loop_body", func);
2122 auto body_tail_bb =
2123 BasicBlock::Create(*llvm_context, "loop_body_tail", func);
2124 auto func_exit = BasicBlock::Create(*llvm_context, "func_exit", func);
2125 auto struct_for_body_bb =
2126 BasicBlock::Create(*llvm_context, "struct_for_body_body", func);
2127
2128 auto lrg = make_loop_reentry_guard(this);
2129 current_loop_reentry = body_tail_bb;
2130
2131 builder->CreateBr(loop_test_bb);
2132
2133 {
2134 // loop_test:
2135 // if (loop_index < upper_bound)
2136 // goto loop_body;
2137 // else
2138 // goto func_exit
2139
2140 builder->SetInsertPoint(loop_test_bb);
2141 auto cond = builder->CreateICmp(
2142 llvm::CmpInst::Predicate::ICMP_SLT,
2143 builder->CreateLoad(loop_index_ty, loop_index), upper_bound);
2144 builder->CreateCondBr(cond, loop_body_bb, func_exit);
2145 }
2146
2147 // ***********************
2148 // Begin loop_body_bb:
2149 builder->SetInsertPoint(loop_body_bb);
2150
2151 // initialize the coordinates
2152 auto new_coordinates = create_entry_block_alloca(physical_coordinate_ty);
2153
2154 call(refine, parent_coordinates, new_coordinates,
2155 builder->CreateLoad(loop_index_ty, loop_index));
2156
2157 // For a bit-vectorized loop over a quant array, one more refine step is
2158 // needed to make final coordinates non-consecutive, since each thread will
2159 // process multiple coordinates via vectorization
2160 if (stmt->is_bit_vectorized) {
2161 refine = get_struct_function(stmt->snode->refine_coordinates_func_name(),
2162 stmt->snode->get_snode_tree_id());
2163 call(refine, new_coordinates, new_coordinates, tlctx->get_constant(0));
2164 }
2165
2166 current_coordinates = new_coordinates;
2167
2168 // exec_cond: safe-guard the execution of loop body:
2169 // - if non-POT field dim exists, make sure we don't go out of bounds
2170 // - if leaf block is bitmasked, make sure we only loop over active
2171 // voxels
2172 auto exec_cond = tlctx->get_constant(true);
2173 auto coord_object = RuntimeObject(kLLVMPhysicalCoordinatesName, this,
2174 builder.get(), new_coordinates);
2175
2176 if (leaf_block->type == SNodeType::bitmasked ||
2177 leaf_block->type == SNodeType::pointer) {
2178 // test whether the current voxel is active or not
2179 auto is_active = call(leaf_block, element.get("element"), "is_active",
2180 {builder->CreateLoad(loop_index_ty, loop_index)});
2181 is_active =
2182 builder->CreateTrunc(is_active, llvm::Type::getInt1Ty(*llvm_context));
2183 exec_cond = builder->CreateAnd(exec_cond, is_active);
2184 }
2185
2186 builder->CreateCondBr(exec_cond, struct_for_body_bb, body_tail_bb);
2187
2188 {
2189 builder->SetInsertPoint(struct_for_body_bb);
2190
2191 // The real loop body of the StructForStmt
2192 stmt->body->accept(this);
2193
2194 builder->CreateBr(body_tail_bb);
2195 }
2196
2197 {
2198 // body tail: increment loop_index and jump to loop_test
2199 builder->SetInsertPoint(body_tail_bb);
2200
2201 create_increment(loop_index, block_dim);
2202 builder->CreateBr(loop_test_bb);
2203
2204 builder->SetInsertPoint(func_exit);
2205 }
2206
2207 if (stmt->bls_epilogue) {
2208 call("block_barrier"); // "__syncthreads()"
2209 stmt->bls_epilogue->accept(this);
2210 call("block_barrier"); // "__syncthreads()"
2211 }
2212
2213 if (stmt->tls_epilogue) {
2214 stmt->tls_epilogue->accept(this);
2215 }
2216 }
2217
2218 int list_element_size = std::min(leaf_block->max_num_elements(),
2219 (int64)taichi_listgen_max_element_size);
2220 int num_splits = std::max(1, list_element_size / stmt->block_dim +
2221 (list_element_size % stmt->block_dim != 0));
2222
2223 auto struct_for_func = get_runtime_function("parallel_struct_for");
2224
2225 if (arch_is_gpu(current_arch())) {
2226 struct_for_func = llvm::cast<llvm::Function>(
2227 module
2228 ->getOrInsertFunction(
2229 tlctx->get_struct_for_func_name(stmt->tls_size),
2230 struct_for_func->getFunctionType(),
2231 struct_for_func->getAttributes())
2232 .getCallee());
2233 struct_for_tls_sizes.insert(stmt->tls_size);
2234 }
2235 // Loop over nodes in the element list, in parallel
2236 call(struct_for_func, get_context(), tlctx->get_constant(leaf_block->id),
2237 tlctx->get_constant(list_element_size), tlctx->get_constant(num_splits),
2238 body, tlctx->get_constant(stmt->tls_size),
2239 tlctx->get_constant(stmt->num_cpu_threads));
2240 // TODO: why do we need num_cpu_threads on GPUs?
2241
2242 current_coordinates = nullptr;
2243 parent_coordinates = nullptr;
2244 block_corner_coordinates = nullptr;
2245}
2246
2247void TaskCodeGenLLVM::visit(LoopIndexStmt *stmt) {
2248 if (stmt->loop->is<OffloadedStmt>() &&
2249 stmt->loop->as<OffloadedStmt>()->task_type ==
2250 OffloadedStmt::TaskType::struct_for) {
2251 llvm::Type *struct_ty = nullptr;
2252 // FIXME: get struct_ty from taichi instead of llvm.
2253 if (auto *alloca = llvm::dyn_cast<llvm::AllocaInst>(current_coordinates)) {
2254 struct_ty = alloca->getAllocatedType();
2255 }
2256 TI_ASSERT(struct_ty);
2257 auto *GEP =
2258 builder->CreateGEP(struct_ty, current_coordinates,
2259 {tlctx->get_constant(0), tlctx->get_constant(0),
2260 tlctx->get_constant(stmt->index)});
2261 if (stmt->index == 0 && !llvm::isa<llvm::GEPOperator>(GEP))
2262 GEP = builder->CreateBitCast(GEP, struct_ty->getPointerTo());
2263 llvm_val[stmt] =
2264 builder->CreateLoad(llvm::Type::getInt32Ty(*llvm_context), GEP);
2265 } else {
2266 llvm_val[stmt] =
2267 builder->CreateLoad(llvm::Type::getInt32Ty(*llvm_context),
2268 loop_vars_llvm[stmt->loop][stmt->index]);
2269 }
2270}
2271
2272void TaskCodeGenLLVM::visit(LoopLinearIndexStmt *stmt) {
2273 if (stmt->loop->is<OffloadedStmt>() &&
2274 (stmt->loop->as<OffloadedStmt>()->task_type ==
2275 OffloadedStmt::TaskType::struct_for ||
2276 stmt->loop->as<OffloadedStmt>()->task_type ==
2277 OffloadedStmt::TaskType::mesh_for)) {
2278 llvm_val[stmt] = call("thread_idx");
2279 } else {
2280 TI_NOT_IMPLEMENTED;
2281 }
2282}
2283
2284void TaskCodeGenLLVM::visit(BlockCornerIndexStmt *stmt) {
2285 if (stmt->loop->is<OffloadedStmt>() &&
2286 stmt->loop->as<OffloadedStmt>()->task_type ==
2287 OffloadedStmt::TaskType::struct_for) {
2288 TI_ASSERT(block_corner_coordinates);
2289 // Make sure physical_coordinate_ty matches
2290 // struct PhysicalCoordinates {
2291 // i32 val[taichi_max_num_indices];
2292 // };
2293 TI_ASSERT(physical_coordinate_ty->isStructTy());
2294 auto physical_coordinate_ty_as_struct =
2295 llvm::cast<llvm::StructType>(physical_coordinate_ty);
2296 TI_ASSERT(physical_coordinate_ty_as_struct);
2297 TI_ASSERT(physical_coordinate_ty_as_struct->getNumElements() == 1);
2298 auto val_ty = physical_coordinate_ty_as_struct->getElementType(0);
2299 TI_ASSERT(val_ty->isArrayTy());
2300 auto val_ty_as_array = llvm::cast<llvm::ArrayType>(val_ty);
2301 llvm_val[stmt] = builder->CreateLoad(
2302 val_ty_as_array->getElementType(),
2303 builder->CreateGEP(physical_coordinate_ty, block_corner_coordinates,
2304 {tlctx->get_constant(0), tlctx->get_constant(0),
2305 tlctx->get_constant(stmt->index)}));
2306 } else {
2307 TI_NOT_IMPLEMENTED;
2308 }
2309}
2310
2311void TaskCodeGenLLVM::visit(GlobalTemporaryStmt *stmt) {
2312 auto runtime = get_runtime();
2313 auto buffer = call("get_temporary_pointer", runtime,
2314 tlctx->get_constant((int64)stmt->offset));
2315
2316 auto ptr_type = llvm::PointerType::get(
2317 tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
2318 llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
2319}
2320
2321void TaskCodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) {
2322 auto base = get_tls_base_ptr();
2323 auto ptr = builder->CreateGEP(llvm::Type::getInt8Ty(*llvm_context), base,
2324 tlctx->get_constant(stmt->offset));
2325 auto ptr_type = llvm::PointerType::get(
2326 tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
2327 llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
2328}
2329
2330void TaskCodeGenLLVM::visit(BlockLocalPtrStmt *stmt) {
2331 TI_ASSERT(bls_buffer);
2332 auto base = bls_buffer;
2333 auto ptr =
2334 builder->CreateGEP(base->getValueType(), base,
2335 {tlctx->get_constant(0), llvm_val[stmt->offset]});
2336 auto ptr_type = llvm::PointerType::get(
2337 tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
2338 llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
2339}
2340
2341void TaskCodeGenLLVM::visit(ClearListStmt *stmt) {
2342 auto snode_child = stmt->snode;
2343 auto snode_parent = stmt->snode->parent;
2344 auto meta_child = cast_pointer(emit_struct_meta(snode_child), "StructMeta");
2345 auto meta_parent = cast_pointer(emit_struct_meta(snode_parent), "StructMeta");
2346 call("clear_list", get_runtime(), meta_parent, meta_child);
2347}
2348
2349void TaskCodeGenLLVM::visit(InternalFuncStmt *stmt) {
2350 std::vector<llvm::Value *> args;
2351
2352 if (stmt->with_runtime_context)
2353 args.push_back(get_context());
2354
2355 for (auto s : stmt->args) {
2356 args.push_back(llvm_val[s]);
2357 }
2358 llvm_val[stmt] = call(stmt->func_name, std::move(args));
2359}
2360
2361void TaskCodeGenLLVM::visit(AdStackAllocaStmt *stmt) {
2362 TI_ASSERT_INFO(stmt->max_size > 0,
2363 "Adaptive autodiff stack's size should have been determined.");
2364 auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context),
2365 stmt->size_in_bytes());
2366 auto alloca = create_entry_block_alloca(type, sizeof(int64));
2367 llvm_val[stmt] = builder->CreateBitCast(
2368 alloca, llvm::PointerType::getInt8PtrTy(*llvm_context));
2369 call("stack_init", llvm_val[stmt]);
2370}
2371
2372void TaskCodeGenLLVM::visit(AdStackPopStmt *stmt) {
2373 call("stack_pop", llvm_val[stmt->stack]);
2374}
2375
2376void TaskCodeGenLLVM::visit(AdStackPushStmt *stmt) {
2377 auto stack = stmt->stack->as<AdStackAllocaStmt>();
2378 call("stack_push", llvm_val[stack], tlctx->get_constant(stack->max_size),
2379 tlctx->get_constant(stack->element_size_in_bytes()));
2380 auto primal_ptr = call("stack_top_primal", llvm_val[stack],
2381 tlctx->get_constant(stack->element_size_in_bytes()));
2382 primal_ptr = builder->CreateBitCast(
2383 primal_ptr,
2384 llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0));
2385 builder->CreateStore(llvm_val[stmt->v], primal_ptr);
2386}
2387
2388void TaskCodeGenLLVM::visit(AdStackLoadTopStmt *stmt) {
2389 auto stack = stmt->stack->as<AdStackAllocaStmt>();
2390 auto primal_ptr = call("stack_top_primal", llvm_val[stack],
2391 tlctx->get_constant(stack->element_size_in_bytes()));
2392 auto primal_ty = tlctx->get_data_type(stmt->ret_type);
2393 primal_ptr =
2394 builder->CreateBitCast(primal_ptr, llvm::PointerType::get(primal_ty, 0));
2395 llvm_val[stmt] = builder->CreateLoad(primal_ty, primal_ptr);
2396}
2397
2398void TaskCodeGenLLVM::visit(AdStackLoadTopAdjStmt *stmt) {
2399 auto stack = stmt->stack->as<AdStackAllocaStmt>();
2400 auto adjoint = call("stack_top_adjoint", llvm_val[stack],
2401 tlctx->get_constant(stack->element_size_in_bytes()));
2402 auto adjoint_ty = tlctx->get_data_type(stmt->ret_type);
2403 adjoint =
2404 builder->CreateBitCast(adjoint, llvm::PointerType::get(adjoint_ty, 0));
2405 llvm_val[stmt] = builder->CreateLoad(adjoint_ty, adjoint);
2406}
2407
2408void TaskCodeGenLLVM::visit(AdStackAccAdjointStmt *stmt) {
2409 auto stack = stmt->stack->as<AdStackAllocaStmt>();
2410 auto adjoint_ptr = call("stack_top_adjoint", llvm_val[stack],
2411 tlctx->get_constant(stack->element_size_in_bytes()));
2412 auto adjoint_ty = tlctx->get_data_type(stack->ret_type);
2413 adjoint_ptr = builder->CreateBitCast(adjoint_ptr,
2414 llvm::PointerType::get(adjoint_ty, 0));
2415 auto old_val = builder->CreateLoad(adjoint_ty, adjoint_ptr);
2416 TI_ASSERT(is_real(stmt->v->ret_type));
2417 auto new_val = builder->CreateFAdd(old_val, llvm_val[stmt->v]);
2418 builder->CreateStore(new_val, adjoint_ptr);
2419}
2420
2421void TaskCodeGenLLVM::visit(RangeAssumptionStmt *stmt) {
2422 llvm_val[stmt] = llvm_val[stmt->input];
2423}
2424
2425void TaskCodeGenLLVM::visit(LoopUniqueStmt *stmt) {
2426 llvm_val[stmt] = llvm_val[stmt->input];
2427}
2428
2429void TaskCodeGenLLVM::visit_call_bitcode(ExternalFuncCallStmt *stmt) {
2430 TI_ASSERT(stmt->type == ExternalFuncCallStmt::BITCODE);
2431 std::vector<llvm::Value *> arg_values;
2432 for (const auto &s : stmt->arg_stmts)
2433 arg_values.push_back(llvm_val[s]);
2434 // Link external module to the core module
2435 if (linked_modules.find(stmt->bc_filename) == linked_modules.end()) {
2436 linked_modules.insert(stmt->bc_filename);
2437 std::unique_ptr<llvm::Module> external_module =
2438 module_from_bitcode_file(stmt->bc_filename, llvm_context);
2439 auto *func_ptr = external_module->getFunction(stmt->bc_funcname);
2440 TI_ASSERT_INFO(func_ptr != nullptr, "{} is not found in {}.",
2441 stmt->bc_funcname, stmt->bc_filename);
2442 auto link_error =
2443 llvm::Linker::linkModules(*module, std::move(external_module));
2444 TI_ASSERT(!link_error);
2445 }
2446 // Retrieve function again. Do it here to detect name conflicting.
2447 auto *func_ptr = module->getFunction(stmt->bc_funcname);
2448 // Convert pointer type from a[n * m] to a[n][m]
2449 for (int i = 0; i < func_ptr->getFunctionType()->getNumParams(); ++i) {
2450 TI_ASSERT_INFO(func_ptr->getArg(i)->getType()->getTypeID() ==
2451 arg_values[i]->getType()->getTypeID(),
2452 "TypeID {} != {} with {}",
2453 (int)func_ptr->getArg(i)->getType()->getTypeID(),
2454 (int)arg_values[i]->getType()->getTypeID(), i);
2455 auto tmp_value = arg_values[i];
2456 arg_values[i] =
2457 builder->CreatePointerCast(tmp_value, func_ptr->getArg(i)->getType());
2458 }
2459 call(func_ptr, arg_values);
2460}
2461
2462void TaskCodeGenLLVM::visit_call_shared_object(ExternalFuncCallStmt *stmt) {
2463 TI_ASSERT(stmt->type == ExternalFuncCallStmt::SHARED_OBJECT);
2464 std::vector<llvm::Type *> arg_types;
2465 std::vector<llvm::Value *> arg_values;
2466
2467 for (const auto &s : stmt->arg_stmts) {
2468 arg_types.push_back(tlctx->get_data_type(s->ret_type));
2469 arg_values.push_back(llvm_val[s]);
2470 }
2471
2472 for (const auto &s : stmt->output_stmts) {
2473 auto t = tlctx->get_data_type(s->ret_type);
2474 auto ptr = llvm::PointerType::get(t, 0);
2475 arg_types.push_back(ptr);
2476 arg_values.push_back(llvm_val[s]);
2477 }
2478
2479 auto func_type = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
2480 arg_types, false);
2481 auto func_ptr_type = llvm::PointerType::get(func_type, 0);
2482
2483 auto addr = tlctx->get_constant((std::size_t)stmt->so_func);
2484 auto func = builder->CreateIntToPtr(addr, func_ptr_type);
2485 call(func, func_type, arg_values);
2486}
2487
2488void TaskCodeGenLLVM::visit(ExternalFuncCallStmt *stmt) {
2489 TI_NOT_IMPLEMENTED
2490}
2491
2492void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) {
2493 llvm_val[stmt] = get_arg(2);
2494}
2495
2496void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) {
2497 auto type = tlctx->get_data_type(stmt->ret_type->as<TensorType>());
2498 llvm::Value *vec = llvm::UndefValue::get(type);
2499 for (int i = 0; i < stmt->values.size(); ++i) {
2500 auto *elem = llvm_val[stmt->values[i]];
2501 if (codegen_vector_type(compile_config)) {
2502 TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(type));
2503 vec = builder->CreateInsertElement(vec, elem, i);
2504 } else {
2505 TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(type));
2506 vec = builder->CreateInsertValue(vec, elem, i);
2507 }
2508 }
2509 llvm_val[stmt] = vec;
2510}
2511
2512void TaskCodeGenLLVM::eliminate_unused_functions() {
2513 TaichiLLVMContext::eliminate_unused_functions(
2514 module.get(), [&](std::string func_name) {
2515 for (auto &task : offloaded_tasks) {
2516 if (task.name == func_name)
2517 return true;
2518 }
2519 return false;
2520 });
2521}
2522
2523FunctionCreationGuard TaskCodeGenLLVM::get_function_creation_guard(
2524 std::vector<llvm::Type *> argument_types,
2525 const std::string &func_name) {
2526 return FunctionCreationGuard(this, argument_types, func_name);
2527}
2528
2529void TaskCodeGenLLVM::initialize_context() {
2530 TI_ASSERT(tlctx != nullptr);
2531 llvm_context = tlctx->get_this_thread_context();
2532 builder = std::make_unique<llvm::IRBuilder<>>(*llvm_context);
2533}
2534
2535llvm::Value *TaskCodeGenLLVM::get_arg(int i) {
2536 std::vector<llvm::Value *> args;
2537 for (auto &arg : func->args()) {
2538 args.push_back(&arg);
2539 }
2540 return args[i];
2541}
2542
2543llvm::Value *TaskCodeGenLLVM::get_context() {
2544 return get_arg(0);
2545}
2546
2547llvm::Value *TaskCodeGenLLVM::get_tls_base_ptr() {
2548 return get_arg(1);
2549}
2550
2551llvm::Type *TaskCodeGenLLVM::get_tls_buffer_type() {
2552 return llvm::Type::getInt8PtrTy(*llvm_context);
2553}
2554
2555std::vector<llvm::Type *> TaskCodeGenLLVM::get_xlogue_argument_types() {
2556 return {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
2557 get_tls_buffer_type()};
2558}
2559
2560std::vector<llvm::Type *> TaskCodeGenLLVM::get_mesh_xlogue_argument_types() {
2561 return {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
2562 get_tls_buffer_type(), tlctx->get_data_type<uint32_t>()};
2563}
2564
2565llvm::Type *TaskCodeGenLLVM::get_xlogue_function_type() {
2566 return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
2567 get_xlogue_argument_types(), false);
2568}
2569
2570llvm::Type *TaskCodeGenLLVM::get_mesh_xlogue_function_type() {
2571 return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
2572 get_mesh_xlogue_argument_types(), false);
2573}
2574
2575llvm::Value *TaskCodeGenLLVM::get_root(int snode_tree_id) {
2576 return call("LLVMRuntime_get_roots", get_runtime(),
2577 tlctx->get_constant(snode_tree_id));
2578}
2579
2580llvm::Value *TaskCodeGenLLVM::get_runtime() {
2581 auto runtime_ptr = call("RuntimeContext_get_runtime", get_context());
2582 return builder->CreateBitCast(
2583 runtime_ptr, llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0));
2584}
2585
2586llvm::Value *TaskCodeGenLLVM::emit_struct_meta(SNode *snode) {
2587 auto obj = emit_struct_meta_object(snode);
2588 TI_ASSERT(obj != nullptr);
2589 return obj->ptr;
2590}
2591
2592void TaskCodeGenLLVM::emit_to_module() {
2593 TI_AUTO_PROF
2594 ir->accept(this);
2595}
2596
2597LLVMCompiledTask TaskCodeGenLLVM::run_compilation() {
2598 // Final lowering
2599 auto offload_to_executable = [](IRNode *ir, const CompileConfig &config,
2600 Kernel *kernel) {
2601 bool verbose = config.print_ir;
2602 if ((kernel->is_accessor && !config.print_accessor_ir) ||
2603 (kernel->is_evaluator && !config.print_evaluator_ir)) {
2604 verbose = false;
2605 }
2606 irpass::offload_to_executable(
2607 ir, config, kernel, verbose,
2608 /*determine_ad_stack_size=*/kernel->autodiff_mode ==
2609 AutodiffMode::kReverse,
2610 /*lower_global_access=*/true,
2611 /*make_thread_local=*/config.make_thread_local,
2612 /*make_block_local=*/
2613 is_extension_supported(config.arch, Extension::bls) &&
2614 config.make_block_local);
2615 };
2616
2617 offload_to_executable(ir, compile_config, kernel);
2618
2619 emit_to_module();
2620 eliminate_unused_functions();
2621
2622 if (compile_config.arch == Arch::cuda) {
2623 // CUDA specific metadata
2624 for (const auto &task : offloaded_tasks) {
2625 llvm::Function *func = module->getFunction(task.name);
2626 TI_ASSERT(func);
2627 tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
2628 }
2629 } else if (compile_config.arch == Arch::amdgpu) {
2630 for (const auto &task : offloaded_tasks) {
2631 llvm::Function *func = module->getFunction(task.name);
2632 TI_ASSERT(func);
2633 tlctx->mark_function_as_amdgpu_kernel(func);
2634 }
2635 }
2636
2637 return {std::move(offloaded_tasks), std::move(module),
2638 std::move(used_tree_ids), std::move(struct_for_tls_sizes)};
2639}
2640
2641llvm::Value *TaskCodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
2642 llvm::Value *xlogue;
2643
2644 auto xlogue_type = get_xlogue_function_type();
2645 auto xlogue_ptr_type = llvm::PointerType::get(xlogue_type, 0);
2646
2647 if (block) {
2648 auto guard = get_function_creation_guard(get_xlogue_argument_types());
2649 block->accept(this);
2650 xlogue = guard.body;
2651 } else {
2652 xlogue = llvm::ConstantPointerNull::get(xlogue_ptr_type);
2653 }
2654
2655 return xlogue;
2656}
2657
2658llvm::Value *TaskCodeGenLLVM::create_mesh_xlogue(
2659 std::unique_ptr<Block> &block) {
2660 llvm::Value *xlogue;
2661
2662 auto xlogue_type = get_mesh_xlogue_function_type();
2663 auto xlogue_ptr_type = llvm::PointerType::get(xlogue_type, 0);
2664
2665 if (block) {
2666 auto guard = get_function_creation_guard(get_mesh_xlogue_argument_types());
2667 block->accept(this);
2668 xlogue = guard.body;
2669 } else {
2670 xlogue = llvm::ConstantPointerNull::get(xlogue_ptr_type);
2671 }
2672
2673 return xlogue;
2674}
2675
2676void TaskCodeGenLLVM::visit(ReferenceStmt *stmt) {
2677 llvm_val[stmt] = llvm_val[stmt->var];
2678}
2679
2680void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
2681 if (!func_map.count(stmt->func)) {
2682 auto guard = get_function_creation_guard(
2683 {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)},
2684 stmt->func->get_name());
2685 Callable *old_callable = current_callable;
2686 current_callable = stmt->func;
2687 func_map.insert({stmt->func, guard.body});
2688 stmt->func->ir->accept(this);
2689 current_callable = old_callable;
2690 }
2691 llvm::Function *llvm_func = func_map[stmt->func];
2692 auto *new_ctx = call("allocate_runtime_context", get_runtime());
2693 call("RuntimeContext_set_runtime", new_ctx, get_runtime());
2694 for (int i = 0; i < stmt->args.size(); i++) {
2695 auto *val =
2696 bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type);
2697 call("RuntimeContext_set_args", new_ctx,
2698 llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
2699 }
2700 llvm::Value *result_buffer = nullptr;
2701 if (stmt->ret_type) {
2702 auto *ret_type = tlctx->get_data_type(stmt->ret_type);
2703 result_buffer = builder->CreateAlloca(ret_type);
2704 auto *result_buffer_u64 = builder->CreatePointerCast(
2705 result_buffer,
2706 llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
2707 call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
2708 }
2709 call(llvm_func, new_ctx);
2710 llvm_val[stmt] = result_buffer;
2711 call("recycle_runtime_context", get_runtime(), new_ctx);
2712}
2713
2714void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
2715 auto *struct_type = tlctx->get_data_type(stmt->src->ret_type);
2716 std::vector<llvm::Value *> index;
2717 index.reserve(stmt->index.size() + 1);
2718 index.push_back(tlctx->get_constant(0));
2719 for (auto &i : stmt->index) {
2720 index.push_back(tlctx->get_constant(i));
2721 }
2722 auto *gep = builder->CreateGEP(struct_type, llvm_val[stmt->src], index);
2723 auto *val = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), gep);
2724 llvm_val[stmt] = val;
2725}
2726
2727void TaskCodeGenLLVM::create_return(llvm::Value *buffer,
2728 llvm::Type *buffer_type,
2729 const std::vector<Stmt *> &elements,
2730 const Type *current_type,
2731 int &current_element,
2732 std::vector<llvm::Value *> &current_index) {
2733 if (auto primitive_type = current_type->cast<PrimitiveType>()) {
2734 TI_ASSERT((Type *)elements[current_element]->ret_type == current_type);
2735 auto *gep = builder->CreateGEP(buffer_type, buffer, current_index);
2736 builder->CreateStore(llvm_val[elements[current_element]], gep);
2737 current_element++;
2738 } else if (auto struct_type = current_type->cast<StructType>()) {
2739 int i = 0;
2740 for (const auto &element : struct_type->elements()) {
2741 current_index.push_back(tlctx->get_constant(i++));
2742 create_return(buffer, buffer_type, elements, element.type,
2743 current_element, current_index);
2744 current_index.pop_back();
2745 }
2746 } else {
2747 auto tensor_type = current_type->as<TensorType>();
2748 int num_elements = tensor_type->get_num_elements();
2749 Type *element_type = tensor_type->get_element_type();
2750 for (int i = 0; i < num_elements; i++) {
2751 current_index.push_back(tlctx->get_constant(i));
2752 create_return(buffer, buffer_type, elements, element_type,
2753 current_element, current_index);
2754 current_index.pop_back();
2755 }
2756 }
2757}
2758
2759void TaskCodeGenLLVM::create_return(const std::vector<Stmt *> &elements) {
2760 auto buffer = call("RuntimeContext_get_result_buffer", get_context());
2761 auto ret_type = current_callable->ret_type;
2762 auto buffer_type = tlctx->get_data_type(ret_type);
2763 buffer = builder->CreatePointerCast(buffer,
2764 llvm::PointerType::get(buffer_type, 0));
2765 int current_element = 0;
2766 std::vector<llvm::Value *> current_index = {tlctx->get_constant(0)};
2767 create_return(buffer, buffer_type, elements, ret_type, current_element,
2768 current_index);
2769};
2770
2771LLVMCompiledTask LLVMCompiledTask::clone() const {
2772 return {tasks, llvm::CloneModule(*module), used_tree_ids,
2773 struct_for_tls_sizes};
2774}
2775
2776LLVMCompiledKernel LLVMCompiledKernel::clone() const {
2777 return {tasks, llvm::CloneModule(*module)};
2778}
2779
2780} // namespace taichi::lang
2781
2782#endif // #ifdef TI_WITH_LLVM
2783