1 | #include "taichi/codegen/llvm/codegen_llvm.h" |
2 | |
3 | #include <algorithm> |
4 | |
5 | #ifdef TI_WITH_LLVM |
6 | #include "llvm/Bitcode/BitcodeReader.h" |
7 | #include "llvm/IR/Module.h" |
8 | #include "llvm/Linker/Linker.h" |
9 | #include "taichi/analysis/offline_cache_util.h" |
10 | #include "taichi/ir/statements.h" |
11 | #include "taichi/ir/transforms.h" |
12 | #include "taichi/runtime/llvm/launch_arg_info.h" |
13 | #include "taichi/runtime/llvm/llvm_offline_cache.h" |
14 | #include "taichi/program/extension.h" |
15 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
16 | #include "taichi/codegen/llvm/struct_llvm.h" |
17 | #include "taichi/util/file_sequence_writer.h" |
18 | #include "taichi/codegen/codegen_utils.h" |
19 | |
20 | namespace taichi::lang { |
21 | |
22 | // TODO: sort function definitions to match declaration order in header |
23 | |
24 | // TODO(k-ye): Hide FunctionCreationGuard inside cpp file |
25 | FunctionCreationGuard::FunctionCreationGuard( |
26 | TaskCodeGenLLVM *mb, |
27 | std::vector<llvm::Type *> arguments, |
28 | const std::string &func_name) |
29 | : mb(mb) { |
30 | // Create the loop body function |
31 | auto body_function_type = llvm::FunctionType::get( |
32 | llvm::Type::getVoidTy(*mb->llvm_context), arguments, false); |
33 | |
34 | body = llvm::Function::Create(body_function_type, |
35 | llvm::Function::InternalLinkage, func_name, |
36 | mb->module.get()); |
37 | old_func = mb->func; |
38 | // emit into loop body function |
39 | mb->func = body; |
40 | |
41 | allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs" , body); |
42 | old_entry = mb->entry_block; |
43 | mb->entry_block = allocas; |
44 | |
45 | final = llvm::BasicBlock::Create(*mb->llvm_context, "final" , body); |
46 | old_final = mb->final_block; |
47 | mb->final_block = final; |
48 | |
49 | entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry" , mb->func); |
50 | |
51 | ip = mb->builder->saveIP(); |
52 | mb->builder->SetInsertPoint(entry); |
53 | |
54 | auto body_bb = |
55 | llvm::BasicBlock::Create(*mb->llvm_context, "function_body" , mb->func); |
56 | mb->builder->CreateBr(body_bb); |
57 | mb->builder->SetInsertPoint(body_bb); |
58 | } |
59 | |
60 | FunctionCreationGuard::~FunctionCreationGuard() { |
61 | if (!mb->returned) { |
62 | mb->builder->CreateBr(final); |
63 | } |
64 | mb->builder->SetInsertPoint(final); |
65 | mb->builder->CreateRetVoid(); |
66 | mb->returned = false; |
67 | |
68 | mb->builder->SetInsertPoint(allocas); |
69 | mb->builder->CreateBr(entry); |
70 | |
71 | mb->entry_block = old_entry; |
72 | mb->final_block = old_final; |
73 | mb->func = old_func; |
74 | mb->builder->restoreIP(ip); |
75 | |
76 | TI_ASSERT(!llvm::verifyFunction(*body, &llvm::errs())); |
77 | } |
78 | |
79 | namespace { |
80 | |
81 | class CodeGenStmtGuard { |
82 | public: |
83 | using Getter = std::function<llvm::BasicBlock *(void)>; |
84 | using Setter = std::function<void(llvm::BasicBlock *)>; |
85 | |
86 | explicit CodeGenStmtGuard(Getter getter, Setter setter) |
87 | : saved_stmt_(getter()), setter_(std::move(setter)) { |
88 | } |
89 | |
90 | ~CodeGenStmtGuard() { |
91 | setter_(saved_stmt_); |
92 | } |
93 | |
94 | CodeGenStmtGuard(CodeGenStmtGuard &&) = default; |
95 | CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default; |
96 | |
97 | private: |
98 | llvm::BasicBlock *saved_stmt_; |
99 | Setter setter_; |
100 | }; |
101 | |
102 | CodeGenStmtGuard make_loop_reentry_guard(TaskCodeGenLLVM *cg) { |
103 | return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; }, |
104 | [cg](llvm::BasicBlock *saved_stmt) { |
105 | cg->current_loop_reentry = saved_stmt; |
106 | }); |
107 | } |
108 | |
109 | CodeGenStmtGuard make_while_after_loop_guard(TaskCodeGenLLVM *cg) { |
110 | return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; }, |
111 | [cg](llvm::BasicBlock *saved_stmt) { |
112 | cg->current_while_after_loop = saved_stmt; |
113 | }); |
114 | } |
115 | |
116 | } // namespace |
117 | |
118 | // TaskCodeGenLLVM |
119 | void TaskCodeGenLLVM::visit(Block *stmt_list) { |
120 | for (auto &stmt : stmt_list->statements) { |
121 | stmt->accept(this); |
122 | if (returned) { |
123 | break; |
124 | } |
125 | } |
126 | } |
127 | |
128 | void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { |
129 | if (stmt->ret_type->is<TensorType>()) { |
130 | auto tensor_type = stmt->ret_type->cast<TensorType>(); |
131 | auto type = tlctx->get_data_type(tensor_type); |
132 | if (stmt->is_shared) { |
133 | auto base = new llvm::GlobalVariable( |
134 | *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, |
135 | fmt::format("shared_array_{}" , stmt->id), nullptr, |
136 | llvm::GlobalVariable::NotThreadLocal, 3 /*addrspace=shared*/); |
137 | base->setAlignment(llvm::MaybeAlign(8)); |
138 | auto ptr_type = llvm::PointerType::get(type, 0); |
139 | llvm_val[stmt] = builder->CreatePointerCast(base, ptr_type); |
140 | } else { |
141 | llvm_val[stmt] = create_entry_block_alloca(type); |
142 | } |
143 | } else { |
144 | llvm_val[stmt] = |
145 | create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer()); |
146 | // initialize as zero if element is not a pointer |
147 | if (!stmt->ret_type.is_pointer()) |
148 | builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0), |
149 | llvm_val[stmt]); |
150 | } |
151 | } |
152 | |
153 | void TaskCodeGenLLVM::visit(RandStmt *stmt) { |
154 | if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
155 | // Promoting to f32 since there's no rand_f16 support in runtime.cpp. |
156 | auto val_f32 = call("rand_f32" , get_context()); |
157 | llvm_val[stmt] = |
158 | builder->CreateFPTrunc(val_f32, llvm::Type::getHalfTy(*llvm_context)); |
159 | } else { |
160 | llvm_val[stmt] = call( |
161 | fmt::format("rand_{}" , data_type_name(stmt->ret_type)), get_context()); |
162 | } |
163 | } |
164 | |
165 | void TaskCodeGenLLVM::(UnaryOpStmt *stmt) { |
166 | auto input = llvm_val[stmt->operand]; |
167 | auto input_taichi_type = stmt->operand->ret_type; |
168 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) { |
169 | // Promote to f32 since we don't have f16 support for extra unary ops in in |
170 | // runtime.cpp. |
171 | input = builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context)); |
172 | input_taichi_type = PrimitiveType::f32; |
173 | } |
174 | |
175 | auto op = stmt->op_type; |
176 | auto input_type = input->getType(); |
177 | |
178 | #define UNARY_STD(x) \ |
179 | else if (op == UnaryOpType::x) { \ |
180 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \ |
181 | llvm_val[stmt] = call(#x "_f32", input); \ |
182 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \ |
183 | llvm_val[stmt] = call(#x "_f64", input); \ |
184 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ |
185 | llvm_val[stmt] = call(#x "_i32", input); \ |
186 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { \ |
187 | llvm_val[stmt] = call(#x "_i64", input); \ |
188 | } else { \ |
189 | TI_NOT_IMPLEMENTED \ |
190 | } \ |
191 | } |
192 | if (false) { |
193 | } |
194 | UNARY_STD(abs) |
195 | UNARY_STD(exp) |
196 | UNARY_STD(log) |
197 | UNARY_STD(tan) |
198 | UNARY_STD(tanh) |
199 | UNARY_STD(sgn) |
200 | UNARY_STD(logic_not) |
201 | UNARY_STD(acos) |
202 | UNARY_STD(asin) |
203 | UNARY_STD(cos) |
204 | UNARY_STD(sin) |
205 | else if (op == UnaryOpType::sqrt) { |
206 | llvm_val[stmt] = |
207 | builder->CreateIntrinsic(llvm::Intrinsic::sqrt, {input_type}, {input}); |
208 | } |
209 | else { |
210 | TI_P(unary_op_type_name(op)); |
211 | TI_NOT_IMPLEMENTED |
212 | } |
213 | #undef UNARY_STD |
214 | if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
215 | // Convert back to f16 |
216 | llvm_val[stmt] = builder->CreateFPTrunc( |
217 | llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); |
218 | } |
219 | } |
220 | |
221 | std::unique_ptr<RuntimeObject> TaskCodeGenLLVM::emit_struct_meta_object( |
222 | SNode *snode) { |
223 | std::unique_ptr<RuntimeObject> meta; |
224 | if (snode->type == SNodeType::dense) { |
225 | meta = std::make_unique<RuntimeObject>("DenseMeta" , this, builder.get()); |
226 | emit_struct_meta_base("Dense" , meta->ptr, snode); |
227 | meta->call("set_morton_dim" , tlctx->get_constant((int)snode->_morton)); |
228 | } else if (snode->type == SNodeType::pointer) { |
229 | meta = std::make_unique<RuntimeObject>("PointerMeta" , this, builder.get()); |
230 | emit_struct_meta_base("Pointer" , meta->ptr, snode); |
231 | } else if (snode->type == SNodeType::root) { |
232 | meta = std::make_unique<RuntimeObject>("RootMeta" , this, builder.get()); |
233 | emit_struct_meta_base("Root" , meta->ptr, snode); |
234 | } else if (snode->type == SNodeType::dynamic) { |
235 | meta = std::make_unique<RuntimeObject>("DynamicMeta" , this, builder.get()); |
236 | emit_struct_meta_base("Dynamic" , meta->ptr, snode); |
237 | meta->call("set_chunk_size" , tlctx->get_constant(snode->chunk_size)); |
238 | } else if (snode->type == SNodeType::bitmasked) { |
239 | meta = |
240 | std::make_unique<RuntimeObject>("BitmaskedMeta" , this, builder.get()); |
241 | emit_struct_meta_base("Bitmasked" , meta->ptr, snode); |
242 | } else if (snode->type == SNodeType::quant_array) { |
243 | meta = std::make_unique<RuntimeObject>("DenseMeta" , this, builder.get()); |
244 | emit_struct_meta_base("Dense" , meta->ptr, snode); |
245 | } else { |
246 | TI_P(snode_type_name(snode->type)); |
247 | TI_NOT_IMPLEMENTED; |
248 | } |
249 | return meta; |
250 | } |
251 | |
252 | void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name, |
253 | llvm::Value *node_meta, |
254 | SNode *snode) { |
255 | RuntimeObject common("StructMeta" , this, builder.get(), node_meta); |
256 | std::size_t element_size; |
257 | if (snode->type == SNodeType::dense) { |
258 | auto body_type = |
259 | StructCompilerLLVM::get_llvm_body_type(module.get(), snode); |
260 | auto element_ty = body_type->getArrayElementType(); |
261 | element_size = tlctx->get_type_size(element_ty); |
262 | } else if (snode->type == SNodeType::pointer) { |
263 | auto element_ty = StructCompilerLLVM::get_llvm_node_type( |
264 | module.get(), snode->ch[0].get()); |
265 | element_size = tlctx->get_type_size(element_ty); |
266 | } else { |
267 | auto element_ty = |
268 | StructCompilerLLVM::get_llvm_element_type(module.get(), snode); |
269 | element_size = tlctx->get_type_size(element_ty); |
270 | } |
271 | common.set("snode_id" , tlctx->get_constant(snode->id)); |
272 | common.set("element_size" , tlctx->get_constant((uint64)element_size)); |
273 | common.set("max_num_elements" , |
274 | tlctx->get_constant(snode->max_num_elements())); |
275 | common.set("context" , get_context()); |
276 | |
277 | /* |
278 | uint8 *(*lookup_element)(uint8 *, int i); |
279 | uint8 *(*from_parent_element)(uint8 *); |
280 | bool (*is_active)(uint8 *, int i); |
281 | int (*get_num_elements)(uint8 *); |
282 | void (*refine_coordinates)(PhysicalCoordinates *inp_coord, |
283 | PhysicalCoordinates *refined_coord, |
284 | int index); |
285 | */ |
286 | |
287 | std::vector<std::string> functions = {"lookup_element" , "is_active" , |
288 | "get_num_elements" }; |
289 | |
290 | for (auto const &f : functions) |
291 | common.set(f, get_runtime_function(fmt::format("{}_{}" , name, f))); |
292 | |
293 | // "from_parent_element", "refine_coordinates" are different for different |
294 | // snodes, even if they have the same type. |
295 | if (snode->parent) |
296 | common.set("from_parent_element" , |
297 | get_struct_function(snode->get_ch_from_parent_func_name(), |
298 | snode->get_snode_tree_id())); |
299 | |
300 | if (snode->type != SNodeType::place) |
301 | common.set("refine_coordinates" , |
302 | get_struct_function(snode->refine_coordinates_func_name(), |
303 | snode->get_snode_tree_id())); |
304 | } |
305 | |
306 | TaskCodeGenLLVM::TaskCodeGenLLVM(const CompileConfig &compile_config, |
307 | TaichiLLVMContext &tlctx, |
308 | Kernel *kernel, |
309 | IRNode *ir, |
310 | std::unique_ptr<llvm::Module> &&module) |
311 | // TODO: simplify LLVMModuleBuilder ctor input |
312 | : LLVMModuleBuilder( |
313 | module == nullptr ? tlctx.new_module("kernel" ) : std::move(module), |
314 | &tlctx), |
315 | compile_config(compile_config), |
316 | kernel(kernel), |
317 | ir(ir), |
318 | prog(kernel->program) { |
319 | if (ir == nullptr) |
320 | this->ir = kernel->ir.get(); |
321 | initialize_context(); |
322 | |
323 | context_ty = get_runtime_type("RuntimeContext" ); |
324 | physical_coordinate_ty = get_runtime_type(kLLVMPhysicalCoordinatesName); |
325 | |
326 | kernel_name = kernel->name + "_kernel" ; |
327 | current_callable = kernel; |
328 | } |
329 | |
330 | void TaskCodeGenLLVM::visit(DecorationStmt *stmt) { |
331 | } |
332 | |
333 | void TaskCodeGenLLVM::create_elementwise_cast( |
334 | UnaryOpStmt *stmt, |
335 | llvm::Type *to_ty, |
336 | std::function<llvm::Value *(llvm::Value *, llvm::Type *)> func, |
337 | bool on_self) { |
338 | auto from_ty = stmt->operand->ret_type->cast<TensorType>(); |
339 | TI_ASSERT_INFO(from_ty, |
340 | "Cannot perform elementwise ops on non-tensor type {}" , |
341 | from_ty->to_string()); |
342 | llvm::Value *vec = llvm::UndefValue::get(llvm::VectorType::get( |
343 | to_ty, from_ty->get_num_elements(), /*scalable=*/false)); |
344 | for (int i = 0; i < from_ty->get_num_elements(); ++i) { |
345 | auto elem = builder->CreateExtractElement( |
346 | on_self ? llvm_val[stmt] : llvm_val[stmt->operand], i); |
347 | auto cast_value = func(elem, to_ty); |
348 | vec = builder->CreateInsertElement(vec, cast_value, i); |
349 | } |
350 | llvm_val[stmt] = vec; |
351 | } |
352 | |
353 | void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { |
354 | auto input = llvm_val[stmt->operand]; |
355 | auto input_type = input->getType(); |
356 | auto op = stmt->op_type; |
357 | |
358 | #define UNARY_INTRINSIC(x) \ |
359 | else if (op == UnaryOpType::x) { \ |
360 | llvm_val[stmt] = \ |
361 | builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \ |
362 | } |
363 | if (stmt->op_type == UnaryOpType::cast_value) { |
364 | llvm::CastInst::CastOps cast_op; |
365 | auto from = stmt->operand->ret_type; |
366 | auto to = stmt->ret_type; |
367 | TI_ASSERT_INFO( |
368 | from->is<TensorType>() == to->is<TensorType>(), |
369 | "Cannot cast between tensor type and non-tensor type: {} v.s. {}" , |
370 | from->to_string(), to->to_string()); |
371 | |
372 | if (from == to) { |
373 | llvm_val[stmt] = llvm_val[stmt->operand]; |
374 | } else if (is_real(from.get_element_type()) != |
375 | is_real(to.get_element_type())) { |
376 | if (is_real(from.get_element_type()) && |
377 | (is_integral(to.get_element_type()))) { |
378 | cast_op = (is_signed(to.get_element_type())) |
379 | ? llvm::Instruction::CastOps::FPToSI |
380 | : llvm::Instruction::CastOps::FPToUI; |
381 | } else if (is_integral(from.get_element_type()) && |
382 | is_real(to.get_element_type())) { |
383 | cast_op = (is_signed(from.get_element_type())) |
384 | ? llvm::Instruction::CastOps::SIToFP |
385 | : llvm::Instruction::CastOps::UIToFP; |
386 | } else { |
387 | TI_P(data_type_name(from)); |
388 | TI_P(data_type_name(to)); |
389 | TI_NOT_IMPLEMENTED; |
390 | } |
391 | bool use_f16 = to->is_primitive(PrimitiveTypeID::f16) || |
392 | (to->is<TensorType>() && |
393 | to->cast<TensorType>()->get_element_type()->is_primitive( |
394 | PrimitiveTypeID::f16)); |
395 | auto cast_type = use_f16 ? (to->is<TensorType>() |
396 | ? TypeFactory::create_tensor_type( |
397 | to->cast<TensorType>()->get_shape(), |
398 | PrimitiveType::f32) |
399 | : PrimitiveType::f32) |
400 | : stmt->cast_type; |
401 | |
402 | auto cast_func = [this, cast_op](llvm::Value *value, llvm::Type *type) { |
403 | return this->builder->CreateCast(cast_op, value, type); |
404 | }; |
405 | if (!cast_type->is<TensorType>()) { |
406 | llvm_val[stmt] = cast_func(input, tlctx->get_data_type(cast_type)); |
407 | } else { |
408 | create_elementwise_cast( |
409 | stmt, |
410 | tlctx->get_data_type( |
411 | cast_type->cast<TensorType>()->get_element_type()), |
412 | cast_func); |
413 | } |
414 | |
415 | if (use_f16) { |
416 | auto trunc_func = [this](llvm::Value *value, llvm::Type *type) { |
417 | return this->builder->CreateFPTrunc(value, type); |
418 | }; |
419 | auto to_ty = llvm::Type::getHalfTy(*llvm_context); |
420 | if (!cast_type->is<TensorType>()) { |
421 | llvm_val[stmt] = trunc_func(llvm_val[stmt], to_ty); |
422 | } else { |
423 | create_elementwise_cast(stmt, to_ty, trunc_func, /*on_self=*/true); |
424 | } |
425 | } |
426 | } else if (is_real(from.get_element_type()) && |
427 | is_real(to.get_element_type())) { |
428 | auto t1 = from->is<TensorType>() |
429 | ? from->cast<TensorType>()->get_element_type() |
430 | : from.operator->(); |
431 | auto t2 = to->is<TensorType>() |
432 | ? to->cast<TensorType>()->get_element_type() |
433 | : to.operator->(); |
434 | if (data_type_size(t1) < data_type_size(t2)) { |
435 | auto cast_func = [this](llvm::Value *value, llvm::Type *type) { |
436 | return this->builder->CreateFPExt(value, type); |
437 | }; |
438 | if (!stmt->cast_type->is<TensorType>()) { |
439 | llvm_val[stmt] = |
440 | cast_func(input, tlctx->get_data_type(stmt->cast_type)); |
441 | } else { |
442 | create_elementwise_cast( |
443 | stmt, |
444 | tlctx->get_data_type( |
445 | stmt->cast_type->cast<TensorType>()->get_element_type()), |
446 | cast_func); |
447 | } |
448 | } else { |
449 | if (to->is_primitive(PrimitiveTypeID::f16) || |
450 | (to->is<TensorType>() && |
451 | to->cast<TensorType>()->get_element_type()->is_primitive( |
452 | PrimitiveTypeID::f16))) { |
453 | if (!to->is<TensorType>()) { |
454 | llvm_val[stmt] = builder->CreateFPTrunc( |
455 | builder->CreateFPTrunc(llvm_val[stmt->operand], |
456 | llvm::Type::getFloatTy(*llvm_context)), |
457 | llvm::Type::getHalfTy(*llvm_context)); |
458 | } else { |
459 | auto tensor_type = to->cast<TensorType>(); |
460 | llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(to)); |
461 | for (int i = 0; i < tensor_type->get_num_elements(); ++i) { |
462 | auto elem = builder->CreateExtractElement(vec, i); |
463 | auto double_trunced = builder->CreateFPTrunc( |
464 | builder->CreateFPTrunc(elem, |
465 | llvm::Type::getFloatTy(*llvm_context)), |
466 | llvm::Type::getHalfTy(*llvm_context)); |
467 | vec = builder->CreateInsertElement(vec, double_trunced, i); |
468 | } |
469 | llvm_val[stmt] = vec; |
470 | } |
471 | } else { |
472 | auto trunc_fn = [this](llvm::Value *value, llvm::Type *type) { |
473 | return this->builder->CreateFPTrunc(value, type); |
474 | }; |
475 | auto cast_type = |
476 | stmt->cast_type->is<TensorType>() |
477 | ? stmt->cast_type->cast<TensorType>()->get_element_type() |
478 | : stmt->cast_type.operator->(); |
479 | if (!stmt->cast_type->is<TensorType>()) { |
480 | llvm_val[stmt] = trunc_fn(input, tlctx->get_data_type(cast_type)); |
481 | } else { |
482 | create_elementwise_cast( |
483 | stmt, |
484 | tlctx->get_data_type( |
485 | cast_type->cast<TensorType>()->get_element_type()), |
486 | trunc_fn); |
487 | } |
488 | } |
489 | } |
490 | } else if (!is_real(from.get_element_type()) && |
491 | !is_real(to.get_element_type())) { |
492 | llvm_val[stmt] = builder->CreateIntCast( |
493 | llvm_val[stmt->operand], tlctx->get_data_type(to), |
494 | is_signed(from.get_element_type())); |
495 | } |
496 | } else if (stmt->op_type == UnaryOpType::cast_bits) { |
497 | TI_ASSERT(data_type_size(stmt->ret_type) == |
498 | data_type_size(stmt->cast_type)); |
499 | if (stmt->operand->ret_type.is_pointer()) { |
500 | TI_ASSERT(is_integral(stmt->cast_type)); |
501 | llvm_val[stmt] = builder->CreatePtrToInt( |
502 | llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); |
503 | } else { |
504 | llvm_val[stmt] = builder->CreateBitCast( |
505 | llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); |
506 | } |
507 | } else if (op == UnaryOpType::rsqrt) { |
508 | llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration( |
509 | module.get(), llvm::Intrinsic::sqrt, input->getType()); |
510 | auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt" ); |
511 | llvm_val[stmt] = builder->CreateFDiv( |
512 | tlctx->get_constant(stmt->ret_type, 1.0), intermediate); |
513 | } else if (op == UnaryOpType::bit_not) { |
514 | llvm_val[stmt] = builder->CreateNot(input); |
515 | } else if (op == UnaryOpType::neg) { |
516 | if (is_real(stmt->operand->ret_type)) { |
517 | llvm_val[stmt] = builder->CreateFNeg(input, "neg" ); |
518 | } else { |
519 | llvm_val[stmt] = builder->CreateNeg(input, "neg" ); |
520 | } |
521 | } |
522 | UNARY_INTRINSIC(round) |
523 | UNARY_INTRINSIC(floor) |
524 | UNARY_INTRINSIC(ceil) |
525 | else { |
526 | emit_extra_unary(stmt); |
527 | } |
528 | #undef UNARY_INTRINSIC |
529 | } |
530 | |
531 | void TaskCodeGenLLVM::create_elementwise_binary( |
532 | BinaryOpStmt *stmt, |
533 | std::function<llvm::Value *(llvm::Value *lhs, llvm::Value *rhs)> f) { |
534 | TI_ASSERT(stmt->lhs->ret_type->is<TensorType>()); |
535 | TI_ASSERT(stmt->rhs->ret_type->is<TensorType>()); |
536 | auto lhs_ty = stmt->lhs->ret_type->cast<TensorType>(); |
537 | auto rhs_ty = stmt->rhs->ret_type->cast<TensorType>(); |
538 | TI_ASSERT(lhs_ty->get_num_elements() == rhs_ty->get_num_elements()); |
539 | auto lhs_vec = llvm_val[stmt->lhs]; |
540 | auto rhs_vec = llvm_val[stmt->rhs]; |
541 | auto elt_type_name = data_type_name(lhs_ty->get_element_type()); |
542 | llvm::Value *result = |
543 | llvm::UndefValue::get(tlctx->get_data_type(stmt->ret_type)); |
544 | for (int i = 0; i < lhs_ty->get_num_elements(); ++i) { |
545 | auto lhs = builder->CreateExtractElement(lhs_vec, i); |
546 | auto rhs = builder->CreateExtractElement(rhs_vec, i); |
547 | result = builder->CreateInsertElement(result, f(lhs, rhs), i); |
548 | } |
549 | llvm_val[stmt] = result; |
550 | } |
551 | |
552 | void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { |
553 | auto op = stmt->op_type; |
554 | auto ret_type = stmt->ret_type; |
555 | |
556 | if (op == BinaryOpType::add) { |
557 | if (is_real(stmt->ret_type.get_element_type())) { |
558 | llvm_val[stmt] = |
559 | builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
560 | #if defined(__clang__) || defined(__GNUC__) |
561 | } else if (compile_config.debug && is_integral(stmt->ret_type)) { |
562 | llvm_val[stmt] = |
563 | call("debug_add_" + stmt->ret_type->to_string(), get_arg(0), |
564 | llvm_val[stmt->lhs], llvm_val[stmt->rhs], |
565 | builder->CreateGlobalStringPtr(stmt->tb)); |
566 | #endif |
567 | } else { |
568 | llvm_val[stmt] = |
569 | builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
570 | } |
571 | } else if (op == BinaryOpType::sub) { |
572 | if (is_real(stmt->ret_type.get_element_type())) { |
573 | llvm_val[stmt] = |
574 | builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
575 | #if defined(__clang__) || defined(__GNUC__) |
576 | } else if (compile_config.debug && is_integral(stmt->ret_type)) { |
577 | llvm_val[stmt] = |
578 | call("debug_sub_" + stmt->ret_type->to_string(), get_arg(0), |
579 | llvm_val[stmt->lhs], llvm_val[stmt->rhs], |
580 | builder->CreateGlobalStringPtr(stmt->tb)); |
581 | #endif |
582 | } else { |
583 | llvm_val[stmt] = |
584 | builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
585 | } |
586 | } else if (op == BinaryOpType::mul) { |
587 | if (is_real(stmt->ret_type.get_element_type())) { |
588 | llvm_val[stmt] = |
589 | builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
590 | #if defined(__clang__) || defined(__GNUC__) |
591 | } else if (compile_config.debug && is_integral(stmt->ret_type)) { |
592 | llvm_val[stmt] = |
593 | call("debug_mul_" + stmt->ret_type->to_string(), get_arg(0), |
594 | llvm_val[stmt->lhs], llvm_val[stmt->rhs], |
595 | builder->CreateGlobalStringPtr(stmt->tb)); |
596 | #endif |
597 | } else { |
598 | llvm_val[stmt] = |
599 | builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
600 | } |
601 | } else if (op == BinaryOpType::div) { |
602 | if (is_real(stmt->ret_type.get_element_type())) { |
603 | llvm_val[stmt] = |
604 | builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
605 | } else if (is_signed(stmt->ret_type)) { |
606 | llvm_val[stmt] = |
607 | builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
608 | } else { |
609 | llvm_val[stmt] = |
610 | builder->CreateUDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
611 | } |
612 | } else if (op == BinaryOpType::mod) { |
613 | llvm_val[stmt] = |
614 | builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
615 | } else if (op == BinaryOpType::bit_and) { |
616 | llvm_val[stmt] = |
617 | builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
618 | } else if (op == BinaryOpType::bit_or) { |
619 | llvm_val[stmt] = |
620 | builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
621 | } else if (op == BinaryOpType::bit_xor) { |
622 | llvm_val[stmt] = |
623 | builder->CreateXor(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
624 | } else if (op == BinaryOpType::bit_shl) { |
625 | #if defined(__clang__) || defined(__GNUC__) |
626 | if (compile_config.debug && is_integral(stmt->ret_type)) { |
627 | llvm_val[stmt] = |
628 | call("debug_shl_" + stmt->ret_type->to_string(), get_arg(0), |
629 | llvm_val[stmt->lhs], llvm_val[stmt->rhs], |
630 | builder->CreateGlobalStringPtr(stmt->tb)); |
631 | } else { |
632 | llvm_val[stmt] = |
633 | builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
634 | } |
635 | #else |
636 | llvm_val[stmt] = |
637 | builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
638 | #endif |
639 | } else if (op == BinaryOpType::bit_sar) { |
640 | if (is_signed(stmt->lhs->element_type())) { |
641 | llvm_val[stmt] = |
642 | builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
643 | } else { |
644 | llvm_val[stmt] = |
645 | builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
646 | } |
647 | } else if (op == BinaryOpType::max) { |
648 | #define BINARYOP_MAX(x) \ |
649 | else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \ |
650 | llvm_val[stmt] = \ |
651 | call("max_" #x, llvm_val[stmt->lhs], llvm_val[stmt->rhs]); \ |
652 | } |
653 | |
654 | if (is_real(ret_type.get_element_type())) { |
655 | llvm_val[stmt] = |
656 | builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
657 | } |
658 | BINARYOP_MAX(u16) |
659 | BINARYOP_MAX(i16) |
660 | BINARYOP_MAX(u32) |
661 | BINARYOP_MAX(i32) |
662 | BINARYOP_MAX(u64) |
663 | BINARYOP_MAX(i64) |
664 | else { |
665 | if (auto tensor_ty = ret_type->cast<TensorType>()) { |
666 | auto elt_ty = tensor_ty->get_element_type(); |
667 | TI_ASSERT(elt_ty->is_primitive(PrimitiveTypeID::u16) || |
668 | elt_ty->is_primitive(PrimitiveTypeID::i16) || |
669 | elt_ty->is_primitive(PrimitiveTypeID::u32) || |
670 | elt_ty->is_primitive(PrimitiveTypeID::i32) || |
671 | elt_ty->is_primitive(PrimitiveTypeID::u64) || |
672 | elt_ty->is_primitive(PrimitiveTypeID::i64)); |
673 | auto dtype_name = data_type_name(elt_ty); |
674 | auto binary_max = [this, &dtype_name](llvm::Value *lhs, |
675 | llvm::Value *rhs) { |
676 | return call("max_" + dtype_name, lhs, rhs); |
677 | }; |
678 | create_elementwise_binary(stmt, binary_max); |
679 | } else { |
680 | TI_P(data_type_name(ret_type)); |
681 | TI_NOT_IMPLEMENTED |
682 | } |
683 | } |
684 | } else if (op == BinaryOpType::min) { |
685 | #define BINARYOP_MIN(x) \ |
686 | else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \ |
687 | llvm_val[stmt] = \ |
688 | call("min_" #x, llvm_val[stmt->lhs], llvm_val[stmt->rhs]); \ |
689 | } |
690 | |
691 | if (is_real(ret_type.get_element_type())) { |
692 | llvm_val[stmt] = |
693 | builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
694 | } |
695 | BINARYOP_MIN(u16) |
696 | BINARYOP_MIN(i16) |
697 | BINARYOP_MIN(u32) |
698 | BINARYOP_MIN(i32) |
699 | BINARYOP_MIN(u64) |
700 | BINARYOP_MIN(i64) |
701 | else { |
702 | if (auto tensor_ty = ret_type->cast<TensorType>()) { |
703 | auto elt_ty = tensor_ty->get_element_type(); |
704 | TI_ASSERT(elt_ty->is_primitive(PrimitiveTypeID::u16) || |
705 | elt_ty->is_primitive(PrimitiveTypeID::i16) || |
706 | elt_ty->is_primitive(PrimitiveTypeID::u32) || |
707 | elt_ty->is_primitive(PrimitiveTypeID::i32) || |
708 | elt_ty->is_primitive(PrimitiveTypeID::u64) || |
709 | elt_ty->is_primitive(PrimitiveTypeID::i64)); |
710 | auto dtype_name = data_type_name(elt_ty); |
711 | auto binary_min = [this, &dtype_name](llvm::Value *lhs, |
712 | llvm::Value *rhs) { |
713 | return call("min_" + dtype_name, lhs, rhs); |
714 | }; |
715 | create_elementwise_binary(stmt, binary_min); |
716 | } else { |
717 | TI_P(data_type_name(ret_type)); |
718 | TI_NOT_IMPLEMENTED |
719 | } |
720 | } |
721 | } else if (is_comparison(op)) { |
722 | llvm::Value *cmp = nullptr; |
723 | auto input_type = stmt->lhs->ret_type; |
724 | if (op == BinaryOpType::cmp_eq) { |
725 | if (is_real(input_type.get_element_type())) { |
726 | cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
727 | } else { |
728 | cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
729 | } |
730 | } else if (op == BinaryOpType::cmp_le) { |
731 | if (is_real(input_type.get_element_type())) { |
732 | cmp = builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
733 | } else { |
734 | if (is_signed(input_type.get_element_type())) { |
735 | cmp = |
736 | builder->CreateICmpSLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
737 | } else { |
738 | cmp = |
739 | builder->CreateICmpULE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
740 | } |
741 | } |
742 | } else if (op == BinaryOpType::cmp_ge) { |
743 | if (is_real(input_type.get_element_type())) { |
744 | cmp = builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
745 | } else { |
746 | if (is_signed(input_type.get_element_type())) { |
747 | cmp = |
748 | builder->CreateICmpSGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
749 | } else { |
750 | cmp = |
751 | builder->CreateICmpUGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
752 | } |
753 | } |
754 | } else if (op == BinaryOpType::cmp_lt) { |
755 | if (is_real(input_type.get_element_type())) { |
756 | cmp = builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
757 | } else { |
758 | if (is_signed(input_type.get_element_type())) { |
759 | cmp = |
760 | builder->CreateICmpSLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
761 | } else { |
762 | cmp = |
763 | builder->CreateICmpULT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
764 | } |
765 | } |
766 | } else if (op == BinaryOpType::cmp_gt) { |
767 | if (is_real(input_type.get_element_type())) { |
768 | cmp = builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
769 | } else { |
770 | if (is_signed(input_type.get_element_type())) { |
771 | cmp = |
772 | builder->CreateICmpSGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
773 | } else { |
774 | cmp = |
775 | builder->CreateICmpUGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
776 | } |
777 | } |
778 | } else if (op == BinaryOpType::cmp_ne) { |
779 | if (is_real(input_type.get_element_type())) { |
780 | cmp = builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
781 | } else { |
782 | cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); |
783 | } |
784 | } else { |
785 | TI_NOT_IMPLEMENTED |
786 | } |
787 | llvm_val[stmt] = |
788 | builder->CreateSExt(cmp, tlctx->get_data_type(PrimitiveType::i32)); |
789 | } else { |
790 | // This branch contains atan2 and pow which use runtime.cpp function for |
791 | // **real** type. We don't have f16 support there so promoting to f32 is |
792 | // necessary. |
793 | llvm::Value *lhs = llvm_val[stmt->lhs]; |
794 | llvm::Value *rhs = llvm_val[stmt->rhs]; |
795 | if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
796 | lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context)); |
797 | } |
798 | if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
799 | rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context)); |
800 | } |
801 | if (ret_type->is_primitive(PrimitiveTypeID::f16)) { |
802 | ret_type = PrimitiveType::f32; |
803 | } |
804 | |
805 | if (op == BinaryOpType::atan2) { |
806 | if (arch_is_cpu(current_arch())) { |
807 | if (ret_type->is_primitive(PrimitiveTypeID::f32)) { |
808 | llvm_val[stmt] = call("atan2_f32" , lhs, rhs); |
809 | } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { |
810 | llvm_val[stmt] = call("atan2_f64" , lhs, rhs); |
811 | } else { |
812 | TI_P(data_type_name(ret_type)); |
813 | TI_NOT_IMPLEMENTED |
814 | } |
815 | } else { |
816 | TI_NOT_IMPLEMENTED |
817 | } |
818 | } else if (op == BinaryOpType::pow) { |
819 | if (arch_is_cpu(current_arch())) { |
820 | // Note that ret_type here cannot be integral because pow with an |
821 | // integral exponent has been demoted in the demote_operations pass |
822 | if (ret_type->is_primitive(PrimitiveTypeID::f32)) { |
823 | llvm_val[stmt] = call("pow_f32" , lhs, rhs); |
824 | } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { |
825 | llvm_val[stmt] = call("pow_f64" , lhs, rhs); |
826 | } else { |
827 | TI_P(data_type_name(ret_type)); |
828 | TI_NOT_IMPLEMENTED |
829 | } |
830 | } else { |
831 | TI_NOT_IMPLEMENTED |
832 | } |
833 | } else { |
834 | TI_P(binary_op_type_name(op)); |
835 | TI_NOT_IMPLEMENTED |
836 | } |
837 | |
838 | // Convert back to f16 if applicable. |
839 | if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
840 | llvm_val[stmt] = builder->CreateFPTrunc( |
841 | llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); |
842 | } |
843 | } |
844 | } |
845 | |
846 | void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) { |
847 | TI_ASSERT(stmt->op_type == TernaryOpType::select); |
848 | llvm_val[stmt] = builder->CreateSelect( |
849 | builder->CreateTrunc(llvm_val[stmt->op1], |
850 | tlctx->get_data_type(PrimitiveType::u1)), |
851 | llvm_val[stmt->op2], llvm_val[stmt->op3]); |
852 | } |
853 | |
854 | void TaskCodeGenLLVM::visit(IfStmt *if_stmt) { |
855 | // TODO: take care of vectorized cases |
856 | llvm::BasicBlock *true_block = |
857 | llvm::BasicBlock::Create(*llvm_context, "true_block" , func); |
858 | llvm::BasicBlock *false_block = |
859 | llvm::BasicBlock::Create(*llvm_context, "false_block" , func); |
860 | llvm::BasicBlock *after_if = |
861 | llvm::BasicBlock::Create(*llvm_context, "after_if" , func); |
862 | builder->CreateCondBr( |
863 | builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)), |
864 | true_block, false_block); |
865 | builder->SetInsertPoint(true_block); |
866 | if (if_stmt->true_statements) { |
867 | if_stmt->true_statements->accept(this); |
868 | } |
869 | if (!returned) { |
870 | builder->CreateBr(after_if); |
871 | } else { |
872 | returned = false; |
873 | } |
874 | builder->SetInsertPoint(false_block); |
875 | if (if_stmt->false_statements) { |
876 | if_stmt->false_statements->accept(this); |
877 | } |
878 | if (!returned) { |
879 | builder->CreateBr(after_if); |
880 | } else { |
881 | returned = false; |
882 | } |
883 | builder->SetInsertPoint(after_if); |
884 | } |
885 | |
886 | llvm::Value *TaskCodeGenLLVM::create_print(std::string tag, |
887 | DataType dt, |
888 | llvm::Value *value) { |
889 | if (!arch_is_cpu(compile_config.arch)) { |
890 | TI_WARN("print not supported on arch {}" , arch_name(compile_config.arch)); |
891 | return nullptr; |
892 | } |
893 | std::vector<llvm::Value *> args; |
894 | std::string format = data_type_format(dt); |
895 | auto runtime_printf = call("LLVMRuntime_get_host_printf" , get_runtime()); |
896 | args.push_back(builder->CreateGlobalStringPtr( |
897 | ("[llvm codegen debug] " + tag + " = " + format + "\n" ).c_str(), |
898 | "format_string" )); |
899 | if (dt->is_primitive(PrimitiveTypeID::f32)) |
900 | value = |
901 | builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64)); |
902 | args.push_back(value); |
903 | |
904 | auto func_type_func = get_runtime_function("get_func_type_host_printf" ); |
905 | return call(runtime_printf, func_type_func->getFunctionType(), |
906 | std::move(args)); |
907 | } |
908 | |
909 | llvm::Value *TaskCodeGenLLVM::create_print(std::string tag, |
910 | llvm::Value *value) { |
911 | if (value->getType() == llvm::Type::getFloatTy(*llvm_context)) |
912 | return create_print( |
913 | tag, |
914 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32), |
915 | value); |
916 | else if (value->getType() == llvm::Type::getInt32Ty(*llvm_context)) |
917 | return create_print( |
918 | tag, |
919 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32), |
920 | value); |
921 | else if (value->getType() == llvm::Type::getHalfTy(*llvm_context)) { |
922 | auto extended = |
923 | builder->CreateFPExt(value, llvm::Type::getFloatTy(*llvm_context)); |
924 | return create_print( |
925 | tag, |
926 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32), |
927 | extended); |
928 | } else if (value->getType() == llvm::Type::getInt64Ty(*llvm_context)) |
929 | return create_print( |
930 | tag, |
931 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64), |
932 | value); |
933 | else if (value->getType() == llvm::Type::getInt16Ty(*llvm_context)) |
934 | return create_print( |
935 | tag, |
936 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i16), |
937 | value); |
938 | else |
939 | TI_NOT_IMPLEMENTED |
940 | } |
941 | |
942 | void TaskCodeGenLLVM::visit(PrintStmt *stmt) { |
943 | std::vector<llvm::Value *> args; |
944 | std::string formats; |
945 | auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) { |
946 | if (dtype->is_primitive(PrimitiveTypeID::f32) || |
947 | dtype->is_primitive(PrimitiveTypeID::f16)) |
948 | return this->builder->CreateFPExt( |
949 | to_print, this->tlctx->get_data_type(PrimitiveType::f64)); |
950 | if (dtype->is_primitive(PrimitiveTypeID::i8)) |
951 | return builder->CreateSExt(to_print, |
952 | tlctx->get_data_type(PrimitiveType::i16)); |
953 | if (dtype->is_primitive(PrimitiveTypeID::u8)) |
954 | return builder->CreateZExt(to_print, |
955 | tlctx->get_data_type(PrimitiveType::u16)); |
956 | return to_print; |
957 | }; |
958 | for (auto const &content : stmt->contents) { |
959 | if (std::holds_alternative<Stmt *>(content)) { |
960 | auto arg_stmt = std::get<Stmt *>(content); |
961 | auto value = llvm_val[arg_stmt]; |
962 | auto value_type = value->getType(); |
963 | if (arg_stmt->ret_type->is<TensorType>()) { |
964 | auto dtype = arg_stmt->ret_type->cast<TensorType>(); |
965 | auto elem_type = dtype->get_element_type(); |
966 | for (int i = 0; i < dtype->get_num_elements(); ++i) { |
967 | if (codegen_vector_type(compile_config)) { |
968 | TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(value_type)); |
969 | auto elem = builder->CreateExtractElement(value, i); |
970 | args.push_back(value_for_printf(elem, elem_type)); |
971 | } else { |
972 | TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(value_type)); |
973 | auto elem = builder->CreateExtractValue(value, i); |
974 | args.push_back(value_for_printf(elem, elem_type)); |
975 | } |
976 | } |
977 | formats += data_type_format(arg_stmt->ret_type); |
978 | } else { |
979 | args.push_back(value_for_printf(value, arg_stmt->ret_type)); |
980 | formats += data_type_format(arg_stmt->ret_type); |
981 | } |
982 | } else { |
983 | auto arg_str = std::get<std::string>(content); |
984 | auto value = builder->CreateGlobalStringPtr(arg_str, "content_string" ); |
985 | args.push_back(value); |
986 | formats += "%s" ; |
987 | } |
988 | } |
989 | auto runtime_printf = call("LLVMRuntime_get_host_printf" , get_runtime()); |
990 | args.insert(args.begin(), |
991 | builder->CreateGlobalStringPtr(formats.c_str(), "format_string" )); |
992 | auto func_type_func = get_runtime_function("get_func_type_host_printf" ); |
993 | llvm_val[stmt] = |
994 | call(runtime_printf, func_type_func->getFunctionType(), std::move(args)); |
995 | } |
996 | |
997 | void TaskCodeGenLLVM::visit(ConstStmt *stmt) { |
998 | auto val = stmt->val; |
999 | if (val.dt->is_primitive(PrimitiveTypeID::f32)) { |
1000 | llvm_val[stmt] = |
1001 | llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float32())); |
1002 | } else if (val.dt->is_primitive(PrimitiveTypeID::f16)) { |
1003 | llvm_val[stmt] = llvm::ConstantFP::get(llvm::Type::getHalfTy(*llvm_context), |
1004 | val.val_float32()); |
1005 | } else if (val.dt->is_primitive(PrimitiveTypeID::f64)) { |
1006 | llvm_val[stmt] = |
1007 | llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64())); |
1008 | } else if (val.dt->is_primitive(PrimitiveTypeID::i8)) { |
1009 | llvm_val[stmt] = llvm::ConstantInt::get( |
1010 | *llvm_context, llvm::APInt(8, (uint64)val.val_int8(), true)); |
1011 | } else if (val.dt->is_primitive(PrimitiveTypeID::u8)) { |
1012 | llvm_val[stmt] = llvm::ConstantInt::get( |
1013 | *llvm_context, llvm::APInt(8, (uint64)val.val_uint8(), false)); |
1014 | } else if (val.dt->is_primitive(PrimitiveTypeID::i16)) { |
1015 | llvm_val[stmt] = llvm::ConstantInt::get( |
1016 | *llvm_context, llvm::APInt(16, (uint64)val.val_int16(), true)); |
1017 | } else if (val.dt->is_primitive(PrimitiveTypeID::u16)) { |
1018 | llvm_val[stmt] = llvm::ConstantInt::get( |
1019 | *llvm_context, llvm::APInt(16, (uint64)val.val_uint16(), false)); |
1020 | } else if (val.dt->is_primitive(PrimitiveTypeID::i32)) { |
1021 | llvm_val[stmt] = llvm::ConstantInt::get( |
1022 | *llvm_context, llvm::APInt(32, (uint64)val.val_int32(), true)); |
1023 | } else if (val.dt->is_primitive(PrimitiveTypeID::u32)) { |
1024 | llvm_val[stmt] = llvm::ConstantInt::get( |
1025 | *llvm_context, llvm::APInt(32, (uint64)val.val_uint32(), false)); |
1026 | } else if (val.dt->is_primitive(PrimitiveTypeID::i64)) { |
1027 | llvm_val[stmt] = llvm::ConstantInt::get( |
1028 | *llvm_context, llvm::APInt(64, (uint64)val.val_int64(), true)); |
1029 | } else if (val.dt->is_primitive(PrimitiveTypeID::u64)) { |
1030 | llvm_val[stmt] = llvm::ConstantInt::get( |
1031 | *llvm_context, llvm::APInt(64, val.val_uint64(), false)); |
1032 | } else { |
1033 | TI_P(data_type_name(val.dt)); |
1034 | TI_NOT_IMPLEMENTED; |
1035 | } |
1036 | } |
1037 | |
1038 | void TaskCodeGenLLVM::visit(WhileControlStmt *stmt) { |
1039 | using namespace llvm; |
1040 | |
1041 | BasicBlock *after_break = |
1042 | BasicBlock::Create(*llvm_context, "after_break" , func); |
1043 | TI_ASSERT(current_while_after_loop); |
1044 | auto cond = |
1045 | builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0)); |
1046 | builder->CreateCondBr(cond, current_while_after_loop, after_break); |
1047 | builder->SetInsertPoint(after_break); |
1048 | } |
1049 | |
1050 | void TaskCodeGenLLVM::visit(ContinueStmt *stmt) { |
1051 | using namespace llvm; |
1052 | auto stmt_in_off_range_for = [stmt]() { |
1053 | TI_ASSERT(stmt->scope != nullptr); |
1054 | if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) { |
1055 | TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || |
1056 | offl->task_type == OffloadedStmt::TaskType::struct_for); |
1057 | return offl->task_type == OffloadedStmt::TaskType::range_for; |
1058 | } |
1059 | return false; |
1060 | }; |
1061 | if (stmt_in_off_range_for()) { |
1062 | builder->CreateRetVoid(); |
1063 | } else { |
1064 | TI_ASSERT(current_loop_reentry != nullptr); |
1065 | builder->CreateBr(current_loop_reentry); |
1066 | } |
1067 | // Stmts after continue are useless, so we switch the insertion point to |
1068 | // /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!". |
1069 | BasicBlock *after_continue = |
1070 | BasicBlock::Create(*llvm_context, "after_continue" , func); |
1071 | builder->SetInsertPoint(after_continue); |
1072 | } |
1073 | |
1074 | void TaskCodeGenLLVM::visit(WhileStmt *stmt) { |
1075 | using namespace llvm; |
1076 | BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body" , func); |
1077 | builder->CreateBr(body); |
1078 | builder->SetInsertPoint(body); |
1079 | auto lrg = make_loop_reentry_guard(this); |
1080 | current_loop_reentry = body; |
1081 | |
1082 | BasicBlock *after_loop = |
1083 | BasicBlock::Create(*llvm_context, "after_while" , func); |
1084 | auto walg = make_while_after_loop_guard(this); |
1085 | current_while_after_loop = after_loop; |
1086 | |
1087 | stmt->body->accept(this); |
1088 | |
1089 | if (!returned) { |
1090 | builder->CreateBr(body); // jump to head |
1091 | } else { |
1092 | returned = false; |
1093 | } |
1094 | |
1095 | builder->SetInsertPoint(after_loop); |
1096 | } |
1097 | |
1098 | llvm::Value *TaskCodeGenLLVM::cast_pointer(llvm::Value *val, |
1099 | std::string dest_ty_name, |
1100 | int addr_space) { |
1101 | return builder->CreateBitCast( |
1102 | val, llvm::PointerType::get(get_runtime_type(dest_ty_name), addr_space)); |
1103 | } |
1104 | |
1105 | void TaskCodeGenLLVM::emit_list_gen(OffloadedStmt *listgen) { |
1106 | auto snode_child = listgen->snode; |
1107 | auto snode_parent = listgen->snode->parent; |
1108 | auto meta_child = cast_pointer(emit_struct_meta(snode_child), "StructMeta" ); |
1109 | auto meta_parent = cast_pointer(emit_struct_meta(snode_parent), "StructMeta" ); |
1110 | if (snode_parent->type == SNodeType::root) { |
1111 | // Since there's only one container to expand, we need a special kernel for |
1112 | // more parallelism. |
1113 | call("element_listgen_root" , get_runtime(), meta_parent, meta_child); |
1114 | } else { |
1115 | call("element_listgen_nonroot" , get_runtime(), meta_parent, meta_child); |
1116 | } |
1117 | } |
1118 | |
1119 | void TaskCodeGenLLVM::emit_gc(OffloadedStmt *stmt) { |
1120 | auto snode = stmt->snode->id; |
1121 | call("node_gc" , get_runtime(), tlctx->get_constant(snode)); |
1122 | } |
1123 | |
1124 | void TaskCodeGenLLVM::emit_gc_rc() { |
1125 | call("runtime_context_gc" , get_runtime()); |
1126 | } |
1127 | |
1128 | void TaskCodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) { |
1129 | auto original_value = builder->CreateLoad(value->getType(), ptr); |
1130 | builder->CreateStore(builder->CreateAdd(original_value, value), ptr); |
1131 | } |
1132 | |
1133 | void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { |
1134 | using namespace llvm; |
1135 | BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body" , func); |
1136 | BasicBlock *loop_inc = |
1137 | BasicBlock::Create(*llvm_context, "for_loop_inc" , func); |
1138 | BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for" , func); |
1139 | BasicBlock *loop_test = |
1140 | BasicBlock::Create(*llvm_context, "for_loop_test" , func); |
1141 | |
1142 | auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32); |
1143 | |
1144 | auto loop_var = create_entry_block_alloca(PrimitiveType::i32); |
1145 | loop_vars_llvm[for_stmt].push_back(loop_var); |
1146 | |
1147 | if (!for_stmt->reversed) { |
1148 | builder->CreateStore(llvm_val[for_stmt->begin], loop_var); |
1149 | } else { |
1150 | builder->CreateStore( |
1151 | builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)), |
1152 | loop_var); |
1153 | } |
1154 | builder->CreateBr(loop_test); |
1155 | |
1156 | { |
1157 | // test block |
1158 | builder->SetInsertPoint(loop_test); |
1159 | llvm::Value *cond; |
1160 | if (!for_stmt->reversed) { |
1161 | cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT, |
1162 | builder->CreateLoad(loop_var_ty, loop_var), |
1163 | llvm_val[for_stmt->end]); |
1164 | } else { |
1165 | cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE, |
1166 | builder->CreateLoad(loop_var_ty, loop_var), |
1167 | llvm_val[for_stmt->begin]); |
1168 | } |
1169 | builder->CreateCondBr(cond, body, after_loop); |
1170 | } |
1171 | |
1172 | { |
1173 | { |
1174 | auto lrg = make_loop_reentry_guard(this); |
1175 | // The continue stmt should jump to the loop-increment block! |
1176 | current_loop_reentry = loop_inc; |
1177 | // body cfg |
1178 | builder->SetInsertPoint(body); |
1179 | |
1180 | for_stmt->body->accept(this); |
1181 | } |
1182 | if (!returned) { |
1183 | builder->CreateBr(loop_inc); |
1184 | } else { |
1185 | returned = false; |
1186 | } |
1187 | builder->SetInsertPoint(loop_inc); |
1188 | |
1189 | if (!for_stmt->reversed) { |
1190 | create_increment(loop_var, tlctx->get_constant(1)); |
1191 | } else { |
1192 | create_increment(loop_var, tlctx->get_constant(-1)); |
1193 | } |
1194 | builder->CreateBr(loop_test); |
1195 | } |
1196 | |
1197 | // next cfg |
1198 | builder->SetInsertPoint(after_loop); |
1199 | } |
1200 | |
1201 | void TaskCodeGenLLVM::visit(RangeForStmt *for_stmt) { |
1202 | create_naive_range_for(for_stmt); |
1203 | } |
1204 | |
1205 | llvm::Value *TaskCodeGenLLVM::bitcast_from_u64(llvm::Value *val, |
1206 | DataType type) { |
1207 | llvm::Type *dest_ty = nullptr; |
1208 | TI_ASSERT(!type->is<PointerType>()); |
1209 | if (auto qit = type->cast<QuantIntType>()) { |
1210 | if (qit->get_is_signed()) |
1211 | dest_ty = tlctx->get_data_type(PrimitiveType::i32); |
1212 | else |
1213 | dest_ty = tlctx->get_data_type(PrimitiveType::u32); |
1214 | } else { |
1215 | dest_ty = tlctx->get_data_type(type); |
1216 | } |
1217 | auto dest_bits = dest_ty->getPrimitiveSizeInBits(); |
1218 | if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) { |
1219 | // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa |
1220 | // which doesn't mean anything. |
1221 | // So we truncate to 32 bits first and then fptrunc to half if applicable |
1222 | auto truncated = |
1223 | builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32)); |
1224 | auto casted = builder->CreateBitCast(truncated, |
1225 | llvm::Type::getFloatTy(*llvm_context)); |
1226 | return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context)); |
1227 | } else { |
1228 | auto truncated = builder->CreateTrunc( |
1229 | val, llvm::Type::getIntNTy(*llvm_context, dest_bits)); |
1230 | |
1231 | return builder->CreateBitCast(truncated, dest_ty); |
1232 | } |
1233 | } |
1234 | |
1235 | llvm::Value *TaskCodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) { |
1236 | auto intermediate_bits = 0; |
1237 | if (type.is_pointer()) { |
1238 | return builder->CreatePtrToInt(val, tlctx->get_data_type<int64>()); |
1239 | } |
1240 | if (auto qit = type->cast<QuantIntType>()) { |
1241 | intermediate_bits = data_type_bits(qit->get_compute_type()); |
1242 | } else { |
1243 | intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits(); |
1244 | } |
1245 | llvm::Type *dest_ty = tlctx->get_data_type<int64>(); |
1246 | llvm::Type *intermediate_type = nullptr; |
1247 | if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) { |
1248 | val = builder->CreateFPExt(val, tlctx->get_data_type<float>()); |
1249 | intermediate_type = tlctx->get_data_type<int32>(); |
1250 | } else { |
1251 | intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits); |
1252 | } |
1253 | return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type), |
1254 | dest_ty); |
1255 | } |
1256 | |
1257 | void TaskCodeGenLLVM::visit(ArgLoadStmt *stmt) { |
1258 | auto raw_arg = stmt->is_grad |
1259 | ? (call(builder.get(), "RuntimeContext_get_grad_args" , |
1260 | get_context(), tlctx->get_constant(stmt->arg_id))) |
1261 | : (call(builder.get(), "RuntimeContext_get_args" , |
1262 | get_context(), tlctx->get_constant(stmt->arg_id))); |
1263 | llvm::Type *dest_ty = nullptr; |
1264 | if (stmt->is_ptr) { |
1265 | dest_ty = llvm::PointerType::get( |
1266 | tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); |
1267 | llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); |
1268 | } else { |
1269 | llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type); |
1270 | } |
1271 | } |
1272 | |
1273 | void TaskCodeGenLLVM::visit(ReturnStmt *stmt) { |
1274 | auto types = stmt->element_types(); |
1275 | if (std::any_of(types.begin(), types.end(), |
1276 | [](const DataType &t) { return t.is_pointer(); })) { |
1277 | TI_NOT_IMPLEMENTED |
1278 | } else { |
1279 | TI_ASSERT(stmt->values.size() == |
1280 | current_callable->ret_type->get_num_elements()); |
1281 | create_return(stmt->values); |
1282 | } |
1283 | builder->CreateBr(final_block); |
1284 | returned = true; |
1285 | } |
1286 | |
1287 | void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) { |
1288 | // FIXME: get ptr_ty from taichi instead of llvm. |
1289 | llvm::Type *ptr_ty = nullptr; |
1290 | auto *val = llvm_val[stmt->src]; |
1291 | if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val)) |
1292 | ptr_ty = alloc->getAllocatedType(); |
1293 | if (!ptr_ty && stmt->src->element_type().is_pointer()) { |
1294 | ptr_ty = tlctx->get_data_type(stmt->src->element_type().ptr_removed()); |
1295 | } |
1296 | TI_ASSERT(ptr_ty); |
1297 | llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]); |
1298 | } |
1299 | |
1300 | void TaskCodeGenLLVM::visit(LocalStoreStmt *stmt) { |
1301 | builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); |
1302 | } |
1303 | |
1304 | void TaskCodeGenLLVM::visit(AssertStmt *stmt) { |
1305 | TI_ASSERT((int)stmt->args.size() <= taichi_error_message_max_num_arguments); |
1306 | auto argument_buffer_size = llvm::ArrayType::get( |
1307 | llvm::Type::getInt64Ty(*llvm_context), stmt->args.size()); |
1308 | |
1309 | // TODO: maybe let all asserts in a single offload share a single buffer? |
1310 | auto arguments = create_entry_block_alloca(argument_buffer_size); |
1311 | |
1312 | std::vector<llvm::Value *> args; |
1313 | args.emplace_back(get_runtime()); |
1314 | args.emplace_back(llvm_val[stmt->cond]); |
1315 | args.emplace_back(builder->CreateGlobalStringPtr(stmt->text)); |
1316 | |
1317 | for (int i = 0; i < stmt->args.size(); i++) { |
1318 | auto arg = stmt->args[i]; |
1319 | TI_ASSERT(llvm_val[arg]); |
1320 | |
1321 | // First convert the argument to an integral type with the same number of |
1322 | // bits: |
1323 | auto cast_type = llvm::Type::getIntNTy( |
1324 | *llvm_context, 8 * (std::size_t)data_type_size(arg->ret_type)); |
1325 | auto cast_int = builder->CreateBitCast(llvm_val[arg], cast_type); |
1326 | |
1327 | // Then zero-extend the conversion result into int64: |
1328 | auto cast_int64 = |
1329 | builder->CreateZExt(cast_int, llvm::Type::getInt64Ty(*llvm_context)); |
1330 | |
1331 | // Finally store the int64 value to the argument buffer: |
1332 | builder->CreateStore( |
1333 | cast_int64, |
1334 | builder->CreateGEP(argument_buffer_size, arguments, |
1335 | {tlctx->get_constant(0), tlctx->get_constant(i)})); |
1336 | } |
1337 | |
1338 | args.emplace_back(tlctx->get_constant((int)stmt->args.size())); |
1339 | args.emplace_back( |
1340 | builder->CreateGEP(argument_buffer_size, arguments, |
1341 | {tlctx->get_constant(0), tlctx->get_constant(0)})); |
1342 | |
1343 | llvm_val[stmt] = call("taichi_assert_format" , std::move(args)); |
1344 | } |
1345 | |
1346 | void TaskCodeGenLLVM::visit(SNodeOpStmt *stmt) { |
1347 | auto snode = stmt->snode; |
1348 | if (stmt->op_type == SNodeOpType::allocate) { |
1349 | TI_ASSERT(snode->type == SNodeType::dynamic); |
1350 | TI_ASSERT(stmt->ret_type.is_pointer() && |
1351 | stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::gen)); |
1352 | auto ptr = |
1353 | call(snode, llvm_val[stmt->ptr], "allocate" , {llvm_val[stmt->val]}); |
1354 | llvm_val[stmt] = ptr; |
1355 | } else if (stmt->op_type == SNodeOpType::length) { |
1356 | TI_ASSERT(snode->type == SNodeType::dynamic); |
1357 | llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "get_num_elements" , {}); |
1358 | } else if (stmt->op_type == SNodeOpType::is_active) { |
1359 | llvm_val[stmt] = |
1360 | call(snode, llvm_val[stmt->ptr], "is_active" , {llvm_val[stmt->val]}); |
1361 | } else if (stmt->op_type == SNodeOpType::activate) { |
1362 | llvm_val[stmt] = |
1363 | call(snode, llvm_val[stmt->ptr], "activate" , {llvm_val[stmt->val]}); |
1364 | } else if (stmt->op_type == SNodeOpType::deactivate) { |
1365 | if (snode->type == SNodeType::pointer || snode->type == SNodeType::hash || |
1366 | snode->type == SNodeType::bitmasked) { |
1367 | llvm_val[stmt] = |
1368 | call(snode, llvm_val[stmt->ptr], "deactivate" , {llvm_val[stmt->val]}); |
1369 | } else if (snode->type == SNodeType::dynamic) { |
1370 | llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "deactivate" , {}); |
1371 | } |
1372 | } else { |
1373 | TI_NOT_IMPLEMENTED |
1374 | } |
1375 | } |
1376 | |
1377 | llvm::Value *TaskCodeGenLLVM::optimized_reduction(AtomicOpStmt *stmt) { |
1378 | return nullptr; |
1379 | } |
1380 | |
1381 | llvm::Value *TaskCodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) { |
1382 | // TODO(type): support all AtomicOpTypes on quant types |
1383 | if (stmt->op_type != AtomicOpType::add) { |
1384 | return nullptr; |
1385 | } |
1386 | |
1387 | auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type(); |
1388 | if (auto qit = dst_type->cast<QuantIntType>()) { |
1389 | return atomic_add_quant_int( |
1390 | llvm_val[stmt->dest], |
1391 | tlctx->get_data_type( |
1392 | stmt->dest->as<GetChStmt>()->input_snode->physical_type), |
1393 | qit, llvm_val[stmt->val], is_signed(stmt->val->ret_type)); |
1394 | } else if (auto qfxt = dst_type->cast<QuantFixedType>()) { |
1395 | return atomic_add_quant_fixed( |
1396 | llvm_val[stmt->dest], |
1397 | tlctx->get_data_type( |
1398 | stmt->dest->as<GetChStmt>()->input_snode->physical_type), |
1399 | qfxt, llvm_val[stmt->val]); |
1400 | } else { |
1401 | return nullptr; |
1402 | } |
1403 | } |
1404 | |
1405 | llvm::Value *TaskCodeGenLLVM::integral_type_atomic(AtomicOpStmt *stmt) { |
1406 | if (!is_integral(stmt->val->ret_type)) { |
1407 | return nullptr; |
1408 | } |
1409 | |
1410 | std::unordered_map<AtomicOpType, llvm::AtomicRMWInst::BinOp> bin_op; |
1411 | bin_op[AtomicOpType::add] = llvm::AtomicRMWInst::BinOp::Add; |
1412 | if (is_signed(stmt->val->ret_type)) { |
1413 | bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::Min; |
1414 | bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::Max; |
1415 | } else { |
1416 | bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::UMin; |
1417 | bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::UMax; |
1418 | } |
1419 | bin_op[AtomicOpType::bit_and] = llvm::AtomicRMWInst::BinOp::And; |
1420 | bin_op[AtomicOpType::bit_or] = llvm::AtomicRMWInst::BinOp::Or; |
1421 | bin_op[AtomicOpType::bit_xor] = llvm::AtomicRMWInst::BinOp::Xor; |
1422 | TI_ASSERT(bin_op.find(stmt->op_type) != bin_op.end()); |
1423 | return builder->CreateAtomicRMW( |
1424 | bin_op.at(stmt->op_type), llvm_val[stmt->dest], llvm_val[stmt->val], |
1425 | llvm::MaybeAlign(0), llvm::AtomicOrdering::SequentiallyConsistent); |
1426 | } |
1427 | |
1428 | llvm::Value *TaskCodeGenLLVM::atomic_op_using_cas( |
1429 | llvm::Value *dest, |
1430 | llvm::Value *val, |
1431 | std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op) { |
1432 | using namespace llvm; |
1433 | BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body" , func); |
1434 | BasicBlock *after_loop = |
1435 | BasicBlock::Create(*llvm_context, "after_while" , func); |
1436 | |
1437 | builder->CreateBr(body); |
1438 | builder->SetInsertPoint(body); |
1439 | |
1440 | llvm::Value *old_val; |
1441 | |
1442 | { |
1443 | old_val = builder->CreateLoad(val->getType(), dest); |
1444 | auto new_val = op(old_val, val); |
1445 | dest = |
1446 | builder->CreateBitCast(dest, llvm::Type::getInt16PtrTy(*llvm_context)); |
1447 | auto atomicCmpXchg = builder->CreateAtomicCmpXchg( |
1448 | dest, |
1449 | builder->CreateBitCast(old_val, llvm::Type::getInt16Ty(*llvm_context)), |
1450 | builder->CreateBitCast(new_val, llvm::Type::getInt16Ty(*llvm_context)), |
1451 | llvm::MaybeAlign(0), AtomicOrdering::SequentiallyConsistent, |
1452 | AtomicOrdering::SequentiallyConsistent); |
1453 | // Check whether CAS was succussful |
1454 | auto ok = builder->CreateExtractValue(atomicCmpXchg, 1); |
1455 | builder->CreateCondBr(builder->CreateNot(ok), body, after_loop); |
1456 | } |
1457 | |
1458 | builder->SetInsertPoint(after_loop); |
1459 | |
1460 | return old_val; |
1461 | } |
1462 | |
1463 | llvm::Value *TaskCodeGenLLVM::real_type_atomic(AtomicOpStmt *stmt) { |
1464 | if (!is_real(stmt->val->ret_type)) { |
1465 | return nullptr; |
1466 | } |
1467 | |
1468 | PrimitiveTypeID prim_type = stmt->val->ret_type->cast<PrimitiveType>()->type; |
1469 | AtomicOpType op = stmt->op_type; |
1470 | if (prim_type == PrimitiveTypeID::f16) { |
1471 | switch (op) { |
1472 | case AtomicOpType::add: |
1473 | return atomic_op_using_cas( |
1474 | llvm_val[stmt->dest], llvm_val[stmt->val], |
1475 | [&](auto v1, auto v2) { return builder->CreateFAdd(v1, v2); }); |
1476 | case AtomicOpType::max: |
1477 | return atomic_op_using_cas( |
1478 | llvm_val[stmt->dest], llvm_val[stmt->val], |
1479 | [&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); }); |
1480 | case AtomicOpType::min: |
1481 | return atomic_op_using_cas( |
1482 | llvm_val[stmt->dest], llvm_val[stmt->val], |
1483 | [&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); }); |
1484 | default: |
1485 | break; |
1486 | } |
1487 | } |
1488 | |
1489 | if (op == AtomicOpType::add) { |
1490 | return builder->CreateAtomicRMW( |
1491 | llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val], |
1492 | llvm::MaybeAlign(0), llvm::AtomicOrdering::SequentiallyConsistent); |
1493 | } |
1494 | |
1495 | std::unordered_map<PrimitiveTypeID, |
1496 | std::unordered_map<AtomicOpType, std::string>> |
1497 | atomics; |
1498 | atomics[PrimitiveTypeID::f32][AtomicOpType::min] = "atomic_min_f32" ; |
1499 | atomics[PrimitiveTypeID::f64][AtomicOpType::min] = "atomic_min_f64" ; |
1500 | atomics[PrimitiveTypeID::f32][AtomicOpType::max] = "atomic_max_f32" ; |
1501 | atomics[PrimitiveTypeID::f64][AtomicOpType::max] = "atomic_max_f64" ; |
1502 | TI_ASSERT(atomics.find(prim_type) != atomics.end()); |
1503 | TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end()); |
1504 | return call(atomics.at(prim_type).at(op), llvm_val[stmt->dest], |
1505 | llvm_val[stmt->val]); |
1506 | } |
1507 | |
1508 | void TaskCodeGenLLVM::visit(AtomicOpStmt *stmt) { |
1509 | bool is_local = stmt->dest->is<AllocaStmt>(); |
1510 | if (is_local) { |
1511 | TI_ERROR("Local atomics should have been demoted." ); |
1512 | } |
1513 | llvm::Value *old_value; |
1514 | if (llvm::Value *result = optimized_reduction(stmt)) { |
1515 | old_value = result; |
1516 | } else if (llvm::Value *result = quant_type_atomic(stmt)) { |
1517 | old_value = result; |
1518 | } else if (llvm::Value *result = real_type_atomic(stmt)) { |
1519 | old_value = result; |
1520 | } else if (llvm::Value *result = integral_type_atomic(stmt)) { |
1521 | old_value = result; |
1522 | } else { |
1523 | TI_NOT_IMPLEMENTED |
1524 | } |
1525 | llvm_val[stmt] = old_value; |
1526 | } |
1527 | |
1528 | void TaskCodeGenLLVM::visit(GlobalPtrStmt *stmt) { |
1529 | TI_ERROR("Global Ptrs should have been lowered." ); |
1530 | } |
1531 | |
1532 | void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) { |
1533 | TI_ASSERT(llvm_val[stmt->val]); |
1534 | TI_ASSERT(llvm_val[stmt->dest]); |
1535 | auto ptr_type = stmt->dest->ret_type->as<PointerType>(); |
1536 | if (ptr_type->is_bit_pointer()) { |
1537 | auto pointee_type = ptr_type->get_pointee_type(); |
1538 | auto snode = stmt->dest->as<GetChStmt>()->input_snode; |
1539 | if (snode->type == SNodeType::bit_struct) { |
1540 | TI_ERROR( |
1541 | "Bit struct stores with type {} should have been handled by " |
1542 | "BitStructStoreStmt." , |
1543 | pointee_type->to_string()); |
1544 | } |
1545 | if (auto qit = pointee_type->cast<QuantIntType>()) { |
1546 | store_quant_int(llvm_val[stmt->dest], |
1547 | tlctx->get_data_type(snode->physical_type), qit, |
1548 | llvm_val[stmt->val], true); |
1549 | } else if (auto qfxt = pointee_type->cast<QuantFixedType>()) { |
1550 | store_quant_fixed(llvm_val[stmt->dest], |
1551 | tlctx->get_data_type(snode->physical_type), qfxt, |
1552 | llvm_val[stmt->val], true); |
1553 | } else { |
1554 | TI_NOT_IMPLEMENTED; |
1555 | } |
1556 | } else { |
1557 | builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); |
1558 | } |
1559 | } |
1560 | |
1561 | llvm::Value *TaskCodeGenLLVM::create_intrinsic_load(llvm::Value *ptr, |
1562 | llvm::Type *ty) { |
1563 | TI_NOT_IMPLEMENTED; |
1564 | } |
1565 | |
1566 | void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, |
1567 | bool should_cache_as_read_only) { |
1568 | auto ptr = llvm_val[stmt->src]; |
1569 | auto ptr_type = stmt->src->ret_type->as<PointerType>(); |
1570 | if (ptr_type->is_bit_pointer()) { |
1571 | auto val_type = ptr_type->get_pointee_type(); |
1572 | auto get_ch = stmt->src->as<GetChStmt>(); |
1573 | auto physical_type = |
1574 | tlctx->get_data_type(get_ch->input_snode->physical_type); |
1575 | auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); |
1576 | auto physical_value = should_cache_as_read_only |
1577 | ? create_intrinsic_load(byte_ptr, physical_type) |
1578 | : builder->CreateLoad(physical_type, byte_ptr); |
1579 | if (auto qit = val_type->cast<QuantIntType>()) { |
1580 | llvm_val[stmt] = extract_quant_int(physical_value, bit_offset, qit); |
1581 | } else if (auto qfxt = val_type->cast<QuantFixedType>()) { |
1582 | qit = qfxt->get_digits_type()->as<QuantIntType>(); |
1583 | auto digits = extract_quant_int(physical_value, bit_offset, qit); |
1584 | llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt); |
1585 | } else { |
1586 | TI_ASSERT(val_type->is<QuantFloatType>()); |
1587 | TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>()); |
1588 | llvm_val[stmt] = extract_quant_float( |
1589 | physical_value, get_ch->input_snode->dt->as<BitStructType>(), |
1590 | get_ch->output_snode->id_in_bit_struct); |
1591 | } |
1592 | } else { |
1593 | // Byte pointer case. |
1594 | if (should_cache_as_read_only) { |
1595 | llvm_val[stmt] = |
1596 | create_intrinsic_load(ptr, tlctx->get_data_type(stmt->ret_type)); |
1597 | } else { |
1598 | llvm_val[stmt] = |
1599 | builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr); |
1600 | } |
1601 | } |
1602 | } |
1603 | |
1604 | void TaskCodeGenLLVM::visit(GlobalLoadStmt *stmt) { |
1605 | create_global_load(stmt, false); |
1606 | } |
1607 | |
1608 | std::string TaskCodeGenLLVM::get_runtime_snode_name(SNode *snode) { |
1609 | if (snode->type == SNodeType::root) { |
1610 | return "Root" ; |
1611 | } else if (snode->type == SNodeType::dense) { |
1612 | return "Dense" ; |
1613 | } else if (snode->type == SNodeType::dynamic) { |
1614 | return "Dynamic" ; |
1615 | } else if (snode->type == SNodeType::pointer) { |
1616 | return "Pointer" ; |
1617 | } else if (snode->type == SNodeType::hash) { |
1618 | return "Hash" ; |
1619 | } else if (snode->type == SNodeType::bitmasked) { |
1620 | return "Bitmasked" ; |
1621 | } else if (snode->type == SNodeType::bit_struct) { |
1622 | return "BitStruct" ; |
1623 | } else if (snode->type == SNodeType::quant_array) { |
1624 | return "QuantArray" ; |
1625 | } else { |
1626 | TI_P(snode_type_name(snode->type)); |
1627 | TI_NOT_IMPLEMENTED |
1628 | } |
1629 | } |
1630 | |
1631 | llvm::Value *TaskCodeGenLLVM::call( |
1632 | SNode *snode, |
1633 | llvm::Value *node_ptr, |
1634 | const std::string &method, |
1635 | const std::vector<llvm::Value *> &arguments) { |
1636 | auto prefix = get_runtime_snode_name(snode); |
1637 | auto s = emit_struct_meta(snode); |
1638 | auto s_ptr = |
1639 | builder->CreateBitCast(s, llvm::Type::getInt8PtrTy(*llvm_context)); |
1640 | |
1641 | node_ptr = |
1642 | builder->CreateBitCast(node_ptr, llvm::Type::getInt8PtrTy(*llvm_context)); |
1643 | |
1644 | std::vector<llvm::Value *> func_arguments{s_ptr, node_ptr}; |
1645 | |
1646 | func_arguments.insert(func_arguments.end(), arguments.begin(), |
1647 | arguments.end()); |
1648 | |
1649 | return call(prefix + "_" + method, std::move(func_arguments)); |
1650 | } |
1651 | |
1652 | llvm::Function *TaskCodeGenLLVM::get_struct_function(const std::string &name, |
1653 | int tree_id) { |
1654 | used_tree_ids.insert(tree_id); |
1655 | auto f = tlctx->get_struct_function(name, tree_id); |
1656 | if (!f) { |
1657 | TI_ERROR("Struct function {} not found." , name); |
1658 | } |
1659 | f = llvm::cast<llvm::Function>( |
1660 | module |
1661 | ->getOrInsertFunction(name, f->getFunctionType(), f->getAttributes()) |
1662 | .getCallee()); |
1663 | return f; |
1664 | } |
1665 | |
1666 | template <typename... Args> |
1667 | llvm::Value *TaskCodeGenLLVM::call_struct_func(int tree_id, |
1668 | const std::string &func_name, |
1669 | Args &&...args) { |
1670 | auto func = get_struct_function(func_name, tree_id); |
1671 | auto arglist = std::vector<llvm::Value *>({args...}); |
1672 | check_func_call_signature(func->getFunctionType(), func->getName(), arglist, |
1673 | builder.get()); |
1674 | return builder->CreateCall(func, arglist); |
1675 | } |
1676 | |
1677 | void TaskCodeGenLLVM::visit(GetRootStmt *stmt) { |
1678 | if (stmt->root() == nullptr) |
1679 | llvm_val[stmt] = builder->CreateBitCast( |
1680 | get_root(SNodeTree::kFirstID), |
1681 | llvm::PointerType::get( |
1682 | StructCompilerLLVM::get_llvm_node_type( |
1683 | module.get(), prog->get_snode_root(SNodeTree::kFirstID)), |
1684 | 0)); |
1685 | else |
1686 | llvm_val[stmt] = builder->CreateBitCast( |
1687 | get_root(stmt->root()->get_snode_tree_id()), |
1688 | llvm::PointerType::get( |
1689 | StructCompilerLLVM::get_llvm_node_type(module.get(), stmt->root()), |
1690 | 0)); |
1691 | } |
1692 | |
1693 | void TaskCodeGenLLVM::visit(LinearizeStmt *stmt) { |
1694 | llvm::Value *val = tlctx->get_constant(0); |
1695 | for (int i = 0; i < (int)stmt->inputs.size(); i++) { |
1696 | val = builder->CreateAdd( |
1697 | builder->CreateMul(val, tlctx->get_constant(stmt->strides[i])), |
1698 | llvm_val[stmt->inputs[i]]); |
1699 | } |
1700 | llvm_val[stmt] = val; |
1701 | } |
1702 | |
1703 | void TaskCodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} |
1704 | |
1705 | llvm::Value *TaskCodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr, |
1706 | llvm::Value *bit_offset) { |
1707 | // 1. define the bit pointer struct (X=8/16/32/64) |
1708 | // struct bit_pointer_X { |
1709 | // iX* byte_ptr; |
1710 | // i32 bit_offset; |
1711 | // }; |
1712 | TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); |
1713 | auto struct_type = llvm::StructType::get( |
1714 | *llvm_context, {byte_ptr->getType(), bit_offset->getType()}); |
1715 | // 2. allocate the bit pointer struct |
1716 | auto bit_ptr = create_entry_block_alloca(struct_type); |
1717 | // 3. store `byte_ptr` |
1718 | builder->CreateStore(byte_ptr, builder->CreateGEP(struct_type, bit_ptr, |
1719 | {tlctx->get_constant(0), |
1720 | tlctx->get_constant(0)})); |
1721 | // 4. store `bit_offset |
1722 | builder->CreateStore( |
1723 | bit_offset, |
1724 | builder->CreateGEP(struct_type, bit_ptr, |
1725 | {tlctx->get_constant(0), tlctx->get_constant(1)})); |
1726 | return bit_ptr; |
1727 | } |
1728 | |
1729 | std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::load_bit_ptr( |
1730 | llvm::Value *bit_ptr) { |
1731 | // FIXME: get ptr_ty from taichi instead of llvm. |
1732 | llvm::Type *ptr_ty = nullptr; |
1733 | if (auto *AI = llvm::dyn_cast<llvm::AllocaInst>(bit_ptr)) |
1734 | ptr_ty = AI->getAllocatedType(); |
1735 | TI_ASSERT(ptr_ty); |
1736 | auto *struct_ty = llvm::cast<llvm::StructType>(ptr_ty); |
1737 | auto byte_ptr = builder->CreateLoad( |
1738 | struct_ty->getElementType(0), |
1739 | builder->CreateGEP(ptr_ty, bit_ptr, |
1740 | {tlctx->get_constant(0), tlctx->get_constant(0)})); |
1741 | auto bit_offset = builder->CreateLoad( |
1742 | struct_ty->getElementType(1), |
1743 | builder->CreateGEP(ptr_ty, bit_ptr, |
1744 | {tlctx->get_constant(0), tlctx->get_constant(1)})); |
1745 | |
1746 | return std::make_tuple(byte_ptr, bit_offset); |
1747 | } |
1748 | |
1749 | void TaskCodeGenLLVM::visit(SNodeLookupStmt *stmt) { |
1750 | llvm::Value *parent = nullptr; |
1751 | parent = llvm_val[stmt->input_snode]; |
1752 | TI_ASSERT(parent); |
1753 | auto snode = stmt->snode; |
1754 | if (snode->type == SNodeType::root) { |
1755 | // FIXME: get parent_type from taichi instead of llvm. |
1756 | llvm::Type *parent_ty = builder->getInt8Ty(); |
1757 | if (auto bit_cast = llvm::dyn_cast<llvm::BitCastInst>(parent)) { |
1758 | parent_ty = bit_cast->getDestTy(); |
1759 | if (auto ptr_ty = llvm::dyn_cast<llvm::PointerType>(parent_ty)) |
1760 | parent_ty = ptr_ty->getPointerElementType(); |
1761 | } |
1762 | llvm_val[stmt] = |
1763 | builder->CreateGEP(parent_ty, parent, llvm_val[stmt->input_index]); |
1764 | } else if (snode->type == SNodeType::dense || |
1765 | snode->type == SNodeType::pointer || |
1766 | snode->type == SNodeType::dynamic || |
1767 | snode->type == SNodeType::bitmasked) { |
1768 | if (stmt->activate) { |
1769 | call(snode, llvm_val[stmt->input_snode], "activate" , |
1770 | {llvm_val[stmt->input_index]}); |
1771 | } |
1772 | llvm_val[stmt] = call(snode, llvm_val[stmt->input_snode], "lookup_element" , |
1773 | {llvm_val[stmt->input_index]}); |
1774 | } else if (snode->type == SNodeType::bit_struct) { |
1775 | llvm_val[stmt] = parent; |
1776 | } else if (snode->type == SNodeType::quant_array) { |
1777 | auto element_num_bits = |
1778 | snode->dt->as<QuantArrayType>()->get_element_num_bits(); |
1779 | auto offset = tlctx->get_constant(element_num_bits); |
1780 | offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); |
1781 | llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset); |
1782 | } else { |
1783 | TI_INFO(snode_type_name(snode->type)); |
1784 | TI_NOT_IMPLEMENTED |
1785 | } |
1786 | } |
1787 | |
1788 | void TaskCodeGenLLVM::visit(GetChStmt *stmt) { |
1789 | if (stmt->input_snode->type == SNodeType::quant_array) { |
1790 | llvm_val[stmt] = llvm_val[stmt->input_ptr]; |
1791 | } else if (stmt->ret_type->as<PointerType>()->is_bit_pointer()) { |
1792 | auto bit_struct = stmt->input_snode->dt->cast<BitStructType>(); |
1793 | auto bit_offset = |
1794 | bit_struct->get_member_bit_offset(stmt->output_snode->id_in_bit_struct); |
1795 | auto offset = tlctx->get_constant(bit_offset); |
1796 | llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset); |
1797 | } else { |
1798 | auto ch = call_struct_func( |
1799 | stmt->output_snode->get_snode_tree_id(), |
1800 | stmt->output_snode->get_ch_from_parent_func_name(), |
1801 | builder->CreateBitCast(llvm_val[stmt->input_ptr], |
1802 | llvm::PointerType::getInt8PtrTy(*llvm_context))); |
1803 | llvm_val[stmt] = builder->CreateBitCast( |
1804 | ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( |
1805 | module.get(), stmt->output_snode), |
1806 | 0)); |
1807 | } |
1808 | } |
1809 | |
1810 | void TaskCodeGenLLVM::visit(MatrixPtrStmt *stmt) { |
1811 | if (stmt->offset_used_as_index()) { |
1812 | auto type = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()); |
1813 | llvm_val[stmt] = |
1814 | builder->CreateGEP(type, llvm_val[stmt->origin], |
1815 | {tlctx->get_constant(0), llvm_val[stmt->offset]}); |
1816 | } else { |
1817 | // Access PtrOffset via: base_ptr + offset |
1818 | auto origin_address = builder->CreatePtrToInt( |
1819 | llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); |
1820 | auto address_offset = builder->CreateSExt( |
1821 | llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); |
1822 | auto target_address = builder->CreateAdd(origin_address, address_offset); |
1823 | auto dt = stmt->ret_type.ptr_removed(); |
1824 | llvm_val[stmt] = builder->CreateIntToPtr( |
1825 | target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); |
1826 | } |
1827 | } |
1828 | |
1829 | void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) { |
1830 | auto argload = stmt->base_ptr->as<ArgLoadStmt>(); |
1831 | auto arg_id = argload->arg_id; |
1832 | int num_indices = stmt->indices.size(); |
1833 | std::vector<llvm::Value *> sizes(num_indices); |
1834 | auto dt = stmt->ret_type.ptr_removed(); |
1835 | int num_element_indices = |
1836 | dt->is<TensorType>() ? 0 : stmt->element_shape.size(); |
1837 | const auto layout = stmt->element_dim <= 0 ? ExternalArrayLayout::kAOS |
1838 | : ExternalArrayLayout::kSOA; |
1839 | |
1840 | /* |
1841 | ExternalPtrStmt can be divided into "outter" and "inner" parts. |
1842 | |
1843 | For example, "x" is an Ndarray with shape = (5, 5, 6), m=2, n=3. |
1844 | Indexing to a single element of "x" is of form: x[i, j, k][m, n] |
1845 | |
1846 | The "outter" part is x[i, j, k], and the "inner" part is [m, n]. |
1847 | Shape of the inner part is known at compile time, stored in its ret_type. |
1848 | Shape of the outter part is determined at runtime, passed from the |
1849 | "extra_args". |
1850 | |
1851 | "num_indices - num_element_indices" gives how many "extra_args" to read from |
1852 | */ |
1853 | int num_array_args = num_indices - num_element_indices; |
1854 | const size_t element_shape_index_offset = |
1855 | (layout == ExternalArrayLayout::kAOS) ? num_array_args : 0; |
1856 | |
1857 | for (int i = 0; i < num_array_args; i++) { |
1858 | auto raw_arg = call("RuntimeContext_get_extra_args" , get_context(), |
1859 | tlctx->get_constant(arg_id), tlctx->get_constant(i)); |
1860 | sizes[i] = raw_arg; |
1861 | } |
1862 | |
1863 | auto linear_index = tlctx->get_constant(0); |
1864 | size_t size_var_index = 0; |
1865 | for (int i = 0; i < num_indices; i++) { |
1866 | if (i >= element_shape_index_offset && |
1867 | i < element_shape_index_offset + num_element_indices) { |
1868 | // Indexing TensorType-elements |
1869 | llvm::Value *size_var = tlctx->get_constant( |
1870 | stmt->element_shape[i - element_shape_index_offset]); |
1871 | linear_index = builder->CreateMul(linear_index, size_var); |
1872 | } else { |
1873 | // Indexing array dimensions |
1874 | linear_index = builder->CreateMul(linear_index, sizes[size_var_index++]); |
1875 | } |
1876 | linear_index = builder->CreateAdd(linear_index, llvm_val[stmt->indices[i]]); |
1877 | } |
1878 | TI_ASSERT(size_var_index == num_indices - num_element_indices); |
1879 | |
1880 | /* |
1881 | llvm::GEP implicitly indicates alignment when used upon llvm::VectorType. |
1882 | For example: |
1883 | |
1884 | "getelementptr <10 x i32>* %1, 0, 1" is interpreted as "%1 + 16(aligned)" |
1885 | |
1886 | However, this does not fit with Taichi's Ndarray semantics. We will have to |
1887 | do pointer arithmetics to manually calculate the offset. |
1888 | */ |
1889 | DataType operand_dtype = argload->ret_type.ptr_removed(); |
1890 | if (operand_dtype->is<TensorType>()) { |
1891 | // Access PtrOffset via: base_ptr + offset * sizeof(element) |
1892 | auto primitive_type = operand_dtype.get_element_type(); |
1893 | auto primitive_ptr = builder->CreateBitCast( |
1894 | llvm_val[stmt->base_ptr], |
1895 | llvm::PointerType::get(tlctx->get_data_type(primitive_type), 0)); |
1896 | |
1897 | auto address_offset = builder->CreateSExt( |
1898 | linear_index, llvm::Type::getInt64Ty(*llvm_context)); |
1899 | |
1900 | if (stmt->ret_type->is<TensorType>()) { |
1901 | // This case corresponds to outter indexing only |
1902 | // The stride for linear_index is num_elements() in TensorType. |
1903 | address_offset = builder->CreateMul( |
1904 | address_offset, |
1905 | tlctx->get_constant( |
1906 | get_data_type<int64>(), |
1907 | stmt->ret_type->cast<TensorType>()->get_num_elements())); |
1908 | } else { |
1909 | // This case corresponds to outter + inner indexing |
1910 | // Since both outter and inner indices are linearized into linear_index, |
1911 | // the stride for linear_index is 1, and there's nothing to do here. |
1912 | } |
1913 | |
1914 | auto ret_ptr = builder->CreateGEP(tlctx->get_data_type(primitive_type), |
1915 | primitive_ptr, address_offset); |
1916 | llvm_val[stmt] = builder->CreateBitCast( |
1917 | ret_ptr, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); |
1918 | |
1919 | } else { |
1920 | auto base_ty = tlctx->get_data_type(dt); |
1921 | auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr], |
1922 | llvm::PointerType::get(base_ty, 0)); |
1923 | |
1924 | llvm_val[stmt] = builder->CreateGEP(base_ty, base, linear_index); |
1925 | } |
1926 | } |
1927 | |
1928 | void TaskCodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) { |
1929 | const auto arg_id = stmt->arg_id; |
1930 | const auto axis = stmt->axis; |
1931 | llvm_val[stmt] = call("RuntimeContext_get_extra_args" , get_context(), |
1932 | tlctx->get_constant(arg_id), tlctx->get_constant(axis)); |
1933 | } |
1934 | |
1935 | std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, |
1936 | std::string suffix) { |
1937 | current_loop_reentry = nullptr; |
1938 | current_while_after_loop = nullptr; |
1939 | |
1940 | task_function_type = |
1941 | llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), |
1942 | {llvm::PointerType::get(context_ty, 0)}, false); |
1943 | |
1944 | auto task_kernel_name = |
1945 | fmt::format("{}_{}_{}{}" , kernel_name, kernel->get_next_task_id(), |
1946 | stmt->task_name(), suffix); |
1947 | func = llvm::Function::Create(task_function_type, |
1948 | llvm::Function::ExternalLinkage, |
1949 | task_kernel_name, module.get()); |
1950 | |
1951 | current_task = std::make_unique<OffloadedTask>(task_kernel_name); |
1952 | |
1953 | for (auto &arg : func->args()) { |
1954 | kernel_args.push_back(&arg); |
1955 | } |
1956 | kernel_args[0]->setName("context" ); |
1957 | if (kernel_argument_by_val()) |
1958 | func->addParamAttr( |
1959 | 0, llvm::Attribute::getWithByValType(*llvm_context, context_ty)); |
1960 | // entry_block has all the allocas |
1961 | this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry" , func); |
1962 | this->final_block = llvm::BasicBlock::Create(*llvm_context, "final" , func); |
1963 | |
1964 | // The real function body |
1965 | func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body" , func); |
1966 | builder->SetInsertPoint(func_body_bb); |
1967 | return task_kernel_name; |
1968 | } |
1969 | |
1970 | void TaskCodeGenLLVM::finalize_offloaded_task_function() { |
1971 | if (!returned) { |
1972 | builder->CreateBr(final_block); |
1973 | } else { |
1974 | returned = false; |
1975 | } |
1976 | builder->SetInsertPoint(final_block); |
1977 | builder->CreateRetVoid(); |
1978 | |
1979 | // entry_block should jump to the body after all allocas are inserted |
1980 | builder->SetInsertPoint(entry_block); |
1981 | builder->CreateBr(func_body_bb); |
1982 | |
1983 | if (compile_config.print_kernel_llvm_ir) { |
1984 | static FileSequenceWriter writer("taichi_kernel_generic_llvm_ir_{:04d}.ll" , |
1985 | "unoptimized LLVM IR (generic)" ); |
1986 | writer.write(module.get()); |
1987 | } |
1988 | TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); |
1989 | // TI_INFO("Kernel function verified."); |
1990 | } |
1991 | |
1992 | std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds( |
1993 | OffloadedStmt *stmt) { |
1994 | llvm::Value *begin, *end; |
1995 | if (stmt->const_begin) { |
1996 | begin = tlctx->get_constant(stmt->begin_value); |
1997 | } else { |
1998 | auto begin_stmt = |
1999 | Stmt::make<GlobalTemporaryStmt>(stmt->begin_offset, PrimitiveType::i32); |
2000 | begin_stmt->accept(this); |
2001 | begin = builder->CreateLoad(tlctx->get_data_type(PrimitiveType::i32), |
2002 | llvm_val[begin_stmt.get()]); |
2003 | } |
2004 | if (stmt->const_end) { |
2005 | end = tlctx->get_constant(stmt->end_value); |
2006 | } else { |
2007 | auto end_stmt = |
2008 | Stmt::make<GlobalTemporaryStmt>(stmt->end_offset, PrimitiveType::i32); |
2009 | end_stmt->accept(this); |
2010 | end = builder->CreateLoad(tlctx->get_data_type(PrimitiveType::i32), |
2011 | llvm_val[end_stmt.get()]); |
2012 | } |
2013 | return std::tuple(begin, end); |
2014 | } |
2015 | |
2016 | void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) { |
2017 | using namespace llvm; |
2018 | // TODO: instead of constructing tons of LLVM IR, writing the logic in |
2019 | // runtime.cpp may be a cleaner solution. See |
2020 | // TaskCodeGenCPU::create_offload_range_for as an example. |
2021 | |
2022 | llvm::Function *body = nullptr; |
2023 | auto leaf_block = stmt->snode; |
2024 | |
2025 | // For a bit-vectorized loop over a quant array, we generate struct for on its |
2026 | // parent node (must be "dense") instead of itself for higher performance. |
2027 | if (stmt->is_bit_vectorized) { |
2028 | if (leaf_block->type == SNodeType::quant_array && |
2029 | leaf_block->parent->type == SNodeType::dense) { |
2030 | leaf_block = leaf_block->parent; |
2031 | } else { |
2032 | TI_ERROR( |
2033 | "A bit-vectorized struct-for must loop over a quant array with a " |
2034 | "dense parent" ); |
2035 | } |
2036 | } |
2037 | |
2038 | { |
2039 | // Create the loop body function |
2040 | auto guard = get_function_creation_guard({ |
2041 | llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
2042 | get_tls_buffer_type(), |
2043 | llvm::PointerType::get(get_runtime_type("Element" ), 0), |
2044 | tlctx->get_data_type<int>(), |
2045 | tlctx->get_data_type<int>(), |
2046 | }); |
2047 | |
2048 | body = guard.body; |
2049 | |
2050 | /* Function structure: |
2051 | * |
2052 | * function_body (entry): |
2053 | * loop_index = lower_bound; |
2054 | * tls_prologue() |
2055 | * bls_prologue() |
2056 | * goto loop_test |
2057 | * |
2058 | * loop_test: |
2059 | * if (loop_index < upper_bound) |
2060 | * goto loop_body |
2061 | * else |
2062 | * goto func_exit |
2063 | * |
2064 | * loop_body: |
2065 | * initialize_coordinates() |
2066 | * if (bitmasked voxel is active) |
2067 | * goto struct_for_body |
2068 | * else |
2069 | * goto loop_body_tail |
2070 | * |
2071 | * struct_for_body: |
2072 | * ... (Run codegen on the StructForStmt::body Taichi Block) |
2073 | * goto loop_body_tail |
2074 | * |
2075 | * loop_body_tail: |
2076 | * loop_index += block_dim |
2077 | * goto loop_test |
2078 | * |
2079 | * func_exit: |
2080 | * bls_epilogue() |
2081 | * tls_epilogue() |
2082 | * return |
2083 | */ |
2084 | auto loop_index_ty = llvm::Type::getInt32Ty(*llvm_context); |
2085 | auto loop_index = create_entry_block_alloca(loop_index_ty); |
2086 | |
2087 | RuntimeObject element("Element" , this, builder.get(), get_arg(2)); |
2088 | |
2089 | // Loop ranges |
2090 | auto lower_bound = get_arg(3); |
2091 | auto upper_bound = get_arg(4); |
2092 | |
2093 | parent_coordinates = element.get_ptr("pcoord" ); |
2094 | block_corner_coordinates = |
2095 | create_entry_block_alloca(physical_coordinate_ty); |
2096 | |
2097 | auto refine = |
2098 | get_struct_function(leaf_block->refine_coordinates_func_name(), |
2099 | leaf_block->get_snode_tree_id()); |
2100 | // A block corner is the global coordinate/index of the lower-left corner |
2101 | // cell within that block, and is the same for all the cells within that |
2102 | // block. |
2103 | call(refine, parent_coordinates, block_corner_coordinates, |
2104 | tlctx->get_constant(0)); |
2105 | |
2106 | if (stmt->tls_prologue) { |
2107 | stmt->tls_prologue->accept(this); |
2108 | } |
2109 | |
2110 | if (stmt->bls_prologue) { |
2111 | call("block_barrier" ); // "__syncthreads()" |
2112 | stmt->bls_prologue->accept(this); |
2113 | call("block_barrier" ); // "__syncthreads()" |
2114 | } |
2115 | |
2116 | auto [thread_idx, block_dim] = this->get_spmd_info(); |
2117 | builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound), |
2118 | loop_index); |
2119 | |
2120 | auto loop_test_bb = BasicBlock::Create(*llvm_context, "loop_test" , func); |
2121 | auto loop_body_bb = BasicBlock::Create(*llvm_context, "loop_body" , func); |
2122 | auto body_tail_bb = |
2123 | BasicBlock::Create(*llvm_context, "loop_body_tail" , func); |
2124 | auto func_exit = BasicBlock::Create(*llvm_context, "func_exit" , func); |
2125 | auto struct_for_body_bb = |
2126 | BasicBlock::Create(*llvm_context, "struct_for_body_body" , func); |
2127 | |
2128 | auto lrg = make_loop_reentry_guard(this); |
2129 | current_loop_reentry = body_tail_bb; |
2130 | |
2131 | builder->CreateBr(loop_test_bb); |
2132 | |
2133 | { |
2134 | // loop_test: |
2135 | // if (loop_index < upper_bound) |
2136 | // goto loop_body; |
2137 | // else |
2138 | // goto func_exit |
2139 | |
2140 | builder->SetInsertPoint(loop_test_bb); |
2141 | auto cond = builder->CreateICmp( |
2142 | llvm::CmpInst::Predicate::ICMP_SLT, |
2143 | builder->CreateLoad(loop_index_ty, loop_index), upper_bound); |
2144 | builder->CreateCondBr(cond, loop_body_bb, func_exit); |
2145 | } |
2146 | |
2147 | // *********************** |
2148 | // Begin loop_body_bb: |
2149 | builder->SetInsertPoint(loop_body_bb); |
2150 | |
2151 | // initialize the coordinates |
2152 | auto new_coordinates = create_entry_block_alloca(physical_coordinate_ty); |
2153 | |
2154 | call(refine, parent_coordinates, new_coordinates, |
2155 | builder->CreateLoad(loop_index_ty, loop_index)); |
2156 | |
2157 | // For a bit-vectorized loop over a quant array, one more refine step is |
2158 | // needed to make final coordinates non-consecutive, since each thread will |
2159 | // process multiple coordinates via vectorization |
2160 | if (stmt->is_bit_vectorized) { |
2161 | refine = get_struct_function(stmt->snode->refine_coordinates_func_name(), |
2162 | stmt->snode->get_snode_tree_id()); |
2163 | call(refine, new_coordinates, new_coordinates, tlctx->get_constant(0)); |
2164 | } |
2165 | |
2166 | current_coordinates = new_coordinates; |
2167 | |
2168 | // exec_cond: safe-guard the execution of loop body: |
2169 | // - if non-POT field dim exists, make sure we don't go out of bounds |
2170 | // - if leaf block is bitmasked, make sure we only loop over active |
2171 | // voxels |
2172 | auto exec_cond = tlctx->get_constant(true); |
2173 | auto coord_object = RuntimeObject(kLLVMPhysicalCoordinatesName, this, |
2174 | builder.get(), new_coordinates); |
2175 | |
2176 | if (leaf_block->type == SNodeType::bitmasked || |
2177 | leaf_block->type == SNodeType::pointer) { |
2178 | // test whether the current voxel is active or not |
2179 | auto is_active = call(leaf_block, element.get("element" ), "is_active" , |
2180 | {builder->CreateLoad(loop_index_ty, loop_index)}); |
2181 | is_active = |
2182 | builder->CreateTrunc(is_active, llvm::Type::getInt1Ty(*llvm_context)); |
2183 | exec_cond = builder->CreateAnd(exec_cond, is_active); |
2184 | } |
2185 | |
2186 | builder->CreateCondBr(exec_cond, struct_for_body_bb, body_tail_bb); |
2187 | |
2188 | { |
2189 | builder->SetInsertPoint(struct_for_body_bb); |
2190 | |
2191 | // The real loop body of the StructForStmt |
2192 | stmt->body->accept(this); |
2193 | |
2194 | builder->CreateBr(body_tail_bb); |
2195 | } |
2196 | |
2197 | { |
2198 | // body tail: increment loop_index and jump to loop_test |
2199 | builder->SetInsertPoint(body_tail_bb); |
2200 | |
2201 | create_increment(loop_index, block_dim); |
2202 | builder->CreateBr(loop_test_bb); |
2203 | |
2204 | builder->SetInsertPoint(func_exit); |
2205 | } |
2206 | |
2207 | if (stmt->bls_epilogue) { |
2208 | call("block_barrier" ); // "__syncthreads()" |
2209 | stmt->bls_epilogue->accept(this); |
2210 | call("block_barrier" ); // "__syncthreads()" |
2211 | } |
2212 | |
2213 | if (stmt->tls_epilogue) { |
2214 | stmt->tls_epilogue->accept(this); |
2215 | } |
2216 | } |
2217 | |
2218 | int list_element_size = std::min(leaf_block->max_num_elements(), |
2219 | (int64)taichi_listgen_max_element_size); |
2220 | int num_splits = std::max(1, list_element_size / stmt->block_dim + |
2221 | (list_element_size % stmt->block_dim != 0)); |
2222 | |
2223 | auto struct_for_func = get_runtime_function("parallel_struct_for" ); |
2224 | |
2225 | if (arch_is_gpu(current_arch())) { |
2226 | struct_for_func = llvm::cast<llvm::Function>( |
2227 | module |
2228 | ->getOrInsertFunction( |
2229 | tlctx->get_struct_for_func_name(stmt->tls_size), |
2230 | struct_for_func->getFunctionType(), |
2231 | struct_for_func->getAttributes()) |
2232 | .getCallee()); |
2233 | struct_for_tls_sizes.insert(stmt->tls_size); |
2234 | } |
2235 | // Loop over nodes in the element list, in parallel |
2236 | call(struct_for_func, get_context(), tlctx->get_constant(leaf_block->id), |
2237 | tlctx->get_constant(list_element_size), tlctx->get_constant(num_splits), |
2238 | body, tlctx->get_constant(stmt->tls_size), |
2239 | tlctx->get_constant(stmt->num_cpu_threads)); |
2240 | // TODO: why do we need num_cpu_threads on GPUs? |
2241 | |
2242 | current_coordinates = nullptr; |
2243 | parent_coordinates = nullptr; |
2244 | block_corner_coordinates = nullptr; |
2245 | } |
2246 | |
2247 | void TaskCodeGenLLVM::visit(LoopIndexStmt *stmt) { |
2248 | if (stmt->loop->is<OffloadedStmt>() && |
2249 | stmt->loop->as<OffloadedStmt>()->task_type == |
2250 | OffloadedStmt::TaskType::struct_for) { |
2251 | llvm::Type *struct_ty = nullptr; |
2252 | // FIXME: get struct_ty from taichi instead of llvm. |
2253 | if (auto *alloca = llvm::dyn_cast<llvm::AllocaInst>(current_coordinates)) { |
2254 | struct_ty = alloca->getAllocatedType(); |
2255 | } |
2256 | TI_ASSERT(struct_ty); |
2257 | auto *GEP = |
2258 | builder->CreateGEP(struct_ty, current_coordinates, |
2259 | {tlctx->get_constant(0), tlctx->get_constant(0), |
2260 | tlctx->get_constant(stmt->index)}); |
2261 | if (stmt->index == 0 && !llvm::isa<llvm::GEPOperator>(GEP)) |
2262 | GEP = builder->CreateBitCast(GEP, struct_ty->getPointerTo()); |
2263 | llvm_val[stmt] = |
2264 | builder->CreateLoad(llvm::Type::getInt32Ty(*llvm_context), GEP); |
2265 | } else { |
2266 | llvm_val[stmt] = |
2267 | builder->CreateLoad(llvm::Type::getInt32Ty(*llvm_context), |
2268 | loop_vars_llvm[stmt->loop][stmt->index]); |
2269 | } |
2270 | } |
2271 | |
2272 | void TaskCodeGenLLVM::visit(LoopLinearIndexStmt *stmt) { |
2273 | if (stmt->loop->is<OffloadedStmt>() && |
2274 | (stmt->loop->as<OffloadedStmt>()->task_type == |
2275 | OffloadedStmt::TaskType::struct_for || |
2276 | stmt->loop->as<OffloadedStmt>()->task_type == |
2277 | OffloadedStmt::TaskType::mesh_for)) { |
2278 | llvm_val[stmt] = call("thread_idx" ); |
2279 | } else { |
2280 | TI_NOT_IMPLEMENTED; |
2281 | } |
2282 | } |
2283 | |
2284 | void TaskCodeGenLLVM::visit(BlockCornerIndexStmt *stmt) { |
2285 | if (stmt->loop->is<OffloadedStmt>() && |
2286 | stmt->loop->as<OffloadedStmt>()->task_type == |
2287 | OffloadedStmt::TaskType::struct_for) { |
2288 | TI_ASSERT(block_corner_coordinates); |
2289 | // Make sure physical_coordinate_ty matches |
2290 | // struct PhysicalCoordinates { |
2291 | // i32 val[taichi_max_num_indices]; |
2292 | // }; |
2293 | TI_ASSERT(physical_coordinate_ty->isStructTy()); |
2294 | auto physical_coordinate_ty_as_struct = |
2295 | llvm::cast<llvm::StructType>(physical_coordinate_ty); |
2296 | TI_ASSERT(physical_coordinate_ty_as_struct); |
2297 | TI_ASSERT(physical_coordinate_ty_as_struct->getNumElements() == 1); |
2298 | auto val_ty = physical_coordinate_ty_as_struct->getElementType(0); |
2299 | TI_ASSERT(val_ty->isArrayTy()); |
2300 | auto val_ty_as_array = llvm::cast<llvm::ArrayType>(val_ty); |
2301 | llvm_val[stmt] = builder->CreateLoad( |
2302 | val_ty_as_array->getElementType(), |
2303 | builder->CreateGEP(physical_coordinate_ty, block_corner_coordinates, |
2304 | {tlctx->get_constant(0), tlctx->get_constant(0), |
2305 | tlctx->get_constant(stmt->index)})); |
2306 | } else { |
2307 | TI_NOT_IMPLEMENTED; |
2308 | } |
2309 | } |
2310 | |
2311 | void TaskCodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { |
2312 | auto runtime = get_runtime(); |
2313 | auto buffer = call("get_temporary_pointer" , runtime, |
2314 | tlctx->get_constant((int64)stmt->offset)); |
2315 | |
2316 | auto ptr_type = llvm::PointerType::get( |
2317 | tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); |
2318 | llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); |
2319 | } |
2320 | |
2321 | void TaskCodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) { |
2322 | auto base = get_tls_base_ptr(); |
2323 | auto ptr = builder->CreateGEP(llvm::Type::getInt8Ty(*llvm_context), base, |
2324 | tlctx->get_constant(stmt->offset)); |
2325 | auto ptr_type = llvm::PointerType::get( |
2326 | tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); |
2327 | llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); |
2328 | } |
2329 | |
2330 | void TaskCodeGenLLVM::visit(BlockLocalPtrStmt *stmt) { |
2331 | TI_ASSERT(bls_buffer); |
2332 | auto base = bls_buffer; |
2333 | auto ptr = |
2334 | builder->CreateGEP(base->getValueType(), base, |
2335 | {tlctx->get_constant(0), llvm_val[stmt->offset]}); |
2336 | auto ptr_type = llvm::PointerType::get( |
2337 | tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); |
2338 | llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); |
2339 | } |
2340 | |
2341 | void TaskCodeGenLLVM::visit(ClearListStmt *stmt) { |
2342 | auto snode_child = stmt->snode; |
2343 | auto snode_parent = stmt->snode->parent; |
2344 | auto meta_child = cast_pointer(emit_struct_meta(snode_child), "StructMeta" ); |
2345 | auto meta_parent = cast_pointer(emit_struct_meta(snode_parent), "StructMeta" ); |
2346 | call("clear_list" , get_runtime(), meta_parent, meta_child); |
2347 | } |
2348 | |
2349 | void TaskCodeGenLLVM::visit(InternalFuncStmt *stmt) { |
2350 | std::vector<llvm::Value *> args; |
2351 | |
2352 | if (stmt->with_runtime_context) |
2353 | args.push_back(get_context()); |
2354 | |
2355 | for (auto s : stmt->args) { |
2356 | args.push_back(llvm_val[s]); |
2357 | } |
2358 | llvm_val[stmt] = call(stmt->func_name, std::move(args)); |
2359 | } |
2360 | |
2361 | void TaskCodeGenLLVM::visit(AdStackAllocaStmt *stmt) { |
2362 | TI_ASSERT_INFO(stmt->max_size > 0, |
2363 | "Adaptive autodiff stack's size should have been determined." ); |
2364 | auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), |
2365 | stmt->size_in_bytes()); |
2366 | auto alloca = create_entry_block_alloca(type, sizeof(int64)); |
2367 | llvm_val[stmt] = builder->CreateBitCast( |
2368 | alloca, llvm::PointerType::getInt8PtrTy(*llvm_context)); |
2369 | call("stack_init" , llvm_val[stmt]); |
2370 | } |
2371 | |
2372 | void TaskCodeGenLLVM::visit(AdStackPopStmt *stmt) { |
2373 | call("stack_pop" , llvm_val[stmt->stack]); |
2374 | } |
2375 | |
2376 | void TaskCodeGenLLVM::visit(AdStackPushStmt *stmt) { |
2377 | auto stack = stmt->stack->as<AdStackAllocaStmt>(); |
2378 | call("stack_push" , llvm_val[stack], tlctx->get_constant(stack->max_size), |
2379 | tlctx->get_constant(stack->element_size_in_bytes())); |
2380 | auto primal_ptr = call("stack_top_primal" , llvm_val[stack], |
2381 | tlctx->get_constant(stack->element_size_in_bytes())); |
2382 | primal_ptr = builder->CreateBitCast( |
2383 | primal_ptr, |
2384 | llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0)); |
2385 | builder->CreateStore(llvm_val[stmt->v], primal_ptr); |
2386 | } |
2387 | |
2388 | void TaskCodeGenLLVM::visit(AdStackLoadTopStmt *stmt) { |
2389 | auto stack = stmt->stack->as<AdStackAllocaStmt>(); |
2390 | auto primal_ptr = call("stack_top_primal" , llvm_val[stack], |
2391 | tlctx->get_constant(stack->element_size_in_bytes())); |
2392 | auto primal_ty = tlctx->get_data_type(stmt->ret_type); |
2393 | primal_ptr = |
2394 | builder->CreateBitCast(primal_ptr, llvm::PointerType::get(primal_ty, 0)); |
2395 | llvm_val[stmt] = builder->CreateLoad(primal_ty, primal_ptr); |
2396 | } |
2397 | |
2398 | void TaskCodeGenLLVM::visit(AdStackLoadTopAdjStmt *stmt) { |
2399 | auto stack = stmt->stack->as<AdStackAllocaStmt>(); |
2400 | auto adjoint = call("stack_top_adjoint" , llvm_val[stack], |
2401 | tlctx->get_constant(stack->element_size_in_bytes())); |
2402 | auto adjoint_ty = tlctx->get_data_type(stmt->ret_type); |
2403 | adjoint = |
2404 | builder->CreateBitCast(adjoint, llvm::PointerType::get(adjoint_ty, 0)); |
2405 | llvm_val[stmt] = builder->CreateLoad(adjoint_ty, adjoint); |
2406 | } |
2407 | |
2408 | void TaskCodeGenLLVM::visit(AdStackAccAdjointStmt *stmt) { |
2409 | auto stack = stmt->stack->as<AdStackAllocaStmt>(); |
2410 | auto adjoint_ptr = call("stack_top_adjoint" , llvm_val[stack], |
2411 | tlctx->get_constant(stack->element_size_in_bytes())); |
2412 | auto adjoint_ty = tlctx->get_data_type(stack->ret_type); |
2413 | adjoint_ptr = builder->CreateBitCast(adjoint_ptr, |
2414 | llvm::PointerType::get(adjoint_ty, 0)); |
2415 | auto old_val = builder->CreateLoad(adjoint_ty, adjoint_ptr); |
2416 | TI_ASSERT(is_real(stmt->v->ret_type)); |
2417 | auto new_val = builder->CreateFAdd(old_val, llvm_val[stmt->v]); |
2418 | builder->CreateStore(new_val, adjoint_ptr); |
2419 | } |
2420 | |
2421 | void TaskCodeGenLLVM::visit(RangeAssumptionStmt *stmt) { |
2422 | llvm_val[stmt] = llvm_val[stmt->input]; |
2423 | } |
2424 | |
2425 | void TaskCodeGenLLVM::visit(LoopUniqueStmt *stmt) { |
2426 | llvm_val[stmt] = llvm_val[stmt->input]; |
2427 | } |
2428 | |
2429 | void TaskCodeGenLLVM::visit_call_bitcode(ExternalFuncCallStmt *stmt) { |
2430 | TI_ASSERT(stmt->type == ExternalFuncCallStmt::BITCODE); |
2431 | std::vector<llvm::Value *> arg_values; |
2432 | for (const auto &s : stmt->arg_stmts) |
2433 | arg_values.push_back(llvm_val[s]); |
2434 | // Link external module to the core module |
2435 | if (linked_modules.find(stmt->bc_filename) == linked_modules.end()) { |
2436 | linked_modules.insert(stmt->bc_filename); |
2437 | std::unique_ptr<llvm::Module> external_module = |
2438 | module_from_bitcode_file(stmt->bc_filename, llvm_context); |
2439 | auto *func_ptr = external_module->getFunction(stmt->bc_funcname); |
2440 | TI_ASSERT_INFO(func_ptr != nullptr, "{} is not found in {}." , |
2441 | stmt->bc_funcname, stmt->bc_filename); |
2442 | auto link_error = |
2443 | llvm::Linker::linkModules(*module, std::move(external_module)); |
2444 | TI_ASSERT(!link_error); |
2445 | } |
2446 | // Retrieve function again. Do it here to detect name conflicting. |
2447 | auto *func_ptr = module->getFunction(stmt->bc_funcname); |
2448 | // Convert pointer type from a[n * m] to a[n][m] |
2449 | for (int i = 0; i < func_ptr->getFunctionType()->getNumParams(); ++i) { |
2450 | TI_ASSERT_INFO(func_ptr->getArg(i)->getType()->getTypeID() == |
2451 | arg_values[i]->getType()->getTypeID(), |
2452 | "TypeID {} != {} with {}" , |
2453 | (int)func_ptr->getArg(i)->getType()->getTypeID(), |
2454 | (int)arg_values[i]->getType()->getTypeID(), i); |
2455 | auto tmp_value = arg_values[i]; |
2456 | arg_values[i] = |
2457 | builder->CreatePointerCast(tmp_value, func_ptr->getArg(i)->getType()); |
2458 | } |
2459 | call(func_ptr, arg_values); |
2460 | } |
2461 | |
2462 | void TaskCodeGenLLVM::visit_call_shared_object(ExternalFuncCallStmt *stmt) { |
2463 | TI_ASSERT(stmt->type == ExternalFuncCallStmt::SHARED_OBJECT); |
2464 | std::vector<llvm::Type *> arg_types; |
2465 | std::vector<llvm::Value *> arg_values; |
2466 | |
2467 | for (const auto &s : stmt->arg_stmts) { |
2468 | arg_types.push_back(tlctx->get_data_type(s->ret_type)); |
2469 | arg_values.push_back(llvm_val[s]); |
2470 | } |
2471 | |
2472 | for (const auto &s : stmt->output_stmts) { |
2473 | auto t = tlctx->get_data_type(s->ret_type); |
2474 | auto ptr = llvm::PointerType::get(t, 0); |
2475 | arg_types.push_back(ptr); |
2476 | arg_values.push_back(llvm_val[s]); |
2477 | } |
2478 | |
2479 | auto func_type = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), |
2480 | arg_types, false); |
2481 | auto func_ptr_type = llvm::PointerType::get(func_type, 0); |
2482 | |
2483 | auto addr = tlctx->get_constant((std::size_t)stmt->so_func); |
2484 | auto func = builder->CreateIntToPtr(addr, func_ptr_type); |
2485 | call(func, func_type, arg_values); |
2486 | } |
2487 | |
2488 | void TaskCodeGenLLVM::visit(ExternalFuncCallStmt *stmt) { |
2489 | TI_NOT_IMPLEMENTED |
2490 | } |
2491 | |
2492 | void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { |
2493 | llvm_val[stmt] = get_arg(2); |
2494 | } |
2495 | |
2496 | void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { |
2497 | auto type = tlctx->get_data_type(stmt->ret_type->as<TensorType>()); |
2498 | llvm::Value *vec = llvm::UndefValue::get(type); |
2499 | for (int i = 0; i < stmt->values.size(); ++i) { |
2500 | auto *elem = llvm_val[stmt->values[i]]; |
2501 | if (codegen_vector_type(compile_config)) { |
2502 | TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(type)); |
2503 | vec = builder->CreateInsertElement(vec, elem, i); |
2504 | } else { |
2505 | TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(type)); |
2506 | vec = builder->CreateInsertValue(vec, elem, i); |
2507 | } |
2508 | } |
2509 | llvm_val[stmt] = vec; |
2510 | } |
2511 | |
2512 | void TaskCodeGenLLVM::eliminate_unused_functions() { |
2513 | TaichiLLVMContext::eliminate_unused_functions( |
2514 | module.get(), [&](std::string func_name) { |
2515 | for (auto &task : offloaded_tasks) { |
2516 | if (task.name == func_name) |
2517 | return true; |
2518 | } |
2519 | return false; |
2520 | }); |
2521 | } |
2522 | |
2523 | FunctionCreationGuard TaskCodeGenLLVM::get_function_creation_guard( |
2524 | std::vector<llvm::Type *> argument_types, |
2525 | const std::string &func_name) { |
2526 | return FunctionCreationGuard(this, argument_types, func_name); |
2527 | } |
2528 | |
2529 | void TaskCodeGenLLVM::initialize_context() { |
2530 | TI_ASSERT(tlctx != nullptr); |
2531 | llvm_context = tlctx->get_this_thread_context(); |
2532 | builder = std::make_unique<llvm::IRBuilder<>>(*llvm_context); |
2533 | } |
2534 | |
2535 | llvm::Value *TaskCodeGenLLVM::get_arg(int i) { |
2536 | std::vector<llvm::Value *> args; |
2537 | for (auto &arg : func->args()) { |
2538 | args.push_back(&arg); |
2539 | } |
2540 | return args[i]; |
2541 | } |
2542 | |
2543 | llvm::Value *TaskCodeGenLLVM::get_context() { |
2544 | return get_arg(0); |
2545 | } |
2546 | |
2547 | llvm::Value *TaskCodeGenLLVM::get_tls_base_ptr() { |
2548 | return get_arg(1); |
2549 | } |
2550 | |
2551 | llvm::Type *TaskCodeGenLLVM::get_tls_buffer_type() { |
2552 | return llvm::Type::getInt8PtrTy(*llvm_context); |
2553 | } |
2554 | |
2555 | std::vector<llvm::Type *> TaskCodeGenLLVM::get_xlogue_argument_types() { |
2556 | return {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
2557 | get_tls_buffer_type()}; |
2558 | } |
2559 | |
2560 | std::vector<llvm::Type *> TaskCodeGenLLVM::get_mesh_xlogue_argument_types() { |
2561 | return {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
2562 | get_tls_buffer_type(), tlctx->get_data_type<uint32_t>()}; |
2563 | } |
2564 | |
2565 | llvm::Type *TaskCodeGenLLVM::get_xlogue_function_type() { |
2566 | return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), |
2567 | get_xlogue_argument_types(), false); |
2568 | } |
2569 | |
2570 | llvm::Type *TaskCodeGenLLVM::get_mesh_xlogue_function_type() { |
2571 | return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), |
2572 | get_mesh_xlogue_argument_types(), false); |
2573 | } |
2574 | |
2575 | llvm::Value *TaskCodeGenLLVM::get_root(int snode_tree_id) { |
2576 | return call("LLVMRuntime_get_roots" , get_runtime(), |
2577 | tlctx->get_constant(snode_tree_id)); |
2578 | } |
2579 | |
2580 | llvm::Value *TaskCodeGenLLVM::get_runtime() { |
2581 | auto runtime_ptr = call("RuntimeContext_get_runtime" , get_context()); |
2582 | return builder->CreateBitCast( |
2583 | runtime_ptr, llvm::PointerType::get(get_runtime_type("LLVMRuntime" ), 0)); |
2584 | } |
2585 | |
2586 | llvm::Value *TaskCodeGenLLVM::emit_struct_meta(SNode *snode) { |
2587 | auto obj = emit_struct_meta_object(snode); |
2588 | TI_ASSERT(obj != nullptr); |
2589 | return obj->ptr; |
2590 | } |
2591 | |
2592 | void TaskCodeGenLLVM::emit_to_module() { |
2593 | TI_AUTO_PROF |
2594 | ir->accept(this); |
2595 | } |
2596 | |
2597 | LLVMCompiledTask TaskCodeGenLLVM::run_compilation() { |
2598 | // Final lowering |
2599 | auto offload_to_executable = [](IRNode *ir, const CompileConfig &config, |
2600 | Kernel *kernel) { |
2601 | bool verbose = config.print_ir; |
2602 | if ((kernel->is_accessor && !config.print_accessor_ir) || |
2603 | (kernel->is_evaluator && !config.print_evaluator_ir)) { |
2604 | verbose = false; |
2605 | } |
2606 | irpass::offload_to_executable( |
2607 | ir, config, kernel, verbose, |
2608 | /*determine_ad_stack_size=*/kernel->autodiff_mode == |
2609 | AutodiffMode::kReverse, |
2610 | /*lower_global_access=*/true, |
2611 | /*make_thread_local=*/config.make_thread_local, |
2612 | /*make_block_local=*/ |
2613 | is_extension_supported(config.arch, Extension::bls) && |
2614 | config.make_block_local); |
2615 | }; |
2616 | |
2617 | offload_to_executable(ir, compile_config, kernel); |
2618 | |
2619 | emit_to_module(); |
2620 | eliminate_unused_functions(); |
2621 | |
2622 | if (compile_config.arch == Arch::cuda) { |
2623 | // CUDA specific metadata |
2624 | for (const auto &task : offloaded_tasks) { |
2625 | llvm::Function *func = module->getFunction(task.name); |
2626 | TI_ASSERT(func); |
2627 | tlctx->mark_function_as_cuda_kernel(func, task.block_dim); |
2628 | } |
2629 | } else if (compile_config.arch == Arch::amdgpu) { |
2630 | for (const auto &task : offloaded_tasks) { |
2631 | llvm::Function *func = module->getFunction(task.name); |
2632 | TI_ASSERT(func); |
2633 | tlctx->mark_function_as_amdgpu_kernel(func); |
2634 | } |
2635 | } |
2636 | |
2637 | return {std::move(offloaded_tasks), std::move(module), |
2638 | std::move(used_tree_ids), std::move(struct_for_tls_sizes)}; |
2639 | } |
2640 | |
2641 | llvm::Value *TaskCodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) { |
2642 | llvm::Value *xlogue; |
2643 | |
2644 | auto xlogue_type = get_xlogue_function_type(); |
2645 | auto xlogue_ptr_type = llvm::PointerType::get(xlogue_type, 0); |
2646 | |
2647 | if (block) { |
2648 | auto guard = get_function_creation_guard(get_xlogue_argument_types()); |
2649 | block->accept(this); |
2650 | xlogue = guard.body; |
2651 | } else { |
2652 | xlogue = llvm::ConstantPointerNull::get(xlogue_ptr_type); |
2653 | } |
2654 | |
2655 | return xlogue; |
2656 | } |
2657 | |
2658 | llvm::Value *TaskCodeGenLLVM::create_mesh_xlogue( |
2659 | std::unique_ptr<Block> &block) { |
2660 | llvm::Value *xlogue; |
2661 | |
2662 | auto xlogue_type = get_mesh_xlogue_function_type(); |
2663 | auto xlogue_ptr_type = llvm::PointerType::get(xlogue_type, 0); |
2664 | |
2665 | if (block) { |
2666 | auto guard = get_function_creation_guard(get_mesh_xlogue_argument_types()); |
2667 | block->accept(this); |
2668 | xlogue = guard.body; |
2669 | } else { |
2670 | xlogue = llvm::ConstantPointerNull::get(xlogue_ptr_type); |
2671 | } |
2672 | |
2673 | return xlogue; |
2674 | } |
2675 | |
2676 | void TaskCodeGenLLVM::visit(ReferenceStmt *stmt) { |
2677 | llvm_val[stmt] = llvm_val[stmt->var]; |
2678 | } |
2679 | |
2680 | void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) { |
2681 | if (!func_map.count(stmt->func)) { |
2682 | auto guard = get_function_creation_guard( |
2683 | {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0)}, |
2684 | stmt->func->get_name()); |
2685 | Callable *old_callable = current_callable; |
2686 | current_callable = stmt->func; |
2687 | func_map.insert({stmt->func, guard.body}); |
2688 | stmt->func->ir->accept(this); |
2689 | current_callable = old_callable; |
2690 | } |
2691 | llvm::Function *llvm_func = func_map[stmt->func]; |
2692 | auto *new_ctx = call("allocate_runtime_context" , get_runtime()); |
2693 | call("RuntimeContext_set_runtime" , new_ctx, get_runtime()); |
2694 | for (int i = 0; i < stmt->args.size(); i++) { |
2695 | auto *val = |
2696 | bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type); |
2697 | call("RuntimeContext_set_args" , new_ctx, |
2698 | llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val); |
2699 | } |
2700 | llvm::Value *result_buffer = nullptr; |
2701 | if (stmt->ret_type) { |
2702 | auto *ret_type = tlctx->get_data_type(stmt->ret_type); |
2703 | result_buffer = builder->CreateAlloca(ret_type); |
2704 | auto *result_buffer_u64 = builder->CreatePointerCast( |
2705 | result_buffer, |
2706 | llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0)); |
2707 | call("RuntimeContext_set_result_buffer" , new_ctx, result_buffer_u64); |
2708 | } |
2709 | call(llvm_func, new_ctx); |
2710 | llvm_val[stmt] = result_buffer; |
2711 | call("recycle_runtime_context" , get_runtime(), new_ctx); |
2712 | } |
2713 | |
2714 | void TaskCodeGenLLVM::visit(GetElementStmt *stmt) { |
2715 | auto *struct_type = tlctx->get_data_type(stmt->src->ret_type); |
2716 | std::vector<llvm::Value *> index; |
2717 | index.reserve(stmt->index.size() + 1); |
2718 | index.push_back(tlctx->get_constant(0)); |
2719 | for (auto &i : stmt->index) { |
2720 | index.push_back(tlctx->get_constant(i)); |
2721 | } |
2722 | auto *gep = builder->CreateGEP(struct_type, llvm_val[stmt->src], index); |
2723 | auto *val = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), gep); |
2724 | llvm_val[stmt] = val; |
2725 | } |
2726 | |
2727 | void TaskCodeGenLLVM::create_return(llvm::Value *buffer, |
2728 | llvm::Type *buffer_type, |
2729 | const std::vector<Stmt *> &elements, |
2730 | const Type *current_type, |
2731 | int ¤t_element, |
2732 | std::vector<llvm::Value *> ¤t_index) { |
2733 | if (auto primitive_type = current_type->cast<PrimitiveType>()) { |
2734 | TI_ASSERT((Type *)elements[current_element]->ret_type == current_type); |
2735 | auto *gep = builder->CreateGEP(buffer_type, buffer, current_index); |
2736 | builder->CreateStore(llvm_val[elements[current_element]], gep); |
2737 | current_element++; |
2738 | } else if (auto struct_type = current_type->cast<StructType>()) { |
2739 | int i = 0; |
2740 | for (const auto &element : struct_type->elements()) { |
2741 | current_index.push_back(tlctx->get_constant(i++)); |
2742 | create_return(buffer, buffer_type, elements, element.type, |
2743 | current_element, current_index); |
2744 | current_index.pop_back(); |
2745 | } |
2746 | } else { |
2747 | auto tensor_type = current_type->as<TensorType>(); |
2748 | int num_elements = tensor_type->get_num_elements(); |
2749 | Type *element_type = tensor_type->get_element_type(); |
2750 | for (int i = 0; i < num_elements; i++) { |
2751 | current_index.push_back(tlctx->get_constant(i)); |
2752 | create_return(buffer, buffer_type, elements, element_type, |
2753 | current_element, current_index); |
2754 | current_index.pop_back(); |
2755 | } |
2756 | } |
2757 | } |
2758 | |
2759 | void TaskCodeGenLLVM::create_return(const std::vector<Stmt *> &elements) { |
2760 | auto buffer = call("RuntimeContext_get_result_buffer" , get_context()); |
2761 | auto ret_type = current_callable->ret_type; |
2762 | auto buffer_type = tlctx->get_data_type(ret_type); |
2763 | buffer = builder->CreatePointerCast(buffer, |
2764 | llvm::PointerType::get(buffer_type, 0)); |
2765 | int current_element = 0; |
2766 | std::vector<llvm::Value *> current_index = {tlctx->get_constant(0)}; |
2767 | create_return(buffer, buffer_type, elements, ret_type, current_element, |
2768 | current_index); |
2769 | }; |
2770 | |
2771 | LLVMCompiledTask LLVMCompiledTask::clone() const { |
2772 | return {tasks, llvm::CloneModule(*module), used_tree_ids, |
2773 | struct_for_tls_sizes}; |
2774 | } |
2775 | |
2776 | LLVMCompiledKernel LLVMCompiledKernel::clone() const { |
2777 | return {tasks, llvm::CloneModule(*module)}; |
2778 | } |
2779 | |
2780 | } // namespace taichi::lang |
2781 | |
2782 | #endif // #ifdef TI_WITH_LLVM |
2783 | |