1#include "taichi/transforms/loop_invariant_detector.h"
2#include "taichi/ir/analysis.h"
3
4namespace taichi::lang {
5
6class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
7 public:
8 using LoopInvariantDetector::visit;
9
10 enum class CacheStatus {
11 None = 0,
12 Read = 1,
13 Write = 2,
14 ReadWrite = 3,
15 };
16
17 typedef std::unordered_map<Stmt *, std::pair<CacheStatus, AllocaStmt *>>
18 CacheMap;
19 std::vector<CacheMap> cached_maps;
20
21 DelayedIRModifier modifier;
22 std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_;
23 std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_;
24
25 OffloadedStmt *current_offloaded;
26
27 explicit CacheLoopInvariantGlobalVars(const CompileConfig &config)
28 : LoopInvariantDetector(config) {
29 }
30
31 void visit(OffloadedStmt *stmt) override {
32 if (stmt->task_type == OffloadedTaskType::range_for ||
33 stmt->task_type == OffloadedTaskType::mesh_for ||
34 stmt->task_type == OffloadedTaskType::struct_for) {
35 auto uniquely_accessed_pointers =
36 irpass::analysis::gather_uniquely_accessed_pointers(stmt);
37 loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first);
38 loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second);
39 }
40 current_offloaded = stmt;
41 // We don't need to visit TLS/BLS prologues/epilogues.
42 if (stmt->body) {
43 if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
44 stmt->task_type == OffloadedTaskType::mesh_for ||
45 stmt->task_type == OffloadedStmt::TaskType::struct_for)
46 visit_loop(stmt->body.get());
47 else
48 stmt->body->accept(this);
49 }
50 current_offloaded = nullptr;
51 }
52
53 bool is_offload_unique(Stmt *stmt) {
54 if (current_offloaded->task_type == OffloadedTaskType::serial) {
55 return true;
56 }
57 if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
58 auto snode = global_ptr->snode;
59 if (loop_unique_ptr_[snode] == nullptr ||
60 loop_unique_ptr_[snode]->indices.empty()) {
61 // not uniquely accessed
62 return false;
63 }
64 if (current_offloaded->mem_access_opt.has_flag(
65 snode, SNodeAccessFlag::block_local) ||
66 current_offloaded->mem_access_opt.has_flag(
67 snode, SNodeAccessFlag::mesh_local)) {
68 // BLS does not support write access yet so we keep atomic_adds.
69 return false;
70 }
71 return true;
72 } else if (stmt->is<ExternalPtrStmt>()) {
73 ExternalPtrStmt *dest_ptr = stmt->as<ExternalPtrStmt>();
74 if (dest_ptr->indices.empty()) {
75 return false;
76 }
77 ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>();
78 int arg_id = arg_load_stmt->arg_id;
79 if (loop_unique_arr_ptr_[arg_id] == nullptr) {
80 // Not loop unique
81 return false;
82 }
83 return true;
84 // TODO: Is BLS / Mem Access Opt a thing for any_arr?
85 }
86 return false;
87 }
88
89 void visit_loop(Block *body) override {
90 cached_maps.emplace_back();
91 LoopInvariantDetector::visit_loop(body);
92 cached_maps.pop_back();
93 }
94
95 void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var, int depth) {
96 auto final_value = std::make_unique<LocalLoadStmt>(alloca_stmt);
97 auto global_store =
98 std::make_unique<GlobalStoreStmt>(global_var, final_value.get());
99 modifier.insert_after(get_loop_stmt(depth), std::move(global_store));
100 modifier.insert_after(get_loop_stmt(depth), std::move(final_value));
101 }
102
103 void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var, int depth) {
104 auto new_global_load = std::make_unique<GlobalLoadStmt>(global_var);
105 auto local_store =
106 std::make_unique<LocalStoreStmt>(alloca_stmt, new_global_load.get());
107 modifier.insert_before(get_loop_stmt(depth), std::move(new_global_load));
108 modifier.insert_before(get_loop_stmt(depth), std::move(local_store));
109 }
110
111 AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status, int depth) {
112 if (auto &[cached_status, alloca_stmt] = cached_maps[depth][dest];
113 cached_status != CacheStatus::None) {
114 // The global variable has already been cached.
115 if (cached_status == CacheStatus::Read && status == CacheStatus::Write) {
116 add_writeback(alloca_stmt, dest, depth);
117 cached_status = CacheStatus::ReadWrite;
118 }
119 return alloca_stmt;
120 }
121 auto alloca_unique =
122 std::make_unique<AllocaStmt>(dest->ret_type.ptr_removed());
123 auto alloca_stmt = alloca_unique.get();
124 modifier.insert_before(get_loop_stmt(depth), std::move(alloca_unique));
125 set_init_value(alloca_stmt, dest, depth);
126 if (status == CacheStatus::Write) {
127 add_writeback(alloca_stmt, dest, depth);
128 }
129 cached_maps[depth][dest] = {status, alloca_stmt};
130 return alloca_stmt;
131 }
132
133 std::optional<int> find_cache_depth_if_cacheable(Stmt *operand,
134 Block *current_scope) {
135 if (!is_offload_unique(operand)) {
136 return std::nullopt;
137 }
138 std::optional<int> depth;
139 for (int n = loop_blocks.size() - 1; n > 0; n--) {
140 if (is_operand_loop_invariant(operand, current_scope, n)) {
141 depth = n;
142 } else {
143 break;
144 }
145 }
146 return depth;
147 }
148
149 void visit(GlobalLoadStmt *stmt) override {
150 if (auto depth = find_cache_depth_if_cacheable(stmt->src, stmt->parent)) {
151 auto alloca_stmt =
152 cache_global_to_local(stmt->src, CacheStatus::Read, depth.value());
153 auto local_load = std::make_unique<LocalLoadStmt>(alloca_stmt);
154 stmt->replace_usages_with(local_load.get());
155 modifier.insert_before(stmt, std::move(local_load));
156 modifier.erase(stmt);
157 }
158 }
159
160 void visit(GlobalStoreStmt *stmt) override {
161 if (auto depth = find_cache_depth_if_cacheable(stmt->dest, stmt->parent)) {
162 auto alloca_stmt =
163 cache_global_to_local(stmt->dest, CacheStatus::Write, depth.value());
164 auto local_store =
165 std::make_unique<LocalStoreStmt>(alloca_stmt, stmt->val);
166 stmt->replace_usages_with(local_store.get());
167 modifier.insert_before(stmt, std::move(local_store));
168 modifier.erase(stmt);
169 }
170 }
171
172 static bool run(IRNode *node, const CompileConfig &config) {
173 bool modified = false;
174
175 while (true) {
176 CacheLoopInvariantGlobalVars eliminator(config);
177 node->accept(&eliminator);
178 if (eliminator.modifier.modify_ir())
179 modified = true;
180 else
181 break;
182 };
183
184 return modified;
185 }
186};
187
188namespace irpass {
189bool cache_loop_invariant_global_vars(IRNode *root,
190 const CompileConfig &config) {
191 TI_AUTO_PROF;
192 return CacheLoopInvariantGlobalVars::run(root, config);
193}
194} // namespace irpass
195} // namespace taichi::lang
196