1#include <cuda.h>
2#include <cuda_runtime.h>
3#include <torch/csrc/utils/pybind.h>
4#if !defined(USE_ROCM)
5#include <cuda_profiler_api.h>
6#else
7#include <hip/hip_runtime_api.h>
8#endif
9
10#include <c10/cuda/CUDAException.h>
11#include <c10/cuda/CUDAGuard.h>
12
13namespace torch {
14namespace cuda {
15namespace shared {
16
17#ifdef USE_ROCM
18namespace {
19hipError_t hipReturnSuccess() {
20 return hipSuccess;
21}
22} // namespace
23#endif
24
25void initCudartBindings(PyObject* module) {
26 auto m = py::handle(module).cast<py::module>();
27
28 auto cudart = m.def_submodule("_cudart", "libcudart.so bindings");
29
30 // By splitting the names of these objects into two literals we prevent the
31 // HIP rewrite rules from changing these names when building with HIP.
32
33#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000
34 // cudaOutputMode_t is used in cudaProfilerInitialize only. The latter is gone
35 // in CUDA 12.
36 py::enum_<cudaOutputMode_t>(
37 cudart,
38 "cuda"
39 "OutputMode")
40 .value("KeyValuePair", cudaKeyValuePair)
41 .value("CSV", cudaCSV);
42#endif
43
44 py::enum_<cudaError_t>(
45 cudart,
46 "cuda"
47 "Error")
48 .value("success", cudaSuccess);
49
50 cudart.def(
51 "cuda"
52 "GetErrorString",
53 cudaGetErrorString);
54 cudart.def(
55 "cuda"
56 "ProfilerStart",
57#ifdef USE_ROCM
58 hipReturnSuccess
59#else
60 cudaProfilerStart
61#endif
62 );
63 cudart.def(
64 "cuda"
65 "ProfilerStop",
66#ifdef USE_ROCM
67 hipReturnSuccess
68#else
69 cudaProfilerStop
70#endif
71 );
72 cudart.def(
73 "cuda"
74 "HostRegister",
75 [](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t {
76 return C10_CUDA_ERROR_HANDLED(
77 cudaHostRegister((void*)ptr, size, flags));
78 });
79 cudart.def(
80 "cuda"
81 "HostUnregister",
82 [](uintptr_t ptr) -> cudaError_t {
83 return C10_CUDA_ERROR_HANDLED(cudaHostUnregister((void*)ptr));
84 });
85 cudart.def(
86 "cuda"
87 "StreamCreate",
88 [](uintptr_t ptr) -> cudaError_t {
89 return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
90 });
91 cudart.def(
92 "cuda"
93 "StreamDestroy",
94 [](uintptr_t ptr) -> cudaError_t {
95 return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr));
96 });
97#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000
98 // cudaProfilerInitialize is no longer needed after CUDA 12:
99 // https://forums.developer.nvidia.com/t/cudaprofilerinitialize-is-deprecated-alternative/200776/3
100 cudart.def(
101 "cuda"
102 "ProfilerInitialize",
103 cudaProfilerInitialize);
104#endif
105 cudart.def(
106 "cuda"
107 "MemGetInfo",
108 [](int device) -> std::pair<size_t, size_t> {
109 c10::cuda::CUDAGuard guard(device);
110 size_t device_free = 0;
111 size_t device_total = 0;
112 C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
113 return {device_free, device_total};
114 });
115}
116
117} // namespace shared
118} // namespace cuda
119} // namespace torch
120