1 | #include <pybind11/pybind11.h> |
2 | #include <torch/csrc/Device.h> |
3 | #include <torch/csrc/THP.h> |
4 | #include <torch/csrc/utils/pybind.h> |
5 | #include <torch/csrc/utils/python_arg_parser.h> |
6 | |
7 | #include <structmember.h> |
8 | |
9 | PyTypeObject* THPStreamClass = nullptr; |
10 | |
11 | static PyObject* THPStream_pynew( |
12 | PyTypeObject* type, |
13 | PyObject* args, |
14 | PyObject* kwargs) { |
15 | HANDLE_TH_ERRORS |
16 | int64_t stream_id = 0; |
17 | int64_t device_index = 0; |
18 | int64_t device_type = 0; |
19 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
20 | constexpr char* kwlist[] = { |
21 | "stream_id" , "device_index" , "device_type" , nullptr}; |
22 | if (!PyArg_ParseTupleAndKeywords( |
23 | args, |
24 | kwargs, |
25 | "|LLL" , |
26 | const_cast<char**>(kwlist), |
27 | &stream_id, |
28 | &device_index, |
29 | &device_type)) { |
30 | return nullptr; |
31 | } |
32 | |
33 | THPObjectPtr ptr(type->tp_alloc(type, 0)); |
34 | if (!ptr) { |
35 | return nullptr; |
36 | } |
37 | |
38 | THPStream* self = (THPStream*)ptr.get(); |
39 | self->stream_id = stream_id; |
40 | self->device_index = device_index; |
41 | self->device_type = device_type; |
42 | return (PyObject*)ptr.release(); |
43 | END_HANDLE_TH_ERRORS |
44 | } |
45 | |
46 | static void THPStream_dealloc(THPStream* self) { |
47 | Py_TYPE(self)->tp_free((PyObject*)self); |
48 | } |
49 | |
50 | static PyObject* THPStream_get_device(THPStream* self, void* unused) { |
51 | HANDLE_TH_ERRORS |
52 | return THPDevice_New(c10::Stream::unpack3( |
53 | self->stream_id, |
54 | self->device_index, |
55 | static_cast<c10::DeviceType>(self->device_type)) |
56 | .device()); |
57 | END_HANDLE_TH_ERRORS |
58 | } |
59 | |
60 | static PyObject* THPStream_eq(THPStream* self, THPStream* other) { |
61 | HANDLE_TH_ERRORS |
62 | return PyBool_FromLong( |
63 | self->stream_id == other->stream_id && |
64 | self->device_index == other->device_index && |
65 | self->device_type == other->device_type); |
66 | END_HANDLE_TH_ERRORS |
67 | } |
68 | |
69 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
70 | static struct PyMemberDef THPStream_members[] = { |
71 | {(char*)"stream_id" , |
72 | T_LONGLONG, |
73 | offsetof(THPStream, stream_id), |
74 | READONLY, |
75 | nullptr}, |
76 | {(char*)"device_index" , |
77 | T_LONGLONG, |
78 | offsetof(THPStream, device_index), |
79 | READONLY, |
80 | nullptr}, |
81 | {(char*)"device_type" , |
82 | T_LONGLONG, |
83 | offsetof(THPStream, device_type), |
84 | READONLY, |
85 | nullptr}, |
86 | {nullptr}}; |
87 | |
88 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
89 | static struct PyGetSetDef THPStream_properties[] = { |
90 | {"device" , (getter)THPStream_get_device, nullptr, nullptr, nullptr}, |
91 | {nullptr}}; |
92 | |
93 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
94 | static PyMethodDef THPStream_methods[] = { |
95 | {(char*)"__eq__" , (PyCFunction)THPStream_eq, METH_O, nullptr}, |
96 | {nullptr}}; |
97 | |
98 | PyTypeObject THPStreamType = { |
99 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.Stream" , /* tp_name */ |
100 | sizeof(THPStream), /* tp_basicsize */ |
101 | 0, /* tp_itemsize */ |
102 | (destructor)THPStream_dealloc, /* tp_dealloc */ |
103 | 0, /* tp_vectorcall_offset */ |
104 | nullptr, /* tp_getattr */ |
105 | nullptr, /* tp_setattr */ |
106 | nullptr, /* tp_reserved */ |
107 | nullptr, /* tp_repr */ |
108 | nullptr, /* tp_as_number */ |
109 | nullptr, /* tp_as_sequence */ |
110 | nullptr, /* tp_as_mapping */ |
111 | nullptr, /* tp_hash */ |
112 | nullptr, /* tp_call */ |
113 | nullptr, /* tp_str */ |
114 | nullptr, /* tp_getattro */ |
115 | nullptr, /* tp_setattro */ |
116 | nullptr, /* tp_as_buffer */ |
117 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
118 | nullptr, /* tp_doc */ |
119 | nullptr, /* tp_traverse */ |
120 | nullptr, /* tp_clear */ |
121 | nullptr, /* tp_richcompare */ |
122 | 0, /* tp_weaklistoffset */ |
123 | nullptr, /* tp_iter */ |
124 | nullptr, /* tp_iternext */ |
125 | THPStream_methods, /* tp_methods */ |
126 | THPStream_members, /* tp_members */ |
127 | THPStream_properties, /* tp_getset */ |
128 | nullptr, /* tp_base */ |
129 | nullptr, /* tp_dict */ |
130 | nullptr, /* tp_descr_get */ |
131 | nullptr, /* tp_descr_set */ |
132 | 0, /* tp_dictoffset */ |
133 | nullptr, /* tp_init */ |
134 | nullptr, /* tp_alloc */ |
135 | THPStream_pynew, /* tp_new */ |
136 | }; |
137 | |
138 | void THPStream_init(PyObject* module) { |
139 | THPStreamClass = &THPStreamType; |
140 | Py_SET_TYPE(&THPStreamType, &PyType_Type); |
141 | if (PyType_Ready(&THPStreamType) < 0) { |
142 | throw python_error(); |
143 | } |
144 | Py_INCREF(&THPStreamType); |
145 | if (PyModule_AddObject(module, "Stream" , (PyObject*)&THPStreamType) < 0) { |
146 | throw python_error(); |
147 | } |
148 | } |
149 | |