1 | #pragma once |
2 | |
3 | #include <ATen/cuda/ATenCUDAGeneral.h> |
4 | #include <ATen/cuda/CUDAContext.h> |
5 | #include <c10/core/impl/GPUTrace.h> |
6 | #include <c10/cuda/CUDAStream.h> |
7 | #include <c10/cuda/CUDAGuard.h> |
8 | #include <ATen/cuda/Exceptions.h> |
9 | #include <c10/util/Exception.h> |
10 | |
11 | #include <cuda_runtime_api.h> |
12 | |
13 | #include <cstdint> |
14 | #include <utility> |
15 | |
16 | namespace at { namespace cuda { |
17 | |
18 | /* |
19 | * CUDAEvents are movable not copyable wrappers around CUDA's events. |
20 | * |
21 | * CUDAEvents are constructed lazily when first recorded unless it is |
22 | * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this |
23 | * device is acquired from the first recording stream. However, if reconstructed |
24 | * from a handle, the device should be explicitly specified; or if ipc_handle() is |
25 | * called before the event is ever recorded, it will use the current device. |
26 | * Later streams that record the event must match this device. |
27 | */ |
28 | struct TORCH_CUDA_CPP_API CUDAEvent { |
29 | // Constructors |
30 | // Default value for `flags` is specified below - it's cudaEventDisableTiming |
31 | CUDAEvent() noexcept = default; |
32 | CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} |
33 | |
34 | CUDAEvent( |
35 | DeviceIndex device_index, const cudaIpcEventHandle_t* handle) { |
36 | device_index_ = device_index; |
37 | CUDAGuard guard(device_index_); |
38 | |
39 | AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); |
40 | is_created_ = true; |
41 | } |
42 | |
43 | // Note: event destruction done on creating device to avoid creating a |
44 | // CUDA context on other devices. |
45 | ~CUDAEvent() { |
46 | try { |
47 | if (is_created_) { |
48 | CUDAGuard guard(device_index_); |
49 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
50 | if (C10_UNLIKELY(interp)) { |
51 | (*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_)); |
52 | } |
53 | cudaEventDestroy(event_); |
54 | } |
55 | } catch (...) { /* No throw */ } |
56 | } |
57 | |
58 | CUDAEvent(const CUDAEvent&) = delete; |
59 | CUDAEvent& operator=(const CUDAEvent&) = delete; |
60 | |
61 | CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } |
62 | CUDAEvent& operator=(CUDAEvent&& other) noexcept { |
63 | if (this != &other) { |
64 | moveHelper(std::move(other)); |
65 | } |
66 | return *this; |
67 | } |
68 | |
69 | operator cudaEvent_t() const { return event(); } |
70 | |
71 | // Less than operator (to allow use in sets) |
72 | friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { |
73 | return left.event_ < right.event_; |
74 | } |
75 | |
76 | optional<at::Device> device() const { |
77 | if (is_created_) { |
78 | return at::Device(at::kCUDA, device_index_); |
79 | } else { |
80 | return {}; |
81 | } |
82 | } |
83 | |
84 | bool isCreated() const { return is_created_; } |
85 | DeviceIndex device_index() const {return device_index_;} |
86 | cudaEvent_t event() const { return event_; } |
87 | |
88 | // Note: cudaEventQuery can be safely called from any device |
89 | bool query() const { |
90 | if (!is_created_) { |
91 | return true; |
92 | } |
93 | |
94 | cudaError_t err = cudaEventQuery(event_); |
95 | if (err == cudaSuccess) { |
96 | return true; |
97 | } else if (err != cudaErrorNotReady) { |
98 | C10_CUDA_CHECK(err); |
99 | } else { |
100 | // ignore and clear the error if not ready |
101 | cudaGetLastError(); |
102 | } |
103 | |
104 | return false; |
105 | } |
106 | |
107 | void record() { record(getCurrentCUDAStream()); } |
108 | |
109 | void recordOnce(const CUDAStream& stream) { |
110 | if (!was_recorded_) record(stream); |
111 | } |
112 | |
113 | // Note: cudaEventRecord must be called on the same device as the event. |
114 | void record(const CUDAStream& stream) { |
115 | if (!is_created_) { |
116 | createEvent(stream.device_index()); |
117 | } |
118 | |
119 | TORCH_CHECK(device_index_ == stream.device_index(), "Event device " , device_index_, |
120 | " does not match recording stream's device " , stream.device_index(), "." ); |
121 | CUDAGuard guard(device_index_); |
122 | AT_CUDA_CHECK(cudaEventRecord(event_, stream)); |
123 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
124 | if (C10_UNLIKELY(interp)) { |
125 | (*interp)->trace_gpu_event_record( |
126 | reinterpret_cast<uintptr_t>(event_), |
127 | reinterpret_cast<uintptr_t>(stream.stream()) |
128 | ); |
129 | } |
130 | was_recorded_ = true; |
131 | } |
132 | |
133 | // Note: cudaStreamWaitEvent must be called on the same device as the stream. |
134 | // The event has no actual GPU resources associated with it. |
135 | void block(const CUDAStream& stream) { |
136 | if (is_created_) { |
137 | CUDAGuard guard(stream.device_index()); |
138 | AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0)); |
139 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
140 | if (C10_UNLIKELY(interp)) { |
141 | (*interp)->trace_gpu_event_wait( |
142 | reinterpret_cast<uintptr_t>(event_), |
143 | reinterpret_cast<uintptr_t>(stream.stream()) |
144 | ); |
145 | } |
146 | } |
147 | } |
148 | |
149 | // Note: cudaEventElapsedTime can be safely called from any device |
150 | float elapsed_time(const CUDAEvent& other) const { |
151 | TORCH_CHECK(is_created_ && other.isCreated(), |
152 | "Both events must be recorded before calculating elapsed time." ); |
153 | float time_ms = 0; |
154 | // raise cudaErrorNotReady if either event is recorded but not yet completed |
155 | AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); |
156 | return time_ms; |
157 | } |
158 | |
159 | // Note: cudaEventSynchronize can be safely called from any device |
160 | void synchronize() const { |
161 | if (is_created_) { |
162 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
163 | if (C10_UNLIKELY(interp)) { |
164 | (*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_)); |
165 | } |
166 | AT_CUDA_CHECK(cudaEventSynchronize(event_)); |
167 | } |
168 | } |
169 | |
170 | // Note: cudaIpcGetEventHandle must be called on the same device as the event |
171 | void ipc_handle(cudaIpcEventHandle_t * handle) { |
172 | if (!is_created_) { |
173 | // this CUDAEvent object was initially constructed from flags but event_ |
174 | // is not created yet. |
175 | createEvent(getCurrentCUDAStream().device_index()); |
176 | } |
177 | CUDAGuard guard(device_index_); |
178 | AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); |
179 | } |
180 | |
181 | private: |
182 | unsigned int flags_ = cudaEventDisableTiming; |
183 | bool is_created_ = false; |
184 | bool was_recorded_ = false; |
185 | DeviceIndex device_index_ = -1; |
186 | cudaEvent_t event_{}; |
187 | |
188 | void createEvent(DeviceIndex device_index) { |
189 | device_index_ = device_index; |
190 | CUDAGuard guard(device_index_); |
191 | AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); |
192 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
193 | if (C10_UNLIKELY(interp)) { |
194 | (*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_)); |
195 | } |
196 | is_created_ = true; |
197 | } |
198 | |
199 | void moveHelper(CUDAEvent&& other) { |
200 | std::swap(flags_, other.flags_); |
201 | std::swap(is_created_, other.is_created_); |
202 | std::swap(was_recorded_, other.was_recorded_); |
203 | std::swap(device_index_, other.device_index_); |
204 | std::swap(event_, other.event_); |
205 | } |
206 | }; |
207 | |
208 | } // namespace cuda |
209 | } // namespace at |
210 | |