1 | #include "taichi/codegen/wasm/codegen_wasm.h" |
2 | |
3 | #include "taichi/codegen/llvm/codegen_llvm.h" |
4 | #include "taichi/common/core.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/util/io.h" |
7 | #include "taichi/util/lang_util.h" |
8 | #include "taichi/program/program.h" |
9 | #include "taichi/ir/ir.h" |
10 | #include "taichi/ir/statements.h" |
11 | #include "taichi/util/file_sequence_writer.h" |
12 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
13 | |
14 | namespace taichi::lang { |
15 | |
16 | namespace { |
17 | constexpr std::array<const char *, 5> kPreloadedFuncNames = { |
18 | "wasm_materialize" , "wasm_set_kernel_parameter_i32" , |
19 | "wasm_set_kernel_parameter_f32" , "wasm_set_print_buffer" , "wasm_print" }; |
20 | } |
21 | |
22 | class TaskCodeGenWASM : public TaskCodeGenLLVM { |
23 | public: |
24 | using IRVisitor::visit; |
25 | |
26 | TaskCodeGenWASM(const CompileConfig &config, |
27 | TaichiLLVMContext &tlctx, |
28 | Kernel *kernel, |
29 | IRNode *ir, |
30 | std::unique_ptr<llvm::Module> &&M = nullptr) |
31 | : TaskCodeGenLLVM(config, tlctx, kernel, ir, std::move(M)) { |
32 | TI_AUTO_PROF |
33 | } |
34 | |
35 | void create_offload_range_for(OffloadedStmt *stmt) override { |
36 | [[maybe_unused]] int step = 1; |
37 | |
38 | // In parallel for-loops reversing the order doesn't make sense. |
39 | // However, we may need to support serial offloaded range for's in the |
40 | // future, so it still makes sense to reverse the order here. |
41 | if (stmt->reversed) { |
42 | step = -1; |
43 | } |
44 | |
45 | auto *body = llvm::BasicBlock::Create(*llvm_context, "for_loop_body" , func); |
46 | auto *loop_inc = |
47 | llvm::BasicBlock::Create(*llvm_context, "for_loop_inc" , func); |
48 | auto *after_loop = |
49 | llvm::BasicBlock::Create(*llvm_context, "after_for" , func); |
50 | auto *loop_test = |
51 | llvm::BasicBlock::Create(*llvm_context, "for_loop_test" , func); |
52 | |
53 | auto loop_var = create_entry_block_alloca(PrimitiveType::i32); |
54 | loop_vars_llvm[stmt].push_back(loop_var); |
55 | |
56 | auto [begin, end] = get_range_for_bounds(stmt); |
57 | if (!stmt->reversed) { |
58 | builder->CreateStore(begin, loop_var); |
59 | } else { |
60 | builder->CreateStore(builder->CreateSub(end, tlctx->get_constant(1)), |
61 | loop_var); |
62 | } |
63 | builder->CreateBr(loop_test); |
64 | |
65 | { |
66 | // test block |
67 | builder->SetInsertPoint(loop_test); |
68 | llvm::Value *cond; |
69 | auto *loop_var_load = builder->CreateLoad(begin->getType(), loop_var); |
70 | if (!stmt->reversed) { |
71 | cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT, |
72 | loop_var_load, end); |
73 | } else { |
74 | cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE, |
75 | loop_var_load, begin); |
76 | } |
77 | builder->CreateCondBr(cond, body, after_loop); |
78 | } |
79 | |
80 | { |
81 | { |
82 | builder->SetInsertPoint(body); |
83 | stmt->body->accept(this); |
84 | } |
85 | builder->CreateBr(loop_inc); |
86 | builder->SetInsertPoint(loop_inc); |
87 | if (!stmt->reversed) { |
88 | create_increment(loop_var, tlctx->get_constant(1)); |
89 | } else { |
90 | create_increment(loop_var, tlctx->get_constant(-1)); |
91 | } |
92 | builder->CreateBr(loop_test); |
93 | } |
94 | |
95 | // next cfg |
96 | builder->SetInsertPoint(after_loop); |
97 | } |
98 | |
99 | void visit(OffloadedStmt *stmt) override { |
100 | TI_ASSERT(current_offload == nullptr) |
101 | current_offload = stmt; |
102 | using Type = OffloadedStmt::TaskType; |
103 | if (stmt->task_type == Type::serial) { |
104 | stmt->body->accept(this); |
105 | } else if (stmt->task_type == Type::range_for) { |
106 | create_offload_range_for(stmt); |
107 | } else { |
108 | TI_NOT_IMPLEMENTED |
109 | } |
110 | current_offload = nullptr; |
111 | } |
112 | |
113 | void visit(PrintStmt *stmt) override { |
114 | std::vector<llvm::Value *> args; |
115 | for (auto const &content : stmt->contents) { |
116 | if (std::holds_alternative<Stmt *>(content)) { |
117 | auto arg_stmt = std::get<Stmt *>(content); |
118 | auto value = llvm_val[arg_stmt]; |
119 | if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::i32)) { |
120 | auto func = get_runtime_function("wasm_print_i32" ); |
121 | builder->CreateCall(func, |
122 | std::vector<llvm::Value *>{get_context(), value}); |
123 | } else if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32)) { |
124 | auto func = get_runtime_function("wasm_print_f32" ); |
125 | builder->CreateCall(func, |
126 | std::vector<llvm::Value *>{get_context(), value}); |
127 | } else { |
128 | TI_NOT_IMPLEMENTED |
129 | } |
130 | } else { |
131 | auto arg_str = std::get<std::string>(content); |
132 | for (int i = 0; i < (int)arg_str.size(); i += 4) { |
133 | llvm::Value *values[4]; |
134 | for (int j = 0; j < 4; ++j) |
135 | if (i + j < (int)arg_str.size()) { |
136 | values[j] = llvm::ConstantInt::get( |
137 | *llvm_context, llvm::APInt(8, (uint64)arg_str[i + j], true)); |
138 | } else { |
139 | values[j] = llvm::ConstantInt::get( |
140 | *llvm_context, llvm::APInt(8, (uint64)0, true)); |
141 | } |
142 | auto func = get_runtime_function("wasm_print_char" ); |
143 | builder->CreateCall(func, std::vector<llvm::Value *>{ |
144 | get_context(), values[0], values[1], |
145 | values[2], values[3]}); |
146 | } |
147 | } |
148 | } |
149 | } |
150 | |
151 | /** |
152 | * Extracts the original function name decorated by @ti.kernel |
153 | * |
154 | * @param kernel_name The format is defined in |
155 | * https://github.com/taichi-dev/taichi/blob/734da3f8f4439ce7f6a5337df7c54fb6dc34def8/python/taichi/lang/kernel_impl.py#L360-L362 |
156 | */ |
157 | std::string (const std::string &kernel_name) { |
158 | if (kernel->is_evaluator) |
159 | return kernel_name; |
160 | int pos = kernel_name.length() - 1; |
161 | int underline_count = 0; |
162 | int redundant_count = 3; |
163 | for (; pos >= 0; --pos) { |
164 | if (kernel_name.at(pos) == '_') { |
165 | underline_count += 1; |
166 | if (underline_count == redundant_count) |
167 | break; |
168 | } |
169 | } |
170 | TI_ASSERT(underline_count == redundant_count) |
171 | return kernel_name.substr(0, pos); |
172 | } |
173 | |
174 | std::string init_taichi_kernel_function() { |
175 | task_function_type = |
176 | llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), |
177 | {llvm::PointerType::get(context_ty, 0)}, false); |
178 | |
179 | auto task_kernel_name = |
180 | fmt::format("{}_body" , extract_original_kernel_name(kernel_name)); |
181 | func = llvm::Function::Create(task_function_type, |
182 | llvm::Function::ExternalLinkage, |
183 | task_kernel_name, module.get()); |
184 | |
185 | for (auto &arg : func->args()) { |
186 | kernel_args.push_back(&arg); |
187 | } |
188 | kernel_args[0]->setName("context" ); |
189 | |
190 | // entry_block has all the allocas |
191 | this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry" , func); |
192 | |
193 | // The real function body |
194 | func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body" , func); |
195 | builder->SetInsertPoint(func_body_bb); |
196 | return task_kernel_name; |
197 | } |
198 | |
199 | void finalize_taichi_kernel_function() { |
200 | builder->CreateRetVoid(); |
201 | |
202 | // entry_block should jump to the body after all allocas are inserted |
203 | builder->SetInsertPoint(entry_block); |
204 | builder->CreateBr(func_body_bb); |
205 | |
206 | if (compile_config.print_kernel_llvm_ir) { |
207 | static FileSequenceWriter writer( |
208 | "taichi_kernel_generic_llvm_ir_{:04d}.ll" , |
209 | "unoptimized LLVM IR (generic)" ); |
210 | writer.write(module.get()); |
211 | } |
212 | TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); |
213 | } |
214 | |
215 | LLVMCompiledTask run_compilation() override { |
216 | // lower kernel |
217 | irpass::ast_to_ir(compile_config, *kernel); |
218 | |
219 | // emit_to_module |
220 | auto offloaded_task_name = init_taichi_kernel_function(); |
221 | ir->accept(this); |
222 | finalize_taichi_kernel_function(); |
223 | // only keep the current func |
224 | TaichiLLVMContext::eliminate_unused_functions( |
225 | module.get(), [offloaded_task_name](const std::string &func_name) { |
226 | for (auto &name : kPreloadedFuncNames) { |
227 | if (std::string(name) == func_name) { |
228 | return true; |
229 | } |
230 | } |
231 | return func_name == offloaded_task_name; |
232 | }); |
233 | LLVMCompiledTask res; |
234 | res.tasks.emplace_back(offloaded_task_name); |
235 | res.module = std::move(this->module); |
236 | return res; |
237 | } |
238 | |
239 | private: |
240 | std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override { |
241 | TI_NOT_IMPLEMENTED; |
242 | } |
243 | }; |
244 | |
245 | FunctionType KernelCodeGenWASM::compile_to_function() { |
246 | TI_AUTO_PROF |
247 | auto linked = compile_kernel_to_module(); |
248 | auto *executor = get_llvm_program(prog)->get_runtime_executor(); |
249 | auto *jit_module = executor->create_jit_module(std::move(linked.module)); |
250 | auto kernel_symbol = jit_module->lookup_function(linked.tasks[0].name); |
251 | return [kernel_symbol](RuntimeContext &context) { |
252 | TI_TRACE("Launching Taichi Kernel Function" ); |
253 | auto func = (int32(*)(void *))kernel_symbol; |
254 | func(&context); |
255 | }; |
256 | } |
257 | |
258 | LLVMCompiledTask KernelCodeGenWASM::compile_task( |
259 | const CompileConfig &config, |
260 | std::unique_ptr<llvm::Module> &&module, |
261 | OffloadedStmt *stmt) { |
262 | bool init_flag = module == nullptr; |
263 | std::vector<OffloadedTask> name_list; |
264 | auto gen = std::make_unique<TaskCodeGenWASM>( |
265 | config, get_taichi_llvm_context(), kernel, ir, std::move(module)); |
266 | |
267 | name_list.emplace_back(nullptr); |
268 | name_list[0].name = gen->init_taichi_kernel_function(); |
269 | gen->emit_to_module(); |
270 | gen->finalize_taichi_kernel_function(); |
271 | |
272 | // TODO: move the following functions to dump process in AOT. |
273 | if (init_flag) { |
274 | for (auto &name : kPreloadedFuncNames) { |
275 | name_list.emplace_back(nullptr); |
276 | name_list.back().name = name; |
277 | } |
278 | } |
279 | |
280 | gen->tlctx->jit->global_optimize_module(gen->module.get()); |
281 | |
282 | return {name_list, std::move(gen->module), {}, {}}; |
283 | } |
284 | |
285 | LLVMCompiledKernel KernelCodeGenWASM::compile_kernel_to_module() { |
286 | const auto &config = get_compile_config(); |
287 | irpass::ast_to_ir(config, *kernel, true); |
288 | |
289 | auto res = compile_task(config); |
290 | std::vector<std::unique_ptr<LLVMCompiledTask>> data; |
291 | data.push_back(std::make_unique<LLVMCompiledTask>(std::move(res))); |
292 | return get_taichi_llvm_context().link_compiled_tasks(std::move(data)); |
293 | } |
294 | |
295 | } // namespace taichi::lang |
296 | |