1 | #include "taichi/ir/analysis.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/pass.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/ir/visitors.h" |
7 | #include "taichi/program/compile_config.h" |
8 | |
9 | namespace { |
10 | |
11 | using namespace taichi; |
12 | using namespace taichi::lang; |
13 | |
14 | class CreateBitStructStores : public BasicStmtVisitor { |
15 | public: |
16 | using BasicStmtVisitor::visit; |
17 | |
18 | CreateBitStructStores() { |
19 | allow_undefined_visitor = true; |
20 | invoke_default_visitor = false; |
21 | } |
22 | |
23 | static void run(IRNode *root) { |
24 | CreateBitStructStores pass; |
25 | root->accept(&pass); |
26 | } |
27 | |
28 | void visit(GlobalStoreStmt *stmt) override { |
29 | auto get_ch = stmt->dest->cast<GetChStmt>(); |
30 | if (!get_ch || get_ch->input_snode->type != SNodeType::bit_struct) |
31 | return; |
32 | |
33 | // We only handle bit_struct pointers here. |
34 | |
35 | auto s = Stmt::make<BitStructStoreStmt>( |
36 | get_ch->input_ptr, |
37 | std::vector<int>{get_ch->output_snode->id_in_bit_struct}, |
38 | std::vector<Stmt *>{stmt->val}); |
39 | stmt->replace_with(VecStatement(std::move(s))); |
40 | } |
41 | }; |
42 | |
43 | class MergeBitStructStores : public BasicStmtVisitor { |
44 | public: |
45 | using BasicStmtVisitor::visit; |
46 | |
47 | MergeBitStructStores() { |
48 | allow_undefined_visitor = true; |
49 | invoke_default_visitor = false; |
50 | } |
51 | |
52 | static void run(IRNode *root) { |
53 | while (true) { |
54 | MergeBitStructStores pass; |
55 | root->accept(&pass); |
56 | if (!pass.modified_) |
57 | break; |
58 | } |
59 | } |
60 | |
61 | void visit(Block *block) override { |
62 | auto &statements = block->statements; |
63 | std::unordered_map<Stmt *, std::vector<BitStructStoreStmt *>> |
64 | ptr_to_bit_struct_stores; |
65 | std::vector<Stmt *> statements_to_delete; |
66 | for (int i = 0; i <= (int)statements.size(); i++) { |
67 | // TODO: in some cases BitStructStoreStmts across container statements can |
68 | // still be merged, similar to basic block v.s. CFG optimizations. |
69 | if (i == statements.size() || statements[i]->is_container_statement()) { |
70 | for (const auto &item : ptr_to_bit_struct_stores) { |
71 | auto ptr = item.first; |
72 | auto stores = item.second; |
73 | if (stores.size() == 1) |
74 | continue; |
75 | std::map<int, Stmt *> values; |
76 | for (auto s : stores) { |
77 | for (int j = 0; j < (int)s->ch_ids.size(); j++) { |
78 | values[s->ch_ids[j]] = s->values[j]; |
79 | } |
80 | } |
81 | std::vector<int> ch_ids; |
82 | std::vector<Stmt *> store_values; |
83 | for (auto &ch_id_and_value : values) { |
84 | ch_ids.push_back(ch_id_and_value.first); |
85 | store_values.push_back(ch_id_and_value.second); |
86 | } |
87 | // Now erase all (except the last) related BitSturctStoreStmts. |
88 | // Replace the last one with a merged version. |
89 | for (int j = 0; j < (int)stores.size() - 1; j++) { |
90 | statements_to_delete.push_back(stores[j]); |
91 | } |
92 | stores.back()->replace_with( |
93 | Stmt::make<BitStructStoreStmt>(ptr, ch_ids, store_values)); |
94 | modified_ = true; |
95 | } |
96 | ptr_to_bit_struct_stores.clear(); |
97 | continue; |
98 | } |
99 | if (auto stmt = statements[i]->cast<BitStructStoreStmt>()) { |
100 | ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt); |
101 | } |
102 | } |
103 | |
104 | for (auto stmt : statements_to_delete) { |
105 | block->erase(stmt); |
106 | } |
107 | |
108 | for (auto &stmt : statements) { |
109 | stmt->accept(this); |
110 | } |
111 | } |
112 | |
113 | private: |
114 | bool modified_{false}; |
115 | }; |
116 | |
117 | class DemoteAtomicBitStructStores : public BasicStmtVisitor { |
118 | private: |
119 | const std::unordered_map<OffloadedStmt *, |
120 | std::unordered_map<const SNode *, GlobalPtrStmt *>> |
121 | &uniquely_accessed_bit_structs_; |
122 | std::unordered_map<OffloadedStmt *, |
123 | std::unordered_map<const SNode *, GlobalPtrStmt *>>:: |
124 | const_iterator current_iterator_; |
125 | bool modified_{false}; |
126 | |
127 | public: |
128 | using BasicStmtVisitor::visit; |
129 | OffloadedStmt *current_offloaded; |
130 | |
131 | explicit DemoteAtomicBitStructStores( |
132 | const std::unordered_map< |
133 | OffloadedStmt *, |
134 | std::unordered_map<const SNode *, GlobalPtrStmt *>> |
135 | &uniquely_accessed_bit_structs) |
136 | : uniquely_accessed_bit_structs_(uniquely_accessed_bit_structs), |
137 | current_offloaded(nullptr) { |
138 | allow_undefined_visitor = true; |
139 | invoke_default_visitor = false; |
140 | } |
141 | |
142 | void visit(BitStructStoreStmt *stmt) override { |
143 | bool demote = false; |
144 | TI_ASSERT(current_offloaded); |
145 | if (current_offloaded->task_type == OffloadedTaskType::serial) { |
146 | demote = true; |
147 | } else if (current_offloaded->task_type == OffloadedTaskType::range_for || |
148 | current_offloaded->task_type == OffloadedTaskType::mesh_for || |
149 | current_offloaded->task_type == OffloadedTaskType::struct_for) { |
150 | auto *snode = stmt->ptr->as<SNodeLookupStmt>()->snode; |
151 | // Find the nearest non-bit-level ancestor |
152 | while (snode->is_bit_level) { |
153 | snode = snode->parent; |
154 | } |
155 | auto accessed_ptr_iterator = current_iterator_->second.find(snode); |
156 | if (accessed_ptr_iterator != current_iterator_->second.end() && |
157 | accessed_ptr_iterator->second != nullptr) { |
158 | demote = true; |
159 | } |
160 | } |
161 | if (demote) { |
162 | stmt->is_atomic = false; |
163 | modified_ = true; |
164 | } |
165 | } |
166 | |
167 | void visit(OffloadedStmt *stmt) override { |
168 | current_offloaded = stmt; |
169 | if (stmt->task_type == OffloadedTaskType::range_for || |
170 | stmt->task_type == OffloadedTaskType::mesh_for || |
171 | stmt->task_type == OffloadedTaskType::struct_for) { |
172 | current_iterator_ = |
173 | uniquely_accessed_bit_structs_.find(current_offloaded); |
174 | } |
175 | // We don't need to visit TLS/BLS prologues/epilogues. |
176 | if (stmt->body) { |
177 | stmt->body->accept(this); |
178 | } |
179 | current_offloaded = nullptr; |
180 | } |
181 | |
182 | static bool run(IRNode *node, |
183 | const std::unordered_map< |
184 | OffloadedStmt *, |
185 | std::unordered_map<const SNode *, GlobalPtrStmt *>> |
186 | &uniquely_accessed_bit_structs) { |
187 | DemoteAtomicBitStructStores demoter(uniquely_accessed_bit_structs); |
188 | node->accept(&demoter); |
189 | return demoter.modified_; |
190 | } |
191 | }; |
192 | |
193 | } // namespace |
194 | |
195 | namespace taichi::lang { |
196 | |
197 | namespace irpass { |
198 | void optimize_bit_struct_stores(IRNode *root, |
199 | const CompileConfig &config, |
200 | AnalysisManager *amgr) { |
201 | TI_AUTO_PROF; |
202 | CreateBitStructStores::run(root); |
203 | die(root); // remove unused GetCh |
204 | if (config.quant_opt_store_fusion) { |
205 | MergeBitStructStores::run(root); |
206 | } |
207 | if (config.quant_opt_atomic_demotion) { |
208 | auto *res = amgr->get_pass_result<GatherUniquelyAccessedBitStructsPass>(); |
209 | TI_ASSERT_INFO(res, |
210 | "The optimize_bit_struct_stores pass must be after the " |
211 | "gather_uniquely_accessed_bit_structs pass when " |
212 | "config.quant_opt_atomic_demotion is true." ); |
213 | DemoteAtomicBitStructStores::run(root, res->uniquely_accessed_bit_structs); |
214 | } |
215 | } |
216 | |
217 | } // namespace irpass |
218 | |
219 | } // namespace taichi::lang |
220 | |