1#pragma once
2
3#include <cstddef>
4#include <memory>
5
6#include "taichi/runtime/llvm/llvm_offline_cache.h"
7#include "taichi/program/compile_config.h"
8#include "taichi/runtime/llvm/llvm_runtime_executor.h"
9#include "taichi/system/memory_pool.h"
10#include "taichi/program/program_impl.h"
11#include "taichi/program/parallel_executor.h"
12#include "taichi/util/bit.h"
13#define TI_RUNTIME_HOST
14#include "taichi/program/context.h"
15#undef TI_RUNTIME_HOST
16
17namespace llvm {
18class Module;
19} // namespace llvm
20
21namespace taichi::lang {
22
23class StructCompiler;
24class Program;
25
26namespace cuda {
27class CudaDevice;
28} // namespace cuda
29
30namespace amdgpu {
31class AmdgpuDevice;
32} // namespace amdgpu
33
34namespace cpu {
35class CpuDevice;
36} // namespace cpu
37
38class LlvmProgramImpl : public ProgramImpl {
39 public:
40 LlvmProgramImpl(CompileConfig &config, KernelProfilerBase *profiler);
41
42 /* ------------------------------------ */
43 /* ---- JIT-Compilation Interfaces ---- */
44 /* ------------------------------------ */
45
46 // TODO(zhanlue): compile-time runtime split for LLVM::CodeGen
47 // For now, compile = codegen + convert
48 FunctionType compile(const CompileConfig &compile_config,
49 Kernel *kernel) override;
50
51 void compile_snode_tree_types(SNodeTree *tree) override;
52
53 // TODO(zhanlue): refactor materialize_snode_tree()
54 // materialize_snode_tree = compile_snode_tree_types +
55 // initialize_llvm_runtime_snodes It's a 2-in-1 interface
56 void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;
57
58 void cache_kernel(const std::string &kernel_key,
59 const LLVMCompiledKernel &data,
60 std::vector<LlvmLaunchArgInfo> &&args);
61 ;
62
63 void cache_field(int snode_tree_id,
64 int root_id,
65 const StructCompiler &struct_compiler);
66
67 LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const {
68 TI_ASSERT(cache_data_->fields.find(snode_tree_id) !=
69 cache_data_->fields.end());
70 return cache_data_->fields.at(snode_tree_id);
71 }
72
73 private:
74 std::unique_ptr<StructCompiler> compile_snode_tree_types_impl(
75 SNodeTree *tree);
76
77 std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
78 const DeviceCapabilityConfig &caps) override;
79
80 void dump_cache_data_to_disk() override;
81
82 /* -------------------------------- */
83 /* ---- JIT-Runtime Interfaces ---- */
84 /* -------------------------------- */
85 // ** Please implement new runtime interfaces in LlvmRuntimeExecutor **
86 //
87 // There are two major customer-level classes, namely Kernel and
88 // FieldsBuilder.
89 //
90 // For now, both Kernel and FieldsBuilder rely on Program/ProgramImpl to
91 // access compile time and runtime interfaces.
92 //
93 // We keep these runtime interfaces in ProgramImpl for now, so as to avoid
94 // changing the higher-level architecture, which is coupled with base classes
95 // and other backends.
96 //
97 // The runtime interfaces in ProgramImpl should be nothing but a simple
98 // wrapper. The one with actual implementation should go inside
99 // LlvmRuntimeExecutor class.
100
101 public:
102 /**
103 * Initializes the runtime system for LLVM based backends.
104 */
105 void materialize_runtime(MemoryPool *memory_pool,
106 KernelProfilerBase *profiler,
107 uint64 **result_buffer_ptr) override {
108 runtime_exec_->materialize_runtime(memory_pool, profiler,
109 result_buffer_ptr);
110 }
111
112 void destroy_snode_tree(SNodeTree *snode_tree) override {
113 return runtime_exec_->destroy_snode_tree(snode_tree);
114 }
115
116 template <typename T>
117 T fetch_result(int i, uint64 *result_buffer) {
118 return runtime_exec_->fetch_result<T>(i, result_buffer);
119 }
120
121 void finalize() override {
122 runtime_exec_->finalize();
123 }
124
125 uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) override {
126 return runtime_exec_->get_ndarray_alloc_info_ptr(alloc);
127 }
128
129 void fill_ndarray(const DeviceAllocation &alloc,
130 std::size_t size,
131 uint32_t data) override {
132 return runtime_exec_->fill_ndarray(alloc, size, data);
133 }
134
135 void prepare_runtime_context(RuntimeContext *ctx) override {
136 runtime_exec_->prepare_runtime_context(ctx);
137 }
138
139 DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
140 uint64 *result_buffer) override {
141 return runtime_exec_->allocate_memory_ndarray(alloc_size, result_buffer);
142 }
143
144 Device *get_compute_device() override {
145 return runtime_exec_->get_compute_device();
146 }
147
148 /**
149 * Initializes the SNodes for LLVM based backends.
150 */
151 void initialize_llvm_runtime_snodes(
152 const LlvmOfflineCache::FieldCacheData &field_cache_data,
153 uint64 *result_buffer) {
154 runtime_exec_->initialize_llvm_runtime_snodes(field_cache_data,
155 result_buffer);
156 }
157
158 uint64 fetch_result_uint64(int i, uint64 *result_buffer) override {
159 return runtime_exec_->fetch_result_uint64(i, result_buffer);
160 }
161
162 TypedConstant fetch_result(char *result_buffer,
163 int offset,
164 const Type *dt) override {
165 if (dt->is_primitive(PrimitiveTypeID::f32)) {
166 return TypedConstant(
167 runtime_exec_->fetch_result<float32>(result_buffer, offset));
168 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
169 return TypedConstant(
170 runtime_exec_->fetch_result<float64>(result_buffer, offset));
171 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
172 return TypedConstant(
173 runtime_exec_->fetch_result<int32>(result_buffer, offset));
174 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
175 return TypedConstant(
176 runtime_exec_->fetch_result<int64>(result_buffer, offset));
177 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
178 return TypedConstant(
179 runtime_exec_->fetch_result<int8>(result_buffer, offset));
180 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
181 return TypedConstant(
182 runtime_exec_->fetch_result<int16>(result_buffer, offset));
183 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
184 return TypedConstant(
185 runtime_exec_->fetch_result<uint8>(result_buffer, offset));
186 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
187 return TypedConstant(
188 runtime_exec_->fetch_result<uint16>(result_buffer, offset));
189 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
190 return TypedConstant(
191 runtime_exec_->fetch_result<uint32>(result_buffer, offset));
192 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
193 return TypedConstant(
194 runtime_exec_->fetch_result<uint64>(result_buffer, offset));
195 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
196 // first fetch the data as u16, and then convert it to f32
197 uint16 half = runtime_exec_->fetch_result<uint16>(result_buffer, offset);
198 return TypedConstant(bit::half_to_float(half));
199 } else {
200 TI_NOT_IMPLEMENTED
201 }
202 }
203
204 template <typename T, typename... Args>
205 T runtime_query(const std::string &key,
206 uint64 *result_buffer,
207 Args &&...args) {
208 return runtime_exec_->runtime_query<T>(key, result_buffer,
209 std::forward<Args>(args)...);
210 }
211
212 void print_list_manager_info(void *list_manager, uint64 *result_buffer) {
213 runtime_exec_->print_list_manager_info(list_manager, result_buffer);
214 }
215
216 void print_memory_profiler_info(
217 std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
218 uint64 *result_buffer) override {
219 runtime_exec_->print_memory_profiler_info(snode_trees_, result_buffer);
220 }
221
222 TaichiLLVMContext *get_llvm_context() {
223 return runtime_exec_->get_llvm_context();
224 }
225
226 void synchronize() override {
227 runtime_exec_->synchronize();
228 }
229
230 LLVMRuntime *get_llvm_runtime() {
231 return runtime_exec_->get_llvm_runtime();
232 }
233
234 std::size_t get_snode_num_dynamically_allocated(
235 SNode *snode,
236 uint64 *result_buffer) override {
237 return runtime_exec_->get_snode_num_dynamically_allocated(snode,
238 result_buffer);
239 }
240
241 void check_runtime_error(uint64 *result_buffer) override {
242 runtime_exec_->check_runtime_error(result_buffer);
243 }
244
245 size_t get_field_in_tree_offset(int tree_id, const SNode *child) override {
246 // FIXME: Compute the proper offset. Current method taken from GGUI code
247 size_t offset = 0;
248
249 SNode *dense_parent = child->parent;
250 SNode *root = dense_parent->parent;
251
252 int child_id = root->child_id(dense_parent);
253
254 for (int i = 0; i < child_id; ++i) {
255 SNode *child = root->ch[i].get();
256 offset += child->cell_size_bytes * child->num_cells_per_container;
257 }
258
259 return offset;
260 }
261
262 DevicePtr get_snode_tree_device_ptr(int tree_id) override {
263 return runtime_exec_->get_snode_tree_device_ptr(tree_id);
264 }
265
266 cuda::CudaDevice *cuda_device() {
267 return runtime_exec_->cuda_device();
268 }
269
270 cpu::CpuDevice *cpu_device() {
271 return runtime_exec_->cpu_device();
272 }
273
274 LlvmDevice *llvm_device() {
275 return runtime_exec_->llvm_device();
276 }
277
278 LlvmRuntimeExecutor *get_runtime_executor() {
279 return runtime_exec_.get();
280 }
281
282 const std::unique_ptr<LlvmOfflineCacheFileReader> &get_cache_reader() {
283 return cache_reader_;
284 }
285
286 std::string get_kernel_return_data_layout() override {
287 return get_llvm_context()->get_data_layout_string();
288 };
289
290 std::string get_kernel_argument_data_layout() override {
291 return get_llvm_context()->get_data_layout_string();
292 };
293 const StructType *get_struct_type_with_data_layout(
294 const StructType *old_ty,
295 const std::string &layout) override {
296 return get_llvm_context()->get_struct_type_with_data_layout(old_ty, layout);
297 }
298
299 // TODO(zhanlue): Rearrange llvm::Context's ownership
300 //
301 // In LLVM backend, most of the compiled information are stored in
302 // llvm::Module:
303 // 1. Runtime functions are compiled into runtime_module,
304 // 2. Fields are compiled into struct_module,
305 // 3. Each kernel is compiled into individual kernel_module
306 //
307 // However, all the llvm::Modules are owned by llvm::Context, which belongs to
308 // TaichiLLVMContext. Upon destruction, there's an implicit requirement that
309 // TaichiLLVMContext has to stay alive until all the llvm::Modules are
310 // destructed, otherwise there will be risks of dangling references.
311 //
312 // To guarantee the life cycle of llvm::Module stay aligned with
313 // llvm::Context, we better make llvm::Context a more global-scoped variable,
314 // instead of owned by TaichiLLVMContext.
315 //
316 // Objects owning llvm::Module so far (from direct to indirect):
317 // 1. LlvmOfflineCache::CachedKernelData(direct owner)
318 // 2. LlvmOfflineCache
319 // 3.1 LlvmProgramImpl
320 // 3.2 LlvmAotModuleBuilder
321 // 3.3 llvm_aot::KernelImpl (for use in CGraph)
322 //
323 // Objects owning llvm::Context (from direct to indirect)
324 // 1. TaichiLLVMContext
325 // 2. LlvmProgramImpl
326 //
327 // Make sure the above mentioned objects are destructed in order.
328 ~LlvmProgramImpl() override {
329 // Explicitly enforce "LlvmOfflineCache::CachedKernelData::owned_module"
330 // destructs before
331 // "LlvmRuntimeExecutor::TaichiLLVMContext::ThreadSafeContext"
332
333 // 1. Destructs cache_data_
334 cache_data_.reset();
335
336 // 2. Destructs cache_reader_
337 cache_reader_.reset();
338
339 // 3. Destructs runtime_exec_
340 runtime_exec_.reset();
341 }
342 ParallelExecutor compilation_workers; // parallel compilation
343
344 private:
345 std::size_t num_snode_trees_processed_{0};
346 std::unique_ptr<LlvmRuntimeExecutor> runtime_exec_;
347 std::unique_ptr<LlvmOfflineCache> cache_data_;
348 std::unique_ptr<LlvmOfflineCacheFileReader> cache_reader_;
349};
350
351LlvmProgramImpl *get_llvm_program(Program *prog);
352
353} // namespace taichi::lang
354