1 | #include "taichi/analysis/gather_uniquely_accessed_pointers.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/analysis.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include <algorithm> |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | class LoopUniqueStmtSearcher : public BasicStmtVisitor { |
11 | private: |
12 | // Constant values that don't change in the loop. |
13 | std::unordered_set<Stmt *> loop_invariant_; |
14 | |
15 | // If loop_unique_[stmt] is -1, the value of stmt is unique among the |
16 | // top-level loop. |
17 | // If loop_unique_[stmt] is x >= 0, the value of stmt is unique to |
18 | // the x-th loop index. |
19 | std::unordered_map<Stmt *, int> loop_unique_; |
20 | |
21 | public: |
22 | // The number of loop indices of the top-level loop. |
23 | // -1 means uninitialized. |
24 | int num_different_loop_indices{-1}; |
25 | using BasicStmtVisitor::visit; |
26 | |
27 | LoopUniqueStmtSearcher() { |
28 | allow_undefined_visitor = true; |
29 | invoke_default_visitor = true; |
30 | } |
31 | |
32 | void visit(LoopIndexStmt *stmt) override { |
33 | if (stmt->loop->is<OffloadedStmt>()) |
34 | loop_unique_[stmt] = stmt->index; |
35 | } |
36 | |
37 | void visit(LoopUniqueStmt *stmt) override { |
38 | loop_unique_[stmt] = -1; |
39 | } |
40 | |
41 | void visit(ConstStmt *stmt) override { |
42 | loop_invariant_.insert(stmt); |
43 | } |
44 | |
45 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
46 | loop_invariant_.insert(stmt); |
47 | } |
48 | |
49 | void visit(UnaryOpStmt *stmt) override { |
50 | if (loop_invariant_.count(stmt->operand) > 0) { |
51 | loop_invariant_.insert(stmt); |
52 | } |
53 | |
54 | // op loop-unique -> loop-unique |
55 | if (loop_unique_.count(stmt->operand) > 0 && |
56 | (stmt->op_type == UnaryOpType::neg)) { |
57 | // TODO: Other injective unary operations |
58 | loop_unique_[stmt] = loop_unique_[stmt->operand]; |
59 | } |
60 | } |
61 | |
62 | void visit(DecorationStmt *stmt) override { |
63 | if (stmt->decoration.size() == 2 && |
64 | stmt->decoration[0] == |
65 | uint32_t(DecorationStmt::Decoration::kLoopUnique)) { |
66 | if (loop_unique_.find(stmt->operand) == loop_unique_.end()) { |
67 | // This decoration exists IFF we are looping over NDArray (or any other |
68 | // cases where the array index is linearized by the codegen) In that |
69 | // case the original loop dimensions have been reduced to 1D. |
70 | loop_unique_[stmt->operand] = stmt->decoration[1]; |
71 | num_different_loop_indices = std::max(loop_unique_[stmt->operand] + 1, |
72 | num_different_loop_indices); |
73 | } |
74 | } |
75 | } |
76 | |
77 | void visit(BinaryOpStmt *stmt) override { |
78 | if (loop_invariant_.count(stmt->lhs) > 0 && |
79 | loop_invariant_.count(stmt->rhs) > 0) { |
80 | loop_invariant_.insert(stmt); |
81 | } |
82 | |
83 | // loop-unique op loop-invariant -> loop-unique |
84 | if ((loop_unique_.count(stmt->lhs) > 0 && |
85 | loop_invariant_.count(stmt->rhs) > 0) && |
86 | (stmt->op_type == BinaryOpType::add || |
87 | stmt->op_type == BinaryOpType::sub || |
88 | stmt->op_type == BinaryOpType::bit_xor)) { |
89 | // TODO: Other operations |
90 | loop_unique_[stmt] = loop_unique_[stmt->lhs]; |
91 | } |
92 | |
93 | // loop-invariant op loop-unique -> loop-unique |
94 | if ((loop_invariant_.count(stmt->lhs) > 0 && |
95 | loop_unique_.count(stmt->rhs) > 0) && |
96 | (stmt->op_type == BinaryOpType::add || |
97 | stmt->op_type == BinaryOpType::sub || |
98 | stmt->op_type == BinaryOpType::bit_xor)) { |
99 | loop_unique_[stmt] = loop_unique_[stmt->rhs]; |
100 | } |
101 | } |
102 | |
103 | bool is_partially_loop_unique(Stmt *stmt) const { |
104 | return loop_unique_.find(stmt) != loop_unique_.end(); |
105 | } |
106 | |
107 | bool is_ptr_indices_loop_unique(GlobalPtrStmt *stmt) const { |
108 | // Check if the address is loop-unique, i.e., stmt contains |
109 | // either a loop-unique index or all top-level loop indices. |
110 | TI_ASSERT(num_different_loop_indices != -1); |
111 | std::vector<int> loop_indices; |
112 | loop_indices.reserve(stmt->indices.size()); |
113 | for (auto &index : stmt->indices) { |
114 | auto loop_unique_index = loop_unique_.find(index); |
115 | if (loop_unique_index != loop_unique_.end()) { |
116 | if (loop_unique_index->second == -1) { |
117 | // LoopUniqueStmt |
118 | return true; |
119 | } else { |
120 | // LoopIndexStmt |
121 | loop_indices.push_back(loop_unique_index->second); |
122 | } |
123 | } |
124 | } |
125 | std::sort(loop_indices.begin(), loop_indices.end()); |
126 | auto current_num_different_loop_indices = |
127 | std::unique(loop_indices.begin(), loop_indices.end()) - |
128 | loop_indices.begin(); |
129 | // for i, j in x: |
130 | // a[j, i] is loop-unique |
131 | // b[i, i] is not loop-unique (because there's no j) |
132 | return current_num_different_loop_indices == num_different_loop_indices; |
133 | } |
134 | |
135 | bool is_ptr_indices_loop_unique(ExternalPtrStmt *stmt) const { |
136 | // Check if the address is loop-unique, i.e., stmt contains |
137 | // either a loop-unique index or all top-level loop indices. |
138 | TI_ASSERT(num_different_loop_indices != -1); |
139 | std::vector<int> loop_indices; |
140 | loop_indices.reserve(stmt->indices.size()); |
141 | for (auto &index : stmt->indices) { |
142 | auto loop_unique_index = loop_unique_.find(index); |
143 | if (loop_unique_index != loop_unique_.end()) { |
144 | if (loop_unique_index->second == -1) { |
145 | // LoopUniqueStmt |
146 | return true; |
147 | } else { |
148 | // LoopIndexStmt |
149 | loop_indices.push_back(loop_unique_index->second); |
150 | } |
151 | } |
152 | } |
153 | std::sort(loop_indices.begin(), loop_indices.end()); |
154 | auto current_num_different_loop_indices = |
155 | std::unique(loop_indices.begin(), loop_indices.end()) - |
156 | loop_indices.begin(); |
157 | |
158 | // for i, j in x: |
159 | // a[j, i] is loop-unique |
160 | // b[i, i] is not loop-unique (because there's no j) |
161 | // c[j, i, 1] is loop-unique |
162 | return current_num_different_loop_indices == num_different_loop_indices; |
163 | } |
164 | }; |
165 | |
166 | class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { |
167 | private: |
168 | LoopUniqueStmtSearcher loop_unique_stmt_searcher_; |
169 | |
170 | // Search SNodes that are uniquely accessed, i.e., accessed by |
171 | // one GlobalPtrStmt (or by definitely-same-address GlobalPtrStmts), |
172 | // and that GlobalPtrStmt's address is loop-unique. |
173 | std::unordered_map<const SNode *, GlobalPtrStmt *> accessed_pointer_; |
174 | std::unordered_map<const SNode *, GlobalPtrStmt *> rel_access_pointer_; |
175 | |
176 | // Search any_arrs that are uniquely accessed. Maps: ArgID -> ExternalPtrStmt |
177 | std::unordered_map<int, ExternalPtrStmt *> accessed_arr_pointer_; |
178 | |
179 | public: |
180 | using BasicStmtVisitor::visit; |
181 | |
182 | UniquelyAccessedSNodeSearcher() { |
183 | allow_undefined_visitor = true; |
184 | invoke_default_visitor = true; |
185 | } |
186 | |
187 | void visit(GlobalPtrStmt *stmt) override { |
188 | auto snode = stmt->snode; |
189 | // mesh-for loop unique |
190 | if (stmt->indices.size() == 1 && |
191 | stmt->indices[0]->is<MeshIndexConversionStmt>()) { |
192 | auto idx = stmt->indices[0]->as<MeshIndexConversionStmt>()->idx; |
193 | while (idx->is<MeshIndexConversionStmt>()) { // special case: l2g + |
194 | // g2r |
195 | idx = idx->as<MeshIndexConversionStmt>()->idx; |
196 | } |
197 | if (idx->is<LoopIndexStmt>() && |
198 | idx->as<LoopIndexStmt>()->is_mesh_index()) { // from-end access |
199 | if (rel_access_pointer_.find(snode) == |
200 | rel_access_pointer_.end()) { // not accessed by neibhours yet |
201 | accessed_pointer_[snode] = stmt; |
202 | } else { // accessed by neibhours, so it's not unique |
203 | accessed_pointer_[snode] = nullptr; |
204 | } |
205 | } else { // to-end access |
206 | rel_access_pointer_[snode] = stmt; |
207 | accessed_pointer_[snode] = |
208 | nullptr; // from-end access should not be unique |
209 | } |
210 | } |
211 | // Range-for / struct-for |
212 | auto accessed_ptr = accessed_pointer_.find(snode); |
213 | if (accessed_ptr == accessed_pointer_.end()) { |
214 | if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) { |
215 | accessed_pointer_[snode] = stmt; |
216 | } else { |
217 | accessed_pointer_[snode] = nullptr; // not loop-unique |
218 | } |
219 | } else { |
220 | if (!irpass::analysis::definitely_same_address(accessed_ptr->second, |
221 | stmt)) { |
222 | accessed_ptr->second = nullptr; // not uniquely accessed |
223 | } |
224 | } |
225 | } |
226 | |
227 | void visit(ExternalPtrStmt *stmt) override { |
228 | // A memory location of an ExternalPtrStmt depends on the indices |
229 | // If the accessed indices are loop unique, |
230 | // the accessed memory location is loop unique |
231 | ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as<ArgLoadStmt>(); |
232 | int arg_id = arg_load_stmt->arg_id; |
233 | |
234 | auto accessed_ptr = accessed_arr_pointer_.find(arg_id); |
235 | |
236 | bool stmt_loop_unique = |
237 | loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt); |
238 | |
239 | if (!stmt_loop_unique) { |
240 | accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique |
241 | } else { |
242 | if (accessed_ptr == accessed_arr_pointer_.end()) { |
243 | // First time using arr @ arg_id |
244 | accessed_arr_pointer_[arg_id] = stmt; |
245 | } else { |
246 | /** |
247 | * We know stmt->base_ptr and the previously recorded pointers |
248 | * are loop-unique. We need to figure out whether their loop-unique |
249 | * indices are the same while ignoring the others. |
250 | * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed |
251 | * a[i, j, 1] and a[j, i, 2] are not uniquely accessed |
252 | * a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed |
253 | * This is a bit stricter than needed. |
254 | * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed |
255 | * However this is probably not common and improvements can be made |
256 | * in a future patch. |
257 | */ |
258 | if (accessed_ptr->second) { |
259 | ExternalPtrStmt *other_ptr = accessed_ptr->second; |
260 | TI_ASSERT(stmt->indices.size() == other_ptr->indices.size()); |
261 | for (int axis = 0; axis < stmt->indices.size(); axis++) { |
262 | Stmt *this_index = stmt->indices[axis]; |
263 | Stmt *other_index = other_ptr->indices[axis]; |
264 | // We only compare unique indices here. |
265 | // Since both pointers are loop-unique, all the unique indices |
266 | // need to be the same for both to be uniquely accessed |
267 | if (loop_unique_stmt_searcher_.is_partially_loop_unique( |
268 | this_index)) { |
269 | if (!irpass::analysis::same_value(this_index, other_index)) { |
270 | // Not equal -> not uniquely accessed |
271 | accessed_arr_pointer_[arg_id] = nullptr; |
272 | break; |
273 | } |
274 | } |
275 | } |
276 | } |
277 | } |
278 | } |
279 | } |
280 | |
281 | static std::pair<std::unordered_map<const SNode *, GlobalPtrStmt *>, |
282 | std::unordered_map<int, ExternalPtrStmt *>> |
283 | run(IRNode *root) { |
284 | TI_ASSERT(root->is<OffloadedStmt>()); |
285 | auto offload = root->as<OffloadedStmt>(); |
286 | UniquelyAccessedSNodeSearcher searcher; |
287 | if (offload->task_type == OffloadedTaskType::range_for || |
288 | offload->task_type == OffloadedTaskType::mesh_for) { |
289 | searcher.loop_unique_stmt_searcher_.num_different_loop_indices = 1; |
290 | } else if (offload->task_type == OffloadedTaskType::struct_for) { |
291 | searcher.loop_unique_stmt_searcher_.num_different_loop_indices = |
292 | offload->snode->num_active_indices; |
293 | } else { |
294 | // serial |
295 | searcher.loop_unique_stmt_searcher_.num_different_loop_indices = 0; |
296 | } |
297 | root->accept(&searcher.loop_unique_stmt_searcher_); |
298 | root->accept(&searcher); |
299 | |
300 | return std::make_pair(searcher.accessed_pointer_, |
301 | searcher.accessed_arr_pointer_); |
302 | } |
303 | }; |
304 | |
305 | class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor { |
306 | private: |
307 | std::unordered_map<OffloadedStmt *, |
308 | std::unordered_map<const SNode *, GlobalPtrStmt *>> |
309 | result_; |
310 | |
311 | public: |
312 | using BasicStmtVisitor::visit; |
313 | |
314 | UniquelyAccessedBitStructGatherer() { |
315 | allow_undefined_visitor = true; |
316 | invoke_default_visitor = false; |
317 | } |
318 | |
319 | void visit(OffloadedStmt *stmt) override { |
320 | if (stmt->task_type == OffloadedTaskType::range_for || |
321 | stmt->task_type == OffloadedTaskType::mesh_for || |
322 | stmt->task_type == OffloadedTaskType::struct_for) { |
323 | auto &loop_unique_bit_struct = result_[stmt]; |
324 | auto loop_unique_ptr = |
325 | irpass::analysis::gather_uniquely_accessed_pointers(stmt).first; |
326 | for (auto &it : loop_unique_ptr) { |
327 | auto *snode = it.first; |
328 | auto *ptr1 = it.second; |
329 | if (ptr1 != nullptr && ptr1->indices.size() > 0 && |
330 | ptr1->indices[0]->is<MeshIndexConversionStmt>()) { |
331 | continue; |
332 | } |
333 | if (snode->is_bit_level) { |
334 | // Find the nearest non-bit-level ancestor |
335 | while (snode->is_bit_level) { |
336 | snode = snode->parent; |
337 | } |
338 | // Check whether uniquely accessed |
339 | auto accessed_ptr = loop_unique_bit_struct.find(snode); |
340 | if (accessed_ptr == loop_unique_bit_struct.end()) { |
341 | loop_unique_bit_struct[snode] = ptr1; |
342 | } else { |
343 | if (ptr1 == nullptr) { |
344 | accessed_ptr->second = nullptr; |
345 | continue; |
346 | } |
347 | auto *ptr2 = accessed_ptr->second; |
348 | TI_ASSERT(ptr1->indices.size() == ptr2->indices.size()); |
349 | for (int id = 0; id < (int)ptr1->indices.size(); id++) { |
350 | if (!irpass::analysis::same_value(ptr1->indices[id], |
351 | ptr2->indices[id])) { |
352 | accessed_ptr->second = nullptr; // not uniquely accessed |
353 | } |
354 | } |
355 | } |
356 | } |
357 | } |
358 | } |
359 | // Do not dive into OffloadedStmt |
360 | } |
361 | |
362 | static std::unordered_map<OffloadedStmt *, |
363 | std::unordered_map<const SNode *, GlobalPtrStmt *>> |
364 | run(IRNode *root) { |
365 | UniquelyAccessedBitStructGatherer gatherer; |
366 | root->accept(&gatherer); |
367 | return gatherer.result_; |
368 | } |
369 | }; |
370 | |
371 | const std::string GatherUniquelyAccessedBitStructsPass::id = |
372 | "GatherUniquelyAccessedBitStructsPass" ; |
373 | |
374 | namespace irpass::analysis { |
375 | std::pair<std::unordered_map<const SNode *, GlobalPtrStmt *>, |
376 | std::unordered_map<int, ExternalPtrStmt *>> |
377 | gather_uniquely_accessed_pointers(IRNode *root) { |
378 | // TODO: What about SNodeOpStmts? |
379 | return UniquelyAccessedSNodeSearcher::run(root); |
380 | } |
381 | |
382 | void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr) { |
383 | amgr->put_pass_result<GatherUniquelyAccessedBitStructsPass>( |
384 | {UniquelyAccessedBitStructGatherer::run(root)}); |
385 | } |
386 | } // namespace irpass::analysis |
387 | |
388 | } // namespace taichi::lang |
389 | |