1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | class GatherStatementUsages : public BasicStmtVisitor { |
9 | private: |
10 | using BasicStmtVisitor::visit; |
11 | |
12 | // maps a stmt to all its usages <stmt, operand> |
13 | std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_; |
14 | |
15 | public: |
16 | explicit GatherStatementUsages() { |
17 | invoke_default_visitor = true; |
18 | } |
19 | |
20 | void default_visit(Stmt *stmt) { |
21 | auto ops = stmt->get_operands(); |
22 | for (int i = 0; i < ops.size(); i++) { |
23 | auto &op = ops[i]; |
24 | if (op != nullptr) { |
25 | stmt_usages_[op].push_back({stmt, i}); |
26 | } |
27 | } |
28 | } |
29 | |
30 | void visit(Stmt *stmt) override { |
31 | default_visit(stmt); |
32 | } |
33 | |
34 | void preprocess_container_stmt(Stmt *stmt) override { |
35 | default_visit(stmt); |
36 | } |
37 | |
38 | static std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> run( |
39 | IRNode *node) { |
40 | GatherStatementUsages pass; |
41 | node->accept(&pass); |
42 | return pass.stmt_usages_; |
43 | } |
44 | }; |
45 | |
46 | namespace irpass::analysis { |
47 | |
48 | std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> |
49 | gather_statement_usages(IRNode *root) { |
50 | return GatherStatementUsages::run(root); |
51 | } |
52 | |
53 | } // namespace irpass::analysis |
54 | |
55 | } // namespace taichi::lang |
56 |