1#pragma once
2
3#include "taichi/program/compile_config.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/ir/statements.h"
6#include "taichi/ir/mesh.h"
7
8#include <set>
9
10namespace taichi::lang {
11
12class MeshBLSCache {
13 public:
14 using AccessFlag = taichi::lang::AccessFlag;
15 using Rec = std::map<std::pair<mesh::MeshElementType, mesh::ConvType>,
16 std::set<std::pair<SNode *, AccessFlag>>>;
17
18 SNode *snode{nullptr};
19 mesh::MeshElementType element_type;
20 mesh::ConvType conv_type;
21
22 bool initialized;
23 bool finalized;
24 bool loop_index;
25 int unique_accessed;
26 AccessFlag total_flags;
27
28 MeshBLSCache() = default;
29
30 explicit MeshBLSCache(SNode *snode) : snode(snode) {
31 total_flags = AccessFlag(0);
32 initialized = false;
33 finalized = false;
34 loop_index = false;
35 unique_accessed = 0;
36 }
37
38 bool access(mesh::MeshElementType element_type,
39 mesh::ConvType conv_type,
40 AccessFlag flags,
41 Stmt *idx) {
42 if (!initialized) {
43 initialized = true;
44 this->conv_type = conv_type;
45 this->element_type = element_type;
46 } else {
47 if (this->conv_type != conv_type || this->element_type != element_type)
48 return false;
49 }
50 this->total_flags |= flags;
51 if (idx->is<LoopIndexStmt>()) {
52 loop_index = true;
53 } else {
54 unique_accessed++;
55 }
56 return true;
57 }
58
59 void finalize(Rec &rec) {
60 TI_ASSERT(!finalized);
61 finalized = true;
62 if (initialized) {
63 const auto cache_type = std::make_pair(element_type, conv_type);
64 auto ptr = rec.find(cache_type);
65 if (ptr == rec.end()) {
66 ptr = rec.emplace(std::piecewise_construct,
67 std::forward_as_tuple(cache_type),
68 std::forward_as_tuple())
69 .first;
70 }
71 ptr->second.insert(std::make_pair(snode, total_flags));
72 }
73 }
74};
75
76class MeshBLSCaches {
77 public:
78 std::map<SNode *, MeshBLSCache> caches;
79
80 using AccessFlag = MeshBLSCache::AccessFlag;
81 using Rec = MeshBLSCache::Rec;
82
83 void insert(SNode *snode) {
84 if (caches.find(snode) == caches.end()) {
85 caches.emplace(std::piecewise_construct, std::forward_as_tuple(snode),
86 std::forward_as_tuple(snode));
87 } else {
88 TI_ERROR("mesh::MeshBLSCaches for {} already exists.",
89 snode->node_type_name);
90 }
91 }
92
93 bool access(SNode *snode,
94 mesh::MeshElementType element_type,
95 mesh::ConvType conv_type,
96 AccessFlag flags,
97 Stmt *idx) {
98 if (caches.find(snode) == caches.end())
99 return false;
100 return caches.find(snode)->second.access(element_type, conv_type, flags,
101 idx);
102 }
103
104 Rec finalize() {
105 Rec rec;
106 for (auto &cache : caches) {
107 cache.second.finalize(rec);
108 }
109 return rec;
110 }
111
112 bool has(SNode *snode) {
113 return caches.find(snode) != caches.end();
114 }
115
116 MeshBLSCache &get(SNode *snode) {
117 TI_ASSERT(caches.find(snode) != caches.end());
118 return caches[snode];
119 }
120};
121
122// Figure out accessed SNodes, and their ranges in this for stmt
123class MeshBLSAnalyzer : public BasicStmtVisitor {
124 using BasicStmtVisitor::visit;
125
126 public:
127 MeshBLSAnalyzer(OffloadedStmt *for_stmt,
128 MeshBLSCaches *caches,
129 bool auto_mesh_local,
130 const CompileConfig &config);
131
132 void visit(GlobalPtrStmt *stmt) override {
133 }
134
135 // Do not eliminate global data access
136 void visit(GlobalLoadStmt *stmt) override;
137
138 void visit(GlobalStoreStmt *stmt) override;
139
140 void visit(AtomicOpStmt *stmt) override;
141
142 void visit(Stmt *stmt) override;
143
144 bool run();
145
146 private:
147 void record_access(Stmt *stmt, AccessFlag flag);
148
149 OffloadedStmt *for_stmt_{nullptr};
150 MeshBLSCaches *caches_{nullptr};
151 bool analysis_ok_{true};
152 bool auto_mesh_local_{false};
153 CompileConfig config_;
154};
155
156} // namespace taichi::lang
157