1#include "taichi/codegen/cpu/codegen_cpu.h"
2
3#include "taichi/runtime/program_impls/llvm/llvm_program.h"
4#include "taichi/common/core.h"
5#include "taichi/util/io.h"
6#include "taichi/util/lang_util.h"
7#include "taichi/program/program.h"
8#include "taichi/ir/ir.h"
9#include "taichi/ir/statements.h"
10#include "taichi/ir/transforms.h"
11#include "taichi/ir/analysis.h"
12#include "taichi/analysis/offline_cache_util.h"
13namespace taichi::lang {
14
15namespace {
16
17class TaskCodeGenCPU : public TaskCodeGenLLVM {
18 public:
19 using IRVisitor::visit;
20
21 TaskCodeGenCPU(const CompileConfig &config,
22 TaichiLLVMContext &tlctx,
23 Kernel *kernel,
24 IRNode *ir)
25 : TaskCodeGenLLVM(config, tlctx, kernel, ir, nullptr) {
26 TI_AUTO_PROF
27 }
28
29 void create_offload_range_for(OffloadedStmt *stmt) override {
30 int step = 1;
31
32 // In parallel for-loops reversing the order doesn't make sense.
33 // However, we may need to support serial offloaded range for's in the
34 // future, so it still makes sense to reverse the order here.
35 if (stmt->reversed) {
36 step = -1;
37 }
38
39 auto *tls_prologue = create_xlogue(stmt->tls_prologue);
40
41 // The loop body
42 llvm::Function *body;
43 {
44 auto guard = get_function_creation_guard(
45 {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
46 llvm::Type::getInt8PtrTy(*llvm_context),
47 tlctx->get_data_type<int>()});
48
49 auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
50 loop_vars_llvm[stmt].push_back(loop_var);
51 builder->CreateStore(get_arg(2), loop_var);
52 stmt->body->accept(this);
53
54 body = guard.body;
55 }
56
57 llvm::Value *epilogue = create_xlogue(stmt->tls_epilogue);
58
59 auto [begin, end] = get_range_for_bounds(stmt);
60
61 // adaptive block_dim
62 if (compile_config.cpu_block_dim_adaptive) {
63 int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step);
64 int num_threads = stmt->num_cpu_threads;
65 int items_per_thread = std::max(1, num_items / (num_threads * 32));
66 // keep each task has at least 512 items to amortize scheduler overhead
67 // also saturate the value to 1024 for better load balancing
68 stmt->block_dim = std::min(1024, std::max(512, items_per_thread));
69 }
70
71 call("cpu_parallel_range_for", get_arg(0),
72 tlctx->get_constant(stmt->num_cpu_threads), begin, end,
73 tlctx->get_constant(step), tlctx->get_constant(stmt->block_dim),
74 tls_prologue, body, epilogue, tlctx->get_constant(stmt->tls_size));
75 }
76
77 void create_offload_mesh_for(OffloadedStmt *stmt) override {
78 auto *tls_prologue = create_mesh_xlogue(stmt->tls_prologue);
79
80 llvm::Function *body;
81 {
82 auto guard = get_function_creation_guard(
83 {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
84 llvm::Type::getInt8PtrTy(*llvm_context),
85 tlctx->get_data_type<int>()});
86
87 for (int i = 0; i < stmt->mesh_prologue->size(); i++) {
88 auto &s = stmt->mesh_prologue->statements[i];
89 s->accept(this);
90 }
91
92 if (stmt->bls_prologue) {
93 stmt->bls_prologue->accept(this);
94 }
95
96 auto loop_test_bb =
97 llvm::BasicBlock::Create(*llvm_context, "loop_test", func);
98 auto loop_body_bb =
99 llvm::BasicBlock::Create(*llvm_context, "loop_body", func);
100 auto func_exit =
101 llvm::BasicBlock::Create(*llvm_context, "func_exit", func);
102 auto loop_index =
103 create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context));
104 builder->CreateStore(tlctx->get_constant(0), loop_index);
105 builder->CreateBr(loop_test_bb);
106
107 {
108 builder->SetInsertPoint(loop_test_bb);
109 auto *loop_index_load =
110 builder->CreateLoad(builder->getInt32Ty(), loop_index);
111 auto cond = builder->CreateICmp(
112 llvm::CmpInst::Predicate::ICMP_SLT, loop_index_load,
113 llvm_val[stmt->owned_num_local.find(stmt->major_from_type)
114 ->second]);
115 builder->CreateCondBr(cond, loop_body_bb, func_exit);
116 }
117
118 {
119 builder->SetInsertPoint(loop_body_bb);
120 loop_vars_llvm[stmt].push_back(loop_index);
121 for (int i = 0; i < stmt->body->size(); i++) {
122 auto &s = stmt->body->statements[i];
123 s->accept(this);
124 }
125 auto *loop_index_load =
126 builder->CreateLoad(builder->getInt32Ty(), loop_index);
127 builder->CreateStore(
128 builder->CreateAdd(loop_index_load, tlctx->get_constant(1)),
129 loop_index);
130 builder->CreateBr(loop_test_bb);
131 builder->SetInsertPoint(func_exit);
132 }
133
134 if (stmt->bls_epilogue) {
135 stmt->bls_epilogue->accept(this);
136 }
137
138 body = guard.body;
139 }
140
141 llvm::Value *epilogue = create_mesh_xlogue(stmt->tls_epilogue);
142
143 call("cpu_parallel_mesh_for", get_arg(0),
144 tlctx->get_constant(stmt->num_cpu_threads),
145 tlctx->get_constant(stmt->mesh->num_patches),
146 tlctx->get_constant(stmt->block_dim), tls_prologue, body, epilogue,
147 tlctx->get_constant(stmt->tls_size));
148 }
149
150 void create_bls_buffer(OffloadedStmt *stmt) {
151 auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context),
152 stmt->bls_size);
153 bls_buffer = new llvm::GlobalVariable(
154 *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr,
155 "bls_buffer", nullptr, llvm::GlobalVariable::LocalExecTLSModel, 0);
156 /* module->getOrInsertGlobal("bls_buffer", type);
157 bls_buffer = module->getNamedGlobal("bls_buffer");
158 bls_buffer->setAlignment(llvm::MaybeAlign(8));*/ // TODO(changyu): Fix JIT session error: Symbols not found: [ __emutls_get_address ] in python 3.10
159
160 // initialize the variable with an undef value to ensure it is added to the
161 // symbol table
162 bls_buffer->setInitializer(llvm::UndefValue::get(type));
163 }
164
165 void visit(OffloadedStmt *stmt) override {
166 TI_ASSERT(current_offload == nullptr);
167 current_offload = stmt;
168 if (stmt->bls_size > 0)
169 create_bls_buffer(stmt);
170 using Type = OffloadedStmt::TaskType;
171 auto offloaded_task_name = init_offloaded_task_function(stmt);
172 if (compile_config.kernel_profiler && arch_is_cpu(compile_config.arch)) {
173 call("LLVMRuntime_profiler_start", get_runtime(),
174 builder->CreateGlobalStringPtr(offloaded_task_name));
175 }
176 if (stmt->task_type == Type::serial) {
177 stmt->body->accept(this);
178 } else if (stmt->task_type == Type::range_for) {
179 create_offload_range_for(stmt);
180 } else if (stmt->task_type == Type::mesh_for) {
181 create_offload_mesh_for(stmt);
182 } else if (stmt->task_type == Type::struct_for) {
183 stmt->block_dim = std::min(stmt->snode->parent->max_num_elements(),
184 (int64)stmt->block_dim);
185 create_offload_struct_for(stmt);
186 } else if (stmt->task_type == Type::listgen) {
187 emit_list_gen(stmt);
188 } else if (stmt->task_type == Type::gc) {
189 emit_gc(stmt);
190 } else if (stmt->task_type == Type::gc_rc) {
191 emit_gc_rc();
192 } else {
193 TI_NOT_IMPLEMENTED
194 }
195 if (compile_config.kernel_profiler && arch_is_cpu(compile_config.arch)) {
196 llvm::IRBuilderBase::InsertPointGuard guard(*builder);
197 builder->SetInsertPoint(final_block);
198 call("LLVMRuntime_profiler_stop", get_runtime());
199 }
200 finalize_offloaded_task_function();
201 offloaded_tasks.push_back(*current_task);
202 current_task = nullptr;
203 current_offload = nullptr;
204 }
205
206 void visit(ExternalFuncCallStmt *stmt) override {
207 if (stmt->type == ExternalFuncCallStmt::BITCODE) {
208 TaskCodeGenLLVM::visit_call_bitcode(stmt);
209 } else if (stmt->type == ExternalFuncCallStmt::SHARED_OBJECT) {
210 TaskCodeGenLLVM::visit_call_shared_object(stmt);
211 } else {
212 TI_NOT_IMPLEMENTED
213 }
214 }
215
216 private:
217 std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
218 auto thread_idx = tlctx->get_constant(0);
219 auto block_dim = tlctx->get_constant(1);
220 return std::make_tuple(thread_idx, block_dim);
221 }
222};
223
224} // namespace
225
226#ifdef TI_WITH_LLVM
227FunctionType CPUModuleToFunctionConverter::convert(
228 const std::string &kernel_name,
229 const std::vector<LlvmLaunchArgInfo> &args,
230 LLVMCompiledKernel data) const {
231 TI_AUTO_PROF;
232 auto jit_module = executor_->create_jit_module(std::move(data.module));
233 using TaskFunc = int32 (*)(void *);
234 std::vector<TaskFunc> task_funcs;
235 task_funcs.reserve(data.tasks.size());
236 for (auto &task : data.tasks) {
237 auto *func_ptr = jit_module->lookup_function(task.name);
238 TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found",
239 task.name);
240 task_funcs.push_back((TaskFunc)(func_ptr));
241 }
242 // Do NOT capture `this`...
243 return [executor = this->executor_, args, kernel_name,
244 task_funcs](RuntimeContext &context) {
245 TI_TRACE("Launching kernel {}", kernel_name);
246 // For taichi ndarrays, context.args saves pointer to its
247 // |DeviceAllocation|, CPU backend actually want to use the raw ptr here.
248 for (int i = 0; i < (int)args.size(); i++) {
249 if (args[i].is_array &&
250 context.device_allocation_type[i] !=
251 RuntimeContext::DevAllocType::kNone &&
252 context.array_runtime_sizes[i] > 0) {
253 DeviceAllocation *ptr =
254 static_cast<DeviceAllocation *>(context.get_arg<void *>(i));
255 uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
256 context.set_arg(i, host_ptr);
257 context.set_array_device_allocation_type(
258 i, RuntimeContext::DevAllocType::kNone);
259
260 if (context.has_grad[i]) {
261 DeviceAllocation *ptr_grad =
262 static_cast<DeviceAllocation *>(context.get_grad_arg<void *>(i));
263 uint64 host_ptr_grad =
264 (uint64)executor->get_ndarray_alloc_info_ptr(*ptr_grad);
265 context.set_grad_arg(i, host_ptr_grad);
266 }
267 }
268 }
269 for (auto task : task_funcs) {
270 task(&context);
271 }
272 };
273}
274
275LLVMCompiledTask KernelCodeGenCPU::compile_task(
276 const CompileConfig &config,
277 std::unique_ptr<llvm::Module> &&module,
278 OffloadedStmt *stmt) {
279 TaskCodeGenCPU gen(config, get_taichi_llvm_context(), kernel, stmt);
280 return gen.run_compilation();
281}
282#endif // TI_WITH_LLVM
283
284FunctionType KernelCodeGenCPU::compile_to_function() {
285 TI_AUTO_PROF;
286 CPUModuleToFunctionConverter converter(
287 &get_taichi_llvm_context(),
288 get_llvm_program(prog)->get_runtime_executor());
289 return converter.convert(kernel, compile_kernel_to_module());
290}
291} // namespace taichi::lang
292