1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5
6namespace taichi::lang {
7
8class GatherStatementUsages : public BasicStmtVisitor {
9 private:
10 using BasicStmtVisitor::visit;
11
12 // maps a stmt to all its usages <stmt, operand>
13 std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_;
14
15 public:
16 explicit GatherStatementUsages() {
17 invoke_default_visitor = true;
18 }
19
20 void default_visit(Stmt *stmt) {
21 auto ops = stmt->get_operands();
22 for (int i = 0; i < ops.size(); i++) {
23 auto &op = ops[i];
24 if (op != nullptr) {
25 stmt_usages_[op].push_back({stmt, i});
26 }
27 }
28 }
29
30 void visit(Stmt *stmt) override {
31 default_visit(stmt);
32 }
33
34 void preprocess_container_stmt(Stmt *stmt) override {
35 default_visit(stmt);
36 }
37
38 static std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> run(
39 IRNode *node) {
40 GatherStatementUsages pass;
41 node->accept(&pass);
42 return pass.stmt_usages_;
43 }
44};
45
46namespace irpass::analysis {
47
48std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>>
49gather_statement_usages(IRNode *root) {
50 return GatherStatementUsages::run(root);
51}
52
53} // namespace irpass::analysis
54
55} // namespace taichi::lang
56