1#pragma once
2
3#include <ATen/cuda/CUDAGeneratorImpl.h>
4#include <ATen/cuda/CUDAEvent.h>
5#include <ATen/cuda/detail/UnpackRaw.cuh>
6#include <ATen/cuda/detail/CUDAHooks.h>
7#include <ATen/detail/CUDAHooksInterface.h>
8#include <c10/core/StreamGuard.h>
9#include <c10/cuda/CUDAGraphsC10Utils.h>
10#include <c10/cuda/CUDAGuard.h>
11
12// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
13// This file adds utils used by aten only.
14
15namespace at {
16namespace cuda {
17
18using CaptureId_t = c10::cuda::CaptureId_t;
19using CaptureStatus = c10::cuda::CaptureStatus;
20
21// Use this version where you don't want to create a CUDA context if none exists.
22inline CaptureStatus currentStreamCaptureStatus() {
23#if !defined(USE_ROCM)
24 // don't create a context if we don't have to
25 if (at::cuda::detail::hasPrimaryContext(c10::cuda::current_device())) {
26 return c10::cuda::currentStreamCaptureStatusMayInitCtx();
27 } else {
28 return CaptureStatus::None;
29 }
30#else
31 return CaptureStatus::None;
32#endif
33}
34
35inline void assertNotCapturing(std::string attempt) {
36 auto status = currentStreamCaptureStatus();
37 TORCH_CHECK(status == CaptureStatus::None,
38 attempt,
39 " during CUDA graph capture. If you need this call to be captured, "
40 "please file an issue. "
41 "Current cudaStreamCaptureStatus: ",
42 status);
43}
44
45inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
46 auto status = currentStreamCaptureStatus();
47 TORCH_CHECK(status == CaptureStatus::None,
48 "Current cudaStreamCaptureStatus: ",
49 status,
50 "\nCapturing ",
51 version_specific,
52 "is prohibited. Possible causes of this error:\n"
53 "1. No warmup iterations occurred before capture.\n"
54 "2. The convolutions you're trying to capture use dynamic shapes, "
55 "in which case capturing them is generally prohibited.");
56}
57
58} // namespace cuda
59} // namespace at
60