1 | #pragma once |
2 | |
3 | #include <ATen/core/Generator.h> |
4 | #include <ATen/cuda/detail/PhiloxCudaStateRaw.cuh> |
5 | #include <ATen/Context.h> |
6 | #include <limits> |
7 | #include <atomic> |
8 | |
9 | namespace at { |
10 | /** |
11 | * Note [CUDA Graph-safe RNG states] |
12 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
13 | * |
14 | * Strategy: |
15 | * ~~~~~~~~~ |
16 | * (It helps to look at |
17 | * cuda/detail/PhiloxCudaStateRaw.cuh and |
18 | * cuda/detail/UnpackRaw.cuh |
19 | * while you read this.) |
20 | * |
21 | * A CUDA graph containing multiple RNG ops behaves like a |
22 | * single giant kernel from the perspective of ops external |
23 | * to the graph. During graph capture, logic in CUDAGeneratorImpl |
24 | * records the total of all offset increments that occur in the |
25 | * graphed region, and records the final total as the offset for |
26 | * the entire graph. |
27 | * |
28 | * When the graph reruns, the logic that reruns it |
29 | * increments this device's CUDA generator's offset |
30 | * by that total. |
31 | * |
32 | * Meanwhile, within the graph, at capture time, instead of |
33 | * populating PhiloxCudaStates with the uint64_t offset pulled |
34 | * directly from the global state, PhiloxCudaState uses a pointer |
35 | * to a one-element stream-local int64_t device tensor |
36 | * holding an initial offset value, and a uint64_t holding an |
37 | * intra-graph offset. (The intra-graph offset starts from zero |
38 | * when capture begins.) In each consumer kernel, |
39 | * at::cuda::philox::unpack computes the offset to use for this kernel |
40 | * as intra-graph offset + *initial offset. |
41 | * |
42 | * When the graph reruns, the logic that reruns it first |
43 | * fill_s the initial offset tensor with this device's |
44 | * CUDA generator's current offset. |
45 | * |
46 | * The control flow above ensures graphed execution is bitwise |
47 | * identical to eager execution as long as RNG ops are enqueued |
48 | * from a single thread, even if RNG ops and graphs containing |
49 | * RNG ops are enqueued and run simultaneously on multiple streams. |
50 | * |
51 | * Usage: |
52 | * ~~~~~~ |
53 | * PhiloxCudaState in this file, and unpack() in |
54 | * cuda/CUDAGraphsUtils.cuh allow non-divergent use of |
55 | * CUDAGeneratorImpl whether graph capture is underway or not. |
56 | * |
57 | * Each PhiloxCudaState instance should be used for one and only one |
58 | * consumer kernel. |
59 | * |
60 | * Example (see e.g. native/cuda/Dropout.cu): |
61 | * |
62 | * #include <ATen/cuda/CUDAGeneratorImpl.h> |
63 | * #include <ATen/cuda/CUDAGraphsUtils.cuh> |
64 | * |
65 | * __global__ void kernel(..., PhiloxCudaState philox_args) { |
66 | * auto seeds = at::cuda::philox::unpack(philox_args); |
67 | * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; |
68 | * curandStatePhilox4_32_10_t state; |
69 | * curand_init(std::get<0>(seeds), // seed |
70 | * idx, // per-thread subsequence |
71 | * std::get<1>(seeds), // offset in subsequence |
72 | * &state); |
73 | * ... |
74 | * } |
75 | * |
76 | * host_caller(...) { |
77 | * PhiloxCudaState rng_engine_inputs; |
78 | * { |
79 | * // See Note [Acquire lock when using random generators] |
80 | * std::lock_guard<std::mutex> lock(gen->mutex_); |
81 | * |
82 | * // gen could be HostState or DevState here! No divergent code needed! |
83 | * rng_engine_inputs = gen->philox_cuda_state(offset_increment); |
84 | * } |
85 | * kernel<<<...>>>(..., rng_engine_inputs); |
86 | * } |
87 | * |
88 | */ |
89 | |
90 | struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { |
91 | // Constructors |
92 | CUDAGeneratorImpl(DeviceIndex device_index = -1); |
93 | ~CUDAGeneratorImpl() override = default; |
94 | |
95 | // CUDAGeneratorImpl methods |
96 | std::shared_ptr<CUDAGeneratorImpl> clone() const; |
97 | void set_current_seed(uint64_t seed) override; |
98 | uint64_t current_seed() const override; |
99 | uint64_t seed() override; |
100 | void set_state(const c10::TensorImpl& new_state) override; |
101 | c10::intrusive_ptr<c10::TensorImpl> get_state() const override; |
102 | void set_philox_offset_per_thread(uint64_t offset); |
103 | uint64_t philox_offset_per_thread() const; |
104 | void capture_prologue(int64_t* , int64_t* ); |
105 | uint64_t capture_epilogue(); |
106 | PhiloxCudaState philox_cuda_state(uint64_t increment); |
107 | |
108 | bool reset_rnn_state() { |
109 | return !no_reset_rnn_state_.test_and_set(); |
110 | } |
111 | |
112 | // Temporarily accommodates call sites that use philox_engine_inputs. |
113 | // Allows incremental refactor of call sites to use philox_cuda_state. |
114 | std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment); |
115 | |
116 | static DeviceType device_type(); |
117 | |
118 | private: |
119 | CUDAGeneratorImpl* clone_impl() const override; |
120 | uint64_t seed_ = default_rng_seed_val; |
121 | uint64_t philox_offset_per_thread_ = 0; |
122 | int64_t* {}; |
123 | int64_t* {}; |
124 | uint32_t offset_intragraph_ = 0; |
125 | bool graph_expects_this_gen_ = false; |
126 | std::atomic_flag no_reset_rnn_state_; |
127 | }; |
128 | |
129 | namespace cuda { |
130 | namespace detail { |
131 | |
132 | TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator( |
133 | DeviceIndex device_index = -1); |
134 | TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1); |
135 | |
136 | } // namespace detail |
137 | } // namespace cuda |
138 | } // namespace at |
139 | |