1 | #pragma once |
2 | |
3 | #include <c10/core/DeviceGuard.h> |
4 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
5 | #include <c10/core/impl/GPUTrace.h> |
6 | #include <c10/macros/Macros.h> |
7 | #include <c10/util/Exception.h> |
8 | |
9 | #include <c10/cuda/CUDACachingAllocator.h> |
10 | #include <c10/cuda/CUDAException.h> |
11 | #include <c10/cuda/CUDAFunctions.h> |
12 | #include <c10/cuda/CUDAStream.h> |
13 | |
14 | #include <cuda_runtime_api.h> |
15 | |
16 | namespace c10 { |
17 | namespace cuda { |
18 | namespace impl { |
19 | |
20 | struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { |
21 | static constexpr DeviceType static_type = DeviceType::CUDA; |
22 | |
23 | CUDAGuardImpl() = default; |
24 | explicit CUDAGuardImpl(DeviceType t) { |
25 | TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); |
26 | } |
27 | DeviceType type() const override { |
28 | return DeviceType::CUDA; |
29 | } |
30 | Device exchangeDevice(Device d) const override { |
31 | TORCH_INTERNAL_ASSERT(d.is_cuda()); |
32 | Device old_device = getDevice(); |
33 | if (old_device.index() != d.index()) { |
34 | C10_CUDA_CHECK(cudaSetDevice(d.index())); |
35 | } |
36 | return old_device; |
37 | } |
38 | Device getDevice() const override { |
39 | int device; |
40 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
41 | return Device(DeviceType::CUDA, device); |
42 | } |
43 | c10::optional<Device> uncheckedGetDevice() const noexcept { |
44 | int device; |
45 | const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); |
46 | C10_CUDA_CHECK_WARN(err); |
47 | if (err != cudaSuccess) { |
48 | return c10::nullopt; |
49 | } |
50 | return Device(DeviceType::CUDA, device); |
51 | } |
52 | void setDevice(Device d) const override { |
53 | TORCH_INTERNAL_ASSERT(d.is_cuda()); |
54 | Device current_device = getDevice(); |
55 | if (current_device != d) { |
56 | C10_CUDA_CHECK(cudaSetDevice(d.index())); |
57 | } |
58 | } |
59 | void uncheckedSetDevice(Device d) const noexcept override { |
60 | auto current_device = uncheckedGetDevice(); |
61 | if (!current_device.has_value() || current_device.value() != d) { |
62 | C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); |
63 | } |
64 | } |
65 | Stream getStream(Device d) const noexcept override { |
66 | return getCurrentCUDAStream(d.index()).unwrap(); |
67 | } |
68 | Stream getDefaultStream(Device d) const override { |
69 | return getDefaultCUDAStream(d.index()); |
70 | } |
71 | Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) |
72 | const override { |
73 | return getStreamFromPool(isHighPriority, d.index()); |
74 | } |
75 | // NB: These do NOT set the current device |
76 | Stream exchangeStream(Stream s) const noexcept override { |
77 | CUDAStream cs(s); |
78 | auto old_stream = getCurrentCUDAStream(s.device().index()); |
79 | setCurrentCUDAStream(cs); |
80 | return old_stream.unwrap(); |
81 | } |
82 | DeviceIndex deviceCount() const noexcept override { |
83 | return device_count(); |
84 | } |
85 | |
86 | // Event-related functions |
87 | void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { |
88 | // Maps PyTorch's Event::Flag to CUDA flag |
89 | auto cuda_flag = cudaEventDefault; |
90 | switch (flag) { |
91 | case EventFlag::PYTORCH_DEFAULT: |
92 | case EventFlag::CUDA_EVENT_DISABLE_TIMING: |
93 | cuda_flag = cudaEventDisableTiming; |
94 | break; |
95 | case EventFlag::BACKEND_DEFAULT: |
96 | case EventFlag::CUDA_EVENT_DEFAULT: |
97 | cuda_flag = cudaEventDefault; |
98 | break; |
99 | default: |
100 | TORCH_CHECK(false, "CUDA event received unknown flag" ); |
101 | } |
102 | |
103 | C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); |
104 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
105 | if (C10_UNLIKELY(interp)) { |
106 | (*interp)->trace_gpu_event_creation( |
107 | reinterpret_cast<uintptr_t>(cuda_event)); |
108 | } |
109 | } |
110 | |
111 | void destroyEvent(void* event, const DeviceIndex device_index) |
112 | const noexcept override { |
113 | if (!event) |
114 | return; |
115 | auto cuda_event = static_cast<cudaEvent_t>(event); |
116 | int orig_device; |
117 | C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); |
118 | C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); |
119 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
120 | if (C10_UNLIKELY(interp)) { |
121 | (*interp)->trace_gpu_event_deletion( |
122 | reinterpret_cast<uintptr_t>(cuda_event)); |
123 | } |
124 | C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); |
125 | C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); |
126 | } |
127 | |
128 | void record( |
129 | void** event, |
130 | const Stream& stream, |
131 | const DeviceIndex device_index, |
132 | const EventFlag flag) const override { |
133 | TORCH_CHECK( |
134 | device_index == -1 || device_index == stream.device_index(), |
135 | "Event device index " , |
136 | device_index, |
137 | " does not match recording stream's device index " , |
138 | stream.device_index(), |
139 | "." ); |
140 | |
141 | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); |
142 | CUDAStream cuda_stream{stream}; |
143 | |
144 | // Moves to stream's device to record |
145 | const auto orig_device = getDevice(); |
146 | setDevice(stream.device()); |
147 | |
148 | // Creates the event (lazily) |
149 | if (!cuda_event) |
150 | createEvent(&cuda_event, flag); |
151 | C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); |
152 | // Makes the void* point to the (possibly just allocated) CUDA event |
153 | *event = cuda_event; |
154 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
155 | if (C10_UNLIKELY(interp)) { |
156 | (*interp)->trace_gpu_event_record( |
157 | reinterpret_cast<uintptr_t>(cuda_event), |
158 | reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
159 | } |
160 | |
161 | // Resets device |
162 | setDevice(orig_device); |
163 | } |
164 | |
165 | void block(void* event, const Stream& stream) const override { |
166 | if (!event) |
167 | return; |
168 | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
169 | CUDAStream cuda_stream{stream}; |
170 | const auto orig_device = getDevice(); |
171 | setDevice(stream.device()); |
172 | C10_CUDA_CHECK(cudaStreamWaitEvent( |
173 | cuda_stream, |
174 | cuda_event, |
175 | /*flags (must be zero)=*/0)); |
176 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
177 | if (C10_UNLIKELY(interp)) { |
178 | (*interp)->trace_gpu_event_wait( |
179 | reinterpret_cast<uintptr_t>(cuda_event), |
180 | reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
181 | } |
182 | setDevice(orig_device); |
183 | } |
184 | |
185 | // May be called from any device |
186 | bool queryEvent(void* event) const override { |
187 | if (!event) |
188 | return true; |
189 | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
190 | const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); |
191 | if (err != cudaErrorNotReady) { |
192 | C10_CUDA_CHECK(err); |
193 | } else { |
194 | // ignore and clear the error if not ready |
195 | (void)cudaGetLastError(); |
196 | } |
197 | return (err == cudaSuccess); |
198 | } |
199 | |
200 | // Stream-related functions |
201 | bool queryStream(const Stream& stream) const override { |
202 | CUDAStream cuda_stream{stream}; |
203 | return cuda_stream.query(); |
204 | } |
205 | |
206 | void synchronizeStream(const Stream& stream) const override { |
207 | CUDAStream cuda_stream{stream}; |
208 | cuda_stream.synchronize(); |
209 | } |
210 | |
211 | void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) |
212 | const override { |
213 | CUDAStream cuda_stream{stream}; |
214 | CUDACachingAllocator::recordStream(data_ptr, cuda_stream); |
215 | } |
216 | }; |
217 | |
218 | } // namespace impl |
219 | } // namespace cuda |
220 | } // namespace c10 |
221 | |