1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/program.h"
7
8#include <unordered_map>
9
10namespace taichi::lang {
11
12class IRCloner : public IRVisitor {
13 private:
14 IRNode *other_node;
15 std::unordered_map<Stmt *, Stmt *> operand_map_;
16
17 public:
18 enum Phase { register_operand_map, replace_operand } phase;
19
20 explicit IRCloner(IRNode *other_node)
21 : other_node(other_node), phase(register_operand_map) {
22 allow_undefined_visitor = true;
23 invoke_default_visitor = true;
24 }
25
26 void visit(Block *stmt_list) override {
27 auto other = other_node->as<Block>();
28 for (int i = 0; i < (int)stmt_list->size(); i++) {
29 other_node = other->statements[i].get();
30 stmt_list->statements[i]->accept(this);
31 }
32 other_node = other;
33 }
34
35 void generic_visit(Stmt *stmt) {
36 if (phase == register_operand_map)
37 operand_map_[stmt] = other_node->as<Stmt>();
38 else {
39 TI_ASSERT(phase == replace_operand);
40 auto other_stmt = other_node->as<Stmt>();
41 TI_ASSERT(stmt->num_operands() == other_stmt->num_operands());
42 for (int i = 0; i < stmt->num_operands(); i++) {
43 if (operand_map_.find(stmt->operand(i)) == operand_map_.end())
44 other_stmt->set_operand(i, stmt->operand(i));
45 else
46 other_stmt->set_operand(i, operand_map_[stmt->operand(i)]);
47 }
48 }
49 }
50
51 void visit(Stmt *stmt) override {
52 generic_visit(stmt);
53 }
54
55 void visit(IfStmt *stmt) override {
56 generic_visit(stmt);
57 auto other = other_node->as<IfStmt>();
58 if (stmt->true_statements) {
59 other_node = other->true_statements.get();
60 stmt->true_statements->accept(this);
61 other_node = other;
62 }
63 if (stmt->false_statements) {
64 other_node = other->false_statements.get();
65 stmt->false_statements->accept(this);
66 other_node = other;
67 }
68 }
69
70 void visit(WhileStmt *stmt) override {
71 generic_visit(stmt);
72 auto other = other_node->as<WhileStmt>();
73 other_node = other->body.get();
74 stmt->body->accept(this);
75 other_node = other;
76 }
77
78 void visit(RangeForStmt *stmt) override {
79 generic_visit(stmt);
80 auto other = other_node->as<RangeForStmt>();
81 other_node = other->body.get();
82 stmt->body->accept(this);
83 other_node = other;
84 }
85
86 void visit(StructForStmt *stmt) override {
87 generic_visit(stmt);
88 auto other = other_node->as<StructForStmt>();
89 other_node = other->body.get();
90 stmt->body->accept(this);
91 other_node = other;
92 }
93
94 void visit(OffloadedStmt *stmt) override {
95 generic_visit(stmt);
96 auto other = other_node->as<OffloadedStmt>();
97
98#define CLONE_BLOCK(B) \
99 if (stmt->B) { \
100 other->B = std::make_unique<Block>(); \
101 other_node = other->B.get(); \
102 stmt->B->accept(this); \
103 }
104
105 CLONE_BLOCK(tls_prologue)
106 CLONE_BLOCK(bls_prologue)
107 CLONE_BLOCK(mesh_prologue)
108
109 if (stmt->body) {
110 other_node = other->body.get();
111 stmt->body->accept(this);
112 }
113
114 CLONE_BLOCK(bls_epilogue)
115 CLONE_BLOCK(tls_epilogue)
116#undef CLONE_BLOCK
117
118 other_node = other;
119 }
120
121 static std::unique_ptr<IRNode> run(IRNode *root) {
122 std::unique_ptr<IRNode> new_root = root->clone();
123 IRCloner cloner(new_root.get());
124 cloner.phase = IRCloner::register_operand_map;
125 root->accept(&cloner);
126 cloner.phase = IRCloner::replace_operand;
127 root->accept(&cloner);
128
129 return new_root;
130 }
131};
132
133namespace irpass::analysis {
134std::unique_ptr<IRNode> clone(IRNode *root) {
135 return IRCloner::run(root);
136}
137} // namespace irpass::analysis
138
139} // namespace taichi::lang
140