1 | // The IRPrinter prints the IR in a human-readable format |
2 | |
3 | #include "taichi/ir/expression_printer.h" |
4 | #include "taichi/ir/ir.h" |
5 | #include "taichi/ir/statements.h" |
6 | #include "taichi/ir/transforms.h" |
7 | #include "taichi/ir/visitors.h" |
8 | #include "taichi/ir/frontend_ir.h" |
9 | #include "taichi/util/str.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | namespace { |
14 | |
15 | std::string scratch_pad_info(const MemoryAccessOptions &opt) { |
16 | std::string ser; |
17 | if (!opt.get_all().empty()) { |
18 | ser += "mem_access_opt [ " ; |
19 | for (auto &rec : opt.get_all()) { |
20 | for (auto flag : rec.second) { |
21 | ser += rec.first->get_node_type_name_hinted() + ":" + |
22 | snode_access_flag_name(flag) + " " ; |
23 | } |
24 | } |
25 | ser += "] " ; |
26 | } else { |
27 | ser = "none" ; |
28 | } |
29 | return ser; |
30 | } |
31 | |
32 | std::string block_dim_info(int block_dim) { |
33 | return "block_dim=" + |
34 | (block_dim == 0 ? "adaptive" : std::to_string(block_dim)) + " " ; |
35 | } |
36 | |
37 | class IRPrinter : public IRVisitor { |
38 | private: |
39 | ExpressionPrinter *expr_printer_{nullptr}; |
40 | |
41 | public: |
42 | int current_indent{0}; |
43 | |
44 | std::string *output{nullptr}; |
45 | std::stringstream ss; |
46 | |
47 | explicit IRPrinter(ExpressionPrinter *expr_printer = nullptr, |
48 | std::string *output = nullptr) |
49 | : expr_printer_(expr_printer), output(output) { |
50 | } |
51 | |
52 | template <typename... Args> |
53 | void print(std::string f, Args &&...args) { |
54 | print_raw(fmt::format(f, std::forward<Args>(args)...)); |
55 | } |
56 | |
57 | void print_raw(std::string f) { |
58 | for (int i = 0; i < current_indent; i++) |
59 | f.insert(0, " " ); |
60 | f += "\n" ; |
61 | if (output) { |
62 | ss << f; |
63 | } else { |
64 | std::cout << f; |
65 | } |
66 | } |
67 | |
68 | static void run(ExpressionPrinter *expr_printer, |
69 | IRNode *node, |
70 | std::string *output) { |
71 | if (node == nullptr) { |
72 | TI_WARN("IRPrinter: Printing nullptr." ); |
73 | if (output) { |
74 | *output = std::string(); |
75 | } |
76 | return; |
77 | } |
78 | auto p = IRPrinter(expr_printer, output); |
79 | p.print("kernel {{" ); |
80 | node->accept(&p); |
81 | p.print("}}" ); |
82 | if (output) |
83 | *output = p.ss.str(); |
84 | } |
85 | |
86 | void visit(Block *stmt_list) override { |
87 | current_indent++; |
88 | for (auto &stmt : stmt_list->statements) { |
89 | stmt->accept(this); |
90 | } |
91 | current_indent--; |
92 | } |
93 | |
94 | void visit(FrontendExprStmt *stmt) override { |
95 | print("{}" , (stmt->val)); |
96 | } |
97 | |
98 | void visit(FrontendBreakStmt *stmt) override { |
99 | print("break" ); |
100 | } |
101 | |
102 | void visit(FrontendContinueStmt *stmt) override { |
103 | print("continue" ); |
104 | } |
105 | |
106 | void visit(FrontendAssignStmt *assign) override { |
107 | print("{} = {}" , expr_to_string(assign->lhs), expr_to_string(assign->rhs)); |
108 | } |
109 | |
110 | void visit(FrontendAllocaStmt *alloca) override { |
111 | std::string shared_suffix = (alloca->is_shared) ? "(shared)" : "" ; |
112 | print("{}${} = alloca{} {}" , alloca->type_hint(), alloca->id, shared_suffix, |
113 | alloca->ident.name()); |
114 | } |
115 | |
116 | void visit(FrontendAssertStmt *assert) override { |
117 | print("{} : assert {}" , assert->name(), expr_to_string(assert->cond)); |
118 | } |
119 | |
120 | void visit(AssertStmt *assert) override { |
121 | std::string ; |
122 | for (auto &arg : assert->args) { |
123 | extras += ", " ; |
124 | extras += arg->name(); |
125 | } |
126 | print("{} : assert {}, \"{}\"{}" , assert->id, assert->cond->name(), |
127 | assert->text, extras); |
128 | } |
129 | |
130 | void visit(ExternalFuncCallStmt *stmt) override { |
131 | std::string ; |
132 | if (stmt->so_func != nullptr) { |
133 | extras += fmt::format("so {:x} " , (uint64)stmt->so_func); |
134 | } else if (!stmt->asm_source.empty()) { |
135 | extras += fmt::format("asm \"{}\" " , stmt->asm_source); |
136 | } else { |
137 | extras += fmt::format("bc {}:{} " , stmt->bc_filename, stmt->bc_funcname); |
138 | } |
139 | extras += "inputs=" ; |
140 | for (auto &arg : stmt->arg_stmts) { |
141 | extras += ", " ; |
142 | extras += arg->name(); |
143 | } |
144 | extras += "outputs=" ; |
145 | for (auto &output : stmt->output_stmts) { |
146 | extras += ", " ; |
147 | extras += output->name(); |
148 | } |
149 | print("{} : {}" , stmt->name(), extras); |
150 | } |
151 | |
152 | void visit(FrontendSNodeOpStmt *stmt) override { |
153 | std::string = "[" ; |
154 | for (int i = 0; i < (int)stmt->indices.size(); i++) { |
155 | extras += expr_to_string(stmt->indices[i]); |
156 | if (i + 1 < (int)stmt->indices.size()) |
157 | extras += ", " ; |
158 | } |
159 | extras += "]" ; |
160 | if (stmt->val.expr) { |
161 | extras += ", " + expr_to_string(stmt->val); |
162 | } |
163 | print("{} : {} {} {}" , stmt->name(), snode_op_type_name(stmt->op_type), |
164 | stmt->snode->get_node_type_name_hinted(), extras); |
165 | } |
166 | |
167 | void visit(SNodeOpStmt *stmt) override { |
168 | std::string ; |
169 | if (stmt->ptr) |
170 | extras = "ptr = " + stmt->ptr->name(); |
171 | if (stmt->val) { |
172 | extras += ", val = " + stmt->val->name(); |
173 | } |
174 | std::string snode = stmt->snode->get_node_type_name_hinted(); |
175 | print("{}{} = {} [{}] {}" , stmt->type_hint(), stmt->name(), |
176 | snode_op_type_name(stmt->op_type), snode, extras); |
177 | } |
178 | |
179 | void visit(AllocaStmt *alloca) override { |
180 | std::string shared_suffix = (alloca->is_shared) ? "(shared)" : "" ; |
181 | print("{}${} = alloca{}" , alloca->type_hint(), alloca->id, shared_suffix); |
182 | } |
183 | |
184 | void visit(RandStmt *stmt) override { |
185 | print("{}{} = rand()" , stmt->type_hint(), stmt->name()); |
186 | } |
187 | |
188 | void visit(DecorationStmt *stmt) override { |
189 | if (stmt->decoration.size() == 2 && |
190 | stmt->decoration[0] == |
191 | uint32_t(DecorationStmt::Decoration::kLoopUnique)) { |
192 | print("decorate {} : Loop-unique {}" , stmt->operand->name(), |
193 | stmt->decoration[0], stmt->decoration[1]); |
194 | } else { |
195 | print("decorate {} : ... size = {}" , stmt->operand->name(), |
196 | stmt->decoration.size()); |
197 | } |
198 | } |
199 | |
200 | void visit(UnaryOpStmt *stmt) override { |
201 | if (stmt->is_cast()) { |
202 | std::string reint = |
203 | stmt->op_type == UnaryOpType::cast_value ? "" : "reinterpret_" ; |
204 | print("{}{} = {}{}<{}> {}" , stmt->type_hint(), stmt->name(), reint, |
205 | unary_op_type_name(stmt->op_type), data_type_name(stmt->cast_type), |
206 | stmt->operand->name()); |
207 | } else { |
208 | print("{}{} = {} {}" , stmt->type_hint(), stmt->name(), |
209 | unary_op_type_name(stmt->op_type), stmt->operand->name()); |
210 | } |
211 | } |
212 | |
213 | void visit(BinaryOpStmt *bin) override { |
214 | print("{}{} = {} {} {}" , bin->type_hint(), bin->name(), |
215 | binary_op_type_name(bin->op_type), bin->lhs->name(), |
216 | bin->rhs->name()); |
217 | } |
218 | |
219 | void visit(TernaryOpStmt *stmt) override { |
220 | print("{}{} = {}({}, {}, {})" , stmt->type_hint(), stmt->name(), |
221 | ternary_type_name(stmt->op_type), stmt->op1->name(), |
222 | stmt->op2->name(), stmt->op3->name()); |
223 | } |
224 | |
225 | void visit(AtomicOpStmt *stmt) override { |
226 | print("{}{} = atomic {}({}, {})" , stmt->type_hint(), stmt->name(), |
227 | atomic_op_type_name(stmt->op_type), stmt->dest->name(), |
228 | stmt->val->name()); |
229 | } |
230 | |
231 | void visit(IfStmt *if_stmt) override { |
232 | print("{} : if {} {{" , if_stmt->name(), if_stmt->cond->name()); |
233 | if (if_stmt->true_statements) |
234 | if_stmt->true_statements->accept(this); |
235 | if (if_stmt->false_statements) { |
236 | print("}} else {{" ); |
237 | if_stmt->false_statements->accept(this); |
238 | } |
239 | print("}}" ); |
240 | } |
241 | |
242 | void visit(FrontendIfStmt *if_stmt) override { |
243 | print("{} : if {} {{" , if_stmt->name(), expr_to_string(if_stmt->condition)); |
244 | if (if_stmt->true_statements) |
245 | if_stmt->true_statements->accept(this); |
246 | if (if_stmt->false_statements) { |
247 | print("}} else {{" ); |
248 | if_stmt->false_statements->accept(this); |
249 | } |
250 | print("}}" ); |
251 | } |
252 | |
253 | void visit(FrontendPrintStmt *print_stmt) override { |
254 | std::vector<std::string> contents; |
255 | for (auto const &c : print_stmt->contents) { |
256 | std::string name; |
257 | if (std::holds_alternative<Expr>(c)) |
258 | name = expr_to_string(std::get<Expr>(c).expr.get()); |
259 | else |
260 | name = c_quoted(std::get<std::string>(c)); |
261 | contents.push_back(name); |
262 | } |
263 | print("print {}" , fmt::join(contents, ", " )); |
264 | } |
265 | |
266 | void visit(PrintStmt *print_stmt) override { |
267 | std::vector<std::string> names; |
268 | for (auto const &c : print_stmt->contents) { |
269 | std::string name; |
270 | if (std::holds_alternative<Stmt *>(c)) |
271 | name = std::get<Stmt *>(c)->name(); |
272 | else |
273 | name = c_quoted(std::get<std::string>(c)); |
274 | names.push_back(name); |
275 | } |
276 | print("print {}" , fmt::join(names, ", " )); |
277 | } |
278 | |
279 | void visit(ConstStmt *const_stmt) override { |
280 | print("{}{} = const {}" , const_stmt->type_hint(), const_stmt->name(), |
281 | const_stmt->val.stringify()); |
282 | } |
283 | |
284 | void visit(WhileControlStmt *stmt) override { |
285 | print("{} : while control {}, {}" , stmt->name(), |
286 | stmt->mask ? stmt->mask->name() : "nullptr" , stmt->cond->name()); |
287 | } |
288 | |
289 | void visit(ContinueStmt *stmt) override { |
290 | if (stmt->scope) { |
291 | print("{} continue (scope={})" , stmt->name(), stmt->scope->name()); |
292 | } else { |
293 | print("{} continue" , stmt->name()); |
294 | } |
295 | } |
296 | |
297 | void visit(FrontendFuncCallStmt *stmt) override { |
298 | std::string args; |
299 | for (int i = 0; i < stmt->args.exprs.size(); i++) { |
300 | if (i) { |
301 | args += ", " ; |
302 | } |
303 | args += expr_to_string(stmt->args.exprs[i]); |
304 | } |
305 | print("{}${} = call \"{}\", args = ({}), ret = {}" , stmt->type_hint(), |
306 | stmt->id, stmt->func->get_name(), args, stmt->ident->name()); |
307 | } |
308 | |
309 | void visit(FuncCallStmt *stmt) override { |
310 | std::vector<std::string> args; |
311 | for (const auto &arg : stmt->args) { |
312 | args.push_back(arg->name()); |
313 | } |
314 | print("{}{} = call \"{}\", args = {{{}}}" , stmt->type_hint(), stmt->name(), |
315 | stmt->func->get_name(), fmt::join(args, ", " )); |
316 | } |
317 | |
318 | void visit(FrontendFuncDefStmt *stmt) override { |
319 | print("function \"{}\" {{" , stmt->funcid); |
320 | stmt->body->accept(this); |
321 | print("}}" ); |
322 | } |
323 | |
324 | void visit(WhileStmt *stmt) override { |
325 | print("{} : while true {{" , stmt->name()); |
326 | stmt->body->accept(this); |
327 | print("}}" ); |
328 | } |
329 | |
330 | void visit(FrontendWhileStmt *stmt) override { |
331 | print("{} : while {} {{" , stmt->name(), expr_to_string(stmt->cond)); |
332 | stmt->body->accept(this); |
333 | print("}}" ); |
334 | } |
335 | |
336 | void visit(FrontendForStmt *for_stmt) override { |
337 | auto vars = make_list<Identifier>( |
338 | for_stmt->loop_var_ids, |
339 | [](const Identifier &id) -> std::string { return id.name(); }); |
340 | if (for_stmt->snode) { |
341 | print("{} : for {} in {} {}{}{{" , for_stmt->name(), vars, |
342 | for_stmt->snode->get_node_type_name_hinted(), |
343 | scratch_pad_info(for_stmt->mem_access_opt), |
344 | block_dim_info(for_stmt->block_dim)); |
345 | } else if (for_stmt->external_tensor) { |
346 | print("{} : for {} in {} {}{}{{" , for_stmt->name(), vars, |
347 | expr_to_string(for_stmt->external_tensor), |
348 | scratch_pad_info(for_stmt->mem_access_opt), |
349 | block_dim_info(for_stmt->block_dim)); |
350 | } else if (for_stmt->mesh) { |
351 | print("{} : for {} in mesh {{" , for_stmt->name(), vars); |
352 | } else { |
353 | print("{} : for {} in range({}, {}) {}{{" , for_stmt->name(), vars, |
354 | expr_to_string(for_stmt->begin), expr_to_string(for_stmt->end), |
355 | block_dim_info(for_stmt->block_dim)); |
356 | } |
357 | for_stmt->body->accept(this); |
358 | print("}}" ); |
359 | } |
360 | |
361 | void visit(RangeForStmt *for_stmt) override { |
362 | print("{} : {}for in range({}, {}) {}{}{{" , for_stmt->name(), |
363 | for_stmt->reversed ? "reversed " : "" , for_stmt->begin->name(), |
364 | for_stmt->end->name(), |
365 | for_stmt->is_bit_vectorized ? "(bit_vectorized) " : "" , |
366 | block_dim_info(for_stmt->block_dim)); |
367 | for_stmt->body->accept(this); |
368 | print("}}" ); |
369 | } |
370 | |
371 | void visit(StructForStmt *for_stmt) override { |
372 | print("{} : struct for in {} {}{}{}{{" , for_stmt->name(), |
373 | for_stmt->snode->get_node_type_name_hinted(), |
374 | for_stmt->is_bit_vectorized ? "(bit_vectorized) " : "" , |
375 | scratch_pad_info(for_stmt->mem_access_opt), |
376 | block_dim_info(for_stmt->block_dim)); |
377 | for_stmt->body->accept(this); |
378 | print("}}" ); |
379 | } |
380 | |
381 | void visit(MeshForStmt *for_stmt) override { |
382 | print("{} : mesh for ({} -> {}) {}{{" , for_stmt->name(), |
383 | mesh::element_type_name(for_stmt->major_from_type), |
384 | for_stmt->major_to_types.size() == 0 |
385 | ? "Unknown" |
386 | : mesh::element_type_name(*for_stmt->major_to_types.begin()), |
387 | scratch_pad_info(for_stmt->mem_access_opt)); |
388 | for_stmt->body->accept(this); |
389 | print("}}" ); |
390 | } |
391 | |
392 | void visit(MatrixOfGlobalPtrStmt *stmt) override { |
393 | std::string s = fmt::format("{}{} = matrix of global ptr [" , |
394 | stmt->type_hint(), stmt->name()); |
395 | |
396 | for (int i = 0; i < (int)stmt->snodes.size(); i++) { |
397 | s += fmt::format("{}" , stmt->snodes[i]->get_node_type_name_hinted()); |
398 | if (i + 1 < (int)stmt->snodes.size()) { |
399 | s += ", " ; |
400 | } |
401 | } |
402 | s += "], index [" ; |
403 | for (int i = 0; i < (int)stmt->indices.size(); i++) { |
404 | s += fmt::format("{}" , stmt->indices[i]->name()); |
405 | if (i + 1 < (int)stmt->indices.size()) { |
406 | s += ", " ; |
407 | } |
408 | } |
409 | s += "]" ; |
410 | |
411 | s += " activate=" + std::string(stmt->activate ? "true" : "false" ); |
412 | |
413 | print_raw(s); |
414 | } |
415 | |
416 | void visit(GlobalPtrStmt *stmt) override { |
417 | std::string s = |
418 | fmt::format("{}{} = global ptr [" , stmt->type_hint(), stmt->name()); |
419 | |
420 | std::string snode_name; |
421 | if (stmt->snode) { |
422 | snode_name = stmt->snode->get_node_type_name_hinted(); |
423 | } else { |
424 | snode_name = "unknown" ; |
425 | } |
426 | s += snode_name; |
427 | s += "], index [" ; |
428 | for (int i = 0; i < (int)stmt->indices.size(); i++) { |
429 | s += fmt::format("{}" , stmt->indices[i]->name()); |
430 | if (i + 1 < (int)stmt->indices.size()) { |
431 | s += ", " ; |
432 | } |
433 | } |
434 | s += "]" ; |
435 | |
436 | s += " activate=" + std::string(stmt->activate ? "true" : "false" ); |
437 | |
438 | print_raw(s); |
439 | } |
440 | |
441 | void visit(MatrixOfMatrixPtrStmt *stmt) override { |
442 | std::string s = fmt::format("{}{} = matrix of matrix ptr [" , |
443 | stmt->type_hint(), stmt->name()); |
444 | for (int i = 0; i < (int)stmt->stmts.size(); i++) { |
445 | s += fmt::format("{}" , stmt->stmts[i]->name()); |
446 | if (i + 1 < (int)stmt->stmts.size()) { |
447 | s += ", " ; |
448 | } |
449 | } |
450 | s += "]" ; |
451 | print_raw(s); |
452 | } |
453 | |
454 | void visit(MatrixPtrStmt *stmt) override { |
455 | std::string s = |
456 | fmt::format("{}{} = shift ptr [{} + {}]" , stmt->type_hint(), |
457 | stmt->name(), stmt->origin->name(), stmt->offset->name()); |
458 | print_raw(s); |
459 | } |
460 | |
461 | void visit(ArgLoadStmt *stmt) override { |
462 | if (!stmt->is_grad) { |
463 | print("{}{} = arg[{}]" , stmt->type_hint(), stmt->name(), stmt->arg_id); |
464 | } else { |
465 | print("{}{} = grad_arg[{}]" , stmt->type_hint(), stmt->name(), |
466 | stmt->arg_id); |
467 | } |
468 | } |
469 | |
470 | void visit(TexturePtrStmt *stmt) override { |
471 | print("<*Texture> {} = {}" , stmt->name(), stmt->arg_load_stmt->name()); |
472 | } |
473 | |
474 | void visit(TextureOpStmt *stmt) override { |
475 | std::string args_string = "" ; |
476 | for (int i = 0; i < (int)stmt->args.size(); i++) { |
477 | args_string += fmt::format("{}" , stmt->args[i]->name()); |
478 | if (i + 1 < (int)stmt->args.size()) { |
479 | args_string += ", " ; |
480 | } |
481 | } |
482 | |
483 | print("<struct> {} = texture_{}({})" , stmt->name(), |
484 | texture_op_type_name(stmt->op), args_string); |
485 | } |
486 | |
487 | void visit(FrontendReturnStmt *stmt) override { |
488 | print("{}{} : return [{}]" , stmt->type_hint(), stmt->name(), |
489 | expr_group_to_string(stmt->values)); |
490 | } |
491 | |
492 | void visit(ReturnStmt *stmt) override { |
493 | print("{}{} : return {}" , stmt->type_hint(), stmt->name(), |
494 | stmt->values_raw_names()); |
495 | } |
496 | |
497 | void visit(LocalLoadStmt *stmt) override { |
498 | print("{}{} = local load [{}]" , stmt->type_hint(), stmt->name(), |
499 | stmt->src->name()); |
500 | } |
501 | |
502 | void visit(LocalStoreStmt *stmt) override { |
503 | print("{}{} : local store [{} <- {}]" , stmt->type_hint(), stmt->name(), |
504 | stmt->dest->name(), stmt->val->name()); |
505 | } |
506 | |
507 | void visit(GlobalLoadStmt *stmt) override { |
508 | print("{}{} = global load {}" , stmt->type_hint(), stmt->name(), |
509 | stmt->src->name()); |
510 | } |
511 | |
512 | void visit(GlobalStoreStmt *stmt) override { |
513 | print("{}{} : global store [{} <- {}]" , stmt->type_hint(), stmt->name(), |
514 | stmt->dest->name(), stmt->val->name()); |
515 | } |
516 | |
517 | void visit(RangeAssumptionStmt *stmt) override { |
518 | print("{}{} = assume_in_range({}{:+d} <= {} < {}{:+d})" , stmt->type_hint(), |
519 | stmt->name(), stmt->base->name(), stmt->low, stmt->input->name(), |
520 | stmt->base->name(), stmt->high); |
521 | } |
522 | |
523 | void visit(LoopUniqueStmt *stmt) override { |
524 | std::string add = "" ; |
525 | if (!stmt->covers.empty()) { |
526 | add = ", covers=[" ; |
527 | for (const auto &sn : stmt->covers) { |
528 | add += fmt::format("S{}, " , sn); |
529 | } |
530 | add.erase(add.size() - 2, 2); // remove the last ", " |
531 | add += "]" ; |
532 | } |
533 | print("{}{} = loop_unique({}{})" , stmt->type_hint(), stmt->name(), |
534 | stmt->input->name(), add); |
535 | } |
536 | |
537 | void visit(LinearizeStmt *stmt) override { |
538 | auto ind = make_list<Stmt *>( |
539 | stmt->inputs, [&](Stmt *const &stmt) { return stmt->name(); }, "{" ); |
540 | auto stride = make_list<int>( |
541 | stmt->strides, |
542 | [&](const int &stride) { return std::to_string(stride); }, "{" ); |
543 | |
544 | print("{}{} = linearized(ind {}, stride {})" , stmt->type_hint(), |
545 | stmt->name(), ind, stride); |
546 | } |
547 | |
548 | void visit(IntegerOffsetStmt *stmt) override { |
549 | print("{}{} = offset {} + {}" , stmt->type_hint(), stmt->name(), |
550 | stmt->input->name(), stmt->offset); |
551 | } |
552 | |
553 | void visit(GetRootStmt *stmt) override { |
554 | if (stmt->root() == nullptr) |
555 | print("{}{} = get root nullptr" , stmt->type_hint(), stmt->name()); |
556 | else |
557 | print("{}{} = get root [{}][{}]" , stmt->type_hint(), stmt->name(), |
558 | stmt->root()->get_node_type_name_hinted(), |
559 | stmt->root()->type_name()); |
560 | } |
561 | |
562 | void visit(SNodeLookupStmt *stmt) override { |
563 | print("{}{} = [{}][{}]::lookup({}, {}) activate = {}" , stmt->type_hint(), |
564 | stmt->name(), stmt->snode->get_node_type_name_hinted(), |
565 | stmt->snode->type_name(), stmt->input_snode->name(), |
566 | stmt->input_index->name(), stmt->activate); |
567 | } |
568 | |
569 | void visit(GetChStmt *stmt) override { |
570 | print("{}{} = get child [{}->{}] {}" , stmt->type_hint(), stmt->name(), |
571 | stmt->input_snode->get_node_type_name_hinted(), |
572 | stmt->output_snode->get_node_type_name_hinted(), |
573 | stmt->input_ptr->name()); |
574 | } |
575 | |
576 | void visit(ExternalPtrStmt *stmt) override { |
577 | std::string s = stmt->base_ptr->name(); |
578 | s += ", [" ; |
579 | for (int i = 0; i < (int)stmt->indices.size(); i++) { |
580 | s += fmt::format("{}" , stmt->indices[i]->name()); |
581 | if (i + 1 < (int)stmt->indices.size()) { |
582 | s += ", " ; |
583 | } |
584 | } |
585 | s += "]" ; |
586 | if (stmt->element_shape.size()) { |
587 | s += ", (" ; |
588 | for (int i = 0; i < (int)stmt->element_shape.size(); i++) { |
589 | s += fmt::format("{}" , stmt->element_shape[i]); |
590 | if (i + 1 < (int)stmt->element_shape.size()) { |
591 | s += ", " ; |
592 | } |
593 | } |
594 | s += ")" ; |
595 | } |
596 | s += fmt::format(" element_dim={} layout={} is_grad={}" , stmt->element_dim, |
597 | (stmt->element_dim <= 0) ? "AOS" : "SOA" , |
598 | stmt->base_ptr->as<ArgLoadStmt>()->is_grad); |
599 | |
600 | print(fmt::format("{}{} = external_ptr {}" , stmt->type_hint(), stmt->name(), |
601 | s)); |
602 | } |
603 | |
604 | void visit(OffloadedStmt *stmt) override { |
605 | std::string details; |
606 | if (stmt->task_type == OffloadedTaskType::range_for) { |
607 | std::string begin_str, end_str; |
608 | if (stmt->const_begin) { |
609 | begin_str = std::to_string(stmt->begin_value); |
610 | } else { |
611 | begin_str = fmt::format("tmp(offset={}B)" , stmt->begin_offset); |
612 | } |
613 | if (stmt->const_end) { |
614 | end_str = std::to_string(stmt->end_value); |
615 | } else if (stmt->end_stmt && !stmt->end_stmt->is<ConstStmt>()) { |
616 | // range_for end is a non-const stmt (e.g. ndarray axis) |
617 | end_str = stmt->end_stmt->name(); |
618 | } else { |
619 | end_str = fmt::format("tmp(offset={}B)" , stmt->end_offset); |
620 | } |
621 | details = |
622 | fmt::format("range_for({}, {}) grid_dim={} block_dim={}" , begin_str, |
623 | end_str, stmt->grid_dim, stmt->block_dim); |
624 | } else if (stmt->task_type == OffloadedTaskType::struct_for) { |
625 | details = |
626 | fmt::format("struct_for({}) grid_dim={} block_dim={} bls={}" , |
627 | stmt->snode->get_node_type_name_hinted(), stmt->grid_dim, |
628 | stmt->block_dim, scratch_pad_info(stmt->mem_access_opt)); |
629 | } else if (stmt->task_type == OffloadedTaskType::mesh_for) { |
630 | details = fmt::format( |
631 | "mesh_for({} -> {}) num_patches={} grid_dim={} block_dim={} bls={}" , |
632 | mesh::element_type_name(stmt->major_from_type), |
633 | stmt->major_to_types.size() == 0 |
634 | ? "Unknown" |
635 | : mesh::element_type_name(*stmt->major_to_types.begin()), |
636 | stmt->mesh->num_patches, stmt->grid_dim, stmt->block_dim, |
637 | scratch_pad_info(stmt->mem_access_opt)); |
638 | } |
639 | if (stmt->task_type == OffloadedTaskType::listgen) { |
640 | print("{} = offloaded listgen {}->{}" , stmt->name(), |
641 | stmt->snode->parent->get_node_type_name_hinted(), |
642 | stmt->snode->get_node_type_name_hinted()); |
643 | } else if (stmt->task_type == OffloadedTaskType::gc) { |
644 | print("{} = offloaded garbage collect {}" , stmt->name(), |
645 | stmt->snode->get_node_type_name_hinted()); |
646 | } else if (stmt->task_type == OffloadedTaskType::gc_rc) { |
647 | print("{} = offloaded garbage collect runtime context" , stmt->name()); |
648 | } else { |
649 | print("{} = offloaded {} " , stmt->name(), details); |
650 | if (stmt->tls_prologue) { |
651 | print("tls prologue {{" ); |
652 | stmt->tls_prologue->accept(this); |
653 | print("}}" ); |
654 | } |
655 | if (stmt->mesh_prologue) { |
656 | TI_ASSERT(stmt->task_type == OffloadedTaskType::mesh_for); |
657 | print("body prologue {{" ); |
658 | stmt->mesh_prologue->accept(this); |
659 | print("}}" ); |
660 | } |
661 | if (stmt->bls_prologue) { |
662 | print("bls prologue {{" ); |
663 | stmt->bls_prologue->accept(this); |
664 | print("}}" ); |
665 | } |
666 | TI_ASSERT(stmt->body); |
667 | print("body {{" ); |
668 | stmt->body->accept(this); |
669 | print("}}" ); |
670 | if (stmt->bls_epilogue) { |
671 | print("bls_epilogue {{" ); |
672 | stmt->bls_epilogue->accept(this); |
673 | print("}}" ); |
674 | } |
675 | if (stmt->tls_epilogue) { |
676 | print("tls_epilogue {{" ); |
677 | stmt->tls_epilogue->accept(this); |
678 | print("}}" ); |
679 | } |
680 | } |
681 | } |
682 | |
683 | void visit(ClearListStmt *stmt) override { |
684 | print("{} = clear_list {}" , stmt->name(), |
685 | stmt->snode->get_node_type_name_hinted()); |
686 | } |
687 | |
688 | void visit(LoopIndexStmt *stmt) override { |
689 | print("{}{} = loop {} index {}" , stmt->type_hint(), stmt->name(), |
690 | stmt->loop->name(), stmt->index); |
691 | } |
692 | |
693 | void visit(LoopLinearIndexStmt *stmt) override { |
694 | print("{}{} = loop {} index linear" , stmt->type_hint(), stmt->name(), |
695 | stmt->loop->name()); |
696 | } |
697 | |
698 | void visit(BlockCornerIndexStmt *stmt) override { |
699 | print("{}{} = loop {} block corner index {}" , stmt->type_hint(), |
700 | stmt->name(), stmt->loop->name(), stmt->index); |
701 | } |
702 | |
703 | void visit(GlobalTemporaryStmt *stmt) override { |
704 | print("{}{} = global tmp var (offset = {} B)" , stmt->type_hint(), |
705 | stmt->name(), stmt->offset); |
706 | } |
707 | |
708 | void visit(ThreadLocalPtrStmt *stmt) override { |
709 | print("{}{} = thread local ptr (offset = {} B)" , stmt->type_hint(), |
710 | stmt->name(), stmt->offset); |
711 | } |
712 | |
713 | void visit(BlockLocalPtrStmt *stmt) override { |
714 | print("{}{} = block local ptr (offset = {})" , stmt->type_hint(), |
715 | stmt->name(), stmt->offset->name()); |
716 | } |
717 | |
718 | void visit(InternalFuncStmt *stmt) override { |
719 | std::string args; |
720 | bool first = true; |
721 | for (auto &arg : stmt->args) { |
722 | if (!first) { |
723 | args += ", " ; |
724 | } |
725 | args += arg->name(); |
726 | first = false; |
727 | } |
728 | print("{}{} = internal call {}({})" , stmt->type_hint(), stmt->name(), |
729 | stmt->func_name, args); |
730 | } |
731 | |
732 | void visit(AdStackAllocaStmt *stmt) override { |
733 | print("{}{} = stack alloc (max_size={})" , stmt->type_hint(), stmt->name(), |
734 | stmt->max_size); |
735 | } |
736 | |
737 | void visit(AdStackLoadTopStmt *stmt) override { |
738 | print("{}{} = stack load top {}" , stmt->type_hint(), stmt->name(), |
739 | stmt->stack->name()); |
740 | } |
741 | |
742 | void visit(AdStackLoadTopAdjStmt *stmt) override { |
743 | print("{}{} = stack load top adj {}" , stmt->type_hint(), stmt->name(), |
744 | stmt->stack->name()); |
745 | } |
746 | |
747 | void visit(AdStackPushStmt *stmt) override { |
748 | print("{}{} : stack push {}, val = {}" , stmt->type_hint(), stmt->name(), |
749 | stmt->stack->name(), stmt->v->name()); |
750 | } |
751 | |
752 | void visit(AdStackPopStmt *stmt) override { |
753 | print("{}{} : stack pop {}" , stmt->type_hint(), stmt->name(), |
754 | stmt->stack->name()); |
755 | } |
756 | |
757 | void visit(AdStackAccAdjointStmt *stmt) override { |
758 | print("{}{} : stack acc adj {}, val = {}" , stmt->type_hint(), stmt->name(), |
759 | stmt->stack->name(), stmt->v->name()); |
760 | } |
761 | |
762 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
763 | print("{}{} = external_tensor_shape_along_axis {}, arg_id {}" , |
764 | stmt->type_hint(), stmt->name(), stmt->axis, stmt->arg_id); |
765 | } |
766 | |
767 | void visit(BitStructStoreStmt *stmt) override { |
768 | std::string ch_ids; |
769 | std::string values; |
770 | for (int i = 0; i < stmt->ch_ids.size(); i++) { |
771 | ch_ids += fmt::format("{}" , stmt->ch_ids[i]); |
772 | values += fmt::format("{}" , stmt->values[i]->name()); |
773 | if (i != stmt->ch_ids.size() - 1) { |
774 | ch_ids += ", " ; |
775 | values += ", " ; |
776 | } |
777 | } |
778 | print("{} : {}bit_struct_store {}, ch_ids=[{}], values=[{}]" , stmt->name(), |
779 | stmt->is_atomic ? "atomic " : "" , stmt->ptr->name(), ch_ids, values); |
780 | } |
781 | |
782 | // Mesh related. |
783 | |
784 | void visit(MeshRelationAccessStmt *stmt) override { |
785 | if (stmt->is_size()) { |
786 | print("{}{} = {} idx relation {} size" , stmt->type_hint(), stmt->name(), |
787 | stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type)); |
788 | } else { |
789 | print("{}{} = {} idx relation {}[{}]" , stmt->type_hint(), stmt->name(), |
790 | stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type), |
791 | stmt->neighbor_idx->name()); |
792 | } |
793 | } |
794 | |
795 | void visit(MeshIndexConversionStmt *stmt) override { |
796 | print("{}{} = {} {} {}" , stmt->type_hint(), stmt->name(), |
797 | mesh::conv_type_name(stmt->conv_type), |
798 | mesh::element_type_name(stmt->idx_type), stmt->idx->name()); |
799 | } |
800 | |
801 | void visit(MeshPatchIndexStmt *stmt) override { |
802 | print("{}{} = mesh patch idx" , stmt->type_hint(), stmt->name()); |
803 | } |
804 | |
805 | void visit(FrontendExternalFuncStmt *stmt) override { |
806 | if (stmt->so_func != nullptr) { |
807 | print("so {:x}" , (uint64)stmt->so_func); |
808 | } else if (!stmt->asm_source.empty()) { |
809 | print("asm \"{}\"" , stmt->asm_source); |
810 | } else { |
811 | print("bc {}:{}" , stmt->bc_filename, stmt->bc_funcname); |
812 | } |
813 | print(" (inputs=" ); |
814 | for (auto &s : stmt->args) { |
815 | print(expr_to_string(s)); |
816 | } |
817 | print(", outputs=" ); |
818 | for (auto &s : stmt->outputs) { |
819 | print(expr_to_string(s)); |
820 | } |
821 | print(")" ); |
822 | } |
823 | |
824 | void visit(ReferenceStmt *stmt) override { |
825 | print("{}{} = ref({})" , stmt->type_hint(), stmt->name(), stmt->var->name()); |
826 | } |
827 | |
828 | void visit(MatrixInitStmt *stmt) override { |
829 | std::string result = "" ; |
830 | result += fmt::format("{}{} = [" , stmt->type_hint(), stmt->name()); |
831 | for (int i = 0; i < stmt->values.size(); ++i) { |
832 | result += stmt->values[i]->name(); |
833 | if (i != stmt->values.size() - 1) { |
834 | result += ", " ; |
835 | } |
836 | } |
837 | result += "]" ; |
838 | print(result); |
839 | } |
840 | |
841 | void visit(GetElementStmt *stmt) override { |
842 | print("{}{} = get_element({}, {})" , stmt->type_hint(), stmt->name(), |
843 | stmt->src->name(), fmt::join(stmt->index, ", " )); |
844 | } |
845 | |
846 | private: |
847 | std::string expr_to_string(Expr &expr) { |
848 | return expr_to_string(expr.expr.get()); |
849 | } |
850 | |
851 | std::string expr_to_string(Expression *expr) { |
852 | TI_ASSERT(expr_printer_); |
853 | std::ostringstream oss; |
854 | expr_printer_->set_ostream(&oss); |
855 | expr->accept(expr_printer_); |
856 | return oss.str(); |
857 | } |
858 | |
859 | std::string expr_group_to_string(ExprGroup &expr_group) { |
860 | TI_ASSERT(expr_printer_); |
861 | std::ostringstream oss; |
862 | expr_printer_->set_ostream(&oss); |
863 | expr_printer_->visit(expr_group); |
864 | return oss.str(); |
865 | } |
866 | }; |
867 | |
868 | } // namespace |
869 | |
870 | namespace irpass { |
871 | |
872 | void print(IRNode *root, std::string *output) { |
873 | ExpressionHumanFriendlyPrinter expr_printer; |
874 | return IRPrinter::run(&expr_printer, root, output); |
875 | } |
876 | |
877 | } // namespace irpass |
878 | |
879 | } // namespace taichi::lang |
880 | |