1#include "taichi/ir/control_flow_graph.h"
2
3#include <queue>
4#include <unordered_set>
5
6#include "taichi/ir/analysis.h"
7#include "taichi/ir/statements.h"
8#include "taichi/system/profiler.h"
9
10namespace taichi::lang {
11
12CFGNode::CFGNode(Block *block,
13 int begin_location,
14 int end_location,
15 bool is_parallel_executed,
16 CFGNode *prev_node_in_same_block)
17 : block(block),
18 begin_location(begin_location),
19 end_location(end_location),
20 is_parallel_executed(is_parallel_executed),
21 prev_node_in_same_block(prev_node_in_same_block),
22 next_node_in_same_block(nullptr) {
23 if (prev_node_in_same_block != nullptr)
24 prev_node_in_same_block->next_node_in_same_block = this;
25 if (!empty()) {
26 // For non-empty nodes, precompute |parent_blocks| to accelerate
27 // get_store_forwarding_data().
28 TI_ASSERT(begin_location >= 0);
29 TI_ASSERT(block);
30 auto parent_block = block;
31 parent_blocks_.insert(parent_block);
32 while (parent_block->parent_block()) {
33 parent_block = parent_block->parent_block();
34 parent_blocks_.insert(parent_block);
35 }
36 }
37}
38
39CFGNode::CFGNode() : CFGNode(nullptr, -1, -1, false, nullptr) {
40}
41
42void CFGNode::add_edge(CFGNode *from, CFGNode *to) {
43 from->next.push_back(to);
44 to->prev.push_back(from);
45}
46
47bool CFGNode::empty() const {
48 return begin_location >= end_location;
49}
50
51std::size_t CFGNode::size() const {
52 return end_location - begin_location;
53}
54
55void CFGNode::erase(int location) {
56 TI_ASSERT(location >= begin_location && location < end_location);
57 block->erase(location);
58 end_location--;
59 for (auto node = next_node_in_same_block; node != nullptr;
60 node = node->next_node_in_same_block) {
61 node->begin_location--;
62 node->end_location--;
63 }
64}
65
66void CFGNode::insert(std::unique_ptr<Stmt> &&new_stmt, int location) {
67 TI_ASSERT(location >= begin_location && location <= end_location);
68 block->insert(std::move(new_stmt), location);
69 end_location++;
70 for (auto node = next_node_in_same_block; node != nullptr;
71 node = node->next_node_in_same_block) {
72 node->begin_location++;
73 node->end_location++;
74 }
75}
76
77void CFGNode::replace_with(int location,
78 std::unique_ptr<Stmt> &&new_stmt,
79 bool replace_usages) const {
80 TI_ASSERT(location >= begin_location && location < end_location);
81 block->replace_with(block->statements[location].get(), std::move(new_stmt),
82 replace_usages);
83}
84
85bool CFGNode::contain_variable(const std::unordered_set<Stmt *> &var_set,
86 Stmt *var) {
87 if (var->is<AllocaStmt>() || var->is<AdStackAllocaStmt>()) {
88 return var_set.find(var) != var_set.end();
89 } else {
90 // TODO: How to optimize this?
91 if (var_set.find(var) != var_set.end())
92 return true;
93 return std::any_of(var_set.begin(), var_set.end(), [&](Stmt *set_var) {
94 return irpass::analysis::definitely_same_address(var, set_var);
95 });
96 }
97}
98
99bool CFGNode::may_contain_variable(const std::unordered_set<Stmt *> &var_set,
100 Stmt *var) {
101 if (var->is<AllocaStmt>() || var->is<AdStackAllocaStmt>()) {
102 return var_set.find(var) != var_set.end();
103 } else {
104 // TODO: How to optimize this?
105 if (var_set.find(var) != var_set.end())
106 return true;
107 return std::any_of(var_set.begin(), var_set.end(), [&](Stmt *set_var) {
108 return irpass::analysis::maybe_same_address(var, set_var);
109 });
110 }
111}
112
113bool CFGNode::reach_kill_variable(Stmt *var) const {
114 // Does this node (definitely) kill a definition of var?
115 return contain_variable(reach_kill, var);
116}
117
118Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
119 // Return the stored data if all definitions in the UD-chain of |var| at
120 // this position store the same data.
121 int last_def_position = -1;
122 for (int i = position - 1; i >= begin_location; i--) {
123 if (block->statements[i]->is<FuncCallStmt>()) {
124 return nullptr;
125 }
126 for (auto store_ptr :
127 irpass::analysis::get_store_destination(block->statements[i].get())) {
128 if (irpass::analysis::definitely_same_address(var, store_ptr)) {
129 last_def_position = i;
130 break;
131 }
132 }
133 if (last_def_position != -1) {
134 break;
135 }
136 }
137 auto may_contain_address = [](Stmt *store_stmt, Stmt *var) {
138 for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) {
139 if (irpass::analysis::maybe_same_address(var, store_ptr)) {
140 return true;
141 }
142 }
143 return false;
144 };
145 if (last_def_position != -1) {
146 // The UD-chain is inside this node.
147 Stmt *result = irpass::analysis::get_store_data(
148 block->statements[last_def_position].get());
149 if (!var->is<AllocaStmt>()) {
150 for (int i = last_def_position + 1; i < position; i++) {
151 if (!irpass::analysis::same_value(
152 result,
153 irpass::analysis::get_store_data(block->statements[i].get()))) {
154 if (may_contain_address(block->statements[i].get(), var)) {
155 return nullptr;
156 }
157 }
158 }
159 }
160 return result;
161 }
162 Stmt *result = nullptr;
163 bool result_visible = false;
164 auto visible = [&](Stmt *stmt) {
165 // Check if |stmt| is before |position| here.
166 if (stmt->parent == block) {
167 return stmt->parent->locate(stmt) < position;
168 }
169 // |parent_blocks| is precomputed in the constructor of CFGNode.
170 // TODO: What if |stmt| appears in an ancestor of |block| but after
171 // |position|?
172 return parent_blocks_.find(stmt->parent) != parent_blocks_.end();
173 };
174 /**
175 * |stmt| is a definition in the UD-chain of |var|. Update |result| with
176 * |stmt|. If either the stored data of |stmt| is not forwardable or the
177 * stored data of |stmt| is not definitely the same as other definitions of
178 * |var|, return false to show that there is no store-to-load forwardable
179 * data.
180 */
181 auto update_result = [&](Stmt *stmt) {
182 auto data = irpass::analysis::get_store_data(stmt);
183 if (!data) { // not forwardable
184 return false; // return nullptr
185 }
186 if (!result) {
187 result = data;
188 result_visible = visible(data);
189 return true; // continue the following loops
190 }
191 if (!irpass::analysis::same_value(result, data)) {
192 // check the special case of alloca (initialized to 0)
193 if (!(result->is<AllocaStmt>() && data->is<ConstStmt>() &&
194 data->as<ConstStmt>()->val.equal_value(0))) {
195 return false; // return nullptr
196 }
197 }
198 if (!result_visible && visible(data)) {
199 // pick the visible one for store-to-load forwarding
200 result = data;
201 result_visible = true;
202 }
203 return true; // continue the following loops
204 };
205 for (auto stmt : reach_in) {
206 // var == stmt is for the case that a global ptr is never stored.
207 // In this case, stmt is from nodes[start_node]->reach_gen.
208 if (var == stmt || may_contain_address(stmt, var)) {
209 if (!update_result(stmt))
210 return nullptr;
211 }
212 }
213 for (auto stmt : reach_gen) {
214 if (may_contain_address(stmt, var) &&
215 stmt->parent->locate(stmt) < position) {
216 if (!update_result(stmt))
217 return nullptr;
218 }
219 }
220 if (!result) {
221 // The UD-chain is empty.
222 TI_WARN("stmt {} loaded in stmt {} before storing.", var->id,
223 block->statements[position]->id);
224 return nullptr;
225 }
226 if (!result_visible) {
227 // The data is store-to-load forwardable but not visible at the place we
228 // are going to forward. We cannot forward it in this case.
229 return nullptr;
230 }
231 return result;
232}
233
234void CFGNode::reaching_definition_analysis(bool after_lower_access) {
235 // Calculate |reach_gen| and |reach_kill|.
236 reach_gen.clear();
237 reach_kill.clear();
238 for (int i = end_location - 1; i >= begin_location; i--) {
239 // loop in reversed order
240 auto stmt = block->statements[i].get();
241 auto data_source_ptrs = irpass::analysis::get_store_destination(stmt);
242 for (auto data_source_ptr : data_source_ptrs) {
243 // stmt provides a data source
244 if (after_lower_access && !(data_source_ptr->is<AllocaStmt>())) {
245 // After lower_access, we only analyze local variables.
246 continue;
247 }
248 if (!reach_kill_variable(data_source_ptr)) {
249 reach_gen.insert(stmt);
250 reach_kill.insert(data_source_ptr);
251 }
252 }
253 }
254}
255
256bool CFGNode::store_to_load_forwarding(bool after_lower_access,
257 bool autodiff_enabled) {
258 bool modified = false;
259 for (int i = begin_location; i < end_location; i++) {
260 // Store-to-load forwarding
261 auto stmt = block->statements[i].get();
262 Stmt *result = nullptr;
263 if (auto local_load = stmt->cast<LocalLoadStmt>()) {
264 result = get_store_forwarding_data(local_load->src, i);
265 } else if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
266 if (!after_lower_access && !autodiff_enabled) {
267 result = get_store_forwarding_data(global_load->src, i);
268 }
269 }
270 if (result) {
271 // Forward the stored data |result|.
272 if (result->is<AllocaStmt>()) {
273 // special case of alloca (initialized to 0)
274 auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
275 replace_with(i, std::move(zero), true);
276 } else {
277 stmt->replace_usages_with(result);
278 erase(i); // This causes end_location--
279 i--; // to cancel i++ in the for loop
280 modified = true;
281 }
282 continue;
283 }
284
285 // Identical store elimination
286 if (auto local_store = stmt->cast<LocalStoreStmt>()) {
287 result = get_store_forwarding_data(local_store->dest, i);
288 if (result && result->is<AllocaStmt>() && !autodiff_enabled) {
289 // special case of alloca (initialized to 0)
290 if (auto stored_data = local_store->val->cast<ConstStmt>()) {
291 if (stored_data->val.equal_value(0)) {
292 erase(i); // This causes end_location--
293 i--; // to cancel i++ in the for loop
294 modified = true;
295 }
296 }
297 } else {
298 // not alloca
299 if (irpass::analysis::same_value(result, local_store->val)) {
300 erase(i); // This causes end_location--
301 i--; // to cancel i++ in the for loop
302 modified = true;
303 }
304 }
305 } else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
306 if (!after_lower_access) {
307 result = get_store_forwarding_data(global_store->dest, i);
308 if (irpass::analysis::same_value(result, global_store->val)) {
309 erase(i); // This causes end_location--
310 i--; // to cancel i++ in the for loop
311 modified = true;
312 }
313 }
314 }
315 }
316 return modified;
317}
318
319void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
320 // Gather the SNodes which this CFGNode loads.
321 // Requires reaching definition analysis.
322 std::unordered_set<Stmt *> killed_in_this_node;
323 for (int i = begin_location; i < end_location; i++) {
324 auto stmt = block->statements[i].get();
325 auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
326 for (auto &load_ptr : load_ptrs) {
327 if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
328 // Avoid computing the UD-chain if every SNode in this global ptr
329 // are already loaded because it can be time-consuming.
330 auto snode = global_ptr->snode;
331 if (snodes.count(snode) > 0) {
332 continue;
333 }
334 if (reach_in.find(global_ptr) != reach_in.end() &&
335 !contain_variable(killed_in_this_node, global_ptr)) {
336 // The UD-chain contains the value before this offloaded task.
337 snodes.insert(snode);
338 }
339 }
340 }
341 auto store_ptrs = irpass::analysis::get_store_destination(stmt);
342 for (auto &store_ptr : store_ptrs) {
343 if (store_ptr->is<GlobalPtrStmt>()) {
344 killed_in_this_node.insert(store_ptr);
345 }
346 }
347 }
348}
349
350void CFGNode::live_variable_analysis(bool after_lower_access) {
351 live_gen.clear();
352 live_kill.clear();
353 for (int i = begin_location; i < end_location; i++) {
354 auto stmt = block->statements[i].get();
355 auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
356 for (auto &load_ptr : load_ptrs) {
357 if (!after_lower_access ||
358 (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
359 // After lower_access, we only analyze local variables and stacks.
360 if (!contain_variable(live_kill, load_ptr)) {
361 live_gen.insert(load_ptr);
362 }
363 }
364 }
365 auto store_ptrs = irpass::analysis::get_store_destination(stmt);
366 // TODO: Consider AD-stacks in get_store_destination instead of here
367 // for store-to-load forwarding on AD-stacks
368 // TODO: SNode deactivation is also a definite store
369 if (auto stack_pop = stmt->cast<AdStackPopStmt>()) {
370 store_ptrs = std::vector<Stmt *>(1, stack_pop->stack);
371 } else if (auto stack_push = stmt->cast<AdStackPushStmt>()) {
372 store_ptrs = std::vector<Stmt *>(1, stack_push->stack);
373 } else if (auto stack_acc_adj = stmt->cast<AdStackAccAdjointStmt>()) {
374 store_ptrs = std::vector<Stmt *>(1, stack_acc_adj->stack);
375 }
376 for (auto store_ptr : store_ptrs) {
377 if (!after_lower_access ||
378 (store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
379 // After lower_access, we only analyze local variables and stacks.
380 live_kill.insert(store_ptr);
381 }
382 }
383 }
384}
385
386bool CFGNode::dead_store_elimination(bool after_lower_access) {
387 bool modified = false;
388 std::unordered_set<Stmt *> live_in_this_node;
389 std::unordered_set<Stmt *> killed_in_this_node;
390 // Map a variable to its nearest load
391 std::unordered_map<Stmt *, Stmt *> live_load_in_this_node;
392 for (int i = end_location - 1; i >= begin_location; i--) {
393 auto stmt = block->statements[i].get();
394 if (stmt->is<FuncCallStmt>()) {
395 killed_in_this_node.clear();
396 live_load_in_this_node.clear();
397 }
398 auto store_ptrs = irpass::analysis::get_store_destination(stmt);
399 // TODO: Consider AD-stacks in get_store_destination instead of here
400 // for store-to-load forwarding on AD-stacks
401 if (auto stack_pop = stmt->cast<AdStackPopStmt>()) {
402 store_ptrs = std::vector<Stmt *>(1, stack_pop->stack);
403 } else if (auto stack_push = stmt->cast<AdStackPushStmt>()) {
404 store_ptrs = std::vector<Stmt *>(1, stack_push->stack);
405 } else if (auto stack_acc_adj = stmt->cast<AdStackAccAdjointStmt>()) {
406 store_ptrs = std::vector<Stmt *>(1, stack_acc_adj->stack);
407 } else if (stmt->is<AdStackAllocaStmt>()) {
408 store_ptrs = std::vector<Stmt *>(1, stmt);
409 }
410 if (store_ptrs.size() == 1) {
411 // Dead store elimination
412 auto store_ptr = *store_ptrs.begin();
413 if (!after_lower_access ||
414 (store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
415 // After lower_access, we only analyze local variables and stacks.
416 // Do not eliminate AllocaStmt and AdStackAllocaStmt here.
417 if (!stmt->is<AllocaStmt>() && !stmt->is<AdStackAllocaStmt>() &&
418 !stmt->is<ExternalFuncCallStmt>() &&
419 !may_contain_variable(live_in_this_node, store_ptr) &&
420 (contain_variable(killed_in_this_node, store_ptr) ||
421 !may_contain_variable(live_out, store_ptr))) {
422 // Neither used in other nodes nor used in this node.
423 if (!stmt->is<AtomicOpStmt>()) {
424 // Eliminate the dead store.
425 erase(i);
426 modified = true;
427 continue;
428 }
429 auto atomic = stmt->cast<AtomicOpStmt>();
430 // Weaken the atomic operation to a load.
431 if (atomic->dest->is<AllocaStmt>()) {
432 auto local_load = Stmt::make<LocalLoadStmt>(atomic->dest);
433 local_load->ret_type = atomic->ret_type;
434 // Notice that we have a load here
435 // (the return value of AtomicOpStmt).
436 live_in_this_node.insert(atomic->dest);
437 live_load_in_this_node[atomic->dest] = local_load.get();
438 killed_in_this_node.erase(atomic->dest);
439 replace_with(i, std::move(local_load), true);
440 modified = true;
441 continue;
442 } else if (!is_parallel_executed ||
443 (atomic->dest->is<GlobalPtrStmt>() &&
444 atomic->dest->as<GlobalPtrStmt>()->snode->is_scalar())) {
445 // If this node is parallel executed, we can't weaken a global
446 // atomic operation to a global load.
447 // TODO: we can weaken it if it's element-wise (i.e. never
448 // accessed by other threads).
449 auto global_load = Stmt::make<GlobalLoadStmt>(atomic->dest);
450 global_load->ret_type = atomic->ret_type;
451 // Notice that we have a load here
452 // (the return value of AtomicOpStmt).
453 live_in_this_node.insert(atomic->dest);
454 live_load_in_this_node[atomic->dest] = global_load.get();
455 killed_in_this_node.erase(atomic->dest);
456 replace_with(i, std::move(global_load), true);
457 modified = true;
458 continue;
459 }
460 } else {
461 // A non-eliminated store.
462 killed_in_this_node.insert(store_ptr);
463 auto old_live_in_this_node = std::move(live_in_this_node);
464 live_in_this_node.clear();
465 for (auto &var : old_live_in_this_node) {
466 if (!irpass::analysis::definitely_same_address(store_ptr, var))
467 live_in_this_node.insert(var);
468 }
469 }
470 }
471 }
472 auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
473 if (load_ptrs.size() == 1 && store_ptrs.empty()) {
474 // Identical load elimination
475 auto load_ptr = load_ptrs.begin()[0];
476 if (!after_lower_access ||
477 (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
478 // After lower_access, we only analyze local variables and stacks.
479 if (live_load_in_this_node.find(load_ptr) !=
480 live_load_in_this_node.end() &&
481 !may_contain_variable(killed_in_this_node, load_ptr)) {
482 // Only perform identical load elimination within a CFGNode.
483 auto next_load_stmt = live_load_in_this_node[load_ptr];
484 TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt));
485 next_load_stmt->replace_usages_with(stmt);
486 erase(block->locate(next_load_stmt));
487 modified = true;
488 }
489 live_load_in_this_node[load_ptr] = stmt;
490 killed_in_this_node.erase(load_ptr);
491 }
492 }
493 for (auto &load_ptr : load_ptrs) {
494 if (!after_lower_access ||
495 (load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
496 // After lower_access, we only analyze local variables and stacks.
497 live_in_this_node.insert(load_ptr);
498 }
499 }
500 }
501 return modified;
502}
503
504void ControlFlowGraph::erase(int node_id) {
505 // Erase an empty node.
506 TI_ASSERT(node_id >= 0 && node_id < (int)size());
507 TI_ASSERT(nodes[node_id] && nodes[node_id]->empty());
508 if (nodes[node_id]->prev_node_in_same_block) {
509 nodes[node_id]->prev_node_in_same_block->next_node_in_same_block =
510 nodes[node_id]->next_node_in_same_block;
511 }
512 if (nodes[node_id]->next_node_in_same_block) {
513 nodes[node_id]->next_node_in_same_block->prev_node_in_same_block =
514 nodes[node_id]->prev_node_in_same_block;
515 }
516 for (auto &prev_node : nodes[node_id]->prev) {
517 prev_node->next.erase(std::find(
518 prev_node->next.begin(), prev_node->next.end(), nodes[node_id].get()));
519 }
520 for (auto &next_node : nodes[node_id]->next) {
521 next_node->prev.erase(std::find(
522 next_node->prev.begin(), next_node->prev.end(), nodes[node_id].get()));
523 }
524 for (auto &prev_node : nodes[node_id]->prev) {
525 for (auto &next_node : nodes[node_id]->next) {
526 CFGNode::add_edge(prev_node, next_node);
527 }
528 }
529 nodes[node_id].reset();
530}
531
532std::size_t ControlFlowGraph::size() const {
533 return nodes.size();
534}
535
536CFGNode *ControlFlowGraph::back() const {
537 return nodes.back().get();
538}
539
540void ControlFlowGraph::print_graph_structure() const {
541 const int num_nodes = size();
542 std::cout << "Control Flow Graph with " << num_nodes
543 << " nodes:" << std::endl;
544 std::unordered_map<CFGNode *, int> to_index;
545 for (int i = 0; i < num_nodes; i++) {
546 to_index[nodes[i].get()] = i;
547 }
548 for (int i = 0; i < num_nodes; i++) {
549 std::string node_info = fmt::format("Node {} : ", i);
550 if (nodes[i]->empty()) {
551 node_info += "empty";
552 } else {
553 node_info += fmt::format(
554 "{}~{} (size={})",
555 nodes[i]->block->statements[nodes[i]->begin_location]->name(),
556 nodes[i]->block->statements[nodes[i]->end_location - 1]->name(),
557 nodes[i]->size());
558 }
559 if (!nodes[i]->prev.empty()) {
560 std::vector<std::string> indices;
561 for (auto prev_node : nodes[i]->prev) {
562 indices.push_back(std::to_string(to_index[prev_node]));
563 }
564 node_info += fmt::format("; prev={{{}}}", fmt::join(indices, ", "));
565 }
566 if (!nodes[i]->next.empty()) {
567 std::vector<std::string> indices;
568 for (auto next_node : nodes[i]->next) {
569 indices.push_back(std::to_string(to_index[next_node]));
570 }
571 node_info += fmt::format("; next={{{}}}", fmt::join(indices, ", "));
572 }
573 if (!nodes[i]->reach_in.empty()) {
574 std::vector<std::string> indices;
575 for (auto stmt : nodes[i]->reach_in) {
576 indices.push_back(stmt->name());
577 }
578 node_info += fmt::format("; reach_in={{{}}}", fmt::join(indices, ", "));
579 }
580 if (!nodes[i]->reach_out.empty()) {
581 std::vector<std::string> indices;
582 for (auto stmt : nodes[i]->reach_out) {
583 indices.push_back(stmt->name());
584 }
585 node_info += fmt::format("; reach_out={{{}}}", fmt::join(indices, ", "));
586 }
587 std::cout << node_info << std::endl;
588 }
589}
590
591void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
592 TI_AUTO_PROF;
593 const int num_nodes = size();
594 std::queue<CFGNode *> to_visit;
595 std::unordered_map<CFGNode *, bool> in_queue;
596 TI_ASSERT(nodes[start_node]->empty());
597 nodes[start_node]->reach_gen.clear();
598 nodes[start_node]->reach_kill.clear();
599 for (int i = 0; i < num_nodes; i++) {
600 for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
601 auto stmt = nodes[i]->block->statements[j].get();
602 if ((stmt->is<MatrixPtrStmt>() &&
603 stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
604 (!after_lower_access &&
605 (stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
606 stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
607 stmt->is<GlobalTemporaryStmt>() || stmt->is<MatrixPtrStmt>() ||
608 stmt->is<GetChStmt>()))) {
609 // TODO: unify them
610 // A global pointer that may contain some data before this kernel.
611 nodes[start_node]->reach_gen.insert(stmt);
612 }
613 }
614 }
615 for (int i = 0; i < num_nodes; i++) {
616 if (i != start_node) {
617 nodes[i]->reaching_definition_analysis(after_lower_access);
618 }
619 nodes[i]->reach_in.clear();
620 nodes[i]->reach_out = nodes[i]->reach_gen;
621 to_visit.push(nodes[i].get());
622 in_queue[nodes[i].get()] = true;
623 }
624
625 // The worklist algorithm.
626 while (!to_visit.empty()) {
627 auto now = to_visit.front();
628 to_visit.pop();
629 in_queue[now] = false;
630
631 now->reach_in.clear();
632 for (auto prev_node : now->prev) {
633 now->reach_in.insert(prev_node->reach_out.begin(),
634 prev_node->reach_out.end());
635 }
636 auto old_out = std::move(now->reach_out);
637 now->reach_out = now->reach_gen;
638 for (auto stmt : now->reach_in) {
639 auto store_ptrs = irpass::analysis::get_store_destination(stmt);
640 bool killed;
641 if (store_ptrs.empty()) { // the case of a global pointer
642 killed = now->reach_kill_variable(stmt);
643 } else {
644 killed = true;
645 for (auto store_ptr : store_ptrs) {
646 if (!now->reach_kill_variable(store_ptr)) {
647 killed = false;
648 break;
649 }
650 }
651 }
652 if (!killed) {
653 now->reach_out.insert(stmt);
654 }
655 }
656 if (now->reach_out != old_out) {
657 // changed
658 for (auto next_node : now->next) {
659 if (!in_queue[next_node]) {
660 to_visit.push(next_node);
661 in_queue[next_node] = true;
662 }
663 }
664 }
665 }
666}
667
668void ControlFlowGraph::live_variable_analysis(
669 bool after_lower_access,
670 const std::optional<LiveVarAnalysisConfig> &config_opt) {
671 TI_AUTO_PROF;
672 const int num_nodes = size();
673 std::queue<CFGNode *> to_visit;
674 std::unordered_map<CFGNode *, bool> in_queue;
675 TI_ASSERT(nodes[final_node]->empty());
676 nodes[final_node]->live_gen.clear();
677 nodes[final_node]->live_kill.clear();
678
679 auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool {
680 if (stmt->is<AllocaStmt>() || stmt->is<AdStackAllocaStmt>()) {
681 return false;
682 }
683 if (stmt->is<MatrixPtrStmt>() &&
684 stmt->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>()) {
685 return false;
686 }
687 if (auto *gptr = stmt->cast<GlobalPtrStmt>();
688 gptr && config_opt.has_value()) {
689 const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0);
690 return res;
691 }
692 // A global pointer that may be loaded after this kernel.
693 return true;
694 };
695 if (!after_lower_access) {
696 for (int i = 0; i < num_nodes; i++) {
697 for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
698 auto stmt = nodes[i]->block->statements[j].get();
699 for (auto store_ptr : irpass::analysis::get_store_destination(stmt)) {
700 if (in_final_node_live_gen(store_ptr)) {
701 nodes[final_node]->live_gen.insert(store_ptr);
702 }
703 }
704 }
705 }
706 }
707 for (int i = num_nodes - 1; i >= 0; i--) {
708 // push into the queue in reversed order to make it slightly faster
709 if (i != final_node) {
710 nodes[i]->live_variable_analysis(after_lower_access);
711 }
712 nodes[i]->live_out.clear();
713 nodes[i]->live_in = nodes[i]->live_gen;
714 to_visit.push(nodes[i].get());
715 in_queue[nodes[i].get()] = true;
716 }
717
718 // The worklist algorithm.
719 while (!to_visit.empty()) {
720 auto now = to_visit.front();
721 to_visit.pop();
722 in_queue[now] = false;
723
724 now->live_out.clear();
725 for (auto next_node : now->next) {
726 now->live_out.insert(next_node->live_in.begin(),
727 next_node->live_in.end());
728 }
729 auto old_in = std::move(now->live_in);
730 now->live_in = now->live_gen;
731 for (auto stmt : now->live_out) {
732 if (!CFGNode::contain_variable(now->live_kill, stmt)) {
733 now->live_in.insert(stmt);
734 }
735 }
736 if (now->live_in != old_in) {
737 // changed
738 for (auto prev_node : now->prev) {
739 if (!in_queue[prev_node]) {
740 to_visit.push(prev_node);
741 in_queue[prev_node] = true;
742 }
743 }
744 }
745 }
746}
747
748void ControlFlowGraph::simplify_graph() {
749 // Simplify the graph structure, do not modify the IR.
750 const int num_nodes = size();
751 while (true) {
752 bool modified = false;
753 for (int i = 0; i < num_nodes; i++) {
754 // If a node is empty with in-degree or out-degree <= 1, we can eliminate
755 // it (except for the start node and the final node).
756 if (nodes[i] && nodes[i]->empty() && i != start_node && i != final_node &&
757 (nodes[i]->prev.size() <= 1 || nodes[i]->next.size() <= 1)) {
758 erase(i);
759 modified = true;
760 }
761 }
762 if (!modified)
763 break;
764 }
765 int new_num_nodes = 0;
766 for (int i = 0; i < num_nodes; i++) {
767 if (nodes[i]) {
768 if (i != new_num_nodes) {
769 nodes[new_num_nodes] = std::move(nodes[i]);
770 }
771 if (final_node == i) {
772 final_node = new_num_nodes;
773 }
774 new_num_nodes++;
775 }
776 }
777 nodes.resize(new_num_nodes);
778}
779
780bool ControlFlowGraph::unreachable_code_elimination() {
781 // Note that container statements are not in the control-flow graph, so
782 // this pass cannot eliminate container statements properly for now.
783 TI_AUTO_PROF;
784 std::unordered_set<CFGNode *> visited;
785 std::queue<CFGNode *> to_visit;
786 to_visit.push(nodes[start_node].get());
787 visited.insert(nodes[start_node].get());
788 // Breadth-first search
789 while (!to_visit.empty()) {
790 auto now = to_visit.front();
791 to_visit.pop();
792 for (auto &next : now->next) {
793 if (visited.find(next) == visited.end()) {
794 to_visit.push(next);
795 visited.insert(next);
796 }
797 }
798 }
799 bool modified = false;
800 for (auto &node : nodes) {
801 if (visited.find(node.get()) == visited.end()) {
802 // unreachable
803 if (!node->empty()) {
804 while (!node->empty())
805 node->erase(node->end_location - 1);
806 modified = true;
807 }
808 }
809 }
810 return modified;
811}
812
813bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access,
814 bool autodiff_enabled) {
815 TI_AUTO_PROF;
816 reaching_definition_analysis(after_lower_access);
817 const int num_nodes = size();
818 bool modified = false;
819 for (int i = 0; i < num_nodes; i++) {
820 if (nodes[i]->store_to_load_forwarding(after_lower_access,
821 autodiff_enabled))
822 modified = true;
823 }
824 return modified;
825}
826
827bool ControlFlowGraph::dead_store_elimination(
828 bool after_lower_access,
829 const std::optional<LiveVarAnalysisConfig> &lva_config_opt) {
830 TI_AUTO_PROF;
831 live_variable_analysis(after_lower_access, lva_config_opt);
832 const int num_nodes = size();
833 bool modified = false;
834 for (int i = 0; i < num_nodes; i++) {
835 if (nodes[i]->dead_store_elimination(after_lower_access))
836 modified = true;
837 }
838 return modified;
839}
840
841std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
842 TI_AUTO_PROF;
843 reaching_definition_analysis(/*after_lower_access=*/false);
844 const int num_nodes = size();
845 std::unordered_set<SNode *> snodes;
846
847 // Note: since global store may only partially modify a value state, the
848 // result (which contains the modified and unmodified part) actually needs a
849 // read from the previous version of the value state.
850 //
851 // I.e.,
852 // output_value_state = merge(input_value_state, written_part)
853 //
854 // Therefore we include the nodes[final_node]->reach_in in snodes.
855 for (auto &stmt : nodes[final_node]->reach_in) {
856 if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
857 snodes.insert(global_ptr->snode);
858 }
859 }
860
861 for (int i = 0; i < num_nodes; i++) {
862 if (i != final_node) {
863 nodes[i]->gather_loaded_snodes(snodes);
864 }
865 }
866 return snodes;
867}
868
869void ControlFlowGraph::determine_ad_stack_size(int default_ad_stack_size) {
870 /**
871 * Determine all adaptive AD-stacks' necessary size using the Bellman-Ford
872 * algorithm. When there is a positive loop (#pushes > #pops in a loop)
873 * for an AD-stack, we cannot determine the size of the AD-stack, and
874 * |default_ad_stack_size| is used. The time complexity is
875 * O(num_statements + num_stacks * num_edges * num_nodes).
876 */
877 const int num_nodes = size();
878
879 // max_increased_size[i][j] is the maximum number of (pushes - pops) of
880 // stack |i| among all prefixes of the CFGNode |j|.
881 std::unordered_map<AdStackAllocaStmt *, std::vector<int>> max_increased_size;
882
883 // increased_size[i][j] is the number of (pushes - pops) of stack |i| in
884 // the CFGNode |j|.
885 std::unordered_map<AdStackAllocaStmt *, std::vector<int>> increased_size;
886
887 std::unordered_map<CFGNode *, int> node_ids;
888 std::unordered_set<AdStackAllocaStmt *> all_stacks;
889 std::unordered_set<AdStackAllocaStmt *> indeterminable_stacks;
890
891 for (int i = 0; i < num_nodes; i++)
892 node_ids[nodes[i].get()] = i;
893
894 for (int i = 0; i < num_nodes; i++) {
895 for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
896 Stmt *stmt = nodes[i]->block->statements[j].get();
897 if (auto *stack = stmt->cast<AdStackAllocaStmt>()) {
898 all_stacks.insert(stack);
899 max_increased_size.insert(
900 std::make_pair(stack, std::vector<int>(num_nodes, 0)));
901 increased_size.insert(
902 std::make_pair(stack, std::vector<int>(num_nodes, 0)));
903 }
904 }
905 }
906
907 // For each basic block we compute the increase of stack size. This is a
908 // pre-processing step for the next maximum stack size determining algorithm.
909 for (int i = 0; i < num_nodes; i++) {
910 for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
911 Stmt *stmt = nodes[i]->block->statements[j].get();
912 if (auto *stack_push = stmt->cast<AdStackPushStmt>()) {
913 auto *stack = stack_push->stack->as<AdStackAllocaStmt>();
914 if (stack->max_size == 0 /*adaptive*/) {
915 increased_size[stack][i]++;
916 if (increased_size[stack][i] > max_increased_size[stack][i]) {
917 max_increased_size[stack][i] = increased_size[stack][i];
918 }
919 }
920 } else if (auto *stack_pop = stmt->cast<AdStackPopStmt>()) {
921 auto *stack = stack_pop->stack->as<AdStackAllocaStmt>();
922 if (stack->max_size == 0 /*adaptive*/) {
923 increased_size[stack][i]--;
924 }
925 }
926 }
927 }
928
929 // The maximum stack size determining algorithm -- run the Bellman-Ford
930 // algorithm on each AD-stack separately.
931 for (auto *stack : all_stacks) {
932 // The maximum size of |stack| among all control flows starting at the
933 // beginning of the IR.
934 int max_size = 0;
935
936 // max_size_at_node_begin[j] is the maximum size of |stack| among
937 // all control flows starting at the beginning of the IR and ending at the
938 // beginning of the CFGNode |j|. Initialize this array to -1 to make sure
939 // that the first iteration of the Bellman-Ford algorithm fully updates
940 // this array.
941 std::vector<int> max_size_at_node_begin(num_nodes, -1);
942
943 // The queue for the Bellman-Ford algorithm.
944 std::queue<int> to_visit;
945
946 // An optimization for the Bellman-Ford algorithm.
947 std::vector<bool> in_queue(num_nodes);
948
949 // An array for detecting positive loop in the Bellman-Ford algorithm.
950 std::vector<int> times_pushed_in_queue(num_nodes, 0);
951
952 max_size_at_node_begin[start_node] = 0;
953 to_visit.push(start_node);
954 in_queue[start_node] = true;
955 times_pushed_in_queue[start_node]++;
956
957 bool has_positive_loop = false;
958
959 // The Bellman-Ford algorithm.
960 while (!to_visit.empty()) {
961 int node_id = to_visit.front();
962 to_visit.pop();
963 in_queue[node_id] = false;
964 CFGNode *now = nodes[node_id].get();
965
966 // Inside this CFGNode -- update the answer |max_size|
967 const auto max_size_inside_this_node = max_increased_size[stack][node_id];
968 const auto current_max_size =
969 max_size_at_node_begin[node_id] + max_size_inside_this_node;
970 if (current_max_size > max_size) {
971 max_size = current_max_size;
972 }
973 // At the end of this CFGNode -- update the state
974 // |max_size_at_node_begin| of other CFGNodes
975 const auto increase_in_this_node = increased_size[stack][node_id];
976 const auto current_size =
977 max_size_at_node_begin[node_id] + increase_in_this_node;
978 for (auto *next_node : now->next) {
979 int next_node_id = node_ids[next_node];
980 if (current_size > max_size_at_node_begin[next_node_id]) {
981 max_size_at_node_begin[next_node_id] = current_size;
982 if (!in_queue[next_node_id]) {
983 if (times_pushed_in_queue[next_node_id] <= num_nodes) {
984 to_visit.push(next_node_id);
985 in_queue[next_node_id] = true;
986 times_pushed_in_queue[next_node_id]++;
987 } else {
988 // A positive loop is found because a node is going to be pushed
989 // into the queue the (num_nodes + 1)-th time.
990 has_positive_loop = true;
991 break;
992 }
993 }
994 }
995 }
996 if (has_positive_loop) {
997 break;
998 }
999 }
1000
1001 if (has_positive_loop) {
1002 stack->max_size = default_ad_stack_size;
1003 indeterminable_stacks.insert(stack);
1004 } else {
1005 // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks
1006 // with maximum capacity indeed equal to 0.
1007 TI_WARN_IF(max_size == 0,
1008 "Unused autodiff stack {} should have been eliminated.",
1009 stack->name());
1010 stack->max_size = max_size;
1011 }
1012 }
1013
1014 // Print a debug message if we have indeterminable AD-stacks' sizes.
1015 if (!indeterminable_stacks.empty()) {
1016 std::vector<std::string> indeterminable_stacks_name;
1017 indeterminable_stacks_name.reserve(indeterminable_stacks.size());
1018 for (auto &stack : indeterminable_stacks) {
1019 indeterminable_stacks_name.push_back(stack->name());
1020 }
1021 TI_DEBUG(
1022 "Unable to determine the necessary size for autodiff stacks [{}]. "
1023 "Use "
1024 "configured size (CompileConfig::default_ad_stack_size) {} instead.",
1025 fmt::join(indeterminable_stacks_name, ", "), default_ad_stack_size);
1026 }
1027}
1028
1029} // namespace taichi::lang
1030