1#include <torch/csrc/Generator.h>
2
3#include <ATen/ATen.h>
4#include <ATen/CPUGeneratorImpl.h>
5#include <structmember.h>
6
7#include <torch/csrc/Device.h>
8#include <torch/csrc/Exceptions.h>
9#include <torch/csrc/THP.h>
10#include <torch/csrc/autograd/generated/VariableType.h>
11#include <torch/csrc/autograd/generated/variable_factories.h>
12#include <torch/csrc/autograd/python_variable.h>
13#include <torch/csrc/utils/python_arg_parser.h>
14#include <torch/csrc/utils/tensor_types.h>
15
16#ifdef USE_CUDA
17#include <ATen/cuda/CUDAGeneratorImpl.h>
18#endif
19
20#ifdef USE_MPS
21#include <ATen/mps/MPSGeneratorImpl.h>
22#endif
23
24using namespace at;
25using namespace torch;
26
27PyObject* THPGeneratorClass = nullptr;
28
29PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) {
30 auto type = (PyTypeObject*)THPGeneratorClass;
31 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
32 if (!self)
33 throw python_error();
34 auto self_ = reinterpret_cast<THPGenerator*>(self.get());
35 self_->cdata = cdata;
36 return self.release();
37}
38
39static void THPGenerator_dealloc(PyObject* _self) {
40 auto self = reinterpret_cast<THPGenerator*>(_self);
41 if (self->cdata.defined()) {
42 self->cdata.set_pyobj(nullptr);
43 self->cdata.~Generator();
44 }
45 Py_TYPE(_self)->tp_free(_self);
46}
47
48static PyObject* THPGenerator_pynew(
49 PyTypeObject* type,
50 PyObject* args,
51 PyObject* kwargs) {
52 HANDLE_TH_ERRORS
53 static torch::PythonArgParser parser({"Generator(Device device=None)"});
54 torch::ParsedArgs<1> parsed_args;
55 auto r = parser.parse(args, kwargs, parsed_args);
56 auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
57
58 THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
59 if (device.type() == at::kCPU) {
60 self->cdata = make_generator<CPUGeneratorImpl>();
61 }
62#ifdef USE_CUDA
63 else if (device.type() == at::kCUDA) {
64 self->cdata = make_generator<CUDAGeneratorImpl>(device.index());
65 }
66#elif USE_MPS
67 else if (device.type() == at::kMPS) {
68 self->cdata = make_generator<MPSGeneratorImpl>();
69 }
70#endif
71 else {
72 AT_ERROR(
73 "Device type ",
74 c10::DeviceTypeName(device.type()),
75 " is not supported for torch.Generator() api.");
76 }
77 return (PyObject*)self.release();
78 END_HANDLE_TH_ERRORS
79}
80
81static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) {
82 using namespace torch::autograd;
83 HANDLE_TH_ERRORS
84 auto& gen = ((THPGenerator*)_self)->cdata;
85
86 // See Note [Acquire lock when using random generators]
87 std::lock_guard<std::mutex> lock(gen.mutex());
88 auto state_tensor = gen.get_state();
89
90 return THPVariable_Wrap(std::move(state_tensor));
91 END_HANDLE_TH_ERRORS
92}
93
94static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
95 using namespace torch::autograd;
96
97 HANDLE_TH_ERRORS
98 if (!THPVariable_Check(_new_state)) {
99 throw torch::TypeError(
100 "expected a torch.ByteTensor, but got %s",
101 Py_TYPE(_new_state)->tp_name);
102 }
103 auto self = (THPGenerator*)_self;
104 auto& gen = self->cdata;
105 const auto& new_state_tensor = THPVariable_Unpack(_new_state);
106
107 // See Note [Acquire lock when using random generators]
108 std::lock_guard<std::mutex> lock(gen.mutex());
109 gen.set_state(new_state_tensor);
110
111 Py_INCREF(self);
112 return (PyObject*)self;
113 END_HANDLE_TH_ERRORS
114}
115
116static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
117 HANDLE_TH_ERRORS
118 auto self = (THPGenerator*)_self;
119 auto generator = self->cdata;
120 THPUtils_assert(
121 THPUtils_checkLong(seed),
122 "manual_seed expected a long, "
123 "but got %s",
124 THPUtils_typename(seed));
125 // See Note [Acquire lock when using random generators]
126 std::lock_guard<std::mutex> lock(generator.mutex());
127 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
128 uint64_t seed_unpacked;
129 try {
130 // First try to interpret as unsigned long
131 seed_unpacked = THPUtils_unpackUInt64(seed);
132 } catch (...) {
133 if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
134 // If an overflow happened, then the seed could be negative,
135 // so try to interpret it as signed long
136 PyErr_Clear();
137 int64_t seed_unpacked_signed = THPUtils_unpackLong(seed);
138 seed_unpacked = *(reinterpret_cast<uint64_t*>(&seed_unpacked_signed));
139 } else {
140 // If any other type of exception happened, rethrow it
141 throw;
142 }
143 }
144 generator.set_current_seed(seed_unpacked);
145 Py_INCREF(self);
146 return (PyObject*)self;
147 END_HANDLE_TH_ERRORS
148}
149
150static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
151 HANDLE_TH_ERRORS
152 // See Note [Acquire lock when using random generators]
153 auto self = (THPGenerator*)_self;
154 std::lock_guard<std::mutex> lock(self->cdata.mutex());
155 uint64_t seed_val = self->cdata.seed();
156 return THPUtils_packUInt64(seed_val);
157 END_HANDLE_TH_ERRORS
158}
159
160static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) {
161 HANDLE_TH_ERRORS
162 auto self = (THPGenerator*)_self;
163 return THPUtils_packUInt64(self->cdata.current_seed());
164 END_HANDLE_TH_ERRORS
165}
166
167static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) {
168 HANDLE_TH_ERRORS
169 return THPDevice_New(self->cdata.device());
170 END_HANDLE_TH_ERRORS
171}
172
173// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
174static struct PyGetSetDef THPGenerator_properties[] = {
175 {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
176 {nullptr}};
177
178// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
179static PyMethodDef THPGenerator_methods[] = {
180 {"get_state", THPGenerator_getState, METH_NOARGS, nullptr},
181 {"set_state", THPGenerator_setState, METH_O, nullptr},
182 {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr},
183 {"seed", THPGenerator_seed, METH_NOARGS, nullptr},
184 {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr},
185 {nullptr}};
186
187// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
188static struct PyMemberDef THPGenerator_members[] = {
189 {(char*)"_cdata",
190 T_ULONGLONG,
191 offsetof(THPGenerator, cdata),
192 READONLY,
193 nullptr},
194 {nullptr}};
195
196PyTypeObject THPGeneratorType = {
197 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.Generator", /* tp_name */
198 sizeof(THPGenerator), /* tp_basicsize */
199 0, /* tp_itemsize */
200 THPGenerator_dealloc, /* tp_dealloc */
201 0, /* tp_vectorcall_offset */
202 nullptr, /* tp_getattr */
203 nullptr, /* tp_setattr */
204 nullptr, /* tp_reserved */
205 nullptr, /* tp_repr */
206 nullptr, /* tp_as_number */
207 nullptr, /* tp_as_sequence */
208 nullptr, /* tp_as_mapping */
209 nullptr, /* tp_hash */
210 nullptr, /* tp_call */
211 nullptr, /* tp_str */
212 nullptr, /* tp_getattro */
213 nullptr, /* tp_setattro */
214 nullptr, /* tp_as_buffer */
215 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
216 nullptr, /* tp_doc */
217 nullptr, /* tp_traverse */
218 nullptr, /* tp_clear */
219 nullptr, /* tp_richcompare */
220 0, /* tp_weaklistoffset */
221 nullptr, /* tp_iter */
222 nullptr, /* tp_iternext */
223 THPGenerator_methods, /* tp_methods */
224 THPGenerator_members, /* tp_members */
225 THPGenerator_properties, /* tp_getset */
226 nullptr, /* tp_base */
227 nullptr, /* tp_dict */
228 nullptr, /* tp_descr_get */
229 nullptr, /* tp_descr_set */
230 0, /* tp_dictoffset */
231 nullptr, /* tp_init */
232 nullptr, /* tp_alloc */
233 THPGenerator_pynew, /* tp_new */
234};
235
236bool THPGenerator_init(PyObject* module) {
237 THPGeneratorClass = (PyObject*)&THPGeneratorType;
238 if (PyType_Ready(&THPGeneratorType) < 0)
239 return false;
240 Py_INCREF(&THPGeneratorType);
241 PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType);
242 return true;
243}
244
245void set_pyobj(const Generator& self, PyObject* pyobj) {
246 TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined generator");
247 self.set_pyobj(pyobj);
248}
249
250PyObject* pyobj(const Generator& self) {
251 TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined generator");
252 return self.pyobj();
253}
254
255PyObject* THPGenerator_Wrap(Generator gen) {
256 if (!gen.defined()) {
257 Py_RETURN_NONE;
258 }
259
260 if (auto obj = pyobj(gen)) {
261 Py_INCREF(obj);
262 return obj;
263 }
264
265 return THPGenerator_NewWithVar(
266 (PyTypeObject*)THPGeneratorClass, std::move(gen));
267}
268
269// Creates a new Python object for a Generator. The Generator must not already
270// have a PyObject* associated with it.
271PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
272 PyObject* obj = type->tp_alloc(type, 0);
273 if (obj) {
274 auto g = (THPGenerator*)obj;
275 new (&g->cdata) Generator(std::move(gen));
276 set_pyobj(g->cdata, obj);
277 }
278 return obj;
279}
280