1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/visitors.h"
5
6namespace taichi::lang {
7
8class GatherDeactivations : public BasicStmtVisitor {
9 public:
10 using BasicStmtVisitor::visit;
11
12 std::unordered_set<SNode *> snodes;
13 IRNode *root;
14
15 explicit GatherDeactivations(IRNode *root) : root(root) {
16 }
17
18 void visit(SNodeOpStmt *stmt) override {
19 if (stmt->op_type == SNodeOpType::deactivate) {
20 if (snodes.find(stmt->snode) == snodes.end()) {
21 snodes.insert(stmt->snode);
22 }
23 }
24 }
25
26 std::unordered_set<SNode *> run() {
27 root->accept(this);
28 return snodes;
29 }
30};
31
32namespace irpass::analysis {
33std::unordered_set<SNode *> gather_deactivations(IRNode *root) {
34 GatherDeactivations gather(root);
35 return gather.run();
36}
37} // namespace irpass::analysis
38
39} // namespace taichi::lang
40