1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/visitors.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | class StmtSearcher : public BasicStmtVisitor { |
8 | private: |
9 | std::function<bool(Stmt *)> test_; |
10 | std::vector<Stmt *> results_; |
11 | |
12 | public: |
13 | using BasicStmtVisitor::visit; |
14 | |
15 | explicit StmtSearcher(std::function<bool(Stmt *)> test) : test_(test) { |
16 | allow_undefined_visitor = true; |
17 | invoke_default_visitor = true; |
18 | } |
19 | |
20 | void visit(Stmt *stmt) override { |
21 | if (test_(stmt)) |
22 | results_.push_back(stmt); |
23 | } |
24 | |
25 | static std::vector<Stmt *> run(IRNode *root, |
26 | const std::function<bool(Stmt *)> &test) { |
27 | StmtSearcher searcher(test); |
28 | root->accept(&searcher); |
29 | return searcher.results_; |
30 | } |
31 | }; |
32 | |
33 | namespace irpass::analysis { |
34 | std::vector<Stmt *> gather_statements(IRNode *root, |
35 | const std::function<bool(Stmt *)> &test) { |
36 | return StmtSearcher::run(root, test); |
37 | } |
38 | } // namespace irpass::analysis |
39 | |
40 | } // namespace taichi::lang |
41 |