1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/Device.h> |
5 | #include <c10/cuda/CUDAGraphsC10Utils.h> |
6 | #include <c10/cuda/CUDAStream.h> |
7 | |
8 | namespace at { |
9 | |
10 | struct CUDAGeneratorImpl; |
11 | |
12 | namespace cuda { |
13 | |
14 | // Standalone way to get a unique mempool id usable as a pool=... argument |
15 | // to CUDAGraph::capture_begin |
16 | TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); |
17 | |
18 | struct TORCH_CUDA_CPP_API CUDAGraph { |
19 | CUDAGraph(); |
20 | ~CUDAGraph(); |
21 | |
22 | void capture_begin(MempoolId_t pool={0, 0}); |
23 | void capture_end(); |
24 | void replay(); |
25 | void reset(); |
26 | MempoolId_t pool(); |
27 | void enable_debug_mode(); |
28 | void debug_dump(const std::string& debug_path); |
29 | |
30 | protected: |
31 | #if !defined(USE_ROCM) |
32 | cudaGraph_t graph_ = NULL; |
33 | cudaGraphExec_t graph_exec_ = NULL; |
34 | #endif |
35 | |
36 | // internal states so reset() can do its best cleaning up |
37 | // Set to true in capture_end if cudaStreamEndCapture succeeded |
38 | // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate |
39 | // to create graph_exec_, then graph_ is deleted |
40 | bool has_graph_ = false; |
41 | // Set to true in capture_end if cudaGraphInstantiate succeeded |
42 | bool has_graph_exec_ = false; |
43 | |
44 | // uuid of this instance's current capture, retrieved from Cuda |
45 | CaptureId_t id_; |
46 | |
47 | // uuid used to request a particular private mempool from CUDACachingAllocator. |
48 | // By default, this will be set to {id_, 0}. |
49 | // |
50 | // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_ |
51 | // will be set to the other graph's mempool_id_, and therefore share a mempool with the |
52 | // other graph. |
53 | // |
54 | // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(), |
55 | // it will share a mempool with any other captures that used "pool=handle". |
56 | // |
57 | // Sharing a mempool across graphs saves memory, and it's safe if you |
58 | // know you'll replay those graphs in the same order you captured them. |
59 | MempoolId_t mempool_id_; |
60 | |
61 | // Stream on which capture began |
62 | at::cuda::CUDAStream capture_stream_; |
63 | |
64 | // Default generator on device where capture began |
65 | at::CUDAGeneratorImpl* capture_gen_; |
66 | |
67 | // Device where capture occurred. Right now, for simplicity, we require all ops |
68 | // in a capture to run on the same device, but this is a limitation of CUDAGraph, |
69 | // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device |
70 | // captures if needed. |
71 | int capture_dev_; |
72 | |
73 | // RNG state trackers |
74 | at::Tensor ; |
75 | at::Tensor ; |
76 | uint64_t wholegraph_increment_; |
77 | }; |
78 | |
79 | } // namespace cuda |
80 | } // namespace at |
81 | |