1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file random_engine.h
21 * \brief Random number generator. It provides a generic interface consistent with
22 * `std::uniform_random_bit_generator`
23 */
24#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
25#define TVM_SUPPORT_RANDOM_ENGINE_H_
26#include <tvm/runtime/logging.h>
27
28#include <cstdint>
29#include <random>
30
31namespace tvm {
32namespace support {
33
34/*!
35 * \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly
36 * corresponds to std::minstd_rand and is designed to be platform-independent.
37 * \note Our linear congruential engine is a complete implementation of
38 * std::uniform_random_bit_generator so it can be used as generator for any STL random number
39 * distribution. However, parts of std::linear_congruential_engine's member functions are not
40 * included for simplification. For full member functions of std::minstd_rand, please check out the
41 * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
42 */
43
44class LinearCongruentialEngine {
45 public:
46 using TRandState = int64_t;
47 /*! \brief The result type. */
48 using result_type = uint64_t;
49 /*! \brief The multiplier */
50 static constexpr TRandState multiplier = 48271;
51 /*! \brief The increment */
52 static constexpr TRandState increment = 0;
53 /*! \brief The modulus */
54 static constexpr TRandState modulus = 2147483647;
55 /*! \brief The minimum possible value of random state here. */
56 static constexpr result_type min() { return 0; }
57 /*! \brief The maximum possible value of random state here. */
58 static constexpr result_type max() { return modulus - 1; }
59
60 /*!
61 * \brief Get a device random state
62 * \return The random state
63 */
64 static TRandState DeviceRandom() { return (std::random_device()()) % modulus; }
65
66 /*!
67 * \brief Operator to move the random state to the next and return the new random state. According
68 * to definition of linear congruential engine, the new random state value is computed as
69 * new_random_state = (current_random_state * multiplier + increment) % modulus.
70 * \return The next current random state value in the type of result_type.
71 * \note In order for better efficiency, the implementation here has a few assumptions:
72 * 1. The multiplication and addition won't overflow.
73 * 2. The given random state pointer `rand_state_ptr` is not nullptr.
74 * 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1].
75 */
76 result_type operator()() {
77 (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus;
78 return *rand_state_ptr_;
79 }
80 /*!
81 * \brief Normalize the random seed to the range of [1, modulus - 1].
82 * \param rand_state The random seed.
83 * \return The normalized random seed.
84 */
85 static TRandState NormalizeSeed(TRandState rand_state) {
86 if (rand_state == -1) {
87 rand_state = DeviceRandom();
88 } else {
89 rand_state %= modulus;
90 }
91 if (rand_state == 0) {
92 rand_state = 1;
93 }
94 if (rand_state < 0) {
95 LOG(FATAL) << "ValueError: Random seed must be non-negative";
96 }
97 return rand_state;
98 }
99 /*!
100 * \brief Change the start random state of RNG with the seed of a new random state value.
101 * \param rand_state The random state given in result_type.
102 */
103 void Seed(TRandState rand_state) {
104 ICHECK(rand_state_ptr_ != nullptr);
105 *rand_state_ptr_ = NormalizeSeed(rand_state);
106 }
107
108 /*!
109 * \brief Fork a new seed for another RNG from current random state.
110 * \return The forked seed.
111 */
112 TRandState ForkSeed() {
113 // In order for reproducibility, we compute the new seed using RNG's random state and a
114 // different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
115 return ((*this)() * 32767) % 1999999973;
116 }
117
118 /*!
119 * \brief Construct a random number generator with a random state pointer.
120 * \param rand_state_ptr The random state pointer given in result_type*.
121 * \note The random state is not checked for whether it's nullptr and whether it's in the range of
122 * [0, modulus-1]. We assume the given random state is valid or the Seed function would be
123 * called right after the constructor before any usage.
124 */
125 explicit LinearCongruentialEngine(TRandState* rand_state_ptr) {
126 rand_state_ptr_ = rand_state_ptr;
127 }
128
129 private:
130 TRandState* rand_state_ptr_;
131};
132
133} // namespace support
134} // namespace tvm
135
136#endif // TVM_SUPPORT_RANDOM_ENGINE_H_
137