1 | #include "taichi/codegen/llvm/struct_llvm.h" |
2 | |
3 | #include "llvm/IR/Verifier.h" |
4 | #include "llvm/IR/IRBuilder.h" |
5 | |
6 | #include "taichi/ir/ir.h" |
7 | #include "taichi/struct/struct.h" |
8 | #include "taichi/util/file_sequence_writer.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | StructCompilerLLVM::StructCompilerLLVM(Arch arch, |
13 | const CompileConfig &config, |
14 | TaichiLLVMContext *tlctx, |
15 | std::unique_ptr<llvm::Module> &&module, |
16 | int snode_tree_id) |
17 | : LLVMModuleBuilder(std::move(module), tlctx), |
18 | arch_(arch), |
19 | config_(config), |
20 | tlctx_(tlctx), |
21 | llvm_ctx_(tlctx_->get_this_thread_context()), |
22 | snode_tree_id_(snode_tree_id) { |
23 | } |
24 | |
25 | StructCompilerLLVM::StructCompilerLLVM(Arch arch, |
26 | LlvmProgramImpl *prog, |
27 | std::unique_ptr<llvm::Module> &&module, |
28 | int snode_tree_id) |
29 | : StructCompilerLLVM(arch, |
30 | *prog->config, |
31 | prog->get_llvm_context(), |
32 | std::move(module), |
33 | snode_tree_id) { |
34 | } |
35 | |
36 | void StructCompilerLLVM::generate_types(SNode &snode) { |
37 | TI_AUTO_PROF; |
38 | auto type = snode.type; |
39 | if (snode.is_bit_level) |
40 | return; |
41 | llvm::Type *node_type = nullptr; |
42 | |
43 | auto ctx = llvm_ctx_; |
44 | TI_ASSERT(ctx == tlctx_->get_this_thread_context()); |
45 | |
46 | // create children type that supports forking... |
47 | |
48 | std::vector<llvm::Type *> ch_types; |
49 | for (int i = 0; i < snode.ch.size(); i++) { |
50 | if (!snode.ch[i]->is_bit_level) { |
51 | // Bit-level SNodes do not really have a corresponding LLVM type |
52 | auto ch = get_llvm_node_type(module.get(), snode.ch[i].get()); |
53 | ch_types.push_back(ch); |
54 | } |
55 | } |
56 | |
57 | auto ch_type = |
58 | llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch" ); |
59 | |
60 | snode.cell_size_bytes = tlctx_->get_type_size(ch_type); |
61 | |
62 | for (int i = 0; i < snode.ch.size(); i++) { |
63 | if (!snode.ch[i]->is_bit_level) { |
64 | snode.ch[i]->offset_bytes_in_parent_cell = |
65 | tlctx_->get_struct_element_offset(ch_type, i); |
66 | } |
67 | } |
68 | |
69 | llvm::Type *body_type = nullptr, *aux_type = nullptr; |
70 | if (type == SNodeType::dense || type == SNodeType::bitmasked) { |
71 | TI_ASSERT(snode._morton == false); |
72 | body_type = llvm::ArrayType::get(ch_type, snode.max_num_elements()); |
73 | if (type == SNodeType::bitmasked) { |
74 | aux_type = llvm::ArrayType::get(llvm::Type::getInt32Ty(*llvm_ctx_), |
75 | (snode.max_num_elements() + 31) / 32); |
76 | } |
77 | } else if (type == SNodeType::root) { |
78 | body_type = ch_type; |
79 | } else if (type == SNodeType::place) { |
80 | body_type = tlctx_->get_data_type(snode.dt); |
81 | } else if (type == SNodeType::bit_struct) { |
82 | if (!arch_is_cpu(arch_)) { |
83 | TI_ERROR_IF(data_type_bits(snode.physical_type) < 32, |
84 | "bit_struct physical type must be at least 32 bits on " |
85 | "non-CPU backends." ); |
86 | } |
87 | body_type = tlctx_->get_data_type(snode.physical_type); |
88 | } else if (type == SNodeType::quant_array) { |
89 | // A quant array SNode should have only one child |
90 | TI_ASSERT(snode.ch.size() == 1); |
91 | auto &ch = snode.ch[0]; |
92 | Type *ch_type = ch->dt; |
93 | if (!arch_is_cpu(arch_)) { |
94 | TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16, |
95 | "quant_array physical type must be at least 32 bits on " |
96 | "non-CPU backends." ); |
97 | } |
98 | snode.dt = TypeFactory::get_instance().get_quant_array_type( |
99 | snode.physical_type, ch_type, snode.num_cells_per_container); |
100 | |
101 | DataType container_primitive_type(snode.physical_type); |
102 | body_type = tlctx_->get_data_type(container_primitive_type); |
103 | } else if (type == SNodeType::pointer) { |
104 | // mutex |
105 | aux_type = llvm::ArrayType::get(llvm::PointerType::getInt64Ty(*ctx), |
106 | snode.max_num_elements()); |
107 | body_type = llvm::ArrayType::get(llvm::PointerType::getInt8PtrTy(*ctx), |
108 | snode.max_num_elements()); |
109 | } else if (type == SNodeType::dynamic) { |
110 | // mutex and n (number of elements) |
111 | aux_type = |
112 | llvm::StructType::get(*ctx, {llvm::PointerType::getInt32Ty(*ctx), |
113 | llvm::PointerType::getInt32Ty(*ctx)}); |
114 | body_type = llvm::PointerType::getInt8PtrTy(*ctx); |
115 | } else { |
116 | TI_P(snode.type_name()); |
117 | TI_NOT_IMPLEMENTED; |
118 | } |
119 | if (aux_type != nullptr) { |
120 | node_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "" ); |
121 | } else { |
122 | node_type = body_type; |
123 | } |
124 | |
125 | TI_ASSERT(node_type != nullptr); |
126 | TI_ASSERT(body_type != nullptr); |
127 | |
128 | // Here we create a stub holding 4 LLVM types as struct members. |
129 | // The aim is to give a **unique** name to the stub, so that we can look up |
130 | // these types using this name. This decouples them from the LLVM context. |
131 | // Note that body_type might not have a unique name, since literal structs |
132 | // (such as {i32, i32}) cannot be aliased in LLVM. |
133 | auto stub = llvm::StructType::create( |
134 | *ctx, |
135 | {node_type, body_type, aux_type ? aux_type : llvm::Type::getInt8Ty(*ctx), |
136 | // aux_type might be null |
137 | ch_type}, |
138 | type_stub_name(&snode)); |
139 | |
140 | // Create a dummy function in the module with the type stub as return type |
141 | // so that the type is referenced in the module |
142 | auto ft = llvm::FunctionType::get(stub, false); |
143 | create_function(ft, type_stub_name(&snode) + "_func" ); |
144 | } |
145 | |
146 | void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { |
147 | TI_AUTO_PROF; |
148 | auto coord_type = get_runtime_type("PhysicalCoordinates" ); |
149 | auto coord_type_ptr = llvm::PointerType::get(coord_type, 0); |
150 | |
151 | auto ft = llvm::FunctionType::get( |
152 | llvm::Type::getVoidTy(*llvm_ctx_), |
153 | {coord_type_ptr, coord_type_ptr, llvm::Type::getInt32Ty(*llvm_ctx_)}, |
154 | false); |
155 | |
156 | auto func = create_function(ft, snode->refine_coordinates_func_name()); |
157 | |
158 | auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry" , func); |
159 | |
160 | llvm::IRBuilder<> builder(bb, bb->begin()); |
161 | std::vector<llvm::Value *> args; |
162 | |
163 | for (auto &arg : func->args()) { |
164 | args.push_back(&arg); |
165 | } |
166 | |
167 | auto inp_coords = args[0]; |
168 | auto outp_coords = args[1]; |
169 | auto l = args[2]; |
170 | |
171 | for (int i = 0; i < taichi_max_num_indices; i++) { |
172 | auto addition = tlctx_->get_constant(0); |
173 | if (snode->extractors[i].shape > 1) { |
174 | auto prev = tlctx_->get_constant(snode->extractors[i].acc_shape * |
175 | snode->extractors[i].shape); |
176 | auto next = tlctx_->get_constant(snode->extractors[i].acc_shape); |
177 | // Use UDiv/URem instead of SDiv/SRem so that LLVM can optimize them |
178 | // into bitwise operations when the divisor is a power of two. |
179 | addition = builder.CreateUDiv(builder.CreateURem(l, prev), next); |
180 | } |
181 | auto in = call(&builder, "PhysicalCoordinates_get_val" , inp_coords, |
182 | tlctx_->get_constant(i)); |
183 | in = |
184 | builder.CreateMul(in, tlctx_->get_constant(snode->extractors[i].shape)); |
185 | auto added = builder.CreateAdd(in, addition); |
186 | call(&builder, "PhysicalCoordinates_set_val" , outp_coords, |
187 | tlctx_->get_constant(i), added); |
188 | } |
189 | builder.CreateRetVoid(); |
190 | } |
191 | |
192 | void StructCompilerLLVM::generate_child_accessors(SNode &snode) { |
193 | TI_AUTO_PROF; |
194 | auto type = snode.type; |
195 | stack.push_back(&snode); |
196 | |
197 | bool is_leaf = type == SNodeType::place; |
198 | |
199 | if (!is_leaf) { |
200 | generate_refine_coordinates(&snode); |
201 | } |
202 | |
203 | if (snode.parent != nullptr) { |
204 | // create the get ch function |
205 | auto parent = snode.parent; |
206 | |
207 | auto inp_type = |
208 | llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0); |
209 | |
210 | auto ft = |
211 | llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx_), |
212 | {llvm::Type::getInt8PtrTy(*llvm_ctx_)}, false); |
213 | |
214 | auto func = create_function(ft, snode.get_ch_from_parent_func_name()); |
215 | |
216 | auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry" , func); |
217 | |
218 | llvm::IRBuilder<> builder(bb, bb->begin()); |
219 | std::vector<llvm::Value *> args; |
220 | |
221 | for (auto &arg : func->args()) { |
222 | args.push_back(&arg); |
223 | } |
224 | llvm::Value *ret; |
225 | ret = builder.CreateGEP(get_llvm_element_type(module.get(), parent), |
226 | builder.CreateBitCast(args[0], inp_type), |
227 | {tlctx_->get_constant(0), |
228 | tlctx_->get_constant(parent->child_id(&snode))}, |
229 | "getch" ); |
230 | |
231 | builder.CreateRet( |
232 | builder.CreateBitCast(ret, llvm::Type::getInt8PtrTy(*llvm_ctx_))); |
233 | } |
234 | |
235 | for (auto &ch : snode.ch) { |
236 | if (!ch->is_bit_level) |
237 | generate_child_accessors(*ch); |
238 | } |
239 | |
240 | stack.pop_back(); |
241 | } |
242 | |
243 | std::string StructCompilerLLVM::type_stub_name(SNode *snode) { |
244 | return snode->node_type_name + "_type_stubs" ; |
245 | } |
246 | |
247 | void StructCompilerLLVM::run(SNode &root) { |
248 | TI_AUTO_PROF; |
249 | // bottom to top |
250 | collect_snodes(root); |
251 | |
252 | auto snodes_rev = snodes; |
253 | std::reverse(snodes_rev.begin(), snodes_rev.end()); |
254 | |
255 | for (auto &n : snodes_rev) |
256 | generate_types(*n); |
257 | |
258 | generate_child_accessors(root); |
259 | |
260 | if (config_.print_struct_llvm_ir) { |
261 | static FileSequenceWriter writer("taichi_struct_llvm_ir_{:04d}.ll" , |
262 | "struct LLVM IR" ); |
263 | writer.write(module.get()); |
264 | } |
265 | |
266 | TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes); |
267 | |
268 | auto node_type = get_llvm_node_type(module.get(), &root); |
269 | root_size = tlctx_->get_type_size(node_type); |
270 | |
271 | tlctx_->add_struct_module(std::move(module), root.get_snode_tree_id()); |
272 | } |
273 | |
274 | llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module, |
275 | SNode *snode, |
276 | uint32 index) { |
277 | TI_ASSERT(module); |
278 | TI_ASSERT(snode); |
279 | auto stub = llvm::StructType::getTypeByName(module->getContext(), |
280 | type_stub_name(snode)); |
281 | TI_ASSERT(stub); |
282 | TI_ASSERT(stub->getStructNumElements() == 4); |
283 | TI_ASSERT(0 <= index && index < 4); |
284 | auto type = stub->getContainedType(index); |
285 | TI_ASSERT(type); |
286 | return type; |
287 | } |
288 | |
289 | llvm::Type *StructCompilerLLVM::get_llvm_node_type(llvm::Module *module, |
290 | SNode *snode) { |
291 | return get_stub(module, snode, 0); |
292 | } |
293 | |
294 | llvm::Type *StructCompilerLLVM::get_llvm_body_type(llvm::Module *module, |
295 | SNode *snode) { |
296 | return get_stub(module, snode, 1); |
297 | } |
298 | |
299 | llvm::Type *StructCompilerLLVM::get_llvm_aux_type(llvm::Module *module, |
300 | SNode *snode) { |
301 | return get_stub(module, snode, 2); |
302 | } |
303 | |
304 | llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module, |
305 | SNode *snode) { |
306 | return get_stub(module, snode, 3); |
307 | } |
308 | |
309 | llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft, |
310 | std::string func_name) { |
311 | return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name, |
312 | *module); |
313 | } |
314 | |
315 | } // namespace taichi::lang |
316 | |