1 | #pragma once |
2 | |
3 | #include <c10/core/DeviceType.h> |
4 | #include <c10/core/impl/InlineDeviceGuard.h> |
5 | #include <c10/core/impl/InlineStreamGuard.h> |
6 | #include <c10/cuda/CUDAMacros.h> |
7 | #include <c10/cuda/impl/CUDAGuardImpl.h> |
8 | |
9 | #include <cstddef> |
10 | |
11 | namespace c10 { |
12 | namespace cuda { |
13 | |
14 | // This code is kind of boilerplatey. See Note [Whither the DeviceGuard |
15 | // boilerplate] |
16 | |
17 | /// A variant of DeviceGuard that is specialized for CUDA. It accepts |
18 | /// integer indices (interpreting them as CUDA devices) and is a little |
19 | /// more efficient than DeviceGuard (it compiles to straight line |
20 | /// cudaSetDevice/cudaGetDevice calls); however, it can only be used |
21 | /// from code that links against CUDA directly. |
22 | struct CUDAGuard { |
23 | /// No default constructor; see Note [Omitted default constructor from RAII] |
24 | explicit CUDAGuard() = delete; |
25 | |
26 | /// Set the current CUDA device to the passed device index. |
27 | explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} |
28 | |
29 | /// Sets the current CUDA device to the passed device. Errors if the passed |
30 | /// device is not a CUDA device. |
31 | explicit CUDAGuard(Device device) : guard_(device) {} |
32 | |
33 | // Copy is not allowed |
34 | CUDAGuard(const CUDAGuard&) = delete; |
35 | CUDAGuard& operator=(const CUDAGuard&) = delete; |
36 | |
37 | // Move is not allowed (there is no uninitialized state) |
38 | CUDAGuard(CUDAGuard&& other) = delete; |
39 | CUDAGuard& operator=(CUDAGuard&& other) = delete; |
40 | |
41 | /// Sets the CUDA device to the given device. Errors if the given device |
42 | /// is not a CUDA device. |
43 | void set_device(Device device) { |
44 | guard_.set_device(device); |
45 | } |
46 | |
47 | /// Sets the CUDA device to the given device. Errors if the given device |
48 | /// is not a CUDA device. (This method is provided for uniformity with |
49 | /// DeviceGuard). |
50 | void reset_device(Device device) { |
51 | guard_.reset_device(device); |
52 | } |
53 | |
54 | /// Sets the CUDA device to the given device index. |
55 | void set_index(DeviceIndex device_index) { |
56 | guard_.set_index(device_index); |
57 | } |
58 | |
59 | /// Returns the device that was set upon construction of the guard |
60 | Device original_device() const { |
61 | return guard_.original_device(); |
62 | } |
63 | |
64 | /// Returns the last device that was set via `set_device`, if any, otherwise |
65 | /// the device passed during construction. |
66 | Device current_device() const { |
67 | return guard_.current_device(); |
68 | } |
69 | |
70 | private: |
71 | /// The guard for the current device. |
72 | c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_; |
73 | }; |
74 | |
75 | /// A variant of OptionalDeviceGuard that is specialized for CUDA. See |
76 | /// CUDAGuard for when you can use this. |
77 | struct OptionalCUDAGuard { |
78 | /// Create an uninitialized OptionalCUDAGuard. |
79 | explicit OptionalCUDAGuard() : guard_() {} |
80 | |
81 | /// Set the current CUDA device to the passed Device, if it is not nullopt. |
82 | explicit OptionalCUDAGuard(optional<Device> device_opt) |
83 | : guard_(device_opt) {} |
84 | |
85 | /// Set the current CUDA device to the passed device index, if it is not |
86 | /// nullopt |
87 | explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt) |
88 | : guard_(device_index_opt) {} |
89 | |
90 | // Copy is not allowed |
91 | OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; |
92 | OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; |
93 | |
94 | // See Note [Move construction for RAII guards is tricky] |
95 | OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; |
96 | |
97 | // See Note [Move assignment for RAII guards is tricky] |
98 | OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; |
99 | |
100 | /// Sets the CUDA device to the given device, initializing the guard if it |
101 | /// is not already initialized. Errors if the given device is not a CUDA |
102 | /// device. |
103 | void set_device(Device device) { |
104 | guard_.set_device(device); |
105 | } |
106 | |
107 | /// Sets the CUDA device to the given device, initializing the guard if it is |
108 | /// not already initialized. Errors if the given device is not a CUDA device. |
109 | /// (This method is provided for uniformity with OptionalDeviceGuard). |
110 | void reset_device(Device device) { |
111 | guard_.reset_device(device); |
112 | } |
113 | |
114 | /// Sets the CUDA device to the given device index, initializing the guard if |
115 | /// it is not already initialized. |
116 | void set_index(DeviceIndex device_index) { |
117 | guard_.set_index(device_index); |
118 | } |
119 | |
120 | /// Returns the device that was set immediately prior to initialization of the |
121 | /// guard, or nullopt if the guard is uninitialized. |
122 | optional<Device> original_device() const { |
123 | return guard_.original_device(); |
124 | } |
125 | |
126 | /// Returns the most recent device that was set using this device guard, |
127 | /// either from construction, or via set_device, if the guard is initialized, |
128 | /// or nullopt if the guard is uninitialized. |
129 | optional<Device> current_device() const { |
130 | return guard_.current_device(); |
131 | } |
132 | |
133 | /// Restore the original CUDA device, resetting this guard to uninitialized |
134 | /// state. |
135 | void reset() { |
136 | guard_.reset(); |
137 | } |
138 | |
139 | private: |
140 | c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_; |
141 | }; |
142 | |
143 | /// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard |
144 | /// for when you can use this. |
145 | struct CUDAStreamGuard { |
146 | /// No default constructor, see Note [Omitted default constructor from RAII] |
147 | explicit CUDAStreamGuard() = delete; |
148 | |
149 | /// Set the current CUDA device to the device associated with the passed |
150 | /// stream, and set the current CUDA stream on that device to the passed |
151 | /// stream. Errors if the Stream is not a CUDA stream. |
152 | explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} |
153 | |
154 | /// Copy is disallowed |
155 | CUDAStreamGuard(const CUDAStreamGuard&) = delete; |
156 | CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; |
157 | |
158 | /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized |
159 | /// state, which is required for moves on types with nontrivial destructors. |
160 | CUDAStreamGuard(CUDAStreamGuard&& other) = delete; |
161 | CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; |
162 | |
163 | /// Resets the currently set stream to the original stream and |
164 | /// the currently set device to the original device. Then, |
165 | /// set the current device to the device associated with the passed stream, |
166 | /// and set the current stream on that device to the passed stream. |
167 | /// Errors if the stream passed is not a CUDA stream. |
168 | /// |
169 | /// NOTE: this implementation may skip some stream/device setting if |
170 | /// it can prove that it is unnecessary. |
171 | /// |
172 | /// WARNING: reset_stream does NOT preserve previously set streams on |
173 | /// different devices. If you need to set streams on multiple devices |
174 | /// on CUDA, use CUDAMultiStreamGuard instead. |
175 | void reset_stream(Stream stream) { |
176 | guard_.reset_stream(stream); |
177 | } |
178 | |
179 | /// Returns the CUDA stream that was set at the time the guard was |
180 | /// constructed. |
181 | CUDAStream original_stream() const { |
182 | return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); |
183 | } |
184 | |
185 | /// Returns the most recent CUDA stream that was set using this device guard, |
186 | /// either from construction, or via set_stream. |
187 | CUDAStream current_stream() const { |
188 | return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); |
189 | } |
190 | |
191 | /// Returns the most recent CUDA device that was set using this device guard, |
192 | /// either from construction, or via set_device/reset_device/set_index. |
193 | Device current_device() const { |
194 | return guard_.current_device(); |
195 | } |
196 | |
197 | /// Returns the CUDA device that was set at the most recent reset_stream(), |
198 | /// or otherwise the device at construction time. |
199 | Device original_device() const { |
200 | return guard_.original_device(); |
201 | } |
202 | |
203 | private: |
204 | c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_; |
205 | }; |
206 | |
207 | /// A variant of OptionalStreamGuard that is specialized for CUDA. See |
208 | /// CUDAGuard for when you can use this. |
209 | struct OptionalCUDAStreamGuard { |
210 | /// Create an uninitialized guard. |
211 | explicit OptionalCUDAStreamGuard() : guard_() {} |
212 | |
213 | /// Set the current CUDA device to the device associated with the passed |
214 | /// stream, and set the current CUDA stream on that device to the passed |
215 | /// stream. Errors if the Stream is not a CUDA stream. |
216 | explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} |
217 | |
218 | /// Set the current device to the device associated with the passed stream, |
219 | /// and set the current stream on that device to the passed stream, |
220 | /// if the passed stream is not nullopt. |
221 | explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt) |
222 | : guard_(stream_opt) {} |
223 | |
224 | /// Copy is disallowed |
225 | OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; |
226 | OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; |
227 | |
228 | // See Note [Move construction for RAII guards is tricky] |
229 | OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; |
230 | |
231 | // See Note [Move assignment for RAII guards is tricky] |
232 | OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; |
233 | |
234 | /// Resets the currently set CUDA stream to the original stream and |
235 | /// the currently set device to the original device. Then, |
236 | /// set the current device to the device associated with the passed stream, |
237 | /// and set the current stream on that device to the passed stream. |
238 | /// Initializes the guard if it was not previously initialized. |
239 | void reset_stream(Stream stream) { |
240 | guard_.reset_stream(stream); |
241 | } |
242 | |
243 | /// Returns the CUDA stream that was set at the time the guard was most |
244 | /// recently initialized, or nullopt if the guard is uninitialized. |
245 | optional<CUDAStream> original_stream() const { |
246 | auto r = guard_.original_stream(); |
247 | if (r.has_value()) { |
248 | return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); |
249 | } else { |
250 | return nullopt; |
251 | } |
252 | } |
253 | |
254 | /// Returns the most recent CUDA stream that was set using this stream guard, |
255 | /// either from construction, or via reset_stream, if the guard is |
256 | /// initialized, or nullopt if the guard is uninitialized. |
257 | optional<CUDAStream> current_stream() const { |
258 | auto r = guard_.current_stream(); |
259 | if (r.has_value()) { |
260 | return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); |
261 | } else { |
262 | return nullopt; |
263 | } |
264 | } |
265 | |
266 | /// Restore the original CUDA device and stream, resetting this guard to |
267 | /// uninitialized state. |
268 | void reset() { |
269 | guard_.reset(); |
270 | } |
271 | |
272 | private: |
273 | c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_; |
274 | }; |
275 | |
276 | /// A variant of MultiStreamGuard that is specialized for CUDA. |
277 | struct CUDAMultiStreamGuard { |
278 | explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) |
279 | : guard_(unwrapStreams(streams)) {} |
280 | |
281 | /// Copy is disallowed |
282 | CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; |
283 | CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; |
284 | |
285 | // See Note [Move construction for RAII guards is tricky] |
286 | CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; |
287 | |
288 | // See Note [Move assignment for RAII guards is tricky] |
289 | CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; |
290 | |
291 | private: |
292 | c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_; |
293 | |
294 | static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) { |
295 | std::vector<Stream> streams; |
296 | streams.reserve(cudaStreams.size()); |
297 | for (const CUDAStream& cudaStream : cudaStreams) { |
298 | streams.push_back(cudaStream); |
299 | } |
300 | return streams; |
301 | } |
302 | }; |
303 | |
304 | } // namespace cuda |
305 | } // namespace c10 |
306 | |