1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/transforms/demote_mesh_statements.h" |
6 | #include "taichi/ir/visitors.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | const PassID DemoteMeshStatements::id = "DemoteMeshStatements" ; |
11 | |
12 | namespace irpass { |
13 | |
14 | auto get_load = [](SNode *snode, Stmt *idx, VecStatement &block) { |
15 | Stmt *casted_idx = block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, idx); |
16 | casted_idx->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
17 | const auto lane = std::vector<Stmt *>{casted_idx}; |
18 | Stmt *globalptr = block.push_back<GlobalPtrStmt>(snode, lane); |
19 | Stmt *load = block.push_back<GlobalLoadStmt>(globalptr); |
20 | return load; |
21 | }; |
22 | |
23 | class ReplaceIndexConversion : public BasicStmtVisitor { |
24 | public: |
25 | using BasicStmtVisitor::visit; |
26 | |
27 | explicit ReplaceIndexConversion(OffloadedStmt *node) { |
28 | allow_undefined_visitor = true; |
29 | invoke_default_visitor = true; |
30 | |
31 | offload = node; |
32 | visit(node); |
33 | } |
34 | |
35 | void visit(MeshIndexConversionStmt *stmt) override { |
36 | SNode *mapping = (stmt->mesh->index_mapping |
37 | .find(std::make_pair(stmt->idx_type, stmt->conv_type)) |
38 | ->second); |
39 | |
40 | VecStatement block; |
41 | Stmt *val; |
42 | if (stmt->conv_type == mesh::ConvType::g2r) { |
43 | // E.g, v_reordered = v_g2r[v_global] |
44 | val = get_load(mapping, stmt->idx, block); |
45 | } else { |
46 | // E.g, v_global = v_l2g[v_local + total_vertices_offset] |
47 | Stmt *offset = offload->total_offset_local.find(stmt->idx_type)->second; |
48 | Stmt *index = |
49 | block.push_back<BinaryOpStmt>(BinaryOpType::add, stmt->idx, offset); |
50 | val = get_load(mapping, index, block); |
51 | } |
52 | Stmt *casted_val = |
53 | block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, val); |
54 | casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
55 | stmt->replace_with(std::move(block)); |
56 | } |
57 | |
58 | OffloadedStmt *offload; |
59 | }; |
60 | |
61 | void demote_mesh_statements_offload(OffloadedStmt *offload, |
62 | const CompileConfig &config, |
63 | const std::string &kernel_name) { |
64 | ReplaceIndexConversion rep_conv( |
65 | offload); // This demote should work for any offloaed statement |
66 | |
67 | if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { |
68 | return; |
69 | } |
70 | |
71 | auto stmts = irpass::analysis::gather_statements( |
72 | offload->body.get(), |
73 | [&](Stmt *stmt) { return stmt->is<MeshRelationAccessStmt>(); }); |
74 | |
75 | for (int i = stmts.size() - 1; i >= 0; --i) { |
76 | auto stmt = stmts[i]->cast<MeshRelationAccessStmt>(); |
77 | mesh::MeshElementType from_type = stmt->from_type(); |
78 | |
79 | auto from_order = mesh::element_order(from_type); |
80 | auto to_order = mesh::element_order(stmt->to_type); |
81 | mesh::MeshRelationType rel_type = |
82 | mesh::relation_by_orders(from_order, to_order); |
83 | if (from_order > to_order) { // high-to-low relation |
84 | if (stmt->is_size()) { |
85 | stmt->replace_with(Stmt::make<ConstStmt>( |
86 | TypedConstant{from_type == mesh::MeshElementType::Cell && |
87 | stmt->to_type == mesh::MeshElementType::Edge |
88 | ? /*Cell-Edge=*/6 |
89 | : (from_order + 1)})); |
90 | } else { |
91 | SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value; |
92 | VecStatement block; |
93 | Stmt *to_size = block.push_back<ConstStmt>( |
94 | TypedConstant{from_type == mesh::MeshElementType::Cell && |
95 | stmt->to_type == mesh::MeshElementType::Edge |
96 | ? /*Cell-Edge=*/6 |
97 | : (from_order + 1)}); |
98 | // E.g, v_2 = CV[(c + total_cells_offset) * 4 + 2] |
99 | Stmt *offset = offload->total_offset_local.find(from_type)->second; |
100 | Stmt *tmp0 = block.push_back<BinaryOpStmt>(BinaryOpType::add, offset, |
101 | stmt->mesh_idx); |
102 | Stmt *tmp1 = |
103 | block.push_back<BinaryOpStmt>(BinaryOpType::mul, tmp0, to_size); |
104 | Stmt *index = block.push_back<BinaryOpStmt>(BinaryOpType::add, tmp1, |
105 | stmt->neighbor_idx); |
106 | [[maybe_unused]] Stmt *val = get_load(rel_value, index, block); |
107 | stmt->replace_with(std::move(block)); |
108 | } |
109 | } else { // low-to-high or same-order |
110 | const auto &rel = stmt->mesh->relations.find(rel_type)->second; |
111 | SNode *rel_offset = rel.offset; |
112 | SNode *rel_patch_offset = rel.patch_offset; |
113 | VecStatement block; |
114 | Stmt *patch_idx = block.push_back<MeshPatchIndexStmt>(); |
115 | Stmt *owned_offset = offload->owned_offset_local.find(from_type)->second; |
116 | Stmt *patch_offset = get_load(rel_patch_offset, patch_idx, block); |
117 | Stmt *index_offset = block.push_back<BinaryOpStmt>( |
118 | BinaryOpType::add, patch_idx, owned_offset); |
119 | Stmt *index = block.push_back<BinaryOpStmt>(BinaryOpType::add, |
120 | index_offset, stmt->mesh_idx); |
121 | Stmt *offset = get_load(rel_offset, index, block); |
122 | if (stmt->is_size()) { |
123 | Stmt *one = block.push_back<ConstStmt>(TypedConstant{1}); |
124 | Stmt *index_1 = |
125 | block.push_back<BinaryOpStmt>(BinaryOpType::add, index, one); |
126 | Stmt *offset_1 = get_load(rel_offset, index_1, block); |
127 | Stmt *val = |
128 | block.push_back<BinaryOpStmt>(BinaryOpType::sub, offset_1, offset); |
129 | Stmt *casted_val = |
130 | block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, val); |
131 | casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
132 | } else { |
133 | SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value; |
134 | Stmt *val_local_index = block.push_back<BinaryOpStmt>( |
135 | BinaryOpType::add, offset, stmt->neighbor_idx); |
136 | Stmt *val_index = block.push_back<BinaryOpStmt>( |
137 | BinaryOpType::add, val_local_index, patch_offset); |
138 | [[maybe_unused]] Stmt *val = get_load(rel_value, val_index, block); |
139 | } |
140 | stmt->replace_with(std::move(block)); |
141 | } |
142 | } |
143 | } |
144 | |
145 | // This pass should happen after offloading but before lower_access |
146 | void demote_mesh_statements(IRNode *root, |
147 | const CompileConfig &config, |
148 | const DemoteMeshStatements::Args &args) { |
149 | TI_AUTO_PROF; |
150 | |
151 | if (auto root_block = root->cast<Block>()) { |
152 | for (auto &offload : root_block->statements) { |
153 | demote_mesh_statements_offload(offload->cast<OffloadedStmt>(), config, |
154 | args.kernel_name); |
155 | } |
156 | } else { |
157 | demote_mesh_statements_offload(root->as<OffloadedStmt>(), config, |
158 | args.kernel_name); |
159 | } |
160 | type_check(root, config); |
161 | } |
162 | |
163 | } // namespace irpass |
164 | } // namespace taichi::lang |
165 | |