1 | #pragma once |
2 | |
3 | #include <c10/core/impl/InlineStreamGuard.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | /** |
8 | * A StreamGuard is an RAII class that changes the current device |
9 | * to the device corresponding to some stream, and changes the |
10 | * default stream on that device to be this stream. |
11 | * |
12 | * Use of StreamGuard is HIGHLY discouraged in operator definitions. In |
13 | * a single operator, you probably don't know enough about the global |
14 | * state of the world to profitably decide how to set streams. Let |
15 | * the caller handle this appropriately, and just use the current stream |
16 | * in your operator code. |
17 | * |
18 | * This StreamGuard does NOT have an uninitialized state; it is guaranteed |
19 | * to reset the stream and device on exit. If you are in a situation |
20 | * where you *might* want to setup a stream guard, see OptionalStreamGuard. |
21 | */ |
22 | struct StreamGuard { |
23 | /// No default constructor, see Note [Omitted default constructor from RAII] |
24 | explicit StreamGuard() = delete; |
25 | |
26 | /// Set the current device to the device associated with the passed stream, |
27 | /// and set the current stream on that device to the passed stream. |
28 | explicit StreamGuard(Stream stream) : guard_(stream) {} |
29 | |
30 | /// Copy is disallowed |
31 | StreamGuard(const StreamGuard&) = delete; |
32 | StreamGuard& operator=(const StreamGuard&) = delete; |
33 | |
34 | /// Move is disallowed, as StreamGuard does not have an uninitialized state, |
35 | /// which is required for moves on types with nontrivial destructors. |
36 | StreamGuard(StreamGuard&& other) = delete; |
37 | StreamGuard& operator=(StreamGuard&& other) = delete; |
38 | |
39 | /// Resets the currently set stream to the original stream and |
40 | /// the currently set device to the original device. Then, |
41 | /// set the current device to the device associated with the passed stream, |
42 | /// and set the current stream on that device to the passed stream. |
43 | /// |
44 | /// NOTE: this implementation may skip some stream/device setting if |
45 | /// it can prove that it is unnecessary. |
46 | /// |
47 | /// WARNING: reset_stream does NOT preserve previously set streams on |
48 | /// different devices. If you need to set streams on multiple devices |
49 | /// on , use MultiStreamGuard instead. |
50 | void reset_stream(Stream stream) { |
51 | guard_.reset_stream(stream); |
52 | } |
53 | |
54 | /// Returns the stream that was set at the time the guard was constructed. |
55 | Stream original_stream() const { |
56 | return guard_.original_stream(); |
57 | } |
58 | |
59 | /// Returns the most recent stream that was set using this device guard, |
60 | /// either from construction, or via set_stream. |
61 | Stream current_stream() const { |
62 | return guard_.current_stream(); |
63 | } |
64 | |
65 | /// Returns the most recent device that was set using this device guard, |
66 | /// either from construction, or via set_device/reset_device/set_index. |
67 | Device current_device() const { |
68 | return guard_.current_device(); |
69 | } |
70 | |
71 | /// Returns the device that was set at the most recent reset_stream(), |
72 | /// or otherwise the device at construction time. |
73 | Device original_device() const { |
74 | return guard_.original_device(); |
75 | } |
76 | |
77 | private: |
78 | c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_; |
79 | }; |
80 | |
81 | /** |
82 | * An OptionalStreamGuard is an RAII class that sets a device to some value on |
83 | * initialization, and resets the device to its original value on destruction. |
84 | * See OptionalDeviceGuard for more guidance on how to use this class. |
85 | */ |
86 | struct OptionalStreamGuard { |
87 | /// Create an uninitialized guard. |
88 | explicit OptionalStreamGuard() = default; |
89 | |
90 | /// Set the current device to the device associated with the passed stream, |
91 | /// and set the current stream on that device to the passed stream. |
92 | explicit OptionalStreamGuard(Stream stream) : guard_(stream) {} |
93 | |
94 | /// Set the current device to the device associated with the passed stream, |
95 | /// and set the current stream on that device to the passed stream, |
96 | /// if the passed stream is not nullopt. |
97 | explicit OptionalStreamGuard(optional<Stream> stream_opt) |
98 | : guard_(stream_opt) {} |
99 | |
100 | /// Copy is disallowed |
101 | OptionalStreamGuard(const OptionalStreamGuard&) = delete; |
102 | OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete; |
103 | |
104 | // See Note [Move construction for RAII guards is tricky] |
105 | OptionalStreamGuard(OptionalStreamGuard&& other) = delete; |
106 | |
107 | // See Note [Move assignment for RAII guards is tricky] |
108 | OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete; |
109 | |
110 | /// Resets the currently set stream to the original stream and |
111 | /// the currently set device to the original device. Then, |
112 | /// set the current device to the device associated with the passed stream, |
113 | /// and set the current stream on that device to the passed stream. |
114 | /// Initializes the guard if it was not previously initialized. |
115 | void reset_stream(Stream stream) { |
116 | guard_.reset_stream(stream); |
117 | } |
118 | |
119 | /// Returns the stream that was set at the time the guard was most recently |
120 | /// initialized, or nullopt if the guard is uninitialized. |
121 | optional<Stream> original_stream() const { |
122 | return guard_.original_stream(); |
123 | } |
124 | |
125 | /// Returns the most recent stream that was set using this stream guard, |
126 | /// either from construction, or via reset_stream, if the guard is |
127 | /// initialized, or nullopt if the guard is uninitialized. |
128 | optional<Stream> current_stream() const { |
129 | return guard_.current_stream(); |
130 | } |
131 | |
132 | /// Restore the original device and stream, resetting this guard to |
133 | /// uninitialized state. |
134 | void reset() { |
135 | guard_.reset(); |
136 | } |
137 | |
138 | private: |
139 | c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_{}; |
140 | }; |
141 | |
142 | /** |
143 | * A MultiStreamGuard is an RAII class that sets the current streams of a set of |
144 | * devices all at once, and resets them to their original values on destruction. |
145 | */ |
146 | struct MultiStreamGuard { |
147 | /// Set the current streams to the passed streams on each of their respective |
148 | /// devices. |
149 | explicit MultiStreamGuard(ArrayRef<Stream> streams) : guard_(streams) {} |
150 | |
151 | /// Copy is disallowed |
152 | MultiStreamGuard(const MultiStreamGuard&) = delete; |
153 | MultiStreamGuard& operator=(const MultiStreamGuard&) = delete; |
154 | |
155 | // See Note [Move construction for RAII guards is tricky] |
156 | MultiStreamGuard(MultiStreamGuard&& other) = delete; |
157 | |
158 | // See Note [Move assignment for RAII guards is tricky] |
159 | MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete; |
160 | |
161 | private: |
162 | c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_; |
163 | }; |
164 | |
165 | } // namespace c10 |
166 | |