1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/analysis.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/ir/transforms.h" |
6 | |
7 | #include <algorithm> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | namespace irpass { |
12 | |
13 | namespace { |
14 | |
15 | void detect_read_only_in_task(OffloadedStmt *offload) { |
16 | auto accessed = irpass::analysis::gather_snode_read_writes(offload); |
17 | for (auto snode : accessed.first) { |
18 | if (accessed.second.count(snode) == 0) { |
19 | // read-only SNode |
20 | offload->mem_access_opt.add_flag(snode, SNodeAccessFlag::read_only); |
21 | } |
22 | } |
23 | } |
24 | |
25 | class ExternalPtrAccessVisitor : public BasicStmtVisitor { |
26 | private: |
27 | std::unordered_map<int, ExternalPtrAccess> &map_; |
28 | |
29 | public: |
30 | using BasicStmtVisitor::visit; |
31 | |
32 | explicit ExternalPtrAccessVisitor( |
33 | std::unordered_map<int, ExternalPtrAccess> &map) |
34 | : map_(map) { |
35 | } |
36 | |
37 | void visit(GlobalLoadStmt *stmt) override { |
38 | if (!(stmt->src && stmt->src->is<ExternalPtrStmt>())) |
39 | return; |
40 | |
41 | ExternalPtrStmt *src = stmt->src->cast<ExternalPtrStmt>(); |
42 | ArgLoadStmt *arg = src->base_ptr->cast<ArgLoadStmt>(); |
43 | if (map_.find(arg->arg_id) != map_.end()) { |
44 | map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::READ; |
45 | } else { |
46 | map_[arg->arg_id] = ExternalPtrAccess::READ; |
47 | } |
48 | } |
49 | |
50 | void visit(GlobalStoreStmt *stmt) override { |
51 | if (!(stmt->dest && stmt->dest->is<ExternalPtrStmt>())) |
52 | return; |
53 | |
54 | ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>(); |
55 | ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>(); |
56 | if (map_.find(arg->arg_id) != map_.end()) { |
57 | map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::WRITE; |
58 | } else { |
59 | map_[arg->arg_id] = ExternalPtrAccess::WRITE; |
60 | } |
61 | } |
62 | |
63 | void visit(AtomicOpStmt *stmt) override { |
64 | if (!(stmt->dest && stmt->dest->is<ExternalPtrStmt>())) |
65 | return; |
66 | |
67 | // Atomics modifies existing state (therefore both read & write) |
68 | ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>(); |
69 | ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>(); |
70 | map_[arg->arg_id] = ExternalPtrAccess::WRITE | ExternalPtrAccess::READ; |
71 | } |
72 | }; |
73 | |
74 | } // namespace |
75 | |
76 | void detect_read_only(IRNode *root) { |
77 | if (root->is<Block>()) { |
78 | for (auto &offload : root->as<Block>()->statements) { |
79 | detect_read_only_in_task(offload->as<OffloadedStmt>()); |
80 | } |
81 | } else { |
82 | detect_read_only_in_task(root->as<OffloadedStmt>()); |
83 | } |
84 | } |
85 | |
86 | std::unordered_map<int, ExternalPtrAccess> detect_external_ptr_access_in_task( |
87 | OffloadedStmt *offload) { |
88 | std::unordered_map<int, ExternalPtrAccess> map; |
89 | ExternalPtrAccessVisitor v(map); |
90 | offload->accept(&v); |
91 | return map; |
92 | } |
93 | |
94 | } // namespace irpass |
95 | |
96 | } // namespace taichi::lang |
97 |