1#include "codegen_cc.h"
2#include "cc_kernel.h"
3#include "cc_layout.h"
4#include "cc_program.h"
5#include "taichi/ir/ir.h"
6#include "taichi/ir/statements.h"
7#include "taichi/ir/transforms.h"
8#include "taichi/util/line_appender.h"
9#include "taichi/util/str.h"
10#include "cc_utils.h"
11
12#define C90_COMPAT 0
13
14namespace taichi::lang {
15namespace cccp { // Codegen for C Compiler Processor
16
17namespace {
18std::string get_node_ptr_name(SNode *snode) {
19 return fmt::format("struct Ti_{} *", snode->get_node_type_name_hinted());
20}
21
22static void lower_ast(const CompileConfig &config, Kernel *kernel) {
23 auto ir = kernel->ir.get();
24 irpass::compile_to_executable(ir, config, kernel,
25 /*autodiff_mode=*/kernel->autodiff_mode,
26 /*ad_use_stack=*/true, config.print_ir,
27 /*lower_global_access*/ true);
28}
29
30} // namespace
31
32class CCTransformer : public IRVisitor {
33 private:
34 [[maybe_unused]] Kernel *kernel_;
35 [[maybe_unused]] CCLayout *layout_;
36
37 LineAppender line_appender_;
38 LineAppender line_appender_header_;
39 bool is_top_level_{true};
40 GetRootStmt *root_stmt_;
41
42 public:
43 CCTransformer(Kernel *kernel, CCLayout *layout)
44 : kernel_(kernel), layout_(layout) {
45 allow_undefined_visitor = true;
46 invoke_default_visitor = true;
47 }
48
49 void run() {
50 emit_header("void Tk_{}(struct Ti_Context *ti_ctx) {{", kernel_->name);
51 kernel_->ir->accept(this);
52 emit("}}");
53 }
54
55 std::string get_source() {
56 return line_appender_header_.lines() + line_appender_.lines();
57 }
58
59 private:
60 void visit(Block *stmt) override {
61 if (!is_top_level_)
62 line_appender_.push_indent();
63 for (auto &s : stmt->statements) {
64 s->accept(this);
65 }
66 if (!is_top_level_)
67 line_appender_.pop_indent();
68 }
69
70 void visit(Stmt *stmt) override {
71 TI_WARN("[cc] unsupported statement type {}\n{}", typeid(*stmt).name(),
72 stmt->tb);
73 }
74
75 std::string define_var(std::string const &type, std::string const &name) {
76 if (C90_COMPAT) {
77 emit_header("{} {};", type, name);
78 return name;
79 } else {
80 return fmt::format("{} {}", type, name);
81 }
82 }
83
84 void visit(GetRootStmt *stmt) override {
85 auto *root = kernel_->program->get_snode_root(SNodeTree::kFirstID);
86 emit("{} = ti_ctx->root;",
87 define_var(get_node_ptr_name(root), stmt->raw_name()));
88 root_stmt_ = stmt;
89 }
90
91 void visit(SNodeLookupStmt *stmt) override {
92 Stmt *input_ptr;
93 if (stmt->input_snode) {
94 input_ptr = stmt->input_snode;
95 } else {
96 TI_ASSERT(root_stmt_ != nullptr);
97 input_ptr = root_stmt_;
98 }
99
100 emit("{} = &{}[{}];",
101 define_var(get_node_ptr_name(stmt->snode), stmt->raw_name()),
102 input_ptr->raw_name(), stmt->input_index->raw_name());
103 }
104
105 void visit(GetChStmt *stmt) override {
106 auto snode = stmt->output_snode;
107 std::string type;
108 if (snode->type == SNodeType::place) {
109 auto dt = fmt::format("{} *", cc_data_type_name(snode->dt));
110 emit("{} = &{}->{};", define_var(dt, stmt->raw_name()),
111 stmt->input_ptr->raw_name(), snode->get_node_type_name());
112 } else {
113 emit("{} = {}->{};",
114 define_var(get_node_ptr_name(snode), stmt->raw_name()),
115 stmt->input_ptr->raw_name(), snode->get_node_type_name());
116 }
117 }
118
119 void visit(GlobalLoadStmt *stmt) override {
120 emit("{} = *{};",
121 define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()),
122 stmt->src->raw_name());
123 }
124
125 void visit(GlobalStoreStmt *stmt) override {
126 emit("*{} = {};", stmt->dest->raw_name(), stmt->val->raw_name());
127 }
128
129 void visit(GlobalTemporaryStmt *stmt) override {
130 auto ptr_type =
131 cc_data_type_name(stmt->element_type().ptr_removed()) + " *";
132 auto var = define_var(ptr_type, stmt->raw_name());
133 emit("{} = ({}) (ti_ctx->gtmp + {});", var, ptr_type, stmt->offset);
134 }
135
136 void visit(LinearizeStmt *stmt) override {
137 std::string val = "0";
138 for (int i = 0; i < stmt->inputs.size(); i++) {
139 val = fmt::format("({} * {} + {})", val, stmt->strides[i],
140 stmt->inputs[i]->raw_name());
141 }
142 emit("{} = {};", define_var("Ti_i32", stmt->raw_name()), val);
143 }
144
145 void visit(ExternalPtrStmt *stmt) override {
146 std::string offset = "0";
147 const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
148 const int arg_id = argload->arg_id;
149 const auto element_shape = stmt->element_shape;
150 const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS
151 : ExternalArrayLayout::kSOA;
152 const size_t element_shape_index_offset =
153 (layout == ExternalArrayLayout::kAOS)
154 ? stmt->indices.size() - element_shape.size()
155 : 0;
156 size_t size_var_index = 0;
157 for (int i = 0; i < stmt->indices.size(); i++) {
158 std::string stride;
159 if (i >= element_shape_index_offset &&
160 i < element_shape_index_offset + element_shape.size()) {
161 stride = fmt::format("{}", element_shape[i - element_shape.size()]);
162 } else {
163 stride = fmt::format("ti_ctx->earg[{} * {} + {}]", arg_id,
164 taichi_max_num_indices, size_var_index++);
165 }
166 offset = fmt::format("({} * {} + {})", offset, stride,
167 stmt->indices[i]->raw_name());
168 }
169 auto var =
170 define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
171 stmt->raw_name());
172 emit("{} = {} + {};", var, stmt->base_ptr->raw_name(), offset);
173 }
174
175 void visit(ArgLoadStmt *stmt) override {
176 if (stmt->is_ptr) {
177 auto var = define_var(
178 cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
179 stmt->raw_name());
180 emit("{} = ti_ctx->args[{}].ptr_{};", var, stmt->arg_id,
181 data_type_name(stmt->element_type().ptr_removed()));
182 } else {
183 auto var =
184 define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
185 emit("{} = ti_ctx->args[{}].val_{};", var, stmt->arg_id,
186 data_type_name(stmt->element_type()));
187 }
188 }
189
190 void visit(ReturnStmt *stmt) override {
191 int idx{0};
192 for (auto &value : stmt->values) {
193 emit("ti_ctx->args[{}].val_{} = {};", idx++,
194 data_type_name(value->element_type()), value->raw_name());
195 }
196 }
197
198 void visit(ConstStmt *stmt) override {
199 emit("{} = {};",
200 define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()),
201 stmt->val.stringify());
202 }
203
204 void visit(AllocaStmt *stmt) override {
205 emit("{} = 0;",
206 define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()));
207 }
208
209 void visit(LocalLoadStmt *stmt) override {
210 auto var =
211 define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
212 emit("{} = {};", var, stmt->src->raw_name());
213 }
214
215 void visit(LocalStoreStmt *stmt) override {
216 emit("{} = {};", stmt->dest->raw_name(), stmt->val->raw_name());
217 }
218
219 void visit(ExternalFuncCallStmt *stmt) override {
220 TI_ASSERT(stmt->type == ExternalFuncCallStmt::ASSEMBLY);
221 auto format = stmt->asm_source;
222 std::string source;
223
224 for (int i = 0; i < format.size(); i++) {
225 char c = format[i];
226 if (c == '%' || c == '$') { // '$' for output, '%' for input
227 int num = 0;
228 while (i < format.size()) {
229 i += 1;
230 if (!::isdigit(format[i])) {
231 i -= 1;
232 break;
233 }
234 num *= 10;
235 num += format[i] - '0';
236 }
237 auto args = (c == '%') ? stmt->arg_stmts : stmt->output_stmts;
238 TI_ASSERT_INFO(num < args.size(), "{}{} out of {} argument range {}", c,
239 num, ((c == '%') ? "input" : "output"), args.size());
240 source += args[num]->raw_name();
241 } else {
242 source.push_back(c);
243 }
244 }
245
246 emit("{};", source);
247 }
248
249 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
250 const auto type = cc_data_type_name(stmt->element_type());
251 const auto name = stmt->raw_name();
252 const auto var = define_var(type, name);
253 const auto arg_id = stmt->arg_id;
254 const auto axis = stmt->axis;
255 const auto axis_size = fmt::format("ti_ctx->earg[{} * {} + {}]", arg_id,
256 taichi_max_num_indices, axis);
257 emit("{} = {};", var, axis_size);
258 }
259
260 static std::string get_libc_function_name(std::string name, DataType dt) {
261 std::string ret;
262 if (dt->is_primitive(PrimitiveTypeID::i32))
263 ret = name;
264 else if (dt->is_primitive(PrimitiveTypeID::i64))
265 ret = "ll" + name;
266 else if (dt->is_primitive(PrimitiveTypeID::f32))
267 ret = name + "f";
268 else if (dt->is_primitive(PrimitiveTypeID::f64))
269 ret = name;
270 else
271 TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name,
272 data_type_name(dt));
273
274 if (name == "rsqrt") {
275 ret = "Ti_" + ret;
276 } else if (name == "sgn") {
277 if (is_real(dt)) {
278 ret = "f" + ret;
279 }
280 ret = "Ti_" + ret;
281 } else if (name == "max" || name == "min" || name == "abs") {
282 if (is_real(dt)) {
283 ret = "f" + ret;
284 } else if (ret != "abs") {
285 ret = "Ti_" + ret;
286 }
287 }
288 return ret;
289 }
290
291 static std::string invoke_libc(std::string name,
292 DataType dt,
293 std::string arguments) {
294 auto func_name = get_libc_function_name(name, dt);
295 return fmt::format("{}({})", func_name, arguments);
296 }
297
298 template <typename... Args>
299 static inline std::string invoke_libc(std::string name,
300 DataType dt,
301 std::string const &fmt,
302 Args &&...args) {
303 auto arguments = fmt::format(fmt, std::forward<Args>(args)...);
304 return invoke_libc(name, dt, arguments);
305 }
306
307 void visit(TernaryOpStmt *tri) override {
308 TI_ASSERT(tri->op_type == TernaryOpType::select);
309 emit("{} {} = {} != 0 ? {} : {};", cc_data_type_name(tri->element_type()),
310 tri->raw_name(), tri->op1->raw_name(), tri->op2->raw_name(),
311 tri->op3->raw_name());
312 }
313
314 void visit(BinaryOpStmt *bin) override {
315 const auto dt_name = cc_data_type_name(bin->element_type());
316 const auto lhs_name = bin->lhs->raw_name();
317 const auto rhs_name = bin->rhs->raw_name();
318 const auto bin_name = bin->raw_name();
319 const auto type = bin->element_type();
320 const auto binop = binary_op_type_symbol(bin->op_type);
321 const auto var = define_var(dt_name, bin_name);
322 if (cc_is_binary_op_infix(bin->op_type)) {
323 if (is_comparison(bin->op_type)) {
324 // XXX(#577): Taichi uses -1 as true due to LLVM i1...
325 emit("{} = -({} {} {});", var, lhs_name, binop, rhs_name);
326 } else if (bin->op_type == BinaryOpType::truediv) {
327 emit("{} = ({}) {} / {};", var, dt_name, lhs_name, rhs_name);
328 } else {
329 emit("{} = {} {} {};", var, lhs_name, binop, rhs_name);
330 }
331 } else {
332 emit("{} = {};", var,
333 invoke_libc(binop, type, "{}, {}", lhs_name, rhs_name));
334 }
335 }
336
337 void visit(DecorationStmt *stmt) override {
338 }
339
340 void visit(UnaryOpStmt *stmt) override {
341 const auto dt_name = cc_data_type_name(stmt->element_type());
342 const auto operand_name = stmt->operand->raw_name();
343 const auto dest_name = stmt->raw_name();
344 const auto type = stmt->element_type();
345 const auto op = unary_op_type_symbol(stmt->op_type);
346 const auto var = define_var(dt_name, dest_name);
347 if (stmt->op_type == UnaryOpType::cast_value) {
348 emit("{} = ({}) {};", var, dt_name, operand_name);
349
350 } else if (stmt->op_type == UnaryOpType::cast_bits) {
351 const auto operand_dt_name =
352 cc_data_type_name(stmt->operand->element_type());
353 emit("union {{ {} bc_src; {} bc_dst; }} {}_bitcast;", operand_dt_name,
354 dt_name, dest_name);
355 emit("{}_bitcast.bc_src = {};", dest_name, operand_name);
356 emit("{} = {}_bitcast.bc_dst;", var, dest_name);
357
358 } else if (cc_is_unary_op_infix(stmt->op_type)) {
359 emit("{} = {}{};", var, op, operand_name);
360 } else {
361 emit("{} = {};", var, invoke_libc(op, type, "{}", operand_name));
362 }
363 }
364
365 void visit(AtomicOpStmt *stmt) override {
366 const auto dest_ptr = stmt->dest->raw_name();
367 const auto src_name = stmt->val->raw_name();
368 const auto op = cc_atomic_op_type_symbol(stmt->op_type);
369 const auto type = stmt->dest->element_type().ptr_removed();
370 auto var = define_var(cc_data_type_name(type), stmt->raw_name());
371 emit("{} = *{};", var, dest_ptr);
372 if (stmt->op_type == AtomicOpType::max ||
373 stmt->op_type == AtomicOpType::min) {
374 emit("*{} = {};", dest_ptr,
375 invoke_libc(op, type, "*{}, {}", dest_ptr, src_name));
376 } else {
377 emit("*{} {}= {};", dest_ptr, op, src_name);
378 }
379 }
380
381 void visit(PrintStmt *stmt) override {
382 std::string format;
383 std::vector<std::string> values;
384
385 for (int i = 0; i < stmt->contents.size(); i++) {
386 auto const &content = stmt->contents[i];
387
388 if (std::holds_alternative<Stmt *>(content)) {
389 auto arg_stmt = std::get<Stmt *>(content);
390 format += data_type_format(arg_stmt->ret_type);
391 values.push_back(arg_stmt->raw_name());
392
393 } else {
394 auto str = std::get<std::string>(content);
395 format += "%s";
396 values.push_back(c_quoted(str));
397 }
398 }
399
400 values.insert(values.begin(), c_quoted(format));
401 emit("printf({});", fmt::join(values, ", "));
402 }
403
404 void generate_serial_kernel(OffloadedStmt *stmt) {
405 stmt->body->accept(this);
406 }
407
408 void generate_range_for_kernel(OffloadedStmt *stmt) {
409 if (stmt->const_begin && stmt->const_end) {
410 ScopedIndent _s(line_appender_);
411 auto begin_value = stmt->begin_value;
412 auto end_value = stmt->end_value;
413 auto var = define_var("Ti_i32", stmt->raw_name());
414 emit("for ({} = {}; {} < {}; {} += {}) {{", var, begin_value,
415 stmt->raw_name(), end_value, stmt->raw_name(), 1 /* stmt->step? */);
416 stmt->body->accept(this);
417 emit("}}");
418 } else {
419 auto var = define_var("Ti_i32", stmt->raw_name());
420 auto begin_expr = "tmp_begin_" + stmt->raw_name();
421 auto end_expr = "tmp_end_" + stmt->raw_name();
422 auto begin_var = define_var("Ti_i32", begin_expr);
423 auto end_var = define_var("Ti_i32", end_expr);
424 if (!stmt->const_begin) {
425 emit("{} = *(Ti_i32 *) (ti_ctx->gtmp + {});", begin_var,
426 stmt->begin_offset);
427 } else {
428 emit("{} = {};", begin_var, stmt->begin_value);
429 }
430 if (!stmt->const_end) {
431 emit("{} = *(Ti_i32 *) (ti_ctx->gtmp + {});", end_var,
432 stmt->end_offset);
433 } else {
434 emit("{} = {};", end_var, stmt->end_value);
435 }
436 emit("for ({} = {}; {} < {}; {} += {}) {{", var, begin_expr,
437 stmt->raw_name(), end_expr, stmt->raw_name(), 1 /* stmt->step? */);
438 stmt->body->accept(this);
439 emit("}}");
440 }
441 }
442
443 void visit(OffloadedStmt *stmt) override {
444 TI_ASSERT(is_top_level_);
445 is_top_level_ = false;
446 if (stmt->task_type == OffloadedStmt::TaskType::serial) {
447 generate_serial_kernel(stmt);
448 } else if (stmt->task_type == OffloadedStmt::TaskType::range_for) {
449 generate_range_for_kernel(stmt);
450 } else {
451 TI_ERROR("[glsl] Unsupported offload type={} on C backend",
452 stmt->task_name());
453 }
454 is_top_level_ = true;
455 }
456
457 void visit(LoopIndexStmt *stmt) override {
458 TI_ASSERT(stmt->index == 0); // TODO: multiple indices
459 if (stmt->loop->is<OffloadedStmt>()) {
460 auto type = stmt->loop->as<OffloadedStmt>()->task_type;
461 if (type == OffloadedStmt::TaskType::range_for) {
462 emit("Ti_i32 {} = {};", stmt->raw_name(), stmt->loop->raw_name());
463 } else {
464 TI_NOT_IMPLEMENTED
465 }
466 } else if (stmt->loop->is<RangeForStmt>()) {
467 emit("Ti_i32 {} = {};", stmt->raw_name(), stmt->loop->raw_name());
468 } else {
469 TI_NOT_IMPLEMENTED
470 }
471 }
472
473 void visit(RangeForStmt *stmt) override {
474 auto var = define_var("Ti_i32", stmt->raw_name());
475 if (!stmt->reversed) {
476 emit("for ({} = {}; {} < {}; {} += {}) {{", var, stmt->begin->raw_name(),
477 stmt->raw_name(), stmt->end->raw_name(), stmt->raw_name(), 1);
478 } else {
479 // reversed for loop
480 emit("for ({} = {} - {}; {} >= {}; {} -= {}) {{", var,
481 stmt->end->raw_name(), 1, stmt->raw_name(), stmt->begin->raw_name(),
482 stmt->raw_name(), 1);
483 }
484 stmt->body->accept(this);
485 emit("}}");
486 }
487
488 void visit(WhileControlStmt *stmt) override {
489 emit("if (!{}) break;", stmt->cond->raw_name());
490 }
491
492 void visit(ContinueStmt *stmt) override {
493 emit("continue;");
494 }
495
496 void visit(WhileStmt *stmt) override {
497 emit("while (1) {{");
498 stmt->body->accept(this);
499 emit("}}");
500 }
501
502 void visit(IfStmt *stmt) override {
503 emit("if ({}) {{", stmt->cond->raw_name());
504 if (stmt->true_statements) {
505 stmt->true_statements->accept(this);
506 }
507 if (stmt->false_statements) {
508 emit("}} else {{");
509 stmt->false_statements->accept(this);
510 }
511 emit("}}");
512 }
513
514 void visit(RandStmt *stmt) override {
515 auto var = define_var(cc_data_type_name(stmt->ret_type), stmt->raw_name());
516 emit("{} = Ti_rand_{}();", var, data_type_name(stmt->ret_type));
517 }
518
519 void visit(AdStackAllocaStmt *stmt) override {
520 TI_ASSERT_INFO(
521 stmt->max_size > 0,
522 "Adaptive autodiff stack's size should have been determined.");
523
524 const auto &var_name = stmt->raw_name();
525 emit("Ti_u8 {}[{}];", var_name, stmt->size_in_bytes() + sizeof(uint32_t));
526 emit("Ti_ad_stack_init({});", var_name);
527 }
528
529 void visit(AdStackPopStmt *stmt) override {
530 emit("Ti_ad_stack_pop({});", stmt->stack->raw_name());
531 }
532
533 void visit(AdStackPushStmt *stmt) override {
534 auto *stack = stmt->stack->as<AdStackAllocaStmt>();
535 const auto &stack_name = stack->raw_name();
536 auto elem_size = stack->element_size_in_bytes();
537 emit("Ti_ad_stack_push({}, {});", stack_name, elem_size);
538 auto primal_name = stmt->raw_name() + "_primal_";
539 auto dt_name = cc_data_type_name(stmt->element_type());
540 auto var = define_var(dt_name + " *", primal_name);
541 emit("{} = ({} *) Ti_ad_stack_top_primal({}, {});", var, dt_name,
542 stack_name, elem_size);
543 emit("*{} = {};", primal_name, stmt->v->raw_name());
544 }
545
546 void visit(AdStackLoadTopStmt *stmt) override {
547 auto *stack = stmt->stack->as<AdStackAllocaStmt>();
548 const auto primal_name = stmt->raw_name() + "_primal_";
549 auto dt_name = cc_data_type_name(stmt->element_type());
550 auto var = define_var(dt_name + " *", primal_name);
551 emit("{} = ({} *)Ti_ad_stack_top_primal({}, {});", var, dt_name,
552 stack->raw_name(), stack->element_size_in_bytes());
553 emit("{} = *{};", define_var(dt_name, stmt->raw_name()), primal_name);
554 }
555
556 void visit(AdStackLoadTopAdjStmt *stmt) override {
557 auto *stack = stmt->stack->as<AdStackAllocaStmt>();
558 const auto adjoint_name = stmt->raw_name() + "_adjoint_";
559 auto dt_name = cc_data_type_name(stmt->element_type());
560 auto var = define_var(dt_name + " *", adjoint_name);
561 emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name,
562 stack->raw_name(), stack->element_size_in_bytes());
563 emit("{} = *{};", define_var(dt_name, stmt->raw_name()), adjoint_name);
564 }
565
566 void visit(AdStackAccAdjointStmt *stmt) override {
567 auto *stack = stmt->stack->as<AdStackAllocaStmt>();
568 const auto adjoint_name = stmt->raw_name() + "_adjoint_";
569 auto dt_name = cc_data_type_name(stmt->element_type());
570 auto var = define_var(dt_name + " *", adjoint_name);
571 emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name,
572 stack->raw_name(), stack->element_size_in_bytes());
573 emit("*{} += {};", adjoint_name, stmt->v->raw_name());
574 }
575
576 template <typename... Args>
577 void emit(std::string f, Args &&...args) {
578 line_appender_.append(std::move(f), std::move(args)...);
579 }
580
581 template <typename... Args>
582 void emit_header(std::string f, Args &&...args) {
583 line_appender_header_.append(std::move(f), std::move(args)...);
584 }
585}; // namespace cccp
586
587std::unique_ptr<CCKernel> CCKernelGen::compile() {
588 lower_ast(compile_config_, kernel_);
589 auto layout = cc_program_impl_->get_layout();
590 CCTransformer tran(kernel_, layout);
591
592 tran.run();
593 auto source = tran.get_source();
594 auto ker = std::make_unique<CCKernel>(cc_program_impl_, kernel_, source,
595 kernel_->name);
596 ker->compile();
597 return ker;
598}
599
600} // namespace cccp
601} // namespace taichi::lang
602