1#include "taichi/codegen/llvm/struct_llvm.h"
2
3#include "llvm/IR/Verifier.h"
4#include "llvm/IR/IRBuilder.h"
5
6#include "taichi/ir/ir.h"
7#include "taichi/struct/struct.h"
8#include "taichi/util/file_sequence_writer.h"
9
10namespace taichi::lang {
11
12StructCompilerLLVM::StructCompilerLLVM(Arch arch,
13 const CompileConfig &config,
14 TaichiLLVMContext *tlctx,
15 std::unique_ptr<llvm::Module> &&module,
16 int snode_tree_id)
17 : LLVMModuleBuilder(std::move(module), tlctx),
18 arch_(arch),
19 config_(config),
20 tlctx_(tlctx),
21 llvm_ctx_(tlctx_->get_this_thread_context()),
22 snode_tree_id_(snode_tree_id) {
23}
24
25StructCompilerLLVM::StructCompilerLLVM(Arch arch,
26 LlvmProgramImpl *prog,
27 std::unique_ptr<llvm::Module> &&module,
28 int snode_tree_id)
29 : StructCompilerLLVM(arch,
30 *prog->config,
31 prog->get_llvm_context(),
32 std::move(module),
33 snode_tree_id) {
34}
35
36void StructCompilerLLVM::generate_types(SNode &snode) {
37 TI_AUTO_PROF;
38 auto type = snode.type;
39 if (snode.is_bit_level)
40 return;
41 llvm::Type *node_type = nullptr;
42
43 auto ctx = llvm_ctx_;
44 TI_ASSERT(ctx == tlctx_->get_this_thread_context());
45
46 // create children type that supports forking...
47
48 std::vector<llvm::Type *> ch_types;
49 for (int i = 0; i < snode.ch.size(); i++) {
50 if (!snode.ch[i]->is_bit_level) {
51 // Bit-level SNodes do not really have a corresponding LLVM type
52 auto ch = get_llvm_node_type(module.get(), snode.ch[i].get());
53 ch_types.push_back(ch);
54 }
55 }
56
57 auto ch_type =
58 llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch");
59
60 snode.cell_size_bytes = tlctx_->get_type_size(ch_type);
61
62 for (int i = 0; i < snode.ch.size(); i++) {
63 if (!snode.ch[i]->is_bit_level) {
64 snode.ch[i]->offset_bytes_in_parent_cell =
65 tlctx_->get_struct_element_offset(ch_type, i);
66 }
67 }
68
69 llvm::Type *body_type = nullptr, *aux_type = nullptr;
70 if (type == SNodeType::dense || type == SNodeType::bitmasked) {
71 TI_ASSERT(snode._morton == false);
72 body_type = llvm::ArrayType::get(ch_type, snode.max_num_elements());
73 if (type == SNodeType::bitmasked) {
74 aux_type = llvm::ArrayType::get(llvm::Type::getInt32Ty(*llvm_ctx_),
75 (snode.max_num_elements() + 31) / 32);
76 }
77 } else if (type == SNodeType::root) {
78 body_type = ch_type;
79 } else if (type == SNodeType::place) {
80 body_type = tlctx_->get_data_type(snode.dt);
81 } else if (type == SNodeType::bit_struct) {
82 if (!arch_is_cpu(arch_)) {
83 TI_ERROR_IF(data_type_bits(snode.physical_type) < 32,
84 "bit_struct physical type must be at least 32 bits on "
85 "non-CPU backends.");
86 }
87 body_type = tlctx_->get_data_type(snode.physical_type);
88 } else if (type == SNodeType::quant_array) {
89 // A quant array SNode should have only one child
90 TI_ASSERT(snode.ch.size() == 1);
91 auto &ch = snode.ch[0];
92 Type *ch_type = ch->dt;
93 if (!arch_is_cpu(arch_)) {
94 TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16,
95 "quant_array physical type must be at least 32 bits on "
96 "non-CPU backends.");
97 }
98 snode.dt = TypeFactory::get_instance().get_quant_array_type(
99 snode.physical_type, ch_type, snode.num_cells_per_container);
100
101 DataType container_primitive_type(snode.physical_type);
102 body_type = tlctx_->get_data_type(container_primitive_type);
103 } else if (type == SNodeType::pointer) {
104 // mutex
105 aux_type = llvm::ArrayType::get(llvm::PointerType::getInt64Ty(*ctx),
106 snode.max_num_elements());
107 body_type = llvm::ArrayType::get(llvm::PointerType::getInt8PtrTy(*ctx),
108 snode.max_num_elements());
109 } else if (type == SNodeType::dynamic) {
110 // mutex and n (number of elements)
111 aux_type =
112 llvm::StructType::get(*ctx, {llvm::PointerType::getInt32Ty(*ctx),
113 llvm::PointerType::getInt32Ty(*ctx)});
114 body_type = llvm::PointerType::getInt8PtrTy(*ctx);
115 } else {
116 TI_P(snode.type_name());
117 TI_NOT_IMPLEMENTED;
118 }
119 if (aux_type != nullptr) {
120 node_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "");
121 } else {
122 node_type = body_type;
123 }
124
125 TI_ASSERT(node_type != nullptr);
126 TI_ASSERT(body_type != nullptr);
127
128 // Here we create a stub holding 4 LLVM types as struct members.
129 // The aim is to give a **unique** name to the stub, so that we can look up
130 // these types using this name. This decouples them from the LLVM context.
131 // Note that body_type might not have a unique name, since literal structs
132 // (such as {i32, i32}) cannot be aliased in LLVM.
133 auto stub = llvm::StructType::create(
134 *ctx,
135 {node_type, body_type, aux_type ? aux_type : llvm::Type::getInt8Ty(*ctx),
136 // aux_type might be null
137 ch_type},
138 type_stub_name(&snode));
139
140 // Create a dummy function in the module with the type stub as return type
141 // so that the type is referenced in the module
142 auto ft = llvm::FunctionType::get(stub, false);
143 create_function(ft, type_stub_name(&snode) + "_func");
144}
145
146void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) {
147 TI_AUTO_PROF;
148 auto coord_type = get_runtime_type("PhysicalCoordinates");
149 auto coord_type_ptr = llvm::PointerType::get(coord_type, 0);
150
151 auto ft = llvm::FunctionType::get(
152 llvm::Type::getVoidTy(*llvm_ctx_),
153 {coord_type_ptr, coord_type_ptr, llvm::Type::getInt32Ty(*llvm_ctx_)},
154 false);
155
156 auto func = create_function(ft, snode->refine_coordinates_func_name());
157
158 auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func);
159
160 llvm::IRBuilder<> builder(bb, bb->begin());
161 std::vector<llvm::Value *> args;
162
163 for (auto &arg : func->args()) {
164 args.push_back(&arg);
165 }
166
167 auto inp_coords = args[0];
168 auto outp_coords = args[1];
169 auto l = args[2];
170
171 for (int i = 0; i < taichi_max_num_indices; i++) {
172 auto addition = tlctx_->get_constant(0);
173 if (snode->extractors[i].shape > 1) {
174 auto prev = tlctx_->get_constant(snode->extractors[i].acc_shape *
175 snode->extractors[i].shape);
176 auto next = tlctx_->get_constant(snode->extractors[i].acc_shape);
177 // Use UDiv/URem instead of SDiv/SRem so that LLVM can optimize them
178 // into bitwise operations when the divisor is a power of two.
179 addition = builder.CreateUDiv(builder.CreateURem(l, prev), next);
180 }
181 auto in = call(&builder, "PhysicalCoordinates_get_val", inp_coords,
182 tlctx_->get_constant(i));
183 in =
184 builder.CreateMul(in, tlctx_->get_constant(snode->extractors[i].shape));
185 auto added = builder.CreateAdd(in, addition);
186 call(&builder, "PhysicalCoordinates_set_val", outp_coords,
187 tlctx_->get_constant(i), added);
188 }
189 builder.CreateRetVoid();
190}
191
192void StructCompilerLLVM::generate_child_accessors(SNode &snode) {
193 TI_AUTO_PROF;
194 auto type = snode.type;
195 stack.push_back(&snode);
196
197 bool is_leaf = type == SNodeType::place;
198
199 if (!is_leaf) {
200 generate_refine_coordinates(&snode);
201 }
202
203 if (snode.parent != nullptr) {
204 // create the get ch function
205 auto parent = snode.parent;
206
207 auto inp_type =
208 llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0);
209
210 auto ft =
211 llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx_),
212 {llvm::Type::getInt8PtrTy(*llvm_ctx_)}, false);
213
214 auto func = create_function(ft, snode.get_ch_from_parent_func_name());
215
216 auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func);
217
218 llvm::IRBuilder<> builder(bb, bb->begin());
219 std::vector<llvm::Value *> args;
220
221 for (auto &arg : func->args()) {
222 args.push_back(&arg);
223 }
224 llvm::Value *ret;
225 ret = builder.CreateGEP(get_llvm_element_type(module.get(), parent),
226 builder.CreateBitCast(args[0], inp_type),
227 {tlctx_->get_constant(0),
228 tlctx_->get_constant(parent->child_id(&snode))},
229 "getch");
230
231 builder.CreateRet(
232 builder.CreateBitCast(ret, llvm::Type::getInt8PtrTy(*llvm_ctx_)));
233 }
234
235 for (auto &ch : snode.ch) {
236 if (!ch->is_bit_level)
237 generate_child_accessors(*ch);
238 }
239
240 stack.pop_back();
241}
242
243std::string StructCompilerLLVM::type_stub_name(SNode *snode) {
244 return snode->node_type_name + "_type_stubs";
245}
246
247void StructCompilerLLVM::run(SNode &root) {
248 TI_AUTO_PROF;
249 // bottom to top
250 collect_snodes(root);
251
252 auto snodes_rev = snodes;
253 std::reverse(snodes_rev.begin(), snodes_rev.end());
254
255 for (auto &n : snodes_rev)
256 generate_types(*n);
257
258 generate_child_accessors(root);
259
260 if (config_.print_struct_llvm_ir) {
261 static FileSequenceWriter writer("taichi_struct_llvm_ir_{:04d}.ll",
262 "struct LLVM IR");
263 writer.write(module.get());
264 }
265
266 TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes);
267
268 auto node_type = get_llvm_node_type(module.get(), &root);
269 root_size = tlctx_->get_type_size(node_type);
270
271 tlctx_->add_struct_module(std::move(module), root.get_snode_tree_id());
272}
273
274llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module,
275 SNode *snode,
276 uint32 index) {
277 TI_ASSERT(module);
278 TI_ASSERT(snode);
279 auto stub = llvm::StructType::getTypeByName(module->getContext(),
280 type_stub_name(snode));
281 TI_ASSERT(stub);
282 TI_ASSERT(stub->getStructNumElements() == 4);
283 TI_ASSERT(0 <= index && index < 4);
284 auto type = stub->getContainedType(index);
285 TI_ASSERT(type);
286 return type;
287}
288
289llvm::Type *StructCompilerLLVM::get_llvm_node_type(llvm::Module *module,
290 SNode *snode) {
291 return get_stub(module, snode, 0);
292}
293
294llvm::Type *StructCompilerLLVM::get_llvm_body_type(llvm::Module *module,
295 SNode *snode) {
296 return get_stub(module, snode, 1);
297}
298
299llvm::Type *StructCompilerLLVM::get_llvm_aux_type(llvm::Module *module,
300 SNode *snode) {
301 return get_stub(module, snode, 2);
302}
303
304llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module,
305 SNode *snode) {
306 return get_stub(module, snode, 3);
307}
308
309llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft,
310 std::string func_name) {
311 return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name,
312 *module);
313}
314
315} // namespace taichi::lang
316