1 | #include "taichi/analysis/bls_analyzer.h" |
2 | |
3 | #include "taichi/system/profiler.h" |
4 | #include "taichi/ir/analysis.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | BLSAnalyzer::BLSAnalyzer(OffloadedStmt *for_stmt, ScratchPads *pads) |
9 | : for_stmt_(for_stmt), pads_(pads) { |
10 | TI_AUTO_PROF; |
11 | allow_undefined_visitor = true; |
12 | invoke_default_visitor = false; |
13 | for (auto &snode : for_stmt->mem_access_opt.get_snodes_with_flag( |
14 | SNodeAccessFlag::block_local)) { |
15 | auto *block = snode->parent; |
16 | if (block_indices_.find(block) == block_indices_.end()) { |
17 | generate_block_indices(block, &block_indices_[block]); |
18 | } |
19 | } |
20 | } |
21 | |
22 | // static |
23 | void BLSAnalyzer::generate_block_indices(SNode *snode, BlockIndices *indices) { |
24 | // NOTE: Assuming not vectorized |
25 | for (int i = 0; i < snode->num_active_indices; i++) { |
26 | auto j = snode->physical_index_position[i]; |
27 | indices->push_back({/*low=*/0, /*high=*/snode->extractors[j].shape - 1}); |
28 | } |
29 | } |
30 | |
31 | void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { |
32 | if (!analysis_ok_) { |
33 | return; |
34 | } |
35 | if (!stmt->is<GlobalPtrStmt>()) |
36 | return; // local alloca |
37 | auto ptr = stmt->as<GlobalPtrStmt>(); |
38 | auto snode = ptr->snode; |
39 | if (!pads_->has(snode)) { |
40 | return; |
41 | } |
42 | bool matching_indices = true; |
43 | std::vector<IndexRange> offsets; |
44 | std::vector<int> coeffs; |
45 | offsets.resize(ptr->indices.size()); |
46 | coeffs.resize(ptr->indices.size()); |
47 | const int num_indices = (int)ptr->indices.size(); |
48 | for (int i = 0; i < num_indices; i++) { |
49 | auto diff = |
50 | irpass::analysis::value_diff_loop_index(ptr->indices[i], for_stmt_, i); |
51 | if (diff.related() && diff.coeff > 0) { |
52 | offsets[i].low = diff.low; |
53 | offsets[i].high = diff.high; |
54 | coeffs[i] = diff.coeff; |
55 | } else { |
56 | matching_indices = false; |
57 | analysis_ok_ = false; |
58 | } |
59 | } |
60 | if (matching_indices) { |
61 | auto *block = snode->parent; |
62 | const auto &index_bounds = block_indices_[block]; |
63 | std::vector<int> index(num_indices, 0); |
64 | std::function<void(int)> visit = [&](int dimension) { |
65 | if (dimension == num_indices) { |
66 | pads_->access(snode, coeffs, index, flag); |
67 | return; |
68 | } |
69 | for (int i = (index_bounds[dimension].low + offsets[dimension].low); |
70 | i < (index_bounds[dimension].high + offsets[dimension].high); i++) { |
71 | index[dimension] = i; |
72 | visit(dimension + 1); |
73 | } |
74 | }; |
75 | visit(0); |
76 | } |
77 | } |
78 | |
79 | // Do not eliminate global data access |
80 | void BLSAnalyzer::visit(GlobalLoadStmt *stmt) { |
81 | record_access(stmt->src, AccessFlag::read); |
82 | } |
83 | |
84 | void BLSAnalyzer::visit(GlobalStoreStmt *stmt) { |
85 | record_access(stmt->dest, AccessFlag::write); |
86 | } |
87 | |
88 | void BLSAnalyzer::visit(AtomicOpStmt *stmt) { |
89 | if (stmt->op_type == AtomicOpType::add) { |
90 | record_access(stmt->dest, AccessFlag::accumulate); |
91 | } |
92 | } |
93 | |
94 | void BLSAnalyzer::visit(Stmt *stmt) { |
95 | TI_ASSERT(!stmt->is_container_statement()); |
96 | } |
97 | |
98 | bool BLSAnalyzer::run() { |
99 | const auto &block = for_stmt_->body; |
100 | |
101 | for (int i = 0; i < (int)block->statements.size(); i++) { |
102 | block->statements[i]->accept(this); |
103 | } |
104 | |
105 | return analysis_ok_; |
106 | } |
107 | |
108 | } // namespace taichi::lang |
109 | |