1 | #pragma once |
2 | |
3 | #include <c10/core/impl/InlineDeviceGuard.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | /// RAII guard that sets a certain default device in its constructor, and |
8 | /// changes it back to the device that was originally active upon destruction. |
9 | /// |
10 | /// The device is always reset to the one that was active at the time of |
11 | /// construction of the guard. Even if you `set_device` after construction, the |
12 | /// destructor will still reset the device to the one that was active at |
13 | /// construction time. |
14 | /// |
15 | /// This device guard does NOT have an uninitialized state; it is guaranteed |
16 | /// to reset a device on exit. If you are in a situation where you *might* |
17 | /// want to setup a guard (i.e., are looking for the moral equivalent |
18 | /// of optional<DeviceGuard>), see OptionalDeviceGuard. |
19 | class DeviceGuard { |
20 | public: |
21 | /// No default constructor; see Note [Omitted default constructor from RAII] |
22 | explicit DeviceGuard() = delete; |
23 | |
24 | /// Set the current device to the passed Device. |
25 | explicit DeviceGuard(Device device) : guard_(device) {} |
26 | |
27 | /// This constructor is for testing only. |
28 | explicit DeviceGuard( |
29 | Device device, |
30 | const impl::DeviceGuardImplInterface* impl) |
31 | : guard_(device, impl) {} |
32 | |
33 | /// Copy is disallowed |
34 | DeviceGuard(const DeviceGuard&) = delete; |
35 | DeviceGuard& operator=(const DeviceGuard&) = delete; |
36 | |
37 | /// Move is disallowed, as DeviceGuard does not have an uninitialized state, |
38 | /// which is required for moves on types with nontrivial destructors. |
39 | DeviceGuard(DeviceGuard&& other) = delete; |
40 | DeviceGuard& operator=(DeviceGuard&& other) = delete; |
41 | |
42 | /// Sets the device to the given one. The specified device must be consistent |
43 | /// with the device type originally specified during guard construction. |
44 | /// |
45 | /// TODO: The consistency check here is inconsistent with StreamGuard's |
46 | /// behavior with set_stream, where a stream on a different device than |
47 | /// the original one isn't an error; we just reset the stream and then |
48 | /// switch devices. |
49 | void reset_device(at::Device device) { |
50 | guard_.reset_device(device); |
51 | } |
52 | |
53 | /// This method is for testing only. |
54 | void reset_device( |
55 | at::Device device, |
56 | const impl::DeviceGuardImplInterface* impl) { |
57 | guard_.reset_device(device, impl); |
58 | } |
59 | |
60 | /// Sets the device index to the given one. The device type is inferred |
61 | /// from the original device type the guard was constructed with. |
62 | void set_index(DeviceIndex index) { |
63 | guard_.set_index(index); |
64 | } |
65 | |
66 | /// Returns the device that was set at the time the guard was constructed. |
67 | Device original_device() const { |
68 | return guard_.original_device(); |
69 | } |
70 | |
71 | /// Returns the most recent device that was set using this device guard, |
72 | /// either from construction, or via set_device. |
73 | Device current_device() const { |
74 | return guard_.current_device(); |
75 | } |
76 | |
77 | private: |
78 | impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_; |
79 | }; |
80 | |
81 | /** |
82 | * A OptionalDeviceGuard is an RAII class that sets a device to some value on |
83 | * initialization, and resets the device to its original value on destruction. |
84 | * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but |
85 | * with extra constructors and methods as appropriate. |
86 | * |
87 | * Besides its obvious use (optionally applying a DeviceGuard), |
88 | * OptionalDeviceGuard is often also used for the following idiom: |
89 | * |
90 | * OptionalDeviceGuard g; |
91 | * for (const auto& t : tensors) { |
92 | * g.set_device(t.device()); |
93 | * do_something_with(t); |
94 | * } |
95 | * |
96 | * This usage is marginally more efficient than constructing a DeviceGuard every |
97 | * iteration of the for loop, as it avoids an unnecessary device reset. |
98 | * |
99 | * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs |
100 | * when you use the nullary constructor, or pass a nullopt to the constructor. |
101 | * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the |
102 | * original device was and they do not reset on destruction. This is why |
103 | * original_device() and current_device() return optional<Device> rather than |
104 | * Device (as they do in DeviceGuard), and also is why we didn't just |
105 | * provide OptionalDeviceGuard by default and hide DeviceGuard from users. |
106 | * |
107 | * The semantics of an OptionalDeviceGuard are exactly explained by thinking |
108 | * of it as an optional<DeviceGuard>. In particular, an initialized |
109 | * OptionalDeviceGuard doesn't restore device to its value at construction; it |
110 | * restores device to its value *at initialization*. So if you have the |
111 | * program: |
112 | * |
113 | * setDevice(1); |
114 | * OptionalDeviceGuard g; |
115 | * setDevice(2); |
116 | * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! |
117 | * |
118 | * On destruction, g will reset device to 2, rather than 1. |
119 | * |
120 | * An uninitialized OptionalDeviceGuard is distinct from a (initialized) |
121 | * DeviceGuard whose original_device_ and current_device_ match, since the |
122 | * DeviceGuard will still reset the device to original_device_. |
123 | */ |
124 | class OptionalDeviceGuard { |
125 | public: |
126 | /// Create an uninitialized guard. Set the guard later using reset_device. |
127 | explicit OptionalDeviceGuard() = default; |
128 | |
129 | /// Initialize the guard, setting the current device to the passed Device. |
130 | explicit OptionalDeviceGuard(Device device) : guard_(device) {} |
131 | |
132 | /// Initialize the guard if a Device is passed; otherwise leave the |
133 | /// guard uninitialized. |
134 | explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {} |
135 | |
136 | /// Constructor for testing only. |
137 | explicit OptionalDeviceGuard( |
138 | Device device, |
139 | const impl::DeviceGuardImplInterface* impl) |
140 | : guard_(device, impl) {} |
141 | |
142 | /// Copy is disallowed |
143 | OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; |
144 | OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; |
145 | |
146 | /// Move is disallowed |
147 | /// See Note [Explicit initialization of optional fields] |
148 | /// and // Note [Move construction for RAII guards is tricky] |
149 | /// for rationale. |
150 | OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; |
151 | OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; |
152 | |
153 | /// Sets the device to the given one. The specified device must be consistent |
154 | /// with the device type originally specified during guard construction. |
155 | void reset_device(at::Device device) { |
156 | guard_.reset_device(device); |
157 | } |
158 | |
159 | /// For testing only |
160 | void reset_device( |
161 | at::Device device, |
162 | const impl::DeviceGuardImplInterface* impl) { |
163 | guard_.reset_device(device, impl); |
164 | } |
165 | |
166 | /// Returns the device that was set at the time the guard was constructed. |
167 | optional<Device> original_device() const { |
168 | return guard_.original_device(); |
169 | } |
170 | |
171 | /// Returns the most recent device that was set using this device guard, |
172 | /// either from construction, or via reset_device. |
173 | optional<Device> current_device() const { |
174 | return guard_.current_device(); |
175 | } |
176 | |
177 | private: |
178 | impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{}; |
179 | }; |
180 | |
181 | // Note [Whither the DeviceGuard boilerplate] |
182 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
183 | // Design note: in principle, we could avoid these wrappers using: |
184 | // |
185 | // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>; |
186 | // using OptionalDeviceGuard = |
187 | // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>; |
188 | // |
189 | // But the error messages are worse, and our users can't just look at the |
190 | // header file to find out what's going on. Furthermore, for specializations |
191 | // like CUDAStreamGuard, it can be profitable to replace some interfaces with |
192 | // refined types (e.g., return CUDAStream instead of Stream). So, we eat |
193 | // the boilerplate and write out the API explicitly. |
194 | |
195 | } // namespace c10 |
196 | |