1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/visitors.h" |
4 | #include "taichi/ir/frontend_ir.h" |
5 | |
6 | #include <unordered_set> |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | class DetectForsWithBreak : public BasicStmtVisitor { |
11 | public: |
12 | using BasicStmtVisitor::visit; |
13 | |
14 | std::vector<Stmt *> loop_stack; |
15 | std::unordered_set<Stmt *> fors_with_break; |
16 | IRNode *root; |
17 | |
18 | explicit DetectForsWithBreak(IRNode *root) : root(root) { |
19 | } |
20 | |
21 | void visit(FrontendBreakStmt *stmt) override { |
22 | TI_ASSERT_INFO(loop_stack.size() != 0, "break statement out of loop scope"); |
23 | auto loop = loop_stack.back(); |
24 | if (loop->is<FrontendForStmt>()) |
25 | fors_with_break.insert(loop); |
26 | } |
27 | |
28 | void visit(FrontendWhileStmt *stmt) override { |
29 | loop_stack.push_back(stmt); |
30 | stmt->body->accept(this); |
31 | loop_stack.pop_back(); |
32 | } |
33 | |
34 | void visit(FrontendForStmt *stmt) override { |
35 | loop_stack.push_back(stmt); |
36 | stmt->body->accept(this); |
37 | loop_stack.pop_back(); |
38 | } |
39 | |
40 | std::unordered_set<Stmt *> run() { |
41 | root->accept(this); |
42 | return fors_with_break; |
43 | } |
44 | }; |
45 | |
46 | namespace irpass::analysis { |
47 | std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root) { |
48 | DetectForsWithBreak detective(root); |
49 | return detective.run(); |
50 | } |
51 | } // namespace irpass::analysis |
52 | |
53 | } // namespace taichi::lang |
54 |