1#include <pybind11/pybind11.h>
2#include <torch/csrc/Device.h>
3#include <torch/csrc/THP.h>
4#include <torch/csrc/cuda/Module.h>
5#include <torch/csrc/cuda/Stream.h>
6#include <torch/csrc/utils/pybind.h>
7#include <torch/csrc/utils/python_numbers.h>
8
9#include <c10/cuda/CUDAGuard.h>
10
11#include <cuda_runtime_api.h>
12#include <structmember.h>
13
14PyObject* THCPStreamClass = nullptr;
15
16static PyObject* THCPStream_pynew(
17 PyTypeObject* type,
18 PyObject* args,
19 PyObject* kwargs) {
20 HANDLE_TH_ERRORS
21
22 const auto current_device = c10::cuda::current_device();
23
24 int priority = 0;
25 int64_t stream_id = 0;
26 int64_t device_index = 0;
27 int64_t device_type = 0;
28 uint64_t stream_ptr = 0;
29
30 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
31 constexpr char* kwlist[] = {
32 "priority",
33 "stream_id",
34 "device_index",
35 "device_type",
36 "stream_ptr",
37 nullptr};
38 if (!PyArg_ParseTupleAndKeywords(
39 args,
40 kwargs,
41 "|iLLLK",
42 const_cast<char**>(kwlist),
43 &priority,
44 &stream_id,
45 &device_index,
46 &device_type,
47 &stream_ptr)) {
48 return nullptr;
49 }
50
51 THPObjectPtr ptr(type->tp_alloc(type, 0));
52 if (!ptr) {
53 return nullptr;
54 }
55
56 if (stream_ptr) {
57 TORCH_CHECK(
58 priority == 0, "Priority was explicitly set for a external stream")
59 }
60
61 at::cuda::CUDAStream stream = (stream_id || device_index || device_type)
62 ? at::cuda::CUDAStream::unpack3(
63 stream_id, device_index, static_cast<c10::DeviceType>(device_type))
64 : stream_ptr
65 ? at::cuda::getStreamFromExternal(
66 reinterpret_cast<cudaStream_t>(stream_ptr), current_device)
67 : at::cuda::getStreamFromPool(
68 /* isHighPriority */ priority < 0 ? true : false);
69
70 THCPStream* self = (THCPStream*)ptr.get();
71 self->stream_id = static_cast<int64_t>(stream.id());
72 self->device_index = static_cast<int64_t>(stream.device_index());
73 self->device_type = static_cast<int64_t>(stream.device_type());
74 new (&self->cuda_stream) at::cuda::CUDAStream(stream);
75
76 return (PyObject*)ptr.release();
77 END_HANDLE_TH_ERRORS
78}
79
80static void THCPStream_dealloc(THCPStream* self) {
81 self->cuda_stream.~CUDAStream();
82 Py_TYPE(self)->tp_free((PyObject*)self);
83}
84
85static PyObject* THCPStream_get_device(THCPStream* self, void* unused) {
86 HANDLE_TH_ERRORS
87 return THPDevice_New(self->cuda_stream.device());
88 END_HANDLE_TH_ERRORS
89}
90
91static PyObject* THCPStream_get_cuda_stream(THCPStream* self, void* unused) {
92 HANDLE_TH_ERRORS
93 return PyLong_FromVoidPtr(self->cuda_stream.stream());
94 END_HANDLE_TH_ERRORS
95}
96
97static PyObject* THCPStream_get_priority(THCPStream* self, void* unused) {
98 HANDLE_TH_ERRORS
99 return THPUtils_packInt64(self->cuda_stream.priority());
100 END_HANDLE_TH_ERRORS
101}
102
103static PyObject* THCPStream_priority_range(
104 PyObject* _unused,
105 PyObject* noargs) {
106 HANDLE_TH_ERRORS
107 auto [least_priority, greatest_priority] =
108 at::cuda::CUDAStream::priority_range();
109 return Py_BuildValue("(ii)", least_priority, greatest_priority);
110 END_HANDLE_TH_ERRORS
111}
112
113static PyObject* THCPStream_query(PyObject* _self, PyObject* noargs) {
114 HANDLE_TH_ERRORS
115 auto self = (THCPStream*)_self;
116 return PyBool_FromLong(self->cuda_stream.query());
117 END_HANDLE_TH_ERRORS
118}
119
120static PyObject* THCPStream_synchronize(PyObject* _self, PyObject* noargs) {
121 HANDLE_TH_ERRORS {
122 pybind11::gil_scoped_release no_gil;
123 auto self = (THCPStream*)_self;
124 self->cuda_stream.synchronize();
125 }
126 Py_RETURN_NONE;
127 END_HANDLE_TH_ERRORS
128}
129
130static PyObject* THCPStream_eq(PyObject* _self, PyObject* _other) {
131 HANDLE_TH_ERRORS
132 auto self = (THCPStream*)_self;
133 auto other = (THCPStream*)_other;
134 return PyBool_FromLong(self->cuda_stream == other->cuda_stream);
135 END_HANDLE_TH_ERRORS
136}
137
138// NOLINTNEXTLINE(modernize-avoid-c-arrays,
139// cppcoreguidelines-avoid-non-const-global-variables,
140// cppcoreguidelines-avoid-c-arrays)
141static struct PyMemberDef THCPStream_members[] = {{nullptr}};
142
143// NOLINTNEXTLINE(modernize-avoid-c-arrays,
144// cppcoreguidelines-avoid-non-const-global-variables,
145// cppcoreguidelines-avoid-c-arrays)
146static struct PyGetSetDef THCPStream_properties[] = {
147 {"cuda_stream",
148 (getter)THCPStream_get_cuda_stream,
149 nullptr,
150 nullptr,
151 nullptr},
152 {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr},
153 {nullptr}};
154
155// NOLINTNEXTLINE(modernize-avoid-c-arrays,
156// cppcoreguidelines-avoid-non-const-global-variables,
157// cppcoreguidelines-avoid-c-arrays)
158static PyMethodDef THCPStream_methods[] = {
159 {(char*)"query", THCPStream_query, METH_NOARGS, nullptr},
160 {(char*)"synchronize", THCPStream_synchronize, METH_NOARGS, nullptr},
161 {(char*)"priority_range",
162 THCPStream_priority_range,
163 METH_STATIC | METH_NOARGS,
164 nullptr},
165 {(char*)"__eq__", THCPStream_eq, METH_O, nullptr},
166 {nullptr}};
167
168PyTypeObject THCPStreamType = {
169 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._CudaStreamBase", /* tp_name */
170 sizeof(THCPStream), /* tp_basicsize */
171 0, /* tp_itemsize */
172 (destructor)THCPStream_dealloc, /* tp_dealloc */
173 0, /* tp_vectorcall_offset */
174 nullptr, /* tp_getattr */
175 nullptr, /* tp_setattr */
176 nullptr, /* tp_reserved */
177 nullptr, /* tp_repr */
178 nullptr, /* tp_as_number */
179 nullptr, /* tp_as_sequence */
180 nullptr, /* tp_as_mapping */
181 nullptr, /* tp_hash */
182 nullptr, /* tp_call */
183 nullptr, /* tp_str */
184 nullptr, /* tp_getattro */
185 nullptr, /* tp_setattro */
186 nullptr, /* tp_as_buffer */
187 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
188 nullptr, /* tp_doc */
189 nullptr, /* tp_traverse */
190 nullptr, /* tp_clear */
191 nullptr, /* tp_richcompare */
192 0, /* tp_weaklistoffset */
193 nullptr, /* tp_iter */
194 nullptr, /* tp_iternext */
195 THCPStream_methods, /* tp_methods */
196 THCPStream_members, /* tp_members */
197 THCPStream_properties, /* tp_getset */
198 nullptr, /* tp_base */
199 nullptr, /* tp_dict */
200 nullptr, /* tp_descr_get */
201 nullptr, /* tp_descr_set */
202 0, /* tp_dictoffset */
203 nullptr, /* tp_init */
204 nullptr, /* tp_alloc */
205 THCPStream_pynew, /* tp_new */
206};
207
208void THCPStream_init(PyObject* module) {
209 Py_INCREF(THPStreamClass);
210 THCPStreamType.tp_base = THPStreamClass;
211 THCPStreamClass = (PyObject*)&THCPStreamType;
212 if (PyType_Ready(&THCPStreamType) < 0) {
213 throw python_error();
214 }
215 Py_INCREF(&THCPStreamType);
216 if (PyModule_AddObject(
217 module, "_CudaStreamBase", (PyObject*)&THCPStreamType) < 0) {
218 throw python_error();
219 }
220}
221