1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/program/program.h" |
7 | |
8 | #include <unordered_map> |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | class IRCloner : public IRVisitor { |
13 | private: |
14 | IRNode *other_node; |
15 | std::unordered_map<Stmt *, Stmt *> operand_map_; |
16 | |
17 | public: |
18 | enum Phase { register_operand_map, replace_operand } phase; |
19 | |
20 | explicit IRCloner(IRNode *other_node) |
21 | : other_node(other_node), phase(register_operand_map) { |
22 | allow_undefined_visitor = true; |
23 | invoke_default_visitor = true; |
24 | } |
25 | |
26 | void visit(Block *stmt_list) override { |
27 | auto other = other_node->as<Block>(); |
28 | for (int i = 0; i < (int)stmt_list->size(); i++) { |
29 | other_node = other->statements[i].get(); |
30 | stmt_list->statements[i]->accept(this); |
31 | } |
32 | other_node = other; |
33 | } |
34 | |
35 | void generic_visit(Stmt *stmt) { |
36 | if (phase == register_operand_map) |
37 | operand_map_[stmt] = other_node->as<Stmt>(); |
38 | else { |
39 | TI_ASSERT(phase == replace_operand); |
40 | auto other_stmt = other_node->as<Stmt>(); |
41 | TI_ASSERT(stmt->num_operands() == other_stmt->num_operands()); |
42 | for (int i = 0; i < stmt->num_operands(); i++) { |
43 | if (operand_map_.find(stmt->operand(i)) == operand_map_.end()) |
44 | other_stmt->set_operand(i, stmt->operand(i)); |
45 | else |
46 | other_stmt->set_operand(i, operand_map_[stmt->operand(i)]); |
47 | } |
48 | } |
49 | } |
50 | |
51 | void visit(Stmt *stmt) override { |
52 | generic_visit(stmt); |
53 | } |
54 | |
55 | void visit(IfStmt *stmt) override { |
56 | generic_visit(stmt); |
57 | auto other = other_node->as<IfStmt>(); |
58 | if (stmt->true_statements) { |
59 | other_node = other->true_statements.get(); |
60 | stmt->true_statements->accept(this); |
61 | other_node = other; |
62 | } |
63 | if (stmt->false_statements) { |
64 | other_node = other->false_statements.get(); |
65 | stmt->false_statements->accept(this); |
66 | other_node = other; |
67 | } |
68 | } |
69 | |
70 | void visit(WhileStmt *stmt) override { |
71 | generic_visit(stmt); |
72 | auto other = other_node->as<WhileStmt>(); |
73 | other_node = other->body.get(); |
74 | stmt->body->accept(this); |
75 | other_node = other; |
76 | } |
77 | |
78 | void visit(RangeForStmt *stmt) override { |
79 | generic_visit(stmt); |
80 | auto other = other_node->as<RangeForStmt>(); |
81 | other_node = other->body.get(); |
82 | stmt->body->accept(this); |
83 | other_node = other; |
84 | } |
85 | |
86 | void visit(StructForStmt *stmt) override { |
87 | generic_visit(stmt); |
88 | auto other = other_node->as<StructForStmt>(); |
89 | other_node = other->body.get(); |
90 | stmt->body->accept(this); |
91 | other_node = other; |
92 | } |
93 | |
94 | void visit(OffloadedStmt *stmt) override { |
95 | generic_visit(stmt); |
96 | auto other = other_node->as<OffloadedStmt>(); |
97 | |
98 | #define CLONE_BLOCK(B) \ |
99 | if (stmt->B) { \ |
100 | other->B = std::make_unique<Block>(); \ |
101 | other_node = other->B.get(); \ |
102 | stmt->B->accept(this); \ |
103 | } |
104 | |
105 | CLONE_BLOCK(tls_prologue) |
106 | CLONE_BLOCK(bls_prologue) |
107 | CLONE_BLOCK(mesh_prologue) |
108 | |
109 | if (stmt->body) { |
110 | other_node = other->body.get(); |
111 | stmt->body->accept(this); |
112 | } |
113 | |
114 | CLONE_BLOCK(bls_epilogue) |
115 | CLONE_BLOCK(tls_epilogue) |
116 | #undef CLONE_BLOCK |
117 | |
118 | other_node = other; |
119 | } |
120 | |
121 | static std::unique_ptr<IRNode> run(IRNode *root) { |
122 | std::unique_ptr<IRNode> new_root = root->clone(); |
123 | IRCloner cloner(new_root.get()); |
124 | cloner.phase = IRCloner::register_operand_map; |
125 | root->accept(&cloner); |
126 | cloner.phase = IRCloner::replace_operand; |
127 | root->accept(&cloner); |
128 | |
129 | return new_root; |
130 | } |
131 | }; |
132 | |
133 | namespace irpass::analysis { |
134 | std::unique_ptr<IRNode> clone(IRNode *root) { |
135 | return IRCloner::run(root); |
136 | } |
137 | } // namespace irpass::analysis |
138 | |
139 | } // namespace taichi::lang |
140 |