1#pragma once
2
3#include <mutex>
4
5#include "taichi/system/dynamic_loader.h"
6#include "taichi/rhi/cuda/cuda_types.h"
7
8#if (0)
9// Turn on to check for compatibility
10namespace taichi {
11static_assert(sizeof(CUresult) == sizeof(uint32));
12static_assert(sizeof(CUmem_advise) == sizeof(uint32));
13static_assert(sizeof(CUdevice) == sizeof(uint32));
14static_assert(sizeof(CUdevice_attribute) == sizeof(uint32));
15static_assert(sizeof(CUfunction) == sizeof(void *));
16static_assert(sizeof(CUmodule) == sizeof(void *));
17static_assert(sizeof(CUstream) == sizeof(void *));
18static_assert(sizeof(CUevent) == sizeof(void *));
19static_assert(sizeof(CUjit_option) == sizeof(uint32));
20} // namespace taichi
21#endif
22
23namespace taichi::lang {
24
25// Driver constants from cuda.h
26
27constexpr uint32 CU_EVENT_DEFAULT = 0x0;
28constexpr uint32 CU_STREAM_DEFAULT = 0x0;
29constexpr uint32 CU_STREAM_NON_BLOCKING = 0x1;
30constexpr uint32 CU_MEM_ATTACH_GLOBAL = 0x1;
31constexpr uint32 CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3;
32constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2;
33constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106;
34constexpr uint32 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16;
35constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75;
36constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76;
37constexpr uint32 CUDA_ERROR_ASSERT = 710;
38constexpr uint32 CU_JIT_MAX_REGISTERS = 0;
39constexpr uint32 CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2;
40constexpr uint32 CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41;
41constexpr uint32 CUDA_SUCCESS = 0;
42constexpr uint32 CU_MEMORYTYPE_DEVICE = 2;
43constexpr uint32 CU_LIMIT_STACK_SIZE = 0;
44
45std::string get_cuda_error_message(uint32 err);
46
47template <typename... Args>
48class CUDADriverFunction {
49 public:
50 CUDADriverFunction() {
51 function_ = nullptr;
52 }
53
54 void set(void *func_ptr) {
55 function_ = (func_type *)func_ptr;
56 }
57
58 uint32 call(Args... args) {
59 TI_ASSERT(function_ != nullptr);
60 TI_ASSERT(driver_lock_ != nullptr);
61 std::lock_guard<std::mutex> _(*driver_lock_);
62 return (uint32)function_(args...);
63 }
64
65 void set_names(const std::string &name, const std::string &symbol_name) {
66 name_ = name;
67 symbol_name_ = symbol_name;
68 }
69
70 void set_lock(std::mutex *lock) {
71 driver_lock_ = lock;
72 }
73
74 std::string get_error_message(uint32 err) {
75 return get_cuda_error_message(err) +
76 fmt::format(" while calling {} ({})", name_, symbol_name_);
77 }
78
79 uint32 call_with_warning(Args... args) {
80 auto err = call(args...);
81 TI_WARN_IF(err, "{}", get_error_message(err));
82 return err;
83 }
84
85 // Note: CUDA driver API passes everything as value
86 void operator()(Args... args) {
87 auto err = call(args...);
88 TI_ERROR_IF(err, get_error_message(err));
89 }
90
91 private:
92 using func_type = uint32_t(Args...);
93
94 func_type *function_{nullptr};
95 std::string name_, symbol_name_;
96 std::mutex *driver_lock_{nullptr};
97};
98
99class CUDADriverBase {
100 public:
101 ~CUDADriverBase() = default;
102
103 protected:
104 std::unique_ptr<DynamicLoader> loader_;
105 CUDADriverBase();
106
107 bool load_lib(std::string lib_linux, std::string lib_windows);
108
109 bool disabled_by_env_{false};
110};
111
112class CUDADriver : protected CUDADriverBase {
113 public:
114#define PER_CUDA_FUNCTION(name, symbol_name, ...) \
115 CUDADriverFunction<__VA_ARGS__> name;
116#include "taichi/rhi/cuda/cuda_driver_functions.inc.h"
117#undef PER_CUDA_FUNCTION
118
119 void (*get_error_name)(uint32, const char **);
120
121 void (*get_error_string)(uint32, const char **);
122
123 void (*driver_get_version)(int *);
124
125 bool detected();
126
127 static CUDADriver &get_instance();
128
129 static CUDADriver &get_instance_without_context();
130
131 private:
132 CUDADriver();
133
134 std::mutex lock_;
135
136 bool cuda_version_valid_{false};
137};
138
139class CUSPARSEDriver : protected CUDADriverBase {
140 public:
141 static CUSPARSEDriver &get_instance();
142
143#define PER_CUSPARSE_FUNCTION(name, symbol_name, ...) \
144 CUDADriverFunction<__VA_ARGS__> name;
145#include "taichi/rhi/cuda/cusparse_functions.inc.h"
146#undef PER_CUSPARSE_FUNCTION
147
148 bool load_cusparse();
149
150 inline bool is_loaded() {
151 return cusparse_loaded_;
152 }
153
154 private:
155 CUSPARSEDriver();
156 std::mutex lock_;
157 bool cusparse_loaded_{false};
158};
159
160class CUSOLVERDriver : protected CUDADriverBase {
161 public:
162 // TODO: Add cusolver function APIs
163 static CUSOLVERDriver &get_instance();
164
165#define PER_CUSOLVER_FUNCTION(name, symbol_name, ...) \
166 CUDADriverFunction<__VA_ARGS__> name;
167#include "taichi/rhi/cuda/cusolver_functions.inc.h"
168#undef PER_CUSOLVER_FUNCTION
169
170 bool load_cusolver();
171
172 inline bool is_loaded() {
173 return cusolver_loaded_;
174 }
175
176 private:
177 CUSOLVERDriver();
178 std::mutex lock_;
179 bool cusolver_loaded_{false};
180};
181
182} // namespace taichi::lang
183