1#include "taichi/ir/ir.h"
2#include "taichi/ir/snode.h"
3#include "taichi/ir/mesh.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/ir/analysis.h"
6#include "taichi/ir/statements.h"
7
8namespace taichi::lang {
9
10using MeshElementTypeSet = std::unordered_set<mesh::MeshElementType>;
11
12class GatherMeshThreadLocal : public BasicStmtVisitor {
13 public:
14 using BasicStmtVisitor::visit;
15
16 GatherMeshThreadLocal(OffloadedStmt *offload_,
17 MeshElementTypeSet *owned_ptr_,
18 MeshElementTypeSet *total_ptr_,
19 bool optimize_mesh_reordered_mapping_) {
20 allow_undefined_visitor = true;
21 invoke_default_visitor = true;
22
23 this->offload = offload_;
24 this->owned_ptr = owned_ptr_;
25 this->total_ptr = total_ptr_;
26 this->optimize_mesh_reordered_mapping = optimize_mesh_reordered_mapping_;
27 }
28
29 static void run(OffloadedStmt *offload,
30 MeshElementTypeSet *owned_ptr,
31 MeshElementTypeSet *total_ptr,
32 const CompileConfig &config) {
33 TI_ASSERT(offload->task_type == OffloadedStmt::TaskType::mesh_for);
34 GatherMeshThreadLocal analyser(offload, owned_ptr, total_ptr,
35 config.optimize_mesh_reordered_mapping);
36 offload->accept(&analyser);
37 }
38
39 void visit(LoopIndexStmt *stmt) override {
40 if (stmt->is_mesh_index()) {
41 this->owned_ptr->insert(stmt->mesh_index_type());
42 }
43 }
44
45 void visit(MeshRelationAccessStmt *stmt) override {
46 if (mesh::element_order(stmt->from_type()) >
47 mesh::element_order(stmt->to_type)) {
48 this->total_ptr->insert(stmt->from_type());
49 } else {
50 this->owned_ptr->insert(stmt->from_type());
51 }
52 }
53
54 void visit(MeshIndexConversionStmt *stmt) override {
55 this->total_ptr->insert(stmt->idx_type);
56 if (optimize_mesh_reordered_mapping &&
57 stmt->conv_type == mesh::ConvType::l2r) {
58 this->owned_ptr->insert(stmt->idx_type);
59 }
60 }
61
62 OffloadedStmt *offload{nullptr};
63 MeshElementTypeSet *owned_ptr{nullptr};
64 MeshElementTypeSet *total_ptr{nullptr};
65 bool optimize_mesh_reordered_mapping{false};
66};
67
68namespace irpass::analysis {
69
70std::pair</* owned= */ MeshElementTypeSet,
71 /* total= */ MeshElementTypeSet>
72gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config) {
73 MeshElementTypeSet local_owned{};
74 MeshElementTypeSet local_total{};
75
76 GatherMeshThreadLocal::run(offload, &local_owned, &local_total, config);
77 return std::make_pair(local_owned, local_total);
78}
79
80} // namespace irpass::analysis
81
82} // namespace taichi::lang
83