1// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2// Eager mode clients should not include this file directly, instead,
3// they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once.
4
5// Stores RNG state values. Passed as a kernel argument.
6// See Note [CUDA Graph-safe RNG states].
7//
8// The raw definition lives in its own file so jit codegen can easily copy it.
9namespace at {
10
11struct PhiloxCudaState {
12 PhiloxCudaState() = default;
13 // Called if graph capture is not underway
14 PhiloxCudaState(uint64_t seed,
15 uint64_t offset) {
16 seed_.val = seed;
17 offset_.val = offset;
18 }
19 // Called if graph capture is underway
20 PhiloxCudaState(int64_t* seed,
21 int64_t* offset_extragraph,
22 uint32_t offset_intragraph) {
23 seed_.ptr = seed;
24 offset_.ptr = offset_extragraph;
25 offset_intragraph_ = offset_intragraph;
26 captured_ = true;
27 }
28
29 // Public members, directly accessible by at::cuda::philox::unpack.
30 // If we made them private with getters/setters, the getters/setters
31 // would have to be __device__, and we can't declare __device__ in ATen.
32 union Payload {
33 uint64_t val;
34 int64_t* ptr;
35 };
36
37 Payload seed_;
38 Payload offset_;
39 uint32_t offset_intragraph_ = 0;
40 bool captured_ = false;
41};
42
43} // namespace at
44