1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/transforms/demote_mesh_statements.h"
6#include "taichi/ir/visitors.h"
7
8namespace taichi::lang {
9
10const PassID DemoteMeshStatements::id = "DemoteMeshStatements";
11
12namespace irpass {
13
14auto get_load = [](SNode *snode, Stmt *idx, VecStatement &block) {
15 Stmt *casted_idx = block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, idx);
16 casted_idx->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
17 const auto lane = std::vector<Stmt *>{casted_idx};
18 Stmt *globalptr = block.push_back<GlobalPtrStmt>(snode, lane);
19 Stmt *load = block.push_back<GlobalLoadStmt>(globalptr);
20 return load;
21};
22
23class ReplaceIndexConversion : public BasicStmtVisitor {
24 public:
25 using BasicStmtVisitor::visit;
26
27 explicit ReplaceIndexConversion(OffloadedStmt *node) {
28 allow_undefined_visitor = true;
29 invoke_default_visitor = true;
30
31 offload = node;
32 visit(node);
33 }
34
35 void visit(MeshIndexConversionStmt *stmt) override {
36 SNode *mapping = (stmt->mesh->index_mapping
37 .find(std::make_pair(stmt->idx_type, stmt->conv_type))
38 ->second);
39
40 VecStatement block;
41 Stmt *val;
42 if (stmt->conv_type == mesh::ConvType::g2r) {
43 // E.g, v_reordered = v_g2r[v_global]
44 val = get_load(mapping, stmt->idx, block);
45 } else {
46 // E.g, v_global = v_l2g[v_local + total_vertices_offset]
47 Stmt *offset = offload->total_offset_local.find(stmt->idx_type)->second;
48 Stmt *index =
49 block.push_back<BinaryOpStmt>(BinaryOpType::add, stmt->idx, offset);
50 val = get_load(mapping, index, block);
51 }
52 Stmt *casted_val =
53 block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, val);
54 casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
55 stmt->replace_with(std::move(block));
56 }
57
58 OffloadedStmt *offload;
59};
60
61void demote_mesh_statements_offload(OffloadedStmt *offload,
62 const CompileConfig &config,
63 const std::string &kernel_name) {
64 ReplaceIndexConversion rep_conv(
65 offload); // This demote should work for any offloaed statement
66
67 if (offload->task_type != OffloadedStmt::TaskType::mesh_for) {
68 return;
69 }
70
71 auto stmts = irpass::analysis::gather_statements(
72 offload->body.get(),
73 [&](Stmt *stmt) { return stmt->is<MeshRelationAccessStmt>(); });
74
75 for (int i = stmts.size() - 1; i >= 0; --i) {
76 auto stmt = stmts[i]->cast<MeshRelationAccessStmt>();
77 mesh::MeshElementType from_type = stmt->from_type();
78
79 auto from_order = mesh::element_order(from_type);
80 auto to_order = mesh::element_order(stmt->to_type);
81 mesh::MeshRelationType rel_type =
82 mesh::relation_by_orders(from_order, to_order);
83 if (from_order > to_order) { // high-to-low relation
84 if (stmt->is_size()) {
85 stmt->replace_with(Stmt::make<ConstStmt>(
86 TypedConstant{from_type == mesh::MeshElementType::Cell &&
87 stmt->to_type == mesh::MeshElementType::Edge
88 ? /*Cell-Edge=*/6
89 : (from_order + 1)}));
90 } else {
91 SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value;
92 VecStatement block;
93 Stmt *to_size = block.push_back<ConstStmt>(
94 TypedConstant{from_type == mesh::MeshElementType::Cell &&
95 stmt->to_type == mesh::MeshElementType::Edge
96 ? /*Cell-Edge=*/6
97 : (from_order + 1)});
98 // E.g, v_2 = CV[(c + total_cells_offset) * 4 + 2]
99 Stmt *offset = offload->total_offset_local.find(from_type)->second;
100 Stmt *tmp0 = block.push_back<BinaryOpStmt>(BinaryOpType::add, offset,
101 stmt->mesh_idx);
102 Stmt *tmp1 =
103 block.push_back<BinaryOpStmt>(BinaryOpType::mul, tmp0, to_size);
104 Stmt *index = block.push_back<BinaryOpStmt>(BinaryOpType::add, tmp1,
105 stmt->neighbor_idx);
106 [[maybe_unused]] Stmt *val = get_load(rel_value, index, block);
107 stmt->replace_with(std::move(block));
108 }
109 } else { // low-to-high or same-order
110 const auto &rel = stmt->mesh->relations.find(rel_type)->second;
111 SNode *rel_offset = rel.offset;
112 SNode *rel_patch_offset = rel.patch_offset;
113 VecStatement block;
114 Stmt *patch_idx = block.push_back<MeshPatchIndexStmt>();
115 Stmt *owned_offset = offload->owned_offset_local.find(from_type)->second;
116 Stmt *patch_offset = get_load(rel_patch_offset, patch_idx, block);
117 Stmt *index_offset = block.push_back<BinaryOpStmt>(
118 BinaryOpType::add, patch_idx, owned_offset);
119 Stmt *index = block.push_back<BinaryOpStmt>(BinaryOpType::add,
120 index_offset, stmt->mesh_idx);
121 Stmt *offset = get_load(rel_offset, index, block);
122 if (stmt->is_size()) {
123 Stmt *one = block.push_back<ConstStmt>(TypedConstant{1});
124 Stmt *index_1 =
125 block.push_back<BinaryOpStmt>(BinaryOpType::add, index, one);
126 Stmt *offset_1 = get_load(rel_offset, index_1, block);
127 Stmt *val =
128 block.push_back<BinaryOpStmt>(BinaryOpType::sub, offset_1, offset);
129 Stmt *casted_val =
130 block.push_back<UnaryOpStmt>(UnaryOpType::cast_value, val);
131 casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
132 } else {
133 SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value;
134 Stmt *val_local_index = block.push_back<BinaryOpStmt>(
135 BinaryOpType::add, offset, stmt->neighbor_idx);
136 Stmt *val_index = block.push_back<BinaryOpStmt>(
137 BinaryOpType::add, val_local_index, patch_offset);
138 [[maybe_unused]] Stmt *val = get_load(rel_value, val_index, block);
139 }
140 stmt->replace_with(std::move(block));
141 }
142 }
143}
144
145// This pass should happen after offloading but before lower_access
146void demote_mesh_statements(IRNode *root,
147 const CompileConfig &config,
148 const DemoteMeshStatements::Args &args) {
149 TI_AUTO_PROF;
150
151 if (auto root_block = root->cast<Block>()) {
152 for (auto &offload : root_block->statements) {
153 demote_mesh_statements_offload(offload->cast<OffloadedStmt>(), config,
154 args.kernel_name);
155 }
156 } else {
157 demote_mesh_statements_offload(root->as<OffloadedStmt>(), config,
158 args.kernel_name);
159 }
160 type_check(root, config);
161}
162
163} // namespace irpass
164} // namespace taichi::lang
165