1#include "taichi/ir/ir.h"
2
3#include <numeric>
4#include <thread>
5#include <unordered_map>
6
7#include "taichi/ir/analysis.h"
8#include "taichi/ir/statements.h"
9#include "taichi/ir/transforms.h"
10
11namespace taichi::lang {
12
13std::string snode_access_flag_name(SNodeAccessFlag type) {
14 if (type == SNodeAccessFlag::block_local) {
15 return "block_local";
16 } else if (type == SNodeAccessFlag::read_only) {
17 return "read_only";
18 } else if (type == SNodeAccessFlag::mesh_local) {
19 return "mesh_local";
20 } else {
21 TI_ERROR("Undefined SNode AccessType (value={})", int(type));
22 }
23}
24
25std::string Identifier::raw_name() const {
26 if (name_.empty())
27 return fmt::format("tmp{}", id);
28 else
29 return name_;
30}
31
32Stmt *VecStatement::push_back(pStmt &&stmt) {
33 auto ret = stmt.get();
34 stmts.push_back(std::move(stmt));
35 return ret;
36}
37
38IRNode *IRNode::get_ir_root() {
39 auto node = this;
40 while (node->get_parent()) {
41 node = node->get_parent();
42 }
43 return node;
44}
45
46std::unique_ptr<IRNode> IRNode::clone() {
47 std::unique_ptr<IRNode> new_irnode;
48 if (is<Block>())
49 new_irnode = as<Block>()->clone();
50 else if (is<Stmt>())
51 new_irnode = as<Stmt>()->clone();
52 else {
53 TI_NOT_IMPLEMENTED
54 }
55 return new_irnode;
56}
57
58class StatementTypeNameVisitor : public IRVisitor {
59 public:
60 std::string type_name;
61 StatementTypeNameVisitor() {
62 }
63
64#define PER_STATEMENT(x) \
65 void visit(x *stmt) override { type_name = #x; }
66#include "taichi/inc/statements.inc.h"
67
68#undef PER_STATEMENT
69};
70
71int StmtFieldSNode::get_snode_id(SNode *snode) {
72 if (snode == nullptr)
73 return -1;
74 return snode->id;
75}
76
77bool StmtFieldSNode::equal(const StmtField *other_generic) const {
78 if (auto other = dynamic_cast<const StmtFieldSNode *>(other_generic)) {
79 return get_snode_id(snode_) == get_snode_id(other->snode_);
80 } else {
81 // Different types
82 return false;
83 }
84}
85
86bool StmtFieldMemoryAccessOptions::equal(const StmtField *other_generic) const {
87 if (auto other =
88 dynamic_cast<const StmtFieldMemoryAccessOptions *>(other_generic)) {
89 return opt_.get_all() == other->opt_.get_all();
90 } else {
91 // Different types
92 return false;
93 }
94}
95
96bool StmtFieldManager::equal(StmtFieldManager &other) const {
97 if (fields.size() != other.fields.size()) {
98 return false;
99 }
100 auto num_fields = fields.size();
101 for (std::size_t i = 0; i < num_fields; i++) {
102 if (!fields[i]->equal(other.fields[i].get())) {
103 return false;
104 }
105 }
106 return true;
107}
108
109std::atomic<int> Stmt::instance_id_counter(0);
110
111Stmt::Stmt() : field_manager(this), fields_registered(false) {
112 parent = nullptr;
113 instance_id = instance_id_counter++;
114 id = instance_id;
115 erased = false;
116}
117
118Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) {
119 parent = stmt.parent;
120 instance_id = instance_id_counter++;
121 id = instance_id;
122 erased = stmt.erased;
123 tb = stmt.tb;
124 ret_type = stmt.ret_type;
125}
126
127Stmt *Stmt::insert_before_me(std::unique_ptr<Stmt> &&new_stmt) {
128 auto ret = new_stmt.get();
129 TI_ASSERT(parent);
130 auto iter = parent->find(this);
131 TI_ASSERT(iter != parent->statements.end());
132 parent->insert_at(std::move(new_stmt), iter);
133 return ret;
134}
135
136Stmt *Stmt::insert_after_me(std::unique_ptr<Stmt> &&new_stmt) {
137 auto ret = new_stmt.get();
138 TI_ASSERT(parent);
139 auto iter = parent->find(this);
140 TI_ASSERT(iter != parent->statements.end());
141 parent->insert_at(std::move(new_stmt), std::next(iter));
142 return ret;
143}
144
145void Stmt::replace_usages_with(Stmt *new_stmt) {
146 irpass::replace_all_usages_with(nullptr, this, new_stmt);
147}
148
149void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) {
150 parent->replace_with(this, std::move(new_statements), replace_usages);
151}
152
153void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) {
154 int n_op = num_operands();
155 for (int i = 0; i < n_op; i++) {
156 if (operand(i) == old_stmt) {
157 *operands[i] = new_stmt;
158 }
159 }
160}
161
162std::string Stmt::type_hint() const {
163 if (ret_type->is_primitive(PrimitiveTypeID::unknown))
164 return "";
165 else
166 return fmt::format("<{}> ", ret_type.to_string());
167}
168
169std::string Stmt::type() {
170 StatementTypeNameVisitor v;
171 this->accept(&v);
172 return v.type_name;
173}
174
175IRNode *Stmt::get_parent() const {
176 return parent;
177}
178
179std::vector<Stmt *> Stmt::get_operands() const {
180 std::vector<Stmt *> ret;
181 for (int i = 0; i < num_operands(); i++) {
182 ret.push_back(*operands[i]);
183 }
184 return ret;
185}
186
187void Stmt::set_operand(int i, Stmt *stmt) {
188 *operands[i] = stmt;
189}
190
191void Stmt::register_operand(Stmt *&stmt) {
192 operands.push_back(&stmt);
193}
194
195void Stmt::mark_fields_registered() {
196 TI_ASSERT(!fields_registered);
197 fields_registered = true;
198}
199
200bool Stmt::has_operand(Stmt *stmt) const {
201 for (int i = 0; i < num_operands(); i++) {
202 if (*operands[i] == stmt) {
203 return true;
204 }
205 }
206 return false;
207}
208
209int Stmt::locate_operand(Stmt **stmt) {
210 for (int i = 0; i < num_operands(); i++) {
211 if (operands[i] == stmt) {
212 return i;
213 }
214 }
215 return -1;
216}
217
218void Block::erase(int location) {
219 auto iter = locate(location);
220 erase_range(iter, std::next(iter));
221}
222
223void Block::erase(Stmt *stmt) {
224 auto iter = find(stmt);
225 erase_range(iter, std::next(iter));
226}
227
228void Block::erase_range(stmt_vector::iterator begin,
229 stmt_vector::iterator end) {
230 for (auto iter = begin; iter != end; iter++) {
231 (*iter)->erased = true;
232 trash_bin.push_back(std::move(*iter));
233 }
234 statements.erase(begin, end);
235}
236
237void Block::erase(std::unordered_set<Stmt *> stmts) {
238 stmt_vector clean_stmts;
239 clean_stmts.reserve(statements.size());
240 // We dont have access to erase_if in C++17
241 for (pStmt &stmt : statements) {
242 if (stmts.find(stmt.get()) != stmts.end()) {
243 stmt->erased = true;
244 trash_bin.push_back(std::move(stmt));
245 } else {
246 clean_stmts.push_back(std::move(stmt));
247 }
248 }
249 statements = std::move(clean_stmts);
250}
251
252std::unique_ptr<Stmt> Block::extract(int location) {
253 auto stmt = std::move(statements[location]);
254 statements.erase(statements.begin() + location);
255 return stmt;
256}
257
258std::unique_ptr<Stmt> Block::extract(Stmt *stmt) {
259 for (int i = 0; i < (int)statements.size(); i++) {
260 if (statements[i].get() == stmt) {
261 return extract(i);
262 }
263 }
264 TI_ERROR("stmt not found");
265}
266
267Stmt *Block::insert(std::unique_ptr<Stmt> &&stmt, int location) {
268 return insert_at(std::move(stmt), locate(location));
269}
270
271Stmt *Block::insert_at(std::unique_ptr<Stmt> &&stmt,
272 stmt_vector::iterator location) {
273 auto stmt_ptr = stmt.get();
274 stmt->parent = this;
275 statements.insert(location, std::move(stmt));
276 return stmt_ptr;
277}
278
279Stmt *Block::insert(VecStatement &&stmt, int location) {
280 return insert_at(std::move(stmt), locate(location));
281}
282
283Stmt *Block::insert_at(VecStatement &&stmt, stmt_vector::iterator location) {
284 Stmt *stmt_ptr = nullptr;
285 if (stmt.size()) {
286 stmt_ptr = stmt.back().get();
287 }
288 for (auto &s : stmt.stmts) {
289 s->parent = this;
290 }
291 statements.insert(location, std::make_move_iterator(stmt.stmts.begin()),
292 std::make_move_iterator(stmt.stmts.end()));
293 return stmt_ptr;
294}
295
296void Block::replace_statements_in_range(int start,
297 int end,
298 VecStatement &&stmts) {
299 TI_ASSERT(start <= end);
300 erase_range(locate(start), locate(end));
301 insert(std::move(stmts), start);
302}
303
304void Block::replace_with(Stmt *old_statement,
305 std::unique_ptr<Stmt> &&new_statement,
306 bool replace_usages) {
307 VecStatement vec;
308 vec.push_back(std::move(new_statement));
309 replace_with(old_statement, std::move(vec), replace_usages);
310}
311
312Stmt *Block::lookup_var(const Identifier &ident) const {
313 auto ptr = local_var_to_stmt.find(ident);
314 if (ptr != local_var_to_stmt.end()) {
315 return ptr->second;
316 } else {
317 if (parent_block()) {
318 return parent_block()->lookup_var(ident);
319 } else {
320 return nullptr;
321 }
322 }
323}
324
325void Block::set_statements(VecStatement &&stmts) {
326 statements.clear();
327 for (int i = 0; i < (int)stmts.size(); i++) {
328 insert(std::move(stmts[i]), i);
329 }
330}
331
332void Block::insert_before(Stmt *old_statement, VecStatement &&new_statements) {
333 insert_at(std::move(new_statements), find(old_statement));
334}
335
336void Block::insert_after(Stmt *old_statement, VecStatement &&new_statements) {
337 insert_at(std::move(new_statements), std::next(find(old_statement)));
338}
339
340void Block::replace_with(Stmt *old_statement,
341 VecStatement &&new_statements,
342 bool replace_usages) {
343 auto iter = find(old_statement);
344 TI_ASSERT(iter != statements.end());
345 if (replace_usages && !new_statements.stmts.empty())
346 old_statement->replace_usages_with(new_statements.back().get());
347 trash_bin.push_back(std::move(*iter));
348 if (new_statements.size() == 1) {
349 // Keep all std::vector::iterator valid in this case.
350 *iter = std::move(new_statements[0]);
351 (*iter)->parent = this;
352 } else {
353 iter = statements.erase(iter);
354 insert_at(std::move(new_statements), iter);
355 }
356}
357
358Block *Block::parent_block() const {
359 if (parent_stmt == nullptr)
360 return nullptr;
361 return parent_stmt->parent;
362}
363
364IRNode *Block::get_parent() const {
365 return parent_stmt;
366}
367
368bool Block::has_container_statements() {
369 for (auto &s : statements) {
370 if (s->is_container_statement())
371 return true;
372 }
373 return false;
374}
375
376int Block::locate(Stmt *stmt) {
377 for (int i = 0; i < (int)statements.size(); i++) {
378 if (statements[i].get() == stmt) {
379 return i;
380 }
381 }
382 return -1;
383}
384
385stmt_vector::iterator Block::locate(int location) {
386 if (location == -1)
387 return statements.end();
388 return statements.begin() + location;
389}
390
391stmt_vector::iterator Block::find(Stmt *stmt) {
392 return std::find_if(statements.begin(), statements.end(),
393 [stmt](const pStmt &x) { return x.get() == stmt; });
394}
395
396std::unique_ptr<Block> Block::clone() const {
397 auto new_block = std::make_unique<Block>();
398 new_block->parent_stmt = parent_stmt;
399 new_block->stop_gradients = stop_gradients;
400 new_block->statements.reserve(size());
401 for (auto &stmt : statements)
402 new_block->insert(stmt->clone());
403 return new_block;
404}
405
406DelayedIRModifier::~DelayedIRModifier() {
407 // TODO: destructors should not be interrupted
408 TI_ASSERT(to_insert_before_.empty());
409 TI_ASSERT(to_insert_after_.empty());
410 TI_ASSERT(to_erase_.empty());
411 TI_ASSERT(to_replace_with_.empty());
412 TI_ASSERT(to_extract_to_block_front_.empty());
413 TI_ASSERT(to_type_check_.empty());
414}
415
416void DelayedIRModifier::erase(Stmt *stmt) {
417 to_erase_.push_back(stmt);
418}
419
420void DelayedIRModifier::insert_before(Stmt *old_statement,
421 std::unique_ptr<Stmt> new_statements) {
422 to_insert_before_.emplace_back(old_statement,
423 VecStatement(std::move(new_statements)));
424}
425
426void DelayedIRModifier::insert_before(Stmt *old_statement,
427 VecStatement &&new_statements) {
428 to_insert_before_.emplace_back(old_statement, std::move(new_statements));
429}
430
431void DelayedIRModifier::insert_after(Stmt *old_statement,
432 std::unique_ptr<Stmt> new_statements) {
433 to_insert_after_.emplace_back(old_statement,
434 VecStatement(std::move(new_statements)));
435}
436
437void DelayedIRModifier::insert_after(Stmt *old_statement,
438 VecStatement &&new_statements) {
439 to_insert_after_.emplace_back(old_statement, std::move(new_statements));
440}
441
442void DelayedIRModifier::replace_with(Stmt *stmt,
443 VecStatement &&new_statements,
444 bool replace_usages) {
445 to_replace_with_.emplace_back(stmt, std::move(new_statements),
446 replace_usages);
447}
448
449void DelayedIRModifier::extract_to_block_front(Stmt *stmt, Block *blk) {
450 to_extract_to_block_front_.emplace_back(stmt, blk);
451}
452
453void DelayedIRModifier::type_check(IRNode *node, CompileConfig cfg) {
454 to_type_check_.emplace_back(node, cfg);
455}
456
457bool DelayedIRModifier::modify_ir() {
458 bool force_modified = modified_;
459 modified_ = false;
460 if (to_insert_before_.empty() && to_insert_after_.empty() &&
461 to_erase_.empty() && to_replace_with_.empty() &&
462 to_extract_to_block_front_.empty() && to_type_check_.empty())
463 return force_modified;
464 for (auto &i : to_insert_before_) {
465 i.first->parent->insert_before(i.first, std::move(i.second));
466 }
467 to_insert_before_.clear();
468 for (auto &i : to_insert_after_) {
469 i.first->parent->insert_after(i.first, std::move(i.second));
470 }
471 to_insert_after_.clear();
472 for (auto &stmt : to_erase_) {
473 stmt->parent->erase(stmt);
474 }
475 to_erase_.clear();
476 for (auto &i : to_replace_with_) {
477 std::get<0>(i)->replace_with(std::move(std::get<1>(i)), std::get<2>(i));
478 }
479 to_replace_with_.clear();
480 for (auto &i : to_extract_to_block_front_) {
481 auto extracted = i.first->parent->extract(i.first);
482 i.second->insert(std::move(extracted), 0);
483 }
484 to_extract_to_block_front_.clear();
485 for (auto &i : to_type_check_) {
486 irpass::type_check(i.first, i.second);
487 }
488 to_type_check_.clear();
489 return true;
490}
491
492void DelayedIRModifier::mark_as_modified() {
493 modified_ = true;
494}
495
496ImmediateIRModifier::ImmediateIRModifier(IRNode *root) {
497 stmt_usages_ = irpass::analysis::gather_statement_usages(root);
498}
499
500void ImmediateIRModifier::replace_usages_with(Stmt *old_stmt, Stmt *new_stmt) {
501 if (stmt_usages_.find(old_stmt) == stmt_usages_.end())
502 return;
503 for (auto &[usage, i] : stmt_usages_.at(old_stmt)) {
504 usage->set_operand(i, new_stmt);
505 }
506}
507
508} // namespace taichi::lang
509