1// The CUDA backend
2
3#pragma once
4
5#include "taichi/codegen/codegen.h"
6#include "taichi/codegen/llvm/codegen_llvm.h"
7
8namespace taichi::lang {
9
10class KernelCodeGenCUDA : public KernelCodeGen {
11 public:
12 explicit KernelCodeGenCUDA(const CompileConfig &compile_config,
13 Kernel *kernel,
14 TaichiLLVMContext &tlctx)
15 : KernelCodeGen(compile_config, kernel, tlctx) {
16 }
17
18// TODO: Stop defining this macro guards in the headers
19#ifdef TI_WITH_LLVM
20 LLVMCompiledTask compile_task(
21 const CompileConfig &config,
22 std::unique_ptr<llvm::Module> &&module = nullptr,
23 OffloadedStmt *stmt = nullptr) override;
24#endif // TI_WITH_LLVM
25
26 bool supports_offline_cache() const override {
27 return true;
28 }
29
30 FunctionType compile_to_function() override;
31};
32
33class CUDAModuleToFunctionConverter : public ModuleToFunctionConverter {
34 public:
35 explicit CUDAModuleToFunctionConverter(TaichiLLVMContext *tlctx,
36 LlvmRuntimeExecutor *executor)
37 : ModuleToFunctionConverter(tlctx, executor) {
38 }
39 using ModuleToFunctionConverter::convert;
40
41 FunctionType convert(const std::string &kernel_name,
42 const std::vector<LlvmLaunchArgInfo> &args,
43 LLVMCompiledKernel data) const override;
44};
45
46} // namespace taichi::lang
47