1#include "taichi/rhi/cuda/cuda_driver.h"
2
3#include "taichi/system/dynamic_loader.h"
4#include "taichi/rhi/cuda/cuda_context.h"
5#include "taichi/util/environ_config.h"
6
7namespace taichi::lang {
8
9std::string get_cuda_error_message(uint32 err) {
10 const char *err_name_ptr;
11 const char *err_string_ptr;
12 CUDADriver::get_instance_without_context().get_error_name(err, &err_name_ptr);
13 CUDADriver::get_instance_without_context().get_error_string(err,
14 &err_string_ptr);
15 return fmt::format("CUDA Error {}: {}", err_name_ptr, err_string_ptr);
16}
17
18CUDADriverBase::CUDADriverBase() {
19 disabled_by_env_ = (get_environ_config("TI_ENABLE_CUDA", 1) == 0);
20 if (disabled_by_env_) {
21 TI_TRACE("CUDA driver disabled by enviroment variable \"TI_ENABLE_CUDA\".");
22 }
23}
24
25bool CUDADriverBase::load_lib(std::string lib_linux, std::string lib_windows) {
26#if defined(TI_PLATFORM_LINUX)
27 auto lib_name = lib_linux;
28#elif defined(TI_PLATFORM_WINDOWS)
29 auto lib_name = lib_windows;
30#else
31 static_assert(false, "Taichi CUDA driver supports only Windows and Linux.");
32#endif
33
34 loader_ = std::make_unique<DynamicLoader>(lib_name);
35 if (!loader_->loaded()) {
36 TI_WARN("{} lib not found.", lib_name);
37 return false;
38 } else {
39 TI_TRACE("{} loaded!", lib_name);
40 return true;
41 }
42}
43
44bool CUDADriver::detected() {
45 return !disabled_by_env_ && cuda_version_valid_ && loader_->loaded();
46}
47
48CUDADriver::CUDADriver() {
49 if (!load_lib("libcuda.so", "nvcuda.dll"))
50 return;
51
52 loader_->load_function("cuGetErrorName", get_error_name);
53 loader_->load_function("cuGetErrorString", get_error_string);
54 loader_->load_function("cuDriverGetVersion", driver_get_version);
55
56 int version;
57 driver_get_version(&version);
58 TI_TRACE("CUDA driver API (v{}.{}) loaded.", version / 1000,
59 version % 1000 / 10);
60
61 // CUDA versions should >= 10.
62 if (version < 10000) {
63 TI_WARN("The Taichi CUDA backend requires at least CUDA 10.0, got v{}.{}.",
64 version / 1000, version % 1000 / 10);
65 return;
66 }
67
68 cuda_version_valid_ = true;
69#define PER_CUDA_FUNCTION(name, symbol_name, ...) \
70 name.set(loader_->load_function(#symbol_name)); \
71 name.set_lock(&lock_); \
72 name.set_names(#name, #symbol_name);
73#include "taichi/rhi/cuda/cuda_driver_functions.inc.h"
74#undef PER_CUDA_FUNCTION
75}
76
77// This is for initializing the CUDA driver itself
78CUDADriver &CUDADriver::get_instance_without_context() {
79 // Thread safety guaranteed by C++ compiler
80 // Note this is never deleted until the process finishes
81 static CUDADriver *instance = new CUDADriver();
82 return *instance;
83}
84
85CUDADriver &CUDADriver::get_instance() {
86 // initialize the CUDA context so that the driver APIs can be called later
87 CUDAContext::get_instance();
88 return get_instance_without_context();
89}
90
91CUSPARSEDriver::CUSPARSEDriver() {
92}
93
94CUSPARSEDriver &CUSPARSEDriver::get_instance() {
95 static CUSPARSEDriver *instance = new CUSPARSEDriver();
96 return *instance;
97}
98
99bool CUSPARSEDriver::load_cusparse() {
100 cusparse_loaded_ = load_lib("libcusparse.so", "cusparse64_11.dll");
101
102 if (!cusparse_loaded_) {
103 return false;
104 }
105#define PER_CUSPARSE_FUNCTION(name, symbol_name, ...) \
106 name.set(loader_->load_function(#symbol_name)); \
107 name.set_lock(&lock_); \
108 name.set_names(#name, #symbol_name);
109#include "taichi/rhi/cuda/cusparse_functions.inc.h"
110#undef PER_CUSPARSE_FUNCTION
111 return cusparse_loaded_;
112}
113
114CUSOLVERDriver::CUSOLVERDriver() {
115}
116
117CUSOLVERDriver &CUSOLVERDriver::get_instance() {
118 static CUSOLVERDriver *instance = new CUSOLVERDriver();
119 return *instance;
120}
121
122bool CUSOLVERDriver::load_cusolver() {
123 cusolver_loaded_ = load_lib("libcusolver.so", "cusolver64_11.dll");
124 if (!cusolver_loaded_) {
125 return false;
126 }
127#define PER_CUSOLVER_FUNCTION(name, symbol_name, ...) \
128 name.set(loader_->load_function(#symbol_name)); \
129 name.set_lock(&lock_); \
130 name.set_names(#name, #symbol_name);
131#include "taichi/rhi/cuda/cusolver_functions.inc.h"
132#undef PER_CUSOLVER_FUNCTION
133 return cusolver_loaded_;
134}
135} // namespace taichi::lang
136