1#include <torch/csrc/Device.h>
2
3#include <torch/csrc/Exceptions.h>
4#include <torch/csrc/utils/object_ptr.h>
5#include <torch/csrc/utils/pybind.h>
6#include <torch/csrc/utils/python_arg_parser.h>
7#include <torch/csrc/utils/python_numbers.h>
8#include <torch/csrc/utils/python_strings.h>
9
10#include <ATen/Device.h>
11#include <c10/util/Exception.h>
12
13#include <structmember.h>
14#include <cstring>
15#include <limits>
16#include <sstream>
17
18PyObject* THPDevice_New(const at::Device& device) {
19 auto type = (PyTypeObject*)&THPDeviceType;
20 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
21 if (!self)
22 throw python_error();
23 auto self_ = reinterpret_cast<THPDevice*>(self.get());
24 self_->device = device;
25 return self.release();
26}
27
28PyObject* THPDevice_repr(THPDevice* self) {
29 std::ostringstream oss;
30 oss << "device(type=\'" << self->device.type() << "\'";
31 if (self->device.has_index()) {
32 // `self->device.index()` returns uint8_t which is treated as ascii while
33 // printing, hence casting it to uint16_t.
34 // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
35 oss << ", index=" << static_cast<uint16_t>(self->device.index());
36 }
37 oss << ")";
38 return THPUtils_packString(oss.str().c_str());
39}
40
41PyObject* THPDevice_str(THPDevice* self) {
42 std::ostringstream oss;
43 oss << self->device;
44 return THPUtils_packString(oss.str().c_str());
45}
46
47PyObject* THPDevice_pynew(
48 PyTypeObject* type,
49 PyObject* args,
50 PyObject* kwargs) {
51 HANDLE_TH_ERRORS
52 static torch::PythonArgParser parser(
53 {"Device(Device device)",
54 "Device(c10::string_view type, int64_t? index=-1)"});
55 torch::ParsedArgs<2> parsed_args;
56 auto r = parser.parse(args, kwargs, parsed_args);
57 if (r.idx == 0) {
58 auto device = r.device(0);
59 return THPDevice_New(device);
60 } else if (r.idx == 1) {
61 auto as_device = r.device(0); // this works, because device can take strings
62 auto device_type = r.string(0);
63 if (as_device.has_index()) {
64 throw std::runtime_error(
65 "type (string) must not include an index because index "
66 "was passed explicitly: " +
67 device_type);
68 }
69 int32_t device_index = -1;
70 if (!r.isNone(1)) {
71 device_index = r.toInt64(1);
72 // -1 is allowed in ATen/C++, to mean the default device, but not in
73 // Python.
74 TORCH_CHECK(device_index >= 0, "Device index must not be negative");
75 }
76 at::Device device(as_device.type(), device_index);
77 return THPDevice_New(device);
78 }
79 Py_RETURN_NONE;
80 END_HANDLE_TH_ERRORS
81}
82
83PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
84 HANDLE_TH_ERRORS
85 std::ostringstream oss;
86 oss << self->device.type();
87 return THPUtils_packString(oss.str().c_str());
88 Py_RETURN_NONE;
89 END_HANDLE_TH_ERRORS
90}
91
92PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
93 HANDLE_TH_ERRORS
94 if (self->device.has_index()) {
95 return THPUtils_packInt64(self->device.index());
96 } else {
97 Py_RETURN_NONE;
98 }
99 END_HANDLE_TH_ERRORS
100}
101
102static Py_ssize_t THPDevice_hash(THPDevice* self) {
103 HANDLE_TH_ERRORS
104 return static_cast<Py_ssize_t>(
105 std::hash<at::Device>{}(self->device) %
106 std::numeric_limits<Py_ssize_t>::max());
107 END_HANDLE_TH_ERRORS_RET(-1)
108}
109
110PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
111 HANDLE_TH_ERRORS
112 if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
113 // Py_RETURN_NOTIMPLEMENTED not in python 2.
114 Py_INCREF(Py_NotImplemented);
115 return Py_NotImplemented;
116 }
117 THPDevice* da = reinterpret_cast<THPDevice*>(a);
118 THPDevice* db = reinterpret_cast<THPDevice*>(b);
119
120 switch (op) {
121 case Py_EQ:
122 if (da->device == db->device) {
123 Py_RETURN_TRUE;
124 } else {
125 Py_RETURN_FALSE;
126 }
127 case Py_NE:
128 if (da->device == db->device) {
129 Py_RETURN_FALSE;
130 } else {
131 Py_RETURN_TRUE;
132 }
133 case Py_LT:
134 case Py_LE:
135 case Py_GT:
136 case Py_GE:
137 throw torch::TypeError("comparison not implemented");
138 default:
139 throw torch::TypeError("unexpected comparison op");
140 }
141 END_HANDLE_TH_ERRORS
142}
143
144PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
145 HANDLE_TH_ERRORS
146 auto self = (THPDevice*)_self;
147 auto ret = THPObjectPtr{PyTuple_New(2)};
148 if (!ret)
149 throw python_error();
150
151 py::object torch_module = py::module::import("torch");
152 py::object torch_device = torch_module.attr("device");
153 PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
154
155 THPObjectPtr args;
156 std::ostringstream oss;
157 oss << self->device.type();
158 if (self->device.has_index()) {
159 args = THPObjectPtr{
160 Py_BuildValue("(si)", oss.str().c_str(), self->device.index())};
161 } else {
162 args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
163 }
164 if (!args)
165 throw python_error();
166 PyTuple_SET_ITEM(ret.get(), 1, args.release());
167
168 return ret.release();
169 END_HANDLE_TH_ERRORS
170}
171
172PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
173 HANDLE_TH_ERRORS
174 py::object mode = py::module::import("torch.utils._device")
175 .attr("DeviceContext")(py::handle(self));
176 at::impl::PythonTorchFunctionTLS::push_onto_stack(
177 std::make_shared<c10::SafePyObject>(
178 mode.release().ptr(), getPyInterpreter()));
179 // So that with torch.device('cuda') as dev: works
180 Py_INCREF(self);
181 return self;
182 END_HANDLE_TH_ERRORS
183}
184
185PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
186 HANDLE_TH_ERRORS
187 at::impl::PythonTorchFunctionTLS::pop_stack();
188 Py_RETURN_NONE;
189 END_HANDLE_TH_ERRORS
190}
191
192PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
193 HANDLE_TH_ERRORS
194 py::object deco =
195 py::module::import("torch.utils._device").attr("device_decorator");
196 return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
197 .release()
198 .ptr();
199 END_HANDLE_TH_ERRORS
200}
201
202typedef PyObject* (*getter)(PyObject*, void*);
203
204// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
205
206// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
207static struct PyGetSetDef THPDevice_properties[] = {
208 {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
209 {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
210 {nullptr}};
211
212// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
213static PyMethodDef THPDevice_methods[] = {
214 {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
215 {"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
216 {"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
217 {nullptr} /* Sentinel */
218};
219
220PyTypeObject THPDeviceType = {
221 PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
222 sizeof(THPDevice), /* tp_basicsize */
223 0, /* tp_itemsize */
224 nullptr, /* tp_dealloc */
225 0, /* tp_vectorcall_offset */
226 nullptr, /* tp_getattr */
227 nullptr, /* tp_setattr */
228 nullptr, /* tp_reserved */
229 (reprfunc)THPDevice_repr, /* tp_repr */
230 nullptr, /* tp_as_number */
231 nullptr, /* tp_as_sequence */
232 nullptr, /* tp_as_mapping */
233 (hashfunc)THPDevice_hash, /* tp_hash */
234 // TODO: We're not sure if this is a good idea or not, because making
235 // torch.device callable means that it will start returning true
236 // for callable() queries, and that is unexpected. We can always add
237 // this later, so for now, don't actually implement this
238 // THPDevice_call, /* tp_call */
239 nullptr, /* tp_call */
240 (reprfunc)THPDevice_str, /* tp_str */
241 nullptr, /* tp_getattro */
242 nullptr, /* tp_setattro */
243 nullptr, /* tp_as_buffer */
244 Py_TPFLAGS_DEFAULT, /* tp_flags */
245 nullptr, /* tp_doc */
246 nullptr, /* tp_traverse */
247 nullptr, /* tp_clear */
248 (richcmpfunc)THPDevice_rc, /* tp_richcompare */
249 0, /* tp_weaklistoffset */
250 nullptr, /* tp_iter */
251 nullptr, /* tp_iternext */
252 THPDevice_methods, /* tp_methods */
253 nullptr, /* tp_members */
254 THPDevice_properties, /* tp_getset */
255 nullptr, /* tp_base */
256 nullptr, /* tp_dict */
257 nullptr, /* tp_descr_get */
258 nullptr, /* tp_descr_set */
259 0, /* tp_dictoffset */
260 nullptr, /* tp_init */
261 nullptr, /* tp_alloc */
262 THPDevice_pynew, /* tp_new */
263};
264
265void THPDevice_init(PyObject* module) {
266 if (PyType_Ready(&THPDeviceType) < 0) {
267 throw python_error();
268 }
269 Py_INCREF(&THPDeviceType);
270 if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
271 throw python_error();
272 }
273}
274