1 | #pragma once |
2 | |
3 | #include <cstdint> |
4 | |
5 | #include <cuda_runtime_api.h> |
6 | #include <cusparse.h> |
7 | #include <cublas_v2.h> |
8 | |
9 | #ifdef CUDART_VERSION |
10 | #include <cusolverDn.h> |
11 | #endif |
12 | |
13 | #include <ATen/core/ATenGeneral.h> |
14 | #include <ATen/Context.h> |
15 | #include <c10/cuda/CUDAStream.h> |
16 | #include <c10/cuda/CUDAFunctions.h> |
17 | #include <ATen/cuda/Exceptions.h> |
18 | |
19 | namespace at { |
20 | namespace cuda { |
21 | |
22 | /* |
23 | A common CUDA interface for ATen. |
24 | |
25 | This interface is distinct from CUDAHooks, which defines an interface that links |
26 | to both CPU-only and CUDA builds. That interface is intended for runtime |
27 | dispatch and should be used from files that are included in both CPU-only and |
28 | CUDA builds. |
29 | |
30 | CUDAContext, on the other hand, should be preferred by files only included in |
31 | CUDA builds. It is intended to expose CUDA functionality in a consistent |
32 | manner. |
33 | |
34 | This means there is some overlap between the CUDAContext and CUDAHooks, but |
35 | the choice of which to use is simple: use CUDAContext when in a CUDA-only file, |
36 | use CUDAHooks otherwise. |
37 | |
38 | Note that CUDAContext simply defines an interface with no associated class. |
39 | It is expected that the modules whose functions compose this interface will |
40 | manage their own state. There is only a single CUDA context/state. |
41 | */ |
42 | |
43 | /** |
44 | * DEPRECATED: use device_count() instead |
45 | */ |
46 | inline int64_t getNumGPUs() { |
47 | return c10::cuda::device_count(); |
48 | } |
49 | |
50 | /** |
51 | * CUDA is available if we compiled with CUDA, and there are one or more |
52 | * devices. If we compiled with CUDA but there is a driver problem, etc., |
53 | * this function will report CUDA is not available (rather than raise an error.) |
54 | */ |
55 | inline bool is_available() { |
56 | return c10::cuda::device_count() > 0; |
57 | } |
58 | |
59 | TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties(); |
60 | |
61 | TORCH_CUDA_CPP_API int warp_size(); |
62 | |
63 | TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(int64_t device); |
64 | |
65 | TORCH_CUDA_CPP_API bool canDeviceAccessPeer( |
66 | int64_t device, |
67 | int64_t peer_device); |
68 | |
69 | TORCH_CUDA_CPP_API Allocator* getCUDADeviceAllocator(); |
70 | |
71 | /* Handles */ |
72 | TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); |
73 | TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); |
74 | |
75 | TORCH_CUDA_CPP_API void clearCublasWorkspaces(); |
76 | |
77 | #ifdef CUDART_VERSION |
78 | TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); |
79 | #endif |
80 | |
81 | } // namespace cuda |
82 | } // namespace at |
83 | |