1#pragma once
2
3#include <c10/util/irange.h>
4
5// define constants like M_PI and C keywords for MSVC
6#ifdef _MSC_VER
7#ifndef _USE_MATH_DEFINES
8#define _USE_MATH_DEFINES
9#endif
10#include <math.h>
11#endif
12
13#include <array>
14#include <cmath>
15#include <cstdint>
16
17namespace at {
18
19constexpr int MERSENNE_STATE_N = 624;
20constexpr int MERSENNE_STATE_M = 397;
21constexpr uint32_t MATRIX_A = 0x9908b0df;
22constexpr uint32_t UMASK = 0x80000000;
23constexpr uint32_t LMASK = 0x7fffffff;
24
25/**
26 * Note [Mt19937 Engine implementation]
27 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28 * Originally implemented in:
29 * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
30 * and modified with C++ constructs. Moreover the state array of the engine
31 * has been modified to hold 32 bit uints instead of 64 bits.
32 *
33 * Note that we reimplemented mt19937 instead of using std::mt19937 because,
34 * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
35 * by default and following are the benchmark numbers (benchmark code can be found at
36 * https://github.com/syed-ahmed/benchmark-rngs):
37 *
38 * with -O2
39 * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
40 * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
41 * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
42 * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
43 *
44 * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
45 * however we can't use std::uniform_real_distribution because of this bug:
46 * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
47 * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
48 * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
49 * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
50 * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
51 *
52 * Copyright notice:
53 * A C-program for MT19937, with initialization improved 2002/2/10.
54 * Coded by Takuji Nishimura and Makoto Matsumoto.
55 * This is a faster version by taking Shawn Cokus's optimization,
56 * Matthe Bellew's simplification, Isaku Wada's real version.
57 *
58 * Before using, initialize the state by using init_genrand(seed)
59 * or init_by_array(init_key, key_length).
60 *
61 * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
62 * All rights reserved.
63 *
64 * Redistribution and use in source and binary forms, with or without
65 * modification, are permitted provided that the following conditions
66 * are met:
67 *
68 * 1. Redistributions of source code must retain the above copyright
69 * notice, this list of conditions and the following disclaimer.
70 *
71 * 2. Redistributions in binary form must reproduce the above copyright
72 * notice, this list of conditions and the following disclaimer in the
73 * documentation and/or other materials provided with the distribution.
74 *
75 * 3. The names of its contributors may not be used to endorse or promote
76 * products derived from this software without specific prior written
77 * permission.
78 *
79 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
80 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
81 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
82 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
83 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
84 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
85 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
86 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
87 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
88 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
89 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
90 *
91 *
92 * Any feedback is very welcome.
93 * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
94 * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
95 */
96
97/**
98 * mt19937_data_pod is used to get POD data in and out
99 * of mt19937_engine. Used in torch.get_rng_state and
100 * torch.set_rng_state functions.
101 */
102struct mt19937_data_pod {
103 uint64_t seed_;
104 int left_;
105 bool seeded_;
106 uint32_t next_;
107 std::array<uint32_t, MERSENNE_STATE_N> state_;
108};
109
110class mt19937_engine {
111public:
112
113 inline explicit mt19937_engine(uint64_t seed = 5489) {
114 init_with_uint32(seed);
115 }
116
117 inline mt19937_data_pod data() const {
118 return data_;
119 }
120
121 inline void set_data(const mt19937_data_pod& data) {
122 data_ = data;
123 }
124
125 inline uint64_t seed() const {
126 return data_.seed_;
127 }
128
129 inline bool is_valid() {
130 if ((data_.seeded_ == true)
131 && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
132 && (data_.next_ <= MERSENNE_STATE_N)) {
133 return true;
134 }
135 return false;
136 }
137
138 inline uint32_t operator()() {
139 uint32_t y;
140
141 if (--(data_.left_) == 0) {
142 next_state();
143 }
144 y = *(data_.state_.data() + data_.next_++);
145 y ^= (y >> 11);
146 y ^= (y << 7) & 0x9d2c5680;
147 y ^= (y << 15) & 0xefc60000;
148 y ^= (y >> 18);
149
150 return y;
151 }
152
153private:
154 mt19937_data_pod data_;
155
156 inline void init_with_uint32(uint64_t seed) {
157 data_.seed_ = seed;
158 data_.seeded_ = true;
159 data_.state_[0] = seed & 0xffffffff;
160 for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
161 data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
162 }
163 data_.left_ = 1;
164 data_.next_ = 0;
165 }
166
167 inline uint32_t mix_bits(uint32_t u, uint32_t v) {
168 return (u & UMASK) | (v & LMASK);
169 }
170
171 inline uint32_t twist(uint32_t u, uint32_t v) {
172 return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
173 }
174
175 inline void next_state() {
176 uint32_t* p = data_.state_.data();
177 data_.left_ = MERSENNE_STATE_N;
178 data_.next_ = 0;
179
180 for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
181 *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
182 }
183
184 for(int j = MERSENNE_STATE_M; --j; p++) {
185 *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
186 }
187
188 *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
189 }
190
191};
192
193typedef mt19937_engine mt19937;
194
195} // namespace at
196