1 | #include "taichi/ir/control_flow_graph.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/program/function.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | struct CFGFuncKey { |
9 | FunctionKey func_key{"" , -1, -1}; |
10 | bool in_parallel_for{false}; |
11 | |
12 | bool operator==(const CFGFuncKey &other_key) const { |
13 | return func_key == other_key.func_key && |
14 | in_parallel_for == other_key.in_parallel_for; |
15 | } |
16 | }; |
17 | |
18 | } // namespace taichi::lang |
19 | |
20 | namespace std { |
21 | template <> |
22 | struct hash<taichi::lang::CFGFuncKey> { |
23 | std::size_t operator()(const taichi::lang::CFGFuncKey &key) const noexcept { |
24 | return std::hash<taichi::lang::FunctionKey>()(key.func_key) ^ |
25 | ((std::size_t)key.in_parallel_for << 32); |
26 | } |
27 | }; |
28 | } // namespace std |
29 | |
30 | namespace taichi::lang { |
31 | |
32 | /** |
33 | * Build a control-flow graph. The resulting graph is guaranteed to have an |
34 | * empty start node and an empty final node. |
35 | * |
36 | * In the following docstrings, node... means a CFGNode's corresponding |
37 | * statements in the CHI IR. Other blocks are just Blocks in the CHI IR. |
38 | * Nodes denoted with "()" mean not yet created when visiting the Stmt/Block. |
39 | * |
40 | * Structures like |
41 | * node_a { |
42 | * ... |
43 | * } -> node_b, node_c; |
44 | * means node_a has edges to node_b and node_c, or equivalently, node_b and |
45 | * node_c appear in the |next| field of node_a. |
46 | * |
47 | * Structures like |
48 | * node_a { |
49 | * ... |
50 | * } -> node_b, [node_c if "cond"]; |
51 | * means node_a has an edge to node_b, and node_a has an edge to node_c iff |
52 | * the condition "cond" is true. |
53 | * |
54 | * When there can be many CFGNodes in a Block, internal nodes are omitted for |
55 | * simplicity. |
56 | * |
57 | * TODO(#2193): Make sure ReturnStmt is handled properly. |
58 | */ |
59 | class CFGBuilder : public IRVisitor { |
60 | public: |
61 | CFGBuilder() |
62 | : current_block_(nullptr), |
63 | last_node_in_current_block_(nullptr), |
64 | current_stmt_id_(-1), |
65 | begin_location_(-1), |
66 | current_offload_(nullptr), |
67 | in_parallel_for_(false) { |
68 | allow_undefined_visitor = true; |
69 | invoke_default_visitor = true; |
70 | graph_ = std::make_unique<ControlFlowGraph>(); |
71 | // Make an empty start node. |
72 | auto start_node = graph_->push_back(); |
73 | prev_nodes_.push_back(start_node); |
74 | } |
75 | |
76 | void visit(Stmt *stmt) override { |
77 | if (stmt->is_container_statement()) { |
78 | TI_ERROR("Visitor for container statement undefined." ); |
79 | } |
80 | } |
81 | |
82 | /** |
83 | * Create a node for the current control-flow graph, |
84 | * mark the current statement as the end location (exclusive) of the node, |
85 | * and add edges from |prev_nodes| to the node. |
86 | * |
87 | * @param next_begin_location The location in the IR block of the first |
88 | * statement in the next node, if the next node is in the same IR block of |
89 | * the node to be returned. Otherwise, next_begin_location must be -1. |
90 | * @return The node which is just created. |
91 | */ |
92 | CFGNode *new_node(int next_begin_location) { |
93 | auto node = graph_->push_back( |
94 | current_block_, begin_location_, /*end_location=*/current_stmt_id_, |
95 | /*is_parallel_executed=*/in_parallel_for_, |
96 | /*prev_node_in_same_block=*/last_node_in_current_block_); |
97 | for (auto &prev_node : prev_nodes_) { |
98 | // Now that the "(next node)" is created, we should insert edges |
99 | // "node... -> (next node)" here. |
100 | CFGNode::add_edge(prev_node, node); |
101 | } |
102 | prev_nodes_.clear(); |
103 | begin_location_ = next_begin_location; |
104 | last_node_in_current_block_ = node; |
105 | return node; |
106 | } |
107 | |
108 | /** |
109 | * Structure: |
110 | * |
111 | * block { |
112 | * node { |
113 | * ... |
114 | * } -> node_loop_begin, (the next node after the loop); |
115 | * continue; |
116 | * (next node) { |
117 | * ... |
118 | * } |
119 | * } |
120 | * |
121 | * Note that the edges are inserted in visit_loop(). |
122 | */ |
123 | void visit(ContinueStmt *stmt) override { |
124 | // Don't put ContinueStmt in any CFGNodes. |
125 | continues_in_current_loop_.push_back(new_node(current_stmt_id_ + 1)); |
126 | } |
127 | |
128 | /** |
129 | * Structure: |
130 | * |
131 | * block { |
132 | * node { |
133 | * ... |
134 | * } -> (next node), (the next node after the loop); |
135 | * while_control (possibly break); |
136 | * (next node) { |
137 | * ... |
138 | * } |
139 | * } |
140 | * |
141 | * Note that the edges are inserted in visit_loop(). |
142 | */ |
143 | void visit(WhileControlStmt *stmt) override { |
144 | // Don't put WhileControlStmt in any CFGNodes. |
145 | auto node = new_node(current_stmt_id_ + 1); |
146 | breaks_in_current_loop_.push_back(node); |
147 | prev_nodes_.push_back(node); |
148 | } |
149 | |
150 | /** |
151 | * Structure: |
152 | * |
153 | * node_before_if { |
154 | * ... |
155 | * } -> node_true_branch_begin, node_false_branch_begin; |
156 | * if (...) { |
157 | * node_true_branch_begin { |
158 | * ... |
159 | * } -> ... -> node_true_branch_end; |
160 | * node_true_branch_end { |
161 | * ... |
162 | * } -> (next node); |
163 | * } else { |
164 | * node_false_branch_begin { |
165 | * ... |
166 | * } -> ... -> node_false_branch_end; |
167 | * node_false_branch_end { |
168 | * ... |
169 | * } -> (next node); |
170 | * } |
171 | * (next node) { |
172 | * ... |
173 | * } |
174 | */ |
175 | void visit(IfStmt *if_stmt) override { |
176 | auto before_if = new_node(-1); |
177 | CFGNode *true_branch_end = nullptr; |
178 | if (if_stmt->true_statements) { |
179 | auto true_branch_begin = graph_->size(); |
180 | if_stmt->true_statements->accept(this); |
181 | CFGNode::add_edge(before_if, graph_->nodes[true_branch_begin].get()); |
182 | true_branch_end = graph_->back(); |
183 | } |
184 | CFGNode *false_branch_end = nullptr; |
185 | if (if_stmt->false_statements) { |
186 | auto false_branch_begin = graph_->size(); |
187 | if_stmt->false_statements->accept(this); |
188 | CFGNode::add_edge(before_if, graph_->nodes[false_branch_begin].get()); |
189 | false_branch_end = graph_->back(); |
190 | } |
191 | TI_ASSERT(prev_nodes_.empty()); |
192 | if (if_stmt->true_statements) |
193 | prev_nodes_.push_back(true_branch_end); |
194 | if (if_stmt->false_statements) |
195 | prev_nodes_.push_back(false_branch_end); |
196 | if (!if_stmt->true_statements || !if_stmt->false_statements) |
197 | prev_nodes_.push_back(before_if); |
198 | // Container statements don't belong to any CFGNodes. |
199 | begin_location_ = current_stmt_id_ + 1; |
200 | } |
201 | |
202 | /** |
203 | * Structure ([(next node) if !is_while_true] means the node has an edge to |
204 | * (next node) only when is_while_true is false): |
205 | * |
206 | * node_before_loop { |
207 | * ... |
208 | * } -> node_loop_begin, [(next node) if !is_while_true]; |
209 | * loop (...) { |
210 | * node_loop_begin { |
211 | * ... |
212 | * } -> ... -> node_loop_end; |
213 | * node_loop_end { |
214 | * ... |
215 | * } -> node_loop_begin, [(next node) if !is_while_true]; |
216 | * } |
217 | * (next node) { |
218 | * ... |
219 | * } |
220 | */ |
221 | void visit_loop(Block *body, CFGNode *before_loop, bool is_while_true) { |
222 | int loop_stmt_id = current_stmt_id_; |
223 | auto backup_continues = std::move(continues_in_current_loop_); |
224 | auto backup_breaks = std::move(breaks_in_current_loop_); |
225 | continues_in_current_loop_.clear(); |
226 | breaks_in_current_loop_.clear(); |
227 | |
228 | auto loop_begin_index = graph_->size(); |
229 | body->accept(this); |
230 | auto loop_begin = graph_->nodes[loop_begin_index].get(); |
231 | CFGNode::add_edge(before_loop, loop_begin); |
232 | auto loop_end = graph_->back(); |
233 | CFGNode::add_edge(loop_end, loop_begin); |
234 | if (!is_while_true) { |
235 | prev_nodes_.push_back(before_loop); |
236 | prev_nodes_.push_back(loop_end); |
237 | } |
238 | for (auto &node : continues_in_current_loop_) { |
239 | CFGNode::add_edge(node, loop_begin); |
240 | prev_nodes_.push_back(node); |
241 | } |
242 | for (auto &node : breaks_in_current_loop_) { |
243 | prev_nodes_.push_back(node); |
244 | } |
245 | |
246 | // Container statements don't belong to any CFGNodes. |
247 | begin_location_ = loop_stmt_id + 1; |
248 | continues_in_current_loop_ = std::move(backup_continues); |
249 | breaks_in_current_loop_ = std::move(backup_breaks); |
250 | } |
251 | |
252 | void visit(WhileStmt *stmt) override { |
253 | visit_loop(stmt->body.get(), new_node(-1), true); |
254 | } |
255 | |
256 | void visit(RangeForStmt *stmt) override { |
257 | auto old_in_parallel_for = in_parallel_for_; |
258 | if (!current_offload_) |
259 | in_parallel_for_ = true; |
260 | visit_loop(stmt->body.get(), new_node(-1), false); |
261 | in_parallel_for_ = old_in_parallel_for; |
262 | } |
263 | |
264 | void visit(StructForStmt *stmt) override { |
265 | auto old_in_parallel_for = in_parallel_for_; |
266 | if (!current_offload_) |
267 | in_parallel_for_ = true; |
268 | visit_loop(stmt->body.get(), new_node(-1), false); |
269 | in_parallel_for_ = old_in_parallel_for; |
270 | } |
271 | |
272 | void visit(MeshForStmt *stmt) override { |
273 | auto old_in_parallel_for = in_parallel_for_; |
274 | if (!current_offload_) |
275 | in_parallel_for_ = true; |
276 | visit_loop(stmt->body.get(), new_node(-1), false); |
277 | in_parallel_for_ = old_in_parallel_for; |
278 | } |
279 | |
280 | /** |
281 | * Structure: |
282 | * |
283 | * node_before_offload { |
284 | * ... |
285 | * } -> node_tls_prologue; |
286 | * node_tls_prologue { |
287 | * ... |
288 | * } -> node_mesh_prologue; |
289 | * node_mesh_prologue: |
290 | * ... |
291 | * } -> node_bls_prologue; |
292 | * node_bls_prologue { |
293 | * ... |
294 | * } -> node_body; |
295 | * node_body { |
296 | * ... |
297 | * } -> node_bls_epilogue; |
298 | * node_bls_epilogue { |
299 | * ... |
300 | * } -> node_tls_epilogue; |
301 | * node_tls_epilogue { |
302 | * ... |
303 | * } -> (next node); |
304 | * (next node) { |
305 | * ... |
306 | * } |
307 | */ |
308 | void visit(OffloadedStmt *stmt) override { |
309 | current_offload_ = stmt; |
310 | if (stmt->tls_prologue) { |
311 | auto before_offload = new_node(-1); |
312 | int offload_stmt_id = current_stmt_id_; |
313 | auto block_begin_index = graph_->size(); |
314 | stmt->tls_prologue->accept(this); |
315 | prev_nodes_.push_back(graph_->back()); |
316 | // Container statements don't belong to any CFGNodes. |
317 | begin_location_ = offload_stmt_id + 1; |
318 | CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); |
319 | } |
320 | if (stmt->mesh_prologue) { |
321 | auto before_offload = new_node(-1); |
322 | int offload_stmt_id = current_stmt_id_; |
323 | auto block_begin_index = graph_->size(); |
324 | stmt->mesh_prologue->accept(this); |
325 | prev_nodes_.push_back(graph_->back()); |
326 | // Container statements don't belong to any CFGNodes. |
327 | begin_location_ = offload_stmt_id + 1; |
328 | CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); |
329 | } |
330 | if (stmt->bls_prologue) { |
331 | auto before_offload = new_node(-1); |
332 | int offload_stmt_id = current_stmt_id_; |
333 | auto block_begin_index = graph_->size(); |
334 | stmt->bls_prologue->accept(this); |
335 | prev_nodes_.push_back(graph_->back()); |
336 | // Container statements don't belong to any CFGNodes. |
337 | begin_location_ = offload_stmt_id + 1; |
338 | CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); |
339 | } |
340 | if (stmt->has_body()) { |
341 | auto before_offload = new_node(-1); |
342 | int offload_stmt_id = current_stmt_id_; |
343 | auto block_begin_index = graph_->size(); |
344 | if (stmt->task_type == OffloadedStmt::TaskType::range_for || |
345 | stmt->task_type == OffloadedStmt::TaskType::struct_for || |
346 | stmt->task_type == OffloadedStmt::TaskType::mesh_for) { |
347 | in_parallel_for_ = true; |
348 | } |
349 | stmt->body->accept(this); |
350 | auto block_begin = graph_->nodes[block_begin_index].get(); |
351 | for (auto &node : continues_in_current_loop_) { |
352 | CFGNode::add_edge(node, block_begin); |
353 | prev_nodes_.push_back(node); |
354 | } |
355 | in_parallel_for_ = false; |
356 | prev_nodes_.push_back(graph_->back()); |
357 | // Container statements don't belong to any CFGNodes. |
358 | begin_location_ = offload_stmt_id + 1; |
359 | CFGNode::add_edge(before_offload, block_begin); |
360 | } |
361 | if (stmt->bls_epilogue) { |
362 | auto before_offload = new_node(-1); |
363 | int offload_stmt_id = current_stmt_id_; |
364 | auto block_begin_index = graph_->size(); |
365 | stmt->bls_epilogue->accept(this); |
366 | prev_nodes_.push_back(graph_->back()); |
367 | // Container statements don't belong to any CFGNodes. |
368 | begin_location_ = offload_stmt_id + 1; |
369 | CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); |
370 | } |
371 | if (stmt->tls_epilogue) { |
372 | auto before_offload = new_node(-1); |
373 | int offload_stmt_id = current_stmt_id_; |
374 | auto block_begin_index = graph_->size(); |
375 | stmt->tls_epilogue->accept(this); |
376 | prev_nodes_.push_back(graph_->back()); |
377 | // Container statements don't belong to any CFGNodes. |
378 | begin_location_ = offload_stmt_id + 1; |
379 | CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); |
380 | } |
381 | current_offload_ = nullptr; |
382 | } |
383 | |
384 | /** |
385 | * Structure: |
386 | * |
387 | * graph->start_node { |
388 | * // no statements |
389 | * } -> node_block_begin if this is the first top-level block; |
390 | * block { |
391 | * node_block_begin { |
392 | * ... |
393 | * } -> ... -> node_block_end; |
394 | * node_block_end { |
395 | * ... |
396 | * } |
397 | * } |
398 | * |
399 | * graph->final_node = node_block_end; |
400 | */ |
401 | void visit(Block *block) override { |
402 | auto backup_block = current_block_; |
403 | auto backup_last_node = last_node_in_current_block_; |
404 | auto backup_stmt_id = current_stmt_id_; |
405 | // |begin_location| must be -1 (indicating we are not building any CFGNode) |
406 | // when the |current_block| changes. |
407 | TI_ASSERT(begin_location_ == -1); |
408 | TI_ASSERT(prev_nodes_.empty() || graph_->size() == 1); |
409 | current_block_ = block; |
410 | last_node_in_current_block_ = nullptr; |
411 | begin_location_ = 0; |
412 | |
413 | for (int i = 0; i < (int)block->size(); i++) { |
414 | current_stmt_id_ = i; |
415 | block->statements[i]->accept(this); |
416 | } |
417 | current_stmt_id_ = block->size(); |
418 | new_node(-1); // Each block has a deterministic last node. |
419 | graph_->final_node = (int)graph_->size() - 1; |
420 | |
421 | current_block_ = backup_block; |
422 | last_node_in_current_block_ = backup_last_node; |
423 | current_stmt_id_ = backup_stmt_id; |
424 | } |
425 | |
426 | static std::unique_ptr<ControlFlowGraph> run(IRNode *root) { |
427 | CFGBuilder builder; |
428 | root->accept(&builder); |
429 | if (!builder.graph_->nodes[builder.graph_->final_node]->empty()) { |
430 | // Make the final node empty (by adding an empty final node). |
431 | builder.graph_->push_back(); |
432 | CFGNode::add_edge(builder.graph_->nodes[builder.graph_->final_node].get(), |
433 | builder.graph_->back()); |
434 | builder.graph_->final_node = (int)builder.graph_->size() - 1; |
435 | } |
436 | return std::move(builder.graph_); |
437 | } |
438 | |
439 | private: |
440 | std::unique_ptr<ControlFlowGraph> graph_; |
441 | Block *current_block_; |
442 | CFGNode *last_node_in_current_block_; |
443 | std::vector<CFGNode *> continues_in_current_loop_; |
444 | std::vector<CFGNode *> breaks_in_current_loop_; |
445 | int current_stmt_id_; |
446 | int begin_location_; |
447 | std::vector<CFGNode *> prev_nodes_; |
448 | OffloadedStmt *current_offload_; |
449 | bool in_parallel_for_; |
450 | std::unordered_map<CFGFuncKey, CFGNode *> node_func_begin_; |
451 | std::unordered_map<CFGFuncKey, CFGNode *> node_func_end_; |
452 | }; |
453 | |
454 | namespace irpass::analysis { |
455 | std::unique_ptr<ControlFlowGraph> build_cfg(IRNode *root) { |
456 | return CFGBuilder::run(root); |
457 | } |
458 | } // namespace irpass::analysis |
459 | |
460 | } // namespace taichi::lang |
461 | |