1#include "taichi/ir/analysis.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/pass.h"
4#include "taichi/ir/statements.h"
5#include "taichi/ir/transforms.h"
6#include "taichi/ir/visitors.h"
7#include "taichi/program/compile_config.h"
8
9namespace {
10
11using namespace taichi;
12using namespace taichi::lang;
13
14class CreateBitStructStores : public BasicStmtVisitor {
15 public:
16 using BasicStmtVisitor::visit;
17
18 CreateBitStructStores() {
19 allow_undefined_visitor = true;
20 invoke_default_visitor = false;
21 }
22
23 static void run(IRNode *root) {
24 CreateBitStructStores pass;
25 root->accept(&pass);
26 }
27
28 void visit(GlobalStoreStmt *stmt) override {
29 auto get_ch = stmt->dest->cast<GetChStmt>();
30 if (!get_ch || get_ch->input_snode->type != SNodeType::bit_struct)
31 return;
32
33 // We only handle bit_struct pointers here.
34
35 auto s = Stmt::make<BitStructStoreStmt>(
36 get_ch->input_ptr,
37 std::vector<int>{get_ch->output_snode->id_in_bit_struct},
38 std::vector<Stmt *>{stmt->val});
39 stmt->replace_with(VecStatement(std::move(s)));
40 }
41};
42
43class MergeBitStructStores : public BasicStmtVisitor {
44 public:
45 using BasicStmtVisitor::visit;
46
47 MergeBitStructStores() {
48 allow_undefined_visitor = true;
49 invoke_default_visitor = false;
50 }
51
52 static void run(IRNode *root) {
53 while (true) {
54 MergeBitStructStores pass;
55 root->accept(&pass);
56 if (!pass.modified_)
57 break;
58 }
59 }
60
61 void visit(Block *block) override {
62 auto &statements = block->statements;
63 std::unordered_map<Stmt *, std::vector<BitStructStoreStmt *>>
64 ptr_to_bit_struct_stores;
65 std::vector<Stmt *> statements_to_delete;
66 for (int i = 0; i <= (int)statements.size(); i++) {
67 // TODO: in some cases BitStructStoreStmts across container statements can
68 // still be merged, similar to basic block v.s. CFG optimizations.
69 if (i == statements.size() || statements[i]->is_container_statement()) {
70 for (const auto &item : ptr_to_bit_struct_stores) {
71 auto ptr = item.first;
72 auto stores = item.second;
73 if (stores.size() == 1)
74 continue;
75 std::map<int, Stmt *> values;
76 for (auto s : stores) {
77 for (int j = 0; j < (int)s->ch_ids.size(); j++) {
78 values[s->ch_ids[j]] = s->values[j];
79 }
80 }
81 std::vector<int> ch_ids;
82 std::vector<Stmt *> store_values;
83 for (auto &ch_id_and_value : values) {
84 ch_ids.push_back(ch_id_and_value.first);
85 store_values.push_back(ch_id_and_value.second);
86 }
87 // Now erase all (except the last) related BitSturctStoreStmts.
88 // Replace the last one with a merged version.
89 for (int j = 0; j < (int)stores.size() - 1; j++) {
90 statements_to_delete.push_back(stores[j]);
91 }
92 stores.back()->replace_with(
93 Stmt::make<BitStructStoreStmt>(ptr, ch_ids, store_values));
94 modified_ = true;
95 }
96 ptr_to_bit_struct_stores.clear();
97 continue;
98 }
99 if (auto stmt = statements[i]->cast<BitStructStoreStmt>()) {
100 ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt);
101 }
102 }
103
104 for (auto stmt : statements_to_delete) {
105 block->erase(stmt);
106 }
107
108 for (auto &stmt : statements) {
109 stmt->accept(this);
110 }
111 }
112
113 private:
114 bool modified_{false};
115};
116
117class DemoteAtomicBitStructStores : public BasicStmtVisitor {
118 private:
119 const std::unordered_map<OffloadedStmt *,
120 std::unordered_map<const SNode *, GlobalPtrStmt *>>
121 &uniquely_accessed_bit_structs_;
122 std::unordered_map<OffloadedStmt *,
123 std::unordered_map<const SNode *, GlobalPtrStmt *>>::
124 const_iterator current_iterator_;
125 bool modified_{false};
126
127 public:
128 using BasicStmtVisitor::visit;
129 OffloadedStmt *current_offloaded;
130
131 explicit DemoteAtomicBitStructStores(
132 const std::unordered_map<
133 OffloadedStmt *,
134 std::unordered_map<const SNode *, GlobalPtrStmt *>>
135 &uniquely_accessed_bit_structs)
136 : uniquely_accessed_bit_structs_(uniquely_accessed_bit_structs),
137 current_offloaded(nullptr) {
138 allow_undefined_visitor = true;
139 invoke_default_visitor = false;
140 }
141
142 void visit(BitStructStoreStmt *stmt) override {
143 bool demote = false;
144 TI_ASSERT(current_offloaded);
145 if (current_offloaded->task_type == OffloadedTaskType::serial) {
146 demote = true;
147 } else if (current_offloaded->task_type == OffloadedTaskType::range_for ||
148 current_offloaded->task_type == OffloadedTaskType::mesh_for ||
149 current_offloaded->task_type == OffloadedTaskType::struct_for) {
150 auto *snode = stmt->ptr->as<SNodeLookupStmt>()->snode;
151 // Find the nearest non-bit-level ancestor
152 while (snode->is_bit_level) {
153 snode = snode->parent;
154 }
155 auto accessed_ptr_iterator = current_iterator_->second.find(snode);
156 if (accessed_ptr_iterator != current_iterator_->second.end() &&
157 accessed_ptr_iterator->second != nullptr) {
158 demote = true;
159 }
160 }
161 if (demote) {
162 stmt->is_atomic = false;
163 modified_ = true;
164 }
165 }
166
167 void visit(OffloadedStmt *stmt) override {
168 current_offloaded = stmt;
169 if (stmt->task_type == OffloadedTaskType::range_for ||
170 stmt->task_type == OffloadedTaskType::mesh_for ||
171 stmt->task_type == OffloadedTaskType::struct_for) {
172 current_iterator_ =
173 uniquely_accessed_bit_structs_.find(current_offloaded);
174 }
175 // We don't need to visit TLS/BLS prologues/epilogues.
176 if (stmt->body) {
177 stmt->body->accept(this);
178 }
179 current_offloaded = nullptr;
180 }
181
182 static bool run(IRNode *node,
183 const std::unordered_map<
184 OffloadedStmt *,
185 std::unordered_map<const SNode *, GlobalPtrStmt *>>
186 &uniquely_accessed_bit_structs) {
187 DemoteAtomicBitStructStores demoter(uniquely_accessed_bit_structs);
188 node->accept(&demoter);
189 return demoter.modified_;
190 }
191};
192
193} // namespace
194
195namespace taichi::lang {
196
197namespace irpass {
198void optimize_bit_struct_stores(IRNode *root,
199 const CompileConfig &config,
200 AnalysisManager *amgr) {
201 TI_AUTO_PROF;
202 CreateBitStructStores::run(root);
203 die(root); // remove unused GetCh
204 if (config.quant_opt_store_fusion) {
205 MergeBitStructStores::run(root);
206 }
207 if (config.quant_opt_atomic_demotion) {
208 auto *res = amgr->get_pass_result<GatherUniquelyAccessedBitStructsPass>();
209 TI_ASSERT_INFO(res,
210 "The optimize_bit_struct_stores pass must be after the "
211 "gather_uniquely_accessed_bit_structs pass when "
212 "config.quant_opt_atomic_demotion is true.");
213 DemoteAtomicBitStructStores::run(root, res->uniquely_accessed_bit_structs);
214 }
215}
216
217} // namespace irpass
218
219} // namespace taichi::lang
220