1#pragma once
2
3#include <ATen/Tensor.h>
4#include <c10/core/Device.h>
5#include <c10/cuda/CUDAGraphsC10Utils.h>
6#include <c10/cuda/CUDAStream.h>
7
8namespace at {
9
10struct CUDAGeneratorImpl;
11
12namespace cuda {
13
14// Standalone way to get a unique mempool id usable as a pool=... argument
15// to CUDAGraph::capture_begin
16TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
17
18struct TORCH_CUDA_CPP_API CUDAGraph {
19 CUDAGraph();
20 ~CUDAGraph();
21
22 void capture_begin(MempoolId_t pool={0, 0});
23 void capture_end();
24 void replay();
25 void reset();
26 MempoolId_t pool();
27 void enable_debug_mode();
28 void debug_dump(const std::string& debug_path);
29
30 protected:
31#if !defined(USE_ROCM)
32 cudaGraph_t graph_ = NULL;
33 cudaGraphExec_t graph_exec_ = NULL;
34#endif
35
36 // internal states so reset() can do its best cleaning up
37 // Set to true in capture_end if cudaStreamEndCapture succeeded
38 // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
39 // to create graph_exec_, then graph_ is deleted
40 bool has_graph_ = false;
41 // Set to true in capture_end if cudaGraphInstantiate succeeded
42 bool has_graph_exec_ = false;
43
44 // uuid of this instance's current capture, retrieved from Cuda
45 CaptureId_t id_;
46
47 // uuid used to request a particular private mempool from CUDACachingAllocator.
48 // By default, this will be set to {id_, 0}.
49 //
50 // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
51 // will be set to the other graph's mempool_id_, and therefore share a mempool with the
52 // other graph.
53 //
54 // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
55 // it will share a mempool with any other captures that used "pool=handle".
56 //
57 // Sharing a mempool across graphs saves memory, and it's safe if you
58 // know you'll replay those graphs in the same order you captured them.
59 MempoolId_t mempool_id_;
60
61 // Stream on which capture began
62 at::cuda::CUDAStream capture_stream_;
63
64 // Default generator on device where capture began
65 at::CUDAGeneratorImpl* capture_gen_;
66
67 // Device where capture occurred. Right now, for simplicity, we require all ops
68 // in a capture to run on the same device, but this is a limitation of CUDAGraph,
69 // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
70 // captures if needed.
71 int capture_dev_;
72
73 // RNG state trackers
74 at::Tensor seed_extragraph_;
75 at::Tensor offset_extragraph_;
76 uint64_t wholegraph_increment_;
77};
78
79} // namespace cuda
80} // namespace at
81