1#pragma once
2
3#include <c10/core/impl/InlineStreamGuard.h>
4
5namespace c10 {
6
7/**
8 * A StreamGuard is an RAII class that changes the current device
9 * to the device corresponding to some stream, and changes the
10 * default stream on that device to be this stream.
11 *
12 * Use of StreamGuard is HIGHLY discouraged in operator definitions. In
13 * a single operator, you probably don't know enough about the global
14 * state of the world to profitably decide how to set streams. Let
15 * the caller handle this appropriately, and just use the current stream
16 * in your operator code.
17 *
18 * This StreamGuard does NOT have an uninitialized state; it is guaranteed
19 * to reset the stream and device on exit. If you are in a situation
20 * where you *might* want to setup a stream guard, see OptionalStreamGuard.
21 */
22struct StreamGuard {
23 /// No default constructor, see Note [Omitted default constructor from RAII]
24 explicit StreamGuard() = delete;
25
26 /// Set the current device to the device associated with the passed stream,
27 /// and set the current stream on that device to the passed stream.
28 explicit StreamGuard(Stream stream) : guard_(stream) {}
29
30 /// Copy is disallowed
31 StreamGuard(const StreamGuard&) = delete;
32 StreamGuard& operator=(const StreamGuard&) = delete;
33
34 /// Move is disallowed, as StreamGuard does not have an uninitialized state,
35 /// which is required for moves on types with nontrivial destructors.
36 StreamGuard(StreamGuard&& other) = delete;
37 StreamGuard& operator=(StreamGuard&& other) = delete;
38
39 /// Resets the currently set stream to the original stream and
40 /// the currently set device to the original device. Then,
41 /// set the current device to the device associated with the passed stream,
42 /// and set the current stream on that device to the passed stream.
43 ///
44 /// NOTE: this implementation may skip some stream/device setting if
45 /// it can prove that it is unnecessary.
46 ///
47 /// WARNING: reset_stream does NOT preserve previously set streams on
48 /// different devices. If you need to set streams on multiple devices
49 /// on , use MultiStreamGuard instead.
50 void reset_stream(Stream stream) {
51 guard_.reset_stream(stream);
52 }
53
54 /// Returns the stream that was set at the time the guard was constructed.
55 Stream original_stream() const {
56 return guard_.original_stream();
57 }
58
59 /// Returns the most recent stream that was set using this device guard,
60 /// either from construction, or via set_stream.
61 Stream current_stream() const {
62 return guard_.current_stream();
63 }
64
65 /// Returns the most recent device that was set using this device guard,
66 /// either from construction, or via set_device/reset_device/set_index.
67 Device current_device() const {
68 return guard_.current_device();
69 }
70
71 /// Returns the device that was set at the most recent reset_stream(),
72 /// or otherwise the device at construction time.
73 Device original_device() const {
74 return guard_.original_device();
75 }
76
77 private:
78 c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_;
79};
80
81/**
82 * An OptionalStreamGuard is an RAII class that sets a device to some value on
83 * initialization, and resets the device to its original value on destruction.
84 * See OptionalDeviceGuard for more guidance on how to use this class.
85 */
86struct OptionalStreamGuard {
87 /// Create an uninitialized guard.
88 explicit OptionalStreamGuard() = default;
89
90 /// Set the current device to the device associated with the passed stream,
91 /// and set the current stream on that device to the passed stream.
92 explicit OptionalStreamGuard(Stream stream) : guard_(stream) {}
93
94 /// Set the current device to the device associated with the passed stream,
95 /// and set the current stream on that device to the passed stream,
96 /// if the passed stream is not nullopt.
97 explicit OptionalStreamGuard(optional<Stream> stream_opt)
98 : guard_(stream_opt) {}
99
100 /// Copy is disallowed
101 OptionalStreamGuard(const OptionalStreamGuard&) = delete;
102 OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete;
103
104 // See Note [Move construction for RAII guards is tricky]
105 OptionalStreamGuard(OptionalStreamGuard&& other) = delete;
106
107 // See Note [Move assignment for RAII guards is tricky]
108 OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete;
109
110 /// Resets the currently set stream to the original stream and
111 /// the currently set device to the original device. Then,
112 /// set the current device to the device associated with the passed stream,
113 /// and set the current stream on that device to the passed stream.
114 /// Initializes the guard if it was not previously initialized.
115 void reset_stream(Stream stream) {
116 guard_.reset_stream(stream);
117 }
118
119 /// Returns the stream that was set at the time the guard was most recently
120 /// initialized, or nullopt if the guard is uninitialized.
121 optional<Stream> original_stream() const {
122 return guard_.original_stream();
123 }
124
125 /// Returns the most recent stream that was set using this stream guard,
126 /// either from construction, or via reset_stream, if the guard is
127 /// initialized, or nullopt if the guard is uninitialized.
128 optional<Stream> current_stream() const {
129 return guard_.current_stream();
130 }
131
132 /// Restore the original device and stream, resetting this guard to
133 /// uninitialized state.
134 void reset() {
135 guard_.reset();
136 }
137
138 private:
139 c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_{};
140};
141
142/**
143 * A MultiStreamGuard is an RAII class that sets the current streams of a set of
144 * devices all at once, and resets them to their original values on destruction.
145 */
146struct MultiStreamGuard {
147 /// Set the current streams to the passed streams on each of their respective
148 /// devices.
149 explicit MultiStreamGuard(ArrayRef<Stream> streams) : guard_(streams) {}
150
151 /// Copy is disallowed
152 MultiStreamGuard(const MultiStreamGuard&) = delete;
153 MultiStreamGuard& operator=(const MultiStreamGuard&) = delete;
154
155 // See Note [Move construction for RAII guards is tricky]
156 MultiStreamGuard(MultiStreamGuard&& other) = delete;
157
158 // See Note [Move assignment for RAII guards is tricky]
159 MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete;
160
161 private:
162 c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_;
163};
164
165} // namespace c10
166