1 | #pragma once |
2 | |
3 | #include "taichi/program/compile_config.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/ir/statements.h" |
6 | #include "taichi/ir/mesh.h" |
7 | |
8 | #include <set> |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | class MeshBLSCache { |
13 | public: |
14 | using AccessFlag = taichi::lang::AccessFlag; |
15 | using Rec = std::map<std::pair<mesh::MeshElementType, mesh::ConvType>, |
16 | std::set<std::pair<SNode *, AccessFlag>>>; |
17 | |
18 | SNode *snode{nullptr}; |
19 | mesh::MeshElementType element_type; |
20 | mesh::ConvType conv_type; |
21 | |
22 | bool initialized; |
23 | bool finalized; |
24 | bool loop_index; |
25 | int unique_accessed; |
26 | AccessFlag total_flags; |
27 | |
28 | MeshBLSCache() = default; |
29 | |
30 | explicit MeshBLSCache(SNode *snode) : snode(snode) { |
31 | total_flags = AccessFlag(0); |
32 | initialized = false; |
33 | finalized = false; |
34 | loop_index = false; |
35 | unique_accessed = 0; |
36 | } |
37 | |
38 | bool access(mesh::MeshElementType element_type, |
39 | mesh::ConvType conv_type, |
40 | AccessFlag flags, |
41 | Stmt *idx) { |
42 | if (!initialized) { |
43 | initialized = true; |
44 | this->conv_type = conv_type; |
45 | this->element_type = element_type; |
46 | } else { |
47 | if (this->conv_type != conv_type || this->element_type != element_type) |
48 | return false; |
49 | } |
50 | this->total_flags |= flags; |
51 | if (idx->is<LoopIndexStmt>()) { |
52 | loop_index = true; |
53 | } else { |
54 | unique_accessed++; |
55 | } |
56 | return true; |
57 | } |
58 | |
59 | void finalize(Rec &rec) { |
60 | TI_ASSERT(!finalized); |
61 | finalized = true; |
62 | if (initialized) { |
63 | const auto cache_type = std::make_pair(element_type, conv_type); |
64 | auto ptr = rec.find(cache_type); |
65 | if (ptr == rec.end()) { |
66 | ptr = rec.emplace(std::piecewise_construct, |
67 | std::forward_as_tuple(cache_type), |
68 | std::forward_as_tuple()) |
69 | .first; |
70 | } |
71 | ptr->second.insert(std::make_pair(snode, total_flags)); |
72 | } |
73 | } |
74 | }; |
75 | |
76 | class MeshBLSCaches { |
77 | public: |
78 | std::map<SNode *, MeshBLSCache> caches; |
79 | |
80 | using AccessFlag = MeshBLSCache::AccessFlag; |
81 | using Rec = MeshBLSCache::Rec; |
82 | |
83 | void insert(SNode *snode) { |
84 | if (caches.find(snode) == caches.end()) { |
85 | caches.emplace(std::piecewise_construct, std::forward_as_tuple(snode), |
86 | std::forward_as_tuple(snode)); |
87 | } else { |
88 | TI_ERROR("mesh::MeshBLSCaches for {} already exists." , |
89 | snode->node_type_name); |
90 | } |
91 | } |
92 | |
93 | bool access(SNode *snode, |
94 | mesh::MeshElementType element_type, |
95 | mesh::ConvType conv_type, |
96 | AccessFlag flags, |
97 | Stmt *idx) { |
98 | if (caches.find(snode) == caches.end()) |
99 | return false; |
100 | return caches.find(snode)->second.access(element_type, conv_type, flags, |
101 | idx); |
102 | } |
103 | |
104 | Rec finalize() { |
105 | Rec rec; |
106 | for (auto &cache : caches) { |
107 | cache.second.finalize(rec); |
108 | } |
109 | return rec; |
110 | } |
111 | |
112 | bool has(SNode *snode) { |
113 | return caches.find(snode) != caches.end(); |
114 | } |
115 | |
116 | MeshBLSCache &get(SNode *snode) { |
117 | TI_ASSERT(caches.find(snode) != caches.end()); |
118 | return caches[snode]; |
119 | } |
120 | }; |
121 | |
122 | // Figure out accessed SNodes, and their ranges in this for stmt |
123 | class MeshBLSAnalyzer : public BasicStmtVisitor { |
124 | using BasicStmtVisitor::visit; |
125 | |
126 | public: |
127 | MeshBLSAnalyzer(OffloadedStmt *for_stmt, |
128 | MeshBLSCaches *caches, |
129 | bool auto_mesh_local, |
130 | const CompileConfig &config); |
131 | |
132 | void visit(GlobalPtrStmt *stmt) override { |
133 | } |
134 | |
135 | // Do not eliminate global data access |
136 | void visit(GlobalLoadStmt *stmt) override; |
137 | |
138 | void visit(GlobalStoreStmt *stmt) override; |
139 | |
140 | void visit(AtomicOpStmt *stmt) override; |
141 | |
142 | void visit(Stmt *stmt) override; |
143 | |
144 | bool run(); |
145 | |
146 | private: |
147 | void record_access(Stmt *stmt, AccessFlag flag); |
148 | |
149 | OffloadedStmt *for_stmt_{nullptr}; |
150 | MeshBLSCaches *caches_{nullptr}; |
151 | bool analysis_ok_{true}; |
152 | bool auto_mesh_local_{false}; |
153 | CompileConfig config_; |
154 | }; |
155 | |
156 | } // namespace taichi::lang |
157 | |