1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/snode.h" |
3 | #include "taichi/ir/visitors.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/statements.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | namespace irpass::analysis { |
10 | |
11 | // Returns the set of SNodes that are read or written |
12 | std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>> |
13 | gather_snode_read_writes(IRNode *root) { |
14 | std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>> accessed; |
15 | irpass::analysis::gather_statements(root, [&](Stmt *stmt) { |
16 | Stmt *ptr = nullptr; |
17 | bool read = false, write = false; |
18 | if (auto global_load = stmt->cast<GlobalLoadStmt>()) { |
19 | read = true; |
20 | ptr = global_load->src; |
21 | } else if (auto global_store = stmt->cast<GlobalStoreStmt>()) { |
22 | write = true; |
23 | ptr = global_store->dest; |
24 | } else if (auto global_atomic = stmt->cast<AtomicOpStmt>()) { |
25 | read = true; |
26 | write = true; |
27 | ptr = global_atomic->dest; |
28 | } |
29 | if (ptr) { |
30 | if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) { |
31 | if (read) |
32 | accessed.first.emplace(global_ptr->snode); |
33 | if (write) |
34 | accessed.second.emplace(global_ptr->snode); |
35 | } |
36 | } |
37 | return false; |
38 | }); |
39 | return accessed; |
40 | } |
41 | } // namespace irpass::analysis |
42 | |
43 | } // namespace taichi::lang |
44 |