1#include <pybind11/pybind11.h>
2#include <torch/csrc/Device.h>
3#include <torch/csrc/THP.h>
4#include <torch/csrc/utils/pybind.h>
5#include <torch/csrc/utils/python_arg_parser.h>
6
7#include <structmember.h>
8
9PyTypeObject* THPStreamClass = nullptr;
10
11static PyObject* THPStream_pynew(
12 PyTypeObject* type,
13 PyObject* args,
14 PyObject* kwargs) {
15 HANDLE_TH_ERRORS
16 int64_t stream_id = 0;
17 int64_t device_index = 0;
18 int64_t device_type = 0;
19 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
20 constexpr char* kwlist[] = {
21 "stream_id", "device_index", "device_type", nullptr};
22 if (!PyArg_ParseTupleAndKeywords(
23 args,
24 kwargs,
25 "|LLL",
26 const_cast<char**>(kwlist),
27 &stream_id,
28 &device_index,
29 &device_type)) {
30 return nullptr;
31 }
32
33 THPObjectPtr ptr(type->tp_alloc(type, 0));
34 if (!ptr) {
35 return nullptr;
36 }
37
38 THPStream* self = (THPStream*)ptr.get();
39 self->stream_id = stream_id;
40 self->device_index = device_index;
41 self->device_type = device_type;
42 return (PyObject*)ptr.release();
43 END_HANDLE_TH_ERRORS
44}
45
46static void THPStream_dealloc(THPStream* self) {
47 Py_TYPE(self)->tp_free((PyObject*)self);
48}
49
50static PyObject* THPStream_get_device(THPStream* self, void* unused) {
51 HANDLE_TH_ERRORS
52 return THPDevice_New(c10::Stream::unpack3(
53 self->stream_id,
54 self->device_index,
55 static_cast<c10::DeviceType>(self->device_type))
56 .device());
57 END_HANDLE_TH_ERRORS
58}
59
60static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
61 HANDLE_TH_ERRORS
62 return PyBool_FromLong(
63 self->stream_id == other->stream_id &&
64 self->device_index == other->device_index &&
65 self->device_type == other->device_type);
66 END_HANDLE_TH_ERRORS
67}
68
69// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
70static struct PyMemberDef THPStream_members[] = {
71 {(char*)"stream_id",
72 T_LONGLONG,
73 offsetof(THPStream, stream_id),
74 READONLY,
75 nullptr},
76 {(char*)"device_index",
77 T_LONGLONG,
78 offsetof(THPStream, device_index),
79 READONLY,
80 nullptr},
81 {(char*)"device_type",
82 T_LONGLONG,
83 offsetof(THPStream, device_type),
84 READONLY,
85 nullptr},
86 {nullptr}};
87
88// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
89static struct PyGetSetDef THPStream_properties[] = {
90 {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr},
91 {nullptr}};
92
93// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
94static PyMethodDef THPStream_methods[] = {
95 {(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
96 {nullptr}};
97
98PyTypeObject THPStreamType = {
99 PyVarObject_HEAD_INIT(nullptr, 0) "torch.Stream", /* tp_name */
100 sizeof(THPStream), /* tp_basicsize */
101 0, /* tp_itemsize */
102 (destructor)THPStream_dealloc, /* tp_dealloc */
103 0, /* tp_vectorcall_offset */
104 nullptr, /* tp_getattr */
105 nullptr, /* tp_setattr */
106 nullptr, /* tp_reserved */
107 nullptr, /* tp_repr */
108 nullptr, /* tp_as_number */
109 nullptr, /* tp_as_sequence */
110 nullptr, /* tp_as_mapping */
111 nullptr, /* tp_hash */
112 nullptr, /* tp_call */
113 nullptr, /* tp_str */
114 nullptr, /* tp_getattro */
115 nullptr, /* tp_setattro */
116 nullptr, /* tp_as_buffer */
117 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
118 nullptr, /* tp_doc */
119 nullptr, /* tp_traverse */
120 nullptr, /* tp_clear */
121 nullptr, /* tp_richcompare */
122 0, /* tp_weaklistoffset */
123 nullptr, /* tp_iter */
124 nullptr, /* tp_iternext */
125 THPStream_methods, /* tp_methods */
126 THPStream_members, /* tp_members */
127 THPStream_properties, /* tp_getset */
128 nullptr, /* tp_base */
129 nullptr, /* tp_dict */
130 nullptr, /* tp_descr_get */
131 nullptr, /* tp_descr_set */
132 0, /* tp_dictoffset */
133 nullptr, /* tp_init */
134 nullptr, /* tp_alloc */
135 THPStream_pynew, /* tp_new */
136};
137
138void THPStream_init(PyObject* module) {
139 THPStreamClass = &THPStreamType;
140 Py_SET_TYPE(&THPStreamType, &PyType_Type);
141 if (PyType_Ready(&THPStreamType) < 0) {
142 throw python_error();
143 }
144 Py_INCREF(&THPStreamType);
145 if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) {
146 throw python_error();
147 }
148}
149