1 | // Driver class for kernel code generators. |
2 | |
3 | #pragma once |
4 | #include <taichi/runtime/llvm/llvm_runtime_executor.h> |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/program/program.h" |
7 | #ifdef TI_WITH_LLVM |
8 | #include "llvm/IR/Module.h" |
9 | #include "taichi/codegen/llvm/codegen_llvm.h" |
10 | #include "taichi/runtime/llvm/launch_arg_info.h" |
11 | #include "taichi/codegen/llvm/llvm_codegen_utils.h" |
12 | #endif |
13 | namespace taichi::lang { |
14 | class TaichiLLVMContext; |
15 | |
16 | /* |
17 | [Note] Codegen of LLVM-based backends |
18 | * KernelCodeGen is the base class of the codegen of all backends using LLVM. |
19 | * Function `compile_to_function` first compiles the IR of a kernel |
20 | * into a LLVM module using `compile_kernel_to_module`, and then constructs a |
21 | * function for runtime execution using `ModuleToFunctionConverter`. |
22 | * |
23 | * Function `compile_kernel_to_module` compiles the IR of a kernel into a LLVM |
24 | * module. A kernel is composed of several offloaded tasks. To compile a kernel, |
25 | * we first compile each task independently into an LLVM module using function |
26 | * `compile_task`. Then, we link the LLVM modules of the offloaded tasks, |
27 | * the runtime module and the struct modules of the SNode trees which are used |
28 | * in the kernel all together into a single LLVM module using |
29 | * `tlctx->link_compiled_tasks`. The LLVM module and the names of the entry |
30 | * functions of the offloaded tasks in the module are stored in the returned |
31 | * LLVMCompiledKernel. |
32 | * |
33 | * Function `compile_task` uses `TaskCodeGen` of the respective backend to |
34 | * compile the IR of a offloaded task to an LLVM module. It also generates some |
35 | * extra information for linking such as which SNode tree is used in the task. |
36 | * The LLVM module, the name of the entry function of the offloaded task in the |
37 | * module and the extra information are stored in the returned LLVMCompiledTask. |
38 | */ |
39 | class KernelCodeGen { |
40 | protected: |
41 | Program *prog; |
42 | Kernel *kernel; |
43 | IRNode *ir; |
44 | |
45 | public: |
46 | explicit KernelCodeGen(const CompileConfig &compile_config, |
47 | Kernel *kernel, |
48 | TaichiLLVMContext &tlctx); |
49 | |
50 | virtual ~KernelCodeGen() = default; |
51 | |
52 | static std::unique_ptr<KernelCodeGen> create( |
53 | const CompileConfig &compile_config, |
54 | Kernel *kernel, |
55 | TaichiLLVMContext &tlctx); |
56 | |
57 | virtual FunctionType compile_to_function() = 0; |
58 | virtual bool supports_offline_cache() const { |
59 | return false; |
60 | } |
61 | |
62 | #ifdef TI_WITH_LLVM |
63 | virtual LLVMCompiledKernel compile_kernel_to_module(); |
64 | |
65 | virtual LLVMCompiledTask compile_task( |
66 | const CompileConfig &config, |
67 | std::unique_ptr<llvm::Module> &&module = nullptr, |
68 | OffloadedStmt *stmt = nullptr){TI_NOT_IMPLEMENTED} |
69 | |
70 | std::optional<LLVMCompiledKernel> maybe_read_compilation_from_cache( |
71 | const std::string &kernel_key); |
72 | void cache_kernel(const std::string &kernel_key, |
73 | const LLVMCompiledKernel &data); |
74 | #endif |
75 | protected: |
76 | const CompileConfig &get_compile_config() const { |
77 | return compile_config_; |
78 | } |
79 | |
80 | TaichiLLVMContext &get_taichi_llvm_context() { |
81 | return tlctx_; |
82 | } |
83 | |
84 | private: |
85 | const CompileConfig &compile_config_; |
86 | TaichiLLVMContext &tlctx_; |
87 | }; |
88 | |
89 | #ifdef TI_WITH_LLVM |
90 | |
91 | class ModuleToFunctionConverter { |
92 | public: |
93 | explicit ModuleToFunctionConverter(TaichiLLVMContext *tlctx, |
94 | LlvmRuntimeExecutor *program); |
95 | |
96 | virtual ~ModuleToFunctionConverter() = default; |
97 | |
98 | virtual FunctionType convert(const std::string &kernel_name, |
99 | const std::vector<LlvmLaunchArgInfo> &args, |
100 | LLVMCompiledKernel data) const = 0; |
101 | |
102 | virtual FunctionType convert(const Kernel *kernel, |
103 | LLVMCompiledKernel data) const; |
104 | |
105 | protected: |
106 | TaichiLLVMContext *tlctx_{nullptr}; |
107 | LlvmRuntimeExecutor *executor_{nullptr}; |
108 | }; |
109 | |
110 | #endif |
111 | } // namespace taichi::lang |
112 | |