1 | // A llvm backend helper |
2 | |
3 | #include "taichi/runtime/llvm/llvm_context.h" |
4 | |
5 | #include "llvm/Transforms/Utils/Cloning.h" |
6 | #include "llvm/ADT/APFloat.h" |
7 | #include "llvm/ADT/STLExtras.h" |
8 | #include "llvm/IR/BasicBlock.h" |
9 | #include "llvm/IR/Constants.h" |
10 | #include "llvm/IR/DerivedTypes.h" |
11 | #include "llvm/IR/Function.h" |
12 | #include "llvm/IR/IRBuilder.h" |
13 | #include "llvm/IR/Instructions.h" |
14 | #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" |
15 | #include "llvm/IR/Module.h" |
16 | #include "llvm/IR/Intrinsics.h" |
17 | #include "llvm/IR/IntrinsicsNVPTX.h" |
18 | #ifdef TI_WITH_AMDGPU |
19 | #include "llvm/IR/IntrinsicsAMDGPU.h" |
20 | #endif // TI_WITH_AMDGPU |
21 | #include "llvm/IR/LLVMContext.h" |
22 | #include "llvm/IR/Module.h" |
23 | #include "llvm/IR/Type.h" |
24 | #include "llvm/IR/Verifier.h" |
25 | #include "llvm/Support/TargetSelect.h" |
26 | #include "llvm/Support/FileSystem.h" |
27 | #include "llvm/Target/TargetMachine.h" |
28 | #include "llvm/Transforms/InstCombine/InstCombine.h" |
29 | #include "llvm/Transforms/Scalar.h" |
30 | #include "llvm/Transforms/Scalar/GVN.h" |
31 | #include "llvm/Transforms/Utils.h" |
32 | #include "llvm/Transforms/IPO.h" |
33 | #include "llvm/Transforms/IPO/Internalize.h" |
34 | #include "llvm/Transforms/IPO/GlobalDCE.h" |
35 | #include "llvm/Pass.h" |
36 | #include "llvm/Passes/PassBuilder.h" |
37 | #include "llvm/Bitcode/BitcodeReader.h" |
38 | #include "llvm/Linker/Linker.h" |
39 | #include "llvm/Demangle/Demangle.h" |
40 | #include "llvm/Bitcode/BitcodeWriter.h" |
41 | |
42 | #include "taichi/util/lang_util.h" |
43 | #include "taichi/jit/jit_session.h" |
44 | #include "taichi/common/task.h" |
45 | #include "taichi/util/environ_config.h" |
46 | #include "llvm_context.h" |
47 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
48 | #include "taichi/codegen/codegen_utils.h" |
49 | |
50 | #include "taichi/runtime/llvm/llvm_context_pass.h" |
51 | |
52 | #ifdef _WIN32 |
53 | // Travis CI seems doesn't support <filesystem>... |
54 | #include <filesystem> |
55 | #else |
56 | #include <unistd.h> |
57 | #endif |
58 | |
59 | #if defined(TI_WITH_CUDA) |
60 | #include "taichi/rhi/cuda/cuda_context.h" |
61 | #endif |
62 | |
63 | #if defined(TI_WITH_AMDGPU) |
64 | #include "taichi/rhi/amdgpu/amdgpu_context.h" |
65 | #endif |
66 | |
67 | namespace taichi::lang { |
68 | |
69 | using namespace llvm; |
70 | |
71 | TaichiLLVMContext::TaichiLLVMContext(const CompileConfig &config, Arch arch) |
72 | : config_(config), arch_(arch) { |
73 | TI_TRACE("Creating Taichi llvm context for arch: {}" , arch_name(arch)); |
74 | main_thread_id_ = std::this_thread::get_id(); |
75 | main_thread_data_ = get_this_thread_data(); |
76 | llvm::remove_fatal_error_handler(); |
77 | llvm::install_fatal_error_handler( |
78 | [](void *user_data, const char *reason, bool gen_crash_diag) { |
79 | TI_ERROR("LLVM Fatal Error: {}" , reason); |
80 | }, |
81 | nullptr); |
82 | |
83 | if (arch_is_cpu(arch)) { |
84 | #if defined(TI_PLATFORM_OSX) and defined(TI_ARCH_ARM) |
85 | // Note that on Apple Silicon (M1), "native" seems to mean arm instead of |
86 | // arm64 (aka AArch64). |
87 | LLVMInitializeAArch64Target(); |
88 | LLVMInitializeAArch64TargetMC(); |
89 | LLVMInitializeAArch64TargetInfo(); |
90 | LLVMInitializeAArch64AsmPrinter(); |
91 | #else |
92 | llvm::InitializeNativeTarget(); |
93 | llvm::InitializeNativeTargetAsmPrinter(); |
94 | llvm::InitializeNativeTargetAsmParser(); |
95 | #endif |
96 | } else if (arch == Arch::dx12) { |
97 | // FIXME: Must initialize these before initializing Arch::dx12 |
98 | // because it uses the jit of CPU right now. |
99 | llvm::InitializeNativeTarget(); |
100 | llvm::InitializeNativeTargetAsmPrinter(); |
101 | llvm::InitializeNativeTargetAsmParser(); |
102 | // The dx target is used elsewhere, so we need to initialize it too. |
103 | #if defined(TI_WITH_DX12) |
104 | LLVMInitializeDirectXTarget(); |
105 | LLVMInitializeDirectXTargetMC(); |
106 | LLVMInitializeDirectXTargetInfo(); |
107 | LLVMInitializeDirectXAsmPrinter(); |
108 | #endif |
109 | } else if (arch == Arch::amdgpu) { |
110 | #if defined(TI_WITH_AMDGPU) |
111 | LLVMInitializeAMDGPUTarget(); |
112 | LLVMInitializeAMDGPUTargetMC(); |
113 | LLVMInitializeAMDGPUTargetInfo(); |
114 | LLVMInitializeAMDGPUAsmPrinter(); |
115 | LLVMInitializeAMDGPUAsmParser(); |
116 | #else |
117 | TI_NOT_IMPLEMENTED |
118 | #endif |
119 | } else { |
120 | #if defined(TI_WITH_CUDA) |
121 | LLVMInitializeNVPTXTarget(); |
122 | LLVMInitializeNVPTXTargetMC(); |
123 | LLVMInitializeNVPTXTargetInfo(); |
124 | LLVMInitializeNVPTXAsmPrinter(); |
125 | #else |
126 | TI_NOT_IMPLEMENTED |
127 | #endif |
128 | } |
129 | jit = JITSession::create(this, config, arch); |
130 | |
131 | linking_context_data = std::make_unique<ThreadLocalData>( |
132 | std::make_unique<llvm::orc::ThreadSafeContext>( |
133 | std::make_unique<llvm::LLVMContext>())); |
134 | linking_context_data->runtime_module = clone_module_to_context( |
135 | get_this_thread_runtime_module(), linking_context_data->llvm_context); |
136 | |
137 | TI_TRACE("Taichi llvm context created." ); |
138 | } |
139 | |
140 | TaichiLLVMContext::~TaichiLLVMContext() { |
141 | } |
142 | |
143 | llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { |
144 | auto ctx = get_this_thread_context(); |
145 | if (dt->is_primitive(PrimitiveTypeID::i8) || |
146 | dt->is_primitive(PrimitiveTypeID::u8)) { |
147 | return llvm::Type::getInt8Ty(*ctx); |
148 | } else if (dt->is_primitive(PrimitiveTypeID::i16) || |
149 | dt->is_primitive(PrimitiveTypeID::u16)) { |
150 | return llvm::Type::getInt16Ty(*ctx); |
151 | } else if (dt->is_primitive(PrimitiveTypeID::i32) || |
152 | dt->is_primitive(PrimitiveTypeID::u32)) { |
153 | return llvm::Type::getInt32Ty(*ctx); |
154 | } else if (dt->is_primitive(PrimitiveTypeID::i64) || |
155 | dt->is_primitive(PrimitiveTypeID::u64)) { |
156 | return llvm::Type::getInt64Ty(*ctx); |
157 | } else if (dt->is_primitive(PrimitiveTypeID::u1)) { |
158 | return llvm::Type::getInt1Ty(*ctx); |
159 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
160 | return llvm::Type::getFloatTy(*ctx); |
161 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
162 | return llvm::Type::getDoubleTy(*ctx); |
163 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
164 | return llvm::Type::getHalfTy(*ctx); |
165 | } else if (dt->is<TensorType>()) { |
166 | auto tensor_type = dt->cast<TensorType>(); |
167 | auto element_type = get_data_type(tensor_type->get_element_type()); |
168 | auto num_elements = tensor_type->get_num_elements(); |
169 | // Return type is <element_type * num_elements> if real matrix is used, |
170 | // otherwise [element_type * num_elements]. |
171 | if (codegen_vector_type(config_)) { |
172 | return llvm::VectorType::get(element_type, num_elements, |
173 | /*scalable=*/false); |
174 | } |
175 | return llvm::ArrayType::get(element_type, num_elements); |
176 | } else if (dt->is<StructType>()) { |
177 | std::vector<llvm::Type *> types; |
178 | auto struct_type = dt->cast<StructType>(); |
179 | for (const auto &element : struct_type->elements()) { |
180 | types.push_back(get_data_type(element.type)); |
181 | } |
182 | return llvm::StructType::get(*ctx, types); |
183 | } else { |
184 | TI_INFO(data_type_name(dt)); |
185 | TI_NOT_IMPLEMENTED; |
186 | } |
187 | } |
188 | |
189 | std::string find_existing_command(const std::vector<std::string> &commands) { |
190 | for (auto &cmd : commands) { |
191 | if (command_exist(cmd)) { |
192 | return cmd; |
193 | } |
194 | } |
195 | for (const auto &cmd : commands) { |
196 | TI_WARN("Potential command {}" , cmd); |
197 | } |
198 | TI_ERROR("None command found." ); |
199 | } |
200 | |
201 | std::string get_runtime_fn(Arch arch) { |
202 | return fmt::format("runtime_{}.bc" , arch_name(arch)); |
203 | } |
204 | |
205 | std::string libdevice_path() { |
206 | std::string folder; |
207 | folder = runtime_lib_dir(); |
208 | auto cuda_version_string = get_cuda_version_string(); |
209 | auto cuda_version_major = int(std::atof(cuda_version_string.c_str())); |
210 | return fmt::format("{}/slim_libdevice.{}.bc" , folder, cuda_version_major); |
211 | } |
212 | |
213 | std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_module_to_context( |
214 | llvm::Module *module, |
215 | llvm::LLVMContext *target_context) { |
216 | // Dump a module from one context to bitcode and then parse the bitcode in a |
217 | // different context |
218 | std::string bitcode; |
219 | |
220 | { |
221 | std::lock_guard<std::mutex> _(mut_); |
222 | llvm::raw_string_ostream sos(bitcode); |
223 | // Use a scope to make sure sos flushes on destruction |
224 | llvm::WriteBitcodeToFile(*module, sos); |
225 | } |
226 | |
227 | auto cloned = parseBitcodeFile( |
228 | llvm::MemoryBufferRef(bitcode, "runtime_bitcode" ), *target_context); |
229 | if (!cloned) { |
230 | auto error = cloned.takeError(); |
231 | TI_ERROR("Bitcode cloned failed." ); |
232 | } |
233 | return std::move(cloned.get()); |
234 | } |
235 | |
236 | std::unique_ptr<llvm::Module> |
237 | TaichiLLVMContext::clone_module_to_this_thread_context(llvm::Module *module) { |
238 | TI_TRACE("Cloning struct module" ); |
239 | TI_ASSERT(module); |
240 | auto this_context = get_this_thread_context(); |
241 | return clone_module_to_context(module, this_context); |
242 | } |
243 | |
244 | std::unique_ptr<llvm::Module> LlvmModuleBitcodeLoader::load( |
245 | llvm::LLVMContext *ctx) const { |
246 | TI_AUTO_PROF; |
247 | std::ifstream ifs(bitcode_path_, std::ios::binary); |
248 | TI_ERROR_IF(!ifs, "Bitcode file ({}) not found." , bitcode_path_); |
249 | std::string bitcode(std::istreambuf_iterator<char>(ifs), |
250 | (std::istreambuf_iterator<char>())); |
251 | auto runtime = |
252 | parseBitcodeFile(llvm::MemoryBufferRef(bitcode, buffer_id_), *ctx); |
253 | if (!runtime) { |
254 | auto error = runtime.takeError(); |
255 | TI_WARN("Bitcode loading error message:" ); |
256 | llvm::errs() << error << "\n" ; |
257 | TI_ERROR("Failed to load bitcode={}" , bitcode_path_); |
258 | return nullptr; |
259 | } |
260 | |
261 | if (inline_funcs_) { |
262 | for (auto &f : *(runtime.get())) { |
263 | TaichiLLVMContext::mark_inline(&f); |
264 | } |
265 | } |
266 | |
267 | const bool module_broken = llvm::verifyModule(*runtime.get(), &llvm::errs()); |
268 | if (module_broken) { |
269 | TI_ERROR("Broken bitcode={}" , bitcode_path_); |
270 | return nullptr; |
271 | } |
272 | return std::move(runtime.get()); |
273 | } |
274 | |
275 | std::unique_ptr<llvm::Module> module_from_bitcode_file( |
276 | const std::string &bitcode_path, |
277 | llvm::LLVMContext *ctx) { |
278 | LlvmModuleBitcodeLoader loader; |
279 | return loader.set_bitcode_path(bitcode_path) |
280 | .set_buffer_id("runtime_bitcode" ) |
281 | .set_inline_funcs(true) |
282 | .load(ctx); |
283 | } |
284 | |
285 | // The goal of this function is to rip off huge libdevice functions that are not |
286 | // going to be used later, at an early stage. Although the LLVM optimizer will |
287 | // ultimately remove unused functions during a global DCE pass, we don't even |
288 | // want these functions to waste clock cycles during module cloning and linking. |
289 | static void remove_useless_cuda_libdevice_functions(llvm::Module *module) { |
290 | std::vector<std::string> function_name_list = { |
291 | "rnorm3df" , |
292 | "norm4df" , |
293 | "rnorm4df" , |
294 | "normf" , |
295 | "rnormf" , |
296 | "j0f" , |
297 | "j1f" , |
298 | "y0f" , |
299 | "y1f" , |
300 | "ynf" , |
301 | "jnf" , |
302 | "cyl_bessel_i0f" , |
303 | "cyl_bessel_i1f" , |
304 | "j0" , |
305 | "j1" , |
306 | "y0" , |
307 | "y1" , |
308 | "yn" , |
309 | "jn" , |
310 | "cyl_bessel_i0" , |
311 | "cyl_bessel_i1" , |
312 | "tgammaf" , |
313 | "lgammaf" , |
314 | "tgamma" , |
315 | "lgamma" , |
316 | "erff" , |
317 | "erfinvf" , |
318 | "erfcf" , |
319 | "erfcxf" , |
320 | "erfcinvf" , |
321 | "erf" , |
322 | "erfinv" , |
323 | "erfcx" , |
324 | "erfcinv" , |
325 | "erfc" , |
326 | }; |
327 | for (auto fn : function_name_list) { |
328 | module->getFunction("__nv_" + fn)->eraseFromParent(); |
329 | } |
330 | module->getFunction("__internal_lgamma_pos" )->eraseFromParent(); |
331 | } |
332 | |
333 | // Note: runtime_module = init_module < struct_module |
334 | |
335 | std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_runtime_module() { |
336 | TI_AUTO_PROF |
337 | auto *mod = get_this_thread_runtime_module(); |
338 | |
339 | std::unique_ptr<llvm::Module> cloned; |
340 | { |
341 | TI_PROFILER("clone module" ); |
342 | cloned = llvm::CloneModule(*mod); |
343 | } |
344 | |
345 | TI_ASSERT(cloned != nullptr); |
346 | |
347 | return cloned; |
348 | } |
349 | |
350 | std::unique_ptr<llvm::Module> TaichiLLVMContext::module_from_file( |
351 | const std::string &file) { |
352 | auto ctx = get_this_thread_context(); |
353 | std::unique_ptr<llvm::Module> module = module_from_bitcode_file( |
354 | fmt::format("{}/{}" , runtime_lib_dir(), file), ctx); |
355 | if (arch_ == Arch::cuda || arch_ == Arch::amdgpu) { |
356 | auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin, |
357 | bool ret = true, |
358 | std::vector<llvm::Type *> types = {}, |
359 | std::vector<llvm::Value *> = {}) { |
360 | auto func = module->getFunction(name); |
361 | if (!func) { |
362 | return; |
363 | } |
364 | func->deleteBody(); |
365 | auto bb = llvm::BasicBlock::Create(*ctx, "entry" , func); |
366 | IRBuilder<> builder(*ctx); |
367 | builder.SetInsertPoint(bb); |
368 | std::vector<llvm::Value *> args; |
369 | for (auto &arg : func->args()) |
370 | args.push_back(&arg); |
371 | args.insert(args.end(), extra_args.begin(), extra_args.end()); |
372 | if (ret) { |
373 | builder.CreateRet(builder.CreateIntrinsic(intrin, types, args)); |
374 | } else { |
375 | builder.CreateIntrinsic(intrin, types, args); |
376 | builder.CreateRetVoid(); |
377 | } |
378 | TaichiLLVMContext::mark_inline(func); |
379 | }; |
380 | |
381 | auto patch_atomic_add = [&](std::string name, |
382 | llvm::AtomicRMWInst::BinOp op) { |
383 | auto func = module->getFunction(name); |
384 | if (!func) { |
385 | return; |
386 | } |
387 | func->deleteBody(); |
388 | auto bb = llvm::BasicBlock::Create(*ctx, "entry" , func); |
389 | IRBuilder<> builder(*ctx); |
390 | builder.SetInsertPoint(bb); |
391 | std::vector<llvm::Value *> args; |
392 | for (auto &arg : func->args()) |
393 | args.push_back(&arg); |
394 | builder.CreateRet(builder.CreateAtomicRMW( |
395 | op, args[0], args[1], llvm::MaybeAlign(0), |
396 | llvm::AtomicOrdering::SequentiallyConsistent)); |
397 | TaichiLLVMContext::mark_inline(func); |
398 | }; |
399 | |
400 | patch_atomic_add("atomic_add_i32" , llvm::AtomicRMWInst::Add); |
401 | patch_atomic_add("atomic_add_i64" , llvm::AtomicRMWInst::Add); |
402 | patch_atomic_add("atomic_add_f64" , llvm::AtomicRMWInst::FAdd); |
403 | patch_atomic_add("atomic_add_f32" , llvm::AtomicRMWInst::FAdd); |
404 | |
405 | if (arch_ == Arch::cuda) { |
406 | module->setTargetTriple("nvptx64-nvidia-cuda" ); |
407 | |
408 | #if defined(TI_WITH_CUDA) |
409 | auto func = module->getFunction("cuda_compute_capability" ); |
410 | if (func) { |
411 | func->deleteBody(); |
412 | auto bb = llvm::BasicBlock::Create(*ctx, "entry" , func); |
413 | IRBuilder<> builder(*ctx); |
414 | builder.SetInsertPoint(bb); |
415 | builder.CreateRet( |
416 | get_constant(CUDAContext::get_instance().get_compute_capability())); |
417 | TaichiLLVMContext::mark_inline(func); |
418 | } |
419 | #endif |
420 | |
421 | patch_intrinsic("thread_idx" , Intrinsic::nvvm_read_ptx_sreg_tid_x); |
422 | patch_intrinsic("cuda_clock_i64" , Intrinsic::nvvm_read_ptx_sreg_clock64); |
423 | patch_intrinsic("block_idx" , Intrinsic::nvvm_read_ptx_sreg_ctaid_x); |
424 | patch_intrinsic("block_dim" , Intrinsic::nvvm_read_ptx_sreg_ntid_x); |
425 | patch_intrinsic("grid_dim" , Intrinsic::nvvm_read_ptx_sreg_nctaid_x); |
426 | patch_intrinsic("block_barrier" , Intrinsic::nvvm_barrier0, false); |
427 | patch_intrinsic("warp_barrier" , Intrinsic::nvvm_bar_warp_sync, false); |
428 | patch_intrinsic("block_memfence" , Intrinsic::nvvm_membar_cta, false); |
429 | patch_intrinsic("grid_memfence" , Intrinsic::nvvm_membar_gl, false); |
430 | patch_intrinsic("system_memfence" , Intrinsic::nvvm_membar_sys, false); |
431 | |
432 | patch_intrinsic("cuda_all" , Intrinsic::nvvm_vote_all); |
433 | patch_intrinsic("cuda_all_sync" , Intrinsic::nvvm_vote_all_sync); |
434 | |
435 | patch_intrinsic("cuda_any" , Intrinsic::nvvm_vote_any); |
436 | patch_intrinsic("cuda_any_sync" , Intrinsic::nvvm_vote_any_sync); |
437 | |
438 | patch_intrinsic("cuda_uni" , Intrinsic::nvvm_vote_uni); |
439 | patch_intrinsic("cuda_uni_sync" , Intrinsic::nvvm_vote_uni_sync); |
440 | |
441 | patch_intrinsic("cuda_ballot" , Intrinsic::nvvm_vote_ballot); |
442 | patch_intrinsic("cuda_ballot_sync" , Intrinsic::nvvm_vote_ballot_sync); |
443 | |
444 | patch_intrinsic("cuda_shfl_down_sync_i32" , |
445 | Intrinsic::nvvm_shfl_sync_down_i32); |
446 | patch_intrinsic("cuda_shfl_down_sync_f32" , |
447 | Intrinsic::nvvm_shfl_sync_down_f32); |
448 | |
449 | patch_intrinsic("cuda_shfl_up_sync_i32" , |
450 | Intrinsic::nvvm_shfl_sync_up_i32); |
451 | patch_intrinsic("cuda_shfl_up_sync_f32" , |
452 | Intrinsic::nvvm_shfl_sync_up_f32); |
453 | |
454 | patch_intrinsic("cuda_shfl_sync_i32" , Intrinsic::nvvm_shfl_sync_idx_i32); |
455 | |
456 | patch_intrinsic("cuda_shfl_sync_f32" , Intrinsic::nvvm_shfl_sync_idx_f32); |
457 | |
458 | patch_intrinsic("cuda_shfl_xor_sync_i32" , |
459 | Intrinsic::nvvm_shfl_sync_bfly_i32); |
460 | |
461 | patch_intrinsic("cuda_match_any_sync_i32" , |
462 | Intrinsic::nvvm_match_any_sync_i32); |
463 | |
464 | // LLVM 10.0.0 seems to have a bug on this intrinsic function |
465 | /* |
466 | nvvm_match_all_sync_i32 |
467 | Args: |
468 | 1. u32 mask |
469 | 2. i32 value |
470 | 3. i32 *pred |
471 | */ |
472 | /* |
473 | patch_intrinsic("cuda_match_all_sync_i32p", |
474 | Intrinsic::nvvm_math_all_sync_i32); |
475 | */ |
476 | |
477 | // LLVM 10.0.0 seems to have a bug on this intrinsic function |
478 | /* |
479 | patch_intrinsic("cuda_match_any_sync_i64", |
480 | Intrinsic::nvvm_match_any_sync_i64); |
481 | */ |
482 | |
483 | patch_intrinsic("ctlz_i32" , Intrinsic::ctlz, true, |
484 | {llvm::Type::getInt32Ty(*ctx)}, {get_constant(false)}); |
485 | patch_intrinsic("cttz_i32" , Intrinsic::cttz, true, |
486 | {llvm::Type::getInt32Ty(*ctx)}, {get_constant(false)}); |
487 | |
488 | patch_intrinsic("block_memfence" , Intrinsic::nvvm_membar_cta, false); |
489 | |
490 | link_module_with_cuda_libdevice(module); |
491 | |
492 | // To prevent potential symbol name conflicts, we use "cuda_vprintf" |
493 | // instead of "vprintf" in llvm/runtime.cpp. Now we change it back for |
494 | // linking |
495 | for (auto &f : *module) { |
496 | if (f.getName() == "cuda_vprintf" ) { |
497 | f.setName("vprintf" ); |
498 | } |
499 | } |
500 | |
501 | // runtime_module->print(llvm::errs(), nullptr); |
502 | } |
503 | |
504 | #ifdef TI_WITH_AMDGPU |
505 | auto patch_amdgpu_kernel_dim = [&](std::string name, llvm::Value *lhs) { |
506 | std::string actual_name; |
507 | if (name == "block_dim" ) |
508 | actual_name = "__ockl_get_local_size" ; |
509 | else if (name == "grid_dim" ) |
510 | actual_name = "__ockl_get_num_groups" ; |
511 | else |
512 | TI_ERROR("Unknown patch function name" ); |
513 | auto func = module->getFunction(name); |
514 | auto actual_func = module->getFunction(actual_name); |
515 | if (!func || !actual_func) { |
516 | return; |
517 | } |
518 | func->deleteBody(); |
519 | auto bb = llvm::BasicBlock::Create(*ctx, "entry" , func); |
520 | IRBuilder<> builder(*ctx); |
521 | builder.SetInsertPoint(bb); |
522 | auto dim_ = builder.CreateCall(actual_func->getFunctionType(), |
523 | actual_func, {lhs}); |
524 | auto ret_ = builder.CreateTrunc(dim_, llvm::Type::getInt32Ty(*ctx)); |
525 | builder.CreateRet(ret_); |
526 | TaichiLLVMContext::mark_inline(func); |
527 | }; |
528 | #endif |
529 | |
530 | if (arch_ == Arch::amdgpu) { |
531 | module->setTargetTriple("amdgcn-amd-amdhsa" ); |
532 | #ifdef TI_WITH_AMDGPU |
533 | llvm::legacy::FunctionPassManager function_pass_manager(module.get()); |
534 | function_pass_manager.add(new AMDGPUConvertAllocaInstAddressSpacePass()); |
535 | function_pass_manager.doInitialization(); |
536 | for (auto func = module->begin(); func != module->end(); ++func) { |
537 | function_pass_manager.run(*func); |
538 | } |
539 | function_pass_manager.doFinalization(); |
540 | patch_intrinsic("thread_idx" , llvm::Intrinsic::amdgcn_workitem_id_x); |
541 | patch_intrinsic("block_idx" , llvm::Intrinsic::amdgcn_workgroup_id_x); |
542 | |
543 | link_module_with_amdgpu_libdevice(module); |
544 | patch_amdgpu_kernel_dim( |
545 | "block_dim" , llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0)); |
546 | patch_amdgpu_kernel_dim( |
547 | "grid_dim" , llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0)); |
548 | #endif |
549 | } |
550 | } |
551 | |
552 | return module; |
553 | } |
554 | |
555 | void TaichiLLVMContext::link_module_with_cuda_libdevice( |
556 | std::unique_ptr<llvm::Module> &module) { |
557 | TI_AUTO_PROF |
558 | TI_ASSERT(arch_ == Arch::cuda); |
559 | |
560 | auto libdevice_module = |
561 | module_from_bitcode_file(libdevice_path(), get_this_thread_context()); |
562 | |
563 | std::vector<std::string> libdevice_function_names; |
564 | for (auto &f : *libdevice_module) { |
565 | if (!f.isDeclaration()) { |
566 | libdevice_function_names.push_back(f.getName().str()); |
567 | } |
568 | } |
569 | |
570 | libdevice_module->setTargetTriple("nvptx64-nvidia-cuda" ); |
571 | module->setDataLayout(libdevice_module->getDataLayout()); |
572 | |
573 | bool failed = llvm::Linker::linkModules(*module, std::move(libdevice_module)); |
574 | if (failed) { |
575 | TI_ERROR("CUDA libdevice linking failure." ); |
576 | } |
577 | |
578 | // Make sure all libdevice functions are linked |
579 | for (auto func_name : libdevice_function_names) { |
580 | auto func = module->getFunction(func_name); |
581 | if (!func) { |
582 | TI_INFO("Function {} not found" , func_name); |
583 | } |
584 | } |
585 | } |
586 | |
587 | void TaichiLLVMContext::link_module_with_amdgpu_libdevice( |
588 | std::unique_ptr<llvm::Module> &module) { |
589 | TI_ASSERT(arch_ == Arch::amdgpu); |
590 | #if defined(TI_WITH_AMDGPU) |
591 | auto isa_version = AMDGPUContext::get_instance().get_mcpu().substr(3, 4); |
592 | std::string libdevice_files[] = {"ocml.bc" , |
593 | "oclc_wavefrontsize64_off.bc" , |
594 | "ockl.bc" , |
595 | "oclc_abi_version_400.bc" , |
596 | "oclc_correctly_rounded_sqrt_off.bc" , |
597 | "oclc_daz_opt_off.bc" , |
598 | "oclc_finite_only_off.bc" , |
599 | "oclc_isa_version_" + isa_version + ".bc" , |
600 | "oclc_unsafe_math_off.bc" , |
601 | "opencl.bc" }; |
602 | |
603 | for (auto &libdevice : libdevice_files) { |
604 | std::string lib_dir = runtime_lib_dir() + "/" ; |
605 | auto libdevice_module = module_from_bitcode_file(lib_dir + libdevice, |
606 | get_this_thread_context()); |
607 | |
608 | if (libdevice == "ocml.bc" ) |
609 | module->setDataLayout(libdevice_module->getDataLayout()); |
610 | |
611 | std::vector<std::string> libdevice_func_names; |
612 | for (auto &f : *libdevice_module) { |
613 | if (!f.isDeclaration()) { |
614 | libdevice_func_names.push_back(f.getName().str()); |
615 | } |
616 | } |
617 | |
618 | for (auto &f : libdevice_module->functions()) { |
619 | auto func_name = libdevice.substr(0, libdevice.length() - 3); |
620 | if (starts_with(f.getName().lower(), "__" + func_name)) |
621 | f.setLinkage(llvm::Function::CommonLinkage); |
622 | } |
623 | |
624 | bool failed = |
625 | llvm::Linker::linkModules(*module, std::move(libdevice_module)); |
626 | if (failed) { |
627 | TI_ERROR("AMDGPU libdevice linking failure." ); |
628 | } |
629 | } |
630 | #endif |
631 | } |
632 | |
633 | void TaichiLLVMContext::add_struct_module(std::unique_ptr<Module> module, |
634 | int tree_id) { |
635 | TI_AUTO_PROF; |
636 | TI_ASSERT(std::this_thread::get_id() == main_thread_id_); |
637 | auto this_thread_data = get_this_thread_data(); |
638 | TI_ASSERT(module); |
639 | if (llvm::verifyModule(*module, &llvm::errs())) { |
640 | module->print(llvm::errs(), nullptr); |
641 | TI_ERROR("module broken" ); |
642 | } |
643 | |
644 | linking_context_data->struct_modules[tree_id] = |
645 | clone_module_to_context(module.get(), linking_context_data->llvm_context); |
646 | |
647 | for (auto &[id, data] : per_thread_data_) { |
648 | if (id == std::this_thread::get_id()) { |
649 | continue; |
650 | } |
651 | data->struct_modules[tree_id] = |
652 | clone_module_to_context(module.get(), data->llvm_context); |
653 | } |
654 | |
655 | this_thread_data->struct_modules[tree_id] = std::move(module); |
656 | } |
657 | template <typename T> |
658 | llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) { |
659 | auto ctx = get_this_thread_context(); |
660 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
661 | return llvm::ConstantFP::get(*ctx, llvm::APFloat((float32)t)); |
662 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
663 | return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t); |
664 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
665 | return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t)); |
666 | } else if (is_integral(dt)) { |
667 | if (is_signed(dt)) { |
668 | return llvm::ConstantInt::get( |
669 | *ctx, llvm::APInt(data_type_bits(dt), (uint64_t)t, true)); |
670 | } else { |
671 | return llvm::ConstantInt::get( |
672 | *ctx, llvm::APInt(data_type_bits(dt), (uint64_t)t, false)); |
673 | } |
674 | } else { |
675 | TI_NOT_IMPLEMENTED |
676 | } |
677 | } |
678 | |
679 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, int32 t); |
680 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, int64 t); |
681 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, uint32 t); |
682 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, uint64 t); |
683 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, float32 t); |
684 | template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, float64 t); |
685 | |
686 | template <typename T> |
687 | llvm::Value *TaichiLLVMContext::get_constant(T t) { |
688 | auto ctx = get_this_thread_context(); |
689 | TI_ASSERT(ctx != nullptr); |
690 | using TargetType = T; |
691 | if constexpr (std::is_same_v<TargetType, float32> || |
692 | std::is_same_v<TargetType, float64>) { |
693 | return llvm::ConstantFP::get(*ctx, llvm::APFloat(t)); |
694 | } else if (std::is_same_v<TargetType, bool>) { |
695 | return llvm::ConstantInt::get(*ctx, llvm::APInt(1, (uint64)t, true)); |
696 | } else if (std::is_same_v<TargetType, int32> || |
697 | std::is_same_v<TargetType, uint32>) { |
698 | return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true)); |
699 | } else if (std::is_same_v<TargetType, int64> || |
700 | std::is_same_v<TargetType, std::size_t> || |
701 | std::is_same_v<TargetType, uint64>) { |
702 | static_assert(sizeof(std::size_t) == sizeof(uint64)); |
703 | return llvm::ConstantInt::get(*ctx, llvm::APInt(64, (uint64)t, true)); |
704 | } else { |
705 | TI_NOT_IMPLEMENTED |
706 | } |
707 | } |
708 | |
709 | std::string TaichiLLVMContext::type_name(llvm::Type *type) { |
710 | std::string type_name; |
711 | llvm::raw_string_ostream rso(type_name); |
712 | type->print(rso); |
713 | return rso.str(); |
714 | } |
715 | |
716 | std::size_t TaichiLLVMContext::get_type_size(llvm::Type *type) { |
717 | return get_data_layout().getTypeAllocSize(type); |
718 | } |
719 | |
720 | std::size_t TaichiLLVMContext::get_struct_element_offset(llvm::StructType *type, |
721 | int idx) { |
722 | return get_data_layout().getStructLayout(type)->getElementOffset(idx); |
723 | } |
724 | |
725 | void TaichiLLVMContext::mark_inline(llvm::Function *f) { |
726 | for (auto &B : *f) |
727 | for (auto &I : B) { |
728 | if (auto *call = llvm::dyn_cast<llvm::CallInst>(&I)) { |
729 | if (auto func = call->getCalledFunction(); |
730 | func && func->getName() == "mark_force_no_inline" ) { |
731 | // Found "mark_force_no_inline". Do not inline. |
732 | return; |
733 | } |
734 | } |
735 | } |
736 | f->removeFnAttr(llvm::Attribute::OptimizeNone); |
737 | f->removeFnAttr(llvm::Attribute::NoInline); |
738 | f->addFnAttr(llvm::Attribute::AlwaysInline); |
739 | } |
740 | |
741 | int TaichiLLVMContext::num_instructions(llvm::Function *func) { |
742 | int counter = 0; |
743 | for (BasicBlock &bb : *func) |
744 | counter += std::distance(bb.begin(), bb.end()); |
745 | return counter; |
746 | } |
747 | |
748 | void TaichiLLVMContext::print_huge_functions(llvm::Module *module) { |
749 | int total_inst = 0; |
750 | int total_big_inst = 0; |
751 | |
752 | for (auto &f : *module) { |
753 | int c = num_instructions(&f); |
754 | if (c > 100) { |
755 | total_big_inst += c; |
756 | TI_INFO("{}: {} inst." , std::string(f.getName()), c); |
757 | } |
758 | total_inst += c; |
759 | } |
760 | TI_P(total_inst); |
761 | TI_P(total_big_inst); |
762 | } |
763 | |
764 | llvm::DataLayout TaichiLLVMContext::get_data_layout() { |
765 | return jit->get_data_layout(); |
766 | } |
767 | |
768 | void TaichiLLVMContext::insert_nvvm_annotation(llvm::Function *func, |
769 | std::string key, |
770 | int val) { |
771 | /******************************************************************* |
772 | Example annotation from llvm PTX doc: |
773 | |
774 | define void @kernel(float addrspace(1)* %A, |
775 | float addrspace(1)* %B, |
776 | float addrspace(1)* %C); |
777 | |
778 | !nvvm.annotations = !{!0} |
779 | !0 = !{void (float addrspace(1)*, |
780 | float addrspace(1)*, |
781 | float addrspace(1)*)* @kernel, !"kernel", i32 1} |
782 | *******************************************************************/ |
783 | auto ctx = get_this_thread_context(); |
784 | llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func), |
785 | MDString::get(*ctx, key), |
786 | llvm::ValueAsMetadata::get(get_constant(val))}; |
787 | |
788 | MDNode *md_node = MDNode::get(*ctx, md_args); |
789 | |
790 | func->getParent() |
791 | ->getOrInsertNamedMetadata("nvvm.annotations" ) |
792 | ->addOperand(md_node); |
793 | } |
794 | |
795 | void TaichiLLVMContext::mark_function_as_cuda_kernel(llvm::Function *func, |
796 | int block_dim) { |
797 | // Mark kernel function as a CUDA __global__ function |
798 | // Add the nvvm annotation that it is considered a kernel function. |
799 | insert_nvvm_annotation(func, "kernel" , 1); |
800 | if (block_dim != 0) { |
801 | // CUDA launch bounds |
802 | insert_nvvm_annotation(func, "maxntidx" , block_dim); |
803 | insert_nvvm_annotation(func, "minctasm" , 2); |
804 | } |
805 | } |
806 | |
807 | void TaichiLLVMContext::mark_function_as_amdgpu_kernel(llvm::Function *func) { |
808 | func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
809 | } |
810 | |
811 | void TaichiLLVMContext::eliminate_unused_functions( |
812 | llvm::Module *module, |
813 | std::function<bool(const std::string &)> export_indicator) { |
814 | TI_AUTO_PROF |
815 | using namespace llvm; |
816 | TI_ASSERT(module); |
817 | if (false) { |
818 | // temporary fix for now to make LLVM 8 work with CUDA |
819 | // TODO: recover this when it's time |
820 | if (llvm::verifyModule(*module, &llvm::errs())) { |
821 | TI_ERROR("Module broken\n" ); |
822 | } |
823 | } |
824 | llvm::ModulePassManager manager; |
825 | llvm::ModuleAnalysisManager ana; |
826 | llvm::PassBuilder pb; |
827 | pb.registerModuleAnalyses(ana); |
828 | manager.addPass(llvm::InternalizePass([&](const GlobalValue &val) -> bool { |
829 | return export_indicator(val.getName().str()); |
830 | })); |
831 | manager.addPass(GlobalDCEPass()); |
832 | manager.run(*module, ana); |
833 | } |
834 | |
835 | TaichiLLVMContext::ThreadLocalData *TaichiLLVMContext::get_this_thread_data() { |
836 | std::lock_guard<std::mutex> _(thread_map_mut_); |
837 | auto tid = std::this_thread::get_id(); |
838 | if (per_thread_data_.find(tid) == per_thread_data_.end()) { |
839 | std::stringstream ss; |
840 | ss << tid; |
841 | TI_TRACE("Creating thread local data for thread {}" , ss.str()); |
842 | per_thread_data_[tid] = std::make_unique<ThreadLocalData>( |
843 | std::make_unique<llvm::orc::ThreadSafeContext>( |
844 | std::make_unique<llvm::LLVMContext>())); |
845 | } |
846 | return per_thread_data_[tid].get(); |
847 | } |
848 | |
849 | llvm::LLVMContext *TaichiLLVMContext::get_this_thread_context() { |
850 | ThreadLocalData *data = get_this_thread_data(); |
851 | TI_ASSERT(data->llvm_context) |
852 | return data->llvm_context; |
853 | } |
854 | |
855 | llvm::orc::ThreadSafeContext * |
856 | TaichiLLVMContext::get_this_thread_thread_safe_context() { |
857 | get_this_thread_context(); // make sure the context is created |
858 | ThreadLocalData *data = get_this_thread_data(); |
859 | return data->thread_safe_llvm_context.get(); |
860 | } |
861 | |
862 | template llvm::Value *TaichiLLVMContext::get_constant(float32 t); |
863 | template llvm::Value *TaichiLLVMContext::get_constant(float64 t); |
864 | |
865 | template llvm::Value *TaichiLLVMContext::get_constant(bool t); |
866 | |
867 | template llvm::Value *TaichiLLVMContext::get_constant(int32 t); |
868 | template llvm::Value *TaichiLLVMContext::get_constant(uint32 t); |
869 | |
870 | template llvm::Value *TaichiLLVMContext::get_constant(int64 t); |
871 | template llvm::Value *TaichiLLVMContext::get_constant(uint64 t); |
872 | |
873 | #ifdef TI_PLATFORM_OSX |
874 | template llvm::Value *TaichiLLVMContext::get_constant(unsigned long t); |
875 | #endif |
876 | |
877 | auto make_slim_libdevice = [](const std::vector<std::string> &args) { |
878 | TI_ASSERT_INFO(args.size() == 1, |
879 | "Usage: ti task make_slim_libdevice [libdevice.X.bc file]" ); |
880 | |
881 | auto ctx = std::make_unique<llvm::LLVMContext>(); |
882 | auto libdevice_module = module_from_bitcode_file(args[0], ctx.get()); |
883 | |
884 | remove_useless_cuda_libdevice_functions(libdevice_module.get()); |
885 | |
886 | std::error_code ec; |
887 | auto output_fn = "slim_" + args[0]; |
888 | llvm::raw_fd_ostream os(output_fn, ec, llvm::sys::fs::OF_None); |
889 | llvm::WriteBitcodeToFile(*libdevice_module, os); |
890 | os.flush(); |
891 | TI_INFO("Slimmed libdevice written to {}" , output_fn); |
892 | }; |
893 | |
894 | void TaichiLLVMContext::init_runtime_module(llvm::Module *runtime_module) { |
895 | if (config_.arch == Arch::cuda) { |
896 | for (auto &f : *runtime_module) { |
897 | bool is_kernel = false; |
898 | const std::string func_name = f.getName().str(); |
899 | if (starts_with(func_name, "runtime_" )) { |
900 | mark_function_as_cuda_kernel(&f); |
901 | is_kernel = true; |
902 | } |
903 | |
904 | if (!is_kernel && !f.isDeclaration()) |
905 | // set declaration-only functions as internal linking to avoid |
906 | // duplicated symbols and to remove external symbol dependencies such |
907 | // as std::sin |
908 | f.setLinkage(llvm::Function::PrivateLinkage); |
909 | } |
910 | } |
911 | |
912 | if (config_.arch == Arch::amdgpu) { |
913 | #ifdef TI_WITH_AMDGPU |
914 | llvm::legacy::PassManager module_pass_manager; |
915 | module_pass_manager.add(new AMDGPUConvertFuncParamAddressSpacePass()); |
916 | module_pass_manager.run(*runtime_module); |
917 | #endif |
918 | } |
919 | |
920 | eliminate_unused_functions(runtime_module, [](std::string func_name) { |
921 | return starts_with(func_name, "runtime_" ) || |
922 | starts_with(func_name, "LLVMRuntime_" ); |
923 | }); |
924 | } |
925 | |
926 | void TaichiLLVMContext::delete_snode_tree(int id) { |
927 | TI_ASSERT(linking_context_data->struct_modules.erase(id)); |
928 | for (auto &[thread_id, data] : per_thread_data_) { |
929 | TI_ASSERT(data->struct_modules.erase(id)); |
930 | } |
931 | } |
932 | |
933 | void TaichiLLVMContext::fetch_this_thread_struct_module() { |
934 | ThreadLocalData *data = get_this_thread_data(); |
935 | if (data->struct_modules.empty()) { |
936 | for (auto &[id, mod] : main_thread_data_->struct_modules) { |
937 | data->struct_modules[id] = clone_module_to_this_thread_context(mod.get()); |
938 | } |
939 | } |
940 | } |
941 | |
942 | llvm::Function *TaichiLLVMContext::get_runtime_function( |
943 | const std::string &name) { |
944 | return get_this_thread_runtime_module()->getFunction(name); |
945 | } |
946 | |
947 | llvm::Module *TaichiLLVMContext::get_this_thread_runtime_module() { |
948 | TI_AUTO_PROF; |
949 | auto data = get_this_thread_data(); |
950 | if (!data->runtime_module) { |
951 | data->runtime_module = module_from_file(get_runtime_fn(arch_)); |
952 | } |
953 | return data->runtime_module.get(); |
954 | } |
955 | |
956 | llvm::Function *TaichiLLVMContext::get_struct_function(const std::string &name, |
957 | int tree_id) { |
958 | auto *data = get_this_thread_data(); |
959 | return data->struct_modules[tree_id]->getFunction(name); |
960 | } |
961 | |
962 | llvm::Type *TaichiLLVMContext::get_runtime_type(const std::string &name) { |
963 | auto ty = llvm::StructType::getTypeByName( |
964 | get_this_thread_runtime_module()->getContext(), ("struct." + name)); |
965 | if (!ty) { |
966 | TI_ERROR("LLVMRuntime type {} not found." , name); |
967 | } |
968 | return ty; |
969 | } |
970 | std::unique_ptr<llvm::Module> TaichiLLVMContext::new_module( |
971 | std::string name, |
972 | llvm::LLVMContext *context) { |
973 | auto new_mod = std::make_unique<llvm::Module>( |
974 | name, context ? *context : *get_this_thread_context()); |
975 | new_mod->setDataLayout(get_this_thread_runtime_module()->getDataLayout()); |
976 | return new_mod; |
977 | } |
978 | |
979 | TaichiLLVMContext::ThreadLocalData::ThreadLocalData( |
980 | std::unique_ptr<llvm::orc::ThreadSafeContext> ctx) |
981 | : thread_safe_llvm_context(std::move(ctx)), |
982 | llvm_context(thread_safe_llvm_context->getContext()) { |
983 | } |
984 | |
985 | TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() { |
986 | runtime_module.reset(); |
987 | struct_modules.clear(); |
988 | thread_safe_llvm_context.reset(); |
989 | } |
990 | |
991 | LLVMCompiledKernel TaichiLLVMContext::link_compiled_tasks( |
992 | std::vector<std::unique_ptr<LLVMCompiledTask>> data_list) { |
993 | LLVMCompiledKernel linked; |
994 | std::unordered_set<int> used_tree_ids; |
995 | std::unordered_set<int> tls_sizes; |
996 | std::unordered_set<std::string> offloaded_names; |
997 | auto mod = new_module("kernel" , linking_context_data->llvm_context); |
998 | llvm::Linker linker(*mod); |
999 | for (auto &datum : data_list) { |
1000 | for (auto tree_id : datum->used_tree_ids) { |
1001 | used_tree_ids.insert(tree_id); |
1002 | } |
1003 | for (auto tls_size : datum->struct_for_tls_sizes) { |
1004 | tls_sizes.insert(tls_size); |
1005 | } |
1006 | for (auto &task : datum->tasks) { |
1007 | offloaded_names.insert(task.name); |
1008 | linked.tasks.push_back(std::move(task)); |
1009 | } |
1010 | linker.linkInModule(clone_module_to_context( |
1011 | datum->module.get(), linking_context_data->llvm_context)); |
1012 | } |
1013 | for (auto tree_id : used_tree_ids) { |
1014 | linker.linkInModule( |
1015 | llvm::CloneModule(*linking_context_data->struct_modules[tree_id]), |
1016 | llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc); |
1017 | } |
1018 | auto runtime_module = |
1019 | llvm::CloneModule(*linking_context_data->runtime_module); |
1020 | for (auto tls_size : tls_sizes) { |
1021 | add_struct_for_func(runtime_module.get(), tls_size); |
1022 | } |
1023 | linker.linkInModule( |
1024 | std::move(runtime_module), |
1025 | llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc); |
1026 | eliminate_unused_functions(mod.get(), [&](std::string func_name) -> bool { |
1027 | return offloaded_names.count(func_name); |
1028 | }); |
1029 | linked.module = std::move(mod); |
1030 | return linked; |
1031 | } |
1032 | |
1033 | void TaichiLLVMContext::add_struct_for_func(llvm::Module *module, |
1034 | int tls_size) { |
1035 | // Note that on CUDA local array allocation must have a compile-time |
1036 | // constant size. Therefore, instead of passing in the tls_buffer_size |
1037 | // argument, we directly clone the "parallel_struct_for" function and |
1038 | // replace the "alignas(8) char tls_buffer[1]" statement with "alignas(8) |
1039 | // char tls_buffer[tls_buffer_size]" at compile time. |
1040 | auto func_name = get_struct_for_func_name(tls_size); |
1041 | if (module->getFunction(func_name)) { |
1042 | return; |
1043 | } |
1044 | llvm::legacy::PassManager module_pass_manager; |
1045 | if (config_.arch == Arch::amdgpu) { |
1046 | #ifdef TI_WITH_AMDGPU |
1047 | module_pass_manager.add( |
1048 | new AMDGPUAddStructForFuncPass(func_name, tls_size)); |
1049 | module_pass_manager.run(*module); |
1050 | #else |
1051 | TI_NOT_IMPLEMENTED |
1052 | #endif |
1053 | } else { |
1054 | module_pass_manager.add(new AddStructForFuncPass(func_name, tls_size)); |
1055 | module_pass_manager.run(*module); |
1056 | } |
1057 | } |
1058 | |
1059 | std::string TaichiLLVMContext::get_struct_for_func_name(int tls_size) { |
1060 | return "parallel_struct_for_" + std::to_string(tls_size); |
1061 | } |
1062 | |
1063 | std::string TaichiLLVMContext::get_data_layout_string() { |
1064 | return get_data_layout().getStringRepresentation(); |
1065 | } |
1066 | |
1067 | const StructType *TaichiLLVMContext::get_struct_type_with_data_layout( |
1068 | const StructType *old_ty, |
1069 | const std::string &layout) { |
1070 | if (old_ty->get_layout() == layout) { |
1071 | return old_ty; |
1072 | } |
1073 | std::vector<StructMember> elements = old_ty->elements(); |
1074 | for (auto &element : elements) { |
1075 | if (auto struct_type = element.type->cast<StructType>()) { |
1076 | element.type = get_struct_type_with_data_layout(struct_type, layout); |
1077 | } |
1078 | } |
1079 | auto *llvm_struct_type = llvm::cast<llvm::StructType>(get_data_type(old_ty)); |
1080 | auto data_layout = llvm::DataLayout::parse(layout); |
1081 | TI_ASSERT(data_layout); |
1082 | auto struct_layout = data_layout->getStructLayout(llvm_struct_type); |
1083 | for (int i = 0; i < elements.size(); i++) { |
1084 | elements[i].offset = struct_layout->getElementOffset(i); |
1085 | } |
1086 | return TypeFactory::get_instance() |
1087 | .get_struct_type(elements, layout) |
1088 | ->cast<StructType>(); |
1089 | } |
1090 | |
1091 | TI_REGISTER_TASK(make_slim_libdevice); |
1092 | |
1093 | } // namespace taichi::lang |
1094 | |