1#pragma once
2
3#include <c10/cuda/CUDAStream.h>
4#include <utility>
5
6// CUDA Graphs utils used by c10 and aten.
7// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
8
9namespace c10 {
10namespace cuda {
11
12using CaptureId_t = unsigned long long;
13
14// first is set if the instance is created by CUDAGraph::capture_begin.
15// second is set if the instance is created by at::cuda::graph_pool_handle.
16using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
17
18// RAII guard for "cudaStreamCaptureMode", a thread-local value
19// that controls the error-checking strictness of a capture.
20
21// TODO: ideally we'd replace this with something like
22// !defined(TORCH_HIP_VERSION) as CUDA <= 10 support was dropped and really
23// this is only a workaround for TORCH_HIP_VERSION not being a sufficient guard
24// to prevent ROCM build breakage.
25#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
26struct C10_CUDA_API CUDAStreamCaptureModeGuard {
27 CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) {
28 strictness_ = desired;
29 C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
30 }
31 ~CUDAStreamCaptureModeGuard() {
32 C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
33 }
34
35 private:
36 cudaStreamCaptureMode strictness_;
37};
38#endif
39
40#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
41// Protects against enum cudaStreamCaptureStatus implementation changes.
42// Some compilers seem not to like static_assert without the messages.
43static_assert(
44 int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
45 "unexpected int(cudaStreamCaptureStatusNone) value");
46static_assert(
47 int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
48 "unexpected int(cudaStreamCaptureStatusActive) value");
49static_assert(
50 int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
51 "unexpected int(cudaStreamCaptureStatusInvalidated) value");
52#endif
53
54enum class CaptureStatus : int {
55#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
56 None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
57 Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
58 Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
59#else
60 None = 0
61#endif
62};
63
64inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
65 switch (status) {
66 case CaptureStatus::None:
67 os << "cudaStreamCaptureStatusNone";
68 break;
69#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
70 case CaptureStatus::Active:
71 os << "cudaStreamCaptureStatusActive";
72 break;
73 case CaptureStatus::Invalidated:
74 os << "cudaStreamCaptureStatusInvalidated";
75 break;
76#endif
77 default:
78 TORCH_INTERNAL_ASSERT(
79 false, "Unknown CUDA graph CaptureStatus", int(status));
80 }
81 return os;
82}
83
84// Use this version where you're sure a CUDA context exists already.
85inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
86#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
87 cudaStreamCaptureStatus is_capturing;
88 C10_CUDA_CHECK(
89 cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
90 return CaptureStatus(is_capturing);
91#else
92 return CaptureStatus::None;
93#endif
94}
95
96} // namespace cuda
97} // namespace c10
98