1 | #include "taichi/ir/control_flow_graph.h" |
2 | |
3 | #include <queue> |
4 | #include <unordered_set> |
5 | |
6 | #include "taichi/ir/analysis.h" |
7 | #include "taichi/ir/statements.h" |
8 | #include "taichi/system/profiler.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | CFGNode::CFGNode(Block *block, |
13 | int begin_location, |
14 | int end_location, |
15 | bool is_parallel_executed, |
16 | CFGNode *prev_node_in_same_block) |
17 | : block(block), |
18 | begin_location(begin_location), |
19 | end_location(end_location), |
20 | is_parallel_executed(is_parallel_executed), |
21 | prev_node_in_same_block(prev_node_in_same_block), |
22 | next_node_in_same_block(nullptr) { |
23 | if (prev_node_in_same_block != nullptr) |
24 | prev_node_in_same_block->next_node_in_same_block = this; |
25 | if (!empty()) { |
26 | // For non-empty nodes, precompute |parent_blocks| to accelerate |
27 | // get_store_forwarding_data(). |
28 | TI_ASSERT(begin_location >= 0); |
29 | TI_ASSERT(block); |
30 | auto parent_block = block; |
31 | parent_blocks_.insert(parent_block); |
32 | while (parent_block->parent_block()) { |
33 | parent_block = parent_block->parent_block(); |
34 | parent_blocks_.insert(parent_block); |
35 | } |
36 | } |
37 | } |
38 | |
39 | CFGNode::CFGNode() : CFGNode(nullptr, -1, -1, false, nullptr) { |
40 | } |
41 | |
42 | void CFGNode::add_edge(CFGNode *from, CFGNode *to) { |
43 | from->next.push_back(to); |
44 | to->prev.push_back(from); |
45 | } |
46 | |
47 | bool CFGNode::empty() const { |
48 | return begin_location >= end_location; |
49 | } |
50 | |
51 | std::size_t CFGNode::size() const { |
52 | return end_location - begin_location; |
53 | } |
54 | |
55 | void CFGNode::erase(int location) { |
56 | TI_ASSERT(location >= begin_location && location < end_location); |
57 | block->erase(location); |
58 | end_location--; |
59 | for (auto node = next_node_in_same_block; node != nullptr; |
60 | node = node->next_node_in_same_block) { |
61 | node->begin_location--; |
62 | node->end_location--; |
63 | } |
64 | } |
65 | |
66 | void CFGNode::insert(std::unique_ptr<Stmt> &&new_stmt, int location) { |
67 | TI_ASSERT(location >= begin_location && location <= end_location); |
68 | block->insert(std::move(new_stmt), location); |
69 | end_location++; |
70 | for (auto node = next_node_in_same_block; node != nullptr; |
71 | node = node->next_node_in_same_block) { |
72 | node->begin_location++; |
73 | node->end_location++; |
74 | } |
75 | } |
76 | |
77 | void CFGNode::replace_with(int location, |
78 | std::unique_ptr<Stmt> &&new_stmt, |
79 | bool replace_usages) const { |
80 | TI_ASSERT(location >= begin_location && location < end_location); |
81 | block->replace_with(block->statements[location].get(), std::move(new_stmt), |
82 | replace_usages); |
83 | } |
84 | |
85 | bool CFGNode::contain_variable(const std::unordered_set<Stmt *> &var_set, |
86 | Stmt *var) { |
87 | if (var->is<AllocaStmt>() || var->is<AdStackAllocaStmt>()) { |
88 | return var_set.find(var) != var_set.end(); |
89 | } else { |
90 | // TODO: How to optimize this? |
91 | if (var_set.find(var) != var_set.end()) |
92 | return true; |
93 | return std::any_of(var_set.begin(), var_set.end(), [&](Stmt *set_var) { |
94 | return irpass::analysis::definitely_same_address(var, set_var); |
95 | }); |
96 | } |
97 | } |
98 | |
99 | bool CFGNode::may_contain_variable(const std::unordered_set<Stmt *> &var_set, |
100 | Stmt *var) { |
101 | if (var->is<AllocaStmt>() || var->is<AdStackAllocaStmt>()) { |
102 | return var_set.find(var) != var_set.end(); |
103 | } else { |
104 | // TODO: How to optimize this? |
105 | if (var_set.find(var) != var_set.end()) |
106 | return true; |
107 | return std::any_of(var_set.begin(), var_set.end(), [&](Stmt *set_var) { |
108 | return irpass::analysis::maybe_same_address(var, set_var); |
109 | }); |
110 | } |
111 | } |
112 | |
113 | bool CFGNode::reach_kill_variable(Stmt *var) const { |
114 | // Does this node (definitely) kill a definition of var? |
115 | return contain_variable(reach_kill, var); |
116 | } |
117 | |
118 | Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { |
119 | // Return the stored data if all definitions in the UD-chain of |var| at |
120 | // this position store the same data. |
121 | int last_def_position = -1; |
122 | for (int i = position - 1; i >= begin_location; i--) { |
123 | if (block->statements[i]->is<FuncCallStmt>()) { |
124 | return nullptr; |
125 | } |
126 | for (auto store_ptr : |
127 | irpass::analysis::get_store_destination(block->statements[i].get())) { |
128 | if (irpass::analysis::definitely_same_address(var, store_ptr)) { |
129 | last_def_position = i; |
130 | break; |
131 | } |
132 | } |
133 | if (last_def_position != -1) { |
134 | break; |
135 | } |
136 | } |
137 | auto may_contain_address = [](Stmt *store_stmt, Stmt *var) { |
138 | for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { |
139 | if (irpass::analysis::maybe_same_address(var, store_ptr)) { |
140 | return true; |
141 | } |
142 | } |
143 | return false; |
144 | }; |
145 | if (last_def_position != -1) { |
146 | // The UD-chain is inside this node. |
147 | Stmt *result = irpass::analysis::get_store_data( |
148 | block->statements[last_def_position].get()); |
149 | if (!var->is<AllocaStmt>()) { |
150 | for (int i = last_def_position + 1; i < position; i++) { |
151 | if (!irpass::analysis::same_value( |
152 | result, |
153 | irpass::analysis::get_store_data(block->statements[i].get()))) { |
154 | if (may_contain_address(block->statements[i].get(), var)) { |
155 | return nullptr; |
156 | } |
157 | } |
158 | } |
159 | } |
160 | return result; |
161 | } |
162 | Stmt *result = nullptr; |
163 | bool result_visible = false; |
164 | auto visible = [&](Stmt *stmt) { |
165 | // Check if |stmt| is before |position| here. |
166 | if (stmt->parent == block) { |
167 | return stmt->parent->locate(stmt) < position; |
168 | } |
169 | // |parent_blocks| is precomputed in the constructor of CFGNode. |
170 | // TODO: What if |stmt| appears in an ancestor of |block| but after |
171 | // |position|? |
172 | return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); |
173 | }; |
174 | /** |
175 | * |stmt| is a definition in the UD-chain of |var|. Update |result| with |
176 | * |stmt|. If either the stored data of |stmt| is not forwardable or the |
177 | * stored data of |stmt| is not definitely the same as other definitions of |
178 | * |var|, return false to show that there is no store-to-load forwardable |
179 | * data. |
180 | */ |
181 | auto update_result = [&](Stmt *stmt) { |
182 | auto data = irpass::analysis::get_store_data(stmt); |
183 | if (!data) { // not forwardable |
184 | return false; // return nullptr |
185 | } |
186 | if (!result) { |
187 | result = data; |
188 | result_visible = visible(data); |
189 | return true; // continue the following loops |
190 | } |
191 | if (!irpass::analysis::same_value(result, data)) { |
192 | // check the special case of alloca (initialized to 0) |
193 | if (!(result->is<AllocaStmt>() && data->is<ConstStmt>() && |
194 | data->as<ConstStmt>()->val.equal_value(0))) { |
195 | return false; // return nullptr |
196 | } |
197 | } |
198 | if (!result_visible && visible(data)) { |
199 | // pick the visible one for store-to-load forwarding |
200 | result = data; |
201 | result_visible = true; |
202 | } |
203 | return true; // continue the following loops |
204 | }; |
205 | for (auto stmt : reach_in) { |
206 | // var == stmt is for the case that a global ptr is never stored. |
207 | // In this case, stmt is from nodes[start_node]->reach_gen. |
208 | if (var == stmt || may_contain_address(stmt, var)) { |
209 | if (!update_result(stmt)) |
210 | return nullptr; |
211 | } |
212 | } |
213 | for (auto stmt : reach_gen) { |
214 | if (may_contain_address(stmt, var) && |
215 | stmt->parent->locate(stmt) < position) { |
216 | if (!update_result(stmt)) |
217 | return nullptr; |
218 | } |
219 | } |
220 | if (!result) { |
221 | // The UD-chain is empty. |
222 | TI_WARN("stmt {} loaded in stmt {} before storing." , var->id, |
223 | block->statements[position]->id); |
224 | return nullptr; |
225 | } |
226 | if (!result_visible) { |
227 | // The data is store-to-load forwardable but not visible at the place we |
228 | // are going to forward. We cannot forward it in this case. |
229 | return nullptr; |
230 | } |
231 | return result; |
232 | } |
233 | |
234 | void CFGNode::reaching_definition_analysis(bool after_lower_access) { |
235 | // Calculate |reach_gen| and |reach_kill|. |
236 | reach_gen.clear(); |
237 | reach_kill.clear(); |
238 | for (int i = end_location - 1; i >= begin_location; i--) { |
239 | // loop in reversed order |
240 | auto stmt = block->statements[i].get(); |
241 | auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); |
242 | for (auto data_source_ptr : data_source_ptrs) { |
243 | // stmt provides a data source |
244 | if (after_lower_access && !(data_source_ptr->is<AllocaStmt>())) { |
245 | // After lower_access, we only analyze local variables. |
246 | continue; |
247 | } |
248 | if (!reach_kill_variable(data_source_ptr)) { |
249 | reach_gen.insert(stmt); |
250 | reach_kill.insert(data_source_ptr); |
251 | } |
252 | } |
253 | } |
254 | } |
255 | |
256 | bool CFGNode::store_to_load_forwarding(bool after_lower_access, |
257 | bool autodiff_enabled) { |
258 | bool modified = false; |
259 | for (int i = begin_location; i < end_location; i++) { |
260 | // Store-to-load forwarding |
261 | auto stmt = block->statements[i].get(); |
262 | Stmt *result = nullptr; |
263 | if (auto local_load = stmt->cast<LocalLoadStmt>()) { |
264 | result = get_store_forwarding_data(local_load->src, i); |
265 | } else if (auto global_load = stmt->cast<GlobalLoadStmt>()) { |
266 | if (!after_lower_access && !autodiff_enabled) { |
267 | result = get_store_forwarding_data(global_load->src, i); |
268 | } |
269 | } |
270 | if (result) { |
271 | // Forward the stored data |result|. |
272 | if (result->is<AllocaStmt>()) { |
273 | // special case of alloca (initialized to 0) |
274 | auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0)); |
275 | replace_with(i, std::move(zero), true); |
276 | } else { |
277 | stmt->replace_usages_with(result); |
278 | erase(i); // This causes end_location-- |
279 | i--; // to cancel i++ in the for loop |
280 | modified = true; |
281 | } |
282 | continue; |
283 | } |
284 | |
285 | // Identical store elimination |
286 | if (auto local_store = stmt->cast<LocalStoreStmt>()) { |
287 | result = get_store_forwarding_data(local_store->dest, i); |
288 | if (result && result->is<AllocaStmt>() && !autodiff_enabled) { |
289 | // special case of alloca (initialized to 0) |
290 | if (auto stored_data = local_store->val->cast<ConstStmt>()) { |
291 | if (stored_data->val.equal_value(0)) { |
292 | erase(i); // This causes end_location-- |
293 | i--; // to cancel i++ in the for loop |
294 | modified = true; |
295 | } |
296 | } |
297 | } else { |
298 | // not alloca |
299 | if (irpass::analysis::same_value(result, local_store->val)) { |
300 | erase(i); // This causes end_location-- |
301 | i--; // to cancel i++ in the for loop |
302 | modified = true; |
303 | } |
304 | } |
305 | } else if (auto global_store = stmt->cast<GlobalStoreStmt>()) { |
306 | if (!after_lower_access) { |
307 | result = get_store_forwarding_data(global_store->dest, i); |
308 | if (irpass::analysis::same_value(result, global_store->val)) { |
309 | erase(i); // This causes end_location-- |
310 | i--; // to cancel i++ in the for loop |
311 | modified = true; |
312 | } |
313 | } |
314 | } |
315 | } |
316 | return modified; |
317 | } |
318 | |
319 | void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const { |
320 | // Gather the SNodes which this CFGNode loads. |
321 | // Requires reaching definition analysis. |
322 | std::unordered_set<Stmt *> killed_in_this_node; |
323 | for (int i = begin_location; i < end_location; i++) { |
324 | auto stmt = block->statements[i].get(); |
325 | auto load_ptrs = irpass::analysis::get_load_pointers(stmt); |
326 | for (auto &load_ptr : load_ptrs) { |
327 | if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) { |
328 | // Avoid computing the UD-chain if every SNode in this global ptr |
329 | // are already loaded because it can be time-consuming. |
330 | auto snode = global_ptr->snode; |
331 | if (snodes.count(snode) > 0) { |
332 | continue; |
333 | } |
334 | if (reach_in.find(global_ptr) != reach_in.end() && |
335 | !contain_variable(killed_in_this_node, global_ptr)) { |
336 | // The UD-chain contains the value before this offloaded task. |
337 | snodes.insert(snode); |
338 | } |
339 | } |
340 | } |
341 | auto store_ptrs = irpass::analysis::get_store_destination(stmt); |
342 | for (auto &store_ptr : store_ptrs) { |
343 | if (store_ptr->is<GlobalPtrStmt>()) { |
344 | killed_in_this_node.insert(store_ptr); |
345 | } |
346 | } |
347 | } |
348 | } |
349 | |
350 | void CFGNode::live_variable_analysis(bool after_lower_access) { |
351 | live_gen.clear(); |
352 | live_kill.clear(); |
353 | for (int i = begin_location; i < end_location; i++) { |
354 | auto stmt = block->statements[i].get(); |
355 | auto load_ptrs = irpass::analysis::get_load_pointers(stmt); |
356 | for (auto &load_ptr : load_ptrs) { |
357 | if (!after_lower_access || |
358 | (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) { |
359 | // After lower_access, we only analyze local variables and stacks. |
360 | if (!contain_variable(live_kill, load_ptr)) { |
361 | live_gen.insert(load_ptr); |
362 | } |
363 | } |
364 | } |
365 | auto store_ptrs = irpass::analysis::get_store_destination(stmt); |
366 | // TODO: Consider AD-stacks in get_store_destination instead of here |
367 | // for store-to-load forwarding on AD-stacks |
368 | // TODO: SNode deactivation is also a definite store |
369 | if (auto stack_pop = stmt->cast<AdStackPopStmt>()) { |
370 | store_ptrs = std::vector<Stmt *>(1, stack_pop->stack); |
371 | } else if (auto stack_push = stmt->cast<AdStackPushStmt>()) { |
372 | store_ptrs = std::vector<Stmt *>(1, stack_push->stack); |
373 | } else if (auto stack_acc_adj = stmt->cast<AdStackAccAdjointStmt>()) { |
374 | store_ptrs = std::vector<Stmt *>(1, stack_acc_adj->stack); |
375 | } |
376 | for (auto store_ptr : store_ptrs) { |
377 | if (!after_lower_access || |
378 | (store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) { |
379 | // After lower_access, we only analyze local variables and stacks. |
380 | live_kill.insert(store_ptr); |
381 | } |
382 | } |
383 | } |
384 | } |
385 | |
386 | bool CFGNode::dead_store_elimination(bool after_lower_access) { |
387 | bool modified = false; |
388 | std::unordered_set<Stmt *> live_in_this_node; |
389 | std::unordered_set<Stmt *> killed_in_this_node; |
390 | // Map a variable to its nearest load |
391 | std::unordered_map<Stmt *, Stmt *> live_load_in_this_node; |
392 | for (int i = end_location - 1; i >= begin_location; i--) { |
393 | auto stmt = block->statements[i].get(); |
394 | if (stmt->is<FuncCallStmt>()) { |
395 | killed_in_this_node.clear(); |
396 | live_load_in_this_node.clear(); |
397 | } |
398 | auto store_ptrs = irpass::analysis::get_store_destination(stmt); |
399 | // TODO: Consider AD-stacks in get_store_destination instead of here |
400 | // for store-to-load forwarding on AD-stacks |
401 | if (auto stack_pop = stmt->cast<AdStackPopStmt>()) { |
402 | store_ptrs = std::vector<Stmt *>(1, stack_pop->stack); |
403 | } else if (auto stack_push = stmt->cast<AdStackPushStmt>()) { |
404 | store_ptrs = std::vector<Stmt *>(1, stack_push->stack); |
405 | } else if (auto stack_acc_adj = stmt->cast<AdStackAccAdjointStmt>()) { |
406 | store_ptrs = std::vector<Stmt *>(1, stack_acc_adj->stack); |
407 | } else if (stmt->is<AdStackAllocaStmt>()) { |
408 | store_ptrs = std::vector<Stmt *>(1, stmt); |
409 | } |
410 | if (store_ptrs.size() == 1) { |
411 | // Dead store elimination |
412 | auto store_ptr = *store_ptrs.begin(); |
413 | if (!after_lower_access || |
414 | (store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) { |
415 | // After lower_access, we only analyze local variables and stacks. |
416 | // Do not eliminate AllocaStmt and AdStackAllocaStmt here. |
417 | if (!stmt->is<AllocaStmt>() && !stmt->is<AdStackAllocaStmt>() && |
418 | !stmt->is<ExternalFuncCallStmt>() && |
419 | !may_contain_variable(live_in_this_node, store_ptr) && |
420 | (contain_variable(killed_in_this_node, store_ptr) || |
421 | !may_contain_variable(live_out, store_ptr))) { |
422 | // Neither used in other nodes nor used in this node. |
423 | if (!stmt->is<AtomicOpStmt>()) { |
424 | // Eliminate the dead store. |
425 | erase(i); |
426 | modified = true; |
427 | continue; |
428 | } |
429 | auto atomic = stmt->cast<AtomicOpStmt>(); |
430 | // Weaken the atomic operation to a load. |
431 | if (atomic->dest->is<AllocaStmt>()) { |
432 | auto local_load = Stmt::make<LocalLoadStmt>(atomic->dest); |
433 | local_load->ret_type = atomic->ret_type; |
434 | // Notice that we have a load here |
435 | // (the return value of AtomicOpStmt). |
436 | live_in_this_node.insert(atomic->dest); |
437 | live_load_in_this_node[atomic->dest] = local_load.get(); |
438 | killed_in_this_node.erase(atomic->dest); |
439 | replace_with(i, std::move(local_load), true); |
440 | modified = true; |
441 | continue; |
442 | } else if (!is_parallel_executed || |
443 | (atomic->dest->is<GlobalPtrStmt>() && |
444 | atomic->dest->as<GlobalPtrStmt>()->snode->is_scalar())) { |
445 | // If this node is parallel executed, we can't weaken a global |
446 | // atomic operation to a global load. |
447 | // TODO: we can weaken it if it's element-wise (i.e. never |
448 | // accessed by other threads). |
449 | auto global_load = Stmt::make<GlobalLoadStmt>(atomic->dest); |
450 | global_load->ret_type = atomic->ret_type; |
451 | // Notice that we have a load here |
452 | // (the return value of AtomicOpStmt). |
453 | live_in_this_node.insert(atomic->dest); |
454 | live_load_in_this_node[atomic->dest] = global_load.get(); |
455 | killed_in_this_node.erase(atomic->dest); |
456 | replace_with(i, std::move(global_load), true); |
457 | modified = true; |
458 | continue; |
459 | } |
460 | } else { |
461 | // A non-eliminated store. |
462 | killed_in_this_node.insert(store_ptr); |
463 | auto old_live_in_this_node = std::move(live_in_this_node); |
464 | live_in_this_node.clear(); |
465 | for (auto &var : old_live_in_this_node) { |
466 | if (!irpass::analysis::definitely_same_address(store_ptr, var)) |
467 | live_in_this_node.insert(var); |
468 | } |
469 | } |
470 | } |
471 | } |
472 | auto load_ptrs = irpass::analysis::get_load_pointers(stmt); |
473 | if (load_ptrs.size() == 1 && store_ptrs.empty()) { |
474 | // Identical load elimination |
475 | auto load_ptr = load_ptrs.begin()[0]; |
476 | if (!after_lower_access || |
477 | (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) { |
478 | // After lower_access, we only analyze local variables and stacks. |
479 | if (live_load_in_this_node.find(load_ptr) != |
480 | live_load_in_this_node.end() && |
481 | !may_contain_variable(killed_in_this_node, load_ptr)) { |
482 | // Only perform identical load elimination within a CFGNode. |
483 | auto next_load_stmt = live_load_in_this_node[load_ptr]; |
484 | TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt)); |
485 | next_load_stmt->replace_usages_with(stmt); |
486 | erase(block->locate(next_load_stmt)); |
487 | modified = true; |
488 | } |
489 | live_load_in_this_node[load_ptr] = stmt; |
490 | killed_in_this_node.erase(load_ptr); |
491 | } |
492 | } |
493 | for (auto &load_ptr : load_ptrs) { |
494 | if (!after_lower_access || |
495 | (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) { |
496 | // After lower_access, we only analyze local variables and stacks. |
497 | live_in_this_node.insert(load_ptr); |
498 | } |
499 | } |
500 | } |
501 | return modified; |
502 | } |
503 | |
504 | void ControlFlowGraph::erase(int node_id) { |
505 | // Erase an empty node. |
506 | TI_ASSERT(node_id >= 0 && node_id < (int)size()); |
507 | TI_ASSERT(nodes[node_id] && nodes[node_id]->empty()); |
508 | if (nodes[node_id]->prev_node_in_same_block) { |
509 | nodes[node_id]->prev_node_in_same_block->next_node_in_same_block = |
510 | nodes[node_id]->next_node_in_same_block; |
511 | } |
512 | if (nodes[node_id]->next_node_in_same_block) { |
513 | nodes[node_id]->next_node_in_same_block->prev_node_in_same_block = |
514 | nodes[node_id]->prev_node_in_same_block; |
515 | } |
516 | for (auto &prev_node : nodes[node_id]->prev) { |
517 | prev_node->next.erase(std::find( |
518 | prev_node->next.begin(), prev_node->next.end(), nodes[node_id].get())); |
519 | } |
520 | for (auto &next_node : nodes[node_id]->next) { |
521 | next_node->prev.erase(std::find( |
522 | next_node->prev.begin(), next_node->prev.end(), nodes[node_id].get())); |
523 | } |
524 | for (auto &prev_node : nodes[node_id]->prev) { |
525 | for (auto &next_node : nodes[node_id]->next) { |
526 | CFGNode::add_edge(prev_node, next_node); |
527 | } |
528 | } |
529 | nodes[node_id].reset(); |
530 | } |
531 | |
532 | std::size_t ControlFlowGraph::size() const { |
533 | return nodes.size(); |
534 | } |
535 | |
536 | CFGNode *ControlFlowGraph::back() const { |
537 | return nodes.back().get(); |
538 | } |
539 | |
540 | void ControlFlowGraph::print_graph_structure() const { |
541 | const int num_nodes = size(); |
542 | std::cout << "Control Flow Graph with " << num_nodes |
543 | << " nodes:" << std::endl; |
544 | std::unordered_map<CFGNode *, int> to_index; |
545 | for (int i = 0; i < num_nodes; i++) { |
546 | to_index[nodes[i].get()] = i; |
547 | } |
548 | for (int i = 0; i < num_nodes; i++) { |
549 | std::string node_info = fmt::format("Node {} : " , i); |
550 | if (nodes[i]->empty()) { |
551 | node_info += "empty" ; |
552 | } else { |
553 | node_info += fmt::format( |
554 | "{}~{} (size={})" , |
555 | nodes[i]->block->statements[nodes[i]->begin_location]->name(), |
556 | nodes[i]->block->statements[nodes[i]->end_location - 1]->name(), |
557 | nodes[i]->size()); |
558 | } |
559 | if (!nodes[i]->prev.empty()) { |
560 | std::vector<std::string> indices; |
561 | for (auto prev_node : nodes[i]->prev) { |
562 | indices.push_back(std::to_string(to_index[prev_node])); |
563 | } |
564 | node_info += fmt::format("; prev={{{}}}" , fmt::join(indices, ", " )); |
565 | } |
566 | if (!nodes[i]->next.empty()) { |
567 | std::vector<std::string> indices; |
568 | for (auto next_node : nodes[i]->next) { |
569 | indices.push_back(std::to_string(to_index[next_node])); |
570 | } |
571 | node_info += fmt::format("; next={{{}}}" , fmt::join(indices, ", " )); |
572 | } |
573 | if (!nodes[i]->reach_in.empty()) { |
574 | std::vector<std::string> indices; |
575 | for (auto stmt : nodes[i]->reach_in) { |
576 | indices.push_back(stmt->name()); |
577 | } |
578 | node_info += fmt::format("; reach_in={{{}}}" , fmt::join(indices, ", " )); |
579 | } |
580 | if (!nodes[i]->reach_out.empty()) { |
581 | std::vector<std::string> indices; |
582 | for (auto stmt : nodes[i]->reach_out) { |
583 | indices.push_back(stmt->name()); |
584 | } |
585 | node_info += fmt::format("; reach_out={{{}}}" , fmt::join(indices, ", " )); |
586 | } |
587 | std::cout << node_info << std::endl; |
588 | } |
589 | } |
590 | |
591 | void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { |
592 | TI_AUTO_PROF; |
593 | const int num_nodes = size(); |
594 | std::queue<CFGNode *> to_visit; |
595 | std::unordered_map<CFGNode *, bool> in_queue; |
596 | TI_ASSERT(nodes[start_node]->empty()); |
597 | nodes[start_node]->reach_gen.clear(); |
598 | nodes[start_node]->reach_kill.clear(); |
599 | for (int i = 0; i < num_nodes; i++) { |
600 | for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { |
601 | auto stmt = nodes[i]->block->statements[j].get(); |
602 | if ((stmt->is<MatrixPtrStmt>() && |
603 | stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) || |
604 | (!after_lower_access && |
605 | (stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() || |
606 | stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() || |
607 | stmt->is<GlobalTemporaryStmt>() || stmt->is<MatrixPtrStmt>() || |
608 | stmt->is<GetChStmt>()))) { |
609 | // TODO: unify them |
610 | // A global pointer that may contain some data before this kernel. |
611 | nodes[start_node]->reach_gen.insert(stmt); |
612 | } |
613 | } |
614 | } |
615 | for (int i = 0; i < num_nodes; i++) { |
616 | if (i != start_node) { |
617 | nodes[i]->reaching_definition_analysis(after_lower_access); |
618 | } |
619 | nodes[i]->reach_in.clear(); |
620 | nodes[i]->reach_out = nodes[i]->reach_gen; |
621 | to_visit.push(nodes[i].get()); |
622 | in_queue[nodes[i].get()] = true; |
623 | } |
624 | |
625 | // The worklist algorithm. |
626 | while (!to_visit.empty()) { |
627 | auto now = to_visit.front(); |
628 | to_visit.pop(); |
629 | in_queue[now] = false; |
630 | |
631 | now->reach_in.clear(); |
632 | for (auto prev_node : now->prev) { |
633 | now->reach_in.insert(prev_node->reach_out.begin(), |
634 | prev_node->reach_out.end()); |
635 | } |
636 | auto old_out = std::move(now->reach_out); |
637 | now->reach_out = now->reach_gen; |
638 | for (auto stmt : now->reach_in) { |
639 | auto store_ptrs = irpass::analysis::get_store_destination(stmt); |
640 | bool killed; |
641 | if (store_ptrs.empty()) { // the case of a global pointer |
642 | killed = now->reach_kill_variable(stmt); |
643 | } else { |
644 | killed = true; |
645 | for (auto store_ptr : store_ptrs) { |
646 | if (!now->reach_kill_variable(store_ptr)) { |
647 | killed = false; |
648 | break; |
649 | } |
650 | } |
651 | } |
652 | if (!killed) { |
653 | now->reach_out.insert(stmt); |
654 | } |
655 | } |
656 | if (now->reach_out != old_out) { |
657 | // changed |
658 | for (auto next_node : now->next) { |
659 | if (!in_queue[next_node]) { |
660 | to_visit.push(next_node); |
661 | in_queue[next_node] = true; |
662 | } |
663 | } |
664 | } |
665 | } |
666 | } |
667 | |
668 | void ControlFlowGraph::live_variable_analysis( |
669 | bool after_lower_access, |
670 | const std::optional<LiveVarAnalysisConfig> &config_opt) { |
671 | TI_AUTO_PROF; |
672 | const int num_nodes = size(); |
673 | std::queue<CFGNode *> to_visit; |
674 | std::unordered_map<CFGNode *, bool> in_queue; |
675 | TI_ASSERT(nodes[final_node]->empty()); |
676 | nodes[final_node]->live_gen.clear(); |
677 | nodes[final_node]->live_kill.clear(); |
678 | |
679 | auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool { |
680 | if (stmt->is<AllocaStmt>() || stmt->is<AdStackAllocaStmt>()) { |
681 | return false; |
682 | } |
683 | if (stmt->is<MatrixPtrStmt>() && |
684 | stmt->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>()) { |
685 | return false; |
686 | } |
687 | if (auto *gptr = stmt->cast<GlobalPtrStmt>(); |
688 | gptr && config_opt.has_value()) { |
689 | const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0); |
690 | return res; |
691 | } |
692 | // A global pointer that may be loaded after this kernel. |
693 | return true; |
694 | }; |
695 | if (!after_lower_access) { |
696 | for (int i = 0; i < num_nodes; i++) { |
697 | for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { |
698 | auto stmt = nodes[i]->block->statements[j].get(); |
699 | for (auto store_ptr : irpass::analysis::get_store_destination(stmt)) { |
700 | if (in_final_node_live_gen(store_ptr)) { |
701 | nodes[final_node]->live_gen.insert(store_ptr); |
702 | } |
703 | } |
704 | } |
705 | } |
706 | } |
707 | for (int i = num_nodes - 1; i >= 0; i--) { |
708 | // push into the queue in reversed order to make it slightly faster |
709 | if (i != final_node) { |
710 | nodes[i]->live_variable_analysis(after_lower_access); |
711 | } |
712 | nodes[i]->live_out.clear(); |
713 | nodes[i]->live_in = nodes[i]->live_gen; |
714 | to_visit.push(nodes[i].get()); |
715 | in_queue[nodes[i].get()] = true; |
716 | } |
717 | |
718 | // The worklist algorithm. |
719 | while (!to_visit.empty()) { |
720 | auto now = to_visit.front(); |
721 | to_visit.pop(); |
722 | in_queue[now] = false; |
723 | |
724 | now->live_out.clear(); |
725 | for (auto next_node : now->next) { |
726 | now->live_out.insert(next_node->live_in.begin(), |
727 | next_node->live_in.end()); |
728 | } |
729 | auto old_in = std::move(now->live_in); |
730 | now->live_in = now->live_gen; |
731 | for (auto stmt : now->live_out) { |
732 | if (!CFGNode::contain_variable(now->live_kill, stmt)) { |
733 | now->live_in.insert(stmt); |
734 | } |
735 | } |
736 | if (now->live_in != old_in) { |
737 | // changed |
738 | for (auto prev_node : now->prev) { |
739 | if (!in_queue[prev_node]) { |
740 | to_visit.push(prev_node); |
741 | in_queue[prev_node] = true; |
742 | } |
743 | } |
744 | } |
745 | } |
746 | } |
747 | |
748 | void ControlFlowGraph::simplify_graph() { |
749 | // Simplify the graph structure, do not modify the IR. |
750 | const int num_nodes = size(); |
751 | while (true) { |
752 | bool modified = false; |
753 | for (int i = 0; i < num_nodes; i++) { |
754 | // If a node is empty with in-degree or out-degree <= 1, we can eliminate |
755 | // it (except for the start node and the final node). |
756 | if (nodes[i] && nodes[i]->empty() && i != start_node && i != final_node && |
757 | (nodes[i]->prev.size() <= 1 || nodes[i]->next.size() <= 1)) { |
758 | erase(i); |
759 | modified = true; |
760 | } |
761 | } |
762 | if (!modified) |
763 | break; |
764 | } |
765 | int new_num_nodes = 0; |
766 | for (int i = 0; i < num_nodes; i++) { |
767 | if (nodes[i]) { |
768 | if (i != new_num_nodes) { |
769 | nodes[new_num_nodes] = std::move(nodes[i]); |
770 | } |
771 | if (final_node == i) { |
772 | final_node = new_num_nodes; |
773 | } |
774 | new_num_nodes++; |
775 | } |
776 | } |
777 | nodes.resize(new_num_nodes); |
778 | } |
779 | |
780 | bool ControlFlowGraph::unreachable_code_elimination() { |
781 | // Note that container statements are not in the control-flow graph, so |
782 | // this pass cannot eliminate container statements properly for now. |
783 | TI_AUTO_PROF; |
784 | std::unordered_set<CFGNode *> visited; |
785 | std::queue<CFGNode *> to_visit; |
786 | to_visit.push(nodes[start_node].get()); |
787 | visited.insert(nodes[start_node].get()); |
788 | // Breadth-first search |
789 | while (!to_visit.empty()) { |
790 | auto now = to_visit.front(); |
791 | to_visit.pop(); |
792 | for (auto &next : now->next) { |
793 | if (visited.find(next) == visited.end()) { |
794 | to_visit.push(next); |
795 | visited.insert(next); |
796 | } |
797 | } |
798 | } |
799 | bool modified = false; |
800 | for (auto &node : nodes) { |
801 | if (visited.find(node.get()) == visited.end()) { |
802 | // unreachable |
803 | if (!node->empty()) { |
804 | while (!node->empty()) |
805 | node->erase(node->end_location - 1); |
806 | modified = true; |
807 | } |
808 | } |
809 | } |
810 | return modified; |
811 | } |
812 | |
813 | bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, |
814 | bool autodiff_enabled) { |
815 | TI_AUTO_PROF; |
816 | reaching_definition_analysis(after_lower_access); |
817 | const int num_nodes = size(); |
818 | bool modified = false; |
819 | for (int i = 0; i < num_nodes; i++) { |
820 | if (nodes[i]->store_to_load_forwarding(after_lower_access, |
821 | autodiff_enabled)) |
822 | modified = true; |
823 | } |
824 | return modified; |
825 | } |
826 | |
827 | bool ControlFlowGraph::dead_store_elimination( |
828 | bool after_lower_access, |
829 | const std::optional<LiveVarAnalysisConfig> &lva_config_opt) { |
830 | TI_AUTO_PROF; |
831 | live_variable_analysis(after_lower_access, lva_config_opt); |
832 | const int num_nodes = size(); |
833 | bool modified = false; |
834 | for (int i = 0; i < num_nodes; i++) { |
835 | if (nodes[i]->dead_store_elimination(after_lower_access)) |
836 | modified = true; |
837 | } |
838 | return modified; |
839 | } |
840 | |
841 | std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() { |
842 | TI_AUTO_PROF; |
843 | reaching_definition_analysis(/*after_lower_access=*/false); |
844 | const int num_nodes = size(); |
845 | std::unordered_set<SNode *> snodes; |
846 | |
847 | // Note: since global store may only partially modify a value state, the |
848 | // result (which contains the modified and unmodified part) actually needs a |
849 | // read from the previous version of the value state. |
850 | // |
851 | // I.e., |
852 | // output_value_state = merge(input_value_state, written_part) |
853 | // |
854 | // Therefore we include the nodes[final_node]->reach_in in snodes. |
855 | for (auto &stmt : nodes[final_node]->reach_in) { |
856 | if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) { |
857 | snodes.insert(global_ptr->snode); |
858 | } |
859 | } |
860 | |
861 | for (int i = 0; i < num_nodes; i++) { |
862 | if (i != final_node) { |
863 | nodes[i]->gather_loaded_snodes(snodes); |
864 | } |
865 | } |
866 | return snodes; |
867 | } |
868 | |
869 | void ControlFlowGraph::determine_ad_stack_size(int default_ad_stack_size) { |
870 | /** |
871 | * Determine all adaptive AD-stacks' necessary size using the Bellman-Ford |
872 | * algorithm. When there is a positive loop (#pushes > #pops in a loop) |
873 | * for an AD-stack, we cannot determine the size of the AD-stack, and |
874 | * |default_ad_stack_size| is used. The time complexity is |
875 | * O(num_statements + num_stacks * num_edges * num_nodes). |
876 | */ |
877 | const int num_nodes = size(); |
878 | |
879 | // max_increased_size[i][j] is the maximum number of (pushes - pops) of |
880 | // stack |i| among all prefixes of the CFGNode |j|. |
881 | std::unordered_map<AdStackAllocaStmt *, std::vector<int>> max_increased_size; |
882 | |
883 | // increased_size[i][j] is the number of (pushes - pops) of stack |i| in |
884 | // the CFGNode |j|. |
885 | std::unordered_map<AdStackAllocaStmt *, std::vector<int>> increased_size; |
886 | |
887 | std::unordered_map<CFGNode *, int> node_ids; |
888 | std::unordered_set<AdStackAllocaStmt *> all_stacks; |
889 | std::unordered_set<AdStackAllocaStmt *> indeterminable_stacks; |
890 | |
891 | for (int i = 0; i < num_nodes; i++) |
892 | node_ids[nodes[i].get()] = i; |
893 | |
894 | for (int i = 0; i < num_nodes; i++) { |
895 | for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { |
896 | Stmt *stmt = nodes[i]->block->statements[j].get(); |
897 | if (auto *stack = stmt->cast<AdStackAllocaStmt>()) { |
898 | all_stacks.insert(stack); |
899 | max_increased_size.insert( |
900 | std::make_pair(stack, std::vector<int>(num_nodes, 0))); |
901 | increased_size.insert( |
902 | std::make_pair(stack, std::vector<int>(num_nodes, 0))); |
903 | } |
904 | } |
905 | } |
906 | |
907 | // For each basic block we compute the increase of stack size. This is a |
908 | // pre-processing step for the next maximum stack size determining algorithm. |
909 | for (int i = 0; i < num_nodes; i++) { |
910 | for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { |
911 | Stmt *stmt = nodes[i]->block->statements[j].get(); |
912 | if (auto *stack_push = stmt->cast<AdStackPushStmt>()) { |
913 | auto *stack = stack_push->stack->as<AdStackAllocaStmt>(); |
914 | if (stack->max_size == 0 /*adaptive*/) { |
915 | increased_size[stack][i]++; |
916 | if (increased_size[stack][i] > max_increased_size[stack][i]) { |
917 | max_increased_size[stack][i] = increased_size[stack][i]; |
918 | } |
919 | } |
920 | } else if (auto *stack_pop = stmt->cast<AdStackPopStmt>()) { |
921 | auto *stack = stack_pop->stack->as<AdStackAllocaStmt>(); |
922 | if (stack->max_size == 0 /*adaptive*/) { |
923 | increased_size[stack][i]--; |
924 | } |
925 | } |
926 | } |
927 | } |
928 | |
929 | // The maximum stack size determining algorithm -- run the Bellman-Ford |
930 | // algorithm on each AD-stack separately. |
931 | for (auto *stack : all_stacks) { |
932 | // The maximum size of |stack| among all control flows starting at the |
933 | // beginning of the IR. |
934 | int max_size = 0; |
935 | |
936 | // max_size_at_node_begin[j] is the maximum size of |stack| among |
937 | // all control flows starting at the beginning of the IR and ending at the |
938 | // beginning of the CFGNode |j|. Initialize this array to -1 to make sure |
939 | // that the first iteration of the Bellman-Ford algorithm fully updates |
940 | // this array. |
941 | std::vector<int> max_size_at_node_begin(num_nodes, -1); |
942 | |
943 | // The queue for the Bellman-Ford algorithm. |
944 | std::queue<int> to_visit; |
945 | |
946 | // An optimization for the Bellman-Ford algorithm. |
947 | std::vector<bool> in_queue(num_nodes); |
948 | |
949 | // An array for detecting positive loop in the Bellman-Ford algorithm. |
950 | std::vector<int> times_pushed_in_queue(num_nodes, 0); |
951 | |
952 | max_size_at_node_begin[start_node] = 0; |
953 | to_visit.push(start_node); |
954 | in_queue[start_node] = true; |
955 | times_pushed_in_queue[start_node]++; |
956 | |
957 | bool has_positive_loop = false; |
958 | |
959 | // The Bellman-Ford algorithm. |
960 | while (!to_visit.empty()) { |
961 | int node_id = to_visit.front(); |
962 | to_visit.pop(); |
963 | in_queue[node_id] = false; |
964 | CFGNode *now = nodes[node_id].get(); |
965 | |
966 | // Inside this CFGNode -- update the answer |max_size| |
967 | const auto max_size_inside_this_node = max_increased_size[stack][node_id]; |
968 | const auto current_max_size = |
969 | max_size_at_node_begin[node_id] + max_size_inside_this_node; |
970 | if (current_max_size > max_size) { |
971 | max_size = current_max_size; |
972 | } |
973 | // At the end of this CFGNode -- update the state |
974 | // |max_size_at_node_begin| of other CFGNodes |
975 | const auto increase_in_this_node = increased_size[stack][node_id]; |
976 | const auto current_size = |
977 | max_size_at_node_begin[node_id] + increase_in_this_node; |
978 | for (auto *next_node : now->next) { |
979 | int next_node_id = node_ids[next_node]; |
980 | if (current_size > max_size_at_node_begin[next_node_id]) { |
981 | max_size_at_node_begin[next_node_id] = current_size; |
982 | if (!in_queue[next_node_id]) { |
983 | if (times_pushed_in_queue[next_node_id] <= num_nodes) { |
984 | to_visit.push(next_node_id); |
985 | in_queue[next_node_id] = true; |
986 | times_pushed_in_queue[next_node_id]++; |
987 | } else { |
988 | // A positive loop is found because a node is going to be pushed |
989 | // into the queue the (num_nodes + 1)-th time. |
990 | has_positive_loop = true; |
991 | break; |
992 | } |
993 | } |
994 | } |
995 | } |
996 | if (has_positive_loop) { |
997 | break; |
998 | } |
999 | } |
1000 | |
1001 | if (has_positive_loop) { |
1002 | stack->max_size = default_ad_stack_size; |
1003 | indeterminable_stacks.insert(stack); |
1004 | } else { |
1005 | // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks |
1006 | // with maximum capacity indeed equal to 0. |
1007 | TI_WARN_IF(max_size == 0, |
1008 | "Unused autodiff stack {} should have been eliminated." , |
1009 | stack->name()); |
1010 | stack->max_size = max_size; |
1011 | } |
1012 | } |
1013 | |
1014 | // Print a debug message if we have indeterminable AD-stacks' sizes. |
1015 | if (!indeterminable_stacks.empty()) { |
1016 | std::vector<std::string> indeterminable_stacks_name; |
1017 | indeterminable_stacks_name.reserve(indeterminable_stacks.size()); |
1018 | for (auto &stack : indeterminable_stacks) { |
1019 | indeterminable_stacks_name.push_back(stack->name()); |
1020 | } |
1021 | TI_DEBUG( |
1022 | "Unable to determine the necessary size for autodiff stacks [{}]. " |
1023 | "Use " |
1024 | "configured size (CompileConfig::default_ad_stack_size) {} instead." , |
1025 | fmt::join(indeterminable_stacks_name, ", " ), default_ad_stack_size); |
1026 | } |
1027 | } |
1028 | |
1029 | } // namespace taichi::lang |
1030 | |