1 | #pragma once |
2 | |
3 | #include <cstddef> |
4 | #include <memory> |
5 | |
6 | #ifdef TI_WITH_LLVM |
7 | |
8 | #include "taichi/rhi/llvm/llvm_device.h" |
9 | #include "taichi/runtime/llvm/llvm_offline_cache.h" |
10 | #include "taichi/runtime/llvm/snode_tree_buffer_manager.h" |
11 | #include "taichi/runtime/llvm/llvm_context.h" |
12 | #include "taichi/struct/snode_tree.h" |
13 | #include "taichi/program/compile_config.h" |
14 | |
15 | #include "taichi/system/threading.h" |
16 | #include "taichi/system/memory_pool.h" |
17 | |
18 | #define TI_RUNTIME_HOST |
19 | #include "taichi/program/context.h" |
20 | #undef TI_RUNTIME_HOST |
21 | |
22 | namespace taichi::lang { |
23 | |
24 | namespace cuda { |
25 | class CudaDevice; |
26 | } // namespace cuda |
27 | |
28 | namespace amdgpu { |
29 | class AmdgpuDevice; |
30 | } // namespace amdgpu |
31 | |
32 | namespace cpu { |
33 | class CpuDevice; |
34 | } // namespace cpu |
35 | |
36 | class LlvmRuntimeExecutor { |
37 | public: |
38 | LlvmRuntimeExecutor(CompileConfig &config, KernelProfilerBase *profiler); |
39 | |
40 | /** |
41 | * Initializes the runtime system for LLVM based backends. |
42 | */ |
43 | void materialize_runtime(MemoryPool *memory_pool, |
44 | KernelProfilerBase *profiler, |
45 | uint64 **result_buffer_ptr); |
46 | |
47 | // SNodeTree Allocation |
48 | void initialize_llvm_runtime_snodes( |
49 | const LlvmOfflineCache::FieldCacheData &field_cache_data, |
50 | uint64 *result_buffer); |
51 | |
52 | // Ndarray Allocation |
53 | DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, |
54 | uint64 *result_buffer); |
55 | |
56 | void deallocate_memory_ndarray(DeviceAllocation handle); |
57 | |
58 | void check_runtime_error(uint64 *result_buffer); |
59 | |
60 | uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc); |
61 | |
62 | const CompileConfig &get_config() const { |
63 | return config_; |
64 | } |
65 | |
66 | TaichiLLVMContext *get_llvm_context(); |
67 | |
68 | JITModule *create_jit_module(std::unique_ptr<llvm::Module> module); |
69 | |
70 | JITModule *get_runtime_jit_module(); |
71 | |
72 | LLVMRuntime *get_llvm_runtime(); |
73 | |
74 | void prepare_runtime_context(RuntimeContext *ctx); |
75 | |
76 | Device *get_compute_device(); |
77 | |
78 | LlvmDevice *llvm_device(); |
79 | |
80 | void synchronize(); |
81 | |
82 | private: |
83 | /* ----------------------- */ |
84 | /* ------ Allocation ----- */ |
85 | /* ----------------------- */ |
86 | template <typename T> |
87 | T fetch_result(int i, uint64 *result_buffer) { |
88 | return taichi_union_cast_with_different_sizes<T>( |
89 | fetch_result_uint64(i, result_buffer)); |
90 | } |
91 | |
92 | template <typename T> |
93 | T fetch_result(char *result_buffer, int offset) { |
94 | T ret; |
95 | fetch_result_impl(&ret, result_buffer, offset, sizeof(T)); |
96 | return ret; |
97 | } |
98 | |
99 | void fetch_result_impl(void *dest, char *result_buffer, int offset, int size); |
100 | |
101 | DevicePtr get_snode_tree_device_ptr(int tree_id); |
102 | |
103 | void fill_ndarray(const DeviceAllocation &alloc, |
104 | std::size_t size, |
105 | uint32_t data); |
106 | |
107 | /* ------------------------- */ |
108 | /* ---- Runtime Helpers ---- */ |
109 | /* ------------------------- */ |
110 | void print_list_manager_info(void *list_manager, uint64 *result_buffer); |
111 | void print_memory_profiler_info( |
112 | std::vector<std::unique_ptr<SNodeTree>> &snode_trees_, |
113 | uint64 *result_buffer); |
114 | |
115 | template <typename T, typename... Args> |
116 | T runtime_query(const std::string &key, |
117 | uint64 *result_buffer, |
118 | Args &&...args) { |
119 | TI_ASSERT(arch_uses_llvm(config_.arch)); |
120 | |
121 | auto runtime = get_runtime_jit_module(); |
122 | runtime->call<void *>("runtime_" + key, llvm_runtime_, |
123 | std::forward<Args>(args)...); |
124 | return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64( |
125 | taichi_result_buffer_runtime_query_id, result_buffer)); |
126 | } |
127 | |
128 | /* -------------------------- */ |
129 | /* ------ Member Access ----- */ |
130 | /* -------------------------- */ |
131 | cuda::CudaDevice *cuda_device(); |
132 | cpu::CpuDevice *cpu_device(); |
133 | amdgpu::AmdgpuDevice *amdgpu_device(); |
134 | |
135 | void finalize(); |
136 | |
137 | uint64 fetch_result_uint64(int i, uint64 *result_buffer); |
138 | void destroy_snode_tree(SNodeTree *snode_tree); |
139 | std::size_t get_snode_num_dynamically_allocated(SNode *snode, |
140 | uint64 *result_buffer); |
141 | |
142 | void init_runtime_jit_module(std::unique_ptr<llvm::Module> module); |
143 | |
144 | private: |
145 | CompileConfig &config_; |
146 | |
147 | // TODO(zhanlue): compile - runtime split for TaichiLLVMContext |
148 | // |
149 | // TaichiLLVMContext is a thread-safe class with llvm::Module for compilation |
150 | // and JITSession/JITModule for runtime loading & execution |
151 | std::unique_ptr<TaichiLLVMContext> llvm_context_{nullptr}; |
152 | JITModule *runtime_jit_module_{nullptr}; |
153 | void *llvm_runtime_{nullptr}; |
154 | |
155 | std::unique_ptr<ThreadPool> thread_pool_{nullptr}; |
156 | std::shared_ptr<Device> device_{nullptr}; |
157 | |
158 | std::unique_ptr<SNodeTreeBufferManager> snode_tree_buffer_manager_{nullptr}; |
159 | std::unordered_map<int, DeviceAllocation> snode_tree_allocs_; |
160 | void *preallocated_device_buffer_{nullptr}; // TODO: move to memory allocator |
161 | DeviceAllocation preallocated_device_buffer_alloc_{kDeviceNullAllocation}; |
162 | |
163 | // good buddy |
164 | friend LlvmProgramImpl; |
165 | friend SNodeTreeBufferManager; |
166 | |
167 | KernelProfilerBase *profiler_ = nullptr; |
168 | }; |
169 | |
170 | } // namespace taichi::lang |
171 | |
172 | #endif // TI_WITH_LLVM |
173 | |