1 | #pragma once |
2 | |
3 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
4 | #include <ATen/cuda/CUDAEvent.h> |
5 | #include <ATen/cuda/detail/UnpackRaw.cuh> |
6 | #include <ATen/cuda/detail/CUDAHooks.h> |
7 | #include <ATen/detail/CUDAHooksInterface.h> |
8 | #include <c10/core/StreamGuard.h> |
9 | #include <c10/cuda/CUDAGraphsC10Utils.h> |
10 | #include <c10/cuda/CUDAGuard.h> |
11 | |
12 | // c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten. |
13 | // This file adds utils used by aten only. |
14 | |
15 | namespace at { |
16 | namespace cuda { |
17 | |
18 | using CaptureId_t = c10::cuda::CaptureId_t; |
19 | using CaptureStatus = c10::cuda::CaptureStatus; |
20 | |
21 | // Use this version where you don't want to create a CUDA context if none exists. |
22 | inline CaptureStatus currentStreamCaptureStatus() { |
23 | #if !defined(USE_ROCM) |
24 | // don't create a context if we don't have to |
25 | if (at::cuda::detail::hasPrimaryContext(c10::cuda::current_device())) { |
26 | return c10::cuda::currentStreamCaptureStatusMayInitCtx(); |
27 | } else { |
28 | return CaptureStatus::None; |
29 | } |
30 | #else |
31 | return CaptureStatus::None; |
32 | #endif |
33 | } |
34 | |
35 | inline void assertNotCapturing(std::string attempt) { |
36 | auto status = currentStreamCaptureStatus(); |
37 | TORCH_CHECK(status == CaptureStatus::None, |
38 | attempt, |
39 | " during CUDA graph capture. If you need this call to be captured, " |
40 | "please file an issue. " |
41 | "Current cudaStreamCaptureStatus: " , |
42 | status); |
43 | } |
44 | |
45 | inline void errorIfCapturingCudnnBenchmark(std::string version_specific) { |
46 | auto status = currentStreamCaptureStatus(); |
47 | TORCH_CHECK(status == CaptureStatus::None, |
48 | "Current cudaStreamCaptureStatus: " , |
49 | status, |
50 | "\nCapturing " , |
51 | version_specific, |
52 | "is prohibited. Possible causes of this error:\n" |
53 | "1. No warmup iterations occurred before capture.\n" |
54 | "2. The convolutions you're trying to capture use dynamic shapes, " |
55 | "in which case capturing them is generally prohibited." ); |
56 | } |
57 | |
58 | } // namespace cuda |
59 | } // namespace at |
60 | |