1#pragma once
2
3#include <cstdint>
4
5#include <cuda_runtime_api.h>
6#include <cusparse.h>
7#include <cublas_v2.h>
8
9#ifdef CUDART_VERSION
10#include <cusolverDn.h>
11#endif
12
13#include <ATen/core/ATenGeneral.h>
14#include <ATen/Context.h>
15#include <c10/cuda/CUDAStream.h>
16#include <c10/cuda/CUDAFunctions.h>
17#include <ATen/cuda/Exceptions.h>
18
19namespace at {
20namespace cuda {
21
22/*
23A common CUDA interface for ATen.
24
25This interface is distinct from CUDAHooks, which defines an interface that links
26to both CPU-only and CUDA builds. That interface is intended for runtime
27dispatch and should be used from files that are included in both CPU-only and
28CUDA builds.
29
30CUDAContext, on the other hand, should be preferred by files only included in
31CUDA builds. It is intended to expose CUDA functionality in a consistent
32manner.
33
34This means there is some overlap between the CUDAContext and CUDAHooks, but
35the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
36use CUDAHooks otherwise.
37
38Note that CUDAContext simply defines an interface with no associated class.
39It is expected that the modules whose functions compose this interface will
40manage their own state. There is only a single CUDA context/state.
41*/
42
43/**
44 * DEPRECATED: use device_count() instead
45 */
46inline int64_t getNumGPUs() {
47 return c10::cuda::device_count();
48}
49
50/**
51 * CUDA is available if we compiled with CUDA, and there are one or more
52 * devices. If we compiled with CUDA but there is a driver problem, etc.,
53 * this function will report CUDA is not available (rather than raise an error.)
54 */
55inline bool is_available() {
56 return c10::cuda::device_count() > 0;
57}
58
59TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
60
61TORCH_CUDA_CPP_API int warp_size();
62
63TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(int64_t device);
64
65TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
66 int64_t device,
67 int64_t peer_device);
68
69TORCH_CUDA_CPP_API Allocator* getCUDADeviceAllocator();
70
71/* Handles */
72TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
73TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
74
75TORCH_CUDA_CPP_API void clearCublasWorkspaces();
76
77#ifdef CUDART_VERSION
78TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
79#endif
80
81} // namespace cuda
82} // namespace at
83