1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/analysis.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/ir/transforms.h"
6
7#include <algorithm>
8
9namespace taichi::lang {
10
11namespace irpass {
12
13namespace {
14
15void detect_read_only_in_task(OffloadedStmt *offload) {
16 auto accessed = irpass::analysis::gather_snode_read_writes(offload);
17 for (auto snode : accessed.first) {
18 if (accessed.second.count(snode) == 0) {
19 // read-only SNode
20 offload->mem_access_opt.add_flag(snode, SNodeAccessFlag::read_only);
21 }
22 }
23}
24
25class ExternalPtrAccessVisitor : public BasicStmtVisitor {
26 private:
27 std::unordered_map<int, ExternalPtrAccess> &map_;
28
29 public:
30 using BasicStmtVisitor::visit;
31
32 explicit ExternalPtrAccessVisitor(
33 std::unordered_map<int, ExternalPtrAccess> &map)
34 : map_(map) {
35 }
36
37 void visit(GlobalLoadStmt *stmt) override {
38 if (!(stmt->src && stmt->src->is<ExternalPtrStmt>()))
39 return;
40
41 ExternalPtrStmt *src = stmt->src->cast<ExternalPtrStmt>();
42 ArgLoadStmt *arg = src->base_ptr->cast<ArgLoadStmt>();
43 if (map_.find(arg->arg_id) != map_.end()) {
44 map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::READ;
45 } else {
46 map_[arg->arg_id] = ExternalPtrAccess::READ;
47 }
48 }
49
50 void visit(GlobalStoreStmt *stmt) override {
51 if (!(stmt->dest && stmt->dest->is<ExternalPtrStmt>()))
52 return;
53
54 ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>();
55 ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>();
56 if (map_.find(arg->arg_id) != map_.end()) {
57 map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::WRITE;
58 } else {
59 map_[arg->arg_id] = ExternalPtrAccess::WRITE;
60 }
61 }
62
63 void visit(AtomicOpStmt *stmt) override {
64 if (!(stmt->dest && stmt->dest->is<ExternalPtrStmt>()))
65 return;
66
67 // Atomics modifies existing state (therefore both read & write)
68 ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>();
69 ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>();
70 map_[arg->arg_id] = ExternalPtrAccess::WRITE | ExternalPtrAccess::READ;
71 }
72};
73
74} // namespace
75
76void detect_read_only(IRNode *root) {
77 if (root->is<Block>()) {
78 for (auto &offload : root->as<Block>()->statements) {
79 detect_read_only_in_task(offload->as<OffloadedStmt>());
80 }
81 } else {
82 detect_read_only_in_task(root->as<OffloadedStmt>());
83 }
84}
85
86std::unordered_map<int, ExternalPtrAccess> detect_external_ptr_access_in_task(
87 OffloadedStmt *offload) {
88 std::unordered_map<int, ExternalPtrAccess> map;
89 ExternalPtrAccessVisitor v(map);
90 offload->accept(&v);
91 return map;
92}
93
94} // namespace irpass
95
96} // namespace taichi::lang
97