1#include <ATen/cuda/CUDAContext.h>
2#include <c10/cuda/CUDACachingAllocator.h>
3#include <c10/util/CallOnce.h>
4
5#include <ATen/cuda/CUDAConfig.h>
6#include <mutex>
7#include <deque>
8#include <vector>
9
10namespace at { namespace cuda {
11
12namespace {
13
14DeviceIndex num_gpus = -1;
15c10::once_flag init_flag;
16std::deque<c10::once_flag> device_flags;
17std::vector<cudaDeviceProp> device_properties;
18
19void initCUDAContextVectors() {
20 num_gpus = c10::cuda::device_count();
21 device_flags.resize(num_gpus);
22 device_properties.resize(num_gpus);
23}
24
25void initDeviceProperty(DeviceIndex device_index) {
26 cudaDeviceProp device_prop;
27 AT_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_index));
28 device_properties[device_index] = device_prop;
29}
30
31} // anonymous namespace
32
33// We need this function to force the linking against torch_cuda(_cpp) on Windows.
34// If you need to modify this function, please specify a new function and apply
35// the changes according to https://github.com/pytorch/pytorch/pull/34288.
36// Related issue: https://github.com/pytorch/pytorch/issues/31611.
37/* Device info */
38int warp_size() {
39 return getCurrentDeviceProperties()->warpSize;
40}
41
42cudaDeviceProp* getCurrentDeviceProperties() {
43 auto device = c10::cuda::current_device();
44 return getDeviceProperties(device);
45}
46
47cudaDeviceProp* getDeviceProperties(int64_t device) {
48 c10::call_once(init_flag, initCUDAContextVectors);
49 if (device == -1) device = c10::cuda::current_device();
50 AT_ASSERT(device >= 0 && device < num_gpus);
51 c10::call_once(device_flags[device], initDeviceProperty, device);
52 return &device_properties[device];
53}
54
55bool canDeviceAccessPeer(int64_t device, int64_t peer_device) {
56 c10::call_once(init_flag, initCUDAContextVectors);
57 if (device == -1) device = c10::cuda::current_device();
58 AT_ASSERT(device >= 0 && device < num_gpus);
59 AT_ASSERT(peer_device >= 0 && peer_device < num_gpus);
60 int can_access = 0;
61 AT_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device, peer_device));
62 return can_access != 0;
63}
64
65Allocator* getCUDADeviceAllocator() {
66 return c10::cuda::CUDACachingAllocator::get();
67}
68
69} // namespace cuda
70
71} // namespace at
72