1 | #pragma once |
2 | |
3 | #include <ATen/detail/CUDAHooksInterface.h> |
4 | |
5 | #include <ATen/Generator.h> |
6 | #include <c10/util/Optional.h> |
7 | |
8 | // TODO: No need to have this whole header, we can just put it all in |
9 | // the cpp file |
10 | |
11 | namespace at { namespace cuda { namespace detail { |
12 | |
13 | // Set the callback to initialize Magma, which is set by |
14 | // torch_cuda_cu. This indirection is required so magma_init is called |
15 | // in the same library where Magma will be used. |
16 | TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); |
17 | |
18 | TORCH_CUDA_CPP_API bool hasPrimaryContext(int64_t device_index); |
19 | TORCH_CUDA_CPP_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext(); |
20 | |
21 | // The real implementation of CUDAHooksInterface |
22 | struct CUDAHooks : public at::CUDAHooksInterface { |
23 | CUDAHooks(at::CUDAHooksArgs) {} |
24 | void initCUDA() const override; |
25 | Device getDeviceFromPtr(void* data) const override; |
26 | bool isPinnedPtr(void* data) const override; |
27 | const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; |
28 | bool hasCUDA() const override; |
29 | bool hasMAGMA() const override; |
30 | bool hasCuDNN() const override; |
31 | bool hasCuSOLVER() const override; |
32 | bool hasROCM() const override; |
33 | const at::cuda::NVRTC& nvrtc() const override; |
34 | int64_t current_device() const override; |
35 | bool hasPrimaryContext(int64_t device_index) const override; |
36 | Allocator* getCUDADeviceAllocator() const override; |
37 | Allocator* getPinnedMemoryAllocator() const override; |
38 | bool compiledWithCuDNN() const override; |
39 | bool compiledWithMIOpen() const override; |
40 | bool supportsDilatedConvolutionWithCuDNN() const override; |
41 | bool supportsDepthwiseConvolutionWithCuDNN() const override; |
42 | bool supportsBFloat16ConvolutionWithCuDNNv8() const override; |
43 | bool hasCUDART() const override; |
44 | long versionCUDART() const override; |
45 | long versionCuDNN() const override; |
46 | std::string showConfig() const override; |
47 | double batchnormMinEpsilonCuDNN() const override; |
48 | int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override; |
49 | void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override; |
50 | int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override; |
51 | void cuFFTClearPlanCache(int64_t device_index) const override; |
52 | int getNumGPUs() const override; |
53 | void deviceSynchronize(int64_t device_index) const override; |
54 | }; |
55 | |
56 | }}} // at::cuda::detail |
57 | |