1#pragma once
2
3#include <c10/core/impl/InlineDeviceGuard.h>
4
5namespace c10 {
6
7/// RAII guard that sets a certain default device in its constructor, and
8/// changes it back to the device that was originally active upon destruction.
9///
10/// The device is always reset to the one that was active at the time of
11/// construction of the guard. Even if you `set_device` after construction, the
12/// destructor will still reset the device to the one that was active at
13/// construction time.
14///
15/// This device guard does NOT have an uninitialized state; it is guaranteed
16/// to reset a device on exit. If you are in a situation where you *might*
17/// want to setup a guard (i.e., are looking for the moral equivalent
18/// of optional<DeviceGuard>), see OptionalDeviceGuard.
19class DeviceGuard {
20 public:
21 /// No default constructor; see Note [Omitted default constructor from RAII]
22 explicit DeviceGuard() = delete;
23
24 /// Set the current device to the passed Device.
25 explicit DeviceGuard(Device device) : guard_(device) {}
26
27 /// This constructor is for testing only.
28 explicit DeviceGuard(
29 Device device,
30 const impl::DeviceGuardImplInterface* impl)
31 : guard_(device, impl) {}
32
33 /// Copy is disallowed
34 DeviceGuard(const DeviceGuard&) = delete;
35 DeviceGuard& operator=(const DeviceGuard&) = delete;
36
37 /// Move is disallowed, as DeviceGuard does not have an uninitialized state,
38 /// which is required for moves on types with nontrivial destructors.
39 DeviceGuard(DeviceGuard&& other) = delete;
40 DeviceGuard& operator=(DeviceGuard&& other) = delete;
41
42 /// Sets the device to the given one. The specified device must be consistent
43 /// with the device type originally specified during guard construction.
44 ///
45 /// TODO: The consistency check here is inconsistent with StreamGuard's
46 /// behavior with set_stream, where a stream on a different device than
47 /// the original one isn't an error; we just reset the stream and then
48 /// switch devices.
49 void reset_device(at::Device device) {
50 guard_.reset_device(device);
51 }
52
53 /// This method is for testing only.
54 void reset_device(
55 at::Device device,
56 const impl::DeviceGuardImplInterface* impl) {
57 guard_.reset_device(device, impl);
58 }
59
60 /// Sets the device index to the given one. The device type is inferred
61 /// from the original device type the guard was constructed with.
62 void set_index(DeviceIndex index) {
63 guard_.set_index(index);
64 }
65
66 /// Returns the device that was set at the time the guard was constructed.
67 Device original_device() const {
68 return guard_.original_device();
69 }
70
71 /// Returns the most recent device that was set using this device guard,
72 /// either from construction, or via set_device.
73 Device current_device() const {
74 return guard_.current_device();
75 }
76
77 private:
78 impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
79};
80
81/**
82 * A OptionalDeviceGuard is an RAII class that sets a device to some value on
83 * initialization, and resets the device to its original value on destruction.
84 * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
85 * with extra constructors and methods as appropriate.
86 *
87 * Besides its obvious use (optionally applying a DeviceGuard),
88 * OptionalDeviceGuard is often also used for the following idiom:
89 *
90 * OptionalDeviceGuard g;
91 * for (const auto& t : tensors) {
92 * g.set_device(t.device());
93 * do_something_with(t);
94 * }
95 *
96 * This usage is marginally more efficient than constructing a DeviceGuard every
97 * iteration of the for loop, as it avoids an unnecessary device reset.
98 *
99 * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
100 * when you use the nullary constructor, or pass a nullopt to the constructor.
101 * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
102 * original device was and they do not reset on destruction. This is why
103 * original_device() and current_device() return optional<Device> rather than
104 * Device (as they do in DeviceGuard), and also is why we didn't just
105 * provide OptionalDeviceGuard by default and hide DeviceGuard from users.
106 *
107 * The semantics of an OptionalDeviceGuard are exactly explained by thinking
108 * of it as an optional<DeviceGuard>. In particular, an initialized
109 * OptionalDeviceGuard doesn't restore device to its value at construction; it
110 * restores device to its value *at initialization*. So if you have the
111 * program:
112 *
113 * setDevice(1);
114 * OptionalDeviceGuard g;
115 * setDevice(2);
116 * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
117 *
118 * On destruction, g will reset device to 2, rather than 1.
119 *
120 * An uninitialized OptionalDeviceGuard is distinct from a (initialized)
121 * DeviceGuard whose original_device_ and current_device_ match, since the
122 * DeviceGuard will still reset the device to original_device_.
123 */
124class OptionalDeviceGuard {
125 public:
126 /// Create an uninitialized guard. Set the guard later using reset_device.
127 explicit OptionalDeviceGuard() = default;
128
129 /// Initialize the guard, setting the current device to the passed Device.
130 explicit OptionalDeviceGuard(Device device) : guard_(device) {}
131
132 /// Initialize the guard if a Device is passed; otherwise leave the
133 /// guard uninitialized.
134 explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
135
136 /// Constructor for testing only.
137 explicit OptionalDeviceGuard(
138 Device device,
139 const impl::DeviceGuardImplInterface* impl)
140 : guard_(device, impl) {}
141
142 /// Copy is disallowed
143 OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
144 OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
145
146 /// Move is disallowed
147 /// See Note [Explicit initialization of optional fields]
148 /// and // Note [Move construction for RAII guards is tricky]
149 /// for rationale.
150 OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
151 OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
152
153 /// Sets the device to the given one. The specified device must be consistent
154 /// with the device type originally specified during guard construction.
155 void reset_device(at::Device device) {
156 guard_.reset_device(device);
157 }
158
159 /// For testing only
160 void reset_device(
161 at::Device device,
162 const impl::DeviceGuardImplInterface* impl) {
163 guard_.reset_device(device, impl);
164 }
165
166 /// Returns the device that was set at the time the guard was constructed.
167 optional<Device> original_device() const {
168 return guard_.original_device();
169 }
170
171 /// Returns the most recent device that was set using this device guard,
172 /// either from construction, or via reset_device.
173 optional<Device> current_device() const {
174 return guard_.current_device();
175 }
176
177 private:
178 impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
179};
180
181// Note [Whither the DeviceGuard boilerplate]
182// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183// Design note: in principle, we could avoid these wrappers using:
184//
185// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
186// using OptionalDeviceGuard =
187// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
188//
189// But the error messages are worse, and our users can't just look at the
190// header file to find out what's going on. Furthermore, for specializations
191// like CUDAStreamGuard, it can be profitable to replace some interfaces with
192// refined types (e.g., return CUDAStream instead of Stream). So, we eat
193// the boilerplate and write out the API explicitly.
194
195} // namespace c10
196