1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | class GatherDeactivations : public BasicStmtVisitor { |
9 | public: |
10 | using BasicStmtVisitor::visit; |
11 | |
12 | std::unordered_set<SNode *> snodes; |
13 | IRNode *root; |
14 | |
15 | explicit GatherDeactivations(IRNode *root) : root(root) { |
16 | } |
17 | |
18 | void visit(SNodeOpStmt *stmt) override { |
19 | if (stmt->op_type == SNodeOpType::deactivate) { |
20 | if (snodes.find(stmt->snode) == snodes.end()) { |
21 | snodes.insert(stmt->snode); |
22 | } |
23 | } |
24 | } |
25 | |
26 | std::unordered_set<SNode *> run() { |
27 | root->accept(this); |
28 | return snodes; |
29 | } |
30 | }; |
31 | |
32 | namespace irpass::analysis { |
33 | std::unordered_set<SNode *> gather_deactivations(IRNode *root) { |
34 | GatherDeactivations gather(root); |
35 | return gather.run(); |
36 | } |
37 | } // namespace irpass::analysis |
38 | |
39 | } // namespace taichi::lang |
40 |