1 | #include <torch/csrc/Device.h> |
2 | |
3 | #include <torch/csrc/Exceptions.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | #include <torch/csrc/utils/python_arg_parser.h> |
7 | #include <torch/csrc/utils/python_numbers.h> |
8 | #include <torch/csrc/utils/python_strings.h> |
9 | |
10 | #include <ATen/Device.h> |
11 | #include <c10/util/Exception.h> |
12 | |
13 | #include <structmember.h> |
14 | #include <cstring> |
15 | #include <limits> |
16 | #include <sstream> |
17 | |
18 | PyObject* THPDevice_New(const at::Device& device) { |
19 | auto type = (PyTypeObject*)&THPDeviceType; |
20 | auto self = THPObjectPtr{type->tp_alloc(type, 0)}; |
21 | if (!self) |
22 | throw python_error(); |
23 | auto self_ = reinterpret_cast<THPDevice*>(self.get()); |
24 | self_->device = device; |
25 | return self.release(); |
26 | } |
27 | |
28 | PyObject* THPDevice_repr(THPDevice* self) { |
29 | std::ostringstream oss; |
30 | oss << "device(type=\'" << self->device.type() << "\'" ; |
31 | if (self->device.has_index()) { |
32 | // `self->device.index()` returns uint8_t which is treated as ascii while |
33 | // printing, hence casting it to uint16_t. |
34 | // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout |
35 | oss << ", index=" << static_cast<uint16_t>(self->device.index()); |
36 | } |
37 | oss << ")" ; |
38 | return THPUtils_packString(oss.str().c_str()); |
39 | } |
40 | |
41 | PyObject* THPDevice_str(THPDevice* self) { |
42 | std::ostringstream oss; |
43 | oss << self->device; |
44 | return THPUtils_packString(oss.str().c_str()); |
45 | } |
46 | |
47 | PyObject* THPDevice_pynew( |
48 | PyTypeObject* type, |
49 | PyObject* args, |
50 | PyObject* kwargs) { |
51 | HANDLE_TH_ERRORS |
52 | static torch::PythonArgParser parser( |
53 | {"Device(Device device)" , |
54 | "Device(c10::string_view type, int64_t? index=-1)" }); |
55 | torch::ParsedArgs<2> parsed_args; |
56 | auto r = parser.parse(args, kwargs, parsed_args); |
57 | if (r.idx == 0) { |
58 | auto device = r.device(0); |
59 | return THPDevice_New(device); |
60 | } else if (r.idx == 1) { |
61 | auto as_device = r.device(0); // this works, because device can take strings |
62 | auto device_type = r.string(0); |
63 | if (as_device.has_index()) { |
64 | throw std::runtime_error( |
65 | "type (string) must not include an index because index " |
66 | "was passed explicitly: " + |
67 | device_type); |
68 | } |
69 | int32_t device_index = -1; |
70 | if (!r.isNone(1)) { |
71 | device_index = r.toInt64(1); |
72 | // -1 is allowed in ATen/C++, to mean the default device, but not in |
73 | // Python. |
74 | TORCH_CHECK(device_index >= 0, "Device index must not be negative" ); |
75 | } |
76 | at::Device device(as_device.type(), device_index); |
77 | return THPDevice_New(device); |
78 | } |
79 | Py_RETURN_NONE; |
80 | END_HANDLE_TH_ERRORS |
81 | } |
82 | |
83 | PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) { |
84 | HANDLE_TH_ERRORS |
85 | std::ostringstream oss; |
86 | oss << self->device.type(); |
87 | return THPUtils_packString(oss.str().c_str()); |
88 | Py_RETURN_NONE; |
89 | END_HANDLE_TH_ERRORS |
90 | } |
91 | |
92 | PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) { |
93 | HANDLE_TH_ERRORS |
94 | if (self->device.has_index()) { |
95 | return THPUtils_packInt64(self->device.index()); |
96 | } else { |
97 | Py_RETURN_NONE; |
98 | } |
99 | END_HANDLE_TH_ERRORS |
100 | } |
101 | |
102 | static Py_ssize_t THPDevice_hash(THPDevice* self) { |
103 | HANDLE_TH_ERRORS |
104 | return static_cast<Py_ssize_t>( |
105 | std::hash<at::Device>{}(self->device) % |
106 | std::numeric_limits<Py_ssize_t>::max()); |
107 | END_HANDLE_TH_ERRORS_RET(-1) |
108 | } |
109 | |
110 | PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) { |
111 | HANDLE_TH_ERRORS |
112 | if (!THPDevice_Check(a) || !THPDevice_Check(b)) { |
113 | // Py_RETURN_NOTIMPLEMENTED not in python 2. |
114 | Py_INCREF(Py_NotImplemented); |
115 | return Py_NotImplemented; |
116 | } |
117 | THPDevice* da = reinterpret_cast<THPDevice*>(a); |
118 | THPDevice* db = reinterpret_cast<THPDevice*>(b); |
119 | |
120 | switch (op) { |
121 | case Py_EQ: |
122 | if (da->device == db->device) { |
123 | Py_RETURN_TRUE; |
124 | } else { |
125 | Py_RETURN_FALSE; |
126 | } |
127 | case Py_NE: |
128 | if (da->device == db->device) { |
129 | Py_RETURN_FALSE; |
130 | } else { |
131 | Py_RETURN_TRUE; |
132 | } |
133 | case Py_LT: |
134 | case Py_LE: |
135 | case Py_GT: |
136 | case Py_GE: |
137 | throw torch::TypeError("comparison not implemented" ); |
138 | default: |
139 | throw torch::TypeError("unexpected comparison op" ); |
140 | } |
141 | END_HANDLE_TH_ERRORS |
142 | } |
143 | |
144 | PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) { |
145 | HANDLE_TH_ERRORS |
146 | auto self = (THPDevice*)_self; |
147 | auto ret = THPObjectPtr{PyTuple_New(2)}; |
148 | if (!ret) |
149 | throw python_error(); |
150 | |
151 | py::object torch_module = py::module::import("torch" ); |
152 | py::object torch_device = torch_module.attr("device" ); |
153 | PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr()); |
154 | |
155 | THPObjectPtr args; |
156 | std::ostringstream oss; |
157 | oss << self->device.type(); |
158 | if (self->device.has_index()) { |
159 | args = THPObjectPtr{ |
160 | Py_BuildValue("(si)" , oss.str().c_str(), self->device.index())}; |
161 | } else { |
162 | args = THPObjectPtr{Py_BuildValue("(s)" , oss.str().c_str())}; |
163 | } |
164 | if (!args) |
165 | throw python_error(); |
166 | PyTuple_SET_ITEM(ret.get(), 1, args.release()); |
167 | |
168 | return ret.release(); |
169 | END_HANDLE_TH_ERRORS |
170 | } |
171 | |
172 | PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) { |
173 | HANDLE_TH_ERRORS |
174 | py::object mode = py::module::import("torch.utils._device" ) |
175 | .attr("DeviceContext" )(py::handle(self)); |
176 | at::impl::PythonTorchFunctionTLS::push_onto_stack( |
177 | std::make_shared<c10::SafePyObject>( |
178 | mode.release().ptr(), getPyInterpreter())); |
179 | // So that with torch.device('cuda') as dev: works |
180 | Py_INCREF(self); |
181 | return self; |
182 | END_HANDLE_TH_ERRORS |
183 | } |
184 | |
185 | PyObject* THPDevice_exit(PyObject* self, PyObject* unused) { |
186 | HANDLE_TH_ERRORS |
187 | at::impl::PythonTorchFunctionTLS::pop_stack(); |
188 | Py_RETURN_NONE; |
189 | END_HANDLE_TH_ERRORS |
190 | } |
191 | |
192 | PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) { |
193 | HANDLE_TH_ERRORS |
194 | py::object deco = |
195 | py::module::import("torch.utils._device" ).attr("device_decorator" ); |
196 | return deco(py::handle(self), *py::handle(args), **py::handle(kwargs)) |
197 | .release() |
198 | .ptr(); |
199 | END_HANDLE_TH_ERRORS |
200 | } |
201 | |
202 | typedef PyObject* (*getter)(PyObject*, void*); |
203 | |
204 | // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in |
205 | |
206 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
207 | static struct PyGetSetDef THPDevice_properties[] = { |
208 | {"type" , (getter)THPDevice_type, nullptr, nullptr, nullptr}, |
209 | {"index" , (getter)THPDevice_index, nullptr, nullptr, nullptr}, |
210 | {nullptr}}; |
211 | |
212 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
213 | static PyMethodDef THPDevice_methods[] = { |
214 | {"__reduce__" , THPDevice_reduce, METH_NOARGS, nullptr}, |
215 | {"__enter__" , THPDevice_enter, METH_NOARGS, nullptr}, |
216 | {"__exit__" , THPDevice_exit, METH_VARARGS, nullptr}, |
217 | {nullptr} /* Sentinel */ |
218 | }; |
219 | |
220 | PyTypeObject THPDeviceType = { |
221 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.device" , /* tp_name */ |
222 | sizeof(THPDevice), /* tp_basicsize */ |
223 | 0, /* tp_itemsize */ |
224 | nullptr, /* tp_dealloc */ |
225 | 0, /* tp_vectorcall_offset */ |
226 | nullptr, /* tp_getattr */ |
227 | nullptr, /* tp_setattr */ |
228 | nullptr, /* tp_reserved */ |
229 | (reprfunc)THPDevice_repr, /* tp_repr */ |
230 | nullptr, /* tp_as_number */ |
231 | nullptr, /* tp_as_sequence */ |
232 | nullptr, /* tp_as_mapping */ |
233 | (hashfunc)THPDevice_hash, /* tp_hash */ |
234 | // TODO: We're not sure if this is a good idea or not, because making |
235 | // torch.device callable means that it will start returning true |
236 | // for callable() queries, and that is unexpected. We can always add |
237 | // this later, so for now, don't actually implement this |
238 | // THPDevice_call, /* tp_call */ |
239 | nullptr, /* tp_call */ |
240 | (reprfunc)THPDevice_str, /* tp_str */ |
241 | nullptr, /* tp_getattro */ |
242 | nullptr, /* tp_setattro */ |
243 | nullptr, /* tp_as_buffer */ |
244 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
245 | nullptr, /* tp_doc */ |
246 | nullptr, /* tp_traverse */ |
247 | nullptr, /* tp_clear */ |
248 | (richcmpfunc)THPDevice_rc, /* tp_richcompare */ |
249 | 0, /* tp_weaklistoffset */ |
250 | nullptr, /* tp_iter */ |
251 | nullptr, /* tp_iternext */ |
252 | THPDevice_methods, /* tp_methods */ |
253 | nullptr, /* tp_members */ |
254 | THPDevice_properties, /* tp_getset */ |
255 | nullptr, /* tp_base */ |
256 | nullptr, /* tp_dict */ |
257 | nullptr, /* tp_descr_get */ |
258 | nullptr, /* tp_descr_set */ |
259 | 0, /* tp_dictoffset */ |
260 | nullptr, /* tp_init */ |
261 | nullptr, /* tp_alloc */ |
262 | THPDevice_pynew, /* tp_new */ |
263 | }; |
264 | |
265 | void THPDevice_init(PyObject* module) { |
266 | if (PyType_Ready(&THPDeviceType) < 0) { |
267 | throw python_error(); |
268 | } |
269 | Py_INCREF(&THPDeviceType); |
270 | if (PyModule_AddObject(module, "device" , (PyObject*)&THPDeviceType) != 0) { |
271 | throw python_error(); |
272 | } |
273 | } |
274 | |