1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/visitors.h"
4
5namespace taichi::lang {
6
7class StmtSearcher : public BasicStmtVisitor {
8 private:
9 std::function<bool(Stmt *)> test_;
10 std::vector<Stmt *> results_;
11
12 public:
13 using BasicStmtVisitor::visit;
14
15 explicit StmtSearcher(std::function<bool(Stmt *)> test) : test_(test) {
16 allow_undefined_visitor = true;
17 invoke_default_visitor = true;
18 }
19
20 void visit(Stmt *stmt) override {
21 if (test_(stmt))
22 results_.push_back(stmt);
23 }
24
25 static std::vector<Stmt *> run(IRNode *root,
26 const std::function<bool(Stmt *)> &test) {
27 StmtSearcher searcher(test);
28 root->accept(&searcher);
29 return searcher.results_;
30 }
31};
32
33namespace irpass::analysis {
34std::vector<Stmt *> gather_statements(IRNode *root,
35 const std::function<bool(Stmt *)> &test) {
36 return StmtSearcher::run(root, test);
37}
38} // namespace irpass::analysis
39
40} // namespace taichi::lang
41