1 | #include "taichi/ir/transforms.h" |
2 | |
3 | namespace taichi::lang { |
4 | |
5 | namespace irpass { |
6 | |
7 | bool replace_and_insert_statements( |
8 | IRNode *root, |
9 | std::function<bool(Stmt *)> filter, |
10 | std::function<std::unique_ptr<Stmt>(Stmt *)> generator) { |
11 | return transform_statements(root, std::move(filter), |
12 | [&](Stmt *stmt, DelayedIRModifier *modifier) { |
13 | modifier->replace_with(stmt, generator(stmt)); |
14 | }); |
15 | } |
16 | |
17 | bool replace_statements(IRNode *root, |
18 | std::function<bool(Stmt *)> filter, |
19 | std::function<Stmt *(Stmt *)> finder) { |
20 | return transform_statements( |
21 | root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) { |
22 | auto existing_new_stmt = finder(stmt); |
23 | irpass::replace_all_usages_with(root, stmt, existing_new_stmt); |
24 | modifier->erase(stmt); |
25 | }); |
26 | } |
27 | |
28 | } // namespace irpass |
29 | |
30 | } // namespace taichi::lang |
31 | |