1#include "taichi/ir/transforms.h"
2
3namespace taichi::lang {
4
5namespace irpass {
6
7bool replace_and_insert_statements(
8 IRNode *root,
9 std::function<bool(Stmt *)> filter,
10 std::function<std::unique_ptr<Stmt>(Stmt *)> generator) {
11 return transform_statements(root, std::move(filter),
12 [&](Stmt *stmt, DelayedIRModifier *modifier) {
13 modifier->replace_with(stmt, generator(stmt));
14 });
15}
16
17bool replace_statements(IRNode *root,
18 std::function<bool(Stmt *)> filter,
19 std::function<Stmt *(Stmt *)> finder) {
20 return transform_statements(
21 root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) {
22 auto existing_new_stmt = finder(stmt);
23 irpass::replace_all_usages_with(root, stmt, existing_new_stmt);
24 modifier->erase(stmt);
25 });
26}
27
28} // namespace irpass
29
30} // namespace taichi::lang
31