1 | #include "taichi/ir/ir.h" |
2 | |
3 | #include <numeric> |
4 | #include <thread> |
5 | #include <unordered_map> |
6 | |
7 | #include "taichi/ir/analysis.h" |
8 | #include "taichi/ir/statements.h" |
9 | #include "taichi/ir/transforms.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | std::string snode_access_flag_name(SNodeAccessFlag type) { |
14 | if (type == SNodeAccessFlag::block_local) { |
15 | return "block_local" ; |
16 | } else if (type == SNodeAccessFlag::read_only) { |
17 | return "read_only" ; |
18 | } else if (type == SNodeAccessFlag::mesh_local) { |
19 | return "mesh_local" ; |
20 | } else { |
21 | TI_ERROR("Undefined SNode AccessType (value={})" , int(type)); |
22 | } |
23 | } |
24 | |
25 | std::string Identifier::raw_name() const { |
26 | if (name_.empty()) |
27 | return fmt::format("tmp{}" , id); |
28 | else |
29 | return name_; |
30 | } |
31 | |
32 | Stmt *VecStatement::push_back(pStmt &&stmt) { |
33 | auto ret = stmt.get(); |
34 | stmts.push_back(std::move(stmt)); |
35 | return ret; |
36 | } |
37 | |
38 | IRNode *IRNode::get_ir_root() { |
39 | auto node = this; |
40 | while (node->get_parent()) { |
41 | node = node->get_parent(); |
42 | } |
43 | return node; |
44 | } |
45 | |
46 | std::unique_ptr<IRNode> IRNode::clone() { |
47 | std::unique_ptr<IRNode> new_irnode; |
48 | if (is<Block>()) |
49 | new_irnode = as<Block>()->clone(); |
50 | else if (is<Stmt>()) |
51 | new_irnode = as<Stmt>()->clone(); |
52 | else { |
53 | TI_NOT_IMPLEMENTED |
54 | } |
55 | return new_irnode; |
56 | } |
57 | |
58 | class StatementTypeNameVisitor : public IRVisitor { |
59 | public: |
60 | std::string type_name; |
61 | StatementTypeNameVisitor() { |
62 | } |
63 | |
64 | #define PER_STATEMENT(x) \ |
65 | void visit(x *stmt) override { type_name = #x; } |
66 | #include "taichi/inc/statements.inc.h" |
67 | |
68 | #undef PER_STATEMENT |
69 | }; |
70 | |
71 | int StmtFieldSNode::get_snode_id(SNode *snode) { |
72 | if (snode == nullptr) |
73 | return -1; |
74 | return snode->id; |
75 | } |
76 | |
77 | bool StmtFieldSNode::equal(const StmtField *other_generic) const { |
78 | if (auto other = dynamic_cast<const StmtFieldSNode *>(other_generic)) { |
79 | return get_snode_id(snode_) == get_snode_id(other->snode_); |
80 | } else { |
81 | // Different types |
82 | return false; |
83 | } |
84 | } |
85 | |
86 | bool StmtFieldMemoryAccessOptions::equal(const StmtField *other_generic) const { |
87 | if (auto other = |
88 | dynamic_cast<const StmtFieldMemoryAccessOptions *>(other_generic)) { |
89 | return opt_.get_all() == other->opt_.get_all(); |
90 | } else { |
91 | // Different types |
92 | return false; |
93 | } |
94 | } |
95 | |
96 | bool StmtFieldManager::equal(StmtFieldManager &other) const { |
97 | if (fields.size() != other.fields.size()) { |
98 | return false; |
99 | } |
100 | auto num_fields = fields.size(); |
101 | for (std::size_t i = 0; i < num_fields; i++) { |
102 | if (!fields[i]->equal(other.fields[i].get())) { |
103 | return false; |
104 | } |
105 | } |
106 | return true; |
107 | } |
108 | |
109 | std::atomic<int> Stmt::instance_id_counter(0); |
110 | |
111 | Stmt::Stmt() : field_manager(this), fields_registered(false) { |
112 | parent = nullptr; |
113 | instance_id = instance_id_counter++; |
114 | id = instance_id; |
115 | erased = false; |
116 | } |
117 | |
118 | Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { |
119 | parent = stmt.parent; |
120 | instance_id = instance_id_counter++; |
121 | id = instance_id; |
122 | erased = stmt.erased; |
123 | tb = stmt.tb; |
124 | ret_type = stmt.ret_type; |
125 | } |
126 | |
127 | Stmt *Stmt::insert_before_me(std::unique_ptr<Stmt> &&new_stmt) { |
128 | auto ret = new_stmt.get(); |
129 | TI_ASSERT(parent); |
130 | auto iter = parent->find(this); |
131 | TI_ASSERT(iter != parent->statements.end()); |
132 | parent->insert_at(std::move(new_stmt), iter); |
133 | return ret; |
134 | } |
135 | |
136 | Stmt *Stmt::insert_after_me(std::unique_ptr<Stmt> &&new_stmt) { |
137 | auto ret = new_stmt.get(); |
138 | TI_ASSERT(parent); |
139 | auto iter = parent->find(this); |
140 | TI_ASSERT(iter != parent->statements.end()); |
141 | parent->insert_at(std::move(new_stmt), std::next(iter)); |
142 | return ret; |
143 | } |
144 | |
145 | void Stmt::replace_usages_with(Stmt *new_stmt) { |
146 | irpass::replace_all_usages_with(nullptr, this, new_stmt); |
147 | } |
148 | |
149 | void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) { |
150 | parent->replace_with(this, std::move(new_statements), replace_usages); |
151 | } |
152 | |
153 | void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { |
154 | int n_op = num_operands(); |
155 | for (int i = 0; i < n_op; i++) { |
156 | if (operand(i) == old_stmt) { |
157 | *operands[i] = new_stmt; |
158 | } |
159 | } |
160 | } |
161 | |
162 | std::string Stmt::type_hint() const { |
163 | if (ret_type->is_primitive(PrimitiveTypeID::unknown)) |
164 | return "" ; |
165 | else |
166 | return fmt::format("<{}> " , ret_type.to_string()); |
167 | } |
168 | |
169 | std::string Stmt::type() { |
170 | StatementTypeNameVisitor v; |
171 | this->accept(&v); |
172 | return v.type_name; |
173 | } |
174 | |
175 | IRNode *Stmt::get_parent() const { |
176 | return parent; |
177 | } |
178 | |
179 | std::vector<Stmt *> Stmt::get_operands() const { |
180 | std::vector<Stmt *> ret; |
181 | for (int i = 0; i < num_operands(); i++) { |
182 | ret.push_back(*operands[i]); |
183 | } |
184 | return ret; |
185 | } |
186 | |
187 | void Stmt::set_operand(int i, Stmt *stmt) { |
188 | *operands[i] = stmt; |
189 | } |
190 | |
191 | void Stmt::register_operand(Stmt *&stmt) { |
192 | operands.push_back(&stmt); |
193 | } |
194 | |
195 | void Stmt::mark_fields_registered() { |
196 | TI_ASSERT(!fields_registered); |
197 | fields_registered = true; |
198 | } |
199 | |
200 | bool Stmt::has_operand(Stmt *stmt) const { |
201 | for (int i = 0; i < num_operands(); i++) { |
202 | if (*operands[i] == stmt) { |
203 | return true; |
204 | } |
205 | } |
206 | return false; |
207 | } |
208 | |
209 | int Stmt::locate_operand(Stmt **stmt) { |
210 | for (int i = 0; i < num_operands(); i++) { |
211 | if (operands[i] == stmt) { |
212 | return i; |
213 | } |
214 | } |
215 | return -1; |
216 | } |
217 | |
218 | void Block::erase(int location) { |
219 | auto iter = locate(location); |
220 | erase_range(iter, std::next(iter)); |
221 | } |
222 | |
223 | void Block::erase(Stmt *stmt) { |
224 | auto iter = find(stmt); |
225 | erase_range(iter, std::next(iter)); |
226 | } |
227 | |
228 | void Block::erase_range(stmt_vector::iterator begin, |
229 | stmt_vector::iterator end) { |
230 | for (auto iter = begin; iter != end; iter++) { |
231 | (*iter)->erased = true; |
232 | trash_bin.push_back(std::move(*iter)); |
233 | } |
234 | statements.erase(begin, end); |
235 | } |
236 | |
237 | void Block::erase(std::unordered_set<Stmt *> stmts) { |
238 | stmt_vector clean_stmts; |
239 | clean_stmts.reserve(statements.size()); |
240 | // We dont have access to erase_if in C++17 |
241 | for (pStmt &stmt : statements) { |
242 | if (stmts.find(stmt.get()) != stmts.end()) { |
243 | stmt->erased = true; |
244 | trash_bin.push_back(std::move(stmt)); |
245 | } else { |
246 | clean_stmts.push_back(std::move(stmt)); |
247 | } |
248 | } |
249 | statements = std::move(clean_stmts); |
250 | } |
251 | |
252 | std::unique_ptr<Stmt> Block::(int location) { |
253 | auto stmt = std::move(statements[location]); |
254 | statements.erase(statements.begin() + location); |
255 | return stmt; |
256 | } |
257 | |
258 | std::unique_ptr<Stmt> Block::(Stmt *stmt) { |
259 | for (int i = 0; i < (int)statements.size(); i++) { |
260 | if (statements[i].get() == stmt) { |
261 | return extract(i); |
262 | } |
263 | } |
264 | TI_ERROR("stmt not found" ); |
265 | } |
266 | |
267 | Stmt *Block::insert(std::unique_ptr<Stmt> &&stmt, int location) { |
268 | return insert_at(std::move(stmt), locate(location)); |
269 | } |
270 | |
271 | Stmt *Block::insert_at(std::unique_ptr<Stmt> &&stmt, |
272 | stmt_vector::iterator location) { |
273 | auto stmt_ptr = stmt.get(); |
274 | stmt->parent = this; |
275 | statements.insert(location, std::move(stmt)); |
276 | return stmt_ptr; |
277 | } |
278 | |
279 | Stmt *Block::insert(VecStatement &&stmt, int location) { |
280 | return insert_at(std::move(stmt), locate(location)); |
281 | } |
282 | |
283 | Stmt *Block::insert_at(VecStatement &&stmt, stmt_vector::iterator location) { |
284 | Stmt *stmt_ptr = nullptr; |
285 | if (stmt.size()) { |
286 | stmt_ptr = stmt.back().get(); |
287 | } |
288 | for (auto &s : stmt.stmts) { |
289 | s->parent = this; |
290 | } |
291 | statements.insert(location, std::make_move_iterator(stmt.stmts.begin()), |
292 | std::make_move_iterator(stmt.stmts.end())); |
293 | return stmt_ptr; |
294 | } |
295 | |
296 | void Block::replace_statements_in_range(int start, |
297 | int end, |
298 | VecStatement &&stmts) { |
299 | TI_ASSERT(start <= end); |
300 | erase_range(locate(start), locate(end)); |
301 | insert(std::move(stmts), start); |
302 | } |
303 | |
304 | void Block::replace_with(Stmt *old_statement, |
305 | std::unique_ptr<Stmt> &&new_statement, |
306 | bool replace_usages) { |
307 | VecStatement vec; |
308 | vec.push_back(std::move(new_statement)); |
309 | replace_with(old_statement, std::move(vec), replace_usages); |
310 | } |
311 | |
312 | Stmt *Block::lookup_var(const Identifier &ident) const { |
313 | auto ptr = local_var_to_stmt.find(ident); |
314 | if (ptr != local_var_to_stmt.end()) { |
315 | return ptr->second; |
316 | } else { |
317 | if (parent_block()) { |
318 | return parent_block()->lookup_var(ident); |
319 | } else { |
320 | return nullptr; |
321 | } |
322 | } |
323 | } |
324 | |
325 | void Block::set_statements(VecStatement &&stmts) { |
326 | statements.clear(); |
327 | for (int i = 0; i < (int)stmts.size(); i++) { |
328 | insert(std::move(stmts[i]), i); |
329 | } |
330 | } |
331 | |
332 | void Block::insert_before(Stmt *old_statement, VecStatement &&new_statements) { |
333 | insert_at(std::move(new_statements), find(old_statement)); |
334 | } |
335 | |
336 | void Block::insert_after(Stmt *old_statement, VecStatement &&new_statements) { |
337 | insert_at(std::move(new_statements), std::next(find(old_statement))); |
338 | } |
339 | |
340 | void Block::replace_with(Stmt *old_statement, |
341 | VecStatement &&new_statements, |
342 | bool replace_usages) { |
343 | auto iter = find(old_statement); |
344 | TI_ASSERT(iter != statements.end()); |
345 | if (replace_usages && !new_statements.stmts.empty()) |
346 | old_statement->replace_usages_with(new_statements.back().get()); |
347 | trash_bin.push_back(std::move(*iter)); |
348 | if (new_statements.size() == 1) { |
349 | // Keep all std::vector::iterator valid in this case. |
350 | *iter = std::move(new_statements[0]); |
351 | (*iter)->parent = this; |
352 | } else { |
353 | iter = statements.erase(iter); |
354 | insert_at(std::move(new_statements), iter); |
355 | } |
356 | } |
357 | |
358 | Block *Block::parent_block() const { |
359 | if (parent_stmt == nullptr) |
360 | return nullptr; |
361 | return parent_stmt->parent; |
362 | } |
363 | |
364 | IRNode *Block::get_parent() const { |
365 | return parent_stmt; |
366 | } |
367 | |
368 | bool Block::has_container_statements() { |
369 | for (auto &s : statements) { |
370 | if (s->is_container_statement()) |
371 | return true; |
372 | } |
373 | return false; |
374 | } |
375 | |
376 | int Block::locate(Stmt *stmt) { |
377 | for (int i = 0; i < (int)statements.size(); i++) { |
378 | if (statements[i].get() == stmt) { |
379 | return i; |
380 | } |
381 | } |
382 | return -1; |
383 | } |
384 | |
385 | stmt_vector::iterator Block::locate(int location) { |
386 | if (location == -1) |
387 | return statements.end(); |
388 | return statements.begin() + location; |
389 | } |
390 | |
391 | stmt_vector::iterator Block::find(Stmt *stmt) { |
392 | return std::find_if(statements.begin(), statements.end(), |
393 | [stmt](const pStmt &x) { return x.get() == stmt; }); |
394 | } |
395 | |
396 | std::unique_ptr<Block> Block::clone() const { |
397 | auto new_block = std::make_unique<Block>(); |
398 | new_block->parent_stmt = parent_stmt; |
399 | new_block->stop_gradients = stop_gradients; |
400 | new_block->statements.reserve(size()); |
401 | for (auto &stmt : statements) |
402 | new_block->insert(stmt->clone()); |
403 | return new_block; |
404 | } |
405 | |
406 | DelayedIRModifier::~DelayedIRModifier() { |
407 | // TODO: destructors should not be interrupted |
408 | TI_ASSERT(to_insert_before_.empty()); |
409 | TI_ASSERT(to_insert_after_.empty()); |
410 | TI_ASSERT(to_erase_.empty()); |
411 | TI_ASSERT(to_replace_with_.empty()); |
412 | TI_ASSERT(to_extract_to_block_front_.empty()); |
413 | TI_ASSERT(to_type_check_.empty()); |
414 | } |
415 | |
416 | void DelayedIRModifier::erase(Stmt *stmt) { |
417 | to_erase_.push_back(stmt); |
418 | } |
419 | |
420 | void DelayedIRModifier::insert_before(Stmt *old_statement, |
421 | std::unique_ptr<Stmt> new_statements) { |
422 | to_insert_before_.emplace_back(old_statement, |
423 | VecStatement(std::move(new_statements))); |
424 | } |
425 | |
426 | void DelayedIRModifier::insert_before(Stmt *old_statement, |
427 | VecStatement &&new_statements) { |
428 | to_insert_before_.emplace_back(old_statement, std::move(new_statements)); |
429 | } |
430 | |
431 | void DelayedIRModifier::insert_after(Stmt *old_statement, |
432 | std::unique_ptr<Stmt> new_statements) { |
433 | to_insert_after_.emplace_back(old_statement, |
434 | VecStatement(std::move(new_statements))); |
435 | } |
436 | |
437 | void DelayedIRModifier::insert_after(Stmt *old_statement, |
438 | VecStatement &&new_statements) { |
439 | to_insert_after_.emplace_back(old_statement, std::move(new_statements)); |
440 | } |
441 | |
442 | void DelayedIRModifier::replace_with(Stmt *stmt, |
443 | VecStatement &&new_statements, |
444 | bool replace_usages) { |
445 | to_replace_with_.emplace_back(stmt, std::move(new_statements), |
446 | replace_usages); |
447 | } |
448 | |
449 | void DelayedIRModifier::(Stmt *stmt, Block *blk) { |
450 | to_extract_to_block_front_.emplace_back(stmt, blk); |
451 | } |
452 | |
453 | void DelayedIRModifier::type_check(IRNode *node, CompileConfig cfg) { |
454 | to_type_check_.emplace_back(node, cfg); |
455 | } |
456 | |
457 | bool DelayedIRModifier::modify_ir() { |
458 | bool force_modified = modified_; |
459 | modified_ = false; |
460 | if (to_insert_before_.empty() && to_insert_after_.empty() && |
461 | to_erase_.empty() && to_replace_with_.empty() && |
462 | to_extract_to_block_front_.empty() && to_type_check_.empty()) |
463 | return force_modified; |
464 | for (auto &i : to_insert_before_) { |
465 | i.first->parent->insert_before(i.first, std::move(i.second)); |
466 | } |
467 | to_insert_before_.clear(); |
468 | for (auto &i : to_insert_after_) { |
469 | i.first->parent->insert_after(i.first, std::move(i.second)); |
470 | } |
471 | to_insert_after_.clear(); |
472 | for (auto &stmt : to_erase_) { |
473 | stmt->parent->erase(stmt); |
474 | } |
475 | to_erase_.clear(); |
476 | for (auto &i : to_replace_with_) { |
477 | std::get<0>(i)->replace_with(std::move(std::get<1>(i)), std::get<2>(i)); |
478 | } |
479 | to_replace_with_.clear(); |
480 | for (auto &i : to_extract_to_block_front_) { |
481 | auto = i.first->parent->extract(i.first); |
482 | i.second->insert(std::move(extracted), 0); |
483 | } |
484 | to_extract_to_block_front_.clear(); |
485 | for (auto &i : to_type_check_) { |
486 | irpass::type_check(i.first, i.second); |
487 | } |
488 | to_type_check_.clear(); |
489 | return true; |
490 | } |
491 | |
492 | void DelayedIRModifier::mark_as_modified() { |
493 | modified_ = true; |
494 | } |
495 | |
496 | ImmediateIRModifier::ImmediateIRModifier(IRNode *root) { |
497 | stmt_usages_ = irpass::analysis::gather_statement_usages(root); |
498 | } |
499 | |
500 | void ImmediateIRModifier::replace_usages_with(Stmt *old_stmt, Stmt *new_stmt) { |
501 | if (stmt_usages_.find(old_stmt) == stmt_usages_.end()) |
502 | return; |
503 | for (auto &[usage, i] : stmt_usages_.at(old_stmt)) { |
504 | usage->set_operand(i, new_stmt); |
505 | } |
506 | } |
507 | |
508 | } // namespace taichi::lang |
509 | |