1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/core/impl/InlineDeviceGuard.h>
5#include <c10/core/impl/InlineStreamGuard.h>
6#include <c10/cuda/CUDAMacros.h>
7#include <c10/cuda/impl/CUDAGuardImpl.h>
8
9#include <cstddef>
10
11namespace c10 {
12namespace cuda {
13
14// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
15// boilerplate]
16
17/// A variant of DeviceGuard that is specialized for CUDA. It accepts
18/// integer indices (interpreting them as CUDA devices) and is a little
19/// more efficient than DeviceGuard (it compiles to straight line
20/// cudaSetDevice/cudaGetDevice calls); however, it can only be used
21/// from code that links against CUDA directly.
22struct CUDAGuard {
23 /// No default constructor; see Note [Omitted default constructor from RAII]
24 explicit CUDAGuard() = delete;
25
26 /// Set the current CUDA device to the passed device index.
27 explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {}
28
29 /// Sets the current CUDA device to the passed device. Errors if the passed
30 /// device is not a CUDA device.
31 explicit CUDAGuard(Device device) : guard_(device) {}
32
33 // Copy is not allowed
34 CUDAGuard(const CUDAGuard&) = delete;
35 CUDAGuard& operator=(const CUDAGuard&) = delete;
36
37 // Move is not allowed (there is no uninitialized state)
38 CUDAGuard(CUDAGuard&& other) = delete;
39 CUDAGuard& operator=(CUDAGuard&& other) = delete;
40
41 /// Sets the CUDA device to the given device. Errors if the given device
42 /// is not a CUDA device.
43 void set_device(Device device) {
44 guard_.set_device(device);
45 }
46
47 /// Sets the CUDA device to the given device. Errors if the given device
48 /// is not a CUDA device. (This method is provided for uniformity with
49 /// DeviceGuard).
50 void reset_device(Device device) {
51 guard_.reset_device(device);
52 }
53
54 /// Sets the CUDA device to the given device index.
55 void set_index(DeviceIndex device_index) {
56 guard_.set_index(device_index);
57 }
58
59 /// Returns the device that was set upon construction of the guard
60 Device original_device() const {
61 return guard_.original_device();
62 }
63
64 /// Returns the last device that was set via `set_device`, if any, otherwise
65 /// the device passed during construction.
66 Device current_device() const {
67 return guard_.current_device();
68 }
69
70 private:
71 /// The guard for the current device.
72 c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_;
73};
74
75/// A variant of OptionalDeviceGuard that is specialized for CUDA. See
76/// CUDAGuard for when you can use this.
77struct OptionalCUDAGuard {
78 /// Create an uninitialized OptionalCUDAGuard.
79 explicit OptionalCUDAGuard() : guard_() {}
80
81 /// Set the current CUDA device to the passed Device, if it is not nullopt.
82 explicit OptionalCUDAGuard(optional<Device> device_opt)
83 : guard_(device_opt) {}
84
85 /// Set the current CUDA device to the passed device index, if it is not
86 /// nullopt
87 explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt)
88 : guard_(device_index_opt) {}
89
90 // Copy is not allowed
91 OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
92 OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete;
93
94 // See Note [Move construction for RAII guards is tricky]
95 OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;
96
97 // See Note [Move assignment for RAII guards is tricky]
98 OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
99
100 /// Sets the CUDA device to the given device, initializing the guard if it
101 /// is not already initialized. Errors if the given device is not a CUDA
102 /// device.
103 void set_device(Device device) {
104 guard_.set_device(device);
105 }
106
107 /// Sets the CUDA device to the given device, initializing the guard if it is
108 /// not already initialized. Errors if the given device is not a CUDA device.
109 /// (This method is provided for uniformity with OptionalDeviceGuard).
110 void reset_device(Device device) {
111 guard_.reset_device(device);
112 }
113
114 /// Sets the CUDA device to the given device index, initializing the guard if
115 /// it is not already initialized.
116 void set_index(DeviceIndex device_index) {
117 guard_.set_index(device_index);
118 }
119
120 /// Returns the device that was set immediately prior to initialization of the
121 /// guard, or nullopt if the guard is uninitialized.
122 optional<Device> original_device() const {
123 return guard_.original_device();
124 }
125
126 /// Returns the most recent device that was set using this device guard,
127 /// either from construction, or via set_device, if the guard is initialized,
128 /// or nullopt if the guard is uninitialized.
129 optional<Device> current_device() const {
130 return guard_.current_device();
131 }
132
133 /// Restore the original CUDA device, resetting this guard to uninitialized
134 /// state.
135 void reset() {
136 guard_.reset();
137 }
138
139 private:
140 c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
141};
142
143/// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard
144/// for when you can use this.
145struct CUDAStreamGuard {
146 /// No default constructor, see Note [Omitted default constructor from RAII]
147 explicit CUDAStreamGuard() = delete;
148
149 /// Set the current CUDA device to the device associated with the passed
150 /// stream, and set the current CUDA stream on that device to the passed
151 /// stream. Errors if the Stream is not a CUDA stream.
152 explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
153
154 /// Copy is disallowed
155 CUDAStreamGuard(const CUDAStreamGuard&) = delete;
156 CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
157
158 /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized
159 /// state, which is required for moves on types with nontrivial destructors.
160 CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
161 CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
162
163 /// Resets the currently set stream to the original stream and
164 /// the currently set device to the original device. Then,
165 /// set the current device to the device associated with the passed stream,
166 /// and set the current stream on that device to the passed stream.
167 /// Errors if the stream passed is not a CUDA stream.
168 ///
169 /// NOTE: this implementation may skip some stream/device setting if
170 /// it can prove that it is unnecessary.
171 ///
172 /// WARNING: reset_stream does NOT preserve previously set streams on
173 /// different devices. If you need to set streams on multiple devices
174 /// on CUDA, use CUDAMultiStreamGuard instead.
175 void reset_stream(Stream stream) {
176 guard_.reset_stream(stream);
177 }
178
179 /// Returns the CUDA stream that was set at the time the guard was
180 /// constructed.
181 CUDAStream original_stream() const {
182 return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
183 }
184
185 /// Returns the most recent CUDA stream that was set using this device guard,
186 /// either from construction, or via set_stream.
187 CUDAStream current_stream() const {
188 return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream());
189 }
190
191 /// Returns the most recent CUDA device that was set using this device guard,
192 /// either from construction, or via set_device/reset_device/set_index.
193 Device current_device() const {
194 return guard_.current_device();
195 }
196
197 /// Returns the CUDA device that was set at the most recent reset_stream(),
198 /// or otherwise the device at construction time.
199 Device original_device() const {
200 return guard_.original_device();
201 }
202
203 private:
204 c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
205};
206
207/// A variant of OptionalStreamGuard that is specialized for CUDA. See
208/// CUDAGuard for when you can use this.
209struct OptionalCUDAStreamGuard {
210 /// Create an uninitialized guard.
211 explicit OptionalCUDAStreamGuard() : guard_() {}
212
213 /// Set the current CUDA device to the device associated with the passed
214 /// stream, and set the current CUDA stream on that device to the passed
215 /// stream. Errors if the Stream is not a CUDA stream.
216 explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
217
218 /// Set the current device to the device associated with the passed stream,
219 /// and set the current stream on that device to the passed stream,
220 /// if the passed stream is not nullopt.
221 explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt)
222 : guard_(stream_opt) {}
223
224 /// Copy is disallowed
225 OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
226 OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete;
227
228 // See Note [Move construction for RAII guards is tricky]
229 OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete;
230
231 // See Note [Move assignment for RAII guards is tricky]
232 OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete;
233
234 /// Resets the currently set CUDA stream to the original stream and
235 /// the currently set device to the original device. Then,
236 /// set the current device to the device associated with the passed stream,
237 /// and set the current stream on that device to the passed stream.
238 /// Initializes the guard if it was not previously initialized.
239 void reset_stream(Stream stream) {
240 guard_.reset_stream(stream);
241 }
242
243 /// Returns the CUDA stream that was set at the time the guard was most
244 /// recently initialized, or nullopt if the guard is uninitialized.
245 optional<CUDAStream> original_stream() const {
246 auto r = guard_.original_stream();
247 if (r.has_value()) {
248 return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
249 } else {
250 return nullopt;
251 }
252 }
253
254 /// Returns the most recent CUDA stream that was set using this stream guard,
255 /// either from construction, or via reset_stream, if the guard is
256 /// initialized, or nullopt if the guard is uninitialized.
257 optional<CUDAStream> current_stream() const {
258 auto r = guard_.current_stream();
259 if (r.has_value()) {
260 return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
261 } else {
262 return nullopt;
263 }
264 }
265
266 /// Restore the original CUDA device and stream, resetting this guard to
267 /// uninitialized state.
268 void reset() {
269 guard_.reset();
270 }
271
272 private:
273 c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
274};
275
276/// A variant of MultiStreamGuard that is specialized for CUDA.
277struct CUDAMultiStreamGuard {
278 explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
279 : guard_(unwrapStreams(streams)) {}
280
281 /// Copy is disallowed
282 CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
283 CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete;
284
285 // See Note [Move construction for RAII guards is tricky]
286 CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete;
287
288 // See Note [Move assignment for RAII guards is tricky]
289 CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
290
291 private:
292 c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;
293
294 static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {
295 std::vector<Stream> streams;
296 streams.reserve(cudaStreams.size());
297 for (const CUDAStream& cudaStream : cudaStreams) {
298 streams.push_back(cudaStream);
299 }
300 return streams;
301 }
302};
303
304} // namespace cuda
305} // namespace c10
306