1#pragma once
2
3#include <c10/core/DeviceGuard.h>
4#include <c10/core/impl/DeviceGuardImplInterface.h>
5#include <c10/core/impl/GPUTrace.h>
6#include <c10/macros/Macros.h>
7#include <c10/util/Exception.h>
8
9#include <c10/cuda/CUDACachingAllocator.h>
10#include <c10/cuda/CUDAException.h>
11#include <c10/cuda/CUDAFunctions.h>
12#include <c10/cuda/CUDAStream.h>
13
14#include <cuda_runtime_api.h>
15
16namespace c10 {
17namespace cuda {
18namespace impl {
19
20struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
21 static constexpr DeviceType static_type = DeviceType::CUDA;
22
23 CUDAGuardImpl() = default;
24 explicit CUDAGuardImpl(DeviceType t) {
25 TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
26 }
27 DeviceType type() const override {
28 return DeviceType::CUDA;
29 }
30 Device exchangeDevice(Device d) const override {
31 TORCH_INTERNAL_ASSERT(d.is_cuda());
32 Device old_device = getDevice();
33 if (old_device.index() != d.index()) {
34 C10_CUDA_CHECK(cudaSetDevice(d.index()));
35 }
36 return old_device;
37 }
38 Device getDevice() const override {
39 int device;
40 C10_CUDA_CHECK(cudaGetDevice(&device));
41 return Device(DeviceType::CUDA, device);
42 }
43 c10::optional<Device> uncheckedGetDevice() const noexcept {
44 int device;
45 const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device));
46 C10_CUDA_CHECK_WARN(err);
47 if (err != cudaSuccess) {
48 return c10::nullopt;
49 }
50 return Device(DeviceType::CUDA, device);
51 }
52 void setDevice(Device d) const override {
53 TORCH_INTERNAL_ASSERT(d.is_cuda());
54 Device current_device = getDevice();
55 if (current_device != d) {
56 C10_CUDA_CHECK(cudaSetDevice(d.index()));
57 }
58 }
59 void uncheckedSetDevice(Device d) const noexcept override {
60 auto current_device = uncheckedGetDevice();
61 if (!current_device.has_value() || current_device.value() != d) {
62 C10_CUDA_CHECK_WARN(cudaSetDevice(d.index()));
63 }
64 }
65 Stream getStream(Device d) const noexcept override {
66 return getCurrentCUDAStream(d.index()).unwrap();
67 }
68 Stream getDefaultStream(Device d) const override {
69 return getDefaultCUDAStream(d.index());
70 }
71 Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
72 const override {
73 return getStreamFromPool(isHighPriority, d.index());
74 }
75 // NB: These do NOT set the current device
76 Stream exchangeStream(Stream s) const noexcept override {
77 CUDAStream cs(s);
78 auto old_stream = getCurrentCUDAStream(s.device().index());
79 setCurrentCUDAStream(cs);
80 return old_stream.unwrap();
81 }
82 DeviceIndex deviceCount() const noexcept override {
83 return device_count();
84 }
85
86 // Event-related functions
87 void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
88 // Maps PyTorch's Event::Flag to CUDA flag
89 auto cuda_flag = cudaEventDefault;
90 switch (flag) {
91 case EventFlag::PYTORCH_DEFAULT:
92 case EventFlag::CUDA_EVENT_DISABLE_TIMING:
93 cuda_flag = cudaEventDisableTiming;
94 break;
95 case EventFlag::BACKEND_DEFAULT:
96 case EventFlag::CUDA_EVENT_DEFAULT:
97 cuda_flag = cudaEventDefault;
98 break;
99 default:
100 TORCH_CHECK(false, "CUDA event received unknown flag");
101 }
102
103 C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
104 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
105 if (C10_UNLIKELY(interp)) {
106 (*interp)->trace_gpu_event_creation(
107 reinterpret_cast<uintptr_t>(cuda_event));
108 }
109 }
110
111 void destroyEvent(void* event, const DeviceIndex device_index)
112 const noexcept override {
113 if (!event)
114 return;
115 auto cuda_event = static_cast<cudaEvent_t>(event);
116 int orig_device;
117 C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
118 C10_CUDA_CHECK_WARN(cudaSetDevice(device_index));
119 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
120 if (C10_UNLIKELY(interp)) {
121 (*interp)->trace_gpu_event_deletion(
122 reinterpret_cast<uintptr_t>(cuda_event));
123 }
124 C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
125 C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device));
126 }
127
128 void record(
129 void** event,
130 const Stream& stream,
131 const DeviceIndex device_index,
132 const EventFlag flag) const override {
133 TORCH_CHECK(
134 device_index == -1 || device_index == stream.device_index(),
135 "Event device index ",
136 device_index,
137 " does not match recording stream's device index ",
138 stream.device_index(),
139 ".");
140
141 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
142 CUDAStream cuda_stream{stream};
143
144 // Moves to stream's device to record
145 const auto orig_device = getDevice();
146 setDevice(stream.device());
147
148 // Creates the event (lazily)
149 if (!cuda_event)
150 createEvent(&cuda_event, flag);
151 C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
152 // Makes the void* point to the (possibly just allocated) CUDA event
153 *event = cuda_event;
154 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
155 if (C10_UNLIKELY(interp)) {
156 (*interp)->trace_gpu_event_record(
157 reinterpret_cast<uintptr_t>(cuda_event),
158 reinterpret_cast<uintptr_t>(cuda_stream.stream()));
159 }
160
161 // Resets device
162 setDevice(orig_device);
163 }
164
165 void block(void* event, const Stream& stream) const override {
166 if (!event)
167 return;
168 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
169 CUDAStream cuda_stream{stream};
170 const auto orig_device = getDevice();
171 setDevice(stream.device());
172 C10_CUDA_CHECK(cudaStreamWaitEvent(
173 cuda_stream,
174 cuda_event,
175 /*flags (must be zero)=*/0));
176 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
177 if (C10_UNLIKELY(interp)) {
178 (*interp)->trace_gpu_event_wait(
179 reinterpret_cast<uintptr_t>(cuda_event),
180 reinterpret_cast<uintptr_t>(cuda_stream.stream()));
181 }
182 setDevice(orig_device);
183 }
184
185 // May be called from any device
186 bool queryEvent(void* event) const override {
187 if (!event)
188 return true;
189 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
190 const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
191 if (err != cudaErrorNotReady) {
192 C10_CUDA_CHECK(err);
193 } else {
194 // ignore and clear the error if not ready
195 (void)cudaGetLastError();
196 }
197 return (err == cudaSuccess);
198 }
199
200 // Stream-related functions
201 bool queryStream(const Stream& stream) const override {
202 CUDAStream cuda_stream{stream};
203 return cuda_stream.query();
204 }
205
206 void synchronizeStream(const Stream& stream) const override {
207 CUDAStream cuda_stream{stream};
208 cuda_stream.synchronize();
209 }
210
211 void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
212 const override {
213 CUDAStream cuda_stream{stream};
214 CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
215 }
216};
217
218} // namespace impl
219} // namespace cuda
220} // namespace c10
221