1 | #pragma once |
2 | |
3 | #include <c10/core/Device.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | /// An index representing a specific stream. A StreamId is not independently |
8 | /// meaningful without knowing the Device it is associated with; try to |
9 | /// use Stream rather than StreamId directly. |
10 | /// |
11 | /// StreamIds are opaque; they are assigned by some DeviceType-specific |
12 | /// numbering system which is not visible to the user. HOWEVER, we |
13 | /// guarantee that StreamId 0 is always a valid stream, and corresponds |
14 | /// to some sort of "default" stream. |
15 | using StreamId = int64_t; |
16 | |
17 | struct C10_API StreamData3 { |
18 | StreamId stream_id; |
19 | DeviceIndex device_index; |
20 | DeviceType device_type; |
21 | }; |
22 | |
23 | // NB: I decided not to call the above StreamIndex to avoid confusion with |
24 | // DeviceIndex. This way, you access device index with index(), and stream id |
25 | // with id() |
26 | |
27 | /** |
28 | * A stream is a software mechanism used to synchronize launched kernels |
29 | * without requiring explicit synchronizations between kernels. The basic |
30 | * model is that every kernel launch is associated with a stream: every |
31 | * kernel on the same stream is implicitly synchronized so that if I launch |
32 | * kernels A and B on the same stream, A is guaranteed to finish before B |
33 | * launches. If I want B to run concurrently with A, I must schedule |
34 | * it on a different stream. |
35 | * |
36 | * The Stream class is a backend agnostic value class representing a stream |
37 | * which I may schedule a kernel on. Every stream is associated with a device, |
38 | * which is recorded in stream, which is used to avoid confusion about which |
39 | * device a stream refers to. |
40 | * |
41 | * Streams are explicitly thread-safe, in the sense that it is OK to pass |
42 | * a Stream from one thread to another, and kernels queued from two different |
43 | * threads will still get serialized appropriately. (Of course, the |
44 | * time when the kernels get queued is undetermined unless you synchronize |
45 | * host side ;) |
46 | * |
47 | * Stream does NOT have a default constructor. Streams are for expert |
48 | * users; if you want to use Streams, we're going to assume you know |
49 | * how to deal with C++ template error messages if you try to |
50 | * resize() a vector of Streams. |
51 | * |
52 | * Known instances of streams in backends: |
53 | * |
54 | * - cudaStream_t (CUDA) |
55 | * - hipStream_t (HIP) |
56 | * - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration |
57 | * does NOT support command queues.) |
58 | * |
59 | * Because this class is device agnostic, it cannot provide backend-specific |
60 | * functionality (e.g., get the cudaStream_t of a CUDA stream.) There are |
61 | * wrapper classes which provide this functionality, e.g., CUDAStream. |
62 | */ |
63 | class C10_API Stream final { |
64 | private: |
65 | Device device_; |
66 | StreamId id_; |
67 | |
68 | public: |
69 | enum Unsafe { UNSAFE }; |
70 | enum Default { DEFAULT }; |
71 | |
72 | /// Unsafely construct a stream from a Device and a StreamId. In |
73 | /// general, only specific implementations of streams for a |
74 | /// backend should manufacture Stream directly in this way; other users |
75 | /// should use the provided APIs to get a stream. In particular, |
76 | /// we don't require backends to give any guarantees about non-zero |
77 | /// StreamIds; they are welcome to allocate in whatever way they like. |
78 | explicit Stream(Unsafe, Device device, StreamId id) |
79 | : device_(device), id_(id) {} |
80 | |
81 | /// Construct the default stream of a Device. The default stream is |
82 | /// NOT the same as the current stream; default stream is a fixed stream |
83 | /// that never changes, whereas the current stream may be changed by |
84 | /// StreamGuard. |
85 | explicit Stream(Default, Device device) : device_(device), id_(0) {} |
86 | |
87 | bool operator==(const Stream& other) const noexcept { |
88 | return this->device_ == other.device_ && this->id_ == other.id_; |
89 | } |
90 | bool operator!=(const Stream& other) const noexcept { |
91 | return !(*this == other); |
92 | } |
93 | |
94 | Device device() const noexcept { |
95 | return device_; |
96 | } |
97 | DeviceType device_type() const noexcept { |
98 | return device_.type(); |
99 | } |
100 | DeviceIndex device_index() const noexcept { |
101 | return device_.index(); |
102 | } |
103 | StreamId id() const noexcept { |
104 | return id_; |
105 | } |
106 | |
107 | // Enqueues a wait instruction in the stream's work queue. |
108 | // This instruction is a no-op unless the event is marked |
109 | // for recording. In that case the stream stops processing |
110 | // until the event is recorded. |
111 | template <typename T> |
112 | void wait(const T& event) const { |
113 | event.block(*this); |
114 | } |
115 | |
116 | // Return whether all asynchronous work previously enqueued on this stream |
117 | // has completed running on the device. |
118 | bool query() const; |
119 | |
120 | // Wait (by blocking the calling thread) until all asynchronous work enqueued |
121 | // on this stream has completed running on the device. |
122 | void synchronize() const; |
123 | |
124 | // The purpose of this function is to more conveniently permit binding |
125 | // of Stream to and from Python. Without packing, I have to setup a whole |
126 | // class with two fields (device and stream id); with packing I can just |
127 | // store a single uint64_t. |
128 | // |
129 | // The particular way we pack streams into a uint64_t is considered an |
130 | // implementation detail and should not be relied upon. |
131 | uint64_t hash() const noexcept { |
132 | // Concat these together into a 64-bit integer |
133 | uint64_t bits = static_cast<uint64_t>(device_type()) << 56 | |
134 | static_cast<uint64_t>(device_index()) << 48 | |
135 | // Remove the sign extension part of the 64-bit address because |
136 | // the id might be used to hold a pointer. |
137 | (static_cast<uint64_t>(id()) & ((1ull << 48) - 1)); |
138 | return bits; |
139 | } |
140 | |
141 | struct StreamData3 pack3() const { |
142 | return {id(), device_index(), device_type()}; |
143 | } |
144 | |
145 | static Stream unpack3( |
146 | StreamId stream_id, |
147 | DeviceIndex device_index, |
148 | DeviceType device_type) { |
149 | TORCH_CHECK(isValidDeviceType(device_type)); |
150 | return Stream(UNSAFE, Device(device_type, device_index), stream_id); |
151 | } |
152 | |
153 | // I decided NOT to provide setters on this class, because really, |
154 | // why would you change the device of a stream? Just construct |
155 | // it correctly from the beginning dude. |
156 | }; |
157 | |
158 | C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s); |
159 | |
160 | } // namespace c10 |
161 | |
162 | namespace std { |
163 | template <> |
164 | struct hash<c10::Stream> { |
165 | size_t operator()(c10::Stream s) const noexcept { |
166 | return std::hash<uint64_t>{}(s.hash()); |
167 | } |
168 | }; |
169 | } // namespace std |
170 | |