1#pragma once
2
3#include <c10/core/Device.h>
4
5namespace c10 {
6
7/// An index representing a specific stream. A StreamId is not independently
8/// meaningful without knowing the Device it is associated with; try to
9/// use Stream rather than StreamId directly.
10///
11/// StreamIds are opaque; they are assigned by some DeviceType-specific
12/// numbering system which is not visible to the user. HOWEVER, we
13/// guarantee that StreamId 0 is always a valid stream, and corresponds
14/// to some sort of "default" stream.
15using StreamId = int64_t;
16
17struct C10_API StreamData3 {
18 StreamId stream_id;
19 DeviceIndex device_index;
20 DeviceType device_type;
21};
22
23// NB: I decided not to call the above StreamIndex to avoid confusion with
24// DeviceIndex. This way, you access device index with index(), and stream id
25// with id()
26
27/**
28 * A stream is a software mechanism used to synchronize launched kernels
29 * without requiring explicit synchronizations between kernels. The basic
30 * model is that every kernel launch is associated with a stream: every
31 * kernel on the same stream is implicitly synchronized so that if I launch
32 * kernels A and B on the same stream, A is guaranteed to finish before B
33 * launches. If I want B to run concurrently with A, I must schedule
34 * it on a different stream.
35 *
36 * The Stream class is a backend agnostic value class representing a stream
37 * which I may schedule a kernel on. Every stream is associated with a device,
38 * which is recorded in stream, which is used to avoid confusion about which
39 * device a stream refers to.
40 *
41 * Streams are explicitly thread-safe, in the sense that it is OK to pass
42 * a Stream from one thread to another, and kernels queued from two different
43 * threads will still get serialized appropriately. (Of course, the
44 * time when the kernels get queued is undetermined unless you synchronize
45 * host side ;)
46 *
47 * Stream does NOT have a default constructor. Streams are for expert
48 * users; if you want to use Streams, we're going to assume you know
49 * how to deal with C++ template error messages if you try to
50 * resize() a vector of Streams.
51 *
52 * Known instances of streams in backends:
53 *
54 * - cudaStream_t (CUDA)
55 * - hipStream_t (HIP)
56 * - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration
57 * does NOT support command queues.)
58 *
59 * Because this class is device agnostic, it cannot provide backend-specific
60 * functionality (e.g., get the cudaStream_t of a CUDA stream.) There are
61 * wrapper classes which provide this functionality, e.g., CUDAStream.
62 */
63class C10_API Stream final {
64 private:
65 Device device_;
66 StreamId id_;
67
68 public:
69 enum Unsafe { UNSAFE };
70 enum Default { DEFAULT };
71
72 /// Unsafely construct a stream from a Device and a StreamId. In
73 /// general, only specific implementations of streams for a
74 /// backend should manufacture Stream directly in this way; other users
75 /// should use the provided APIs to get a stream. In particular,
76 /// we don't require backends to give any guarantees about non-zero
77 /// StreamIds; they are welcome to allocate in whatever way they like.
78 explicit Stream(Unsafe, Device device, StreamId id)
79 : device_(device), id_(id) {}
80
81 /// Construct the default stream of a Device. The default stream is
82 /// NOT the same as the current stream; default stream is a fixed stream
83 /// that never changes, whereas the current stream may be changed by
84 /// StreamGuard.
85 explicit Stream(Default, Device device) : device_(device), id_(0) {}
86
87 bool operator==(const Stream& other) const noexcept {
88 return this->device_ == other.device_ && this->id_ == other.id_;
89 }
90 bool operator!=(const Stream& other) const noexcept {
91 return !(*this == other);
92 }
93
94 Device device() const noexcept {
95 return device_;
96 }
97 DeviceType device_type() const noexcept {
98 return device_.type();
99 }
100 DeviceIndex device_index() const noexcept {
101 return device_.index();
102 }
103 StreamId id() const noexcept {
104 return id_;
105 }
106
107 // Enqueues a wait instruction in the stream's work queue.
108 // This instruction is a no-op unless the event is marked
109 // for recording. In that case the stream stops processing
110 // until the event is recorded.
111 template <typename T>
112 void wait(const T& event) const {
113 event.block(*this);
114 }
115
116 // Return whether all asynchronous work previously enqueued on this stream
117 // has completed running on the device.
118 bool query() const;
119
120 // Wait (by blocking the calling thread) until all asynchronous work enqueued
121 // on this stream has completed running on the device.
122 void synchronize() const;
123
124 // The purpose of this function is to more conveniently permit binding
125 // of Stream to and from Python. Without packing, I have to setup a whole
126 // class with two fields (device and stream id); with packing I can just
127 // store a single uint64_t.
128 //
129 // The particular way we pack streams into a uint64_t is considered an
130 // implementation detail and should not be relied upon.
131 uint64_t hash() const noexcept {
132 // Concat these together into a 64-bit integer
133 uint64_t bits = static_cast<uint64_t>(device_type()) << 56 |
134 static_cast<uint64_t>(device_index()) << 48 |
135 // Remove the sign extension part of the 64-bit address because
136 // the id might be used to hold a pointer.
137 (static_cast<uint64_t>(id()) & ((1ull << 48) - 1));
138 return bits;
139 }
140
141 struct StreamData3 pack3() const {
142 return {id(), device_index(), device_type()};
143 }
144
145 static Stream unpack3(
146 StreamId stream_id,
147 DeviceIndex device_index,
148 DeviceType device_type) {
149 TORCH_CHECK(isValidDeviceType(device_type));
150 return Stream(UNSAFE, Device(device_type, device_index), stream_id);
151 }
152
153 // I decided NOT to provide setters on this class, because really,
154 // why would you change the device of a stream? Just construct
155 // it correctly from the beginning dude.
156};
157
158C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
159
160} // namespace c10
161
162namespace std {
163template <>
164struct hash<c10::Stream> {
165 size_t operator()(c10::Stream s) const noexcept {
166 return std::hash<uint64_t>{}(s.hash());
167 }
168};
169} // namespace std
170