1#include "taichi/analysis/mesh_bls_analyzer.h"
2
3#include "taichi/system/profiler.h"
4#include "taichi/ir/analysis.h"
5
6namespace taichi::lang {
7
8MeshBLSAnalyzer::MeshBLSAnalyzer(OffloadedStmt *for_stmt,
9 MeshBLSCaches *caches,
10 bool auto_mesh_local,
11 const CompileConfig &config)
12 : for_stmt_(for_stmt),
13 caches_(caches),
14 auto_mesh_local_(auto_mesh_local),
15 config_(config) {
16 TI_AUTO_PROF;
17 allow_undefined_visitor = true;
18 invoke_default_visitor = false;
19}
20
21void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
22 if (!analysis_ok_) {
23 return;
24 }
25 if (!stmt->is<GlobalPtrStmt>())
26 return; // local alloca
27 auto ptr = stmt->as<GlobalPtrStmt>();
28 if (ptr->indices.size() != std::size_t(1) ||
29 !ptr->indices[0]->is<MeshIndexConversionStmt>())
30 return;
31 auto conv = ptr->indices[0]->as<MeshIndexConversionStmt>();
32 auto element_type = conv->idx_type;
33 auto conv_type = conv->conv_type;
34 auto idx = conv->idx;
35 if (conv_type == mesh::ConvType::g2r)
36 return;
37 auto snode = ptr->snode;
38 if (!caches_->has(snode)) {
39 if (auto_mesh_local_ &&
40 (flag == AccessFlag::accumulate ||
41 (flag == AccessFlag::read && config_.arch == Arch::cuda)) &&
42 (!idx->is<LoopIndexStmt>() ||
43 !idx->as<LoopIndexStmt>()->is_mesh_index())) {
44 caches_->insert(snode);
45 } else {
46 return;
47 }
48 }
49 if (idx->is<MeshRelationAccessStmt>()) {
50 if (!caches_->access(snode, element_type, conv_type, flag,
51 idx->as<MeshRelationAccessStmt>()->neighbor_idx)) {
52 analysis_ok_ = false;
53 return;
54 }
55 } else {
56 // No optimization for front-end attribute access
57 }
58}
59
60void MeshBLSAnalyzer::visit(GlobalLoadStmt *stmt) {
61 record_access(stmt->src, AccessFlag::read);
62}
63
64void MeshBLSAnalyzer::visit(GlobalStoreStmt *stmt) {
65 record_access(stmt->dest, AccessFlag::write);
66}
67
68void MeshBLSAnalyzer::visit(AtomicOpStmt *stmt) {
69 if (stmt->op_type == AtomicOpType::add) {
70 record_access(stmt->dest, AccessFlag::accumulate);
71 }
72}
73
74void MeshBLSAnalyzer::visit(Stmt *stmt) {
75 TI_ASSERT(!stmt->is_container_statement());
76}
77
78bool MeshBLSAnalyzer::run() {
79 const auto &block = for_stmt_->body;
80
81 for (int i = 0; i < (int)block->statements.size(); i++) {
82 block->statements[i]->accept(this);
83 }
84
85 return analysis_ok_;
86}
87
88namespace irpass {
89namespace analysis {
90
91std::unique_ptr<MeshBLSCaches> initialize_mesh_local_attribute(
92 OffloadedStmt *offload,
93 bool auto_mesh_local,
94 const CompileConfig &config) {
95 TI_AUTO_PROF
96 TI_ASSERT(offload->task_type == OffloadedTaskType::mesh_for);
97 std::unique_ptr<MeshBLSCaches> caches;
98 caches = std::make_unique<MeshBLSCaches>();
99 for (auto snode : offload->mem_access_opt.get_snodes_with_flag(
100 SNodeAccessFlag::mesh_local)) {
101 caches->insert(snode);
102 }
103
104 MeshBLSAnalyzer bls_analyzer(offload, caches.get(), auto_mesh_local, config);
105 bool analysis_ok = bls_analyzer.run();
106 if (!analysis_ok) {
107 TI_ERROR("Mesh BLS analysis failed !");
108 }
109 return caches;
110}
111
112} // namespace analysis
113} // namespace irpass
114
115} // namespace taichi::lang
116