1 | // The CUDA backend |
2 | |
3 | #pragma once |
4 | |
5 | #include "taichi/codegen/codegen.h" |
6 | #include "taichi/codegen/llvm/codegen_llvm.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | class KernelCodeGenCUDA : public KernelCodeGen { |
11 | public: |
12 | explicit KernelCodeGenCUDA(const CompileConfig &compile_config, |
13 | Kernel *kernel, |
14 | TaichiLLVMContext &tlctx) |
15 | : KernelCodeGen(compile_config, kernel, tlctx) { |
16 | } |
17 | |
18 | // TODO: Stop defining this macro guards in the headers |
19 | #ifdef TI_WITH_LLVM |
20 | LLVMCompiledTask compile_task( |
21 | const CompileConfig &config, |
22 | std::unique_ptr<llvm::Module> &&module = nullptr, |
23 | OffloadedStmt *stmt = nullptr) override; |
24 | #endif // TI_WITH_LLVM |
25 | |
26 | bool supports_offline_cache() const override { |
27 | return true; |
28 | } |
29 | |
30 | FunctionType compile_to_function() override; |
31 | }; |
32 | |
33 | class CUDAModuleToFunctionConverter : public ModuleToFunctionConverter { |
34 | public: |
35 | explicit CUDAModuleToFunctionConverter(TaichiLLVMContext *tlctx, |
36 | LlvmRuntimeExecutor *executor) |
37 | : ModuleToFunctionConverter(tlctx, executor) { |
38 | } |
39 | using ModuleToFunctionConverter::convert; |
40 | |
41 | FunctionType convert(const std::string &kernel_name, |
42 | const std::vector<LlvmLaunchArgInfo> &args, |
43 | LLVMCompiledKernel data) const override; |
44 | }; |
45 | |
46 | } // namespace taichi::lang |
47 | |