1#include "taichi/ir/control_flow_graph.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/statements.h"
4#include "taichi/program/function.h"
5
6namespace taichi::lang {
7
8struct CFGFuncKey {
9 FunctionKey func_key{"", -1, -1};
10 bool in_parallel_for{false};
11
12 bool operator==(const CFGFuncKey &other_key) const {
13 return func_key == other_key.func_key &&
14 in_parallel_for == other_key.in_parallel_for;
15 }
16};
17
18} // namespace taichi::lang
19
20namespace std {
21template <>
22struct hash<taichi::lang::CFGFuncKey> {
23 std::size_t operator()(const taichi::lang::CFGFuncKey &key) const noexcept {
24 return std::hash<taichi::lang::FunctionKey>()(key.func_key) ^
25 ((std::size_t)key.in_parallel_for << 32);
26 }
27};
28} // namespace std
29
30namespace taichi::lang {
31
32/**
33 * Build a control-flow graph. The resulting graph is guaranteed to have an
34 * empty start node and an empty final node.
35 *
36 * In the following docstrings, node... means a CFGNode's corresponding
37 * statements in the CHI IR. Other blocks are just Blocks in the CHI IR.
38 * Nodes denoted with "()" mean not yet created when visiting the Stmt/Block.
39 *
40 * Structures like
41 * node_a {
42 * ...
43 * } -> node_b, node_c;
44 * means node_a has edges to node_b and node_c, or equivalently, node_b and
45 * node_c appear in the |next| field of node_a.
46 *
47 * Structures like
48 * node_a {
49 * ...
50 * } -> node_b, [node_c if "cond"];
51 * means node_a has an edge to node_b, and node_a has an edge to node_c iff
52 * the condition "cond" is true.
53 *
54 * When there can be many CFGNodes in a Block, internal nodes are omitted for
55 * simplicity.
56 *
57 * TODO(#2193): Make sure ReturnStmt is handled properly.
58 */
59class CFGBuilder : public IRVisitor {
60 public:
61 CFGBuilder()
62 : current_block_(nullptr),
63 last_node_in_current_block_(nullptr),
64 current_stmt_id_(-1),
65 begin_location_(-1),
66 current_offload_(nullptr),
67 in_parallel_for_(false) {
68 allow_undefined_visitor = true;
69 invoke_default_visitor = true;
70 graph_ = std::make_unique<ControlFlowGraph>();
71 // Make an empty start node.
72 auto start_node = graph_->push_back();
73 prev_nodes_.push_back(start_node);
74 }
75
76 void visit(Stmt *stmt) override {
77 if (stmt->is_container_statement()) {
78 TI_ERROR("Visitor for container statement undefined.");
79 }
80 }
81
82 /**
83 * Create a node for the current control-flow graph,
84 * mark the current statement as the end location (exclusive) of the node,
85 * and add edges from |prev_nodes| to the node.
86 *
87 * @param next_begin_location The location in the IR block of the first
88 * statement in the next node, if the next node is in the same IR block of
89 * the node to be returned. Otherwise, next_begin_location must be -1.
90 * @return The node which is just created.
91 */
92 CFGNode *new_node(int next_begin_location) {
93 auto node = graph_->push_back(
94 current_block_, begin_location_, /*end_location=*/current_stmt_id_,
95 /*is_parallel_executed=*/in_parallel_for_,
96 /*prev_node_in_same_block=*/last_node_in_current_block_);
97 for (auto &prev_node : prev_nodes_) {
98 // Now that the "(next node)" is created, we should insert edges
99 // "node... -> (next node)" here.
100 CFGNode::add_edge(prev_node, node);
101 }
102 prev_nodes_.clear();
103 begin_location_ = next_begin_location;
104 last_node_in_current_block_ = node;
105 return node;
106 }
107
108 /**
109 * Structure:
110 *
111 * block {
112 * node {
113 * ...
114 * } -> node_loop_begin, (the next node after the loop);
115 * continue;
116 * (next node) {
117 * ...
118 * }
119 * }
120 *
121 * Note that the edges are inserted in visit_loop().
122 */
123 void visit(ContinueStmt *stmt) override {
124 // Don't put ContinueStmt in any CFGNodes.
125 continues_in_current_loop_.push_back(new_node(current_stmt_id_ + 1));
126 }
127
128 /**
129 * Structure:
130 *
131 * block {
132 * node {
133 * ...
134 * } -> (next node), (the next node after the loop);
135 * while_control (possibly break);
136 * (next node) {
137 * ...
138 * }
139 * }
140 *
141 * Note that the edges are inserted in visit_loop().
142 */
143 void visit(WhileControlStmt *stmt) override {
144 // Don't put WhileControlStmt in any CFGNodes.
145 auto node = new_node(current_stmt_id_ + 1);
146 breaks_in_current_loop_.push_back(node);
147 prev_nodes_.push_back(node);
148 }
149
150 /**
151 * Structure:
152 *
153 * node_before_if {
154 * ...
155 * } -> node_true_branch_begin, node_false_branch_begin;
156 * if (...) {
157 * node_true_branch_begin {
158 * ...
159 * } -> ... -> node_true_branch_end;
160 * node_true_branch_end {
161 * ...
162 * } -> (next node);
163 * } else {
164 * node_false_branch_begin {
165 * ...
166 * } -> ... -> node_false_branch_end;
167 * node_false_branch_end {
168 * ...
169 * } -> (next node);
170 * }
171 * (next node) {
172 * ...
173 * }
174 */
175 void visit(IfStmt *if_stmt) override {
176 auto before_if = new_node(-1);
177 CFGNode *true_branch_end = nullptr;
178 if (if_stmt->true_statements) {
179 auto true_branch_begin = graph_->size();
180 if_stmt->true_statements->accept(this);
181 CFGNode::add_edge(before_if, graph_->nodes[true_branch_begin].get());
182 true_branch_end = graph_->back();
183 }
184 CFGNode *false_branch_end = nullptr;
185 if (if_stmt->false_statements) {
186 auto false_branch_begin = graph_->size();
187 if_stmt->false_statements->accept(this);
188 CFGNode::add_edge(before_if, graph_->nodes[false_branch_begin].get());
189 false_branch_end = graph_->back();
190 }
191 TI_ASSERT(prev_nodes_.empty());
192 if (if_stmt->true_statements)
193 prev_nodes_.push_back(true_branch_end);
194 if (if_stmt->false_statements)
195 prev_nodes_.push_back(false_branch_end);
196 if (!if_stmt->true_statements || !if_stmt->false_statements)
197 prev_nodes_.push_back(before_if);
198 // Container statements don't belong to any CFGNodes.
199 begin_location_ = current_stmt_id_ + 1;
200 }
201
202 /**
203 * Structure ([(next node) if !is_while_true] means the node has an edge to
204 * (next node) only when is_while_true is false):
205 *
206 * node_before_loop {
207 * ...
208 * } -> node_loop_begin, [(next node) if !is_while_true];
209 * loop (...) {
210 * node_loop_begin {
211 * ...
212 * } -> ... -> node_loop_end;
213 * node_loop_end {
214 * ...
215 * } -> node_loop_begin, [(next node) if !is_while_true];
216 * }
217 * (next node) {
218 * ...
219 * }
220 */
221 void visit_loop(Block *body, CFGNode *before_loop, bool is_while_true) {
222 int loop_stmt_id = current_stmt_id_;
223 auto backup_continues = std::move(continues_in_current_loop_);
224 auto backup_breaks = std::move(breaks_in_current_loop_);
225 continues_in_current_loop_.clear();
226 breaks_in_current_loop_.clear();
227
228 auto loop_begin_index = graph_->size();
229 body->accept(this);
230 auto loop_begin = graph_->nodes[loop_begin_index].get();
231 CFGNode::add_edge(before_loop, loop_begin);
232 auto loop_end = graph_->back();
233 CFGNode::add_edge(loop_end, loop_begin);
234 if (!is_while_true) {
235 prev_nodes_.push_back(before_loop);
236 prev_nodes_.push_back(loop_end);
237 }
238 for (auto &node : continues_in_current_loop_) {
239 CFGNode::add_edge(node, loop_begin);
240 prev_nodes_.push_back(node);
241 }
242 for (auto &node : breaks_in_current_loop_) {
243 prev_nodes_.push_back(node);
244 }
245
246 // Container statements don't belong to any CFGNodes.
247 begin_location_ = loop_stmt_id + 1;
248 continues_in_current_loop_ = std::move(backup_continues);
249 breaks_in_current_loop_ = std::move(backup_breaks);
250 }
251
252 void visit(WhileStmt *stmt) override {
253 visit_loop(stmt->body.get(), new_node(-1), true);
254 }
255
256 void visit(RangeForStmt *stmt) override {
257 auto old_in_parallel_for = in_parallel_for_;
258 if (!current_offload_)
259 in_parallel_for_ = true;
260 visit_loop(stmt->body.get(), new_node(-1), false);
261 in_parallel_for_ = old_in_parallel_for;
262 }
263
264 void visit(StructForStmt *stmt) override {
265 auto old_in_parallel_for = in_parallel_for_;
266 if (!current_offload_)
267 in_parallel_for_ = true;
268 visit_loop(stmt->body.get(), new_node(-1), false);
269 in_parallel_for_ = old_in_parallel_for;
270 }
271
272 void visit(MeshForStmt *stmt) override {
273 auto old_in_parallel_for = in_parallel_for_;
274 if (!current_offload_)
275 in_parallel_for_ = true;
276 visit_loop(stmt->body.get(), new_node(-1), false);
277 in_parallel_for_ = old_in_parallel_for;
278 }
279
280 /**
281 * Structure:
282 *
283 * node_before_offload {
284 * ...
285 * } -> node_tls_prologue;
286 * node_tls_prologue {
287 * ...
288 * } -> node_mesh_prologue;
289 * node_mesh_prologue:
290 * ...
291 * } -> node_bls_prologue;
292 * node_bls_prologue {
293 * ...
294 * } -> node_body;
295 * node_body {
296 * ...
297 * } -> node_bls_epilogue;
298 * node_bls_epilogue {
299 * ...
300 * } -> node_tls_epilogue;
301 * node_tls_epilogue {
302 * ...
303 * } -> (next node);
304 * (next node) {
305 * ...
306 * }
307 */
308 void visit(OffloadedStmt *stmt) override {
309 current_offload_ = stmt;
310 if (stmt->tls_prologue) {
311 auto before_offload = new_node(-1);
312 int offload_stmt_id = current_stmt_id_;
313 auto block_begin_index = graph_->size();
314 stmt->tls_prologue->accept(this);
315 prev_nodes_.push_back(graph_->back());
316 // Container statements don't belong to any CFGNodes.
317 begin_location_ = offload_stmt_id + 1;
318 CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get());
319 }
320 if (stmt->mesh_prologue) {
321 auto before_offload = new_node(-1);
322 int offload_stmt_id = current_stmt_id_;
323 auto block_begin_index = graph_->size();
324 stmt->mesh_prologue->accept(this);
325 prev_nodes_.push_back(graph_->back());
326 // Container statements don't belong to any CFGNodes.
327 begin_location_ = offload_stmt_id + 1;
328 CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get());
329 }
330 if (stmt->bls_prologue) {
331 auto before_offload = new_node(-1);
332 int offload_stmt_id = current_stmt_id_;
333 auto block_begin_index = graph_->size();
334 stmt->bls_prologue->accept(this);
335 prev_nodes_.push_back(graph_->back());
336 // Container statements don't belong to any CFGNodes.
337 begin_location_ = offload_stmt_id + 1;
338 CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get());
339 }
340 if (stmt->has_body()) {
341 auto before_offload = new_node(-1);
342 int offload_stmt_id = current_stmt_id_;
343 auto block_begin_index = graph_->size();
344 if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
345 stmt->task_type == OffloadedStmt::TaskType::struct_for ||
346 stmt->task_type == OffloadedStmt::TaskType::mesh_for) {
347 in_parallel_for_ = true;
348 }
349 stmt->body->accept(this);
350 auto block_begin = graph_->nodes[block_begin_index].get();
351 for (auto &node : continues_in_current_loop_) {
352 CFGNode::add_edge(node, block_begin);
353 prev_nodes_.push_back(node);
354 }
355 in_parallel_for_ = false;
356 prev_nodes_.push_back(graph_->back());
357 // Container statements don't belong to any CFGNodes.
358 begin_location_ = offload_stmt_id + 1;
359 CFGNode::add_edge(before_offload, block_begin);
360 }
361 if (stmt->bls_epilogue) {
362 auto before_offload = new_node(-1);
363 int offload_stmt_id = current_stmt_id_;
364 auto block_begin_index = graph_->size();
365 stmt->bls_epilogue->accept(this);
366 prev_nodes_.push_back(graph_->back());
367 // Container statements don't belong to any CFGNodes.
368 begin_location_ = offload_stmt_id + 1;
369 CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get());
370 }
371 if (stmt->tls_epilogue) {
372 auto before_offload = new_node(-1);
373 int offload_stmt_id = current_stmt_id_;
374 auto block_begin_index = graph_->size();
375 stmt->tls_epilogue->accept(this);
376 prev_nodes_.push_back(graph_->back());
377 // Container statements don't belong to any CFGNodes.
378 begin_location_ = offload_stmt_id + 1;
379 CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get());
380 }
381 current_offload_ = nullptr;
382 }
383
384 /**
385 * Structure:
386 *
387 * graph->start_node {
388 * // no statements
389 * } -> node_block_begin if this is the first top-level block;
390 * block {
391 * node_block_begin {
392 * ...
393 * } -> ... -> node_block_end;
394 * node_block_end {
395 * ...
396 * }
397 * }
398 *
399 * graph->final_node = node_block_end;
400 */
401 void visit(Block *block) override {
402 auto backup_block = current_block_;
403 auto backup_last_node = last_node_in_current_block_;
404 auto backup_stmt_id = current_stmt_id_;
405 // |begin_location| must be -1 (indicating we are not building any CFGNode)
406 // when the |current_block| changes.
407 TI_ASSERT(begin_location_ == -1);
408 TI_ASSERT(prev_nodes_.empty() || graph_->size() == 1);
409 current_block_ = block;
410 last_node_in_current_block_ = nullptr;
411 begin_location_ = 0;
412
413 for (int i = 0; i < (int)block->size(); i++) {
414 current_stmt_id_ = i;
415 block->statements[i]->accept(this);
416 }
417 current_stmt_id_ = block->size();
418 new_node(-1); // Each block has a deterministic last node.
419 graph_->final_node = (int)graph_->size() - 1;
420
421 current_block_ = backup_block;
422 last_node_in_current_block_ = backup_last_node;
423 current_stmt_id_ = backup_stmt_id;
424 }
425
426 static std::unique_ptr<ControlFlowGraph> run(IRNode *root) {
427 CFGBuilder builder;
428 root->accept(&builder);
429 if (!builder.graph_->nodes[builder.graph_->final_node]->empty()) {
430 // Make the final node empty (by adding an empty final node).
431 builder.graph_->push_back();
432 CFGNode::add_edge(builder.graph_->nodes[builder.graph_->final_node].get(),
433 builder.graph_->back());
434 builder.graph_->final_node = (int)builder.graph_->size() - 1;
435 }
436 return std::move(builder.graph_);
437 }
438
439 private:
440 std::unique_ptr<ControlFlowGraph> graph_;
441 Block *current_block_;
442 CFGNode *last_node_in_current_block_;
443 std::vector<CFGNode *> continues_in_current_loop_;
444 std::vector<CFGNode *> breaks_in_current_loop_;
445 int current_stmt_id_;
446 int begin_location_;
447 std::vector<CFGNode *> prev_nodes_;
448 OffloadedStmt *current_offload_;
449 bool in_parallel_for_;
450 std::unordered_map<CFGFuncKey, CFGNode *> node_func_begin_;
451 std::unordered_map<CFGFuncKey, CFGNode *> node_func_end_;
452};
453
454namespace irpass::analysis {
455std::unique_ptr<ControlFlowGraph> build_cfg(IRNode *root) {
456 return CFGBuilder::run(root);
457}
458} // namespace irpass::analysis
459
460} // namespace taichi::lang
461