1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/program/kernel.h" |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/transforms/lower_access.h" |
9 | #include "taichi/transforms/scalar_pointer_lowerer.h" |
10 | |
11 | #include <deque> |
12 | #include <set> |
13 | |
14 | namespace taichi::lang { |
15 | namespace { |
16 | |
17 | class LowerAccess; |
18 | |
19 | class PtrLowererImpl : public ScalarPointerLowerer { |
20 | public: |
21 | using ScalarPointerLowerer::ScalarPointerLowerer; |
22 | |
23 | void set_lower_access(LowerAccess *la); |
24 | |
25 | void set_pointer_needs_activation(bool v) { |
26 | pointer_needs_activation_ = v; |
27 | } |
28 | |
29 | protected: |
30 | Stmt *handle_snode_at_level(int level, |
31 | LinearizeStmt *linearized, |
32 | Stmt *last) override; |
33 | |
34 | private: |
35 | LowerAccess *la_{nullptr}; |
36 | std::unordered_set<SNode *> snodes_on_loop_; |
37 | bool pointer_needs_activation_{false}; |
38 | }; |
39 | |
40 | // Lower GlobalPtrStmt into smaller pieces for access optimization |
41 | |
42 | class LowerAccess : public IRVisitor { |
43 | public: |
44 | DelayedIRModifier modifier; |
45 | StructForStmt *current_struct_for; |
46 | const std::vector<SNode *> &kernel_forces_no_activate; |
47 | bool lower_atomic_ptr; |
48 | |
49 | LowerAccess(const std::vector<SNode *> &kernel_forces_no_activate, |
50 | bool lower_atomic_ptr) |
51 | : kernel_forces_no_activate(kernel_forces_no_activate), |
52 | lower_atomic_ptr(lower_atomic_ptr) { |
53 | // TODO: change this to false |
54 | allow_undefined_visitor = true; |
55 | current_struct_for = nullptr; |
56 | } |
57 | |
58 | void visit(Block *stmt_list) override { |
59 | for (auto &stmt : stmt_list->statements) { |
60 | stmt->accept(this); |
61 | } |
62 | } |
63 | |
64 | void visit(IfStmt *if_stmt) override { |
65 | if (if_stmt->true_statements) |
66 | if_stmt->true_statements->accept(this); |
67 | if (if_stmt->false_statements) { |
68 | if_stmt->false_statements->accept(this); |
69 | } |
70 | } |
71 | |
72 | void visit(OffloadedStmt *stmt) override { |
73 | stmt->all_blocks_accept(this); |
74 | } |
75 | |
76 | void visit(WhileStmt *stmt) override { |
77 | stmt->body->accept(this); |
78 | } |
79 | |
80 | void visit(RangeForStmt *for_stmt) override { |
81 | for_stmt->body->accept(this); |
82 | } |
83 | |
84 | void visit(StructForStmt *for_stmt) override { |
85 | current_struct_for = for_stmt; |
86 | for_stmt->body->accept(this); |
87 | current_struct_for = nullptr; |
88 | } |
89 | |
90 | VecStatement lower_ptr(GlobalPtrStmt *ptr, |
91 | bool activate, |
92 | SNodeOpType snode_op = SNodeOpType::undefined) { |
93 | VecStatement lowered; |
94 | if (snode_op == SNodeOpType::is_active) { |
95 | // For ti.is_active |
96 | TI_ASSERT(!activate); |
97 | } |
98 | PtrLowererImpl lowerer{ptr->snode, ptr->indices, snode_op, |
99 | ptr->is_bit_vectorized, &lowered}; |
100 | lowerer.set_pointer_needs_activation(activate); |
101 | lowerer.set_lower_access(this); |
102 | lowerer.run(); |
103 | TI_ASSERT(lowered.size() > 0); |
104 | auto lowered_ptr = lowered.back().get(); |
105 | if (ptr->is_bit_vectorized) { |
106 | // if the global ptr is bit vectorized, we start from the place snode |
107 | // and find the parent quant array snode, use its physical type |
108 | auto parent_ret_type = ptr->snode->parent->physical_type; |
109 | auto ptr_ret_type = |
110 | TypeFactory::get_instance().get_pointer_type(parent_ret_type); |
111 | lowered_ptr->ret_type = DataType(ptr_ret_type); |
112 | } else { |
113 | lowered_ptr->ret_type = ptr->snode->dt; |
114 | } |
115 | return lowered; |
116 | } |
117 | |
118 | void visit(GlobalLoadStmt *stmt) override { |
119 | if (!stmt->src->is<GlobalPtrStmt>()) |
120 | return; |
121 | // No need to activate for all read accesses |
122 | auto lowered = lower_ptr(stmt->src->as<GlobalPtrStmt>(), false); |
123 | stmt->src = lowered.back().get(); |
124 | modifier.insert_before(stmt, std::move(lowered)); |
125 | } |
126 | |
127 | // TODO: this seems to be redundant |
128 | void visit(MatrixPtrStmt *stmt) override { |
129 | if (!stmt->is_unlowered_global_ptr()) |
130 | return; |
131 | auto ptr = stmt->origin->as<GlobalPtrStmt>(); |
132 | // If ptr already has activate = false, no need to activate all the |
133 | // generated micro-access ops. Otherwise, activate the nodes. |
134 | auto lowered = lower_ptr(ptr, ptr->activate); |
135 | stmt->origin = lowered.back().get(); |
136 | modifier.insert_before(stmt, std::move(lowered)); |
137 | } |
138 | |
139 | void visit(GlobalStoreStmt *stmt) override { |
140 | if (!stmt->dest->is<GlobalPtrStmt>()) |
141 | return; |
142 | auto ptr = stmt->dest->as<GlobalPtrStmt>(); |
143 | // If ptr already has activate = false, no need to activate all the |
144 | // generated micro-access ops. Otherwise, activate the nodes. |
145 | auto lowered = lower_ptr(ptr, ptr->activate); |
146 | stmt->dest = lowered.back().get(); |
147 | modifier.insert_before(stmt, std::move(lowered)); |
148 | } |
149 | |
150 | void visit(SNodeOpStmt *stmt) override { |
151 | if (stmt->ptr->is<GlobalPtrStmt>()) { |
152 | auto global_ptr = stmt->ptr->as<GlobalPtrStmt>(); |
153 | if (global_ptr->is_cell_access) { |
154 | auto lowered = lower_ptr(global_ptr, false, stmt->op_type); |
155 | modifier.replace_with(stmt, std::move(lowered), true); |
156 | } else if (stmt->op_type == SNodeOpType::get_addr) { |
157 | auto lowered = lower_ptr(global_ptr, false); |
158 | auto cast = lowered.push_back<UnaryOpStmt>(UnaryOpType::cast_bits, |
159 | lowered.back().get()); |
160 | cast->cast_type = TypeFactory::get_instance().get_primitive_type( |
161 | PrimitiveTypeID::u64); |
162 | stmt->ptr = lowered.back().get(); |
163 | modifier.replace_with(stmt, std::move(lowered)); |
164 | } else { |
165 | auto lowered = |
166 | lower_ptr(global_ptr, SNodeOpStmt::need_activation(stmt->op_type)); |
167 | stmt->ptr = lowered.back().get(); |
168 | modifier.insert_before(stmt, std::move(lowered)); |
169 | } |
170 | } |
171 | } |
172 | |
173 | void visit(AtomicOpStmt *stmt) override { |
174 | if (!lower_atomic_ptr) |
175 | return; |
176 | if (stmt->dest->is<GlobalPtrStmt>()) { |
177 | auto lowered = lower_ptr(stmt->dest->as<GlobalPtrStmt>(), |
178 | stmt->dest->as<GlobalPtrStmt>()->activate); |
179 | stmt->dest = lowered.back().get(); |
180 | modifier.insert_before(stmt, std::move(lowered)); |
181 | } |
182 | } |
183 | |
184 | void visit(LocalStoreStmt *stmt) override { |
185 | if (stmt->val->is<GlobalPtrStmt>()) { |
186 | auto lowered = lower_ptr(stmt->val->as<GlobalPtrStmt>(), true); |
187 | stmt->val = lowered.back().get(); |
188 | modifier.insert_before(stmt, std::move(lowered)); |
189 | } |
190 | } |
191 | |
192 | static bool run(IRNode *node, |
193 | const std::vector<SNode *> &kernel_forces_no_activate, |
194 | bool lower_atomic) { |
195 | LowerAccess inst(kernel_forces_no_activate, lower_atomic); |
196 | bool modified = false; |
197 | while (true) { |
198 | node->accept(&inst); |
199 | if (inst.modifier.modify_ir()) { |
200 | modified = true; |
201 | } else { |
202 | break; |
203 | } |
204 | } |
205 | return modified; |
206 | } |
207 | }; |
208 | |
209 | void PtrLowererImpl::set_lower_access(LowerAccess *la) { |
210 | la_ = la; |
211 | |
212 | snodes_on_loop_.clear(); |
213 | if (la_->current_struct_for) { |
214 | for (SNode *s = la_->current_struct_for->snode; s != nullptr; |
215 | s = s->parent) { |
216 | snodes_on_loop_.insert(s); |
217 | } |
218 | } |
219 | } |
220 | |
221 | Stmt *PtrLowererImpl::handle_snode_at_level(int level, |
222 | LinearizeStmt *linearized, |
223 | Stmt *last) { |
224 | // Check whether |snode| is part of the tree being iterated over by struct for |
225 | auto *snode = snodes()[level]; |
226 | bool on_loop_tree = (snodes_on_loop_.find(snode) != snodes_on_loop_.end()); |
227 | auto *current_struct_for = la_->current_struct_for; |
228 | if (on_loop_tree && current_struct_for && |
229 | (indices_.size() == current_struct_for->snode->num_active_indices)) { |
230 | for (int j = 0; j < (int)indices_.size(); j++) { |
231 | auto diff = irpass::analysis::value_diff_loop_index( |
232 | indices_[j], current_struct_for, j); |
233 | if (!diff.linear_related()) { |
234 | on_loop_tree = false; |
235 | } else if (j == (int)indices_.size() - 1) { |
236 | if (!(0 <= diff.low && diff.high <= 1)) { // TODO: Vectorize |
237 | on_loop_tree = false; |
238 | } |
239 | } else { |
240 | if (!diff.certain() || diff.low != 0) { |
241 | on_loop_tree = false; |
242 | } |
243 | } |
244 | } |
245 | } |
246 | |
247 | // Generates the SNode access operations at the current |level|. |
248 | if ((snode_op_ != SNodeOpType::undefined) && |
249 | (level == (int)snodes().size() - 1)) { |
250 | // Create a SNodeOp querying if element i(linearized) of node is active |
251 | lowered_->push_back<SNodeOpStmt>(snode_op_, snode, last, linearized); |
252 | } else { |
253 | const bool kernel_forces_no_activate_snode = |
254 | std::find(la_->kernel_forces_no_activate.begin(), |
255 | la_->kernel_forces_no_activate.end(), |
256 | snode) != la_->kernel_forces_no_activate.end(); |
257 | |
258 | const bool needs_activation = |
259 | snode->need_activation() && pointer_needs_activation_ && |
260 | !kernel_forces_no_activate_snode && !on_loop_tree; |
261 | |
262 | auto lookup = lowered_->push_back<SNodeLookupStmt>(snode, last, linearized, |
263 | needs_activation); |
264 | int chid = snode->child_id(snodes()[level + 1]); |
265 | if (is_bit_vectorized_ && (snode->type == SNodeType::dense) && |
266 | (level == path_length() - 2)) { |
267 | last = lowered_->push_back<GetChStmt>(lookup, chid, |
268 | /*is_bit_vectorized=*/true); |
269 | } else { |
270 | last = lowered_->push_back<GetChStmt>(lookup, chid, |
271 | /*is_bit_vectorized=*/false); |
272 | } |
273 | } |
274 | return last; |
275 | } |
276 | |
277 | } // namespace |
278 | |
279 | const PassID LowerAccessPass::id = "LowerAccessPass" ; |
280 | |
281 | namespace irpass { |
282 | |
283 | bool lower_access(IRNode *root, |
284 | const CompileConfig &config, |
285 | const LowerAccessPass::Args &args) { |
286 | bool modified = |
287 | LowerAccess::run(root, args.kernel_forces_no_activate, args.lower_atomic); |
288 | type_check(root, config); |
289 | return modified; |
290 | } |
291 | |
292 | } // namespace irpass |
293 | } // namespace taichi::lang |
294 | |