1 | #include <pybind11/pybind11.h> |
2 | #include <torch/csrc/Device.h> |
3 | #include <torch/csrc/THP.h> |
4 | #include <torch/csrc/cuda/Module.h> |
5 | #include <torch/csrc/cuda/Stream.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | #include <torch/csrc/utils/python_numbers.h> |
8 | |
9 | #include <c10/cuda/CUDAGuard.h> |
10 | |
11 | #include <cuda_runtime_api.h> |
12 | #include <structmember.h> |
13 | |
14 | PyObject* THCPStreamClass = nullptr; |
15 | |
16 | static PyObject* THCPStream_pynew( |
17 | PyTypeObject* type, |
18 | PyObject* args, |
19 | PyObject* kwargs) { |
20 | HANDLE_TH_ERRORS |
21 | |
22 | const auto current_device = c10::cuda::current_device(); |
23 | |
24 | int priority = 0; |
25 | int64_t stream_id = 0; |
26 | int64_t device_index = 0; |
27 | int64_t device_type = 0; |
28 | uint64_t stream_ptr = 0; |
29 | |
30 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
31 | constexpr char* kwlist[] = { |
32 | "priority" , |
33 | "stream_id" , |
34 | "device_index" , |
35 | "device_type" , |
36 | "stream_ptr" , |
37 | nullptr}; |
38 | if (!PyArg_ParseTupleAndKeywords( |
39 | args, |
40 | kwargs, |
41 | "|iLLLK" , |
42 | const_cast<char**>(kwlist), |
43 | &priority, |
44 | &stream_id, |
45 | &device_index, |
46 | &device_type, |
47 | &stream_ptr)) { |
48 | return nullptr; |
49 | } |
50 | |
51 | THPObjectPtr ptr(type->tp_alloc(type, 0)); |
52 | if (!ptr) { |
53 | return nullptr; |
54 | } |
55 | |
56 | if (stream_ptr) { |
57 | TORCH_CHECK( |
58 | priority == 0, "Priority was explicitly set for a external stream" ) |
59 | } |
60 | |
61 | at::cuda::CUDAStream stream = (stream_id || device_index || device_type) |
62 | ? at::cuda::CUDAStream::unpack3( |
63 | stream_id, device_index, static_cast<c10::DeviceType>(device_type)) |
64 | : stream_ptr |
65 | ? at::cuda::getStreamFromExternal( |
66 | reinterpret_cast<cudaStream_t>(stream_ptr), current_device) |
67 | : at::cuda::getStreamFromPool( |
68 | /* isHighPriority */ priority < 0 ? true : false); |
69 | |
70 | THCPStream* self = (THCPStream*)ptr.get(); |
71 | self->stream_id = static_cast<int64_t>(stream.id()); |
72 | self->device_index = static_cast<int64_t>(stream.device_index()); |
73 | self->device_type = static_cast<int64_t>(stream.device_type()); |
74 | new (&self->cuda_stream) at::cuda::CUDAStream(stream); |
75 | |
76 | return (PyObject*)ptr.release(); |
77 | END_HANDLE_TH_ERRORS |
78 | } |
79 | |
80 | static void THCPStream_dealloc(THCPStream* self) { |
81 | self->cuda_stream.~CUDAStream(); |
82 | Py_TYPE(self)->tp_free((PyObject*)self); |
83 | } |
84 | |
85 | static PyObject* THCPStream_get_device(THCPStream* self, void* unused) { |
86 | HANDLE_TH_ERRORS |
87 | return THPDevice_New(self->cuda_stream.device()); |
88 | END_HANDLE_TH_ERRORS |
89 | } |
90 | |
91 | static PyObject* THCPStream_get_cuda_stream(THCPStream* self, void* unused) { |
92 | HANDLE_TH_ERRORS |
93 | return PyLong_FromVoidPtr(self->cuda_stream.stream()); |
94 | END_HANDLE_TH_ERRORS |
95 | } |
96 | |
97 | static PyObject* THCPStream_get_priority(THCPStream* self, void* unused) { |
98 | HANDLE_TH_ERRORS |
99 | return THPUtils_packInt64(self->cuda_stream.priority()); |
100 | END_HANDLE_TH_ERRORS |
101 | } |
102 | |
103 | static PyObject* THCPStream_priority_range( |
104 | PyObject* _unused, |
105 | PyObject* noargs) { |
106 | HANDLE_TH_ERRORS |
107 | auto [least_priority, greatest_priority] = |
108 | at::cuda::CUDAStream::priority_range(); |
109 | return Py_BuildValue("(ii)" , least_priority, greatest_priority); |
110 | END_HANDLE_TH_ERRORS |
111 | } |
112 | |
113 | static PyObject* THCPStream_query(PyObject* _self, PyObject* noargs) { |
114 | HANDLE_TH_ERRORS |
115 | auto self = (THCPStream*)_self; |
116 | return PyBool_FromLong(self->cuda_stream.query()); |
117 | END_HANDLE_TH_ERRORS |
118 | } |
119 | |
120 | static PyObject* THCPStream_synchronize(PyObject* _self, PyObject* noargs) { |
121 | HANDLE_TH_ERRORS { |
122 | pybind11::gil_scoped_release no_gil; |
123 | auto self = (THCPStream*)_self; |
124 | self->cuda_stream.synchronize(); |
125 | } |
126 | Py_RETURN_NONE; |
127 | END_HANDLE_TH_ERRORS |
128 | } |
129 | |
130 | static PyObject* THCPStream_eq(PyObject* _self, PyObject* _other) { |
131 | HANDLE_TH_ERRORS |
132 | auto self = (THCPStream*)_self; |
133 | auto other = (THCPStream*)_other; |
134 | return PyBool_FromLong(self->cuda_stream == other->cuda_stream); |
135 | END_HANDLE_TH_ERRORS |
136 | } |
137 | |
138 | // NOLINTNEXTLINE(modernize-avoid-c-arrays, |
139 | // cppcoreguidelines-avoid-non-const-global-variables, |
140 | // cppcoreguidelines-avoid-c-arrays) |
141 | static struct PyMemberDef THCPStream_members[] = {{nullptr}}; |
142 | |
143 | // NOLINTNEXTLINE(modernize-avoid-c-arrays, |
144 | // cppcoreguidelines-avoid-non-const-global-variables, |
145 | // cppcoreguidelines-avoid-c-arrays) |
146 | static struct PyGetSetDef THCPStream_properties[] = { |
147 | {"cuda_stream" , |
148 | (getter)THCPStream_get_cuda_stream, |
149 | nullptr, |
150 | nullptr, |
151 | nullptr}, |
152 | {"priority" , (getter)THCPStream_get_priority, nullptr, nullptr, nullptr}, |
153 | {nullptr}}; |
154 | |
155 | // NOLINTNEXTLINE(modernize-avoid-c-arrays, |
156 | // cppcoreguidelines-avoid-non-const-global-variables, |
157 | // cppcoreguidelines-avoid-c-arrays) |
158 | static PyMethodDef THCPStream_methods[] = { |
159 | {(char*)"query" , THCPStream_query, METH_NOARGS, nullptr}, |
160 | {(char*)"synchronize" , THCPStream_synchronize, METH_NOARGS, nullptr}, |
161 | {(char*)"priority_range" , |
162 | THCPStream_priority_range, |
163 | METH_STATIC | METH_NOARGS, |
164 | nullptr}, |
165 | {(char*)"__eq__" , THCPStream_eq, METH_O, nullptr}, |
166 | {nullptr}}; |
167 | |
168 | PyTypeObject THCPStreamType = { |
169 | PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._CudaStreamBase" , /* tp_name */ |
170 | sizeof(THCPStream), /* tp_basicsize */ |
171 | 0, /* tp_itemsize */ |
172 | (destructor)THCPStream_dealloc, /* tp_dealloc */ |
173 | 0, /* tp_vectorcall_offset */ |
174 | nullptr, /* tp_getattr */ |
175 | nullptr, /* tp_setattr */ |
176 | nullptr, /* tp_reserved */ |
177 | nullptr, /* tp_repr */ |
178 | nullptr, /* tp_as_number */ |
179 | nullptr, /* tp_as_sequence */ |
180 | nullptr, /* tp_as_mapping */ |
181 | nullptr, /* tp_hash */ |
182 | nullptr, /* tp_call */ |
183 | nullptr, /* tp_str */ |
184 | nullptr, /* tp_getattro */ |
185 | nullptr, /* tp_setattro */ |
186 | nullptr, /* tp_as_buffer */ |
187 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
188 | nullptr, /* tp_doc */ |
189 | nullptr, /* tp_traverse */ |
190 | nullptr, /* tp_clear */ |
191 | nullptr, /* tp_richcompare */ |
192 | 0, /* tp_weaklistoffset */ |
193 | nullptr, /* tp_iter */ |
194 | nullptr, /* tp_iternext */ |
195 | THCPStream_methods, /* tp_methods */ |
196 | THCPStream_members, /* tp_members */ |
197 | THCPStream_properties, /* tp_getset */ |
198 | nullptr, /* tp_base */ |
199 | nullptr, /* tp_dict */ |
200 | nullptr, /* tp_descr_get */ |
201 | nullptr, /* tp_descr_set */ |
202 | 0, /* tp_dictoffset */ |
203 | nullptr, /* tp_init */ |
204 | nullptr, /* tp_alloc */ |
205 | THCPStream_pynew, /* tp_new */ |
206 | }; |
207 | |
208 | void THCPStream_init(PyObject* module) { |
209 | Py_INCREF(THPStreamClass); |
210 | THCPStreamType.tp_base = THPStreamClass; |
211 | THCPStreamClass = (PyObject*)&THCPStreamType; |
212 | if (PyType_Ready(&THCPStreamType) < 0) { |
213 | throw python_error(); |
214 | } |
215 | Py_INCREF(&THCPStreamType); |
216 | if (PyModule_AddObject( |
217 | module, "_CudaStreamBase" , (PyObject*)&THCPStreamType) < 0) { |
218 | throw python_error(); |
219 | } |
220 | } |
221 | |