1// Driver class for kernel codegen
2
3#include "codegen.h"
4
5#if defined(TI_WITH_LLVM)
6#include "taichi/codegen/cpu/codegen_cpu.h"
7#include "taichi/codegen/wasm/codegen_wasm.h"
8#include "taichi/runtime/llvm/llvm_offline_cache.h"
9#include "taichi/runtime/program_impls/llvm/llvm_program.h"
10#endif
11#if defined(TI_WITH_CUDA)
12#include "taichi/codegen/cuda/codegen_cuda.h"
13#endif
14#if defined(TI_WITH_DX12)
15#include "taichi/codegen/dx12/codegen_dx12.h"
16#endif
17#if defined(TI_WITH_AMDGPU)
18#include "taichi/codegen/amdgpu/codegen_amdgpu.h"
19#endif
20#include "taichi/system/timer.h"
21#include "taichi/ir/analysis.h"
22#include "taichi/ir/transforms.h"
23#include "taichi/analysis/offline_cache_util.h"
24
25namespace taichi::lang {
26
27KernelCodeGen::KernelCodeGen(const CompileConfig &compile_config,
28 Kernel *kernel,
29 TaichiLLVMContext &tlctx)
30 : prog(kernel->program),
31 kernel(kernel),
32 compile_config_(compile_config),
33 tlctx_(tlctx) {
34 this->ir = kernel->ir.get();
35}
36
37std::unique_ptr<KernelCodeGen> KernelCodeGen::create(
38 const CompileConfig &compile_config,
39 Kernel *kernel,
40 TaichiLLVMContext &tlctx) {
41#ifdef TI_WITH_LLVM
42 const auto arch = compile_config.arch;
43 if (arch_is_cpu(arch) && arch != Arch::wasm) {
44 return std::make_unique<KernelCodeGenCPU>(compile_config, kernel, tlctx);
45 } else if (arch == Arch::wasm) {
46 return std::make_unique<KernelCodeGenWASM>(compile_config, kernel, tlctx);
47 } else if (arch == Arch::cuda) {
48#if defined(TI_WITH_CUDA)
49 return std::make_unique<KernelCodeGenCUDA>(compile_config, kernel, tlctx);
50#else
51 TI_NOT_IMPLEMENTED
52#endif
53 } else if (arch == Arch::dx12) {
54#if defined(TI_WITH_DX12)
55 return std::make_unique<KernelCodeGenDX12>(compile_config, kernel, tlctx);
56#else
57 TI_NOT_IMPLEMENTED
58#endif
59 } else if (arch == Arch::amdgpu) {
60#if defined(TI_WITH_AMDGPU)
61 return std::make_unique<KernelCodeGenAMDGPU>(compile_config, kernel, tlctx);
62#else
63 TI_NOT_IMPLEMENTED
64#endif
65 } else {
66 TI_NOT_IMPLEMENTED
67 }
68#else
69 TI_ERROR("Llvm disabled");
70#endif
71}
72#ifdef TI_WITH_LLVM
73
74std::optional<LLVMCompiledKernel>
75KernelCodeGen::maybe_read_compilation_from_cache(
76 const std::string &kernel_key) {
77 TI_AUTO_PROF;
78 auto *llvm_prog = get_llvm_program(prog);
79 const auto &reader = llvm_prog->get_cache_reader();
80 if (!reader) {
81 return std::nullopt;
82 }
83
84 LlvmOfflineCache::KernelCacheData cache_data;
85 auto &llvm_ctx = *tlctx_.get_this_thread_context();
86
87 if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
88 return std::nullopt;
89 }
90 return {std::move(cache_data.compiled_data)};
91}
92
93void KernelCodeGen::cache_kernel(const std::string &kernel_key,
94 const LLVMCompiledKernel &data) {
95 get_llvm_program(prog)->cache_kernel(kernel_key, data,
96 infer_launch_args(kernel));
97}
98
99LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
100 std::string kernel_key =
101 get_hashed_offline_cache_key(compile_config_, kernel);
102 kernel->set_kernel_key_for_cache(kernel_key);
103 if (compile_config_.offline_cache && this->supports_offline_cache() &&
104 !kernel->is_evaluator) {
105 auto res = maybe_read_compilation_from_cache(kernel_key);
106 if (res) {
107 TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
108 kernel_key);
109 cache_kernel(kernel_key, *res);
110 return std::move(*res);
111 }
112 }
113
114 irpass::ast_to_ir(compile_config_, *kernel, false);
115
116 auto block = dynamic_cast<Block *>(kernel->ir.get());
117 auto &worker = get_llvm_program(kernel->program)->compilation_workers;
118 TI_ASSERT(block);
119
120 auto &offloads = block->statements;
121 std::vector<std::unique_ptr<LLVMCompiledTask>> data(offloads.size());
122 for (int i = 0; i < offloads.size(); i++) {
123 auto compile_func = [&, i] {
124 tlctx_.fetch_this_thread_struct_module();
125 auto offload = irpass::analysis::clone(offloads[i].get());
126 irpass::re_id(offload.get());
127 auto new_data = this->compile_task(compile_config_, nullptr,
128 offload->as<OffloadedStmt>());
129 data[i] = std::make_unique<LLVMCompiledTask>(std::move(new_data));
130 };
131 if (kernel->is_evaluator) {
132 compile_func();
133 } else {
134 worker.enqueue(compile_func);
135 }
136 }
137 if (!kernel->is_evaluator) {
138 worker.flush();
139 }
140 auto linked = tlctx_.link_compiled_tasks(std::move(data));
141
142 if (!kernel->is_evaluator) {
143 TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key);
144 cache_kernel(kernel_key, linked);
145 }
146 return linked;
147}
148
149ModuleToFunctionConverter::ModuleToFunctionConverter(
150 TaichiLLVMContext *tlctx,
151 LlvmRuntimeExecutor *executor)
152 : tlctx_(tlctx), executor_(executor) {
153}
154
155FunctionType ModuleToFunctionConverter::convert(const Kernel *kernel,
156 LLVMCompiledKernel data) const {
157 return convert(kernel->name, infer_launch_args(kernel), std::move(data));
158}
159
160#endif
161} // namespace taichi::lang
162