1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/transforms/utils.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | namespace { |
11 | |
12 | void convert_to_range_for(OffloadedStmt *offloaded) { |
13 | TI_ASSERT(offloaded->task_type == OffloadedTaskType::mesh_for); |
14 | |
15 | DelayedIRModifier modifier; |
16 | auto stmts = irpass::analysis::gather_statements( |
17 | offloaded->body.get(), |
18 | [&](Stmt *stmt) { return stmt->is<MeshIndexConversionStmt>(); }); |
19 | for (size_t i = 0; i < stmts.size(); ++i) { |
20 | auto conv_stmt = stmts[i]->cast<MeshIndexConversionStmt>(); |
21 | if (conv_stmt->conv_type == mesh::ConvType::l2g) { |
22 | stmts[i]->replace_usages_with(conv_stmt->idx); |
23 | modifier.erase(stmts[i]); |
24 | } else if (conv_stmt->conv_type == mesh::ConvType::l2r) { |
25 | stmts[i]->as<MeshIndexConversionStmt>()->conv_type = mesh::ConvType::g2r; |
26 | } |
27 | } |
28 | |
29 | modifier.modify_ir(); |
30 | |
31 | offloaded->const_begin = true; |
32 | offloaded->const_end = true; |
33 | offloaded->begin_value = 0; |
34 | offloaded->end_value = |
35 | offloaded->mesh->num_elements.find(offloaded->major_from_type)->second; |
36 | offloaded->mesh = nullptr; |
37 | offloaded->task_type = OffloadedTaskType::range_for; |
38 | } |
39 | |
40 | void maybe_convert(OffloadedStmt *offloaded) { |
41 | if (offloaded->task_type == OffloadedTaskType::mesh_for && |
42 | offloaded->major_to_types.size() == 0) { |
43 | auto stmts = irpass::analysis::gather_statements( // ti.mesh_patch_idx() |
44 | // relies on mesh-for |
45 | offloaded->body.get(), |
46 | [&](Stmt *stmt) { return stmt->is<MeshPatchIndexStmt>(); }); |
47 | if (stmts.size() == 0) { |
48 | convert_to_range_for(offloaded); |
49 | } |
50 | } |
51 | } |
52 | |
53 | } // namespace |
54 | |
55 | namespace irpass { |
56 | |
57 | void demote_no_access_mesh_fors(IRNode *root) { |
58 | if (auto *block = root->cast<Block>()) { |
59 | for (auto &s_ : block->statements) { |
60 | if (auto *s = s_->cast<OffloadedStmt>()) { |
61 | maybe_convert(s); |
62 | } |
63 | } |
64 | } else if (auto *s = root->cast<OffloadedStmt>()) { |
65 | maybe_convert(s); |
66 | } |
67 | re_id(root); |
68 | } |
69 | |
70 | } // namespace irpass |
71 | |
72 | } // namespace taichi::lang |
73 |