1 | #pragma once |
2 | |
3 | #include <mutex> |
4 | |
5 | #include "taichi/system/dynamic_loader.h" |
6 | #include "taichi/rhi/cuda/cuda_types.h" |
7 | |
8 | #if (0) |
9 | // Turn on to check for compatibility |
10 | namespace taichi { |
11 | static_assert(sizeof(CUresult) == sizeof(uint32)); |
12 | static_assert(sizeof(CUmem_advise) == sizeof(uint32)); |
13 | static_assert(sizeof(CUdevice) == sizeof(uint32)); |
14 | static_assert(sizeof(CUdevice_attribute) == sizeof(uint32)); |
15 | static_assert(sizeof(CUfunction) == sizeof(void *)); |
16 | static_assert(sizeof(CUmodule) == sizeof(void *)); |
17 | static_assert(sizeof(CUstream) == sizeof(void *)); |
18 | static_assert(sizeof(CUevent) == sizeof(void *)); |
19 | static_assert(sizeof(CUjit_option) == sizeof(uint32)); |
20 | } // namespace taichi |
21 | #endif |
22 | |
23 | namespace taichi::lang { |
24 | |
25 | // Driver constants from cuda.h |
26 | |
27 | constexpr uint32 CU_EVENT_DEFAULT = 0x0; |
28 | constexpr uint32 CU_STREAM_DEFAULT = 0x0; |
29 | constexpr uint32 CU_STREAM_NON_BLOCKING = 0x1; |
30 | constexpr uint32 CU_MEM_ATTACH_GLOBAL = 0x1; |
31 | constexpr uint32 CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3; |
32 | constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2; |
33 | constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106; |
34 | constexpr uint32 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16; |
35 | constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75; |
36 | constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76; |
37 | constexpr uint32 CUDA_ERROR_ASSERT = 710; |
38 | constexpr uint32 CU_JIT_MAX_REGISTERS = 0; |
39 | constexpr uint32 CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2; |
40 | constexpr uint32 CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41; |
41 | constexpr uint32 CUDA_SUCCESS = 0; |
42 | constexpr uint32 CU_MEMORYTYPE_DEVICE = 2; |
43 | constexpr uint32 CU_LIMIT_STACK_SIZE = 0; |
44 | |
45 | std::string get_cuda_error_message(uint32 err); |
46 | |
47 | template <typename... Args> |
48 | class CUDADriverFunction { |
49 | public: |
50 | CUDADriverFunction() { |
51 | function_ = nullptr; |
52 | } |
53 | |
54 | void set(void *func_ptr) { |
55 | function_ = (func_type *)func_ptr; |
56 | } |
57 | |
58 | uint32 call(Args... args) { |
59 | TI_ASSERT(function_ != nullptr); |
60 | TI_ASSERT(driver_lock_ != nullptr); |
61 | std::lock_guard<std::mutex> _(*driver_lock_); |
62 | return (uint32)function_(args...); |
63 | } |
64 | |
65 | void set_names(const std::string &name, const std::string &symbol_name) { |
66 | name_ = name; |
67 | symbol_name_ = symbol_name; |
68 | } |
69 | |
70 | void set_lock(std::mutex *lock) { |
71 | driver_lock_ = lock; |
72 | } |
73 | |
74 | std::string get_error_message(uint32 err) { |
75 | return get_cuda_error_message(err) + |
76 | fmt::format(" while calling {} ({})" , name_, symbol_name_); |
77 | } |
78 | |
79 | uint32 call_with_warning(Args... args) { |
80 | auto err = call(args...); |
81 | TI_WARN_IF(err, "{}" , get_error_message(err)); |
82 | return err; |
83 | } |
84 | |
85 | // Note: CUDA driver API passes everything as value |
86 | void operator()(Args... args) { |
87 | auto err = call(args...); |
88 | TI_ERROR_IF(err, get_error_message(err)); |
89 | } |
90 | |
91 | private: |
92 | using func_type = uint32_t(Args...); |
93 | |
94 | func_type *function_{nullptr}; |
95 | std::string name_, symbol_name_; |
96 | std::mutex *driver_lock_{nullptr}; |
97 | }; |
98 | |
99 | class CUDADriverBase { |
100 | public: |
101 | ~CUDADriverBase() = default; |
102 | |
103 | protected: |
104 | std::unique_ptr<DynamicLoader> loader_; |
105 | CUDADriverBase(); |
106 | |
107 | bool load_lib(std::string lib_linux, std::string lib_windows); |
108 | |
109 | bool disabled_by_env_{false}; |
110 | }; |
111 | |
112 | class CUDADriver : protected CUDADriverBase { |
113 | public: |
114 | #define PER_CUDA_FUNCTION(name, symbol_name, ...) \ |
115 | CUDADriverFunction<__VA_ARGS__> name; |
116 | #include "taichi/rhi/cuda/cuda_driver_functions.inc.h" |
117 | #undef PER_CUDA_FUNCTION |
118 | |
119 | void (*get_error_name)(uint32, const char **); |
120 | |
121 | void (*get_error_string)(uint32, const char **); |
122 | |
123 | void (*driver_get_version)(int *); |
124 | |
125 | bool detected(); |
126 | |
127 | static CUDADriver &get_instance(); |
128 | |
129 | static CUDADriver &get_instance_without_context(); |
130 | |
131 | private: |
132 | CUDADriver(); |
133 | |
134 | std::mutex lock_; |
135 | |
136 | bool cuda_version_valid_{false}; |
137 | }; |
138 | |
139 | class CUSPARSEDriver : protected CUDADriverBase { |
140 | public: |
141 | static CUSPARSEDriver &get_instance(); |
142 | |
143 | #define PER_CUSPARSE_FUNCTION(name, symbol_name, ...) \ |
144 | CUDADriverFunction<__VA_ARGS__> name; |
145 | #include "taichi/rhi/cuda/cusparse_functions.inc.h" |
146 | #undef PER_CUSPARSE_FUNCTION |
147 | |
148 | bool load_cusparse(); |
149 | |
150 | inline bool is_loaded() { |
151 | return cusparse_loaded_; |
152 | } |
153 | |
154 | private: |
155 | CUSPARSEDriver(); |
156 | std::mutex lock_; |
157 | bool cusparse_loaded_{false}; |
158 | }; |
159 | |
160 | class CUSOLVERDriver : protected CUDADriverBase { |
161 | public: |
162 | // TODO: Add cusolver function APIs |
163 | static CUSOLVERDriver &get_instance(); |
164 | |
165 | #define PER_CUSOLVER_FUNCTION(name, symbol_name, ...) \ |
166 | CUDADriverFunction<__VA_ARGS__> name; |
167 | #include "taichi/rhi/cuda/cusolver_functions.inc.h" |
168 | #undef PER_CUSOLVER_FUNCTION |
169 | |
170 | bool load_cusolver(); |
171 | |
172 | inline bool is_loaded() { |
173 | return cusolver_loaded_; |
174 | } |
175 | |
176 | private: |
177 | CUSOLVERDriver(); |
178 | std::mutex lock_; |
179 | bool cusolver_loaded_{false}; |
180 | }; |
181 | |
182 | } // namespace taichi::lang |
183 | |