1#include "taichi/analysis/gather_uniquely_accessed_pointers.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/analysis.h"
4#include "taichi/ir/statements.h"
5#include "taichi/ir/visitors.h"
6#include <algorithm>
7
8namespace taichi::lang {
9
10class LoopUniqueStmtSearcher : public BasicStmtVisitor {
11 private:
12 // Constant values that don't change in the loop.
13 std::unordered_set<Stmt *> loop_invariant_;
14
15 // If loop_unique_[stmt] is -1, the value of stmt is unique among the
16 // top-level loop.
17 // If loop_unique_[stmt] is x >= 0, the value of stmt is unique to
18 // the x-th loop index.
19 std::unordered_map<Stmt *, int> loop_unique_;
20
21 public:
22 // The number of loop indices of the top-level loop.
23 // -1 means uninitialized.
24 int num_different_loop_indices{-1};
25 using BasicStmtVisitor::visit;
26
27 LoopUniqueStmtSearcher() {
28 allow_undefined_visitor = true;
29 invoke_default_visitor = true;
30 }
31
32 void visit(LoopIndexStmt *stmt) override {
33 if (stmt->loop->is<OffloadedStmt>())
34 loop_unique_[stmt] = stmt->index;
35 }
36
37 void visit(LoopUniqueStmt *stmt) override {
38 loop_unique_[stmt] = -1;
39 }
40
41 void visit(ConstStmt *stmt) override {
42 loop_invariant_.insert(stmt);
43 }
44
45 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
46 loop_invariant_.insert(stmt);
47 }
48
49 void visit(UnaryOpStmt *stmt) override {
50 if (loop_invariant_.count(stmt->operand) > 0) {
51 loop_invariant_.insert(stmt);
52 }
53
54 // op loop-unique -> loop-unique
55 if (loop_unique_.count(stmt->operand) > 0 &&
56 (stmt->op_type == UnaryOpType::neg)) {
57 // TODO: Other injective unary operations
58 loop_unique_[stmt] = loop_unique_[stmt->operand];
59 }
60 }
61
62 void visit(DecorationStmt *stmt) override {
63 if (stmt->decoration.size() == 2 &&
64 stmt->decoration[0] ==
65 uint32_t(DecorationStmt::Decoration::kLoopUnique)) {
66 if (loop_unique_.find(stmt->operand) == loop_unique_.end()) {
67 // This decoration exists IFF we are looping over NDArray (or any other
68 // cases where the array index is linearized by the codegen) In that
69 // case the original loop dimensions have been reduced to 1D.
70 loop_unique_[stmt->operand] = stmt->decoration[1];
71 num_different_loop_indices = std::max(loop_unique_[stmt->operand] + 1,
72 num_different_loop_indices);
73 }
74 }
75 }
76
77 void visit(BinaryOpStmt *stmt) override {
78 if (loop_invariant_.count(stmt->lhs) > 0 &&
79 loop_invariant_.count(stmt->rhs) > 0) {
80 loop_invariant_.insert(stmt);
81 }
82
83 // loop-unique op loop-invariant -> loop-unique
84 if ((loop_unique_.count(stmt->lhs) > 0 &&
85 loop_invariant_.count(stmt->rhs) > 0) &&
86 (stmt->op_type == BinaryOpType::add ||
87 stmt->op_type == BinaryOpType::sub ||
88 stmt->op_type == BinaryOpType::bit_xor)) {
89 // TODO: Other operations
90 loop_unique_[stmt] = loop_unique_[stmt->lhs];
91 }
92
93 // loop-invariant op loop-unique -> loop-unique
94 if ((loop_invariant_.count(stmt->lhs) > 0 &&
95 loop_unique_.count(stmt->rhs) > 0) &&
96 (stmt->op_type == BinaryOpType::add ||
97 stmt->op_type == BinaryOpType::sub ||
98 stmt->op_type == BinaryOpType::bit_xor)) {
99 loop_unique_[stmt] = loop_unique_[stmt->rhs];
100 }
101 }
102
103 bool is_partially_loop_unique(Stmt *stmt) const {
104 return loop_unique_.find(stmt) != loop_unique_.end();
105 }
106
107 bool is_ptr_indices_loop_unique(GlobalPtrStmt *stmt) const {
108 // Check if the address is loop-unique, i.e., stmt contains
109 // either a loop-unique index or all top-level loop indices.
110 TI_ASSERT(num_different_loop_indices != -1);
111 std::vector<int> loop_indices;
112 loop_indices.reserve(stmt->indices.size());
113 for (auto &index : stmt->indices) {
114 auto loop_unique_index = loop_unique_.find(index);
115 if (loop_unique_index != loop_unique_.end()) {
116 if (loop_unique_index->second == -1) {
117 // LoopUniqueStmt
118 return true;
119 } else {
120 // LoopIndexStmt
121 loop_indices.push_back(loop_unique_index->second);
122 }
123 }
124 }
125 std::sort(loop_indices.begin(), loop_indices.end());
126 auto current_num_different_loop_indices =
127 std::unique(loop_indices.begin(), loop_indices.end()) -
128 loop_indices.begin();
129 // for i, j in x:
130 // a[j, i] is loop-unique
131 // b[i, i] is not loop-unique (because there's no j)
132 return current_num_different_loop_indices == num_different_loop_indices;
133 }
134
135 bool is_ptr_indices_loop_unique(ExternalPtrStmt *stmt) const {
136 // Check if the address is loop-unique, i.e., stmt contains
137 // either a loop-unique index or all top-level loop indices.
138 TI_ASSERT(num_different_loop_indices != -1);
139 std::vector<int> loop_indices;
140 loop_indices.reserve(stmt->indices.size());
141 for (auto &index : stmt->indices) {
142 auto loop_unique_index = loop_unique_.find(index);
143 if (loop_unique_index != loop_unique_.end()) {
144 if (loop_unique_index->second == -1) {
145 // LoopUniqueStmt
146 return true;
147 } else {
148 // LoopIndexStmt
149 loop_indices.push_back(loop_unique_index->second);
150 }
151 }
152 }
153 std::sort(loop_indices.begin(), loop_indices.end());
154 auto current_num_different_loop_indices =
155 std::unique(loop_indices.begin(), loop_indices.end()) -
156 loop_indices.begin();
157
158 // for i, j in x:
159 // a[j, i] is loop-unique
160 // b[i, i] is not loop-unique (because there's no j)
161 // c[j, i, 1] is loop-unique
162 return current_num_different_loop_indices == num_different_loop_indices;
163 }
164};
165
166class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
167 private:
168 LoopUniqueStmtSearcher loop_unique_stmt_searcher_;
169
170 // Search SNodes that are uniquely accessed, i.e., accessed by
171 // one GlobalPtrStmt (or by definitely-same-address GlobalPtrStmts),
172 // and that GlobalPtrStmt's address is loop-unique.
173 std::unordered_map<const SNode *, GlobalPtrStmt *> accessed_pointer_;
174 std::unordered_map<const SNode *, GlobalPtrStmt *> rel_access_pointer_;
175
176 // Search any_arrs that are uniquely accessed. Maps: ArgID -> ExternalPtrStmt
177 std::unordered_map<int, ExternalPtrStmt *> accessed_arr_pointer_;
178
179 public:
180 using BasicStmtVisitor::visit;
181
182 UniquelyAccessedSNodeSearcher() {
183 allow_undefined_visitor = true;
184 invoke_default_visitor = true;
185 }
186
187 void visit(GlobalPtrStmt *stmt) override {
188 auto snode = stmt->snode;
189 // mesh-for loop unique
190 if (stmt->indices.size() == 1 &&
191 stmt->indices[0]->is<MeshIndexConversionStmt>()) {
192 auto idx = stmt->indices[0]->as<MeshIndexConversionStmt>()->idx;
193 while (idx->is<MeshIndexConversionStmt>()) { // special case: l2g +
194 // g2r
195 idx = idx->as<MeshIndexConversionStmt>()->idx;
196 }
197 if (idx->is<LoopIndexStmt>() &&
198 idx->as<LoopIndexStmt>()->is_mesh_index()) { // from-end access
199 if (rel_access_pointer_.find(snode) ==
200 rel_access_pointer_.end()) { // not accessed by neibhours yet
201 accessed_pointer_[snode] = stmt;
202 } else { // accessed by neibhours, so it's not unique
203 accessed_pointer_[snode] = nullptr;
204 }
205 } else { // to-end access
206 rel_access_pointer_[snode] = stmt;
207 accessed_pointer_[snode] =
208 nullptr; // from-end access should not be unique
209 }
210 }
211 // Range-for / struct-for
212 auto accessed_ptr = accessed_pointer_.find(snode);
213 if (accessed_ptr == accessed_pointer_.end()) {
214 if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
215 accessed_pointer_[snode] = stmt;
216 } else {
217 accessed_pointer_[snode] = nullptr; // not loop-unique
218 }
219 } else {
220 if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
221 stmt)) {
222 accessed_ptr->second = nullptr; // not uniquely accessed
223 }
224 }
225 }
226
227 void visit(ExternalPtrStmt *stmt) override {
228 // A memory location of an ExternalPtrStmt depends on the indices
229 // If the accessed indices are loop unique,
230 // the accessed memory location is loop unique
231 ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as<ArgLoadStmt>();
232 int arg_id = arg_load_stmt->arg_id;
233
234 auto accessed_ptr = accessed_arr_pointer_.find(arg_id);
235
236 bool stmt_loop_unique =
237 loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);
238
239 if (!stmt_loop_unique) {
240 accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique
241 } else {
242 if (accessed_ptr == accessed_arr_pointer_.end()) {
243 // First time using arr @ arg_id
244 accessed_arr_pointer_[arg_id] = stmt;
245 } else {
246 /**
247 * We know stmt->base_ptr and the previously recorded pointers
248 * are loop-unique. We need to figure out whether their loop-unique
249 * indices are the same while ignoring the others.
250 * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
251 * a[i, j, 1] and a[j, i, 2] are not uniquely accessed
252 * a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
253 * This is a bit stricter than needed.
254 * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
255 * However this is probably not common and improvements can be made
256 * in a future patch.
257 */
258 if (accessed_ptr->second) {
259 ExternalPtrStmt *other_ptr = accessed_ptr->second;
260 TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
261 for (int axis = 0; axis < stmt->indices.size(); axis++) {
262 Stmt *this_index = stmt->indices[axis];
263 Stmt *other_index = other_ptr->indices[axis];
264 // We only compare unique indices here.
265 // Since both pointers are loop-unique, all the unique indices
266 // need to be the same for both to be uniquely accessed
267 if (loop_unique_stmt_searcher_.is_partially_loop_unique(
268 this_index)) {
269 if (!irpass::analysis::same_value(this_index, other_index)) {
270 // Not equal -> not uniquely accessed
271 accessed_arr_pointer_[arg_id] = nullptr;
272 break;
273 }
274 }
275 }
276 }
277 }
278 }
279 }
280
281 static std::pair<std::unordered_map<const SNode *, GlobalPtrStmt *>,
282 std::unordered_map<int, ExternalPtrStmt *>>
283 run(IRNode *root) {
284 TI_ASSERT(root->is<OffloadedStmt>());
285 auto offload = root->as<OffloadedStmt>();
286 UniquelyAccessedSNodeSearcher searcher;
287 if (offload->task_type == OffloadedTaskType::range_for ||
288 offload->task_type == OffloadedTaskType::mesh_for) {
289 searcher.loop_unique_stmt_searcher_.num_different_loop_indices = 1;
290 } else if (offload->task_type == OffloadedTaskType::struct_for) {
291 searcher.loop_unique_stmt_searcher_.num_different_loop_indices =
292 offload->snode->num_active_indices;
293 } else {
294 // serial
295 searcher.loop_unique_stmt_searcher_.num_different_loop_indices = 0;
296 }
297 root->accept(&searcher.loop_unique_stmt_searcher_);
298 root->accept(&searcher);
299
300 return std::make_pair(searcher.accessed_pointer_,
301 searcher.accessed_arr_pointer_);
302 }
303};
304
305class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor {
306 private:
307 std::unordered_map<OffloadedStmt *,
308 std::unordered_map<const SNode *, GlobalPtrStmt *>>
309 result_;
310
311 public:
312 using BasicStmtVisitor::visit;
313
314 UniquelyAccessedBitStructGatherer() {
315 allow_undefined_visitor = true;
316 invoke_default_visitor = false;
317 }
318
319 void visit(OffloadedStmt *stmt) override {
320 if (stmt->task_type == OffloadedTaskType::range_for ||
321 stmt->task_type == OffloadedTaskType::mesh_for ||
322 stmt->task_type == OffloadedTaskType::struct_for) {
323 auto &loop_unique_bit_struct = result_[stmt];
324 auto loop_unique_ptr =
325 irpass::analysis::gather_uniquely_accessed_pointers(stmt).first;
326 for (auto &it : loop_unique_ptr) {
327 auto *snode = it.first;
328 auto *ptr1 = it.second;
329 if (ptr1 != nullptr && ptr1->indices.size() > 0 &&
330 ptr1->indices[0]->is<MeshIndexConversionStmt>()) {
331 continue;
332 }
333 if (snode->is_bit_level) {
334 // Find the nearest non-bit-level ancestor
335 while (snode->is_bit_level) {
336 snode = snode->parent;
337 }
338 // Check whether uniquely accessed
339 auto accessed_ptr = loop_unique_bit_struct.find(snode);
340 if (accessed_ptr == loop_unique_bit_struct.end()) {
341 loop_unique_bit_struct[snode] = ptr1;
342 } else {
343 if (ptr1 == nullptr) {
344 accessed_ptr->second = nullptr;
345 continue;
346 }
347 auto *ptr2 = accessed_ptr->second;
348 TI_ASSERT(ptr1->indices.size() == ptr2->indices.size());
349 for (int id = 0; id < (int)ptr1->indices.size(); id++) {
350 if (!irpass::analysis::same_value(ptr1->indices[id],
351 ptr2->indices[id])) {
352 accessed_ptr->second = nullptr; // not uniquely accessed
353 }
354 }
355 }
356 }
357 }
358 }
359 // Do not dive into OffloadedStmt
360 }
361
362 static std::unordered_map<OffloadedStmt *,
363 std::unordered_map<const SNode *, GlobalPtrStmt *>>
364 run(IRNode *root) {
365 UniquelyAccessedBitStructGatherer gatherer;
366 root->accept(&gatherer);
367 return gatherer.result_;
368 }
369};
370
371const std::string GatherUniquelyAccessedBitStructsPass::id =
372 "GatherUniquelyAccessedBitStructsPass";
373
374namespace irpass::analysis {
375std::pair<std::unordered_map<const SNode *, GlobalPtrStmt *>,
376 std::unordered_map<int, ExternalPtrStmt *>>
377gather_uniquely_accessed_pointers(IRNode *root) {
378 // TODO: What about SNodeOpStmts?
379 return UniquelyAccessedSNodeSearcher::run(root);
380}
381
382void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr) {
383 amgr->put_pass_result<GatherUniquelyAccessedBitStructsPass>(
384 {UniquelyAccessedBitStructGatherer::run(root)});
385}
386} // namespace irpass::analysis
387
388} // namespace taichi::lang
389