1 | #include "taichi/ir/transforms.h" |
---|---|
2 | #include "taichi/ir/visitors.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/program/function.h" |
5 | #include "taichi/program/compile_config.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | class CompileTaichiFunctions : public BasicStmtVisitor { |
10 | public: |
11 | using BasicStmtVisitor::visit; |
12 | |
13 | explicit CompileTaichiFunctions(const CompileConfig &compile_config) |
14 | : compile_config_(compile_config) { |
15 | } |
16 | |
17 | void visit(FuncCallStmt *stmt) override { |
18 | using IRType = Function::IRType; |
19 | auto *func = stmt->func; |
20 | const auto ir_type = func->ir_type(); |
21 | if (ir_type != IRType::OptimizedIR) { |
22 | TI_ASSERT(ir_type == IRType::AST || ir_type == IRType::InitialIR); |
23 | func->set_ir_type(IRType::OptimizedIR); |
24 | irpass::compile_function(func->ir.get(), compile_config_, func, |
25 | /*autodiff_mode=*/AutodiffMode::kNone, |
26 | /*verbose=*/compile_config_.print_ir, |
27 | /*start_from_ast=*/ir_type == IRType::AST); |
28 | func->ir->accept(this); |
29 | } |
30 | } |
31 | |
32 | static void run(IRNode *ir, const CompileConfig &compile_config) { |
33 | CompileTaichiFunctions ctf{compile_config}; |
34 | ir->accept(&ctf); |
35 | } |
36 | |
37 | private: |
38 | const CompileConfig &compile_config_; |
39 | }; |
40 | |
41 | namespace irpass { |
42 | |
43 | void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config) { |
44 | TI_AUTO_PROF; |
45 | CompileTaichiFunctions::run(ir, compile_config); |
46 | } |
47 | |
48 | } // namespace irpass |
49 | |
50 | } // namespace taichi::lang |
51 |