1// Generated from "/code/pytorch/third_party/nvfuser/../../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* PhiloxCudaStateRaw_cu = R"(
7// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
8// Eager mode clients should not include this file directly, instead,
9// they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once.
10
11// Stores RNG state values. Passed as a kernel argument.
12// See Note [CUDA Graph-safe RNG states].
13//
14// The raw definition lives in its own file so jit codegen can easily copy it.
15namespace at {
16
17struct PhiloxCudaState {
18 PhiloxCudaState() = default;
19 // Called if graph capture is not underway
20 PhiloxCudaState(uint64_t seed,
21 uint64_t offset) {
22 seed_.val = seed;
23 offset_.val = offset;
24 }
25 // Called if graph capture is underway
26 PhiloxCudaState(int64_t* seed,
27 int64_t* offset_extragraph,
28 uint32_t offset_intragraph) {
29 seed_.ptr = seed;
30 offset_.ptr = offset_extragraph;
31 offset_intragraph_ = offset_intragraph;
32 captured_ = true;
33 }
34
35 // Public members, directly accessible by at::cuda::philox::unpack.
36 // If we made them private with getters/setters, the getters/setters
37 // would have to be __device__, and we can't declare __device__ in ATen.
38 union Payload {
39 uint64_t val;
40 int64_t* ptr;
41 };
42
43 Payload seed_;
44 Payload offset_;
45 uint32_t offset_intragraph_ = 0;
46 bool captured_ = false;
47};
48
49} // namespace at
50)";
51
52} // namespace nvfuser_resources
53