1 | #include "taichi/transforms/loop_invariant_detector.h" |
2 | #include "taichi/ir/analysis.h" |
3 | |
4 | namespace taichi::lang { |
5 | |
6 | class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { |
7 | public: |
8 | using LoopInvariantDetector::visit; |
9 | |
10 | enum class CacheStatus { |
11 | None = 0, |
12 | Read = 1, |
13 | Write = 2, |
14 | ReadWrite = 3, |
15 | }; |
16 | |
17 | typedef std::unordered_map<Stmt *, std::pair<CacheStatus, AllocaStmt *>> |
18 | CacheMap; |
19 | std::vector<CacheMap> cached_maps; |
20 | |
21 | DelayedIRModifier modifier; |
22 | std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_; |
23 | std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_; |
24 | |
25 | OffloadedStmt *current_offloaded; |
26 | |
27 | explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) |
28 | : LoopInvariantDetector(config) { |
29 | } |
30 | |
31 | void visit(OffloadedStmt *stmt) override { |
32 | if (stmt->task_type == OffloadedTaskType::range_for || |
33 | stmt->task_type == OffloadedTaskType::mesh_for || |
34 | stmt->task_type == OffloadedTaskType::struct_for) { |
35 | auto uniquely_accessed_pointers = |
36 | irpass::analysis::gather_uniquely_accessed_pointers(stmt); |
37 | loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); |
38 | loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); |
39 | } |
40 | current_offloaded = stmt; |
41 | // We don't need to visit TLS/BLS prologues/epilogues. |
42 | if (stmt->body) { |
43 | if (stmt->task_type == OffloadedStmt::TaskType::range_for || |
44 | stmt->task_type == OffloadedTaskType::mesh_for || |
45 | stmt->task_type == OffloadedStmt::TaskType::struct_for) |
46 | visit_loop(stmt->body.get()); |
47 | else |
48 | stmt->body->accept(this); |
49 | } |
50 | current_offloaded = nullptr; |
51 | } |
52 | |
53 | bool is_offload_unique(Stmt *stmt) { |
54 | if (current_offloaded->task_type == OffloadedTaskType::serial) { |
55 | return true; |
56 | } |
57 | if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) { |
58 | auto snode = global_ptr->snode; |
59 | if (loop_unique_ptr_[snode] == nullptr || |
60 | loop_unique_ptr_[snode]->indices.empty()) { |
61 | // not uniquely accessed |
62 | return false; |
63 | } |
64 | if (current_offloaded->mem_access_opt.has_flag( |
65 | snode, SNodeAccessFlag::block_local) || |
66 | current_offloaded->mem_access_opt.has_flag( |
67 | snode, SNodeAccessFlag::mesh_local)) { |
68 | // BLS does not support write access yet so we keep atomic_adds. |
69 | return false; |
70 | } |
71 | return true; |
72 | } else if (stmt->is<ExternalPtrStmt>()) { |
73 | ExternalPtrStmt *dest_ptr = stmt->as<ExternalPtrStmt>(); |
74 | if (dest_ptr->indices.empty()) { |
75 | return false; |
76 | } |
77 | ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>(); |
78 | int arg_id = arg_load_stmt->arg_id; |
79 | if (loop_unique_arr_ptr_[arg_id] == nullptr) { |
80 | // Not loop unique |
81 | return false; |
82 | } |
83 | return true; |
84 | // TODO: Is BLS / Mem Access Opt a thing for any_arr? |
85 | } |
86 | return false; |
87 | } |
88 | |
89 | void visit_loop(Block *body) override { |
90 | cached_maps.emplace_back(); |
91 | LoopInvariantDetector::visit_loop(body); |
92 | cached_maps.pop_back(); |
93 | } |
94 | |
95 | void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var, int depth) { |
96 | auto final_value = std::make_unique<LocalLoadStmt>(alloca_stmt); |
97 | auto global_store = |
98 | std::make_unique<GlobalStoreStmt>(global_var, final_value.get()); |
99 | modifier.insert_after(get_loop_stmt(depth), std::move(global_store)); |
100 | modifier.insert_after(get_loop_stmt(depth), std::move(final_value)); |
101 | } |
102 | |
103 | void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var, int depth) { |
104 | auto new_global_load = std::make_unique<GlobalLoadStmt>(global_var); |
105 | auto local_store = |
106 | std::make_unique<LocalStoreStmt>(alloca_stmt, new_global_load.get()); |
107 | modifier.insert_before(get_loop_stmt(depth), std::move(new_global_load)); |
108 | modifier.insert_before(get_loop_stmt(depth), std::move(local_store)); |
109 | } |
110 | |
111 | AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status, int depth) { |
112 | if (auto &[cached_status, alloca_stmt] = cached_maps[depth][dest]; |
113 | cached_status != CacheStatus::None) { |
114 | // The global variable has already been cached. |
115 | if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { |
116 | add_writeback(alloca_stmt, dest, depth); |
117 | cached_status = CacheStatus::ReadWrite; |
118 | } |
119 | return alloca_stmt; |
120 | } |
121 | auto alloca_unique = |
122 | std::make_unique<AllocaStmt>(dest->ret_type.ptr_removed()); |
123 | auto alloca_stmt = alloca_unique.get(); |
124 | modifier.insert_before(get_loop_stmt(depth), std::move(alloca_unique)); |
125 | set_init_value(alloca_stmt, dest, depth); |
126 | if (status == CacheStatus::Write) { |
127 | add_writeback(alloca_stmt, dest, depth); |
128 | } |
129 | cached_maps[depth][dest] = {status, alloca_stmt}; |
130 | return alloca_stmt; |
131 | } |
132 | |
133 | std::optional<int> find_cache_depth_if_cacheable(Stmt *operand, |
134 | Block *current_scope) { |
135 | if (!is_offload_unique(operand)) { |
136 | return std::nullopt; |
137 | } |
138 | std::optional<int> depth; |
139 | for (int n = loop_blocks.size() - 1; n > 0; n--) { |
140 | if (is_operand_loop_invariant(operand, current_scope, n)) { |
141 | depth = n; |
142 | } else { |
143 | break; |
144 | } |
145 | } |
146 | return depth; |
147 | } |
148 | |
149 | void visit(GlobalLoadStmt *stmt) override { |
150 | if (auto depth = find_cache_depth_if_cacheable(stmt->src, stmt->parent)) { |
151 | auto alloca_stmt = |
152 | cache_global_to_local(stmt->src, CacheStatus::Read, depth.value()); |
153 | auto local_load = std::make_unique<LocalLoadStmt>(alloca_stmt); |
154 | stmt->replace_usages_with(local_load.get()); |
155 | modifier.insert_before(stmt, std::move(local_load)); |
156 | modifier.erase(stmt); |
157 | } |
158 | } |
159 | |
160 | void visit(GlobalStoreStmt *stmt) override { |
161 | if (auto depth = find_cache_depth_if_cacheable(stmt->dest, stmt->parent)) { |
162 | auto alloca_stmt = |
163 | cache_global_to_local(stmt->dest, CacheStatus::Write, depth.value()); |
164 | auto local_store = |
165 | std::make_unique<LocalStoreStmt>(alloca_stmt, stmt->val); |
166 | stmt->replace_usages_with(local_store.get()); |
167 | modifier.insert_before(stmt, std::move(local_store)); |
168 | modifier.erase(stmt); |
169 | } |
170 | } |
171 | |
172 | static bool run(IRNode *node, const CompileConfig &config) { |
173 | bool modified = false; |
174 | |
175 | while (true) { |
176 | CacheLoopInvariantGlobalVars eliminator(config); |
177 | node->accept(&eliminator); |
178 | if (eliminator.modifier.modify_ir()) |
179 | modified = true; |
180 | else |
181 | break; |
182 | }; |
183 | |
184 | return modified; |
185 | } |
186 | }; |
187 | |
188 | namespace irpass { |
189 | bool cache_loop_invariant_global_vars(IRNode *root, |
190 | const CompileConfig &config) { |
191 | TI_AUTO_PROF; |
192 | return CacheLoopInvariantGlobalVars::run(root, config); |
193 | } |
194 | } // namespace irpass |
195 | } // namespace taichi::lang |
196 | |