1#pragma once
2
3#include <ATen/detail/CUDAHooksInterface.h>
4
5#include <ATen/Generator.h>
6#include <c10/util/Optional.h>
7
8// TODO: No need to have this whole header, we can just put it all in
9// the cpp file
10
11namespace at { namespace cuda { namespace detail {
12
13// Set the callback to initialize Magma, which is set by
14// torch_cuda_cu. This indirection is required so magma_init is called
15// in the same library where Magma will be used.
16TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
17
18TORCH_CUDA_CPP_API bool hasPrimaryContext(int64_t device_index);
19TORCH_CUDA_CPP_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
20
21// The real implementation of CUDAHooksInterface
22struct CUDAHooks : public at::CUDAHooksInterface {
23 CUDAHooks(at::CUDAHooksArgs) {}
24 void initCUDA() const override;
25 Device getDeviceFromPtr(void* data) const override;
26 bool isPinnedPtr(void* data) const override;
27 const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
28 bool hasCUDA() const override;
29 bool hasMAGMA() const override;
30 bool hasCuDNN() const override;
31 bool hasCuSOLVER() const override;
32 bool hasROCM() const override;
33 const at::cuda::NVRTC& nvrtc() const override;
34 int64_t current_device() const override;
35 bool hasPrimaryContext(int64_t device_index) const override;
36 Allocator* getCUDADeviceAllocator() const override;
37 Allocator* getPinnedMemoryAllocator() const override;
38 bool compiledWithCuDNN() const override;
39 bool compiledWithMIOpen() const override;
40 bool supportsDilatedConvolutionWithCuDNN() const override;
41 bool supportsDepthwiseConvolutionWithCuDNN() const override;
42 bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
43 bool hasCUDART() const override;
44 long versionCUDART() const override;
45 long versionCuDNN() const override;
46 std::string showConfig() const override;
47 double batchnormMinEpsilonCuDNN() const override;
48 int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
49 void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
50 int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
51 void cuFFTClearPlanCache(int64_t device_index) const override;
52 int getNumGPUs() const override;
53 void deviceSynchronize(int64_t device_index) const override;
54};
55
56}}} // at::cuda::detail
57