1 | // No "#pragma once" because this is a raw definition that can be copied by jit codegen. |
2 | // Eager mode clients should not include this file directly, instead, |
3 | // they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once. |
4 | |
5 | // Stores RNG state values. Passed as a kernel argument. |
6 | // See Note [CUDA Graph-safe RNG states]. |
7 | // |
8 | // The raw definition lives in its own file so jit codegen can easily copy it. |
9 | namespace at { |
10 | |
11 | struct PhiloxCudaState { |
12 | PhiloxCudaState() = default; |
13 | // Called if graph capture is not underway |
14 | PhiloxCudaState(uint64_t seed, |
15 | uint64_t offset) { |
16 | seed_.val = seed; |
17 | offset_.val = offset; |
18 | } |
19 | // Called if graph capture is underway |
20 | PhiloxCudaState(int64_t* seed, |
21 | int64_t* , |
22 | uint32_t offset_intragraph) { |
23 | seed_.ptr = seed; |
24 | offset_.ptr = offset_extragraph; |
25 | offset_intragraph_ = offset_intragraph; |
26 | captured_ = true; |
27 | } |
28 | |
29 | // Public members, directly accessible by at::cuda::philox::unpack. |
30 | // If we made them private with getters/setters, the getters/setters |
31 | // would have to be __device__, and we can't declare __device__ in ATen. |
32 | union Payload { |
33 | uint64_t val; |
34 | int64_t* ptr; |
35 | }; |
36 | |
37 | Payload seed_; |
38 | Payload offset_; |
39 | uint32_t offset_intragraph_ = 0; |
40 | bool captured_ = false; |
41 | }; |
42 | |
43 | } // namespace at |
44 | |