1// The IRPrinter prints the IR in a human-readable format
2
3#include "taichi/ir/expression_printer.h"
4#include "taichi/ir/ir.h"
5#include "taichi/ir/statements.h"
6#include "taichi/ir/transforms.h"
7#include "taichi/ir/visitors.h"
8#include "taichi/ir/frontend_ir.h"
9#include "taichi/util/str.h"
10
11namespace taichi::lang {
12
13namespace {
14
15std::string scratch_pad_info(const MemoryAccessOptions &opt) {
16 std::string ser;
17 if (!opt.get_all().empty()) {
18 ser += "mem_access_opt [ ";
19 for (auto &rec : opt.get_all()) {
20 for (auto flag : rec.second) {
21 ser += rec.first->get_node_type_name_hinted() + ":" +
22 snode_access_flag_name(flag) + " ";
23 }
24 }
25 ser += "] ";
26 } else {
27 ser = "none";
28 }
29 return ser;
30}
31
32std::string block_dim_info(int block_dim) {
33 return "block_dim=" +
34 (block_dim == 0 ? "adaptive" : std::to_string(block_dim)) + " ";
35}
36
37class IRPrinter : public IRVisitor {
38 private:
39 ExpressionPrinter *expr_printer_{nullptr};
40
41 public:
42 int current_indent{0};
43
44 std::string *output{nullptr};
45 std::stringstream ss;
46
47 explicit IRPrinter(ExpressionPrinter *expr_printer = nullptr,
48 std::string *output = nullptr)
49 : expr_printer_(expr_printer), output(output) {
50 }
51
52 template <typename... Args>
53 void print(std::string f, Args &&...args) {
54 print_raw(fmt::format(f, std::forward<Args>(args)...));
55 }
56
57 void print_raw(std::string f) {
58 for (int i = 0; i < current_indent; i++)
59 f.insert(0, " ");
60 f += "\n";
61 if (output) {
62 ss << f;
63 } else {
64 std::cout << f;
65 }
66 }
67
68 static void run(ExpressionPrinter *expr_printer,
69 IRNode *node,
70 std::string *output) {
71 if (node == nullptr) {
72 TI_WARN("IRPrinter: Printing nullptr.");
73 if (output) {
74 *output = std::string();
75 }
76 return;
77 }
78 auto p = IRPrinter(expr_printer, output);
79 p.print("kernel {{");
80 node->accept(&p);
81 p.print("}}");
82 if (output)
83 *output = p.ss.str();
84 }
85
86 void visit(Block *stmt_list) override {
87 current_indent++;
88 for (auto &stmt : stmt_list->statements) {
89 stmt->accept(this);
90 }
91 current_indent--;
92 }
93
94 void visit(FrontendExprStmt *stmt) override {
95 print("{}", (stmt->val));
96 }
97
98 void visit(FrontendBreakStmt *stmt) override {
99 print("break");
100 }
101
102 void visit(FrontendContinueStmt *stmt) override {
103 print("continue");
104 }
105
106 void visit(FrontendAssignStmt *assign) override {
107 print("{} = {}", expr_to_string(assign->lhs), expr_to_string(assign->rhs));
108 }
109
110 void visit(FrontendAllocaStmt *alloca) override {
111 std::string shared_suffix = (alloca->is_shared) ? "(shared)" : "";
112 print("{}${} = alloca{} {}", alloca->type_hint(), alloca->id, shared_suffix,
113 alloca->ident.name());
114 }
115
116 void visit(FrontendAssertStmt *assert) override {
117 print("{} : assert {}", assert->name(), expr_to_string(assert->cond));
118 }
119
120 void visit(AssertStmt *assert) override {
121 std::string extras;
122 for (auto &arg : assert->args) {
123 extras += ", ";
124 extras += arg->name();
125 }
126 print("{} : assert {}, \"{}\"{}", assert->id, assert->cond->name(),
127 assert->text, extras);
128 }
129
130 void visit(ExternalFuncCallStmt *stmt) override {
131 std::string extras;
132 if (stmt->so_func != nullptr) {
133 extras += fmt::format("so {:x} ", (uint64)stmt->so_func);
134 } else if (!stmt->asm_source.empty()) {
135 extras += fmt::format("asm \"{}\" ", stmt->asm_source);
136 } else {
137 extras += fmt::format("bc {}:{} ", stmt->bc_filename, stmt->bc_funcname);
138 }
139 extras += "inputs=";
140 for (auto &arg : stmt->arg_stmts) {
141 extras += ", ";
142 extras += arg->name();
143 }
144 extras += "outputs=";
145 for (auto &output : stmt->output_stmts) {
146 extras += ", ";
147 extras += output->name();
148 }
149 print("{} : {}", stmt->name(), extras);
150 }
151
152 void visit(FrontendSNodeOpStmt *stmt) override {
153 std::string extras = "[";
154 for (int i = 0; i < (int)stmt->indices.size(); i++) {
155 extras += expr_to_string(stmt->indices[i]);
156 if (i + 1 < (int)stmt->indices.size())
157 extras += ", ";
158 }
159 extras += "]";
160 if (stmt->val.expr) {
161 extras += ", " + expr_to_string(stmt->val);
162 }
163 print("{} : {} {} {}", stmt->name(), snode_op_type_name(stmt->op_type),
164 stmt->snode->get_node_type_name_hinted(), extras);
165 }
166
167 void visit(SNodeOpStmt *stmt) override {
168 std::string extras;
169 if (stmt->ptr)
170 extras = "ptr = " + stmt->ptr->name();
171 if (stmt->val) {
172 extras += ", val = " + stmt->val->name();
173 }
174 std::string snode = stmt->snode->get_node_type_name_hinted();
175 print("{}{} = {} [{}] {}", stmt->type_hint(), stmt->name(),
176 snode_op_type_name(stmt->op_type), snode, extras);
177 }
178
179 void visit(AllocaStmt *alloca) override {
180 std::string shared_suffix = (alloca->is_shared) ? "(shared)" : "";
181 print("{}${} = alloca{}", alloca->type_hint(), alloca->id, shared_suffix);
182 }
183
184 void visit(RandStmt *stmt) override {
185 print("{}{} = rand()", stmt->type_hint(), stmt->name());
186 }
187
188 void visit(DecorationStmt *stmt) override {
189 if (stmt->decoration.size() == 2 &&
190 stmt->decoration[0] ==
191 uint32_t(DecorationStmt::Decoration::kLoopUnique)) {
192 print("decorate {} : Loop-unique {}", stmt->operand->name(),
193 stmt->decoration[0], stmt->decoration[1]);
194 } else {
195 print("decorate {} : ... size = {}", stmt->operand->name(),
196 stmt->decoration.size());
197 }
198 }
199
200 void visit(UnaryOpStmt *stmt) override {
201 if (stmt->is_cast()) {
202 std::string reint =
203 stmt->op_type == UnaryOpType::cast_value ? "" : "reinterpret_";
204 print("{}{} = {}{}<{}> {}", stmt->type_hint(), stmt->name(), reint,
205 unary_op_type_name(stmt->op_type), data_type_name(stmt->cast_type),
206 stmt->operand->name());
207 } else {
208 print("{}{} = {} {}", stmt->type_hint(), stmt->name(),
209 unary_op_type_name(stmt->op_type), stmt->operand->name());
210 }
211 }
212
213 void visit(BinaryOpStmt *bin) override {
214 print("{}{} = {} {} {}", bin->type_hint(), bin->name(),
215 binary_op_type_name(bin->op_type), bin->lhs->name(),
216 bin->rhs->name());
217 }
218
219 void visit(TernaryOpStmt *stmt) override {
220 print("{}{} = {}({}, {}, {})", stmt->type_hint(), stmt->name(),
221 ternary_type_name(stmt->op_type), stmt->op1->name(),
222 stmt->op2->name(), stmt->op3->name());
223 }
224
225 void visit(AtomicOpStmt *stmt) override {
226 print("{}{} = atomic {}({}, {})", stmt->type_hint(), stmt->name(),
227 atomic_op_type_name(stmt->op_type), stmt->dest->name(),
228 stmt->val->name());
229 }
230
231 void visit(IfStmt *if_stmt) override {
232 print("{} : if {} {{", if_stmt->name(), if_stmt->cond->name());
233 if (if_stmt->true_statements)
234 if_stmt->true_statements->accept(this);
235 if (if_stmt->false_statements) {
236 print("}} else {{");
237 if_stmt->false_statements->accept(this);
238 }
239 print("}}");
240 }
241
242 void visit(FrontendIfStmt *if_stmt) override {
243 print("{} : if {} {{", if_stmt->name(), expr_to_string(if_stmt->condition));
244 if (if_stmt->true_statements)
245 if_stmt->true_statements->accept(this);
246 if (if_stmt->false_statements) {
247 print("}} else {{");
248 if_stmt->false_statements->accept(this);
249 }
250 print("}}");
251 }
252
253 void visit(FrontendPrintStmt *print_stmt) override {
254 std::vector<std::string> contents;
255 for (auto const &c : print_stmt->contents) {
256 std::string name;
257 if (std::holds_alternative<Expr>(c))
258 name = expr_to_string(std::get<Expr>(c).expr.get());
259 else
260 name = c_quoted(std::get<std::string>(c));
261 contents.push_back(name);
262 }
263 print("print {}", fmt::join(contents, ", "));
264 }
265
266 void visit(PrintStmt *print_stmt) override {
267 std::vector<std::string> names;
268 for (auto const &c : print_stmt->contents) {
269 std::string name;
270 if (std::holds_alternative<Stmt *>(c))
271 name = std::get<Stmt *>(c)->name();
272 else
273 name = c_quoted(std::get<std::string>(c));
274 names.push_back(name);
275 }
276 print("print {}", fmt::join(names, ", "));
277 }
278
279 void visit(ConstStmt *const_stmt) override {
280 print("{}{} = const {}", const_stmt->type_hint(), const_stmt->name(),
281 const_stmt->val.stringify());
282 }
283
284 void visit(WhileControlStmt *stmt) override {
285 print("{} : while control {}, {}", stmt->name(),
286 stmt->mask ? stmt->mask->name() : "nullptr", stmt->cond->name());
287 }
288
289 void visit(ContinueStmt *stmt) override {
290 if (stmt->scope) {
291 print("{} continue (scope={})", stmt->name(), stmt->scope->name());
292 } else {
293 print("{} continue", stmt->name());
294 }
295 }
296
297 void visit(FrontendFuncCallStmt *stmt) override {
298 std::string args;
299 for (int i = 0; i < stmt->args.exprs.size(); i++) {
300 if (i) {
301 args += ", ";
302 }
303 args += expr_to_string(stmt->args.exprs[i]);
304 }
305 print("{}${} = call \"{}\", args = ({}), ret = {}", stmt->type_hint(),
306 stmt->id, stmt->func->get_name(), args, stmt->ident->name());
307 }
308
309 void visit(FuncCallStmt *stmt) override {
310 std::vector<std::string> args;
311 for (const auto &arg : stmt->args) {
312 args.push_back(arg->name());
313 }
314 print("{}{} = call \"{}\", args = {{{}}}", stmt->type_hint(), stmt->name(),
315 stmt->func->get_name(), fmt::join(args, ", "));
316 }
317
318 void visit(FrontendFuncDefStmt *stmt) override {
319 print("function \"{}\" {{", stmt->funcid);
320 stmt->body->accept(this);
321 print("}}");
322 }
323
324 void visit(WhileStmt *stmt) override {
325 print("{} : while true {{", stmt->name());
326 stmt->body->accept(this);
327 print("}}");
328 }
329
330 void visit(FrontendWhileStmt *stmt) override {
331 print("{} : while {} {{", stmt->name(), expr_to_string(stmt->cond));
332 stmt->body->accept(this);
333 print("}}");
334 }
335
336 void visit(FrontendForStmt *for_stmt) override {
337 auto vars = make_list<Identifier>(
338 for_stmt->loop_var_ids,
339 [](const Identifier &id) -> std::string { return id.name(); });
340 if (for_stmt->snode) {
341 print("{} : for {} in {} {}{}{{", for_stmt->name(), vars,
342 for_stmt->snode->get_node_type_name_hinted(),
343 scratch_pad_info(for_stmt->mem_access_opt),
344 block_dim_info(for_stmt->block_dim));
345 } else if (for_stmt->external_tensor) {
346 print("{} : for {} in {} {}{}{{", for_stmt->name(), vars,
347 expr_to_string(for_stmt->external_tensor),
348 scratch_pad_info(for_stmt->mem_access_opt),
349 block_dim_info(for_stmt->block_dim));
350 } else if (for_stmt->mesh) {
351 print("{} : for {} in mesh {{", for_stmt->name(), vars);
352 } else {
353 print("{} : for {} in range({}, {}) {}{{", for_stmt->name(), vars,
354 expr_to_string(for_stmt->begin), expr_to_string(for_stmt->end),
355 block_dim_info(for_stmt->block_dim));
356 }
357 for_stmt->body->accept(this);
358 print("}}");
359 }
360
361 void visit(RangeForStmt *for_stmt) override {
362 print("{} : {}for in range({}, {}) {}{}{{", for_stmt->name(),
363 for_stmt->reversed ? "reversed " : "", for_stmt->begin->name(),
364 for_stmt->end->name(),
365 for_stmt->is_bit_vectorized ? "(bit_vectorized) " : "",
366 block_dim_info(for_stmt->block_dim));
367 for_stmt->body->accept(this);
368 print("}}");
369 }
370
371 void visit(StructForStmt *for_stmt) override {
372 print("{} : struct for in {} {}{}{}{{", for_stmt->name(),
373 for_stmt->snode->get_node_type_name_hinted(),
374 for_stmt->is_bit_vectorized ? "(bit_vectorized) " : "",
375 scratch_pad_info(for_stmt->mem_access_opt),
376 block_dim_info(for_stmt->block_dim));
377 for_stmt->body->accept(this);
378 print("}}");
379 }
380
381 void visit(MeshForStmt *for_stmt) override {
382 print("{} : mesh for ({} -> {}) {}{{", for_stmt->name(),
383 mesh::element_type_name(for_stmt->major_from_type),
384 for_stmt->major_to_types.size() == 0
385 ? "Unknown"
386 : mesh::element_type_name(*for_stmt->major_to_types.begin()),
387 scratch_pad_info(for_stmt->mem_access_opt));
388 for_stmt->body->accept(this);
389 print("}}");
390 }
391
392 void visit(MatrixOfGlobalPtrStmt *stmt) override {
393 std::string s = fmt::format("{}{} = matrix of global ptr [",
394 stmt->type_hint(), stmt->name());
395
396 for (int i = 0; i < (int)stmt->snodes.size(); i++) {
397 s += fmt::format("{}", stmt->snodes[i]->get_node_type_name_hinted());
398 if (i + 1 < (int)stmt->snodes.size()) {
399 s += ", ";
400 }
401 }
402 s += "], index [";
403 for (int i = 0; i < (int)stmt->indices.size(); i++) {
404 s += fmt::format("{}", stmt->indices[i]->name());
405 if (i + 1 < (int)stmt->indices.size()) {
406 s += ", ";
407 }
408 }
409 s += "]";
410
411 s += " activate=" + std::string(stmt->activate ? "true" : "false");
412
413 print_raw(s);
414 }
415
416 void visit(GlobalPtrStmt *stmt) override {
417 std::string s =
418 fmt::format("{}{} = global ptr [", stmt->type_hint(), stmt->name());
419
420 std::string snode_name;
421 if (stmt->snode) {
422 snode_name = stmt->snode->get_node_type_name_hinted();
423 } else {
424 snode_name = "unknown";
425 }
426 s += snode_name;
427 s += "], index [";
428 for (int i = 0; i < (int)stmt->indices.size(); i++) {
429 s += fmt::format("{}", stmt->indices[i]->name());
430 if (i + 1 < (int)stmt->indices.size()) {
431 s += ", ";
432 }
433 }
434 s += "]";
435
436 s += " activate=" + std::string(stmt->activate ? "true" : "false");
437
438 print_raw(s);
439 }
440
441 void visit(MatrixOfMatrixPtrStmt *stmt) override {
442 std::string s = fmt::format("{}{} = matrix of matrix ptr [",
443 stmt->type_hint(), stmt->name());
444 for (int i = 0; i < (int)stmt->stmts.size(); i++) {
445 s += fmt::format("{}", stmt->stmts[i]->name());
446 if (i + 1 < (int)stmt->stmts.size()) {
447 s += ", ";
448 }
449 }
450 s += "]";
451 print_raw(s);
452 }
453
454 void visit(MatrixPtrStmt *stmt) override {
455 std::string s =
456 fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(),
457 stmt->name(), stmt->origin->name(), stmt->offset->name());
458 print_raw(s);
459 }
460
461 void visit(ArgLoadStmt *stmt) override {
462 if (!stmt->is_grad) {
463 print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
464 } else {
465 print("{}{} = grad_arg[{}]", stmt->type_hint(), stmt->name(),
466 stmt->arg_id);
467 }
468 }
469
470 void visit(TexturePtrStmt *stmt) override {
471 print("<*Texture> {} = {}", stmt->name(), stmt->arg_load_stmt->name());
472 }
473
474 void visit(TextureOpStmt *stmt) override {
475 std::string args_string = "";
476 for (int i = 0; i < (int)stmt->args.size(); i++) {
477 args_string += fmt::format("{}", stmt->args[i]->name());
478 if (i + 1 < (int)stmt->args.size()) {
479 args_string += ", ";
480 }
481 }
482
483 print("<struct> {} = texture_{}({})", stmt->name(),
484 texture_op_type_name(stmt->op), args_string);
485 }
486
487 void visit(FrontendReturnStmt *stmt) override {
488 print("{}{} : return [{}]", stmt->type_hint(), stmt->name(),
489 expr_group_to_string(stmt->values));
490 }
491
492 void visit(ReturnStmt *stmt) override {
493 print("{}{} : return {}", stmt->type_hint(), stmt->name(),
494 stmt->values_raw_names());
495 }
496
497 void visit(LocalLoadStmt *stmt) override {
498 print("{}{} = local load [{}]", stmt->type_hint(), stmt->name(),
499 stmt->src->name());
500 }
501
502 void visit(LocalStoreStmt *stmt) override {
503 print("{}{} : local store [{} <- {}]", stmt->type_hint(), stmt->name(),
504 stmt->dest->name(), stmt->val->name());
505 }
506
507 void visit(GlobalLoadStmt *stmt) override {
508 print("{}{} = global load {}", stmt->type_hint(), stmt->name(),
509 stmt->src->name());
510 }
511
512 void visit(GlobalStoreStmt *stmt) override {
513 print("{}{} : global store [{} <- {}]", stmt->type_hint(), stmt->name(),
514 stmt->dest->name(), stmt->val->name());
515 }
516
517 void visit(RangeAssumptionStmt *stmt) override {
518 print("{}{} = assume_in_range({}{:+d} <= {} < {}{:+d})", stmt->type_hint(),
519 stmt->name(), stmt->base->name(), stmt->low, stmt->input->name(),
520 stmt->base->name(), stmt->high);
521 }
522
523 void visit(LoopUniqueStmt *stmt) override {
524 std::string add = "";
525 if (!stmt->covers.empty()) {
526 add = ", covers=[";
527 for (const auto &sn : stmt->covers) {
528 add += fmt::format("S{}, ", sn);
529 }
530 add.erase(add.size() - 2, 2); // remove the last ", "
531 add += "]";
532 }
533 print("{}{} = loop_unique({}{})", stmt->type_hint(), stmt->name(),
534 stmt->input->name(), add);
535 }
536
537 void visit(LinearizeStmt *stmt) override {
538 auto ind = make_list<Stmt *>(
539 stmt->inputs, [&](Stmt *const &stmt) { return stmt->name(); }, "{");
540 auto stride = make_list<int>(
541 stmt->strides,
542 [&](const int &stride) { return std::to_string(stride); }, "{");
543
544 print("{}{} = linearized(ind {}, stride {})", stmt->type_hint(),
545 stmt->name(), ind, stride);
546 }
547
548 void visit(IntegerOffsetStmt *stmt) override {
549 print("{}{} = offset {} + {}", stmt->type_hint(), stmt->name(),
550 stmt->input->name(), stmt->offset);
551 }
552
553 void visit(GetRootStmt *stmt) override {
554 if (stmt->root() == nullptr)
555 print("{}{} = get root nullptr", stmt->type_hint(), stmt->name());
556 else
557 print("{}{} = get root [{}][{}]", stmt->type_hint(), stmt->name(),
558 stmt->root()->get_node_type_name_hinted(),
559 stmt->root()->type_name());
560 }
561
562 void visit(SNodeLookupStmt *stmt) override {
563 print("{}{} = [{}][{}]::lookup({}, {}) activate = {}", stmt->type_hint(),
564 stmt->name(), stmt->snode->get_node_type_name_hinted(),
565 stmt->snode->type_name(), stmt->input_snode->name(),
566 stmt->input_index->name(), stmt->activate);
567 }
568
569 void visit(GetChStmt *stmt) override {
570 print("{}{} = get child [{}->{}] {}", stmt->type_hint(), stmt->name(),
571 stmt->input_snode->get_node_type_name_hinted(),
572 stmt->output_snode->get_node_type_name_hinted(),
573 stmt->input_ptr->name());
574 }
575
576 void visit(ExternalPtrStmt *stmt) override {
577 std::string s = stmt->base_ptr->name();
578 s += ", [";
579 for (int i = 0; i < (int)stmt->indices.size(); i++) {
580 s += fmt::format("{}", stmt->indices[i]->name());
581 if (i + 1 < (int)stmt->indices.size()) {
582 s += ", ";
583 }
584 }
585 s += "]";
586 if (stmt->element_shape.size()) {
587 s += ", (";
588 for (int i = 0; i < (int)stmt->element_shape.size(); i++) {
589 s += fmt::format("{}", stmt->element_shape[i]);
590 if (i + 1 < (int)stmt->element_shape.size()) {
591 s += ", ";
592 }
593 }
594 s += ")";
595 }
596 s += fmt::format(" element_dim={} layout={} is_grad={}", stmt->element_dim,
597 (stmt->element_dim <= 0) ? "AOS" : "SOA",
598 stmt->base_ptr->as<ArgLoadStmt>()->is_grad);
599
600 print(fmt::format("{}{} = external_ptr {}", stmt->type_hint(), stmt->name(),
601 s));
602 }
603
604 void visit(OffloadedStmt *stmt) override {
605 std::string details;
606 if (stmt->task_type == OffloadedTaskType::range_for) {
607 std::string begin_str, end_str;
608 if (stmt->const_begin) {
609 begin_str = std::to_string(stmt->begin_value);
610 } else {
611 begin_str = fmt::format("tmp(offset={}B)", stmt->begin_offset);
612 }
613 if (stmt->const_end) {
614 end_str = std::to_string(stmt->end_value);
615 } else if (stmt->end_stmt && !stmt->end_stmt->is<ConstStmt>()) {
616 // range_for end is a non-const stmt (e.g. ndarray axis)
617 end_str = stmt->end_stmt->name();
618 } else {
619 end_str = fmt::format("tmp(offset={}B)", stmt->end_offset);
620 }
621 details =
622 fmt::format("range_for({}, {}) grid_dim={} block_dim={}", begin_str,
623 end_str, stmt->grid_dim, stmt->block_dim);
624 } else if (stmt->task_type == OffloadedTaskType::struct_for) {
625 details =
626 fmt::format("struct_for({}) grid_dim={} block_dim={} bls={}",
627 stmt->snode->get_node_type_name_hinted(), stmt->grid_dim,
628 stmt->block_dim, scratch_pad_info(stmt->mem_access_opt));
629 } else if (stmt->task_type == OffloadedTaskType::mesh_for) {
630 details = fmt::format(
631 "mesh_for({} -> {}) num_patches={} grid_dim={} block_dim={} bls={}",
632 mesh::element_type_name(stmt->major_from_type),
633 stmt->major_to_types.size() == 0
634 ? "Unknown"
635 : mesh::element_type_name(*stmt->major_to_types.begin()),
636 stmt->mesh->num_patches, stmt->grid_dim, stmt->block_dim,
637 scratch_pad_info(stmt->mem_access_opt));
638 }
639 if (stmt->task_type == OffloadedTaskType::listgen) {
640 print("{} = offloaded listgen {}->{}", stmt->name(),
641 stmt->snode->parent->get_node_type_name_hinted(),
642 stmt->snode->get_node_type_name_hinted());
643 } else if (stmt->task_type == OffloadedTaskType::gc) {
644 print("{} = offloaded garbage collect {}", stmt->name(),
645 stmt->snode->get_node_type_name_hinted());
646 } else if (stmt->task_type == OffloadedTaskType::gc_rc) {
647 print("{} = offloaded garbage collect runtime context", stmt->name());
648 } else {
649 print("{} = offloaded {} ", stmt->name(), details);
650 if (stmt->tls_prologue) {
651 print("tls prologue {{");
652 stmt->tls_prologue->accept(this);
653 print("}}");
654 }
655 if (stmt->mesh_prologue) {
656 TI_ASSERT(stmt->task_type == OffloadedTaskType::mesh_for);
657 print("body prologue {{");
658 stmt->mesh_prologue->accept(this);
659 print("}}");
660 }
661 if (stmt->bls_prologue) {
662 print("bls prologue {{");
663 stmt->bls_prologue->accept(this);
664 print("}}");
665 }
666 TI_ASSERT(stmt->body);
667 print("body {{");
668 stmt->body->accept(this);
669 print("}}");
670 if (stmt->bls_epilogue) {
671 print("bls_epilogue {{");
672 stmt->bls_epilogue->accept(this);
673 print("}}");
674 }
675 if (stmt->tls_epilogue) {
676 print("tls_epilogue {{");
677 stmt->tls_epilogue->accept(this);
678 print("}}");
679 }
680 }
681 }
682
683 void visit(ClearListStmt *stmt) override {
684 print("{} = clear_list {}", stmt->name(),
685 stmt->snode->get_node_type_name_hinted());
686 }
687
688 void visit(LoopIndexStmt *stmt) override {
689 print("{}{} = loop {} index {}", stmt->type_hint(), stmt->name(),
690 stmt->loop->name(), stmt->index);
691 }
692
693 void visit(LoopLinearIndexStmt *stmt) override {
694 print("{}{} = loop {} index linear", stmt->type_hint(), stmt->name(),
695 stmt->loop->name());
696 }
697
698 void visit(BlockCornerIndexStmt *stmt) override {
699 print("{}{} = loop {} block corner index {}", stmt->type_hint(),
700 stmt->name(), stmt->loop->name(), stmt->index);
701 }
702
703 void visit(GlobalTemporaryStmt *stmt) override {
704 print("{}{} = global tmp var (offset = {} B)", stmt->type_hint(),
705 stmt->name(), stmt->offset);
706 }
707
708 void visit(ThreadLocalPtrStmt *stmt) override {
709 print("{}{} = thread local ptr (offset = {} B)", stmt->type_hint(),
710 stmt->name(), stmt->offset);
711 }
712
713 void visit(BlockLocalPtrStmt *stmt) override {
714 print("{}{} = block local ptr (offset = {})", stmt->type_hint(),
715 stmt->name(), stmt->offset->name());
716 }
717
718 void visit(InternalFuncStmt *stmt) override {
719 std::string args;
720 bool first = true;
721 for (auto &arg : stmt->args) {
722 if (!first) {
723 args += ", ";
724 }
725 args += arg->name();
726 first = false;
727 }
728 print("{}{} = internal call {}({})", stmt->type_hint(), stmt->name(),
729 stmt->func_name, args);
730 }
731
732 void visit(AdStackAllocaStmt *stmt) override {
733 print("{}{} = stack alloc (max_size={})", stmt->type_hint(), stmt->name(),
734 stmt->max_size);
735 }
736
737 void visit(AdStackLoadTopStmt *stmt) override {
738 print("{}{} = stack load top {}", stmt->type_hint(), stmt->name(),
739 stmt->stack->name());
740 }
741
742 void visit(AdStackLoadTopAdjStmt *stmt) override {
743 print("{}{} = stack load top adj {}", stmt->type_hint(), stmt->name(),
744 stmt->stack->name());
745 }
746
747 void visit(AdStackPushStmt *stmt) override {
748 print("{}{} : stack push {}, val = {}", stmt->type_hint(), stmt->name(),
749 stmt->stack->name(), stmt->v->name());
750 }
751
752 void visit(AdStackPopStmt *stmt) override {
753 print("{}{} : stack pop {}", stmt->type_hint(), stmt->name(),
754 stmt->stack->name());
755 }
756
757 void visit(AdStackAccAdjointStmt *stmt) override {
758 print("{}{} : stack acc adj {}, val = {}", stmt->type_hint(), stmt->name(),
759 stmt->stack->name(), stmt->v->name());
760 }
761
762 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
763 print("{}{} = external_tensor_shape_along_axis {}, arg_id {}",
764 stmt->type_hint(), stmt->name(), stmt->axis, stmt->arg_id);
765 }
766
767 void visit(BitStructStoreStmt *stmt) override {
768 std::string ch_ids;
769 std::string values;
770 for (int i = 0; i < stmt->ch_ids.size(); i++) {
771 ch_ids += fmt::format("{}", stmt->ch_ids[i]);
772 values += fmt::format("{}", stmt->values[i]->name());
773 if (i != stmt->ch_ids.size() - 1) {
774 ch_ids += ", ";
775 values += ", ";
776 }
777 }
778 print("{} : {}bit_struct_store {}, ch_ids=[{}], values=[{}]", stmt->name(),
779 stmt->is_atomic ? "atomic " : "", stmt->ptr->name(), ch_ids, values);
780 }
781
782 // Mesh related.
783
784 void visit(MeshRelationAccessStmt *stmt) override {
785 if (stmt->is_size()) {
786 print("{}{} = {} idx relation {} size", stmt->type_hint(), stmt->name(),
787 stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type));
788 } else {
789 print("{}{} = {} idx relation {}[{}]", stmt->type_hint(), stmt->name(),
790 stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type),
791 stmt->neighbor_idx->name());
792 }
793 }
794
795 void visit(MeshIndexConversionStmt *stmt) override {
796 print("{}{} = {} {} {}", stmt->type_hint(), stmt->name(),
797 mesh::conv_type_name(stmt->conv_type),
798 mesh::element_type_name(stmt->idx_type), stmt->idx->name());
799 }
800
801 void visit(MeshPatchIndexStmt *stmt) override {
802 print("{}{} = mesh patch idx", stmt->type_hint(), stmt->name());
803 }
804
805 void visit(FrontendExternalFuncStmt *stmt) override {
806 if (stmt->so_func != nullptr) {
807 print("so {:x}", (uint64)stmt->so_func);
808 } else if (!stmt->asm_source.empty()) {
809 print("asm \"{}\"", stmt->asm_source);
810 } else {
811 print("bc {}:{}", stmt->bc_filename, stmt->bc_funcname);
812 }
813 print(" (inputs=");
814 for (auto &s : stmt->args) {
815 print(expr_to_string(s));
816 }
817 print(", outputs=");
818 for (auto &s : stmt->outputs) {
819 print(expr_to_string(s));
820 }
821 print(")");
822 }
823
824 void visit(ReferenceStmt *stmt) override {
825 print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name());
826 }
827
828 void visit(MatrixInitStmt *stmt) override {
829 std::string result = "";
830 result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name());
831 for (int i = 0; i < stmt->values.size(); ++i) {
832 result += stmt->values[i]->name();
833 if (i != stmt->values.size() - 1) {
834 result += ", ";
835 }
836 }
837 result += "]";
838 print(result);
839 }
840
841 void visit(GetElementStmt *stmt) override {
842 print("{}{} = get_element({}, {})", stmt->type_hint(), stmt->name(),
843 stmt->src->name(), fmt::join(stmt->index, ", "));
844 }
845
846 private:
847 std::string expr_to_string(Expr &expr) {
848 return expr_to_string(expr.expr.get());
849 }
850
851 std::string expr_to_string(Expression *expr) {
852 TI_ASSERT(expr_printer_);
853 std::ostringstream oss;
854 expr_printer_->set_ostream(&oss);
855 expr->accept(expr_printer_);
856 return oss.str();
857 }
858
859 std::string expr_group_to_string(ExprGroup &expr_group) {
860 TI_ASSERT(expr_printer_);
861 std::ostringstream oss;
862 expr_printer_->set_ostream(&oss);
863 expr_printer_->visit(expr_group);
864 return oss.str();
865 }
866};
867
868} // namespace
869
870namespace irpass {
871
872void print(IRNode *root, std::string *output) {
873 ExpressionHumanFriendlyPrinter expr_printer;
874 return IRPrinter::run(&expr_printer, root, output);
875}
876
877} // namespace irpass
878
879} // namespace taichi::lang
880