1 | #include <c10/core/GeneratorImpl.h> |
2 | #include <random> |
3 | |
4 | #if defined(__SGX_ENABLED__) |
5 | #include <sgx_trts.h> |
6 | #endif |
7 | |
8 | #ifndef _WIN32 |
9 | #include <fcntl.h> |
10 | #include <unistd.h> |
11 | #include <chrono> |
12 | #endif |
13 | |
14 | namespace c10 { |
15 | |
16 | /** |
17 | * GeneratorImpl class implementation |
18 | */ |
19 | GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set) |
20 | : device_{device_in}, key_set_(key_set) {} |
21 | |
22 | /** |
23 | * Clone this generator. Note that clone() is the only |
24 | * method for copying for Generators in ATen. |
25 | */ |
26 | c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const { |
27 | auto res = this->clone_impl(); |
28 | c10::raw::intrusive_ptr::incref(res); |
29 | c10::raw::weak_intrusive_ptr::incref(res); |
30 | return c10::intrusive_ptr<GeneratorImpl>::reclaim(res); |
31 | } |
32 | |
33 | /** |
34 | * Gets the device of a generator. |
35 | */ |
36 | Device GeneratorImpl::device() const { |
37 | return device_; |
38 | } |
39 | |
40 | namespace detail { |
41 | |
42 | /** |
43 | * Gets a random number for /dev/urandom |
44 | * Note this is a legacy method (from THRandom.cpp) |
45 | * FIXME: use std::random_device with entropy information |
46 | */ |
47 | #if !defined(_WIN32) |
48 | static uint64_t readURandomLong() { |
49 | int randDev = open("/dev/urandom" , O_RDONLY); |
50 | TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom" ); |
51 | uint64_t randValue{}; |
52 | ssize_t readBytes = read(randDev, &randValue, sizeof(randValue)); |
53 | close(randDev); |
54 | TORCH_CHECK( |
55 | readBytes >= (ssize_t)sizeof(randValue), |
56 | "Unable to read from /dev/urandom" ); |
57 | return randValue; |
58 | } |
59 | #endif // _WIN32 |
60 | |
61 | /** |
62 | * Gets a non deterministic random number number from either the |
63 | * /dev/urandom or the current time. For CUDA, gets random from |
64 | * std::random_device and adds a transformation on it. For Intel SGX |
65 | * platform use sgx_read_rand as reading from /dev/urandom is |
66 | * prohibited on that platfrom. |
67 | * |
68 | * FIXME: The behavior in this function is from legacy code |
69 | * (THRandom_seed/THCRandom_seed) and is probably not the right thing to do, |
70 | * even though our tests pass. Figure out if tests get perturbed |
71 | * - when the same algorithm is used for all backends. Note that the current |
72 | * behavior is different for CPU, CUDA and Windows CPU. |
73 | * - when using C++11 std objects, such as std::random_device |
74 | * - when constructing a 64 bit seed properly, rather than static casting |
75 | * a 32 bit number to 64 bit. |
76 | */ |
77 | uint64_t getNonDeterministicRandom(bool is_cuda) { |
78 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
79 | uint64_t s; |
80 | if (!is_cuda) { |
81 | #ifdef _WIN32 |
82 | s = (uint64_t)std::chrono::high_resolution_clock::now() |
83 | .time_since_epoch() |
84 | .count(); |
85 | #elif defined(__SGX_ENABLED__) |
86 | TORCH_CHECK( |
87 | sgx_read_rand(reinterpret_cast<uint8_t*>(&s), sizeof(s)) == SGX_SUCCESS, |
88 | "Could not generate random number with sgx_read_rand." ); |
89 | #else |
90 | s = readURandomLong(); |
91 | #endif |
92 | } else { |
93 | std::random_device rd; |
94 | // limit to 53 bits to ensure unique representation in double |
95 | s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; |
96 | } |
97 | return s; |
98 | } |
99 | |
100 | } // namespace detail |
101 | } // namespace c10 |
102 | |