1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5
6namespace taichi::lang {
7
8// Transform each filtered statement
9class StatementsTransformer : public BasicStmtVisitor {
10 public:
11 using BasicStmtVisitor::visit;
12
13 StatementsTransformer(
14 std::function<bool(Stmt *)> filter,
15 std::function<void(Stmt *, DelayedIRModifier *)> transformer)
16 : filter_(std::move(filter)), transformer_(std::move(transformer)) {
17 allow_undefined_visitor = true;
18 invoke_default_visitor = true;
19 }
20
21 void maybe_transform(Stmt *stmt) {
22 if (filter_(stmt)) {
23 transformer_(stmt, &modifier_);
24 }
25 }
26
27 void preprocess_container_stmt(Stmt *stmt) override {
28 maybe_transform(stmt);
29 }
30
31 void visit(Stmt *stmt) override {
32 maybe_transform(stmt);
33 }
34
35 static bool run(IRNode *root,
36 std::function<bool(Stmt *)> filter,
37 std::function<void(Stmt *, DelayedIRModifier *)> replacer) {
38 StatementsTransformer transformer(std::move(filter), std::move(replacer));
39 root->accept(&transformer);
40 return transformer.modifier_.modify_ir();
41 }
42
43 private:
44 std::function<bool(Stmt *)> filter_;
45 std::function<void(Stmt *, DelayedIRModifier *)> transformer_;
46 DelayedIRModifier modifier_;
47};
48
49namespace irpass {
50
51bool transform_statements(
52 IRNode *root,
53 std::function<bool(Stmt *)> filter,
54 std::function<void(Stmt *, DelayedIRModifier *)> transformer) {
55 return StatementsTransformer::run(root, std::move(filter),
56 std::move(transformer));
57}
58
59} // namespace irpass
60
61} // namespace taichi::lang
62