1 | #pragma once |
2 | |
3 | #include <c10/cuda/CUDAStream.h> |
4 | #include <utility> |
5 | |
6 | // CUDA Graphs utils used by c10 and aten. |
7 | // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. |
8 | |
9 | namespace c10 { |
10 | namespace cuda { |
11 | |
12 | using CaptureId_t = unsigned long long; |
13 | |
14 | // first is set if the instance is created by CUDAGraph::capture_begin. |
15 | // second is set if the instance is created by at::cuda::graph_pool_handle. |
16 | using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>; |
17 | |
18 | // RAII guard for "cudaStreamCaptureMode", a thread-local value |
19 | // that controls the error-checking strictness of a capture. |
20 | |
21 | // TODO: ideally we'd replace this with something like |
22 | // !defined(TORCH_HIP_VERSION) as CUDA <= 10 support was dropped and really |
23 | // this is only a workaround for TORCH_HIP_VERSION not being a sufficient guard |
24 | // to prevent ROCM build breakage. |
25 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
26 | struct C10_CUDA_API CUDAStreamCaptureModeGuard { |
27 | CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) { |
28 | strictness_ = desired; |
29 | C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); |
30 | } |
31 | ~CUDAStreamCaptureModeGuard() { |
32 | C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); |
33 | } |
34 | |
35 | private: |
36 | cudaStreamCaptureMode strictness_; |
37 | }; |
38 | #endif |
39 | |
40 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
41 | // Protects against enum cudaStreamCaptureStatus implementation changes. |
42 | // Some compilers seem not to like static_assert without the messages. |
43 | static_assert( |
44 | int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, |
45 | "unexpected int(cudaStreamCaptureStatusNone) value" ); |
46 | static_assert( |
47 | int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, |
48 | "unexpected int(cudaStreamCaptureStatusActive) value" ); |
49 | static_assert( |
50 | int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, |
51 | "unexpected int(cudaStreamCaptureStatusInvalidated) value" ); |
52 | #endif |
53 | |
54 | enum class CaptureStatus : int { |
55 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
56 | None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), |
57 | Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), |
58 | Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) |
59 | #else |
60 | None = 0 |
61 | #endif |
62 | }; |
63 | |
64 | inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { |
65 | switch (status) { |
66 | case CaptureStatus::None: |
67 | os << "cudaStreamCaptureStatusNone" ; |
68 | break; |
69 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
70 | case CaptureStatus::Active: |
71 | os << "cudaStreamCaptureStatusActive" ; |
72 | break; |
73 | case CaptureStatus::Invalidated: |
74 | os << "cudaStreamCaptureStatusInvalidated" ; |
75 | break; |
76 | #endif |
77 | default: |
78 | TORCH_INTERNAL_ASSERT( |
79 | false, "Unknown CUDA graph CaptureStatus" , int(status)); |
80 | } |
81 | return os; |
82 | } |
83 | |
84 | // Use this version where you're sure a CUDA context exists already. |
85 | inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { |
86 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
87 | cudaStreamCaptureStatus is_capturing; |
88 | C10_CUDA_CHECK( |
89 | cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); |
90 | return CaptureStatus(is_capturing); |
91 | #else |
92 | return CaptureStatus::None; |
93 | #endif |
94 | } |
95 | |
96 | } // namespace cuda |
97 | } // namespace c10 |
98 | |