1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | // Find if there is a store (or AtomicOpStmt). |
9 | class LocalStoreSearcher : public BasicStmtVisitor { |
10 | private: |
11 | const std::vector<Stmt *> &vars_; |
12 | bool result_; |
13 | |
14 | public: |
15 | using BasicStmtVisitor::visit; |
16 | |
17 | explicit LocalStoreSearcher(const std::vector<Stmt *> &vars) |
18 | : vars_(vars), result_(false) { |
19 | for (auto var : vars) { |
20 | TI_ASSERT(var->is<AllocaStmt>()); |
21 | } |
22 | allow_undefined_visitor = true; |
23 | invoke_default_visitor = true; |
24 | } |
25 | |
26 | void visit(LocalStoreStmt *stmt) override { |
27 | for (auto var : vars_) { |
28 | if (stmt->dest == var) { |
29 | result_ = true; |
30 | break; |
31 | } |
32 | } |
33 | } |
34 | |
35 | void visit(AtomicOpStmt *stmt) override { |
36 | for (auto var : vars_) { |
37 | if (stmt->dest == var) { |
38 | result_ = true; |
39 | break; |
40 | } |
41 | } |
42 | } |
43 | |
44 | static bool run(IRNode *root, const std::vector<Stmt *> &vars) { |
45 | LocalStoreSearcher searcher(vars); |
46 | root->accept(&searcher); |
47 | return searcher.result_; |
48 | } |
49 | }; |
50 | |
51 | namespace irpass::analysis { |
52 | bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars) { |
53 | return LocalStoreSearcher::run(root, vars); |
54 | } |
55 | } // namespace irpass::analysis |
56 | |
57 | } // namespace taichi::lang |
58 |