1#pragma once
2
3#include <ATen/core/Generator.h>
4#include <ATen/cuda/detail/PhiloxCudaStateRaw.cuh>
5#include <ATen/Context.h>
6#include <limits>
7#include <atomic>
8
9namespace at {
10/**
11 * Note [CUDA Graph-safe RNG states]
12 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13 *
14 * Strategy:
15 * ~~~~~~~~~
16 * (It helps to look at
17 * cuda/detail/PhiloxCudaStateRaw.cuh and
18 * cuda/detail/UnpackRaw.cuh
19 * while you read this.)
20 *
21 * A CUDA graph containing multiple RNG ops behaves like a
22 * single giant kernel from the perspective of ops external
23 * to the graph. During graph capture, logic in CUDAGeneratorImpl
24 * records the total of all offset increments that occur in the
25 * graphed region, and records the final total as the offset for
26 * the entire graph.
27 *
28 * When the graph reruns, the logic that reruns it
29 * increments this device's CUDA generator's offset
30 * by that total.
31 *
32 * Meanwhile, within the graph, at capture time, instead of
33 * populating PhiloxCudaStates with the uint64_t offset pulled
34 * directly from the global state, PhiloxCudaState uses a pointer
35 * to a one-element stream-local int64_t device tensor
36 * holding an initial offset value, and a uint64_t holding an
37 * intra-graph offset. (The intra-graph offset starts from zero
38 * when capture begins.) In each consumer kernel,
39 * at::cuda::philox::unpack computes the offset to use for this kernel
40 * as intra-graph offset + *initial offset.
41 *
42 * When the graph reruns, the logic that reruns it first
43 * fill_s the initial offset tensor with this device's
44 * CUDA generator's current offset.
45 *
46 * The control flow above ensures graphed execution is bitwise
47 * identical to eager execution as long as RNG ops are enqueued
48 * from a single thread, even if RNG ops and graphs containing
49 * RNG ops are enqueued and run simultaneously on multiple streams.
50 *
51 * Usage:
52 * ~~~~~~
53 * PhiloxCudaState in this file, and unpack() in
54 * cuda/CUDAGraphsUtils.cuh allow non-divergent use of
55 * CUDAGeneratorImpl whether graph capture is underway or not.
56 *
57 * Each PhiloxCudaState instance should be used for one and only one
58 * consumer kernel.
59 *
60 * Example (see e.g. native/cuda/Dropout.cu):
61 *
62 * #include <ATen/cuda/CUDAGeneratorImpl.h>
63 * #include <ATen/cuda/CUDAGraphsUtils.cuh>
64 *
65 * __global__ void kernel(..., PhiloxCudaState philox_args) {
66 * auto seeds = at::cuda::philox::unpack(philox_args);
67 * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
68 * curandStatePhilox4_32_10_t state;
69 * curand_init(std::get<0>(seeds), // seed
70 * idx, // per-thread subsequence
71 * std::get<1>(seeds), // offset in subsequence
72 * &state);
73 * ...
74 * }
75 *
76 * host_caller(...) {
77 * PhiloxCudaState rng_engine_inputs;
78 * {
79 * // See Note [Acquire lock when using random generators]
80 * std::lock_guard<std::mutex> lock(gen->mutex_);
81 *
82 * // gen could be HostState or DevState here! No divergent code needed!
83 * rng_engine_inputs = gen->philox_cuda_state(offset_increment);
84 * }
85 * kernel<<<...>>>(..., rng_engine_inputs);
86 * }
87 *
88 */
89
90struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
91 // Constructors
92 CUDAGeneratorImpl(DeviceIndex device_index = -1);
93 ~CUDAGeneratorImpl() override = default;
94
95 // CUDAGeneratorImpl methods
96 std::shared_ptr<CUDAGeneratorImpl> clone() const;
97 void set_current_seed(uint64_t seed) override;
98 uint64_t current_seed() const override;
99 uint64_t seed() override;
100 void set_state(const c10::TensorImpl& new_state) override;
101 c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
102 void set_philox_offset_per_thread(uint64_t offset);
103 uint64_t philox_offset_per_thread() const;
104 void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
105 uint64_t capture_epilogue();
106 PhiloxCudaState philox_cuda_state(uint64_t increment);
107
108 bool reset_rnn_state() {
109 return !no_reset_rnn_state_.test_and_set();
110 }
111
112 // Temporarily accommodates call sites that use philox_engine_inputs.
113 // Allows incremental refactor of call sites to use philox_cuda_state.
114 std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
115
116 static DeviceType device_type();
117
118private:
119 CUDAGeneratorImpl* clone_impl() const override;
120 uint64_t seed_ = default_rng_seed_val;
121 uint64_t philox_offset_per_thread_ = 0;
122 int64_t* seed_extragraph_{};
123 int64_t* offset_extragraph_{};
124 uint32_t offset_intragraph_ = 0;
125 bool graph_expects_this_gen_ = false;
126 std::atomic_flag no_reset_rnn_state_;
127};
128
129namespace cuda {
130namespace detail {
131
132TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
133 DeviceIndex device_index = -1);
134TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
135
136} // namespace detail
137} // namespace cuda
138} // namespace at
139