1#pragma once
2
3#include <ATen/cuda/ATenCUDAGeneral.h>
4#include <ATen/cuda/CUDAContext.h>
5#include <c10/core/impl/GPUTrace.h>
6#include <c10/cuda/CUDAStream.h>
7#include <c10/cuda/CUDAGuard.h>
8#include <ATen/cuda/Exceptions.h>
9#include <c10/util/Exception.h>
10
11#include <cuda_runtime_api.h>
12
13#include <cstdint>
14#include <utility>
15
16namespace at { namespace cuda {
17
18/*
19* CUDAEvents are movable not copyable wrappers around CUDA's events.
20*
21* CUDAEvents are constructed lazily when first recorded unless it is
22* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
23* device is acquired from the first recording stream. However, if reconstructed
24* from a handle, the device should be explicitly specified; or if ipc_handle() is
25* called before the event is ever recorded, it will use the current device.
26* Later streams that record the event must match this device.
27*/
28struct TORCH_CUDA_CPP_API CUDAEvent {
29 // Constructors
30 // Default value for `flags` is specified below - it's cudaEventDisableTiming
31 CUDAEvent() noexcept = default;
32 CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
33
34 CUDAEvent(
35 DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
36 device_index_ = device_index;
37 CUDAGuard guard(device_index_);
38
39 AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
40 is_created_ = true;
41 }
42
43 // Note: event destruction done on creating device to avoid creating a
44 // CUDA context on other devices.
45 ~CUDAEvent() {
46 try {
47 if (is_created_) {
48 CUDAGuard guard(device_index_);
49 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
50 if (C10_UNLIKELY(interp)) {
51 (*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_));
52 }
53 cudaEventDestroy(event_);
54 }
55 } catch (...) { /* No throw */ }
56 }
57
58 CUDAEvent(const CUDAEvent&) = delete;
59 CUDAEvent& operator=(const CUDAEvent&) = delete;
60
61 CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
62 CUDAEvent& operator=(CUDAEvent&& other) noexcept {
63 if (this != &other) {
64 moveHelper(std::move(other));
65 }
66 return *this;
67 }
68
69 operator cudaEvent_t() const { return event(); }
70
71 // Less than operator (to allow use in sets)
72 friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
73 return left.event_ < right.event_;
74 }
75
76 optional<at::Device> device() const {
77 if (is_created_) {
78 return at::Device(at::kCUDA, device_index_);
79 } else {
80 return {};
81 }
82 }
83
84 bool isCreated() const { return is_created_; }
85 DeviceIndex device_index() const {return device_index_;}
86 cudaEvent_t event() const { return event_; }
87
88 // Note: cudaEventQuery can be safely called from any device
89 bool query() const {
90 if (!is_created_) {
91 return true;
92 }
93
94 cudaError_t err = cudaEventQuery(event_);
95 if (err == cudaSuccess) {
96 return true;
97 } else if (err != cudaErrorNotReady) {
98 C10_CUDA_CHECK(err);
99 } else {
100 // ignore and clear the error if not ready
101 cudaGetLastError();
102 }
103
104 return false;
105 }
106
107 void record() { record(getCurrentCUDAStream()); }
108
109 void recordOnce(const CUDAStream& stream) {
110 if (!was_recorded_) record(stream);
111 }
112
113 // Note: cudaEventRecord must be called on the same device as the event.
114 void record(const CUDAStream& stream) {
115 if (!is_created_) {
116 createEvent(stream.device_index());
117 }
118
119 TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
120 " does not match recording stream's device ", stream.device_index(), ".");
121 CUDAGuard guard(device_index_);
122 AT_CUDA_CHECK(cudaEventRecord(event_, stream));
123 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
124 if (C10_UNLIKELY(interp)) {
125 (*interp)->trace_gpu_event_record(
126 reinterpret_cast<uintptr_t>(event_),
127 reinterpret_cast<uintptr_t>(stream.stream())
128 );
129 }
130 was_recorded_ = true;
131 }
132
133 // Note: cudaStreamWaitEvent must be called on the same device as the stream.
134 // The event has no actual GPU resources associated with it.
135 void block(const CUDAStream& stream) {
136 if (is_created_) {
137 CUDAGuard guard(stream.device_index());
138 AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
139 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
140 if (C10_UNLIKELY(interp)) {
141 (*interp)->trace_gpu_event_wait(
142 reinterpret_cast<uintptr_t>(event_),
143 reinterpret_cast<uintptr_t>(stream.stream())
144 );
145 }
146 }
147 }
148
149 // Note: cudaEventElapsedTime can be safely called from any device
150 float elapsed_time(const CUDAEvent& other) const {
151 TORCH_CHECK(is_created_ && other.isCreated(),
152 "Both events must be recorded before calculating elapsed time.");
153 float time_ms = 0;
154 // raise cudaErrorNotReady if either event is recorded but not yet completed
155 AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
156 return time_ms;
157 }
158
159 // Note: cudaEventSynchronize can be safely called from any device
160 void synchronize() const {
161 if (is_created_) {
162 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
163 if (C10_UNLIKELY(interp)) {
164 (*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_));
165 }
166 AT_CUDA_CHECK(cudaEventSynchronize(event_));
167 }
168 }
169
170 // Note: cudaIpcGetEventHandle must be called on the same device as the event
171 void ipc_handle(cudaIpcEventHandle_t * handle) {
172 if (!is_created_) {
173 // this CUDAEvent object was initially constructed from flags but event_
174 // is not created yet.
175 createEvent(getCurrentCUDAStream().device_index());
176 }
177 CUDAGuard guard(device_index_);
178 AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
179 }
180
181private:
182 unsigned int flags_ = cudaEventDisableTiming;
183 bool is_created_ = false;
184 bool was_recorded_ = false;
185 DeviceIndex device_index_ = -1;
186 cudaEvent_t event_{};
187
188 void createEvent(DeviceIndex device_index) {
189 device_index_ = device_index;
190 CUDAGuard guard(device_index_);
191 AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
192 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
193 if (C10_UNLIKELY(interp)) {
194 (*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_));
195 }
196 is_created_ = true;
197 }
198
199 void moveHelper(CUDAEvent&& other) {
200 std::swap(flags_, other.flags_);
201 std::swap(is_created_, other.is_created_);
202 std::swap(was_recorded_, other.was_recorded_);
203 std::swap(device_index_, other.device_index_);
204 std::swap(event_, other.event_);
205 }
206};
207
208} // namespace cuda
209} // namespace at
210