1#include "taichi/ir/ir.h"
2#include "taichi/ir/snode.h"
3#include "taichi/ir/visitors.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/statements.h"
6
7namespace taichi::lang {
8
9namespace irpass::analysis {
10
11// Returns the set of SNodes that are read or written
12std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
13gather_snode_read_writes(IRNode *root) {
14 std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>> accessed;
15 irpass::analysis::gather_statements(root, [&](Stmt *stmt) {
16 Stmt *ptr = nullptr;
17 bool read = false, write = false;
18 if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
19 read = true;
20 ptr = global_load->src;
21 } else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
22 write = true;
23 ptr = global_store->dest;
24 } else if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
25 read = true;
26 write = true;
27 ptr = global_atomic->dest;
28 }
29 if (ptr) {
30 if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
31 if (read)
32 accessed.first.emplace(global_ptr->snode);
33 if (write)
34 accessed.second.emplace(global_ptr->snode);
35 }
36 }
37 return false;
38 });
39 return accessed;
40}
41} // namespace irpass::analysis
42
43} // namespace taichi::lang
44