1// A llvm backend helper
2
3#include "taichi/runtime/llvm/llvm_context.h"
4
5#include "llvm/Transforms/Utils/Cloning.h"
6#include "llvm/ADT/APFloat.h"
7#include "llvm/ADT/STLExtras.h"
8#include "llvm/IR/BasicBlock.h"
9#include "llvm/IR/Constants.h"
10#include "llvm/IR/DerivedTypes.h"
11#include "llvm/IR/Function.h"
12#include "llvm/IR/IRBuilder.h"
13#include "llvm/IR/Instructions.h"
14#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
15#include "llvm/IR/Module.h"
16#include "llvm/IR/Intrinsics.h"
17#include "llvm/IR/IntrinsicsNVPTX.h"
18#ifdef TI_WITH_AMDGPU
19#include "llvm/IR/IntrinsicsAMDGPU.h"
20#endif // TI_WITH_AMDGPU
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/Module.h"
23#include "llvm/IR/Type.h"
24#include "llvm/IR/Verifier.h"
25#include "llvm/Support/TargetSelect.h"
26#include "llvm/Support/FileSystem.h"
27#include "llvm/Target/TargetMachine.h"
28#include "llvm/Transforms/InstCombine/InstCombine.h"
29#include "llvm/Transforms/Scalar.h"
30#include "llvm/Transforms/Scalar/GVN.h"
31#include "llvm/Transforms/Utils.h"
32#include "llvm/Transforms/IPO.h"
33#include "llvm/Transforms/IPO/Internalize.h"
34#include "llvm/Transforms/IPO/GlobalDCE.h"
35#include "llvm/Pass.h"
36#include "llvm/Passes/PassBuilder.h"
37#include "llvm/Bitcode/BitcodeReader.h"
38#include "llvm/Linker/Linker.h"
39#include "llvm/Demangle/Demangle.h"
40#include "llvm/Bitcode/BitcodeWriter.h"
41
42#include "taichi/util/lang_util.h"
43#include "taichi/jit/jit_session.h"
44#include "taichi/common/task.h"
45#include "taichi/util/environ_config.h"
46#include "llvm_context.h"
47#include "taichi/runtime/program_impls/llvm/llvm_program.h"
48#include "taichi/codegen/codegen_utils.h"
49
50#include "taichi/runtime/llvm/llvm_context_pass.h"
51
52#ifdef _WIN32
53// Travis CI seems doesn't support <filesystem>...
54#include <filesystem>
55#else
56#include <unistd.h>
57#endif
58
59#if defined(TI_WITH_CUDA)
60#include "taichi/rhi/cuda/cuda_context.h"
61#endif
62
63#if defined(TI_WITH_AMDGPU)
64#include "taichi/rhi/amdgpu/amdgpu_context.h"
65#endif
66
67namespace taichi::lang {
68
69using namespace llvm;
70
71TaichiLLVMContext::TaichiLLVMContext(const CompileConfig &config, Arch arch)
72 : config_(config), arch_(arch) {
73 TI_TRACE("Creating Taichi llvm context for arch: {}", arch_name(arch));
74 main_thread_id_ = std::this_thread::get_id();
75 main_thread_data_ = get_this_thread_data();
76 llvm::remove_fatal_error_handler();
77 llvm::install_fatal_error_handler(
78 [](void *user_data, const char *reason, bool gen_crash_diag) {
79 TI_ERROR("LLVM Fatal Error: {}", reason);
80 },
81 nullptr);
82
83 if (arch_is_cpu(arch)) {
84#if defined(TI_PLATFORM_OSX) and defined(TI_ARCH_ARM)
85 // Note that on Apple Silicon (M1), "native" seems to mean arm instead of
86 // arm64 (aka AArch64).
87 LLVMInitializeAArch64Target();
88 LLVMInitializeAArch64TargetMC();
89 LLVMInitializeAArch64TargetInfo();
90 LLVMInitializeAArch64AsmPrinter();
91#else
92 llvm::InitializeNativeTarget();
93 llvm::InitializeNativeTargetAsmPrinter();
94 llvm::InitializeNativeTargetAsmParser();
95#endif
96 } else if (arch == Arch::dx12) {
97 // FIXME: Must initialize these before initializing Arch::dx12
98 // because it uses the jit of CPU right now.
99 llvm::InitializeNativeTarget();
100 llvm::InitializeNativeTargetAsmPrinter();
101 llvm::InitializeNativeTargetAsmParser();
102 // The dx target is used elsewhere, so we need to initialize it too.
103#if defined(TI_WITH_DX12)
104 LLVMInitializeDirectXTarget();
105 LLVMInitializeDirectXTargetMC();
106 LLVMInitializeDirectXTargetInfo();
107 LLVMInitializeDirectXAsmPrinter();
108#endif
109 } else if (arch == Arch::amdgpu) {
110#if defined(TI_WITH_AMDGPU)
111 LLVMInitializeAMDGPUTarget();
112 LLVMInitializeAMDGPUTargetMC();
113 LLVMInitializeAMDGPUTargetInfo();
114 LLVMInitializeAMDGPUAsmPrinter();
115 LLVMInitializeAMDGPUAsmParser();
116#else
117 TI_NOT_IMPLEMENTED
118#endif
119 } else {
120#if defined(TI_WITH_CUDA)
121 LLVMInitializeNVPTXTarget();
122 LLVMInitializeNVPTXTargetMC();
123 LLVMInitializeNVPTXTargetInfo();
124 LLVMInitializeNVPTXAsmPrinter();
125#else
126 TI_NOT_IMPLEMENTED
127#endif
128 }
129 jit = JITSession::create(this, config, arch);
130
131 linking_context_data = std::make_unique<ThreadLocalData>(
132 std::make_unique<llvm::orc::ThreadSafeContext>(
133 std::make_unique<llvm::LLVMContext>()));
134 linking_context_data->runtime_module = clone_module_to_context(
135 get_this_thread_runtime_module(), linking_context_data->llvm_context);
136
137 TI_TRACE("Taichi llvm context created.");
138}
139
140TaichiLLVMContext::~TaichiLLVMContext() {
141}
142
143llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
144 auto ctx = get_this_thread_context();
145 if (dt->is_primitive(PrimitiveTypeID::i8) ||
146 dt->is_primitive(PrimitiveTypeID::u8)) {
147 return llvm::Type::getInt8Ty(*ctx);
148 } else if (dt->is_primitive(PrimitiveTypeID::i16) ||
149 dt->is_primitive(PrimitiveTypeID::u16)) {
150 return llvm::Type::getInt16Ty(*ctx);
151 } else if (dt->is_primitive(PrimitiveTypeID::i32) ||
152 dt->is_primitive(PrimitiveTypeID::u32)) {
153 return llvm::Type::getInt32Ty(*ctx);
154 } else if (dt->is_primitive(PrimitiveTypeID::i64) ||
155 dt->is_primitive(PrimitiveTypeID::u64)) {
156 return llvm::Type::getInt64Ty(*ctx);
157 } else if (dt->is_primitive(PrimitiveTypeID::u1)) {
158 return llvm::Type::getInt1Ty(*ctx);
159 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
160 return llvm::Type::getFloatTy(*ctx);
161 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
162 return llvm::Type::getDoubleTy(*ctx);
163 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
164 return llvm::Type::getHalfTy(*ctx);
165 } else if (dt->is<TensorType>()) {
166 auto tensor_type = dt->cast<TensorType>();
167 auto element_type = get_data_type(tensor_type->get_element_type());
168 auto num_elements = tensor_type->get_num_elements();
169 // Return type is <element_type * num_elements> if real matrix is used,
170 // otherwise [element_type * num_elements].
171 if (codegen_vector_type(config_)) {
172 return llvm::VectorType::get(element_type, num_elements,
173 /*scalable=*/false);
174 }
175 return llvm::ArrayType::get(element_type, num_elements);
176 } else if (dt->is<StructType>()) {
177 std::vector<llvm::Type *> types;
178 auto struct_type = dt->cast<StructType>();
179 for (const auto &element : struct_type->elements()) {
180 types.push_back(get_data_type(element.type));
181 }
182 return llvm::StructType::get(*ctx, types);
183 } else {
184 TI_INFO(data_type_name(dt));
185 TI_NOT_IMPLEMENTED;
186 }
187}
188
189std::string find_existing_command(const std::vector<std::string> &commands) {
190 for (auto &cmd : commands) {
191 if (command_exist(cmd)) {
192 return cmd;
193 }
194 }
195 for (const auto &cmd : commands) {
196 TI_WARN("Potential command {}", cmd);
197 }
198 TI_ERROR("None command found.");
199}
200
201std::string get_runtime_fn(Arch arch) {
202 return fmt::format("runtime_{}.bc", arch_name(arch));
203}
204
205std::string libdevice_path() {
206 std::string folder;
207 folder = runtime_lib_dir();
208 auto cuda_version_string = get_cuda_version_string();
209 auto cuda_version_major = int(std::atof(cuda_version_string.c_str()));
210 return fmt::format("{}/slim_libdevice.{}.bc", folder, cuda_version_major);
211}
212
213std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_module_to_context(
214 llvm::Module *module,
215 llvm::LLVMContext *target_context) {
216 // Dump a module from one context to bitcode and then parse the bitcode in a
217 // different context
218 std::string bitcode;
219
220 {
221 std::lock_guard<std::mutex> _(mut_);
222 llvm::raw_string_ostream sos(bitcode);
223 // Use a scope to make sure sos flushes on destruction
224 llvm::WriteBitcodeToFile(*module, sos);
225 }
226
227 auto cloned = parseBitcodeFile(
228 llvm::MemoryBufferRef(bitcode, "runtime_bitcode"), *target_context);
229 if (!cloned) {
230 auto error = cloned.takeError();
231 TI_ERROR("Bitcode cloned failed.");
232 }
233 return std::move(cloned.get());
234}
235
236std::unique_ptr<llvm::Module>
237TaichiLLVMContext::clone_module_to_this_thread_context(llvm::Module *module) {
238 TI_TRACE("Cloning struct module");
239 TI_ASSERT(module);
240 auto this_context = get_this_thread_context();
241 return clone_module_to_context(module, this_context);
242}
243
244std::unique_ptr<llvm::Module> LlvmModuleBitcodeLoader::load(
245 llvm::LLVMContext *ctx) const {
246 TI_AUTO_PROF;
247 std::ifstream ifs(bitcode_path_, std::ios::binary);
248 TI_ERROR_IF(!ifs, "Bitcode file ({}) not found.", bitcode_path_);
249 std::string bitcode(std::istreambuf_iterator<char>(ifs),
250 (std::istreambuf_iterator<char>()));
251 auto runtime =
252 parseBitcodeFile(llvm::MemoryBufferRef(bitcode, buffer_id_), *ctx);
253 if (!runtime) {
254 auto error = runtime.takeError();
255 TI_WARN("Bitcode loading error message:");
256 llvm::errs() << error << "\n";
257 TI_ERROR("Failed to load bitcode={}", bitcode_path_);
258 return nullptr;
259 }
260
261 if (inline_funcs_) {
262 for (auto &f : *(runtime.get())) {
263 TaichiLLVMContext::mark_inline(&f);
264 }
265 }
266
267 const bool module_broken = llvm::verifyModule(*runtime.get(), &llvm::errs());
268 if (module_broken) {
269 TI_ERROR("Broken bitcode={}", bitcode_path_);
270 return nullptr;
271 }
272 return std::move(runtime.get());
273}
274
275std::unique_ptr<llvm::Module> module_from_bitcode_file(
276 const std::string &bitcode_path,
277 llvm::LLVMContext *ctx) {
278 LlvmModuleBitcodeLoader loader;
279 return loader.set_bitcode_path(bitcode_path)
280 .set_buffer_id("runtime_bitcode")
281 .set_inline_funcs(true)
282 .load(ctx);
283}
284
285// The goal of this function is to rip off huge libdevice functions that are not
286// going to be used later, at an early stage. Although the LLVM optimizer will
287// ultimately remove unused functions during a global DCE pass, we don't even
288// want these functions to waste clock cycles during module cloning and linking.
289static void remove_useless_cuda_libdevice_functions(llvm::Module *module) {
290 std::vector<std::string> function_name_list = {
291 "rnorm3df",
292 "norm4df",
293 "rnorm4df",
294 "normf",
295 "rnormf",
296 "j0f",
297 "j1f",
298 "y0f",
299 "y1f",
300 "ynf",
301 "jnf",
302 "cyl_bessel_i0f",
303 "cyl_bessel_i1f",
304 "j0",
305 "j1",
306 "y0",
307 "y1",
308 "yn",
309 "jn",
310 "cyl_bessel_i0",
311 "cyl_bessel_i1",
312 "tgammaf",
313 "lgammaf",
314 "tgamma",
315 "lgamma",
316 "erff",
317 "erfinvf",
318 "erfcf",
319 "erfcxf",
320 "erfcinvf",
321 "erf",
322 "erfinv",
323 "erfcx",
324 "erfcinv",
325 "erfc",
326 };
327 for (auto fn : function_name_list) {
328 module->getFunction("__nv_" + fn)->eraseFromParent();
329 }
330 module->getFunction("__internal_lgamma_pos")->eraseFromParent();
331}
332
333// Note: runtime_module = init_module < struct_module
334
335std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_runtime_module() {
336 TI_AUTO_PROF
337 auto *mod = get_this_thread_runtime_module();
338
339 std::unique_ptr<llvm::Module> cloned;
340 {
341 TI_PROFILER("clone module");
342 cloned = llvm::CloneModule(*mod);
343 }
344
345 TI_ASSERT(cloned != nullptr);
346
347 return cloned;
348}
349
350std::unique_ptr<llvm::Module> TaichiLLVMContext::module_from_file(
351 const std::string &file) {
352 auto ctx = get_this_thread_context();
353 std::unique_ptr<llvm::Module> module = module_from_bitcode_file(
354 fmt::format("{}/{}", runtime_lib_dir(), file), ctx);
355 if (arch_ == Arch::cuda || arch_ == Arch::amdgpu) {
356 auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin,
357 bool ret = true,
358 std::vector<llvm::Type *> types = {},
359 std::vector<llvm::Value *> extra_args = {}) {
360 auto func = module->getFunction(name);
361 if (!func) {
362 return;
363 }
364 func->deleteBody();
365 auto bb = llvm::BasicBlock::Create(*ctx, "entry", func);
366 IRBuilder<> builder(*ctx);
367 builder.SetInsertPoint(bb);
368 std::vector<llvm::Value *> args;
369 for (auto &arg : func->args())
370 args.push_back(&arg);
371 args.insert(args.end(), extra_args.begin(), extra_args.end());
372 if (ret) {
373 builder.CreateRet(builder.CreateIntrinsic(intrin, types, args));
374 } else {
375 builder.CreateIntrinsic(intrin, types, args);
376 builder.CreateRetVoid();
377 }
378 TaichiLLVMContext::mark_inline(func);
379 };
380
381 auto patch_atomic_add = [&](std::string name,
382 llvm::AtomicRMWInst::BinOp op) {
383 auto func = module->getFunction(name);
384 if (!func) {
385 return;
386 }
387 func->deleteBody();
388 auto bb = llvm::BasicBlock::Create(*ctx, "entry", func);
389 IRBuilder<> builder(*ctx);
390 builder.SetInsertPoint(bb);
391 std::vector<llvm::Value *> args;
392 for (auto &arg : func->args())
393 args.push_back(&arg);
394 builder.CreateRet(builder.CreateAtomicRMW(
395 op, args[0], args[1], llvm::MaybeAlign(0),
396 llvm::AtomicOrdering::SequentiallyConsistent));
397 TaichiLLVMContext::mark_inline(func);
398 };
399
400 patch_atomic_add("atomic_add_i32", llvm::AtomicRMWInst::Add);
401 patch_atomic_add("atomic_add_i64", llvm::AtomicRMWInst::Add);
402 patch_atomic_add("atomic_add_f64", llvm::AtomicRMWInst::FAdd);
403 patch_atomic_add("atomic_add_f32", llvm::AtomicRMWInst::FAdd);
404
405 if (arch_ == Arch::cuda) {
406 module->setTargetTriple("nvptx64-nvidia-cuda");
407
408#if defined(TI_WITH_CUDA)
409 auto func = module->getFunction("cuda_compute_capability");
410 if (func) {
411 func->deleteBody();
412 auto bb = llvm::BasicBlock::Create(*ctx, "entry", func);
413 IRBuilder<> builder(*ctx);
414 builder.SetInsertPoint(bb);
415 builder.CreateRet(
416 get_constant(CUDAContext::get_instance().get_compute_capability()));
417 TaichiLLVMContext::mark_inline(func);
418 }
419#endif
420
421 patch_intrinsic("thread_idx", Intrinsic::nvvm_read_ptx_sreg_tid_x);
422 patch_intrinsic("cuda_clock_i64", Intrinsic::nvvm_read_ptx_sreg_clock64);
423 patch_intrinsic("block_idx", Intrinsic::nvvm_read_ptx_sreg_ctaid_x);
424 patch_intrinsic("block_dim", Intrinsic::nvvm_read_ptx_sreg_ntid_x);
425 patch_intrinsic("grid_dim", Intrinsic::nvvm_read_ptx_sreg_nctaid_x);
426 patch_intrinsic("block_barrier", Intrinsic::nvvm_barrier0, false);
427 patch_intrinsic("warp_barrier", Intrinsic::nvvm_bar_warp_sync, false);
428 patch_intrinsic("block_memfence", Intrinsic::nvvm_membar_cta, false);
429 patch_intrinsic("grid_memfence", Intrinsic::nvvm_membar_gl, false);
430 patch_intrinsic("system_memfence", Intrinsic::nvvm_membar_sys, false);
431
432 patch_intrinsic("cuda_all", Intrinsic::nvvm_vote_all);
433 patch_intrinsic("cuda_all_sync", Intrinsic::nvvm_vote_all_sync);
434
435 patch_intrinsic("cuda_any", Intrinsic::nvvm_vote_any);
436 patch_intrinsic("cuda_any_sync", Intrinsic::nvvm_vote_any_sync);
437
438 patch_intrinsic("cuda_uni", Intrinsic::nvvm_vote_uni);
439 patch_intrinsic("cuda_uni_sync", Intrinsic::nvvm_vote_uni_sync);
440
441 patch_intrinsic("cuda_ballot", Intrinsic::nvvm_vote_ballot);
442 patch_intrinsic("cuda_ballot_sync", Intrinsic::nvvm_vote_ballot_sync);
443
444 patch_intrinsic("cuda_shfl_down_sync_i32",
445 Intrinsic::nvvm_shfl_sync_down_i32);
446 patch_intrinsic("cuda_shfl_down_sync_f32",
447 Intrinsic::nvvm_shfl_sync_down_f32);
448
449 patch_intrinsic("cuda_shfl_up_sync_i32",
450 Intrinsic::nvvm_shfl_sync_up_i32);
451 patch_intrinsic("cuda_shfl_up_sync_f32",
452 Intrinsic::nvvm_shfl_sync_up_f32);
453
454 patch_intrinsic("cuda_shfl_sync_i32", Intrinsic::nvvm_shfl_sync_idx_i32);
455
456 patch_intrinsic("cuda_shfl_sync_f32", Intrinsic::nvvm_shfl_sync_idx_f32);
457
458 patch_intrinsic("cuda_shfl_xor_sync_i32",
459 Intrinsic::nvvm_shfl_sync_bfly_i32);
460
461 patch_intrinsic("cuda_match_any_sync_i32",
462 Intrinsic::nvvm_match_any_sync_i32);
463
464 // LLVM 10.0.0 seems to have a bug on this intrinsic function
465 /*
466 nvvm_match_all_sync_i32
467 Args:
468 1. u32 mask
469 2. i32 value
470 3. i32 *pred
471 */
472 /*
473 patch_intrinsic("cuda_match_all_sync_i32p",
474 Intrinsic::nvvm_math_all_sync_i32);
475 */
476
477 // LLVM 10.0.0 seems to have a bug on this intrinsic function
478 /*
479 patch_intrinsic("cuda_match_any_sync_i64",
480 Intrinsic::nvvm_match_any_sync_i64);
481 */
482
483 patch_intrinsic("ctlz_i32", Intrinsic::ctlz, true,
484 {llvm::Type::getInt32Ty(*ctx)}, {get_constant(false)});
485 patch_intrinsic("cttz_i32", Intrinsic::cttz, true,
486 {llvm::Type::getInt32Ty(*ctx)}, {get_constant(false)});
487
488 patch_intrinsic("block_memfence", Intrinsic::nvvm_membar_cta, false);
489
490 link_module_with_cuda_libdevice(module);
491
492 // To prevent potential symbol name conflicts, we use "cuda_vprintf"
493 // instead of "vprintf" in llvm/runtime.cpp. Now we change it back for
494 // linking
495 for (auto &f : *module) {
496 if (f.getName() == "cuda_vprintf") {
497 f.setName("vprintf");
498 }
499 }
500
501 // runtime_module->print(llvm::errs(), nullptr);
502 }
503
504#ifdef TI_WITH_AMDGPU
505 auto patch_amdgpu_kernel_dim = [&](std::string name, llvm::Value *lhs) {
506 std::string actual_name;
507 if (name == "block_dim")
508 actual_name = "__ockl_get_local_size";
509 else if (name == "grid_dim")
510 actual_name = "__ockl_get_num_groups";
511 else
512 TI_ERROR("Unknown patch function name");
513 auto func = module->getFunction(name);
514 auto actual_func = module->getFunction(actual_name);
515 if (!func || !actual_func) {
516 return;
517 }
518 func->deleteBody();
519 auto bb = llvm::BasicBlock::Create(*ctx, "entry", func);
520 IRBuilder<> builder(*ctx);
521 builder.SetInsertPoint(bb);
522 auto dim_ = builder.CreateCall(actual_func->getFunctionType(),
523 actual_func, {lhs});
524 auto ret_ = builder.CreateTrunc(dim_, llvm::Type::getInt32Ty(*ctx));
525 builder.CreateRet(ret_);
526 TaichiLLVMContext::mark_inline(func);
527 };
528#endif
529
530 if (arch_ == Arch::amdgpu) {
531 module->setTargetTriple("amdgcn-amd-amdhsa");
532#ifdef TI_WITH_AMDGPU
533 llvm::legacy::FunctionPassManager function_pass_manager(module.get());
534 function_pass_manager.add(new AMDGPUConvertAllocaInstAddressSpacePass());
535 function_pass_manager.doInitialization();
536 for (auto func = module->begin(); func != module->end(); ++func) {
537 function_pass_manager.run(*func);
538 }
539 function_pass_manager.doFinalization();
540 patch_intrinsic("thread_idx", llvm::Intrinsic::amdgcn_workitem_id_x);
541 patch_intrinsic("block_idx", llvm::Intrinsic::amdgcn_workgroup_id_x);
542
543 link_module_with_amdgpu_libdevice(module);
544 patch_amdgpu_kernel_dim(
545 "block_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
546 patch_amdgpu_kernel_dim(
547 "grid_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
548#endif
549 }
550 }
551
552 return module;
553}
554
555void TaichiLLVMContext::link_module_with_cuda_libdevice(
556 std::unique_ptr<llvm::Module> &module) {
557 TI_AUTO_PROF
558 TI_ASSERT(arch_ == Arch::cuda);
559
560 auto libdevice_module =
561 module_from_bitcode_file(libdevice_path(), get_this_thread_context());
562
563 std::vector<std::string> libdevice_function_names;
564 for (auto &f : *libdevice_module) {
565 if (!f.isDeclaration()) {
566 libdevice_function_names.push_back(f.getName().str());
567 }
568 }
569
570 libdevice_module->setTargetTriple("nvptx64-nvidia-cuda");
571 module->setDataLayout(libdevice_module->getDataLayout());
572
573 bool failed = llvm::Linker::linkModules(*module, std::move(libdevice_module));
574 if (failed) {
575 TI_ERROR("CUDA libdevice linking failure.");
576 }
577
578 // Make sure all libdevice functions are linked
579 for (auto func_name : libdevice_function_names) {
580 auto func = module->getFunction(func_name);
581 if (!func) {
582 TI_INFO("Function {} not found", func_name);
583 }
584 }
585}
586
587void TaichiLLVMContext::link_module_with_amdgpu_libdevice(
588 std::unique_ptr<llvm::Module> &module) {
589 TI_ASSERT(arch_ == Arch::amdgpu);
590#if defined(TI_WITH_AMDGPU)
591 auto isa_version = AMDGPUContext::get_instance().get_mcpu().substr(3, 4);
592 std::string libdevice_files[] = {"ocml.bc",
593 "oclc_wavefrontsize64_off.bc",
594 "ockl.bc",
595 "oclc_abi_version_400.bc",
596 "oclc_correctly_rounded_sqrt_off.bc",
597 "oclc_daz_opt_off.bc",
598 "oclc_finite_only_off.bc",
599 "oclc_isa_version_" + isa_version + ".bc",
600 "oclc_unsafe_math_off.bc",
601 "opencl.bc"};
602
603 for (auto &libdevice : libdevice_files) {
604 std::string lib_dir = runtime_lib_dir() + "/";
605 auto libdevice_module = module_from_bitcode_file(lib_dir + libdevice,
606 get_this_thread_context());
607
608 if (libdevice == "ocml.bc")
609 module->setDataLayout(libdevice_module->getDataLayout());
610
611 std::vector<std::string> libdevice_func_names;
612 for (auto &f : *libdevice_module) {
613 if (!f.isDeclaration()) {
614 libdevice_func_names.push_back(f.getName().str());
615 }
616 }
617
618 for (auto &f : libdevice_module->functions()) {
619 auto func_name = libdevice.substr(0, libdevice.length() - 3);
620 if (starts_with(f.getName().lower(), "__" + func_name))
621 f.setLinkage(llvm::Function::CommonLinkage);
622 }
623
624 bool failed =
625 llvm::Linker::linkModules(*module, std::move(libdevice_module));
626 if (failed) {
627 TI_ERROR("AMDGPU libdevice linking failure.");
628 }
629 }
630#endif
631}
632
633void TaichiLLVMContext::add_struct_module(std::unique_ptr<Module> module,
634 int tree_id) {
635 TI_AUTO_PROF;
636 TI_ASSERT(std::this_thread::get_id() == main_thread_id_);
637 auto this_thread_data = get_this_thread_data();
638 TI_ASSERT(module);
639 if (llvm::verifyModule(*module, &llvm::errs())) {
640 module->print(llvm::errs(), nullptr);
641 TI_ERROR("module broken");
642 }
643
644 linking_context_data->struct_modules[tree_id] =
645 clone_module_to_context(module.get(), linking_context_data->llvm_context);
646
647 for (auto &[id, data] : per_thread_data_) {
648 if (id == std::this_thread::get_id()) {
649 continue;
650 }
651 data->struct_modules[tree_id] =
652 clone_module_to_context(module.get(), data->llvm_context);
653 }
654
655 this_thread_data->struct_modules[tree_id] = std::move(module);
656}
657template <typename T>
658llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) {
659 auto ctx = get_this_thread_context();
660 if (dt->is_primitive(PrimitiveTypeID::f32)) {
661 return llvm::ConstantFP::get(*ctx, llvm::APFloat((float32)t));
662 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
663 return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t);
664 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
665 return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t));
666 } else if (is_integral(dt)) {
667 if (is_signed(dt)) {
668 return llvm::ConstantInt::get(
669 *ctx, llvm::APInt(data_type_bits(dt), (uint64_t)t, true));
670 } else {
671 return llvm::ConstantInt::get(
672 *ctx, llvm::APInt(data_type_bits(dt), (uint64_t)t, false));
673 }
674 } else {
675 TI_NOT_IMPLEMENTED
676 }
677}
678
679template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, int32 t);
680template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, int64 t);
681template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, uint32 t);
682template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, uint64 t);
683template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, float32 t);
684template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, float64 t);
685
686template <typename T>
687llvm::Value *TaichiLLVMContext::get_constant(T t) {
688 auto ctx = get_this_thread_context();
689 TI_ASSERT(ctx != nullptr);
690 using TargetType = T;
691 if constexpr (std::is_same_v<TargetType, float32> ||
692 std::is_same_v<TargetType, float64>) {
693 return llvm::ConstantFP::get(*ctx, llvm::APFloat(t));
694 } else if (std::is_same_v<TargetType, bool>) {
695 return llvm::ConstantInt::get(*ctx, llvm::APInt(1, (uint64)t, true));
696 } else if (std::is_same_v<TargetType, int32> ||
697 std::is_same_v<TargetType, uint32>) {
698 return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true));
699 } else if (std::is_same_v<TargetType, int64> ||
700 std::is_same_v<TargetType, std::size_t> ||
701 std::is_same_v<TargetType, uint64>) {
702 static_assert(sizeof(std::size_t) == sizeof(uint64));
703 return llvm::ConstantInt::get(*ctx, llvm::APInt(64, (uint64)t, true));
704 } else {
705 TI_NOT_IMPLEMENTED
706 }
707}
708
709std::string TaichiLLVMContext::type_name(llvm::Type *type) {
710 std::string type_name;
711 llvm::raw_string_ostream rso(type_name);
712 type->print(rso);
713 return rso.str();
714}
715
716std::size_t TaichiLLVMContext::get_type_size(llvm::Type *type) {
717 return get_data_layout().getTypeAllocSize(type);
718}
719
720std::size_t TaichiLLVMContext::get_struct_element_offset(llvm::StructType *type,
721 int idx) {
722 return get_data_layout().getStructLayout(type)->getElementOffset(idx);
723}
724
725void TaichiLLVMContext::mark_inline(llvm::Function *f) {
726 for (auto &B : *f)
727 for (auto &I : B) {
728 if (auto *call = llvm::dyn_cast<llvm::CallInst>(&I)) {
729 if (auto func = call->getCalledFunction();
730 func && func->getName() == "mark_force_no_inline") {
731 // Found "mark_force_no_inline". Do not inline.
732 return;
733 }
734 }
735 }
736 f->removeFnAttr(llvm::Attribute::OptimizeNone);
737 f->removeFnAttr(llvm::Attribute::NoInline);
738 f->addFnAttr(llvm::Attribute::AlwaysInline);
739}
740
741int TaichiLLVMContext::num_instructions(llvm::Function *func) {
742 int counter = 0;
743 for (BasicBlock &bb : *func)
744 counter += std::distance(bb.begin(), bb.end());
745 return counter;
746}
747
748void TaichiLLVMContext::print_huge_functions(llvm::Module *module) {
749 int total_inst = 0;
750 int total_big_inst = 0;
751
752 for (auto &f : *module) {
753 int c = num_instructions(&f);
754 if (c > 100) {
755 total_big_inst += c;
756 TI_INFO("{}: {} inst.", std::string(f.getName()), c);
757 }
758 total_inst += c;
759 }
760 TI_P(total_inst);
761 TI_P(total_big_inst);
762}
763
764llvm::DataLayout TaichiLLVMContext::get_data_layout() {
765 return jit->get_data_layout();
766}
767
768void TaichiLLVMContext::insert_nvvm_annotation(llvm::Function *func,
769 std::string key,
770 int val) {
771 /*******************************************************************
772 Example annotation from llvm PTX doc:
773
774 define void @kernel(float addrspace(1)* %A,
775 float addrspace(1)* %B,
776 float addrspace(1)* %C);
777
778 !nvvm.annotations = !{!0}
779 !0 = !{void (float addrspace(1)*,
780 float addrspace(1)*,
781 float addrspace(1)*)* @kernel, !"kernel", i32 1}
782 *******************************************************************/
783 auto ctx = get_this_thread_context();
784 llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
785 MDString::get(*ctx, key),
786 llvm::ValueAsMetadata::get(get_constant(val))};
787
788 MDNode *md_node = MDNode::get(*ctx, md_args);
789
790 func->getParent()
791 ->getOrInsertNamedMetadata("nvvm.annotations")
792 ->addOperand(md_node);
793}
794
795void TaichiLLVMContext::mark_function_as_cuda_kernel(llvm::Function *func,
796 int block_dim) {
797 // Mark kernel function as a CUDA __global__ function
798 // Add the nvvm annotation that it is considered a kernel function.
799 insert_nvvm_annotation(func, "kernel", 1);
800 if (block_dim != 0) {
801 // CUDA launch bounds
802 insert_nvvm_annotation(func, "maxntidx", block_dim);
803 insert_nvvm_annotation(func, "minctasm", 2);
804 }
805}
806
807void TaichiLLVMContext::mark_function_as_amdgpu_kernel(llvm::Function *func) {
808 func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
809}
810
811void TaichiLLVMContext::eliminate_unused_functions(
812 llvm::Module *module,
813 std::function<bool(const std::string &)> export_indicator) {
814 TI_AUTO_PROF
815 using namespace llvm;
816 TI_ASSERT(module);
817 if (false) {
818 // temporary fix for now to make LLVM 8 work with CUDA
819 // TODO: recover this when it's time
820 if (llvm::verifyModule(*module, &llvm::errs())) {
821 TI_ERROR("Module broken\n");
822 }
823 }
824 llvm::ModulePassManager manager;
825 llvm::ModuleAnalysisManager ana;
826 llvm::PassBuilder pb;
827 pb.registerModuleAnalyses(ana);
828 manager.addPass(llvm::InternalizePass([&](const GlobalValue &val) -> bool {
829 return export_indicator(val.getName().str());
830 }));
831 manager.addPass(GlobalDCEPass());
832 manager.run(*module, ana);
833}
834
835TaichiLLVMContext::ThreadLocalData *TaichiLLVMContext::get_this_thread_data() {
836 std::lock_guard<std::mutex> _(thread_map_mut_);
837 auto tid = std::this_thread::get_id();
838 if (per_thread_data_.find(tid) == per_thread_data_.end()) {
839 std::stringstream ss;
840 ss << tid;
841 TI_TRACE("Creating thread local data for thread {}", ss.str());
842 per_thread_data_[tid] = std::make_unique<ThreadLocalData>(
843 std::make_unique<llvm::orc::ThreadSafeContext>(
844 std::make_unique<llvm::LLVMContext>()));
845 }
846 return per_thread_data_[tid].get();
847}
848
849llvm::LLVMContext *TaichiLLVMContext::get_this_thread_context() {
850 ThreadLocalData *data = get_this_thread_data();
851 TI_ASSERT(data->llvm_context)
852 return data->llvm_context;
853}
854
855llvm::orc::ThreadSafeContext *
856TaichiLLVMContext::get_this_thread_thread_safe_context() {
857 get_this_thread_context(); // make sure the context is created
858 ThreadLocalData *data = get_this_thread_data();
859 return data->thread_safe_llvm_context.get();
860}
861
862template llvm::Value *TaichiLLVMContext::get_constant(float32 t);
863template llvm::Value *TaichiLLVMContext::get_constant(float64 t);
864
865template llvm::Value *TaichiLLVMContext::get_constant(bool t);
866
867template llvm::Value *TaichiLLVMContext::get_constant(int32 t);
868template llvm::Value *TaichiLLVMContext::get_constant(uint32 t);
869
870template llvm::Value *TaichiLLVMContext::get_constant(int64 t);
871template llvm::Value *TaichiLLVMContext::get_constant(uint64 t);
872
873#ifdef TI_PLATFORM_OSX
874template llvm::Value *TaichiLLVMContext::get_constant(unsigned long t);
875#endif
876
877auto make_slim_libdevice = [](const std::vector<std::string> &args) {
878 TI_ASSERT_INFO(args.size() == 1,
879 "Usage: ti task make_slim_libdevice [libdevice.X.bc file]");
880
881 auto ctx = std::make_unique<llvm::LLVMContext>();
882 auto libdevice_module = module_from_bitcode_file(args[0], ctx.get());
883
884 remove_useless_cuda_libdevice_functions(libdevice_module.get());
885
886 std::error_code ec;
887 auto output_fn = "slim_" + args[0];
888 llvm::raw_fd_ostream os(output_fn, ec, llvm::sys::fs::OF_None);
889 llvm::WriteBitcodeToFile(*libdevice_module, os);
890 os.flush();
891 TI_INFO("Slimmed libdevice written to {}", output_fn);
892};
893
894void TaichiLLVMContext::init_runtime_module(llvm::Module *runtime_module) {
895 if (config_.arch == Arch::cuda) {
896 for (auto &f : *runtime_module) {
897 bool is_kernel = false;
898 const std::string func_name = f.getName().str();
899 if (starts_with(func_name, "runtime_")) {
900 mark_function_as_cuda_kernel(&f);
901 is_kernel = true;
902 }
903
904 if (!is_kernel && !f.isDeclaration())
905 // set declaration-only functions as internal linking to avoid
906 // duplicated symbols and to remove external symbol dependencies such
907 // as std::sin
908 f.setLinkage(llvm::Function::PrivateLinkage);
909 }
910 }
911
912 if (config_.arch == Arch::amdgpu) {
913#ifdef TI_WITH_AMDGPU
914 llvm::legacy::PassManager module_pass_manager;
915 module_pass_manager.add(new AMDGPUConvertFuncParamAddressSpacePass());
916 module_pass_manager.run(*runtime_module);
917#endif
918 }
919
920 eliminate_unused_functions(runtime_module, [](std::string func_name) {
921 return starts_with(func_name, "runtime_") ||
922 starts_with(func_name, "LLVMRuntime_");
923 });
924}
925
926void TaichiLLVMContext::delete_snode_tree(int id) {
927 TI_ASSERT(linking_context_data->struct_modules.erase(id));
928 for (auto &[thread_id, data] : per_thread_data_) {
929 TI_ASSERT(data->struct_modules.erase(id));
930 }
931}
932
933void TaichiLLVMContext::fetch_this_thread_struct_module() {
934 ThreadLocalData *data = get_this_thread_data();
935 if (data->struct_modules.empty()) {
936 for (auto &[id, mod] : main_thread_data_->struct_modules) {
937 data->struct_modules[id] = clone_module_to_this_thread_context(mod.get());
938 }
939 }
940}
941
942llvm::Function *TaichiLLVMContext::get_runtime_function(
943 const std::string &name) {
944 return get_this_thread_runtime_module()->getFunction(name);
945}
946
947llvm::Module *TaichiLLVMContext::get_this_thread_runtime_module() {
948 TI_AUTO_PROF;
949 auto data = get_this_thread_data();
950 if (!data->runtime_module) {
951 data->runtime_module = module_from_file(get_runtime_fn(arch_));
952 }
953 return data->runtime_module.get();
954}
955
956llvm::Function *TaichiLLVMContext::get_struct_function(const std::string &name,
957 int tree_id) {
958 auto *data = get_this_thread_data();
959 return data->struct_modules[tree_id]->getFunction(name);
960}
961
962llvm::Type *TaichiLLVMContext::get_runtime_type(const std::string &name) {
963 auto ty = llvm::StructType::getTypeByName(
964 get_this_thread_runtime_module()->getContext(), ("struct." + name));
965 if (!ty) {
966 TI_ERROR("LLVMRuntime type {} not found.", name);
967 }
968 return ty;
969}
970std::unique_ptr<llvm::Module> TaichiLLVMContext::new_module(
971 std::string name,
972 llvm::LLVMContext *context) {
973 auto new_mod = std::make_unique<llvm::Module>(
974 name, context ? *context : *get_this_thread_context());
975 new_mod->setDataLayout(get_this_thread_runtime_module()->getDataLayout());
976 return new_mod;
977}
978
979TaichiLLVMContext::ThreadLocalData::ThreadLocalData(
980 std::unique_ptr<llvm::orc::ThreadSafeContext> ctx)
981 : thread_safe_llvm_context(std::move(ctx)),
982 llvm_context(thread_safe_llvm_context->getContext()) {
983}
984
985TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() {
986 runtime_module.reset();
987 struct_modules.clear();
988 thread_safe_llvm_context.reset();
989}
990
991LLVMCompiledKernel TaichiLLVMContext::link_compiled_tasks(
992 std::vector<std::unique_ptr<LLVMCompiledTask>> data_list) {
993 LLVMCompiledKernel linked;
994 std::unordered_set<int> used_tree_ids;
995 std::unordered_set<int> tls_sizes;
996 std::unordered_set<std::string> offloaded_names;
997 auto mod = new_module("kernel", linking_context_data->llvm_context);
998 llvm::Linker linker(*mod);
999 for (auto &datum : data_list) {
1000 for (auto tree_id : datum->used_tree_ids) {
1001 used_tree_ids.insert(tree_id);
1002 }
1003 for (auto tls_size : datum->struct_for_tls_sizes) {
1004 tls_sizes.insert(tls_size);
1005 }
1006 for (auto &task : datum->tasks) {
1007 offloaded_names.insert(task.name);
1008 linked.tasks.push_back(std::move(task));
1009 }
1010 linker.linkInModule(clone_module_to_context(
1011 datum->module.get(), linking_context_data->llvm_context));
1012 }
1013 for (auto tree_id : used_tree_ids) {
1014 linker.linkInModule(
1015 llvm::CloneModule(*linking_context_data->struct_modules[tree_id]),
1016 llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc);
1017 }
1018 auto runtime_module =
1019 llvm::CloneModule(*linking_context_data->runtime_module);
1020 for (auto tls_size : tls_sizes) {
1021 add_struct_for_func(runtime_module.get(), tls_size);
1022 }
1023 linker.linkInModule(
1024 std::move(runtime_module),
1025 llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc);
1026 eliminate_unused_functions(mod.get(), [&](std::string func_name) -> bool {
1027 return offloaded_names.count(func_name);
1028 });
1029 linked.module = std::move(mod);
1030 return linked;
1031}
1032
1033void TaichiLLVMContext::add_struct_for_func(llvm::Module *module,
1034 int tls_size) {
1035 // Note that on CUDA local array allocation must have a compile-time
1036 // constant size. Therefore, instead of passing in the tls_buffer_size
1037 // argument, we directly clone the "parallel_struct_for" function and
1038 // replace the "alignas(8) char tls_buffer[1]" statement with "alignas(8)
1039 // char tls_buffer[tls_buffer_size]" at compile time.
1040 auto func_name = get_struct_for_func_name(tls_size);
1041 if (module->getFunction(func_name)) {
1042 return;
1043 }
1044 llvm::legacy::PassManager module_pass_manager;
1045 if (config_.arch == Arch::amdgpu) {
1046#ifdef TI_WITH_AMDGPU
1047 module_pass_manager.add(
1048 new AMDGPUAddStructForFuncPass(func_name, tls_size));
1049 module_pass_manager.run(*module);
1050#else
1051 TI_NOT_IMPLEMENTED
1052#endif
1053 } else {
1054 module_pass_manager.add(new AddStructForFuncPass(func_name, tls_size));
1055 module_pass_manager.run(*module);
1056 }
1057}
1058
1059std::string TaichiLLVMContext::get_struct_for_func_name(int tls_size) {
1060 return "parallel_struct_for_" + std::to_string(tls_size);
1061}
1062
1063std::string TaichiLLVMContext::get_data_layout_string() {
1064 return get_data_layout().getStringRepresentation();
1065}
1066
1067const StructType *TaichiLLVMContext::get_struct_type_with_data_layout(
1068 const StructType *old_ty,
1069 const std::string &layout) {
1070 if (old_ty->get_layout() == layout) {
1071 return old_ty;
1072 }
1073 std::vector<StructMember> elements = old_ty->elements();
1074 for (auto &element : elements) {
1075 if (auto struct_type = element.type->cast<StructType>()) {
1076 element.type = get_struct_type_with_data_layout(struct_type, layout);
1077 }
1078 }
1079 auto *llvm_struct_type = llvm::cast<llvm::StructType>(get_data_type(old_ty));
1080 auto data_layout = llvm::DataLayout::parse(layout);
1081 TI_ASSERT(data_layout);
1082 auto struct_layout = data_layout->getStructLayout(llvm_struct_type);
1083 for (int i = 0; i < elements.size(); i++) {
1084 elements[i].offset = struct_layout->getElementOffset(i);
1085 }
1086 return TypeFactory::get_instance()
1087 .get_struct_type(elements, layout)
1088 ->cast<StructType>();
1089}
1090
1091TI_REGISTER_TASK(make_slim_libdevice);
1092
1093} // namespace taichi::lang
1094