1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | |
5 | // define constants like M_PI and C keywords for MSVC |
6 | #ifdef _MSC_VER |
7 | #ifndef _USE_MATH_DEFINES |
8 | #define _USE_MATH_DEFINES |
9 | #endif |
10 | #include <math.h> |
11 | #endif |
12 | |
13 | #include <array> |
14 | #include <cmath> |
15 | #include <cstdint> |
16 | |
17 | namespace at { |
18 | |
19 | constexpr int MERSENNE_STATE_N = 624; |
20 | constexpr int MERSENNE_STATE_M = 397; |
21 | constexpr uint32_t MATRIX_A = 0x9908b0df; |
22 | constexpr uint32_t UMASK = 0x80000000; |
23 | constexpr uint32_t LMASK = 0x7fffffff; |
24 | |
25 | /** |
26 | * Note [Mt19937 Engine implementation] |
27 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
28 | * Originally implemented in: |
29 | * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c |
30 | * and modified with C++ constructs. Moreover the state array of the engine |
31 | * has been modified to hold 32 bit uints instead of 64 bits. |
32 | * |
33 | * Note that we reimplemented mt19937 instead of using std::mt19937 because, |
34 | * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2 |
35 | * by default and following are the benchmark numbers (benchmark code can be found at |
36 | * https://github.com/syed-ahmed/benchmark-rngs): |
37 | * |
38 | * with -O2 |
39 | * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s |
40 | * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s |
41 | * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s |
42 | * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s |
43 | * |
44 | * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution, |
45 | * however we can't use std::uniform_real_distribution because of this bug: |
46 | * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used |
47 | * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm |
48 | * than what's in pytorch currently and that messes up the tests in tests_distributions.py. |
49 | * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower |
50 | * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter. |
51 | * |
52 | * Copyright notice: |
53 | * A C-program for MT19937, with initialization improved 2002/2/10. |
54 | * Coded by Takuji Nishimura and Makoto Matsumoto. |
55 | * This is a faster version by taking Shawn Cokus's optimization, |
56 | * Matthe Bellew's simplification, Isaku Wada's real version. |
57 | * |
58 | * Before using, initialize the state by using init_genrand(seed) |
59 | * or init_by_array(init_key, key_length). |
60 | * |
61 | * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, |
62 | * All rights reserved. |
63 | * |
64 | * Redistribution and use in source and binary forms, with or without |
65 | * modification, are permitted provided that the following conditions |
66 | * are met: |
67 | * |
68 | * 1. Redistributions of source code must retain the above copyright |
69 | * notice, this list of conditions and the following disclaimer. |
70 | * |
71 | * 2. Redistributions in binary form must reproduce the above copyright |
72 | * notice, this list of conditions and the following disclaimer in the |
73 | * documentation and/or other materials provided with the distribution. |
74 | * |
75 | * 3. The names of its contributors may not be used to endorse or promote |
76 | * products derived from this software without specific prior written |
77 | * permission. |
78 | * |
79 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
80 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
81 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
82 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR |
83 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
84 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
85 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
86 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF |
87 | * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING |
88 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
89 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
90 | * |
91 | * |
92 | * Any feedback is very welcome. |
93 | * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html |
94 | * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space) |
95 | */ |
96 | |
97 | /** |
98 | * mt19937_data_pod is used to get POD data in and out |
99 | * of mt19937_engine. Used in torch.get_rng_state and |
100 | * torch.set_rng_state functions. |
101 | */ |
102 | struct mt19937_data_pod { |
103 | uint64_t seed_; |
104 | int left_; |
105 | bool seeded_; |
106 | uint32_t next_; |
107 | std::array<uint32_t, MERSENNE_STATE_N> state_; |
108 | }; |
109 | |
110 | class mt19937_engine { |
111 | public: |
112 | |
113 | inline explicit mt19937_engine(uint64_t seed = 5489) { |
114 | init_with_uint32(seed); |
115 | } |
116 | |
117 | inline mt19937_data_pod data() const { |
118 | return data_; |
119 | } |
120 | |
121 | inline void set_data(const mt19937_data_pod& data) { |
122 | data_ = data; |
123 | } |
124 | |
125 | inline uint64_t seed() const { |
126 | return data_.seed_; |
127 | } |
128 | |
129 | inline bool is_valid() { |
130 | if ((data_.seeded_ == true) |
131 | && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N) |
132 | && (data_.next_ <= MERSENNE_STATE_N)) { |
133 | return true; |
134 | } |
135 | return false; |
136 | } |
137 | |
138 | inline uint32_t operator()() { |
139 | uint32_t y; |
140 | |
141 | if (--(data_.left_) == 0) { |
142 | next_state(); |
143 | } |
144 | y = *(data_.state_.data() + data_.next_++); |
145 | y ^= (y >> 11); |
146 | y ^= (y << 7) & 0x9d2c5680; |
147 | y ^= (y << 15) & 0xefc60000; |
148 | y ^= (y >> 18); |
149 | |
150 | return y; |
151 | } |
152 | |
153 | private: |
154 | mt19937_data_pod data_; |
155 | |
156 | inline void init_with_uint32(uint64_t seed) { |
157 | data_.seed_ = seed; |
158 | data_.seeded_ = true; |
159 | data_.state_[0] = seed & 0xffffffff; |
160 | for (const auto j : c10::irange(1, MERSENNE_STATE_N)) { |
161 | data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j); |
162 | } |
163 | data_.left_ = 1; |
164 | data_.next_ = 0; |
165 | } |
166 | |
167 | inline uint32_t mix_bits(uint32_t u, uint32_t v) { |
168 | return (u & UMASK) | (v & LMASK); |
169 | } |
170 | |
171 | inline uint32_t twist(uint32_t u, uint32_t v) { |
172 | return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0); |
173 | } |
174 | |
175 | inline void next_state() { |
176 | uint32_t* p = data_.state_.data(); |
177 | data_.left_ = MERSENNE_STATE_N; |
178 | data_.next_ = 0; |
179 | |
180 | for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) { |
181 | *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]); |
182 | } |
183 | |
184 | for(int j = MERSENNE_STATE_M; --j; p++) { |
185 | *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]); |
186 | } |
187 | |
188 | *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]); |
189 | } |
190 | |
191 | }; |
192 | |
193 | typedef mt19937_engine mt19937; |
194 | |
195 | } // namespace at |
196 | |