1 | #pragma once |
2 | |
3 | #include <c10/core/Allocator.h> |
4 | #include <ATen/core/Generator.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <c10/util/Registry.h> |
8 | |
9 | #include <cstddef> |
10 | #include <functional> |
11 | #include <memory> |
12 | |
13 | // Forward-declares at::cuda::NVRTC |
14 | namespace at { namespace cuda { |
15 | struct NVRTC; |
16 | }} // at::cuda |
17 | |
18 | namespace at { |
19 | class Context; |
20 | } |
21 | |
22 | // NB: Class must live in `at` due to limitations of Registry.h. |
23 | namespace at { |
24 | |
25 | #ifdef _MSC_VER |
26 | constexpr const char* CUDA_HELP = |
27 | "PyTorch splits its backend into two shared libraries: a CPU library " |
28 | "and a CUDA library; this error has occurred because you are trying " |
29 | "to use some CUDA functionality, but the CUDA library has not been " |
30 | "loaded by the dynamic linker for some reason. The CUDA library MUST " |
31 | "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " |
32 | "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " |
33 | "in your link arguments; many dynamic linkers will delete dynamic library " |
34 | "dependencies if you don't depend on any of their symbols. You can check " |
35 | "if this has occurred by using link on your binary to see if there is a " |
36 | "dependency on *_cuda.dll library." ; |
37 | #else |
38 | constexpr const char* CUDA_HELP = |
39 | "PyTorch splits its backend into two shared libraries: a CPU library " |
40 | "and a CUDA library; this error has occurred because you are trying " |
41 | "to use some CUDA functionality, but the CUDA library has not been " |
42 | "loaded by the dynamic linker for some reason. The CUDA library MUST " |
43 | "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " |
44 | "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " |
45 | "dynamic linkers will delete dynamic library dependencies if you don't " |
46 | "depend on any of their symbols. You can check if this has occurred by " |
47 | "using ldd on your binary to see if there is a dependency on *_cuda.so " |
48 | "library." ; |
49 | #endif |
50 | |
51 | // The CUDAHooksInterface is an omnibus interface for any CUDA functionality |
52 | // which we may want to call into from CPU code (and thus must be dynamically |
53 | // dispatched, to allow for separate compilation of CUDA code). How do I |
54 | // decide if a function should live in this class? There are two tests: |
55 | // |
56 | // 1. Does the *implementation* of this function require linking against |
57 | // CUDA libraries? |
58 | // |
59 | // 2. Is this function *called* from non-CUDA ATen code? |
60 | // |
61 | // (2) should filter out many ostensible use-cases, since many times a CUDA |
62 | // function provided by ATen is only really ever used by actual CUDA code. |
63 | // |
64 | // TODO: Consider putting the stub definitions in another class, so that one |
65 | // never forgets to implement each virtual function in the real implementation |
66 | // in CUDAHooks. This probably doesn't buy us much though. |
67 | struct TORCH_API CUDAHooksInterface { |
68 | // This should never actually be implemented, but it is used to |
69 | // squelch -Werror=non-virtual-dtor |
70 | virtual ~CUDAHooksInterface() = default; |
71 | |
72 | // Initialize THCState and, transitively, the CUDA state |
73 | virtual void initCUDA() const { |
74 | TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. " , CUDA_HELP); |
75 | } |
76 | |
77 | virtual const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const { |
78 | (void)device_index; // Suppress unused variable warning |
79 | TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. " , CUDA_HELP); |
80 | } |
81 | |
82 | virtual Device getDeviceFromPtr(void* /*data*/) const { |
83 | TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. " , CUDA_HELP); |
84 | } |
85 | |
86 | virtual bool isPinnedPtr(void* /*data*/) const { |
87 | return false; |
88 | } |
89 | |
90 | virtual bool hasCUDA() const { |
91 | return false; |
92 | } |
93 | |
94 | virtual bool hasCUDART() const { |
95 | return false; |
96 | } |
97 | |
98 | virtual bool hasMAGMA() const { |
99 | return false; |
100 | } |
101 | |
102 | virtual bool hasCuDNN() const { |
103 | return false; |
104 | } |
105 | |
106 | virtual bool hasCuSOLVER() const { |
107 | return false; |
108 | } |
109 | |
110 | virtual bool hasROCM() const { |
111 | return false; |
112 | } |
113 | |
114 | virtual const at::cuda::NVRTC& nvrtc() const { |
115 | TORCH_CHECK(false, "NVRTC requires CUDA. " , CUDA_HELP); |
116 | } |
117 | |
118 | virtual bool hasPrimaryContext(int64_t device_index) const { |
119 | TORCH_CHECK(false, "Cannot call hasPrimaryContext(" , device_index, ") without ATen_cuda library. " , CUDA_HELP); |
120 | } |
121 | |
122 | virtual int64_t current_device() const { |
123 | return -1; |
124 | } |
125 | |
126 | virtual Allocator* getPinnedMemoryAllocator() const { |
127 | TORCH_CHECK(false, "Pinned memory requires CUDA. " , CUDA_HELP); |
128 | } |
129 | |
130 | virtual Allocator* getCUDADeviceAllocator() const { |
131 | TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. " , CUDA_HELP); |
132 | } |
133 | |
134 | virtual bool compiledWithCuDNN() const { |
135 | return false; |
136 | } |
137 | |
138 | virtual bool compiledWithMIOpen() const { |
139 | return false; |
140 | } |
141 | |
142 | virtual bool supportsDilatedConvolutionWithCuDNN() const { |
143 | return false; |
144 | } |
145 | |
146 | virtual bool supportsDepthwiseConvolutionWithCuDNN() const { |
147 | return false; |
148 | } |
149 | |
150 | virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const { |
151 | return false; |
152 | } |
153 | |
154 | virtual long versionCuDNN() const { |
155 | TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. " , CUDA_HELP); |
156 | } |
157 | |
158 | virtual long versionCUDART() const { |
159 | TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. " , CUDA_HELP); |
160 | } |
161 | |
162 | virtual std::string showConfig() const { |
163 | TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. " , CUDA_HELP); |
164 | } |
165 | |
166 | virtual double batchnormMinEpsilonCuDNN() const { |
167 | TORCH_CHECK(false, |
168 | "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. " , CUDA_HELP); |
169 | } |
170 | |
171 | virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const { |
172 | TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. " , CUDA_HELP); |
173 | } |
174 | |
175 | virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const { |
176 | TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. " , CUDA_HELP); |
177 | } |
178 | |
179 | virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const { |
180 | TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. " , CUDA_HELP); |
181 | } |
182 | |
183 | virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const { |
184 | TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. " , CUDA_HELP); |
185 | } |
186 | |
187 | virtual int getNumGPUs() const { |
188 | return 0; |
189 | } |
190 | |
191 | virtual void deviceSynchronize(int64_t /*device_index*/) const { |
192 | TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. " , CUDA_HELP); |
193 | } |
194 | }; |
195 | |
196 | // NB: dummy argument to suppress "ISO C++11 requires at least one argument |
197 | // for the "..." in a variadic macro" |
198 | struct TORCH_API CUDAHooksArgs {}; |
199 | |
200 | C10_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs); |
201 | #define REGISTER_CUDA_HOOKS(clsname) \ |
202 | C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname) |
203 | |
204 | namespace detail { |
205 | TORCH_API const CUDAHooksInterface& getCUDAHooks(); |
206 | } // namespace detail |
207 | } // namespace at |
208 | |