1#include "taichi/codegen/wasm/codegen_wasm.h"
2
3#include "taichi/codegen/llvm/codegen_llvm.h"
4#include "taichi/common/core.h"
5#include "taichi/ir/transforms.h"
6#include "taichi/util/io.h"
7#include "taichi/util/lang_util.h"
8#include "taichi/program/program.h"
9#include "taichi/ir/ir.h"
10#include "taichi/ir/statements.h"
11#include "taichi/util/file_sequence_writer.h"
12#include "taichi/runtime/program_impls/llvm/llvm_program.h"
13
14namespace taichi::lang {
15
16namespace {
17constexpr std::array<const char *, 5> kPreloadedFuncNames = {
18 "wasm_materialize", "wasm_set_kernel_parameter_i32",
19 "wasm_set_kernel_parameter_f32", "wasm_set_print_buffer", "wasm_print"};
20}
21
22class TaskCodeGenWASM : public TaskCodeGenLLVM {
23 public:
24 using IRVisitor::visit;
25
26 TaskCodeGenWASM(const CompileConfig &config,
27 TaichiLLVMContext &tlctx,
28 Kernel *kernel,
29 IRNode *ir,
30 std::unique_ptr<llvm::Module> &&M = nullptr)
31 : TaskCodeGenLLVM(config, tlctx, kernel, ir, std::move(M)) {
32 TI_AUTO_PROF
33 }
34
35 void create_offload_range_for(OffloadedStmt *stmt) override {
36 [[maybe_unused]] int step = 1;
37
38 // In parallel for-loops reversing the order doesn't make sense.
39 // However, we may need to support serial offloaded range for's in the
40 // future, so it still makes sense to reverse the order here.
41 if (stmt->reversed) {
42 step = -1;
43 }
44
45 auto *body = llvm::BasicBlock::Create(*llvm_context, "for_loop_body", func);
46 auto *loop_inc =
47 llvm::BasicBlock::Create(*llvm_context, "for_loop_inc", func);
48 auto *after_loop =
49 llvm::BasicBlock::Create(*llvm_context, "after_for", func);
50 auto *loop_test =
51 llvm::BasicBlock::Create(*llvm_context, "for_loop_test", func);
52
53 auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
54 loop_vars_llvm[stmt].push_back(loop_var);
55
56 auto [begin, end] = get_range_for_bounds(stmt);
57 if (!stmt->reversed) {
58 builder->CreateStore(begin, loop_var);
59 } else {
60 builder->CreateStore(builder->CreateSub(end, tlctx->get_constant(1)),
61 loop_var);
62 }
63 builder->CreateBr(loop_test);
64
65 {
66 // test block
67 builder->SetInsertPoint(loop_test);
68 llvm::Value *cond;
69 auto *loop_var_load = builder->CreateLoad(begin->getType(), loop_var);
70 if (!stmt->reversed) {
71 cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
72 loop_var_load, end);
73 } else {
74 cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
75 loop_var_load, begin);
76 }
77 builder->CreateCondBr(cond, body, after_loop);
78 }
79
80 {
81 {
82 builder->SetInsertPoint(body);
83 stmt->body->accept(this);
84 }
85 builder->CreateBr(loop_inc);
86 builder->SetInsertPoint(loop_inc);
87 if (!stmt->reversed) {
88 create_increment(loop_var, tlctx->get_constant(1));
89 } else {
90 create_increment(loop_var, tlctx->get_constant(-1));
91 }
92 builder->CreateBr(loop_test);
93 }
94
95 // next cfg
96 builder->SetInsertPoint(after_loop);
97 }
98
99 void visit(OffloadedStmt *stmt) override {
100 TI_ASSERT(current_offload == nullptr)
101 current_offload = stmt;
102 using Type = OffloadedStmt::TaskType;
103 if (stmt->task_type == Type::serial) {
104 stmt->body->accept(this);
105 } else if (stmt->task_type == Type::range_for) {
106 create_offload_range_for(stmt);
107 } else {
108 TI_NOT_IMPLEMENTED
109 }
110 current_offload = nullptr;
111 }
112
113 void visit(PrintStmt *stmt) override {
114 std::vector<llvm::Value *> args;
115 for (auto const &content : stmt->contents) {
116 if (std::holds_alternative<Stmt *>(content)) {
117 auto arg_stmt = std::get<Stmt *>(content);
118 auto value = llvm_val[arg_stmt];
119 if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::i32)) {
120 auto func = get_runtime_function("wasm_print_i32");
121 builder->CreateCall(func,
122 std::vector<llvm::Value *>{get_context(), value});
123 } else if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32)) {
124 auto func = get_runtime_function("wasm_print_f32");
125 builder->CreateCall(func,
126 std::vector<llvm::Value *>{get_context(), value});
127 } else {
128 TI_NOT_IMPLEMENTED
129 }
130 } else {
131 auto arg_str = std::get<std::string>(content);
132 for (int i = 0; i < (int)arg_str.size(); i += 4) {
133 llvm::Value *values[4];
134 for (int j = 0; j < 4; ++j)
135 if (i + j < (int)arg_str.size()) {
136 values[j] = llvm::ConstantInt::get(
137 *llvm_context, llvm::APInt(8, (uint64)arg_str[i + j], true));
138 } else {
139 values[j] = llvm::ConstantInt::get(
140 *llvm_context, llvm::APInt(8, (uint64)0, true));
141 }
142 auto func = get_runtime_function("wasm_print_char");
143 builder->CreateCall(func, std::vector<llvm::Value *>{
144 get_context(), values[0], values[1],
145 values[2], values[3]});
146 }
147 }
148 }
149 }
150
151 /**
152 * Extracts the original function name decorated by @ti.kernel
153 *
154 * @param kernel_name The format is defined in
155 * https://github.com/taichi-dev/taichi/blob/734da3f8f4439ce7f6a5337df7c54fb6dc34def8/python/taichi/lang/kernel_impl.py#L360-L362
156 */
157 std::string extract_original_kernel_name(const std::string &kernel_name) {
158 if (kernel->is_evaluator)
159 return kernel_name;
160 int pos = kernel_name.length() - 1;
161 int underline_count = 0;
162 int redundant_count = 3;
163 for (; pos >= 0; --pos) {
164 if (kernel_name.at(pos) == '_') {
165 underline_count += 1;
166 if (underline_count == redundant_count)
167 break;
168 }
169 }
170 TI_ASSERT(underline_count == redundant_count)
171 return kernel_name.substr(0, pos);
172 }
173
174 std::string init_taichi_kernel_function() {
175 task_function_type =
176 llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
177 {llvm::PointerType::get(context_ty, 0)}, false);
178
179 auto task_kernel_name =
180 fmt::format("{}_body", extract_original_kernel_name(kernel_name));
181 func = llvm::Function::Create(task_function_type,
182 llvm::Function::ExternalLinkage,
183 task_kernel_name, module.get());
184
185 for (auto &arg : func->args()) {
186 kernel_args.push_back(&arg);
187 }
188 kernel_args[0]->setName("context");
189
190 // entry_block has all the allocas
191 this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);
192
193 // The real function body
194 func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
195 builder->SetInsertPoint(func_body_bb);
196 return task_kernel_name;
197 }
198
199 void finalize_taichi_kernel_function() {
200 builder->CreateRetVoid();
201
202 // entry_block should jump to the body after all allocas are inserted
203 builder->SetInsertPoint(entry_block);
204 builder->CreateBr(func_body_bb);
205
206 if (compile_config.print_kernel_llvm_ir) {
207 static FileSequenceWriter writer(
208 "taichi_kernel_generic_llvm_ir_{:04d}.ll",
209 "unoptimized LLVM IR (generic)");
210 writer.write(module.get());
211 }
212 TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
213 }
214
215 LLVMCompiledTask run_compilation() override {
216 // lower kernel
217 irpass::ast_to_ir(compile_config, *kernel);
218
219 // emit_to_module
220 auto offloaded_task_name = init_taichi_kernel_function();
221 ir->accept(this);
222 finalize_taichi_kernel_function();
223 // only keep the current func
224 TaichiLLVMContext::eliminate_unused_functions(
225 module.get(), [offloaded_task_name](const std::string &func_name) {
226 for (auto &name : kPreloadedFuncNames) {
227 if (std::string(name) == func_name) {
228 return true;
229 }
230 }
231 return func_name == offloaded_task_name;
232 });
233 LLVMCompiledTask res;
234 res.tasks.emplace_back(offloaded_task_name);
235 res.module = std::move(this->module);
236 return res;
237 }
238
239 private:
240 std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
241 TI_NOT_IMPLEMENTED;
242 }
243};
244
245FunctionType KernelCodeGenWASM::compile_to_function() {
246 TI_AUTO_PROF
247 auto linked = compile_kernel_to_module();
248 auto *executor = get_llvm_program(prog)->get_runtime_executor();
249 auto *jit_module = executor->create_jit_module(std::move(linked.module));
250 auto kernel_symbol = jit_module->lookup_function(linked.tasks[0].name);
251 return [kernel_symbol](RuntimeContext &context) {
252 TI_TRACE("Launching Taichi Kernel Function");
253 auto func = (int32(*)(void *))kernel_symbol;
254 func(&context);
255 };
256}
257
258LLVMCompiledTask KernelCodeGenWASM::compile_task(
259 const CompileConfig &config,
260 std::unique_ptr<llvm::Module> &&module,
261 OffloadedStmt *stmt) {
262 bool init_flag = module == nullptr;
263 std::vector<OffloadedTask> name_list;
264 auto gen = std::make_unique<TaskCodeGenWASM>(
265 config, get_taichi_llvm_context(), kernel, ir, std::move(module));
266
267 name_list.emplace_back(nullptr);
268 name_list[0].name = gen->init_taichi_kernel_function();
269 gen->emit_to_module();
270 gen->finalize_taichi_kernel_function();
271
272 // TODO: move the following functions to dump process in AOT.
273 if (init_flag) {
274 for (auto &name : kPreloadedFuncNames) {
275 name_list.emplace_back(nullptr);
276 name_list.back().name = name;
277 }
278 }
279
280 gen->tlctx->jit->global_optimize_module(gen->module.get());
281
282 return {name_list, std::move(gen->module), {}, {}};
283}
284
285LLVMCompiledKernel KernelCodeGenWASM::compile_kernel_to_module() {
286 const auto &config = get_compile_config();
287 irpass::ast_to_ir(config, *kernel, true);
288
289 auto res = compile_task(config);
290 std::vector<std::unique_ptr<LLVMCompiledTask>> data;
291 data.push_back(std::make_unique<LLVMCompiledTask>(std::move(res)));
292 return get_taichi_llvm_context().link_compiled_tasks(std::move(data));
293}
294
295} // namespace taichi::lang
296