1#pragma once
2
3#include <c10/core/Allocator.h>
4#include <ATen/core/Generator.h>
5#include <c10/util/Exception.h>
6#include <c10/util/Optional.h>
7#include <c10/util/Registry.h>
8
9#include <cstddef>
10#include <functional>
11#include <memory>
12
13// Forward-declares at::cuda::NVRTC
14namespace at { namespace cuda {
15struct NVRTC;
16}} // at::cuda
17
18namespace at {
19class Context;
20}
21
22// NB: Class must live in `at` due to limitations of Registry.h.
23namespace at {
24
25#ifdef _MSC_VER
26constexpr const char* CUDA_HELP =
27 "PyTorch splits its backend into two shared libraries: a CPU library "
28 "and a CUDA library; this error has occurred because you are trying "
29 "to use some CUDA functionality, but the CUDA library has not been "
30 "loaded by the dynamic linker for some reason. The CUDA library MUST "
31 "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
32 "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ "
33 "in your link arguments; many dynamic linkers will delete dynamic library "
34 "dependencies if you don't depend on any of their symbols. You can check "
35 "if this has occurred by using link on your binary to see if there is a "
36 "dependency on *_cuda.dll library.";
37#else
38constexpr const char* CUDA_HELP =
39 "PyTorch splits its backend into two shared libraries: a CPU library "
40 "and a CUDA library; this error has occurred because you are trying "
41 "to use some CUDA functionality, but the CUDA library has not been "
42 "loaded by the dynamic linker for some reason. The CUDA library MUST "
43 "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
44 "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many "
45 "dynamic linkers will delete dynamic library dependencies if you don't "
46 "depend on any of their symbols. You can check if this has occurred by "
47 "using ldd on your binary to see if there is a dependency on *_cuda.so "
48 "library.";
49#endif
50
51// The CUDAHooksInterface is an omnibus interface for any CUDA functionality
52// which we may want to call into from CPU code (and thus must be dynamically
53// dispatched, to allow for separate compilation of CUDA code). How do I
54// decide if a function should live in this class? There are two tests:
55//
56// 1. Does the *implementation* of this function require linking against
57// CUDA libraries?
58//
59// 2. Is this function *called* from non-CUDA ATen code?
60//
61// (2) should filter out many ostensible use-cases, since many times a CUDA
62// function provided by ATen is only really ever used by actual CUDA code.
63//
64// TODO: Consider putting the stub definitions in another class, so that one
65// never forgets to implement each virtual function in the real implementation
66// in CUDAHooks. This probably doesn't buy us much though.
67struct TORCH_API CUDAHooksInterface {
68 // This should never actually be implemented, but it is used to
69 // squelch -Werror=non-virtual-dtor
70 virtual ~CUDAHooksInterface() = default;
71
72 // Initialize THCState and, transitively, the CUDA state
73 virtual void initCUDA() const {
74 TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
75 }
76
77 virtual const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const {
78 (void)device_index; // Suppress unused variable warning
79 TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
80 }
81
82 virtual Device getDeviceFromPtr(void* /*data*/) const {
83 TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
84 }
85
86 virtual bool isPinnedPtr(void* /*data*/) const {
87 return false;
88 }
89
90 virtual bool hasCUDA() const {
91 return false;
92 }
93
94 virtual bool hasCUDART() const {
95 return false;
96 }
97
98 virtual bool hasMAGMA() const {
99 return false;
100 }
101
102 virtual bool hasCuDNN() const {
103 return false;
104 }
105
106 virtual bool hasCuSOLVER() const {
107 return false;
108 }
109
110 virtual bool hasROCM() const {
111 return false;
112 }
113
114 virtual const at::cuda::NVRTC& nvrtc() const {
115 TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
116 }
117
118 virtual bool hasPrimaryContext(int64_t device_index) const {
119 TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
120 }
121
122 virtual int64_t current_device() const {
123 return -1;
124 }
125
126 virtual Allocator* getPinnedMemoryAllocator() const {
127 TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
128 }
129
130 virtual Allocator* getCUDADeviceAllocator() const {
131 TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP);
132 }
133
134 virtual bool compiledWithCuDNN() const {
135 return false;
136 }
137
138 virtual bool compiledWithMIOpen() const {
139 return false;
140 }
141
142 virtual bool supportsDilatedConvolutionWithCuDNN() const {
143 return false;
144 }
145
146 virtual bool supportsDepthwiseConvolutionWithCuDNN() const {
147 return false;
148 }
149
150 virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const {
151 return false;
152 }
153
154 virtual long versionCuDNN() const {
155 TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
156 }
157
158 virtual long versionCUDART() const {
159 TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
160 }
161
162 virtual std::string showConfig() const {
163 TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
164 }
165
166 virtual double batchnormMinEpsilonCuDNN() const {
167 TORCH_CHECK(false,
168 "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
169 }
170
171 virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const {
172 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
173 }
174
175 virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const {
176 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
177 }
178
179 virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const {
180 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
181 }
182
183 virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const {
184 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
185 }
186
187 virtual int getNumGPUs() const {
188 return 0;
189 }
190
191 virtual void deviceSynchronize(int64_t /*device_index*/) const {
192 TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
193 }
194};
195
196// NB: dummy argument to suppress "ISO C++11 requires at least one argument
197// for the "..." in a variadic macro"
198struct TORCH_API CUDAHooksArgs {};
199
200C10_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs);
201#define REGISTER_CUDA_HOOKS(clsname) \
202 C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname)
203
204namespace detail {
205TORCH_API const CUDAHooksInterface& getCUDAHooks();
206} // namespace detail
207} // namespace at
208