1 | #include <ATen/Utils.h> |
2 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
3 | #include <ATen/cuda/CUDAGraphsUtils.cuh> |
4 | #include <c10/core/StreamGuard.h> |
5 | #include <c10/cuda/CUDAFunctions.h> |
6 | #include <c10/util/CallOnce.h> |
7 | #include <ATen/Utils.h> |
8 | |
9 | namespace at { |
10 | namespace cuda { |
11 | namespace detail { |
12 | |
13 | namespace { |
14 | |
15 | // Ensures we only call cudaGetDeviceCount only once. |
16 | static c10::once_flag num_gpu_init_flag; |
17 | |
18 | // Total number of gpus in the system. |
19 | static int64_t num_gpus; |
20 | |
21 | // Ensures default_gens_cuda is initialized once. |
22 | static std::deque<c10::once_flag> cuda_gens_init_flag; |
23 | |
24 | // Default, global CUDA generators, one per GPU. |
25 | static std::vector<Generator> default_gens_cuda; |
26 | |
27 | /* |
28 | * Populates the global variables related to CUDA generators |
29 | * Warning: this function must only be called once! |
30 | */ |
31 | static void initCUDAGenVector(){ |
32 | num_gpus = c10::cuda::device_count(); |
33 | cuda_gens_init_flag.resize(num_gpus); |
34 | default_gens_cuda.resize(num_gpus); |
35 | } |
36 | |
37 | } // anonymous namespace |
38 | |
39 | /** |
40 | * PyTorch maintains a collection of default generators that get |
41 | * initialized once. The purpose of these default generators is to |
42 | * maintain a global running state of the pseudo random number generation, |
43 | * when a user does not explicitly mention any generator. |
44 | * getDefaultCUDAGenerator gets the default generator for a particular |
45 | * cuda device. |
46 | */ |
47 | const Generator& getDefaultCUDAGenerator(DeviceIndex device_index) { |
48 | c10::call_once(num_gpu_init_flag, initCUDAGenVector); |
49 | DeviceIndex idx = device_index; |
50 | if (idx == -1) { |
51 | idx = c10::cuda::current_device(); |
52 | } else { |
53 | TORCH_CHECK(idx >= 0 && idx < num_gpus); |
54 | } |
55 | c10::call_once(cuda_gens_init_flag[idx], [&] { |
56 | default_gens_cuda[idx] = make_generator<CUDAGeneratorImpl>(idx); |
57 | default_gens_cuda[idx].seed(); |
58 | }); |
59 | return default_gens_cuda[idx]; |
60 | } |
61 | |
62 | /** |
63 | * Utility to create a CUDAGeneratorImpl. Returns a shared_ptr |
64 | */ |
65 | Generator createCUDAGenerator(DeviceIndex device_index) { |
66 | c10::call_once(num_gpu_init_flag, initCUDAGenVector); |
67 | DeviceIndex idx = device_index; |
68 | if (idx == -1) { |
69 | idx = c10::cuda::current_device(); |
70 | } |
71 | TORCH_CHECK(idx >= 0 && idx < num_gpus, "The device_index is invalid." ); |
72 | auto gen = make_generator<CUDAGeneratorImpl>(idx); |
73 | auto cuda_gen = check_generator<CUDAGeneratorImpl>(gen); |
74 | cuda_gen->set_current_seed(default_rng_seed_val); |
75 | cuda_gen->set_philox_offset_per_thread(0); |
76 | return gen; |
77 | } |
78 | |
79 | } // namespace detail |
80 | } // namespace cuda |
81 | |
82 | /** |
83 | * Note [Why enforce RNG offset % 4 == 0?] |
84 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
85 | * Curand philox does allow offsets that aren't a multiple of 4. |
86 | * But jit kernels don't use curand, they use a custom "Philox" class (see |
87 | * torch/csrc/jit/tensorexpr/cuda_random.h or |
88 | * torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu). |
89 | * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its |
90 | * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called |
91 | * in a thread, returns one 32-bit chunk at a time from that start in the bitstream. |
92 | * In other words, if the incoming offset is not a multiple of 4, each thread |
93 | * might repeat some previously-generated 32-bit values in the bitstream. See |
94 | * https://github.com/pytorch/pytorch/pull/50169. |
95 | */ |
96 | |
97 | /** |
98 | * CUDAGeneratorImpl class implementation |
99 | */ |
100 | CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index) |
101 | : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index), |
102 | DispatchKeySet(c10::DispatchKey::CUDA)} { |
103 | at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl" ); |
104 | no_reset_rnn_state_.clear(); |
105 | } |
106 | |
107 | /** |
108 | * Sets the seed to be used by curandStatePhilox4_32_10 |
109 | * Resets the philox_offset_per_thread_ to 0 |
110 | * |
111 | * See Note [Acquire lock when using random generators] |
112 | */ |
113 | void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { |
114 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_current_seed" ); |
115 | seed_ = seed; |
116 | philox_offset_per_thread_ = 0; |
117 | no_reset_rnn_state_.clear(); |
118 | } |
119 | |
120 | #define CAPTURE_DEFAULT_GENS_MSG \ |
121 | "In regions captured by CUDA graphs, you may only use the default CUDA RNG " \ |
122 | "generator on the device that's current when capture begins. " \ |
123 | "If you need a non-default (user-supplied) generator, or a generator on another " \ |
124 | "device, please file an issue." |
125 | |
126 | /** |
127 | * Gets the current seed of CUDAGeneratorImpl. |
128 | */ |
129 | uint64_t CUDAGeneratorImpl::current_seed() const { |
130 | // Debatable if current_seed() should be allowed in captured regions. |
131 | // Conservatively disallow it for now. |
132 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed" ); |
133 | return seed_; |
134 | } |
135 | |
136 | /** |
137 | * Gets a nondeterministic random number from /dev/urandom or time, |
138 | * seeds the CPUGeneratorImpl with it and then returns that number. |
139 | * |
140 | * FIXME: You can move this function to Generator.cpp if the algorithm |
141 | * in getNonDeterministicRandom is unified for both CPU and CUDA |
142 | */ |
143 | uint64_t CUDAGeneratorImpl::seed() { |
144 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::seed" ); |
145 | auto random = c10::detail::getNonDeterministicRandom(true); |
146 | this->set_current_seed(random); |
147 | return random; |
148 | } |
149 | |
150 | /** |
151 | * Gets the current internal state of CUDAGeneratorImpl. The internal |
152 | * state is returned as a CPU byte tensor. |
153 | */ |
154 | c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { |
155 | // The RNG state comprises the seed, and an offset used for Philox. |
156 | // The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120. |
157 | // It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); |
158 | // MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here |
159 | // because this is just host side code and we don't want to worry about linking with cuda |
160 | static const size_t states_size = 200 * sizeof(4120); |
161 | static const size_t seed_size = sizeof(uint64_t); |
162 | static const size_t offset_size = sizeof(int64_t); |
163 | static const size_t total_size = states_size + seed_size + offset_size; |
164 | |
165 | auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); |
166 | auto rng_state = state_tensor.data_ptr<uint8_t>(); |
167 | // since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1 |
168 | // gen_states in THCGenerator struct was an array of curandStateMtgp32s. |
169 | memset(rng_state, -1, states_size); |
170 | auto current_seed = this->current_seed(); |
171 | auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t> |
172 | memcpy(rng_state + states_size, ¤t_seed, seed_size); |
173 | memcpy(rng_state + states_size + seed_size, &offset, offset_size); |
174 | |
175 | return state_tensor.getIntrusivePtr(); |
176 | } |
177 | |
178 | /** |
179 | * Sets the internal state of CUDAGeneratorImpl. The new internal state |
180 | * must be a strided CPU byte tensor and have appropriate size. See |
181 | * comments of CUDAGeneratorImpl::state for information about the layout |
182 | * and size of the internal state. |
183 | */ |
184 | void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { |
185 | static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason |
186 | static const size_t seed_size = sizeof(uint64_t); |
187 | static const size_t offset_size = sizeof(int64_t); |
188 | static const size_t total_size = states_size + seed_size + offset_size; |
189 | |
190 | detail::check_rng_state(new_state); |
191 | |
192 | bool no_philox_seed = false; |
193 | auto new_state_size = new_state.numel(); |
194 | if (new_state_size == total_size - offset_size) { |
195 | no_philox_seed = true; |
196 | } else { |
197 | TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size" ); |
198 | } |
199 | |
200 | uint64_t input_seed; |
201 | auto new_rng_state = new_state.data<uint8_t>(); |
202 | memcpy(&input_seed, new_rng_state + states_size, seed_size); |
203 | this->set_current_seed(input_seed); |
204 | int64_t philox_offset = 0; |
205 | if (!no_philox_seed) { |
206 | memcpy(&philox_offset, new_rng_state + states_size + seed_size, offset_size); |
207 | } |
208 | this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset)); |
209 | } |
210 | |
211 | /** |
212 | * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10 |
213 | * |
214 | * See Note [Acquire lock when using random generators] |
215 | */ |
216 | void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { |
217 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_philox_offset_per_thread" ); |
218 | // see Note [Why enforce RNG offset % 4 == 0?] |
219 | TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4" ); |
220 | philox_offset_per_thread_ = offset; |
221 | } |
222 | |
223 | /** |
224 | * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl. |
225 | */ |
226 | uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const { |
227 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread" ); |
228 | return philox_offset_per_thread_; |
229 | } |
230 | |
231 | /** |
232 | * Called by CUDAGraph to prepare this instance for a graph capture region. |
233 | * offset_extragraph is the initial offset at the start of the graphed region. |
234 | * offset_intragraph tracks the offset in the graphed region. |
235 | */ |
236 | void CUDAGeneratorImpl::capture_prologue(int64_t* , int64_t* ) { |
237 | seed_extragraph_ = seed_extragraph; |
238 | offset_extragraph_ = offset_extragraph; |
239 | offset_intragraph_ = 0; |
240 | graph_expects_this_gen_ = true; |
241 | } |
242 | |
243 | /** |
244 | * Called by CUDAGraph to finalize a graph capture region for this instance. |
245 | */ |
246 | uint64_t CUDAGeneratorImpl::capture_epilogue() { |
247 | graph_expects_this_gen_ = false; |
248 | return offset_intragraph_; |
249 | } |
250 | |
251 | /** |
252 | * Gets the seed and philox offset value to be used in |
253 | * curandStatePhilox4_32_10, in an opaque PhiloxCudaState that's safe |
254 | * and can be used non-divergently in callers whether CUDA graph |
255 | * capture is underway or not. See |
256 | * Note [CUDA Graph-safe RNG states] |
257 | * |
258 | * Each kernel using philox has to sensibly increment offset |
259 | * for future users of philox. So it gets the "old" value for |
260 | * itself (before add), and tells subsequent users which offset |
261 | * they should use, since only the kernel knows how many randoms |
262 | * it intends to generate. |
263 | * |
264 | * Increment should be at least the number of curand() random numbers used in |
265 | * each thread. It is the user's responsibility to make sure the increment |
266 | * for philox is never smaller than the number of curand() calls. Increment |
267 | * value > the number of curand() calls won't harm but anything less would mean |
268 | * that you would be reusing random values from previous calls. |
269 | * |
270 | * See Note [Acquire lock when using random generators] |
271 | */ |
272 | PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) { |
273 | // rounds increment up to the nearest multiple of 4 |
274 | increment = ((increment + 3) / 4) * 4; |
275 | if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { |
276 | TORCH_CHECK(graph_expects_this_gen_, |
277 | "philox_cuda_state for an unexpected CUDA generator used during capture. " |
278 | CAPTURE_DEFAULT_GENS_MSG); |
279 | // see Note [Why enforce RNG offset % 4 == 0?] |
280 | TORCH_INTERNAL_ASSERT(this->offset_intragraph_ % 4 == 0); |
281 | uint32_t offset = this->offset_intragraph_; |
282 | TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <= |
283 | std::numeric_limits<uint32_t>::max() - increment); |
284 | this->offset_intragraph_ += increment; |
285 | return PhiloxCudaState(this->seed_extragraph_, |
286 | this->offset_extragraph_, |
287 | offset); |
288 | } else { |
289 | TORCH_CHECK(!graph_expects_this_gen_, |
290 | "CUDA generator expects graph capture to be underway, " |
291 | "but the current stream is not capturing." ); |
292 | // see Note [Why enforce RNG offset % 4 == 0?] |
293 | TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); |
294 | uint64_t offset = this->philox_offset_per_thread_; |
295 | this->philox_offset_per_thread_ += increment; |
296 | return PhiloxCudaState(this->seed_, offset); |
297 | } |
298 | } |
299 | |
300 | /** |
301 | * Temporarily accommodates call sites that use philox_engine_inputs. |
302 | * Allows incremental refactor of call sites to use philox_cuda_state. |
303 | */ |
304 | std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(uint64_t increment) { |
305 | at::cuda::assertNotCapturing("Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. " |
306 | "Cannot call CUDAGeneratorImpl::philox_engine_inputs" ); |
307 | // rounds increment up to the nearest multiple of 4 |
308 | increment = ((increment + 3) / 4) * 4; |
309 | // see Note [Why enforce RNG offset % 4 == 0?] |
310 | TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); |
311 | uint64_t offset = this->philox_offset_per_thread_; |
312 | this->philox_offset_per_thread_ += increment; |
313 | return std::make_pair(this->seed_, offset); |
314 | } |
315 | |
316 | /* |
317 | * Gets the DeviceType of CUDAGeneratorImpl. |
318 | * Used for type checking during run time. |
319 | */ |
320 | DeviceType CUDAGeneratorImpl::device_type() { |
321 | return DeviceType::CUDA; |
322 | } |
323 | |
324 | /** |
325 | * Public clone method implementation |
326 | * |
327 | * See Note [Acquire lock when using random generators] |
328 | */ |
329 | std::shared_ptr<CUDAGeneratorImpl> CUDAGeneratorImpl::clone() const { |
330 | return std::shared_ptr<CUDAGeneratorImpl>(this->clone_impl()); |
331 | } |
332 | |
333 | /** |
334 | * Private clone method implementation |
335 | * |
336 | * See Note [Acquire lock when using random generators] |
337 | */ |
338 | CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const { |
339 | at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl" ); |
340 | auto gen = new CUDAGeneratorImpl(this->device().index()); |
341 | gen->set_current_seed(this->seed_); |
342 | gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); |
343 | return gen; |
344 | } |
345 | |
346 | } // namespace at |
347 | |