1#include <c10/core/GeneratorImpl.h>
2#include <random>
3
4#if defined(__SGX_ENABLED__)
5#include <sgx_trts.h>
6#endif
7
8#ifndef _WIN32
9#include <fcntl.h>
10#include <unistd.h>
11#include <chrono>
12#endif
13
14namespace c10 {
15
16/**
17 * GeneratorImpl class implementation
18 */
19GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set)
20 : device_{device_in}, key_set_(key_set) {}
21
22/**
23 * Clone this generator. Note that clone() is the only
24 * method for copying for Generators in ATen.
25 */
26c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const {
27 auto res = this->clone_impl();
28 c10::raw::intrusive_ptr::incref(res);
29 c10::raw::weak_intrusive_ptr::incref(res);
30 return c10::intrusive_ptr<GeneratorImpl>::reclaim(res);
31}
32
33/**
34 * Gets the device of a generator.
35 */
36Device GeneratorImpl::device() const {
37 return device_;
38}
39
40namespace detail {
41
42/**
43 * Gets a random number for /dev/urandom
44 * Note this is a legacy method (from THRandom.cpp)
45 * FIXME: use std::random_device with entropy information
46 */
47#if !defined(_WIN32)
48static uint64_t readURandomLong() {
49 int randDev = open("/dev/urandom", O_RDONLY);
50 TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom");
51 uint64_t randValue{};
52 ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));
53 close(randDev);
54 TORCH_CHECK(
55 readBytes >= (ssize_t)sizeof(randValue),
56 "Unable to read from /dev/urandom");
57 return randValue;
58}
59#endif // _WIN32
60
61/**
62 * Gets a non deterministic random number number from either the
63 * /dev/urandom or the current time. For CUDA, gets random from
64 * std::random_device and adds a transformation on it. For Intel SGX
65 * platform use sgx_read_rand as reading from /dev/urandom is
66 * prohibited on that platfrom.
67 *
68 * FIXME: The behavior in this function is from legacy code
69 * (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
70 * even though our tests pass. Figure out if tests get perturbed
71 * - when the same algorithm is used for all backends. Note that the current
72 * behavior is different for CPU, CUDA and Windows CPU.
73 * - when using C++11 std objects, such as std::random_device
74 * - when constructing a 64 bit seed properly, rather than static casting
75 * a 32 bit number to 64 bit.
76 */
77uint64_t getNonDeterministicRandom(bool is_cuda) {
78 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
79 uint64_t s;
80 if (!is_cuda) {
81#ifdef _WIN32
82 s = (uint64_t)std::chrono::high_resolution_clock::now()
83 .time_since_epoch()
84 .count();
85#elif defined(__SGX_ENABLED__)
86 TORCH_CHECK(
87 sgx_read_rand(reinterpret_cast<uint8_t*>(&s), sizeof(s)) == SGX_SUCCESS,
88 "Could not generate random number with sgx_read_rand.");
89#else
90 s = readURandomLong();
91#endif
92 } else {
93 std::random_device rd;
94 // limit to 53 bits to ensure unique representation in double
95 s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
96 }
97 return s;
98}
99
100} // namespace detail
101} // namespace c10
102