1#include <ATen/CPUGeneratorImpl.h>
2#include <ATen/Utils.h>
3#include <ATen/core/MT19937RNGEngine.h>
4#include <c10/util/MathConstants.h>
5#include <algorithm>
6
7namespace at {
8
9namespace detail {
10
11/**
12 * CPUGeneratorImplStateLegacy is a POD class needed for memcpys
13 * in torch.get_rng_state() and torch.set_rng_state().
14 * It is a legacy class and even though it is replaced with
15 * at::CPUGeneratorImpl, we need this class and some of its fields
16 * to support backward compatibility on loading checkpoints.
17 */
18struct CPUGeneratorImplStateLegacy {
19 /* The initial seed. */
20 uint64_t the_initial_seed;
21 int left; /* = 1; */
22 int seeded; /* = 0; */
23 uint64_t next;
24 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
25 uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */
26
27 /********************************/
28
29 /* For normal distribution */
30 double normal_x;
31 double normal_y;
32 double normal_rho;
33 int normal_is_valid; /* = 0; */
34};
35
36/**
37 * CPUGeneratorImplState is a POD class containing
38 * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
39 * as a helper for torch.get_rng_state() and torch.set_rng_state()
40 * functions.
41 */
42struct CPUGeneratorImplState {
43 CPUGeneratorImplStateLegacy legacy_pod;
44 float next_float_normal_sample;
45 bool is_next_float_normal_sample_valid;
46};
47
48/**
49 * PyTorch maintains a collection of default generators that get
50 * initialized once. The purpose of these default generators is to
51 * maintain a global running state of the pseudo random number generation,
52 * when a user does not explicitly mention any generator.
53 * getDefaultCPUGenerator gets the default generator for a particular
54 * device.
55 */
56const Generator& getDefaultCPUGenerator() {
57 static auto default_gen_cpu = createCPUGenerator(c10::detail::getNonDeterministicRandom());
58 return default_gen_cpu;
59}
60
61/**
62 * Utility to create a CPUGeneratorImpl. Returns a shared_ptr
63 */
64Generator createCPUGenerator(uint64_t seed_val) {
65 return make_generator<CPUGeneratorImpl>(seed_val);
66}
67
68/**
69 * Helper function to concatenate two 32 bit unsigned int
70 * and return them as a 64 bit unsigned int
71 */
72inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {
73 return (static_cast<uint64_t>(hi) << 32) | lo;
74}
75
76} // namespace detail
77
78/**
79 * CPUGeneratorImpl class implementation
80 */
81CPUGeneratorImpl::CPUGeneratorImpl(uint64_t seed_in)
82 : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
83 engine_{seed_in},
84 next_float_normal_sample_{c10::optional<float>()},
85 next_double_normal_sample_{c10::optional<double>()} { }
86
87/**
88 * Manually seeds the engine with the seed input
89 * See Note [Acquire lock when using random generators]
90 */
91void CPUGeneratorImpl::set_current_seed(uint64_t seed) {
92 next_float_normal_sample_.reset();
93 next_double_normal_sample_.reset();
94 engine_ = mt19937(seed);
95}
96
97/**
98 * Gets the current seed of CPUGeneratorImpl.
99 */
100uint64_t CPUGeneratorImpl::current_seed() const {
101 return engine_.seed();
102}
103
104/**
105 * Gets a nondeterministic random number from /dev/urandom or time,
106 * seeds the CPUGeneratorImpl with it and then returns that number.
107 *
108 * FIXME: You can move this function to Generator.cpp if the algorithm
109 * in getNonDeterministicRandom is unified for both CPU and CUDA
110 */
111uint64_t CPUGeneratorImpl::seed() {
112 auto random = c10::detail::getNonDeterministicRandom();
113 this->set_current_seed(random);
114 return random;
115}
116
117/**
118 * Sets the internal state of CPUGeneratorImpl. The new internal state
119 * must be a strided CPU byte tensor and of the same size as either
120 * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
121 * CPUGeneratorImplState (for new state).
122 *
123 * FIXME: Remove support of the legacy state in the future?
124 */
125void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
126 using detail::CPUGeneratorImplState;
127 using detail::CPUGeneratorImplStateLegacy;
128
129 static_assert(std::is_standard_layout<CPUGeneratorImplStateLegacy>::value, "CPUGeneratorImplStateLegacy is not a PODType");
130 static_assert(std::is_standard_layout<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
131
132 static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
133 static const size_t size_current = sizeof(CPUGeneratorImplState);
134 static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");
135
136 detail::check_rng_state(new_state);
137
138 at::mt19937 engine;
139 auto float_normal_sample = c10::optional<float>();
140 auto double_normal_sample = c10::optional<double>();
141
142 // Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
143 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
144 CPUGeneratorImplStateLegacy* legacy_pod;
145 auto new_state_size = new_state.numel();
146 if (new_state_size == size_legacy) {
147 legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
148 // Note that in CPUGeneratorImplStateLegacy, we didn't have float version
149 // of normal sample and hence we leave the c10::optional<float> as is
150
151 // Update next_double_normal_sample.
152 // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
153 // and a rho value (normal_rho). These three values were redundant and in the new
154 // DistributionsHelper.h, we store the actual extra normal sample, rather than three
155 // intermediate values.
156 if (legacy_pod->normal_is_valid) {
157 auto r = legacy_pod->normal_rho;
158 auto theta = 2.0 * c10::pi<double> * legacy_pod->normal_x;
159 // we return the sin version of the normal sample when in caching mode
160 double_normal_sample = c10::optional<double>(r * ::sin(theta));
161 }
162 } else if (new_state_size == size_current) {
163 auto rng_state = (CPUGeneratorImplState*)new_state.data();
164 legacy_pod = &rng_state->legacy_pod;
165 // update next_float_normal_sample
166 if (rng_state->is_next_float_normal_sample_valid) {
167 float_normal_sample = c10::optional<float>(rng_state->next_float_normal_sample);
168 }
169
170 // Update next_double_normal_sample.
171 // Note that in getRNGState, we now return the actual normal sample in normal_y
172 // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
173 // are squashed to 0.0.
174 if (legacy_pod->normal_is_valid) {
175 double_normal_sample = c10::optional<double>(legacy_pod->normal_y);
176 }
177 } else {
178 AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
179 " or a CPUGeneratorImplState of size ", size_current,
180 " but found the input RNG state size to be ", new_state_size);
181 }
182
183 // construct engine_
184 // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
185 // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
186 // doing a std::copy.
187 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
188 at::mt19937_data_pod rng_data;
189 std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
190 rng_data.seed_ = legacy_pod->the_initial_seed;
191 rng_data.left_ = legacy_pod->left;
192 rng_data.seeded_ = legacy_pod->seeded;
193 rng_data.next_ = static_cast<uint32_t>(legacy_pod->next);
194 engine.set_data(rng_data);
195 TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
196 this->engine_ = engine;
197 this->next_float_normal_sample_ = float_normal_sample;
198 this->next_double_normal_sample_ = double_normal_sample;
199}
200
201/**
202 * Gets the current internal state of CPUGeneratorImpl. The internal
203 * state is returned as a CPU byte tensor.
204 */
205c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
206 using detail::CPUGeneratorImplState;
207
208 static const size_t size = sizeof(CPUGeneratorImplState);
209 static_assert(std::is_standard_layout<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
210
211 auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
212 auto rng_state = state_tensor.data_ptr();
213
214 // accumulate generator data to be copied into byte tensor
215 auto accum_state = std::make_unique<CPUGeneratorImplState>();
216 auto rng_data = this->engine_.data();
217 accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
218 accum_state->legacy_pod.left = rng_data.left_;
219 accum_state->legacy_pod.seeded = rng_data.seeded_;
220 accum_state->legacy_pod.next = rng_data.next_;
221 std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
222 accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
223 accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
224 accum_state->legacy_pod.normal_is_valid = false;
225 accum_state->legacy_pod.normal_y = 0.0;
226 accum_state->next_float_normal_sample = 0.0f;
227 accum_state->is_next_float_normal_sample_valid = false;
228 if (this->next_double_normal_sample_) {
229 accum_state->legacy_pod.normal_is_valid = true;
230 accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
231 }
232 if (this->next_float_normal_sample_) {
233 accum_state->is_next_float_normal_sample_valid = true;
234 accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
235 }
236
237 memcpy(rng_state, accum_state.get(), size);
238 return state_tensor.getIntrusivePtr();
239}
240
241/**
242 * Gets the DeviceType of CPUGeneratorImpl.
243 * Used for type checking during run time.
244 */
245DeviceType CPUGeneratorImpl::device_type() {
246 return DeviceType::CPU;
247}
248
249/**
250 * Gets a random 32 bit unsigned integer from the engine
251 *
252 * See Note [Acquire lock when using random generators]
253 */
254uint32_t CPUGeneratorImpl::random() {
255 return engine_();
256}
257
258/**
259 * Gets a random 64 bit unsigned integer from the engine
260 *
261 * See Note [Acquire lock when using random generators]
262 */
263uint64_t CPUGeneratorImpl::random64() {
264 uint32_t random1 = engine_();
265 uint32_t random2 = engine_();
266 return detail::make64BitsFrom32Bits(random1, random2);
267}
268
269/**
270 * Get the cached normal random in float
271 */
272c10::optional<float> CPUGeneratorImpl::next_float_normal_sample() {
273 return next_float_normal_sample_;
274}
275
276/**
277 * Get the cached normal random in double
278 */
279c10::optional<double> CPUGeneratorImpl::next_double_normal_sample() {
280 return next_double_normal_sample_;
281}
282
283/**
284 * Cache normal random in float
285 *
286 * See Note [Acquire lock when using random generators]
287 */
288void CPUGeneratorImpl::set_next_float_normal_sample(c10::optional<float> randn) {
289 next_float_normal_sample_ = randn;
290}
291
292/**
293 * Cache normal random in double
294 *
295 * See Note [Acquire lock when using random generators]
296 */
297void CPUGeneratorImpl::set_next_double_normal_sample(c10::optional<double> randn) {
298 next_double_normal_sample_ = randn;
299}
300
301/**
302 * Get the engine of the CPUGeneratorImpl
303 */
304at::mt19937 CPUGeneratorImpl::engine() {
305 return engine_;
306}
307
308/**
309 * Set the engine of the CPUGeneratorImpl
310 *
311 * See Note [Acquire lock when using random generators]
312 */
313void CPUGeneratorImpl::set_engine(at::mt19937 engine) {
314 engine_ = engine;
315}
316
317/**
318 * Public clone method implementation
319 *
320 * See Note [Acquire lock when using random generators]
321 */
322std::shared_ptr<CPUGeneratorImpl> CPUGeneratorImpl::clone() const {
323 return std::shared_ptr<CPUGeneratorImpl>(this->clone_impl());
324}
325
326/**
327 * Private clone method implementation
328 *
329 * See Note [Acquire lock when using random generators]
330 */
331CPUGeneratorImpl* CPUGeneratorImpl::clone_impl() const {
332 auto gen = new CPUGeneratorImpl();
333 gen->set_engine(engine_);
334 gen->set_next_float_normal_sample(next_float_normal_sample_);
335 gen->set_next_double_normal_sample(next_double_normal_sample_);
336 return gen;
337}
338
339} // namespace at
340