1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/snode.h" |
3 | #include "taichi/ir/mesh.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/ir/analysis.h" |
6 | #include "taichi/ir/statements.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | using MeshElementTypeSet = std::unordered_set<mesh::MeshElementType>; |
11 | |
12 | class GatherMeshThreadLocal : public BasicStmtVisitor { |
13 | public: |
14 | using BasicStmtVisitor::visit; |
15 | |
16 | GatherMeshThreadLocal(OffloadedStmt *offload_, |
17 | MeshElementTypeSet *owned_ptr_, |
18 | MeshElementTypeSet *total_ptr_, |
19 | bool optimize_mesh_reordered_mapping_) { |
20 | allow_undefined_visitor = true; |
21 | invoke_default_visitor = true; |
22 | |
23 | this->offload = offload_; |
24 | this->owned_ptr = owned_ptr_; |
25 | this->total_ptr = total_ptr_; |
26 | this->optimize_mesh_reordered_mapping = optimize_mesh_reordered_mapping_; |
27 | } |
28 | |
29 | static void run(OffloadedStmt *offload, |
30 | MeshElementTypeSet *owned_ptr, |
31 | MeshElementTypeSet *total_ptr, |
32 | const CompileConfig &config) { |
33 | TI_ASSERT(offload->task_type == OffloadedStmt::TaskType::mesh_for); |
34 | GatherMeshThreadLocal analyser(offload, owned_ptr, total_ptr, |
35 | config.optimize_mesh_reordered_mapping); |
36 | offload->accept(&analyser); |
37 | } |
38 | |
39 | void visit(LoopIndexStmt *stmt) override { |
40 | if (stmt->is_mesh_index()) { |
41 | this->owned_ptr->insert(stmt->mesh_index_type()); |
42 | } |
43 | } |
44 | |
45 | void visit(MeshRelationAccessStmt *stmt) override { |
46 | if (mesh::element_order(stmt->from_type()) > |
47 | mesh::element_order(stmt->to_type)) { |
48 | this->total_ptr->insert(stmt->from_type()); |
49 | } else { |
50 | this->owned_ptr->insert(stmt->from_type()); |
51 | } |
52 | } |
53 | |
54 | void visit(MeshIndexConversionStmt *stmt) override { |
55 | this->total_ptr->insert(stmt->idx_type); |
56 | if (optimize_mesh_reordered_mapping && |
57 | stmt->conv_type == mesh::ConvType::l2r) { |
58 | this->owned_ptr->insert(stmt->idx_type); |
59 | } |
60 | } |
61 | |
62 | OffloadedStmt *offload{nullptr}; |
63 | MeshElementTypeSet *owned_ptr{nullptr}; |
64 | MeshElementTypeSet *total_ptr{nullptr}; |
65 | bool optimize_mesh_reordered_mapping{false}; |
66 | }; |
67 | |
68 | namespace irpass::analysis { |
69 | |
70 | std::pair</* owned= */ MeshElementTypeSet, |
71 | /* total= */ MeshElementTypeSet> |
72 | gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config) { |
73 | MeshElementTypeSet local_owned{}; |
74 | MeshElementTypeSet local_total{}; |
75 | |
76 | GatherMeshThreadLocal::run(offload, &local_owned, &local_total, config); |
77 | return std::make_pair(local_owned, local_total); |
78 | } |
79 | |
80 | } // namespace irpass::analysis |
81 | |
82 | } // namespace taichi::lang |
83 | |