1#pragma once
2
3#include <ATen/cuda/Exceptions.h>
4
5#include <cuda.h>
6#include <cuda_runtime.h>
7
8namespace at {
9namespace cuda {
10
11inline Device getDeviceFromPtr(void* ptr) {
12 cudaPointerAttributes attr{};
13
14 AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
15
16#if !defined(USE_ROCM)
17 TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
18 "The specified pointer resides on host memory and is not registered with any CUDA device.");
19#endif
20
21 return {DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
22}
23
24}} // namespace at::cuda
25