1 | #include "taichi/codegen/cpu/codegen_cpu.h" |
2 | |
3 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
4 | #include "taichi/common/core.h" |
5 | #include "taichi/util/io.h" |
6 | #include "taichi/util/lang_util.h" |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/ir/ir.h" |
9 | #include "taichi/ir/statements.h" |
10 | #include "taichi/ir/transforms.h" |
11 | #include "taichi/ir/analysis.h" |
12 | #include "taichi/analysis/offline_cache_util.h" |
13 | namespace taichi::lang { |
14 | |
15 | namespace { |
16 | |
17 | class TaskCodeGenCPU : public TaskCodeGenLLVM { |
18 | public: |
19 | using IRVisitor::visit; |
20 | |
21 | TaskCodeGenCPU(const CompileConfig &config, |
22 | TaichiLLVMContext &tlctx, |
23 | Kernel *kernel, |
24 | IRNode *ir) |
25 | : TaskCodeGenLLVM(config, tlctx, kernel, ir, nullptr) { |
26 | TI_AUTO_PROF |
27 | } |
28 | |
29 | void create_offload_range_for(OffloadedStmt *stmt) override { |
30 | int step = 1; |
31 | |
32 | // In parallel for-loops reversing the order doesn't make sense. |
33 | // However, we may need to support serial offloaded range for's in the |
34 | // future, so it still makes sense to reverse the order here. |
35 | if (stmt->reversed) { |
36 | step = -1; |
37 | } |
38 | |
39 | auto *tls_prologue = create_xlogue(stmt->tls_prologue); |
40 | |
41 | // The loop body |
42 | llvm::Function *body; |
43 | { |
44 | auto guard = get_function_creation_guard( |
45 | {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
46 | llvm::Type::getInt8PtrTy(*llvm_context), |
47 | tlctx->get_data_type<int>()}); |
48 | |
49 | auto loop_var = create_entry_block_alloca(PrimitiveType::i32); |
50 | loop_vars_llvm[stmt].push_back(loop_var); |
51 | builder->CreateStore(get_arg(2), loop_var); |
52 | stmt->body->accept(this); |
53 | |
54 | body = guard.body; |
55 | } |
56 | |
57 | llvm::Value *epilogue = create_xlogue(stmt->tls_epilogue); |
58 | |
59 | auto [begin, end] = get_range_for_bounds(stmt); |
60 | |
61 | // adaptive block_dim |
62 | if (compile_config.cpu_block_dim_adaptive) { |
63 | int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step); |
64 | int num_threads = stmt->num_cpu_threads; |
65 | int items_per_thread = std::max(1, num_items / (num_threads * 32)); |
66 | // keep each task has at least 512 items to amortize scheduler overhead |
67 | // also saturate the value to 1024 for better load balancing |
68 | stmt->block_dim = std::min(1024, std::max(512, items_per_thread)); |
69 | } |
70 | |
71 | call("cpu_parallel_range_for" , get_arg(0), |
72 | tlctx->get_constant(stmt->num_cpu_threads), begin, end, |
73 | tlctx->get_constant(step), tlctx->get_constant(stmt->block_dim), |
74 | tls_prologue, body, epilogue, tlctx->get_constant(stmt->tls_size)); |
75 | } |
76 | |
77 | void create_offload_mesh_for(OffloadedStmt *stmt) override { |
78 | auto *tls_prologue = create_mesh_xlogue(stmt->tls_prologue); |
79 | |
80 | llvm::Function *body; |
81 | { |
82 | auto guard = get_function_creation_guard( |
83 | {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
84 | llvm::Type::getInt8PtrTy(*llvm_context), |
85 | tlctx->get_data_type<int>()}); |
86 | |
87 | for (int i = 0; i < stmt->mesh_prologue->size(); i++) { |
88 | auto &s = stmt->mesh_prologue->statements[i]; |
89 | s->accept(this); |
90 | } |
91 | |
92 | if (stmt->bls_prologue) { |
93 | stmt->bls_prologue->accept(this); |
94 | } |
95 | |
96 | auto loop_test_bb = |
97 | llvm::BasicBlock::Create(*llvm_context, "loop_test" , func); |
98 | auto loop_body_bb = |
99 | llvm::BasicBlock::Create(*llvm_context, "loop_body" , func); |
100 | auto func_exit = |
101 | llvm::BasicBlock::Create(*llvm_context, "func_exit" , func); |
102 | auto loop_index = |
103 | create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context)); |
104 | builder->CreateStore(tlctx->get_constant(0), loop_index); |
105 | builder->CreateBr(loop_test_bb); |
106 | |
107 | { |
108 | builder->SetInsertPoint(loop_test_bb); |
109 | auto *loop_index_load = |
110 | builder->CreateLoad(builder->getInt32Ty(), loop_index); |
111 | auto cond = builder->CreateICmp( |
112 | llvm::CmpInst::Predicate::ICMP_SLT, loop_index_load, |
113 | llvm_val[stmt->owned_num_local.find(stmt->major_from_type) |
114 | ->second]); |
115 | builder->CreateCondBr(cond, loop_body_bb, func_exit); |
116 | } |
117 | |
118 | { |
119 | builder->SetInsertPoint(loop_body_bb); |
120 | loop_vars_llvm[stmt].push_back(loop_index); |
121 | for (int i = 0; i < stmt->body->size(); i++) { |
122 | auto &s = stmt->body->statements[i]; |
123 | s->accept(this); |
124 | } |
125 | auto *loop_index_load = |
126 | builder->CreateLoad(builder->getInt32Ty(), loop_index); |
127 | builder->CreateStore( |
128 | builder->CreateAdd(loop_index_load, tlctx->get_constant(1)), |
129 | loop_index); |
130 | builder->CreateBr(loop_test_bb); |
131 | builder->SetInsertPoint(func_exit); |
132 | } |
133 | |
134 | if (stmt->bls_epilogue) { |
135 | stmt->bls_epilogue->accept(this); |
136 | } |
137 | |
138 | body = guard.body; |
139 | } |
140 | |
141 | llvm::Value *epilogue = create_mesh_xlogue(stmt->tls_epilogue); |
142 | |
143 | call("cpu_parallel_mesh_for" , get_arg(0), |
144 | tlctx->get_constant(stmt->num_cpu_threads), |
145 | tlctx->get_constant(stmt->mesh->num_patches), |
146 | tlctx->get_constant(stmt->block_dim), tls_prologue, body, epilogue, |
147 | tlctx->get_constant(stmt->tls_size)); |
148 | } |
149 | |
150 | void create_bls_buffer(OffloadedStmt *stmt) { |
151 | auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), |
152 | stmt->bls_size); |
153 | bls_buffer = new llvm::GlobalVariable( |
154 | *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, |
155 | "bls_buffer" , nullptr, llvm::GlobalVariable::LocalExecTLSModel, 0); |
156 | /* module->getOrInsertGlobal("bls_buffer", type); |
157 | bls_buffer = module->getNamedGlobal("bls_buffer"); |
158 | bls_buffer->setAlignment(llvm::MaybeAlign(8));*/ // TODO(changyu): Fix JIT session error: Symbols not found: [ __emutls_get_address ] in python 3.10 |
159 | |
160 | // initialize the variable with an undef value to ensure it is added to the |
161 | // symbol table |
162 | bls_buffer->setInitializer(llvm::UndefValue::get(type)); |
163 | } |
164 | |
165 | void visit(OffloadedStmt *stmt) override { |
166 | TI_ASSERT(current_offload == nullptr); |
167 | current_offload = stmt; |
168 | if (stmt->bls_size > 0) |
169 | create_bls_buffer(stmt); |
170 | using Type = OffloadedStmt::TaskType; |
171 | auto offloaded_task_name = init_offloaded_task_function(stmt); |
172 | if (compile_config.kernel_profiler && arch_is_cpu(compile_config.arch)) { |
173 | call("LLVMRuntime_profiler_start" , get_runtime(), |
174 | builder->CreateGlobalStringPtr(offloaded_task_name)); |
175 | } |
176 | if (stmt->task_type == Type::serial) { |
177 | stmt->body->accept(this); |
178 | } else if (stmt->task_type == Type::range_for) { |
179 | create_offload_range_for(stmt); |
180 | } else if (stmt->task_type == Type::mesh_for) { |
181 | create_offload_mesh_for(stmt); |
182 | } else if (stmt->task_type == Type::struct_for) { |
183 | stmt->block_dim = std::min(stmt->snode->parent->max_num_elements(), |
184 | (int64)stmt->block_dim); |
185 | create_offload_struct_for(stmt); |
186 | } else if (stmt->task_type == Type::listgen) { |
187 | emit_list_gen(stmt); |
188 | } else if (stmt->task_type == Type::gc) { |
189 | emit_gc(stmt); |
190 | } else if (stmt->task_type == Type::gc_rc) { |
191 | emit_gc_rc(); |
192 | } else { |
193 | TI_NOT_IMPLEMENTED |
194 | } |
195 | if (compile_config.kernel_profiler && arch_is_cpu(compile_config.arch)) { |
196 | llvm::IRBuilderBase::InsertPointGuard guard(*builder); |
197 | builder->SetInsertPoint(final_block); |
198 | call("LLVMRuntime_profiler_stop" , get_runtime()); |
199 | } |
200 | finalize_offloaded_task_function(); |
201 | offloaded_tasks.push_back(*current_task); |
202 | current_task = nullptr; |
203 | current_offload = nullptr; |
204 | } |
205 | |
206 | void visit(ExternalFuncCallStmt *stmt) override { |
207 | if (stmt->type == ExternalFuncCallStmt::BITCODE) { |
208 | TaskCodeGenLLVM::visit_call_bitcode(stmt); |
209 | } else if (stmt->type == ExternalFuncCallStmt::SHARED_OBJECT) { |
210 | TaskCodeGenLLVM::visit_call_shared_object(stmt); |
211 | } else { |
212 | TI_NOT_IMPLEMENTED |
213 | } |
214 | } |
215 | |
216 | private: |
217 | std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override { |
218 | auto thread_idx = tlctx->get_constant(0); |
219 | auto block_dim = tlctx->get_constant(1); |
220 | return std::make_tuple(thread_idx, block_dim); |
221 | } |
222 | }; |
223 | |
224 | } // namespace |
225 | |
226 | #ifdef TI_WITH_LLVM |
227 | FunctionType CPUModuleToFunctionConverter::convert( |
228 | const std::string &kernel_name, |
229 | const std::vector<LlvmLaunchArgInfo> &args, |
230 | LLVMCompiledKernel data) const { |
231 | TI_AUTO_PROF; |
232 | auto jit_module = executor_->create_jit_module(std::move(data.module)); |
233 | using TaskFunc = int32 (*)(void *); |
234 | std::vector<TaskFunc> task_funcs; |
235 | task_funcs.reserve(data.tasks.size()); |
236 | for (auto &task : data.tasks) { |
237 | auto *func_ptr = jit_module->lookup_function(task.name); |
238 | TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found" , |
239 | task.name); |
240 | task_funcs.push_back((TaskFunc)(func_ptr)); |
241 | } |
242 | // Do NOT capture `this`... |
243 | return [executor = this->executor_, args, kernel_name, |
244 | task_funcs](RuntimeContext &context) { |
245 | TI_TRACE("Launching kernel {}" , kernel_name); |
246 | // For taichi ndarrays, context.args saves pointer to its |
247 | // |DeviceAllocation|, CPU backend actually want to use the raw ptr here. |
248 | for (int i = 0; i < (int)args.size(); i++) { |
249 | if (args[i].is_array && |
250 | context.device_allocation_type[i] != |
251 | RuntimeContext::DevAllocType::kNone && |
252 | context.array_runtime_sizes[i] > 0) { |
253 | DeviceAllocation *ptr = |
254 | static_cast<DeviceAllocation *>(context.get_arg<void *>(i)); |
255 | uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr); |
256 | context.set_arg(i, host_ptr); |
257 | context.set_array_device_allocation_type( |
258 | i, RuntimeContext::DevAllocType::kNone); |
259 | |
260 | if (context.has_grad[i]) { |
261 | DeviceAllocation *ptr_grad = |
262 | static_cast<DeviceAllocation *>(context.get_grad_arg<void *>(i)); |
263 | uint64 host_ptr_grad = |
264 | (uint64)executor->get_ndarray_alloc_info_ptr(*ptr_grad); |
265 | context.set_grad_arg(i, host_ptr_grad); |
266 | } |
267 | } |
268 | } |
269 | for (auto task : task_funcs) { |
270 | task(&context); |
271 | } |
272 | }; |
273 | } |
274 | |
275 | LLVMCompiledTask KernelCodeGenCPU::compile_task( |
276 | const CompileConfig &config, |
277 | std::unique_ptr<llvm::Module> &&module, |
278 | OffloadedStmt *stmt) { |
279 | TaskCodeGenCPU gen(config, get_taichi_llvm_context(), kernel, stmt); |
280 | return gen.run_compilation(); |
281 | } |
282 | #endif // TI_WITH_LLVM |
283 | |
284 | FunctionType KernelCodeGenCPU::compile_to_function() { |
285 | TI_AUTO_PROF; |
286 | CPUModuleToFunctionConverter converter( |
287 | &get_taichi_llvm_context(), |
288 | get_llvm_program(prog)->get_runtime_executor()); |
289 | return converter.convert(kernel, compile_kernel_to_module()); |
290 | } |
291 | } // namespace taichi::lang |
292 | |