1#include <c10/core/impl/GPUTrace.h>
2#include <c10/cuda/CUDAFunctions.h>
3#include <c10/cuda/CUDAGuard.h>
4#include <c10/cuda/CUDAStream.h>
5#include <c10/util/CallOnce.h>
6#include <c10/util/Exception.h>
7#include <c10/util/irange.h>
8
9#include <atomic>
10#include <cstdint>
11#include <mutex>
12#include <vector>
13
14#include <iostream>
15namespace c10 {
16namespace cuda {
17
18namespace {
19
20// Global stream state and constants
21static c10::once_flag init_flag;
22static DeviceIndex num_gpus = -1;
23static constexpr int kStreamsPerPoolBits = 5;
24static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
25static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
26static constexpr int kStreamTypeBits = 3;
27
28// Note: lower numbers are higher priorities, zero is default priority
29static constexpr int kHighPriority = -1;
30static constexpr int kLowPriority = 0;
31
32// Non-default streams
33// Note: the number of CUDA devices is determined at run time,
34// and the low and high priority pools are lazily initialized
35// when the first stream is requested for a device.
36// The device flags track the initialization of each device, while
37// the low and high priority counters track, for each device, the next stream
38// in the pool to be returned when a stream is requested (round-robin fashion
39// , see the note in CUDAStream.h).
40// The streams are "leaked": they are created but never destroyed because the
41// destruction of global variables could happen after the CUDA runtime has
42// already been destroyed and thus invoking cudaStreamDestroy could lead to a
43// crash. It's likely an issue in CUDA, but to be safe - let's just "forget"
44// the destruction.
45static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
46static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
47static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
48static cudaStream_t low_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
49 [kStreamsPerPool];
50static cudaStream_t high_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
51 [kStreamsPerPool];
52
53// Note [StreamId assignment]
54// ~~~~~~~~~~~~~~~~~~~~~~~~~~
55// How do we assign stream IDs?
56//
57// -- 57 bits -- -- 5 bits ----- -- 3 bits --
58// zeros stream id index StreamIdType
59//
60// Where StreamIdType:
61// 000 = default stream or externally allocated if id[63:3] != 0
62// 001 = low priority stream
63// 010 = high priority stream
64//
65// This is not really for efficiency; it's just easier to write the code
66// to extract the index if we do this with bitmasks :)
67//
68// We are obligated to treat the stream ID 0 as the default stream, per the
69// invariant specified in c10::Stream. However, all other numbers are entirely
70// an internal implementation detail, we reserve the right to renumber streams
71// however we like.
72//
73// Note that it is really important that the MSB is zero; StreamId is a
74// *signed* integer, and unsigned to signed conversion outside of the
75// bounds of signed integer representation is undefined behavior. You
76// could work around this with something like
77// https://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
78// but it seems a bit overkill for this.
79//
80// Also, external managed stream pointers (cudaStream_t) can be directly stored
81// in the Id field so in this case, we need to check the stream alignment.
82// The IdType uses an additional bit to match with the 64-bit address alignment
83// making easy to identify an external stream when its value (X & 7) > 0
84enum class StreamIdType : uint8_t {
85 DEFAULT = 0x0,
86 LOW = 0x1,
87 HIGH = 0x2,
88 EXT = 0x3,
89};
90
91std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
92 switch (s) {
93 case StreamIdType::DEFAULT:
94 stream << "DEFAULT";
95 break;
96 case StreamIdType::LOW:
97 stream << "LOW";
98 break;
99 case StreamIdType::HIGH:
100 stream << "HIGH";
101 break;
102 case StreamIdType::EXT:
103 stream << "EXT";
104 break;
105 default:
106 stream << static_cast<uint8_t>(s);
107 break;
108 }
109 return stream;
110}
111
112// StreamId is 64-bit, so we can just rely on regular promotion rules.
113// We rely on streamIdIndex and streamIdType being non-negative;
114// see Note [Hazard when concatenating signed integers]
115
116static inline StreamIdType streamIdType(StreamId s) {
117 int mask_for_type = (1 << kStreamTypeBits) - 1;
118 if (s && ((s & mask_for_type) == 0)) {
119 // Externally allocated streams have their id being the cudaStream_ptr
120 // so the bits corresponding to the type will be 0 and will collide with
121 // the default stream.
122 return StreamIdType::EXT;
123 }
124 return static_cast<StreamIdType>(s & mask_for_type);
125}
126
127static inline size_t streamIdIndex(StreamId s) {
128 return static_cast<size_t>(
129 (s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1));
130}
131
132StreamId makeStreamId(StreamIdType st, size_t si) {
133 return (static_cast<StreamId>(si) << kStreamTypeBits) |
134 static_cast<StreamId>(st);
135}
136
137// Thread-local current streams
138static thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
139
140// Populates global values.
141// Warning: this function must only be called once!
142static void initGlobalStreamState() {
143 num_gpus = device_count();
144 // Check if the number of GPUs matches the expected compile-time max number
145 // of GPUs.
146 TORCH_CHECK(
147 num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
148 "Number of CUDA devices on the machine is larger than the compiled "
149 "max number of gpus expected (",
150 C10_COMPILE_TIME_MAX_GPUS,
151 "). Increase that and recompile.");
152}
153
154// Creates the low and high priority stream pools for the specified device
155// Warning: only call once per device!
156static void initDeviceStreamState(DeviceIndex device_index) {
157 // Switches to the requested device so streams are properly associated
158 // with it.
159 CUDAGuard device_guard{device_index};
160
161 for (const auto i : c10::irange(kStreamsPerPool)) {
162 auto& lowpri_stream = low_priority_streams[device_index][i];
163 auto& hipri_stream = high_priority_streams[device_index][i];
164
165 C10_CUDA_CHECK(cudaStreamCreateWithPriority(
166 &lowpri_stream, kDefaultFlags, kLowPriority));
167 C10_CUDA_CHECK(cudaStreamCreateWithPriority(
168 &hipri_stream, kDefaultFlags, kHighPriority));
169
170 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
171 if (C10_UNLIKELY(interp)) {
172 (*interp)->trace_gpu_stream_creation(
173 reinterpret_cast<uintptr_t>(lowpri_stream));
174 (*interp)->trace_gpu_stream_creation(
175 reinterpret_cast<uintptr_t>(hipri_stream));
176 }
177 }
178
179 low_priority_counters[device_index] = 0;
180 high_priority_counters[device_index] = 0;
181}
182
183// Init front-end to ensure initialization only occurs once
184static void initCUDAStreamsOnce() {
185 // Inits default streams (once, globally)
186 c10::call_once(init_flag, initGlobalStreamState);
187
188 if (current_streams) {
189 return;
190 }
191
192 // Inits current streams (thread local) to default streams
193 current_streams = std::make_unique<StreamId[]>(num_gpus);
194 for (const auto i : c10::irange(num_gpus)) {
195 current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
196 }
197}
198
199// Helper to verify the GPU index is valid
200static inline void check_gpu(DeviceIndex device_index) {
201 TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus);
202}
203
204// Helper to determine the index of the stream to return
205// Note: Streams are returned round-robin (see note in CUDAStream.h)
206static uint32_t get_idx(std::atomic<uint32_t>& counter) {
207 auto raw_idx = counter++;
208 return raw_idx % kStreamsPerPool;
209}
210
211CUDAStream CUDAStreamForId(DeviceIndex device_index, StreamId stream_id) {
212 return CUDAStream(
213 CUDAStream::UNCHECKED,
214 Stream(
215 Stream::UNSAFE,
216 c10::Device(DeviceType::CUDA, device_index),
217 stream_id));
218}
219
220} // anonymous namespace
221
222// See Note [StreamId assignment]
223cudaStream_t CUDAStream::stream() const {
224 c10::DeviceIndex device_index = stream_.device_index();
225 StreamId stream_id = stream_.id();
226 StreamIdType st = streamIdType(stream_id);
227 size_t si = streamIdIndex(stream_id);
228 switch (st) {
229 case StreamIdType::DEFAULT:
230 TORCH_INTERNAL_ASSERT(
231 si == 0,
232 "Unrecognized stream ",
233 stream_,
234 " (I think this should be the default stream, but I got a non-zero index ",
235 si,
236 ").",
237 " Did you manufacture the StreamId yourself? Don't do that; use the",
238 " official API like c10::cuda::getStreamFromPool() to get a new stream.");
239 return nullptr;
240 case StreamIdType::LOW:
241 return low_priority_streams[device_index][si];
242 case StreamIdType::HIGH:
243 return high_priority_streams[device_index][si];
244 case StreamIdType::EXT:
245 return reinterpret_cast<cudaStream_t>(stream_id);
246 default:
247 TORCH_INTERNAL_ASSERT(
248 0,
249 "Unrecognized stream ",
250 stream_,
251 " (I didn't recognize the stream type, ",
252 st,
253 ")");
254 }
255}
256
257// Returns a stream from the requested pool
258// Note: when called the first time on a device, this will create the
259// stream pools for that device.
260CUDAStream getStreamFromPool(
261 const bool isHighPriority,
262 DeviceIndex device_index) {
263 initCUDAStreamsOnce();
264 if (device_index == -1)
265 device_index = current_device();
266 check_gpu(device_index);
267
268 // Initializes the stream pools (once)
269 c10::call_once(
270 device_flags[device_index], initDeviceStreamState, device_index);
271
272 if (isHighPriority) {
273 const auto idx = get_idx(high_priority_counters[device_index]);
274 return CUDAStreamForId(device_index, makeStreamId(StreamIdType::HIGH, idx));
275 }
276
277 const auto idx = get_idx(low_priority_counters[device_index]);
278 return CUDAStreamForId(device_index, makeStreamId(StreamIdType::LOW, idx));
279}
280
281CUDAStream getStreamFromExternal(
282 cudaStream_t ext_stream,
283 DeviceIndex device_index) {
284 // The stream pointer will be the actual id
285 return CUDAStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
286}
287
288CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
289 initCUDAStreamsOnce();
290 if (device_index == -1) {
291 device_index = current_device();
292 }
293 check_gpu(device_index);
294 return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
295}
296
297CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
298 initCUDAStreamsOnce();
299 if (device_index == -1) {
300 device_index = current_device();
301 }
302 check_gpu(device_index);
303 return CUDAStreamForId(device_index, current_streams[device_index]);
304}
305
306void setCurrentCUDAStream(CUDAStream stream) {
307 initCUDAStreamsOnce();
308 current_streams[stream.device_index()] = stream.id();
309}
310
311std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
312 return stream << s.unwrap();
313}
314
315} // namespace cuda
316} // namespace c10
317