1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | // Transform each filtered statement |
9 | class StatementsTransformer : public BasicStmtVisitor { |
10 | public: |
11 | using BasicStmtVisitor::visit; |
12 | |
13 | StatementsTransformer( |
14 | std::function<bool(Stmt *)> filter, |
15 | std::function<void(Stmt *, DelayedIRModifier *)> transformer) |
16 | : filter_(std::move(filter)), transformer_(std::move(transformer)) { |
17 | allow_undefined_visitor = true; |
18 | invoke_default_visitor = true; |
19 | } |
20 | |
21 | void maybe_transform(Stmt *stmt) { |
22 | if (filter_(stmt)) { |
23 | transformer_(stmt, &modifier_); |
24 | } |
25 | } |
26 | |
27 | void preprocess_container_stmt(Stmt *stmt) override { |
28 | maybe_transform(stmt); |
29 | } |
30 | |
31 | void visit(Stmt *stmt) override { |
32 | maybe_transform(stmt); |
33 | } |
34 | |
35 | static bool run(IRNode *root, |
36 | std::function<bool(Stmt *)> filter, |
37 | std::function<void(Stmt *, DelayedIRModifier *)> replacer) { |
38 | StatementsTransformer transformer(std::move(filter), std::move(replacer)); |
39 | root->accept(&transformer); |
40 | return transformer.modifier_.modify_ir(); |
41 | } |
42 | |
43 | private: |
44 | std::function<bool(Stmt *)> filter_; |
45 | std::function<void(Stmt *, DelayedIRModifier *)> transformer_; |
46 | DelayedIRModifier modifier_; |
47 | }; |
48 | |
49 | namespace irpass { |
50 | |
51 | bool transform_statements( |
52 | IRNode *root, |
53 | std::function<bool(Stmt *)> filter, |
54 | std::function<void(Stmt *, DelayedIRModifier *)> transformer) { |
55 | return StatementsTransformer::run(root, std::move(filter), |
56 | std::move(transformer)); |
57 | } |
58 | |
59 | } // namespace irpass |
60 | |
61 | } // namespace taichi::lang |
62 | |