1 | #pragma once |
2 | |
3 | #include <ATen/cuda/Exceptions.h> |
4 | |
5 | #include <cuda.h> |
6 | #include <cuda_runtime.h> |
7 | |
8 | namespace at { |
9 | namespace cuda { |
10 | |
11 | inline Device getDeviceFromPtr(void* ptr) { |
12 | cudaPointerAttributes attr{}; |
13 | |
14 | AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); |
15 | |
16 | #if !defined(USE_ROCM) |
17 | TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, |
18 | "The specified pointer resides on host memory and is not registered with any CUDA device." ); |
19 | #endif |
20 | |
21 | return {DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)}; |
22 | } |
23 | |
24 | }} // namespace at::cuda |
25 | |