1 | #include "taichi/analysis/mesh_bls_analyzer.h" |
2 | |
3 | #include "taichi/system/profiler.h" |
4 | #include "taichi/ir/analysis.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | MeshBLSAnalyzer::MeshBLSAnalyzer(OffloadedStmt *for_stmt, |
9 | MeshBLSCaches *caches, |
10 | bool auto_mesh_local, |
11 | const CompileConfig &config) |
12 | : for_stmt_(for_stmt), |
13 | caches_(caches), |
14 | auto_mesh_local_(auto_mesh_local), |
15 | config_(config) { |
16 | TI_AUTO_PROF; |
17 | allow_undefined_visitor = true; |
18 | invoke_default_visitor = false; |
19 | } |
20 | |
21 | void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { |
22 | if (!analysis_ok_) { |
23 | return; |
24 | } |
25 | if (!stmt->is<GlobalPtrStmt>()) |
26 | return; // local alloca |
27 | auto ptr = stmt->as<GlobalPtrStmt>(); |
28 | if (ptr->indices.size() != std::size_t(1) || |
29 | !ptr->indices[0]->is<MeshIndexConversionStmt>()) |
30 | return; |
31 | auto conv = ptr->indices[0]->as<MeshIndexConversionStmt>(); |
32 | auto element_type = conv->idx_type; |
33 | auto conv_type = conv->conv_type; |
34 | auto idx = conv->idx; |
35 | if (conv_type == mesh::ConvType::g2r) |
36 | return; |
37 | auto snode = ptr->snode; |
38 | if (!caches_->has(snode)) { |
39 | if (auto_mesh_local_ && |
40 | (flag == AccessFlag::accumulate || |
41 | (flag == AccessFlag::read && config_.arch == Arch::cuda)) && |
42 | (!idx->is<LoopIndexStmt>() || |
43 | !idx->as<LoopIndexStmt>()->is_mesh_index())) { |
44 | caches_->insert(snode); |
45 | } else { |
46 | return; |
47 | } |
48 | } |
49 | if (idx->is<MeshRelationAccessStmt>()) { |
50 | if (!caches_->access(snode, element_type, conv_type, flag, |
51 | idx->as<MeshRelationAccessStmt>()->neighbor_idx)) { |
52 | analysis_ok_ = false; |
53 | return; |
54 | } |
55 | } else { |
56 | // No optimization for front-end attribute access |
57 | } |
58 | } |
59 | |
60 | void MeshBLSAnalyzer::visit(GlobalLoadStmt *stmt) { |
61 | record_access(stmt->src, AccessFlag::read); |
62 | } |
63 | |
64 | void MeshBLSAnalyzer::visit(GlobalStoreStmt *stmt) { |
65 | record_access(stmt->dest, AccessFlag::write); |
66 | } |
67 | |
68 | void MeshBLSAnalyzer::visit(AtomicOpStmt *stmt) { |
69 | if (stmt->op_type == AtomicOpType::add) { |
70 | record_access(stmt->dest, AccessFlag::accumulate); |
71 | } |
72 | } |
73 | |
74 | void MeshBLSAnalyzer::visit(Stmt *stmt) { |
75 | TI_ASSERT(!stmt->is_container_statement()); |
76 | } |
77 | |
78 | bool MeshBLSAnalyzer::run() { |
79 | const auto &block = for_stmt_->body; |
80 | |
81 | for (int i = 0; i < (int)block->statements.size(); i++) { |
82 | block->statements[i]->accept(this); |
83 | } |
84 | |
85 | return analysis_ok_; |
86 | } |
87 | |
88 | namespace irpass { |
89 | namespace analysis { |
90 | |
91 | std::unique_ptr<MeshBLSCaches> initialize_mesh_local_attribute( |
92 | OffloadedStmt *offload, |
93 | bool auto_mesh_local, |
94 | const CompileConfig &config) { |
95 | TI_AUTO_PROF |
96 | TI_ASSERT(offload->task_type == OffloadedTaskType::mesh_for); |
97 | std::unique_ptr<MeshBLSCaches> caches; |
98 | caches = std::make_unique<MeshBLSCaches>(); |
99 | for (auto snode : offload->mem_access_opt.get_snodes_with_flag( |
100 | SNodeAccessFlag::mesh_local)) { |
101 | caches->insert(snode); |
102 | } |
103 | |
104 | MeshBLSAnalyzer bls_analyzer(offload, caches.get(), auto_mesh_local, config); |
105 | bool analysis_ok = bls_analyzer.run(); |
106 | if (!analysis_ok) { |
107 | TI_ERROR("Mesh BLS analysis failed !" ); |
108 | } |
109 | return caches; |
110 | } |
111 | |
112 | } // namespace analysis |
113 | } // namespace irpass |
114 | |
115 | } // namespace taichi::lang |
116 | |