1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/program/compile_config.h"
6
7namespace taichi::lang {
8
9class ExtractConstant : public BasicStmtVisitor {
10 private:
11 Block *top_level_;
12 DelayedIRModifier modifier_;
13
14 public:
15 using BasicStmtVisitor::visit;
16
17 explicit ExtractConstant(IRNode *node) : top_level_(nullptr) {
18 if (node->is<Block>())
19 top_level_ = node->as<Block>();
20 }
21
22 void visit(ConstStmt *stmt) override {
23 TI_ASSERT(top_level_);
24 if (stmt->parent != top_level_) {
25 modifier_.extract_to_block_front(stmt, top_level_);
26 }
27 }
28
29 void visit(OffloadedStmt *offload) override {
30 if (offload->body) {
31 Block *backup = top_level_;
32 top_level_ = offload->body.get();
33 offload->body->accept(this);
34 top_level_ = backup;
35 }
36 }
37
38 static bool run(IRNode *node) {
39 ExtractConstant extractor(node);
40 bool ir_modified = false;
41 while (true) {
42 node->accept(&extractor);
43 if (extractor.modifier_.modify_ir()) {
44 ir_modified = true;
45 } else {
46 break;
47 }
48 }
49 return ir_modified;
50 }
51};
52
53namespace irpass {
54bool extract_constant(IRNode *root, const CompileConfig &config) {
55 TI_AUTO_PROF;
56 if (config.advanced_optimization) {
57 return ExtractConstant::run(root);
58 } else {
59 return false;
60 }
61}
62} // namespace irpass
63
64} // namespace taichi::lang
65