1 | #pragma once |
2 | |
3 | #include <stdint.h> |
4 | #include <mutex> |
5 | #include <deque> |
6 | #include <atomic> |
7 | #include <typeinfo> |
8 | #include <utility> |
9 | #include <cstddef> |
10 | |
11 | #include <c10/util/Exception.h> |
12 | #include <c10/util/C++17.h> |
13 | #include <c10/util/intrusive_ptr.h> |
14 | #include <c10/core/Device.h> |
15 | #include <c10/core/DispatchKeySet.h> |
16 | |
17 | // For the record I don't think this is a correct pimpl idiom. |
18 | // Including Impl header in interface header defeats the purpose |
19 | // because you can't change Impl private members without forcing |
20 | // everything that included the interface to rebuild. |
21 | // Impl should be forward-declared in the interface header instead. |
22 | #include <c10/core/GeneratorImpl.h> |
23 | |
24 | /** |
25 | * Note [Generator] |
26 | * ~~~~~~~~~~~~~~~~ |
27 | * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to |
28 | * generate a seemingly random sequence of numbers, that may be later be used in creating |
29 | * a random distribution. Such an engine almost always maintains a state and requires a |
30 | * seed to start off the creation of random numbers. Often times, users have |
31 | * found it beneficial to be able to explicitly create, retain, and destroy |
32 | * PRNG states and also be able to have control over the seed value. |
33 | * |
34 | * A Generator in ATen gives users the ability to read, write and modify a PRNG engine. |
35 | * For instance, it does so by letting users seed a PRNG engine, fork the state of the |
36 | * engine, etc. |
37 | * |
38 | * By default, there is one generator per device, and a device's generator is |
39 | * lazily created. A user can use the torch.Generator() api to create their own generator. |
40 | */ |
41 | |
42 | /** |
43 | * Note [Acquire lock when using random generators] |
44 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
45 | * Generator and its derived classes are NOT thread-safe. Please note that most of the |
46 | * places where we have inserted locking for generators are historically based, and we |
47 | * haven't actually checked that everything is truly thread safe (and it probably isn't). |
48 | * Please use the public mutex_ when using any methods from these classes, except for the |
49 | * read-only methods. You can learn about the usage by looking into the unittests |
50 | * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. |
51 | * |
52 | * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making |
53 | * them non-thread safe and instead making the generator state splittable, to accommodate |
54 | * forks into other threads). |
55 | */ |
56 | |
57 | namespace at { |
58 | |
59 | class Tensor; |
60 | |
61 | struct TORCH_API Generator { |
62 | Generator() = default; |
63 | |
64 | explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) |
65 | : impl_(std::move(gen_impl)) { |
66 | if (impl_.get() == nullptr) { |
67 | throw std::runtime_error("GeneratorImpl with nullptr is not supported" ); |
68 | } |
69 | } |
70 | |
71 | bool operator==(const Generator& rhs) const { |
72 | return this->impl_ == rhs.impl_; |
73 | } |
74 | |
75 | bool operator!=(const Generator& rhs) const { |
76 | return !((*this) == rhs); |
77 | } |
78 | |
79 | bool defined() const { |
80 | return static_cast<bool>(impl_); |
81 | } |
82 | |
83 | c10::GeneratorImpl* unsafeGetGeneratorImpl() const { |
84 | return impl_.get(); |
85 | } |
86 | |
87 | c10::GeneratorImpl* unsafeReleaseGeneratorImpl() { |
88 | return impl_.release(); |
89 | } |
90 | |
91 | const c10::intrusive_ptr<c10::GeneratorImpl>& getIntrusivePtr() const { |
92 | return impl_; |
93 | } |
94 | |
95 | void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); } |
96 | |
97 | uint64_t current_seed() const { return impl_->current_seed(); } |
98 | |
99 | uint64_t seed() { return impl_->seed(); } |
100 | |
101 | // Implementation not inlined to prevent cycle reference between |
102 | // `ATen/core/Generator.h` and `ATen/core/Tensor.h` |
103 | void set_state(const at::Tensor& new_state); |
104 | |
105 | at::Tensor get_state() const; |
106 | |
107 | std::mutex& mutex() { |
108 | return impl_->mutex_; |
109 | } |
110 | |
111 | DispatchKeySet key_set() const { |
112 | return impl_->key_set(); |
113 | } |
114 | |
115 | Device device() const { return impl_->device(); } |
116 | |
117 | inline void set_pyobj(PyObject* pyobj) const noexcept { |
118 | impl_->set_pyobj(pyobj); |
119 | } |
120 | |
121 | inline PyObject* pyobj() const noexcept { |
122 | return impl_->pyobj(); |
123 | } |
124 | |
125 | template<typename T> |
126 | T* get() const { return static_cast<T*>(impl_.get()); } |
127 | |
128 | Generator clone() const { |
129 | return Generator(impl_->clone()); |
130 | } |
131 | |
132 | private: |
133 | c10::intrusive_ptr<c10::GeneratorImpl> impl_; |
134 | }; |
135 | |
136 | template<class Impl, class... Args> |
137 | Generator make_generator(Args&&... args) { |
138 | return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...)); |
139 | } |
140 | |
141 | /** |
142 | * Utility function to static cast input Generator* to |
143 | * the backend generator type (CPU/CUDAGeneratorImpl etc.) |
144 | */ |
145 | template <typename T> |
146 | static inline T * check_generator(c10::optional<Generator> gen) { |
147 | TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt" ); |
148 | TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed" ); |
149 | TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '" , T::device_type(), "' device type for generator but found '" , gen->device().type(), "'" ); |
150 | return gen->get<T>(); |
151 | } |
152 | |
153 | /** |
154 | * Utility function used in tensor implementations, which |
155 | * supplies the default generator to tensors, if an input generator |
156 | * is not supplied. The input Generator* is also static casted to |
157 | * the backend generator type (CPU/CUDAGeneratorImpl etc.) |
158 | */ |
159 | template <typename T> |
160 | static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) { |
161 | return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen); |
162 | } |
163 | |
164 | namespace detail { |
165 | |
166 | /** |
167 | * Helper function for checking the validity of new random generator |
168 | * state. Right now following conditions are checked: |
169 | * |
170 | * - The new state tensor must be a torch.ByteTensor |
171 | * - Data of the new state tensor must be contiguous |
172 | */ |
173 | static inline void check_rng_state(const c10::TensorImpl& new_state) { |
174 | TORCH_CHECK_TYPE( |
175 | new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, |
176 | "RNG state must be a torch.ByteTensor" |
177 | ); |
178 | |
179 | TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous" ); |
180 | } |
181 | |
182 | } // namespace detail |
183 | |
184 | } // namespace at |
185 | |