1 | #include "codegen_cc.h" |
2 | #include "cc_kernel.h" |
3 | #include "cc_layout.h" |
4 | #include "cc_program.h" |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/ir/statements.h" |
7 | #include "taichi/ir/transforms.h" |
8 | #include "taichi/util/line_appender.h" |
9 | #include "taichi/util/str.h" |
10 | #include "cc_utils.h" |
11 | |
12 | #define C90_COMPAT 0 |
13 | |
14 | namespace taichi::lang { |
15 | namespace cccp { // Codegen for C Compiler Processor |
16 | |
17 | namespace { |
18 | std::string get_node_ptr_name(SNode *snode) { |
19 | return fmt::format("struct Ti_{} *" , snode->get_node_type_name_hinted()); |
20 | } |
21 | |
22 | static void lower_ast(const CompileConfig &config, Kernel *kernel) { |
23 | auto ir = kernel->ir.get(); |
24 | irpass::compile_to_executable(ir, config, kernel, |
25 | /*autodiff_mode=*/kernel->autodiff_mode, |
26 | /*ad_use_stack=*/true, config.print_ir, |
27 | /*lower_global_access*/ true); |
28 | } |
29 | |
30 | } // namespace |
31 | |
32 | class CCTransformer : public IRVisitor { |
33 | private: |
34 | [[maybe_unused]] Kernel *kernel_; |
35 | [[maybe_unused]] CCLayout *layout_; |
36 | |
37 | LineAppender line_appender_; |
38 | LineAppender ; |
39 | bool is_top_level_{true}; |
40 | GetRootStmt *root_stmt_; |
41 | |
42 | public: |
43 | CCTransformer(Kernel *kernel, CCLayout *layout) |
44 | : kernel_(kernel), layout_(layout) { |
45 | allow_undefined_visitor = true; |
46 | invoke_default_visitor = true; |
47 | } |
48 | |
49 | void run() { |
50 | emit_header("void Tk_{}(struct Ti_Context *ti_ctx) {{" , kernel_->name); |
51 | kernel_->ir->accept(this); |
52 | emit("}}" ); |
53 | } |
54 | |
55 | std::string get_source() { |
56 | return line_appender_header_.lines() + line_appender_.lines(); |
57 | } |
58 | |
59 | private: |
60 | void visit(Block *stmt) override { |
61 | if (!is_top_level_) |
62 | line_appender_.push_indent(); |
63 | for (auto &s : stmt->statements) { |
64 | s->accept(this); |
65 | } |
66 | if (!is_top_level_) |
67 | line_appender_.pop_indent(); |
68 | } |
69 | |
70 | void visit(Stmt *stmt) override { |
71 | TI_WARN("[cc] unsupported statement type {}\n{}" , typeid(*stmt).name(), |
72 | stmt->tb); |
73 | } |
74 | |
75 | std::string define_var(std::string const &type, std::string const &name) { |
76 | if (C90_COMPAT) { |
77 | emit_header("{} {};" , type, name); |
78 | return name; |
79 | } else { |
80 | return fmt::format("{} {}" , type, name); |
81 | } |
82 | } |
83 | |
84 | void visit(GetRootStmt *stmt) override { |
85 | auto *root = kernel_->program->get_snode_root(SNodeTree::kFirstID); |
86 | emit("{} = ti_ctx->root;" , |
87 | define_var(get_node_ptr_name(root), stmt->raw_name())); |
88 | root_stmt_ = stmt; |
89 | } |
90 | |
91 | void visit(SNodeLookupStmt *stmt) override { |
92 | Stmt *input_ptr; |
93 | if (stmt->input_snode) { |
94 | input_ptr = stmt->input_snode; |
95 | } else { |
96 | TI_ASSERT(root_stmt_ != nullptr); |
97 | input_ptr = root_stmt_; |
98 | } |
99 | |
100 | emit("{} = &{}[{}];" , |
101 | define_var(get_node_ptr_name(stmt->snode), stmt->raw_name()), |
102 | input_ptr->raw_name(), stmt->input_index->raw_name()); |
103 | } |
104 | |
105 | void visit(GetChStmt *stmt) override { |
106 | auto snode = stmt->output_snode; |
107 | std::string type; |
108 | if (snode->type == SNodeType::place) { |
109 | auto dt = fmt::format("{} *" , cc_data_type_name(snode->dt)); |
110 | emit("{} = &{}->{};" , define_var(dt, stmt->raw_name()), |
111 | stmt->input_ptr->raw_name(), snode->get_node_type_name()); |
112 | } else { |
113 | emit("{} = {}->{};" , |
114 | define_var(get_node_ptr_name(snode), stmt->raw_name()), |
115 | stmt->input_ptr->raw_name(), snode->get_node_type_name()); |
116 | } |
117 | } |
118 | |
119 | void visit(GlobalLoadStmt *stmt) override { |
120 | emit("{} = *{};" , |
121 | define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()), |
122 | stmt->src->raw_name()); |
123 | } |
124 | |
125 | void visit(GlobalStoreStmt *stmt) override { |
126 | emit("*{} = {};" , stmt->dest->raw_name(), stmt->val->raw_name()); |
127 | } |
128 | |
129 | void visit(GlobalTemporaryStmt *stmt) override { |
130 | auto ptr_type = |
131 | cc_data_type_name(stmt->element_type().ptr_removed()) + " *" ; |
132 | auto var = define_var(ptr_type, stmt->raw_name()); |
133 | emit("{} = ({}) (ti_ctx->gtmp + {});" , var, ptr_type, stmt->offset); |
134 | } |
135 | |
136 | void visit(LinearizeStmt *stmt) override { |
137 | std::string val = "0" ; |
138 | for (int i = 0; i < stmt->inputs.size(); i++) { |
139 | val = fmt::format("({} * {} + {})" , val, stmt->strides[i], |
140 | stmt->inputs[i]->raw_name()); |
141 | } |
142 | emit("{} = {};" , define_var("Ti_i32" , stmt->raw_name()), val); |
143 | } |
144 | |
145 | void visit(ExternalPtrStmt *stmt) override { |
146 | std::string offset = "0" ; |
147 | const auto *argload = stmt->base_ptr->as<ArgLoadStmt>(); |
148 | const int arg_id = argload->arg_id; |
149 | const auto element_shape = stmt->element_shape; |
150 | const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS |
151 | : ExternalArrayLayout::kSOA; |
152 | const size_t element_shape_index_offset = |
153 | (layout == ExternalArrayLayout::kAOS) |
154 | ? stmt->indices.size() - element_shape.size() |
155 | : 0; |
156 | size_t size_var_index = 0; |
157 | for (int i = 0; i < stmt->indices.size(); i++) { |
158 | std::string stride; |
159 | if (i >= element_shape_index_offset && |
160 | i < element_shape_index_offset + element_shape.size()) { |
161 | stride = fmt::format("{}" , element_shape[i - element_shape.size()]); |
162 | } else { |
163 | stride = fmt::format("ti_ctx->earg[{} * {} + {}]" , arg_id, |
164 | taichi_max_num_indices, size_var_index++); |
165 | } |
166 | offset = fmt::format("({} * {} + {})" , offset, stride, |
167 | stmt->indices[i]->raw_name()); |
168 | } |
169 | auto var = |
170 | define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *" , |
171 | stmt->raw_name()); |
172 | emit("{} = {} + {};" , var, stmt->base_ptr->raw_name(), offset); |
173 | } |
174 | |
175 | void visit(ArgLoadStmt *stmt) override { |
176 | if (stmt->is_ptr) { |
177 | auto var = define_var( |
178 | cc_data_type_name(stmt->element_type().ptr_removed()) + " *" , |
179 | stmt->raw_name()); |
180 | emit("{} = ti_ctx->args[{}].ptr_{};" , var, stmt->arg_id, |
181 | data_type_name(stmt->element_type().ptr_removed())); |
182 | } else { |
183 | auto var = |
184 | define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()); |
185 | emit("{} = ti_ctx->args[{}].val_{};" , var, stmt->arg_id, |
186 | data_type_name(stmt->element_type())); |
187 | } |
188 | } |
189 | |
190 | void visit(ReturnStmt *stmt) override { |
191 | int idx{0}; |
192 | for (auto &value : stmt->values) { |
193 | emit("ti_ctx->args[{}].val_{} = {};" , idx++, |
194 | data_type_name(value->element_type()), value->raw_name()); |
195 | } |
196 | } |
197 | |
198 | void visit(ConstStmt *stmt) override { |
199 | emit("{} = {};" , |
200 | define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()), |
201 | stmt->val.stringify()); |
202 | } |
203 | |
204 | void visit(AllocaStmt *stmt) override { |
205 | emit("{} = 0;" , |
206 | define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name())); |
207 | } |
208 | |
209 | void visit(LocalLoadStmt *stmt) override { |
210 | auto var = |
211 | define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()); |
212 | emit("{} = {};" , var, stmt->src->raw_name()); |
213 | } |
214 | |
215 | void visit(LocalStoreStmt *stmt) override { |
216 | emit("{} = {};" , stmt->dest->raw_name(), stmt->val->raw_name()); |
217 | } |
218 | |
219 | void visit(ExternalFuncCallStmt *stmt) override { |
220 | TI_ASSERT(stmt->type == ExternalFuncCallStmt::ASSEMBLY); |
221 | auto format = stmt->asm_source; |
222 | std::string source; |
223 | |
224 | for (int i = 0; i < format.size(); i++) { |
225 | char c = format[i]; |
226 | if (c == '%' || c == '$') { // '$' for output, '%' for input |
227 | int num = 0; |
228 | while (i < format.size()) { |
229 | i += 1; |
230 | if (!::isdigit(format[i])) { |
231 | i -= 1; |
232 | break; |
233 | } |
234 | num *= 10; |
235 | num += format[i] - '0'; |
236 | } |
237 | auto args = (c == '%') ? stmt->arg_stmts : stmt->output_stmts; |
238 | TI_ASSERT_INFO(num < args.size(), "{}{} out of {} argument range {}" , c, |
239 | num, ((c == '%') ? "input" : "output" ), args.size()); |
240 | source += args[num]->raw_name(); |
241 | } else { |
242 | source.push_back(c); |
243 | } |
244 | } |
245 | |
246 | emit("{};" , source); |
247 | } |
248 | |
249 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
250 | const auto type = cc_data_type_name(stmt->element_type()); |
251 | const auto name = stmt->raw_name(); |
252 | const auto var = define_var(type, name); |
253 | const auto arg_id = stmt->arg_id; |
254 | const auto axis = stmt->axis; |
255 | const auto axis_size = fmt::format("ti_ctx->earg[{} * {} + {}]" , arg_id, |
256 | taichi_max_num_indices, axis); |
257 | emit("{} = {};" , var, axis_size); |
258 | } |
259 | |
260 | static std::string get_libc_function_name(std::string name, DataType dt) { |
261 | std::string ret; |
262 | if (dt->is_primitive(PrimitiveTypeID::i32)) |
263 | ret = name; |
264 | else if (dt->is_primitive(PrimitiveTypeID::i64)) |
265 | ret = "ll" + name; |
266 | else if (dt->is_primitive(PrimitiveTypeID::f32)) |
267 | ret = name + "f" ; |
268 | else if (dt->is_primitive(PrimitiveTypeID::f64)) |
269 | ret = name; |
270 | else |
271 | TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend" , name, |
272 | data_type_name(dt)); |
273 | |
274 | if (name == "rsqrt" ) { |
275 | ret = "Ti_" + ret; |
276 | } else if (name == "sgn" ) { |
277 | if (is_real(dt)) { |
278 | ret = "f" + ret; |
279 | } |
280 | ret = "Ti_" + ret; |
281 | } else if (name == "max" || name == "min" || name == "abs" ) { |
282 | if (is_real(dt)) { |
283 | ret = "f" + ret; |
284 | } else if (ret != "abs" ) { |
285 | ret = "Ti_" + ret; |
286 | } |
287 | } |
288 | return ret; |
289 | } |
290 | |
291 | static std::string invoke_libc(std::string name, |
292 | DataType dt, |
293 | std::string arguments) { |
294 | auto func_name = get_libc_function_name(name, dt); |
295 | return fmt::format("{}({})" , func_name, arguments); |
296 | } |
297 | |
298 | template <typename... Args> |
299 | static inline std::string invoke_libc(std::string name, |
300 | DataType dt, |
301 | std::string const &fmt, |
302 | Args &&...args) { |
303 | auto arguments = fmt::format(fmt, std::forward<Args>(args)...); |
304 | return invoke_libc(name, dt, arguments); |
305 | } |
306 | |
307 | void visit(TernaryOpStmt *tri) override { |
308 | TI_ASSERT(tri->op_type == TernaryOpType::select); |
309 | emit("{} {} = {} != 0 ? {} : {};" , cc_data_type_name(tri->element_type()), |
310 | tri->raw_name(), tri->op1->raw_name(), tri->op2->raw_name(), |
311 | tri->op3->raw_name()); |
312 | } |
313 | |
314 | void visit(BinaryOpStmt *bin) override { |
315 | const auto dt_name = cc_data_type_name(bin->element_type()); |
316 | const auto lhs_name = bin->lhs->raw_name(); |
317 | const auto rhs_name = bin->rhs->raw_name(); |
318 | const auto bin_name = bin->raw_name(); |
319 | const auto type = bin->element_type(); |
320 | const auto binop = binary_op_type_symbol(bin->op_type); |
321 | const auto var = define_var(dt_name, bin_name); |
322 | if (cc_is_binary_op_infix(bin->op_type)) { |
323 | if (is_comparison(bin->op_type)) { |
324 | // XXX(#577): Taichi uses -1 as true due to LLVM i1... |
325 | emit("{} = -({} {} {});" , var, lhs_name, binop, rhs_name); |
326 | } else if (bin->op_type == BinaryOpType::truediv) { |
327 | emit("{} = ({}) {} / {};" , var, dt_name, lhs_name, rhs_name); |
328 | } else { |
329 | emit("{} = {} {} {};" , var, lhs_name, binop, rhs_name); |
330 | } |
331 | } else { |
332 | emit("{} = {};" , var, |
333 | invoke_libc(binop, type, "{}, {}" , lhs_name, rhs_name)); |
334 | } |
335 | } |
336 | |
337 | void visit(DecorationStmt *stmt) override { |
338 | } |
339 | |
340 | void visit(UnaryOpStmt *stmt) override { |
341 | const auto dt_name = cc_data_type_name(stmt->element_type()); |
342 | const auto operand_name = stmt->operand->raw_name(); |
343 | const auto dest_name = stmt->raw_name(); |
344 | const auto type = stmt->element_type(); |
345 | const auto op = unary_op_type_symbol(stmt->op_type); |
346 | const auto var = define_var(dt_name, dest_name); |
347 | if (stmt->op_type == UnaryOpType::cast_value) { |
348 | emit("{} = ({}) {};" , var, dt_name, operand_name); |
349 | |
350 | } else if (stmt->op_type == UnaryOpType::cast_bits) { |
351 | const auto operand_dt_name = |
352 | cc_data_type_name(stmt->operand->element_type()); |
353 | emit("union {{ {} bc_src; {} bc_dst; }} {}_bitcast;" , operand_dt_name, |
354 | dt_name, dest_name); |
355 | emit("{}_bitcast.bc_src = {};" , dest_name, operand_name); |
356 | emit("{} = {}_bitcast.bc_dst;" , var, dest_name); |
357 | |
358 | } else if (cc_is_unary_op_infix(stmt->op_type)) { |
359 | emit("{} = {}{};" , var, op, operand_name); |
360 | } else { |
361 | emit("{} = {};" , var, invoke_libc(op, type, "{}" , operand_name)); |
362 | } |
363 | } |
364 | |
365 | void visit(AtomicOpStmt *stmt) override { |
366 | const auto dest_ptr = stmt->dest->raw_name(); |
367 | const auto src_name = stmt->val->raw_name(); |
368 | const auto op = cc_atomic_op_type_symbol(stmt->op_type); |
369 | const auto type = stmt->dest->element_type().ptr_removed(); |
370 | auto var = define_var(cc_data_type_name(type), stmt->raw_name()); |
371 | emit("{} = *{};" , var, dest_ptr); |
372 | if (stmt->op_type == AtomicOpType::max || |
373 | stmt->op_type == AtomicOpType::min) { |
374 | emit("*{} = {};" , dest_ptr, |
375 | invoke_libc(op, type, "*{}, {}" , dest_ptr, src_name)); |
376 | } else { |
377 | emit("*{} {}= {};" , dest_ptr, op, src_name); |
378 | } |
379 | } |
380 | |
381 | void visit(PrintStmt *stmt) override { |
382 | std::string format; |
383 | std::vector<std::string> values; |
384 | |
385 | for (int i = 0; i < stmt->contents.size(); i++) { |
386 | auto const &content = stmt->contents[i]; |
387 | |
388 | if (std::holds_alternative<Stmt *>(content)) { |
389 | auto arg_stmt = std::get<Stmt *>(content); |
390 | format += data_type_format(arg_stmt->ret_type); |
391 | values.push_back(arg_stmt->raw_name()); |
392 | |
393 | } else { |
394 | auto str = std::get<std::string>(content); |
395 | format += "%s" ; |
396 | values.push_back(c_quoted(str)); |
397 | } |
398 | } |
399 | |
400 | values.insert(values.begin(), c_quoted(format)); |
401 | emit("printf({});" , fmt::join(values, ", " )); |
402 | } |
403 | |
404 | void generate_serial_kernel(OffloadedStmt *stmt) { |
405 | stmt->body->accept(this); |
406 | } |
407 | |
408 | void generate_range_for_kernel(OffloadedStmt *stmt) { |
409 | if (stmt->const_begin && stmt->const_end) { |
410 | ScopedIndent _s(line_appender_); |
411 | auto begin_value = stmt->begin_value; |
412 | auto end_value = stmt->end_value; |
413 | auto var = define_var("Ti_i32" , stmt->raw_name()); |
414 | emit("for ({} = {}; {} < {}; {} += {}) {{" , var, begin_value, |
415 | stmt->raw_name(), end_value, stmt->raw_name(), 1 /* stmt->step? */); |
416 | stmt->body->accept(this); |
417 | emit("}}" ); |
418 | } else { |
419 | auto var = define_var("Ti_i32" , stmt->raw_name()); |
420 | auto begin_expr = "tmp_begin_" + stmt->raw_name(); |
421 | auto end_expr = "tmp_end_" + stmt->raw_name(); |
422 | auto begin_var = define_var("Ti_i32" , begin_expr); |
423 | auto end_var = define_var("Ti_i32" , end_expr); |
424 | if (!stmt->const_begin) { |
425 | emit("{} = *(Ti_i32 *) (ti_ctx->gtmp + {});" , begin_var, |
426 | stmt->begin_offset); |
427 | } else { |
428 | emit("{} = {};" , begin_var, stmt->begin_value); |
429 | } |
430 | if (!stmt->const_end) { |
431 | emit("{} = *(Ti_i32 *) (ti_ctx->gtmp + {});" , end_var, |
432 | stmt->end_offset); |
433 | } else { |
434 | emit("{} = {};" , end_var, stmt->end_value); |
435 | } |
436 | emit("for ({} = {}; {} < {}; {} += {}) {{" , var, begin_expr, |
437 | stmt->raw_name(), end_expr, stmt->raw_name(), 1 /* stmt->step? */); |
438 | stmt->body->accept(this); |
439 | emit("}}" ); |
440 | } |
441 | } |
442 | |
443 | void visit(OffloadedStmt *stmt) override { |
444 | TI_ASSERT(is_top_level_); |
445 | is_top_level_ = false; |
446 | if (stmt->task_type == OffloadedStmt::TaskType::serial) { |
447 | generate_serial_kernel(stmt); |
448 | } else if (stmt->task_type == OffloadedStmt::TaskType::range_for) { |
449 | generate_range_for_kernel(stmt); |
450 | } else { |
451 | TI_ERROR("[glsl] Unsupported offload type={} on C backend" , |
452 | stmt->task_name()); |
453 | } |
454 | is_top_level_ = true; |
455 | } |
456 | |
457 | void visit(LoopIndexStmt *stmt) override { |
458 | TI_ASSERT(stmt->index == 0); // TODO: multiple indices |
459 | if (stmt->loop->is<OffloadedStmt>()) { |
460 | auto type = stmt->loop->as<OffloadedStmt>()->task_type; |
461 | if (type == OffloadedStmt::TaskType::range_for) { |
462 | emit("Ti_i32 {} = {};" , stmt->raw_name(), stmt->loop->raw_name()); |
463 | } else { |
464 | TI_NOT_IMPLEMENTED |
465 | } |
466 | } else if (stmt->loop->is<RangeForStmt>()) { |
467 | emit("Ti_i32 {} = {};" , stmt->raw_name(), stmt->loop->raw_name()); |
468 | } else { |
469 | TI_NOT_IMPLEMENTED |
470 | } |
471 | } |
472 | |
473 | void visit(RangeForStmt *stmt) override { |
474 | auto var = define_var("Ti_i32" , stmt->raw_name()); |
475 | if (!stmt->reversed) { |
476 | emit("for ({} = {}; {} < {}; {} += {}) {{" , var, stmt->begin->raw_name(), |
477 | stmt->raw_name(), stmt->end->raw_name(), stmt->raw_name(), 1); |
478 | } else { |
479 | // reversed for loop |
480 | emit("for ({} = {} - {}; {} >= {}; {} -= {}) {{" , var, |
481 | stmt->end->raw_name(), 1, stmt->raw_name(), stmt->begin->raw_name(), |
482 | stmt->raw_name(), 1); |
483 | } |
484 | stmt->body->accept(this); |
485 | emit("}}" ); |
486 | } |
487 | |
488 | void visit(WhileControlStmt *stmt) override { |
489 | emit("if (!{}) break;" , stmt->cond->raw_name()); |
490 | } |
491 | |
492 | void visit(ContinueStmt *stmt) override { |
493 | emit("continue;" ); |
494 | } |
495 | |
496 | void visit(WhileStmt *stmt) override { |
497 | emit("while (1) {{" ); |
498 | stmt->body->accept(this); |
499 | emit("}}" ); |
500 | } |
501 | |
502 | void visit(IfStmt *stmt) override { |
503 | emit("if ({}) {{" , stmt->cond->raw_name()); |
504 | if (stmt->true_statements) { |
505 | stmt->true_statements->accept(this); |
506 | } |
507 | if (stmt->false_statements) { |
508 | emit("}} else {{" ); |
509 | stmt->false_statements->accept(this); |
510 | } |
511 | emit("}}" ); |
512 | } |
513 | |
514 | void visit(RandStmt *stmt) override { |
515 | auto var = define_var(cc_data_type_name(stmt->ret_type), stmt->raw_name()); |
516 | emit("{} = Ti_rand_{}();" , var, data_type_name(stmt->ret_type)); |
517 | } |
518 | |
519 | void visit(AdStackAllocaStmt *stmt) override { |
520 | TI_ASSERT_INFO( |
521 | stmt->max_size > 0, |
522 | "Adaptive autodiff stack's size should have been determined." ); |
523 | |
524 | const auto &var_name = stmt->raw_name(); |
525 | emit("Ti_u8 {}[{}];" , var_name, stmt->size_in_bytes() + sizeof(uint32_t)); |
526 | emit("Ti_ad_stack_init({});" , var_name); |
527 | } |
528 | |
529 | void visit(AdStackPopStmt *stmt) override { |
530 | emit("Ti_ad_stack_pop({});" , stmt->stack->raw_name()); |
531 | } |
532 | |
533 | void visit(AdStackPushStmt *stmt) override { |
534 | auto *stack = stmt->stack->as<AdStackAllocaStmt>(); |
535 | const auto &stack_name = stack->raw_name(); |
536 | auto elem_size = stack->element_size_in_bytes(); |
537 | emit("Ti_ad_stack_push({}, {});" , stack_name, elem_size); |
538 | auto primal_name = stmt->raw_name() + "_primal_" ; |
539 | auto dt_name = cc_data_type_name(stmt->element_type()); |
540 | auto var = define_var(dt_name + " *" , primal_name); |
541 | emit("{} = ({} *) Ti_ad_stack_top_primal({}, {});" , var, dt_name, |
542 | stack_name, elem_size); |
543 | emit("*{} = {};" , primal_name, stmt->v->raw_name()); |
544 | } |
545 | |
546 | void visit(AdStackLoadTopStmt *stmt) override { |
547 | auto *stack = stmt->stack->as<AdStackAllocaStmt>(); |
548 | const auto primal_name = stmt->raw_name() + "_primal_" ; |
549 | auto dt_name = cc_data_type_name(stmt->element_type()); |
550 | auto var = define_var(dt_name + " *" , primal_name); |
551 | emit("{} = ({} *)Ti_ad_stack_top_primal({}, {});" , var, dt_name, |
552 | stack->raw_name(), stack->element_size_in_bytes()); |
553 | emit("{} = *{};" , define_var(dt_name, stmt->raw_name()), primal_name); |
554 | } |
555 | |
556 | void visit(AdStackLoadTopAdjStmt *stmt) override { |
557 | auto *stack = stmt->stack->as<AdStackAllocaStmt>(); |
558 | const auto adjoint_name = stmt->raw_name() + "_adjoint_" ; |
559 | auto dt_name = cc_data_type_name(stmt->element_type()); |
560 | auto var = define_var(dt_name + " *" , adjoint_name); |
561 | emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});" , var, dt_name, |
562 | stack->raw_name(), stack->element_size_in_bytes()); |
563 | emit("{} = *{};" , define_var(dt_name, stmt->raw_name()), adjoint_name); |
564 | } |
565 | |
566 | void visit(AdStackAccAdjointStmt *stmt) override { |
567 | auto *stack = stmt->stack->as<AdStackAllocaStmt>(); |
568 | const auto adjoint_name = stmt->raw_name() + "_adjoint_" ; |
569 | auto dt_name = cc_data_type_name(stmt->element_type()); |
570 | auto var = define_var(dt_name + " *" , adjoint_name); |
571 | emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});" , var, dt_name, |
572 | stack->raw_name(), stack->element_size_in_bytes()); |
573 | emit("*{} += {};" , adjoint_name, stmt->v->raw_name()); |
574 | } |
575 | |
576 | template <typename... Args> |
577 | void emit(std::string f, Args &&...args) { |
578 | line_appender_.append(std::move(f), std::move(args)...); |
579 | } |
580 | |
581 | template <typename... Args> |
582 | void (std::string f, Args &&...args) { |
583 | line_appender_header_.append(std::move(f), std::move(args)...); |
584 | } |
585 | }; // namespace cccp |
586 | |
587 | std::unique_ptr<CCKernel> CCKernelGen::compile() { |
588 | lower_ast(compile_config_, kernel_); |
589 | auto layout = cc_program_impl_->get_layout(); |
590 | CCTransformer tran(kernel_, layout); |
591 | |
592 | tran.run(); |
593 | auto source = tran.get_source(); |
594 | auto ker = std::make_unique<CCKernel>(cc_program_impl_, kernel_, source, |
595 | kernel_->name); |
596 | ker->compile(); |
597 | return ker; |
598 | } |
599 | |
600 | } // namespace cccp |
601 | } // namespace taichi::lang |
602 | |