1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // Flag accesses to be either weak (non-activating) or strong (activating) |
10 | class FlagAccess : public IRVisitor { |
11 | public: |
12 | explicit FlagAccess(IRNode *node) { |
13 | allow_undefined_visitor = true; |
14 | invoke_default_visitor = false; |
15 | node->accept(this); |
16 | } |
17 | |
18 | void visit(Block *stmt_list) override { // block itself has no id |
19 | for (auto &stmt : stmt_list->statements) { |
20 | stmt->accept(this); |
21 | } |
22 | } |
23 | |
24 | void visit(IfStmt *if_stmt) override { |
25 | if (if_stmt->true_statements) |
26 | if_stmt->true_statements->accept(this); |
27 | if (if_stmt->false_statements) { |
28 | if_stmt->false_statements->accept(this); |
29 | } |
30 | } |
31 | |
32 | void visit(WhileStmt *stmt) override { |
33 | stmt->body->accept(this); |
34 | } |
35 | |
36 | void visit(RangeForStmt *for_stmt) override { |
37 | for_stmt->body->accept(this); |
38 | } |
39 | |
40 | void visit(StructForStmt *for_stmt) override { |
41 | for_stmt->body->accept(this); |
42 | } |
43 | |
44 | void visit(MeshForStmt *for_stmt) override { |
45 | for_stmt->body->accept(this); |
46 | } |
47 | |
48 | void visit(OffloadedStmt *stmt) override { |
49 | stmt->all_blocks_accept(this); |
50 | } |
51 | |
52 | // Assuming pointers will be visited before global load/st |
53 | void visit(GlobalPtrStmt *stmt) override { |
54 | stmt->activate = false; |
55 | } |
56 | |
57 | void visit(GlobalStoreStmt *stmt) override { |
58 | if (stmt->dest->is<GlobalPtrStmt>()) { |
59 | stmt->dest->as<GlobalPtrStmt>()->activate = true; |
60 | } |
61 | if (stmt->dest->is<MatrixPtrStmt>()) { |
62 | if (stmt->dest->as<MatrixPtrStmt>()->is_unlowered_global_ptr()) { |
63 | stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>()->activate = |
64 | true; |
65 | } |
66 | } |
67 | } |
68 | |
69 | void visit(AtomicOpStmt *stmt) override { |
70 | if (stmt->dest->is<GlobalPtrStmt>()) { |
71 | stmt->dest->as<GlobalPtrStmt>()->activate = true; |
72 | } |
73 | } |
74 | }; |
75 | |
76 | // For struct fors, weaken accesses on variables currently being looped over |
77 | // E.g. |
78 | // for i in x: |
79 | // x[i] = 0 |
80 | // Here although we are writing to x[i], but i will only loop over active |
81 | // elements of x. So we don't need one more activation. Note the indices of x |
82 | // accesses must be loop indices for this optimization to be correct. |
83 | |
84 | class WeakenAccess : public BasicStmtVisitor { |
85 | public: |
86 | using BasicStmtVisitor::visit; |
87 | |
88 | explicit WeakenAccess(IRNode *node) { |
89 | allow_undefined_visitor = true; |
90 | invoke_default_visitor = false; |
91 | current_struct_for_ = nullptr; |
92 | current_offload_ = nullptr; |
93 | node->accept(this); |
94 | } |
95 | |
96 | void visit(Block *stmt_list) override { // block itself has no id |
97 | for (auto &stmt : stmt_list->statements) { |
98 | stmt->accept(this); |
99 | } |
100 | } |
101 | |
102 | void visit(StructForStmt *stmt) override { |
103 | current_struct_for_ = stmt; |
104 | stmt->body->accept(this); |
105 | current_struct_for_ = nullptr; |
106 | } |
107 | |
108 | void visit(OffloadedStmt *stmt) override { |
109 | current_offload_ = stmt; |
110 | if (stmt->body) |
111 | stmt->body->accept(this); |
112 | current_offload_ = nullptr; |
113 | } |
114 | |
115 | static SNode *least_sparse_ancestor(SNode *a) { |
116 | while (a->type == SNodeType::place || a->type == SNodeType::dense || |
117 | a->type == SNodeType::bit_struct || |
118 | a->type == SNodeType::quant_array) { |
119 | a = a->parent; |
120 | } |
121 | return a; |
122 | } |
123 | |
124 | static bool share_sparsity(SNode *a, SNode *b) { |
125 | return least_sparse_ancestor(a) == least_sparse_ancestor(b); |
126 | } |
127 | |
128 | void visit(GlobalPtrStmt *stmt) override { |
129 | if (stmt->activate) { |
130 | bool is_struct_for = |
131 | (current_offload_ && current_offload_->task_type == |
132 | OffloadedStmt::TaskType::struct_for) || |
133 | current_struct_for_; |
134 | if (is_struct_for) { |
135 | bool same_as_loop_snode = true; |
136 | SNode *loop_snode = nullptr; |
137 | if (current_struct_for_) { |
138 | loop_snode = current_struct_for_->snode; |
139 | } else { |
140 | loop_snode = current_offload_->snode; |
141 | } |
142 | TI_ASSERT(loop_snode); |
143 | if (!share_sparsity(stmt->snode, loop_snode)) { |
144 | same_as_loop_snode = false; |
145 | } |
146 | if (stmt->indices.size() == loop_snode->num_active_indices) |
147 | for (int i = 0; i < loop_snode->num_active_indices; i++) { |
148 | auto ind = stmt->indices[i]; |
149 | // TODO: vectorized cases? |
150 | if (auto loop_var = ind->cast<LoopIndexStmt>()) { |
151 | if (loop_var->index != i) { |
152 | same_as_loop_snode = false; |
153 | } |
154 | } else { |
155 | same_as_loop_snode = false; |
156 | } |
157 | } |
158 | if (same_as_loop_snode) |
159 | stmt->activate = false; |
160 | } |
161 | } |
162 | } |
163 | |
164 | private: |
165 | OffloadedStmt *current_offload_; |
166 | StructForStmt *current_struct_for_; |
167 | }; |
168 | |
169 | namespace irpass { |
170 | |
171 | void flag_access(IRNode *root) { |
172 | TI_AUTO_PROF; |
173 | FlagAccess flag_access(root); |
174 | WeakenAccess weaken_access(root); |
175 | } |
176 | |
177 | } // namespace irpass |
178 | |
179 | } // namespace taichi::lang |
180 |