1#include "taichi/analysis/bls_analyzer.h"
2
3#include "taichi/system/profiler.h"
4#include "taichi/ir/analysis.h"
5
6namespace taichi::lang {
7
8BLSAnalyzer::BLSAnalyzer(OffloadedStmt *for_stmt, ScratchPads *pads)
9 : for_stmt_(for_stmt), pads_(pads) {
10 TI_AUTO_PROF;
11 allow_undefined_visitor = true;
12 invoke_default_visitor = false;
13 for (auto &snode : for_stmt->mem_access_opt.get_snodes_with_flag(
14 SNodeAccessFlag::block_local)) {
15 auto *block = snode->parent;
16 if (block_indices_.find(block) == block_indices_.end()) {
17 generate_block_indices(block, &block_indices_[block]);
18 }
19 }
20}
21
22// static
23void BLSAnalyzer::generate_block_indices(SNode *snode, BlockIndices *indices) {
24 // NOTE: Assuming not vectorized
25 for (int i = 0; i < snode->num_active_indices; i++) {
26 auto j = snode->physical_index_position[i];
27 indices->push_back({/*low=*/0, /*high=*/snode->extractors[j].shape - 1});
28 }
29}
30
31void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
32 if (!analysis_ok_) {
33 return;
34 }
35 if (!stmt->is<GlobalPtrStmt>())
36 return; // local alloca
37 auto ptr = stmt->as<GlobalPtrStmt>();
38 auto snode = ptr->snode;
39 if (!pads_->has(snode)) {
40 return;
41 }
42 bool matching_indices = true;
43 std::vector<IndexRange> offsets;
44 std::vector<int> coeffs;
45 offsets.resize(ptr->indices.size());
46 coeffs.resize(ptr->indices.size());
47 const int num_indices = (int)ptr->indices.size();
48 for (int i = 0; i < num_indices; i++) {
49 auto diff =
50 irpass::analysis::value_diff_loop_index(ptr->indices[i], for_stmt_, i);
51 if (diff.related() && diff.coeff > 0) {
52 offsets[i].low = diff.low;
53 offsets[i].high = diff.high;
54 coeffs[i] = diff.coeff;
55 } else {
56 matching_indices = false;
57 analysis_ok_ = false;
58 }
59 }
60 if (matching_indices) {
61 auto *block = snode->parent;
62 const auto &index_bounds = block_indices_[block];
63 std::vector<int> index(num_indices, 0);
64 std::function<void(int)> visit = [&](int dimension) {
65 if (dimension == num_indices) {
66 pads_->access(snode, coeffs, index, flag);
67 return;
68 }
69 for (int i = (index_bounds[dimension].low + offsets[dimension].low);
70 i < (index_bounds[dimension].high + offsets[dimension].high); i++) {
71 index[dimension] = i;
72 visit(dimension + 1);
73 }
74 };
75 visit(0);
76 }
77}
78
79// Do not eliminate global data access
80void BLSAnalyzer::visit(GlobalLoadStmt *stmt) {
81 record_access(stmt->src, AccessFlag::read);
82}
83
84void BLSAnalyzer::visit(GlobalStoreStmt *stmt) {
85 record_access(stmt->dest, AccessFlag::write);
86}
87
88void BLSAnalyzer::visit(AtomicOpStmt *stmt) {
89 if (stmt->op_type == AtomicOpType::add) {
90 record_access(stmt->dest, AccessFlag::accumulate);
91 }
92}
93
94void BLSAnalyzer::visit(Stmt *stmt) {
95 TI_ASSERT(!stmt->is_container_statement());
96}
97
98bool BLSAnalyzer::run() {
99 const auto &block = for_stmt_->body;
100
101 for (int i = 0; i < (int)block->statements.size(); i++) {
102 block->statements[i]->accept(this);
103 }
104
105 return analysis_ok_;
106}
107
108} // namespace taichi::lang
109