1 | #include <torch/csrc/Generator.h> |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/CPUGeneratorImpl.h> |
5 | #include <structmember.h> |
6 | |
7 | #include <torch/csrc/Device.h> |
8 | #include <torch/csrc/Exceptions.h> |
9 | #include <torch/csrc/THP.h> |
10 | #include <torch/csrc/autograd/generated/VariableType.h> |
11 | #include <torch/csrc/autograd/generated/variable_factories.h> |
12 | #include <torch/csrc/autograd/python_variable.h> |
13 | #include <torch/csrc/utils/python_arg_parser.h> |
14 | #include <torch/csrc/utils/tensor_types.h> |
15 | |
16 | #ifdef USE_CUDA |
17 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
18 | #endif |
19 | |
20 | #ifdef USE_MPS |
21 | #include <ATen/mps/MPSGeneratorImpl.h> |
22 | #endif |
23 | |
24 | using namespace at; |
25 | using namespace torch; |
26 | |
27 | PyObject* THPGeneratorClass = nullptr; |
28 | |
29 | PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) { |
30 | auto type = (PyTypeObject*)THPGeneratorClass; |
31 | auto self = THPObjectPtr{type->tp_alloc(type, 0)}; |
32 | if (!self) |
33 | throw python_error(); |
34 | auto self_ = reinterpret_cast<THPGenerator*>(self.get()); |
35 | self_->cdata = cdata; |
36 | return self.release(); |
37 | } |
38 | |
39 | static void THPGenerator_dealloc(PyObject* _self) { |
40 | auto self = reinterpret_cast<THPGenerator*>(_self); |
41 | if (self->cdata.defined()) { |
42 | self->cdata.set_pyobj(nullptr); |
43 | self->cdata.~Generator(); |
44 | } |
45 | Py_TYPE(_self)->tp_free(_self); |
46 | } |
47 | |
48 | static PyObject* THPGenerator_pynew( |
49 | PyTypeObject* type, |
50 | PyObject* args, |
51 | PyObject* kwargs) { |
52 | HANDLE_TH_ERRORS |
53 | static torch::PythonArgParser parser({"Generator(Device device=None)" }); |
54 | torch::ParsedArgs<1> parsed_args; |
55 | auto r = parser.parse(args, kwargs, parsed_args); |
56 | auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); |
57 | |
58 | THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0)); |
59 | if (device.type() == at::kCPU) { |
60 | self->cdata = make_generator<CPUGeneratorImpl>(); |
61 | } |
62 | #ifdef USE_CUDA |
63 | else if (device.type() == at::kCUDA) { |
64 | self->cdata = make_generator<CUDAGeneratorImpl>(device.index()); |
65 | } |
66 | #elif USE_MPS |
67 | else if (device.type() == at::kMPS) { |
68 | self->cdata = make_generator<MPSGeneratorImpl>(); |
69 | } |
70 | #endif |
71 | else { |
72 | AT_ERROR( |
73 | "Device type " , |
74 | c10::DeviceTypeName(device.type()), |
75 | " is not supported for torch.Generator() api." ); |
76 | } |
77 | return (PyObject*)self.release(); |
78 | END_HANDLE_TH_ERRORS |
79 | } |
80 | |
81 | static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) { |
82 | using namespace torch::autograd; |
83 | HANDLE_TH_ERRORS |
84 | auto& gen = ((THPGenerator*)_self)->cdata; |
85 | |
86 | // See Note [Acquire lock when using random generators] |
87 | std::lock_guard<std::mutex> lock(gen.mutex()); |
88 | auto state_tensor = gen.get_state(); |
89 | |
90 | return THPVariable_Wrap(std::move(state_tensor)); |
91 | END_HANDLE_TH_ERRORS |
92 | } |
93 | |
94 | static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) { |
95 | using namespace torch::autograd; |
96 | |
97 | HANDLE_TH_ERRORS |
98 | if (!THPVariable_Check(_new_state)) { |
99 | throw torch::TypeError( |
100 | "expected a torch.ByteTensor, but got %s" , |
101 | Py_TYPE(_new_state)->tp_name); |
102 | } |
103 | auto self = (THPGenerator*)_self; |
104 | auto& gen = self->cdata; |
105 | const auto& new_state_tensor = THPVariable_Unpack(_new_state); |
106 | |
107 | // See Note [Acquire lock when using random generators] |
108 | std::lock_guard<std::mutex> lock(gen.mutex()); |
109 | gen.set_state(new_state_tensor); |
110 | |
111 | Py_INCREF(self); |
112 | return (PyObject*)self; |
113 | END_HANDLE_TH_ERRORS |
114 | } |
115 | |
116 | static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) { |
117 | HANDLE_TH_ERRORS |
118 | auto self = (THPGenerator*)_self; |
119 | auto generator = self->cdata; |
120 | THPUtils_assert( |
121 | THPUtils_checkLong(seed), |
122 | "manual_seed expected a long, " |
123 | "but got %s" , |
124 | THPUtils_typename(seed)); |
125 | // See Note [Acquire lock when using random generators] |
126 | std::lock_guard<std::mutex> lock(generator.mutex()); |
127 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
128 | uint64_t seed_unpacked; |
129 | try { |
130 | // First try to interpret as unsigned long |
131 | seed_unpacked = THPUtils_unpackUInt64(seed); |
132 | } catch (...) { |
133 | if (PyErr_ExceptionMatches(PyExc_OverflowError)) { |
134 | // If an overflow happened, then the seed could be negative, |
135 | // so try to interpret it as signed long |
136 | PyErr_Clear(); |
137 | int64_t seed_unpacked_signed = THPUtils_unpackLong(seed); |
138 | seed_unpacked = *(reinterpret_cast<uint64_t*>(&seed_unpacked_signed)); |
139 | } else { |
140 | // If any other type of exception happened, rethrow it |
141 | throw; |
142 | } |
143 | } |
144 | generator.set_current_seed(seed_unpacked); |
145 | Py_INCREF(self); |
146 | return (PyObject*)self; |
147 | END_HANDLE_TH_ERRORS |
148 | } |
149 | |
150 | static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) { |
151 | HANDLE_TH_ERRORS |
152 | // See Note [Acquire lock when using random generators] |
153 | auto self = (THPGenerator*)_self; |
154 | std::lock_guard<std::mutex> lock(self->cdata.mutex()); |
155 | uint64_t seed_val = self->cdata.seed(); |
156 | return THPUtils_packUInt64(seed_val); |
157 | END_HANDLE_TH_ERRORS |
158 | } |
159 | |
160 | static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) { |
161 | HANDLE_TH_ERRORS |
162 | auto self = (THPGenerator*)_self; |
163 | return THPUtils_packUInt64(self->cdata.current_seed()); |
164 | END_HANDLE_TH_ERRORS |
165 | } |
166 | |
167 | static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) { |
168 | HANDLE_TH_ERRORS |
169 | return THPDevice_New(self->cdata.device()); |
170 | END_HANDLE_TH_ERRORS |
171 | } |
172 | |
173 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
174 | static struct PyGetSetDef THPGenerator_properties[] = { |
175 | {"device" , (getter)THPGenerator_get_device, nullptr, nullptr, nullptr}, |
176 | {nullptr}}; |
177 | |
178 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
179 | static PyMethodDef THPGenerator_methods[] = { |
180 | {"get_state" , THPGenerator_getState, METH_NOARGS, nullptr}, |
181 | {"set_state" , THPGenerator_setState, METH_O, nullptr}, |
182 | {"manual_seed" , THPGenerator_manualSeed, METH_O, nullptr}, |
183 | {"seed" , THPGenerator_seed, METH_NOARGS, nullptr}, |
184 | {"initial_seed" , THPGenerator_initialSeed, METH_NOARGS, nullptr}, |
185 | {nullptr}}; |
186 | |
187 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
188 | static struct PyMemberDef THPGenerator_members[] = { |
189 | {(char*)"_cdata" , |
190 | T_ULONGLONG, |
191 | offsetof(THPGenerator, cdata), |
192 | READONLY, |
193 | nullptr}, |
194 | {nullptr}}; |
195 | |
196 | PyTypeObject THPGeneratorType = { |
197 | PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.Generator" , /* tp_name */ |
198 | sizeof(THPGenerator), /* tp_basicsize */ |
199 | 0, /* tp_itemsize */ |
200 | THPGenerator_dealloc, /* tp_dealloc */ |
201 | 0, /* tp_vectorcall_offset */ |
202 | nullptr, /* tp_getattr */ |
203 | nullptr, /* tp_setattr */ |
204 | nullptr, /* tp_reserved */ |
205 | nullptr, /* tp_repr */ |
206 | nullptr, /* tp_as_number */ |
207 | nullptr, /* tp_as_sequence */ |
208 | nullptr, /* tp_as_mapping */ |
209 | nullptr, /* tp_hash */ |
210 | nullptr, /* tp_call */ |
211 | nullptr, /* tp_str */ |
212 | nullptr, /* tp_getattro */ |
213 | nullptr, /* tp_setattro */ |
214 | nullptr, /* tp_as_buffer */ |
215 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
216 | nullptr, /* tp_doc */ |
217 | nullptr, /* tp_traverse */ |
218 | nullptr, /* tp_clear */ |
219 | nullptr, /* tp_richcompare */ |
220 | 0, /* tp_weaklistoffset */ |
221 | nullptr, /* tp_iter */ |
222 | nullptr, /* tp_iternext */ |
223 | THPGenerator_methods, /* tp_methods */ |
224 | THPGenerator_members, /* tp_members */ |
225 | THPGenerator_properties, /* tp_getset */ |
226 | nullptr, /* tp_base */ |
227 | nullptr, /* tp_dict */ |
228 | nullptr, /* tp_descr_get */ |
229 | nullptr, /* tp_descr_set */ |
230 | 0, /* tp_dictoffset */ |
231 | nullptr, /* tp_init */ |
232 | nullptr, /* tp_alloc */ |
233 | THPGenerator_pynew, /* tp_new */ |
234 | }; |
235 | |
236 | bool THPGenerator_init(PyObject* module) { |
237 | THPGeneratorClass = (PyObject*)&THPGeneratorType; |
238 | if (PyType_Ready(&THPGeneratorType) < 0) |
239 | return false; |
240 | Py_INCREF(&THPGeneratorType); |
241 | PyModule_AddObject(module, "Generator" , (PyObject*)&THPGeneratorType); |
242 | return true; |
243 | } |
244 | |
245 | void set_pyobj(const Generator& self, PyObject* pyobj) { |
246 | TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined generator" ); |
247 | self.set_pyobj(pyobj); |
248 | } |
249 | |
250 | PyObject* pyobj(const Generator& self) { |
251 | TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined generator" ); |
252 | return self.pyobj(); |
253 | } |
254 | |
255 | PyObject* THPGenerator_Wrap(Generator gen) { |
256 | if (!gen.defined()) { |
257 | Py_RETURN_NONE; |
258 | } |
259 | |
260 | if (auto obj = pyobj(gen)) { |
261 | Py_INCREF(obj); |
262 | return obj; |
263 | } |
264 | |
265 | return THPGenerator_NewWithVar( |
266 | (PyTypeObject*)THPGeneratorClass, std::move(gen)); |
267 | } |
268 | |
269 | // Creates a new Python object for a Generator. The Generator must not already |
270 | // have a PyObject* associated with it. |
271 | PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) { |
272 | PyObject* obj = type->tp_alloc(type, 0); |
273 | if (obj) { |
274 | auto g = (THPGenerator*)obj; |
275 | new (&g->cdata) Generator(std::move(gen)); |
276 | set_pyobj(g->cdata, obj); |
277 | } |
278 | return obj; |
279 | } |
280 | |