1 | #pragma once |
2 | |
3 | // llvm backend compiler (x64, arm64, cuda, amdgpu etc) |
4 | // in charge of creating & JITing arch-specific LLVM modules, |
5 | // and invoking compiled functions (kernels). |
6 | // Designed to be multithreaded for parallel compilation. |
7 | |
8 | #include <mutex> |
9 | #include <functional> |
10 | #include <thread> |
11 | |
12 | #include "taichi/util/lang_util.h" |
13 | #include "taichi/runtime/llvm/llvm_fwd.h" |
14 | #include "taichi/ir/snode.h" |
15 | #include "taichi/jit/jit_session.h" |
16 | #include "taichi/codegen/llvm/llvm_compiled_data.h" |
17 | |
18 | namespace taichi::lang { |
19 | |
20 | class JITSessionCPU; |
21 | class LlvmProgramImpl; |
22 | |
23 | /** |
24 | * Manages an LLVMContext for Taichi's usage. |
25 | */ |
26 | class TaichiLLVMContext { |
27 | private: |
28 | struct ThreadLocalData { |
29 | std::unique_ptr<llvm::orc::ThreadSafeContext> thread_safe_llvm_context{ |
30 | nullptr}; |
31 | llvm::LLVMContext *llvm_context{nullptr}; |
32 | std::unique_ptr<llvm::Module> runtime_module{nullptr}; |
33 | std::unordered_map<int, std::unique_ptr<llvm::Module>> struct_modules; |
34 | explicit ThreadLocalData(std::unique_ptr<llvm::orc::ThreadSafeContext> ctx); |
35 | ~ThreadLocalData(); |
36 | }; |
37 | const CompileConfig &config_; |
38 | |
39 | public: |
40 | std::unique_ptr<JITSession> jit{nullptr}; |
41 | // main_thread is defined to be the thread that runs the initializer |
42 | |
43 | std::unique_ptr<ThreadLocalData> linking_context_data{nullptr}; |
44 | |
45 | TaichiLLVMContext(const CompileConfig &config, Arch arch); |
46 | |
47 | virtual ~TaichiLLVMContext(); |
48 | |
49 | llvm::LLVMContext *get_this_thread_context(); |
50 | |
51 | llvm::orc::ThreadSafeContext *get_this_thread_thread_safe_context(); |
52 | |
53 | /** |
54 | * Updates the LLVM module of the JIT compiled SNode structs. |
55 | * |
56 | * @param module Module containing the JIT compiled SNode structs. |
57 | */ |
58 | void add_struct_module(std::unique_ptr<llvm::Module> module, int tree_id); |
59 | |
60 | void init_runtime_module(llvm::Module *runtime_module); |
61 | |
62 | /** |
63 | * Clones the LLVM module compiled from llvm/runtime.cpp |
64 | * |
65 | * @return The cloned module. |
66 | */ |
67 | std::unique_ptr<llvm::Module> clone_runtime_module(); |
68 | |
69 | std::unique_ptr<llvm::Module> module_from_file(const std::string &file); |
70 | |
71 | llvm::Type *get_data_type(DataType dt); |
72 | |
73 | template <typename T> |
74 | llvm::Type *get_data_type() { |
75 | return TaichiLLVMContext::get_data_type(taichi::lang::get_data_type<T>()); |
76 | } |
77 | |
78 | std::size_t get_type_size(llvm::Type *type); |
79 | |
80 | std::size_t get_struct_element_offset(llvm::StructType *type, int idx); |
81 | |
82 | const StructType *get_struct_type_with_data_layout(const StructType *old_ty, |
83 | const std::string &layout); |
84 | |
85 | template <typename T> |
86 | llvm::Value *get_constant(T t); |
87 | |
88 | template <typename T> |
89 | llvm::Value *get_constant(DataType dt, T t); |
90 | |
91 | llvm::DataLayout get_data_layout(); |
92 | |
93 | std::string get_data_layout_string(); |
94 | |
95 | std::string type_name(llvm::Type *type); |
96 | |
97 | static void mark_inline(llvm::Function *func); |
98 | |
99 | static void print_huge_functions(llvm::Module *module); |
100 | |
101 | // remove all functions that are not (directly & indirectly) used by those |
102 | // with export_indicator(func_name) = true |
103 | static void eliminate_unused_functions( |
104 | llvm::Module *module, |
105 | std::function<bool(const std::string &)> export_indicator); |
106 | |
107 | void mark_function_as_cuda_kernel(llvm::Function *func, int block_dim = 0); |
108 | |
109 | void mark_function_as_amdgpu_kernel(llvm::Function *func); |
110 | |
111 | void fetch_this_thread_struct_module(); |
112 | llvm::Module *get_this_thread_runtime_module(); |
113 | llvm::Function *get_runtime_function(const std::string &name); |
114 | llvm::Function *get_struct_function(const std::string &name, int tree_id); |
115 | llvm::Type *get_runtime_type(const std::string &name); |
116 | |
117 | std::unique_ptr<llvm::Module> new_module( |
118 | std::string name, |
119 | llvm::LLVMContext *context = nullptr); |
120 | |
121 | void delete_snode_tree(int id); |
122 | |
123 | void add_struct_for_func(llvm::Module *module, int tls_size); |
124 | |
125 | static std::string get_struct_for_func_name(int tls_size); |
126 | |
127 | LLVMCompiledKernel link_compiled_tasks( |
128 | std::vector<std::unique_ptr<LLVMCompiledTask>> data_list); |
129 | |
130 | private: |
131 | std::unique_ptr<llvm::Module> clone_module_to_context( |
132 | llvm::Module *module, |
133 | llvm::LLVMContext *target_context); |
134 | |
135 | void link_module_with_cuda_libdevice(std::unique_ptr<llvm::Module> &module); |
136 | |
137 | void link_module_with_amdgpu_libdevice(std::unique_ptr<llvm::Module> &module); |
138 | |
139 | static int num_instructions(llvm::Function *func); |
140 | |
141 | void insert_nvvm_annotation(llvm::Function *func, std::string key, int val); |
142 | |
143 | std::unique_ptr<llvm::Module> clone_module_to_this_thread_context( |
144 | llvm::Module *module); |
145 | |
146 | ThreadLocalData *get_this_thread_data(); |
147 | |
148 | std::unordered_map<std::thread::id, std::unique_ptr<ThreadLocalData>> |
149 | per_thread_data_; |
150 | |
151 | Arch arch_; |
152 | |
153 | std::thread::id main_thread_id_; |
154 | ThreadLocalData *main_thread_data_{nullptr}; |
155 | std::mutex mut_; |
156 | std::mutex thread_map_mut_; |
157 | |
158 | std::unordered_map<int, std::vector<std::string>> snode_tree_funcs_; |
159 | }; |
160 | |
161 | class LlvmModuleBitcodeLoader { |
162 | public: |
163 | LlvmModuleBitcodeLoader &set_bitcode_path(const std::string &bitcode_path) { |
164 | bitcode_path_ = bitcode_path; |
165 | return *this; |
166 | } |
167 | |
168 | LlvmModuleBitcodeLoader &set_buffer_id(const std::string &buffer_id) { |
169 | buffer_id_ = buffer_id; |
170 | return *this; |
171 | } |
172 | |
173 | LlvmModuleBitcodeLoader &set_inline_funcs(bool inline_funcs) { |
174 | inline_funcs_ = inline_funcs; |
175 | return *this; |
176 | } |
177 | |
178 | std::unique_ptr<llvm::Module> load(llvm::LLVMContext *ctx) const; |
179 | |
180 | private: |
181 | std::string bitcode_path_; |
182 | std::string buffer_id_; |
183 | bool inline_funcs_{false}; |
184 | }; |
185 | |
186 | std::unique_ptr<llvm::Module> module_from_bitcode_file( |
187 | const std::string &bitcode_path, |
188 | llvm::LLVMContext *ctx); |
189 | |
190 | } // namespace taichi::lang |
191 | |