1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/program/compile_config.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | class ExtractConstant : public BasicStmtVisitor { |
10 | private: |
11 | Block *top_level_; |
12 | DelayedIRModifier modifier_; |
13 | |
14 | public: |
15 | using BasicStmtVisitor::visit; |
16 | |
17 | explicit ExtractConstant(IRNode *node) : top_level_(nullptr) { |
18 | if (node->is<Block>()) |
19 | top_level_ = node->as<Block>(); |
20 | } |
21 | |
22 | void visit(ConstStmt *stmt) override { |
23 | TI_ASSERT(top_level_); |
24 | if (stmt->parent != top_level_) { |
25 | modifier_.extract_to_block_front(stmt, top_level_); |
26 | } |
27 | } |
28 | |
29 | void visit(OffloadedStmt *offload) override { |
30 | if (offload->body) { |
31 | Block *backup = top_level_; |
32 | top_level_ = offload->body.get(); |
33 | offload->body->accept(this); |
34 | top_level_ = backup; |
35 | } |
36 | } |
37 | |
38 | static bool run(IRNode *node) { |
39 | ExtractConstant extractor(node); |
40 | bool ir_modified = false; |
41 | while (true) { |
42 | node->accept(&extractor); |
43 | if (extractor.modifier_.modify_ir()) { |
44 | ir_modified = true; |
45 | } else { |
46 | break; |
47 | } |
48 | } |
49 | return ir_modified; |
50 | } |
51 | }; |
52 | |
53 | namespace irpass { |
54 | bool extract_constant(IRNode *root, const CompileConfig &config) { |
55 | TI_AUTO_PROF; |
56 | if (config.advanced_optimization) { |
57 | return ExtractConstant::run(root); |
58 | } else { |
59 | return false; |
60 | } |
61 | } |
62 | } // namespace irpass |
63 | |
64 | } // namespace taichi::lang |
65 |