1 | // Generated from "/code/pytorch/third_party/nvfuser/../../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* PhiloxCudaStateRaw_cu = R"( |
7 | // No "#pragma once" because this is a raw definition that can be copied by jit codegen. |
8 | // Eager mode clients should not include this file directly, instead, |
9 | // they should #include <ATen/cuda/CUDAGeneratorImpl.h>, which has a #pragma once. |
10 | |
11 | // Stores RNG state values. Passed as a kernel argument. |
12 | // See Note [CUDA Graph-safe RNG states]. |
13 | // |
14 | // The raw definition lives in its own file so jit codegen can easily copy it. |
15 | namespace at { |
16 | |
17 | struct PhiloxCudaState { |
18 | PhiloxCudaState() = default; |
19 | // Called if graph capture is not underway |
20 | PhiloxCudaState(uint64_t seed, |
21 | uint64_t offset) { |
22 | seed_.val = seed; |
23 | offset_.val = offset; |
24 | } |
25 | // Called if graph capture is underway |
26 | PhiloxCudaState(int64_t* seed, |
27 | int64_t* offset_extragraph, |
28 | uint32_t offset_intragraph) { |
29 | seed_.ptr = seed; |
30 | offset_.ptr = offset_extragraph; |
31 | offset_intragraph_ = offset_intragraph; |
32 | captured_ = true; |
33 | } |
34 | |
35 | // Public members, directly accessible by at::cuda::philox::unpack. |
36 | // If we made them private with getters/setters, the getters/setters |
37 | // would have to be __device__, and we can't declare __device__ in ATen. |
38 | union Payload { |
39 | uint64_t val; |
40 | int64_t* ptr; |
41 | }; |
42 | |
43 | Payload seed_; |
44 | Payload offset_; |
45 | uint32_t offset_intragraph_ = 0; |
46 | bool captured_ = false; |
47 | }; |
48 | |
49 | } // namespace at |
50 | )" ; |
51 | |
52 | } // namespace nvfuser_resources |
53 | |