1#pragma once
2
3#include <cstdint>
4#include <utility>
5
6#include <cuda_runtime_api.h>
7
8#include <c10/core/DeviceGuard.h>
9#include <c10/core/Stream.h>
10#include <c10/cuda/CUDAFunctions.h>
11#include <c10/util/Exception.h>
12
13/*
14 * Stream pool note.
15 *
16 * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
17 * are backed by cuStreams, but they use several pools to minimize the costs
18 * associated with creating, retaining, and destroying cuStreams.
19 *
20 * There are three pools per device, and a device's pools are lazily created.
21 *
22 * The first pool contains only the default stream. When the default stream
23 * is requested it's returned.
24 *
25 * The second pool is the "low priority" or "default priority" streams. In
26 * HIP builds there is no distinction between streams in this pool and streams
27 * in the third pool (below). There are 32 of these streams per device, and
28 * when a stream is requested one of these streams is returned round-robin.
29 * That is, the first stream requested is at index 0, the second at index 1...
30 * to index 31, then index 0 again.
31 *
32 * This means that if 33 low priority streams are requested, the first and
33 * last streams requested are actually the same stream (under the covers)
34 * and kernels enqueued on them cannot run concurrently.
35 *
36 * The third pool is the "high priority" streams. The third pool acts like
37 * the second pool except the streams are created with a higher priority.
38 *
39 * These pools suggest that stream users should prefer many short-lived streams,
40 * as the cost of acquiring and releasing streams is effectively zero. If
41 * many longer-lived streams are required in performance critical scenarios
42 * then the functionality here may need to be extended to allow, for example,
43 * "reserving" a subset of the pool so that other streams do not accidentally
44 * overlap the performance critical streams.
45 *
46 * Note: although the notion of "current stream for device" is thread local
47 * (every OS thread has a separate current stream, as one might expect),
48 * the stream pool is global across all threads; stream 0 is always stream 0
49 * no matter which thread you use it on. Multiple threads can synchronize
50 * on the same stream. Although the CUDA documentation is not very clear
51 * on the matter, streams are thread safe; e.g., it is safe to enqueue
52 * a kernel on the same stream from two different threads.
53 */
54
55namespace c10 {
56namespace cuda {
57
58// Value object representing a CUDA stream. This is just a wrapper
59// around c10::Stream, but it comes with a little extra CUDA-specific
60// functionality (conversion to cudaStream_t), and a guarantee that
61// the wrapped c10::Stream really is a CUDA stream.
62class C10_CUDA_API CUDAStream {
63 public:
64 enum Unchecked { UNCHECKED };
65
66 /// Construct a CUDAStream from a Stream. This construction is checked,
67 /// and will raise an error if the Stream is not, in fact, a CUDA stream.
68 explicit CUDAStream(Stream stream) : stream_(stream) {
69 TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
70 }
71
72 /// Construct a CUDAStream from a Stream with no error checking.
73 /// This constructor uses the "named" constructor idiom, and can
74 /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
75 explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {}
76
77 bool operator==(const CUDAStream& other) const noexcept {
78 return unwrap() == other.unwrap();
79 }
80
81 bool operator!=(const CUDAStream& other) const noexcept {
82 return unwrap() != other.unwrap();
83 }
84
85 /// Implicit conversion to cudaStream_t.
86 operator cudaStream_t() const {
87 return stream();
88 }
89
90 /// Implicit conversion to Stream (a.k.a., forget that the stream is a
91 /// CUDA stream).
92 operator Stream() const {
93 return unwrap();
94 }
95
96 /// Used to avoid baking in device type explicitly to Python-side API.
97 DeviceType device_type() const {
98 return DeviceType::CUDA;
99 }
100
101 /// Get the CUDA device index that this stream is associated with.
102 DeviceIndex device_index() const {
103 return stream_.device_index();
104 }
105
106 /// Get the full Device that this stream is associated with. The Device
107 /// is guaranteed to be a CUDA device.
108 Device device() const {
109 return Device(DeviceType::CUDA, device_index());
110 }
111
112 /// Return the stream ID corresponding to this particular stream.
113 StreamId id() const {
114 return stream_.id();
115 }
116
117 bool query() const {
118 DeviceGuard guard{stream_.device()};
119 cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream()));
120
121 if (err == cudaSuccess) {
122 return true;
123 } else if (err != cudaErrorNotReady) {
124 C10_CUDA_CHECK(err);
125 } else {
126 // ignore and clear the error if not ready
127 (void)cudaGetLastError();
128 }
129
130 return false;
131 }
132
133 void synchronize() const {
134 DeviceGuard guard{stream_.device()};
135 c10::cuda::stream_synchronize(stream());
136 }
137
138 int priority() const {
139 DeviceGuard guard{stream_.device()};
140 int priority = 0;
141 C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
142 return priority;
143 }
144
145 /// Explicit conversion to cudaStream_t.
146 cudaStream_t stream() const;
147
148 /// Explicit conversion to Stream.
149 Stream unwrap() const {
150 return stream_;
151 }
152
153 /// Reversibly pack a CUDAStream into a struct representation.
154 /// Previously the stream's data was packed into a single int64_t,
155 /// as it was assumed the fields would not require more than
156 /// 64 bits of storage in total.
157 /// See https://github.com/pytorch/pytorch/issues/75854
158 /// for more information regarding newer platforms that may violate
159 /// this assumption.
160 ///
161 /// The CUDAStream can be unpacked using unpack().
162 struct c10::StreamData3 pack3() const {
163 return stream_.pack3();
164 }
165
166 // Unpack a CUDAStream from the 3 fields generated by pack().
167 static CUDAStream unpack3(
168 StreamId stream_id,
169 DeviceIndex device_index,
170 DeviceType device_type) {
171 return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
172 }
173
174 static std::tuple<int, int> priority_range() {
175 // Note: this returns the range of priority **supported by PyTorch**, not
176 // the range of priority **supported by CUDA**. The former is a subset of
177 // the latter. Currently PyTorch only supports 0 and -1, which are "low" and
178 // "high" priority.
179 int least_priority, greatest_priority;
180 C10_CUDA_CHECK(
181 cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
182 TORCH_INTERNAL_ASSERT(
183 least_priority >= 0, "Unexpected CUDA stream priority range");
184 TORCH_INTERNAL_ASSERT(
185 greatest_priority <= -1, "Unexpected CUDA stream priority range");
186 return std::make_tuple(0, -1);
187 }
188
189 // Deleted for now; use CUDAEvent::block instead
190 // void synchronize_with(const CUDAEvent& event) const;
191
192 private:
193 Stream stream_;
194};
195
196/**
197 * Get a new stream from the CUDA stream pool. You can think of this
198 * as "creating" a new stream, but no such creation actually happens;
199 * instead, streams are preallocated from the pool and returned in a
200 * round-robin fashion.
201 *
202 * You can request a stream from the high priority pool by setting
203 * isHighPriority to true, or a stream for a specific device by setting device
204 * (defaulting to the current CUDA stream.)
205 */
206C10_API CUDAStream
207getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
208
209/**
210 * Get a CUDAStream from a externally allocated one.
211 *
212 * This is mainly for interoperability with different libraries where we
213 * want to operate on a non-torch allocated stream for data exchange or similar
214 * purposes
215 */
216C10_API CUDAStream
217getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
218
219/**
220 * Get the default CUDA stream, for the passed CUDA device, or for the
221 * current device if no device index is passed. The default stream is
222 * where most computation occurs when you aren't explicitly using
223 * streams.
224 */
225C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
226
227/**
228 * Get the current CUDA stream, for the passed CUDA device, or for the
229 * current device if no device index is passed. The current CUDA stream
230 * will usually be the default CUDA stream for the device, but it may
231 * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
232 * or 'CUDAStreamGuard'.
233 */
234C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
235
236/**
237 * Set the current stream on the device of the passed in stream to be
238 * the passed in stream. Yes, you read that right: this function
239 * has *nothing* to do with the current device: it toggles the current
240 * stream of the device of the passed stream.
241 *
242 * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead
243 * (which will switch both your current device and current stream in the way you
244 * expect, and reset it back to its original state afterwards).
245 */
246C10_API void setCurrentCUDAStream(CUDAStream stream);
247
248C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
249
250} // namespace cuda
251} // namespace c10
252
253namespace std {
254template <>
255struct hash<c10::cuda::CUDAStream> {
256 size_t operator()(c10::cuda::CUDAStream s) const noexcept {
257 return std::hash<c10::Stream>{}(s.unwrap());
258 }
259};
260} // namespace std
261