1 | #pragma once |
2 | |
3 | #include <cstddef> |
4 | #include <memory> |
5 | |
6 | #include "taichi/runtime/llvm/llvm_offline_cache.h" |
7 | #include "taichi/program/compile_config.h" |
8 | #include "taichi/runtime/llvm/llvm_runtime_executor.h" |
9 | #include "taichi/system/memory_pool.h" |
10 | #include "taichi/program/program_impl.h" |
11 | #include "taichi/program/parallel_executor.h" |
12 | #include "taichi/util/bit.h" |
13 | #define TI_RUNTIME_HOST |
14 | #include "taichi/program/context.h" |
15 | #undef TI_RUNTIME_HOST |
16 | |
17 | namespace llvm { |
18 | class Module; |
19 | } // namespace llvm |
20 | |
21 | namespace taichi::lang { |
22 | |
23 | class StructCompiler; |
24 | class Program; |
25 | |
26 | namespace cuda { |
27 | class CudaDevice; |
28 | } // namespace cuda |
29 | |
30 | namespace amdgpu { |
31 | class AmdgpuDevice; |
32 | } // namespace amdgpu |
33 | |
34 | namespace cpu { |
35 | class CpuDevice; |
36 | } // namespace cpu |
37 | |
38 | class LlvmProgramImpl : public ProgramImpl { |
39 | public: |
40 | LlvmProgramImpl(CompileConfig &config, KernelProfilerBase *profiler); |
41 | |
42 | /* ------------------------------------ */ |
43 | /* ---- JIT-Compilation Interfaces ---- */ |
44 | /* ------------------------------------ */ |
45 | |
46 | // TODO(zhanlue): compile-time runtime split for LLVM::CodeGen |
47 | // For now, compile = codegen + convert |
48 | FunctionType compile(const CompileConfig &compile_config, |
49 | Kernel *kernel) override; |
50 | |
51 | void compile_snode_tree_types(SNodeTree *tree) override; |
52 | |
53 | // TODO(zhanlue): refactor materialize_snode_tree() |
54 | // materialize_snode_tree = compile_snode_tree_types + |
55 | // initialize_llvm_runtime_snodes It's a 2-in-1 interface |
56 | void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override; |
57 | |
58 | void cache_kernel(const std::string &kernel_key, |
59 | const LLVMCompiledKernel &data, |
60 | std::vector<LlvmLaunchArgInfo> &&args); |
61 | ; |
62 | |
63 | void cache_field(int snode_tree_id, |
64 | int root_id, |
65 | const StructCompiler &struct_compiler); |
66 | |
67 | LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const { |
68 | TI_ASSERT(cache_data_->fields.find(snode_tree_id) != |
69 | cache_data_->fields.end()); |
70 | return cache_data_->fields.at(snode_tree_id); |
71 | } |
72 | |
73 | private: |
74 | std::unique_ptr<StructCompiler> compile_snode_tree_types_impl( |
75 | SNodeTree *tree); |
76 | |
77 | std::unique_ptr<AotModuleBuilder> make_aot_module_builder( |
78 | const DeviceCapabilityConfig &caps) override; |
79 | |
80 | void dump_cache_data_to_disk() override; |
81 | |
82 | /* -------------------------------- */ |
83 | /* ---- JIT-Runtime Interfaces ---- */ |
84 | /* -------------------------------- */ |
85 | // ** Please implement new runtime interfaces in LlvmRuntimeExecutor ** |
86 | // |
87 | // There are two major customer-level classes, namely Kernel and |
88 | // FieldsBuilder. |
89 | // |
90 | // For now, both Kernel and FieldsBuilder rely on Program/ProgramImpl to |
91 | // access compile time and runtime interfaces. |
92 | // |
93 | // We keep these runtime interfaces in ProgramImpl for now, so as to avoid |
94 | // changing the higher-level architecture, which is coupled with base classes |
95 | // and other backends. |
96 | // |
97 | // The runtime interfaces in ProgramImpl should be nothing but a simple |
98 | // wrapper. The one with actual implementation should go inside |
99 | // LlvmRuntimeExecutor class. |
100 | |
101 | public: |
102 | /** |
103 | * Initializes the runtime system for LLVM based backends. |
104 | */ |
105 | void materialize_runtime(MemoryPool *memory_pool, |
106 | KernelProfilerBase *profiler, |
107 | uint64 **result_buffer_ptr) override { |
108 | runtime_exec_->materialize_runtime(memory_pool, profiler, |
109 | result_buffer_ptr); |
110 | } |
111 | |
112 | void destroy_snode_tree(SNodeTree *snode_tree) override { |
113 | return runtime_exec_->destroy_snode_tree(snode_tree); |
114 | } |
115 | |
116 | template <typename T> |
117 | T fetch_result(int i, uint64 *result_buffer) { |
118 | return runtime_exec_->fetch_result<T>(i, result_buffer); |
119 | } |
120 | |
121 | void finalize() override { |
122 | runtime_exec_->finalize(); |
123 | } |
124 | |
125 | uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) override { |
126 | return runtime_exec_->get_ndarray_alloc_info_ptr(alloc); |
127 | } |
128 | |
129 | void fill_ndarray(const DeviceAllocation &alloc, |
130 | std::size_t size, |
131 | uint32_t data) override { |
132 | return runtime_exec_->fill_ndarray(alloc, size, data); |
133 | } |
134 | |
135 | void prepare_runtime_context(RuntimeContext *ctx) override { |
136 | runtime_exec_->prepare_runtime_context(ctx); |
137 | } |
138 | |
139 | DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, |
140 | uint64 *result_buffer) override { |
141 | return runtime_exec_->allocate_memory_ndarray(alloc_size, result_buffer); |
142 | } |
143 | |
144 | Device *get_compute_device() override { |
145 | return runtime_exec_->get_compute_device(); |
146 | } |
147 | |
148 | /** |
149 | * Initializes the SNodes for LLVM based backends. |
150 | */ |
151 | void initialize_llvm_runtime_snodes( |
152 | const LlvmOfflineCache::FieldCacheData &field_cache_data, |
153 | uint64 *result_buffer) { |
154 | runtime_exec_->initialize_llvm_runtime_snodes(field_cache_data, |
155 | result_buffer); |
156 | } |
157 | |
158 | uint64 fetch_result_uint64(int i, uint64 *result_buffer) override { |
159 | return runtime_exec_->fetch_result_uint64(i, result_buffer); |
160 | } |
161 | |
162 | TypedConstant fetch_result(char *result_buffer, |
163 | int offset, |
164 | const Type *dt) override { |
165 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
166 | return TypedConstant( |
167 | runtime_exec_->fetch_result<float32>(result_buffer, offset)); |
168 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
169 | return TypedConstant( |
170 | runtime_exec_->fetch_result<float64>(result_buffer, offset)); |
171 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
172 | return TypedConstant( |
173 | runtime_exec_->fetch_result<int32>(result_buffer, offset)); |
174 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
175 | return TypedConstant( |
176 | runtime_exec_->fetch_result<int64>(result_buffer, offset)); |
177 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
178 | return TypedConstant( |
179 | runtime_exec_->fetch_result<int8>(result_buffer, offset)); |
180 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
181 | return TypedConstant( |
182 | runtime_exec_->fetch_result<int16>(result_buffer, offset)); |
183 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
184 | return TypedConstant( |
185 | runtime_exec_->fetch_result<uint8>(result_buffer, offset)); |
186 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
187 | return TypedConstant( |
188 | runtime_exec_->fetch_result<uint16>(result_buffer, offset)); |
189 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
190 | return TypedConstant( |
191 | runtime_exec_->fetch_result<uint32>(result_buffer, offset)); |
192 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
193 | return TypedConstant( |
194 | runtime_exec_->fetch_result<uint64>(result_buffer, offset)); |
195 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
196 | // first fetch the data as u16, and then convert it to f32 |
197 | uint16 half = runtime_exec_->fetch_result<uint16>(result_buffer, offset); |
198 | return TypedConstant(bit::half_to_float(half)); |
199 | } else { |
200 | TI_NOT_IMPLEMENTED |
201 | } |
202 | } |
203 | |
204 | template <typename T, typename... Args> |
205 | T runtime_query(const std::string &key, |
206 | uint64 *result_buffer, |
207 | Args &&...args) { |
208 | return runtime_exec_->runtime_query<T>(key, result_buffer, |
209 | std::forward<Args>(args)...); |
210 | } |
211 | |
212 | void print_list_manager_info(void *list_manager, uint64 *result_buffer) { |
213 | runtime_exec_->print_list_manager_info(list_manager, result_buffer); |
214 | } |
215 | |
216 | void print_memory_profiler_info( |
217 | std::vector<std::unique_ptr<SNodeTree>> &snode_trees_, |
218 | uint64 *result_buffer) override { |
219 | runtime_exec_->print_memory_profiler_info(snode_trees_, result_buffer); |
220 | } |
221 | |
222 | TaichiLLVMContext *get_llvm_context() { |
223 | return runtime_exec_->get_llvm_context(); |
224 | } |
225 | |
226 | void synchronize() override { |
227 | runtime_exec_->synchronize(); |
228 | } |
229 | |
230 | LLVMRuntime *get_llvm_runtime() { |
231 | return runtime_exec_->get_llvm_runtime(); |
232 | } |
233 | |
234 | std::size_t get_snode_num_dynamically_allocated( |
235 | SNode *snode, |
236 | uint64 *result_buffer) override { |
237 | return runtime_exec_->get_snode_num_dynamically_allocated(snode, |
238 | result_buffer); |
239 | } |
240 | |
241 | void check_runtime_error(uint64 *result_buffer) override { |
242 | runtime_exec_->check_runtime_error(result_buffer); |
243 | } |
244 | |
245 | size_t get_field_in_tree_offset(int tree_id, const SNode *child) override { |
246 | // FIXME: Compute the proper offset. Current method taken from GGUI code |
247 | size_t offset = 0; |
248 | |
249 | SNode *dense_parent = child->parent; |
250 | SNode *root = dense_parent->parent; |
251 | |
252 | int child_id = root->child_id(dense_parent); |
253 | |
254 | for (int i = 0; i < child_id; ++i) { |
255 | SNode *child = root->ch[i].get(); |
256 | offset += child->cell_size_bytes * child->num_cells_per_container; |
257 | } |
258 | |
259 | return offset; |
260 | } |
261 | |
262 | DevicePtr get_snode_tree_device_ptr(int tree_id) override { |
263 | return runtime_exec_->get_snode_tree_device_ptr(tree_id); |
264 | } |
265 | |
266 | cuda::CudaDevice *cuda_device() { |
267 | return runtime_exec_->cuda_device(); |
268 | } |
269 | |
270 | cpu::CpuDevice *cpu_device() { |
271 | return runtime_exec_->cpu_device(); |
272 | } |
273 | |
274 | LlvmDevice *llvm_device() { |
275 | return runtime_exec_->llvm_device(); |
276 | } |
277 | |
278 | LlvmRuntimeExecutor *get_runtime_executor() { |
279 | return runtime_exec_.get(); |
280 | } |
281 | |
282 | const std::unique_ptr<LlvmOfflineCacheFileReader> &get_cache_reader() { |
283 | return cache_reader_; |
284 | } |
285 | |
286 | std::string get_kernel_return_data_layout() override { |
287 | return get_llvm_context()->get_data_layout_string(); |
288 | }; |
289 | |
290 | std::string get_kernel_argument_data_layout() override { |
291 | return get_llvm_context()->get_data_layout_string(); |
292 | }; |
293 | const StructType *get_struct_type_with_data_layout( |
294 | const StructType *old_ty, |
295 | const std::string &layout) override { |
296 | return get_llvm_context()->get_struct_type_with_data_layout(old_ty, layout); |
297 | } |
298 | |
299 | // TODO(zhanlue): Rearrange llvm::Context's ownership |
300 | // |
301 | // In LLVM backend, most of the compiled information are stored in |
302 | // llvm::Module: |
303 | // 1. Runtime functions are compiled into runtime_module, |
304 | // 2. Fields are compiled into struct_module, |
305 | // 3. Each kernel is compiled into individual kernel_module |
306 | // |
307 | // However, all the llvm::Modules are owned by llvm::Context, which belongs to |
308 | // TaichiLLVMContext. Upon destruction, there's an implicit requirement that |
309 | // TaichiLLVMContext has to stay alive until all the llvm::Modules are |
310 | // destructed, otherwise there will be risks of dangling references. |
311 | // |
312 | // To guarantee the life cycle of llvm::Module stay aligned with |
313 | // llvm::Context, we better make llvm::Context a more global-scoped variable, |
314 | // instead of owned by TaichiLLVMContext. |
315 | // |
316 | // Objects owning llvm::Module so far (from direct to indirect): |
317 | // 1. LlvmOfflineCache::CachedKernelData(direct owner) |
318 | // 2. LlvmOfflineCache |
319 | // 3.1 LlvmProgramImpl |
320 | // 3.2 LlvmAotModuleBuilder |
321 | // 3.3 llvm_aot::KernelImpl (for use in CGraph) |
322 | // |
323 | // Objects owning llvm::Context (from direct to indirect) |
324 | // 1. TaichiLLVMContext |
325 | // 2. LlvmProgramImpl |
326 | // |
327 | // Make sure the above mentioned objects are destructed in order. |
328 | ~LlvmProgramImpl() override { |
329 | // Explicitly enforce "LlvmOfflineCache::CachedKernelData::owned_module" |
330 | // destructs before |
331 | // "LlvmRuntimeExecutor::TaichiLLVMContext::ThreadSafeContext" |
332 | |
333 | // 1. Destructs cache_data_ |
334 | cache_data_.reset(); |
335 | |
336 | // 2. Destructs cache_reader_ |
337 | cache_reader_.reset(); |
338 | |
339 | // 3. Destructs runtime_exec_ |
340 | runtime_exec_.reset(); |
341 | } |
342 | ParallelExecutor compilation_workers; // parallel compilation |
343 | |
344 | private: |
345 | std::size_t num_snode_trees_processed_{0}; |
346 | std::unique_ptr<LlvmRuntimeExecutor> runtime_exec_; |
347 | std::unique_ptr<LlvmOfflineCache> cache_data_; |
348 | std::unique_ptr<LlvmOfflineCacheFileReader> cache_reader_; |
349 | }; |
350 | |
351 | LlvmProgramImpl *get_llvm_program(Program *prog); |
352 | |
353 | } // namespace taichi::lang |
354 | |