1 | #include <torch/csrc/autograd/input_buffer.h> |
2 | |
3 | #include <ATen/LegacyBatchedTensorImpl.h> |
4 | #include <ATen/SparseCsrTensorUtils.h> |
5 | #include <ATen/SparseTensorUtils.h> |
6 | #include <ATen/TensorOperators.h> |
7 | #include <ATen/TensorSubclassLikeUtils.h> |
8 | |
9 | #include <c10/core/DeviceGuard.h> |
10 | #include <c10/core/Event.h> |
11 | #include <c10/core/StreamGuard.h> |
12 | #include <c10/util/Optional.h> |
13 | |
14 | #include <cstddef> |
15 | #include <utility> |
16 | #include <vector> |
17 | |
18 | namespace torch { |
19 | namespace autograd { |
20 | |
21 | namespace { |
22 | // look what you made me do >.< |
23 | // Divergent paths for per-Impl stream recording that leak implementation |
24 | // details of the impls should not be needed here. |
25 | // See https://github.com/pytorch/pytorch/issues/60306 |
26 | // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is |
27 | // improved |
28 | void record_stream_any_impl(Variable& var, c10::Stream& stream) { |
29 | const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA); |
30 | |
31 | if (C10_UNLIKELY(at::isBatchedTensor(var))) { |
32 | auto* impl = at::maybeGetBatchedImpl(var); |
33 | if (impl) { |
34 | guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream); |
35 | } else { |
36 | TORCH_INTERNAL_ASSERT(false, "Expected batched tensor" ); |
37 | } |
38 | } else { |
39 | switch (var.layout()) { |
40 | case c10::kSparseCsr: |
41 | case c10::kSparseCsc: |
42 | case c10::kSparseBsr: |
43 | case c10::kSparseBsc: { |
44 | auto* impl = at::sparse_csr::get_sparse_csr_impl(var); |
45 | guard.recordDataPtrOnStream( |
46 | impl->values().storage().data_ptr(), stream); |
47 | guard.recordDataPtrOnStream( |
48 | impl->compressed_indices().storage().data_ptr(), stream); |
49 | guard.recordDataPtrOnStream( |
50 | impl->plain_indices().storage().data_ptr(), stream); |
51 | break; |
52 | } |
53 | case c10::kSparse: { |
54 | auto* impl = at::sparse::get_sparse_impl(var); |
55 | guard.recordDataPtrOnStream( |
56 | impl->values().storage().data_ptr(), stream); |
57 | guard.recordDataPtrOnStream( |
58 | impl->indices().storage().data_ptr(), stream); |
59 | break; |
60 | } |
61 | case c10::kStrided: |
62 | guard.recordDataPtrOnStream(var.storage().data_ptr(), stream); |
63 | break; |
64 | default: |
65 | TORCH_INTERNAL_ASSERT( |
66 | false, "Unknown layout in record_stream_any_impl" ); |
67 | } |
68 | } |
69 | } |
70 | |
71 | bool can_accumulate_inplace(const Variable& v) { |
72 | return ( |
73 | // `v` is a "vanilla" Tensor |
74 | !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) && |
75 | |
76 | // with a favorable memory layout |
77 | v.is_non_overlapping_and_dense() && |
78 | |
79 | // and we hold the last reference |
80 | v.use_count() == 1 && v.has_storage() && v.storage().use_count() == 1); |
81 | } |
82 | } // anonymous namespace |
83 | |
84 | static void accumulate( |
85 | std::vector<Variable>& buffer, |
86 | const size_t pos, |
87 | Variable&& var) { |
88 | TORCH_INTERNAL_ASSERT(pos < buffer.size()); |
89 | auto& old_var = buffer[pos]; |
90 | // If we hold the last reference to `old_var` AND its storage we will try to |
91 | // repurpose it to store the output. (Or, if `old_var` is sparse then `var` |
92 | // becomes the candidate output Tensor.) We only do this if: |
93 | // 1) GradMode is disabled since Autograd has special handling for inplace |
94 | // mutation which we don't want to trigger. |
95 | // |
96 | // 2) We hold the last reference. |
97 | // (Both `.use_count` and `.storage().use_count()` are one) |
98 | // |
99 | // 3) The candidate tensor is a contiguous, non-overlapping, dense, and |
100 | // otherwise stock standard Tensor. |
101 | // |
102 | // 4) The candidate is mutable. Currently only ZeroTensors are immutable. |
103 | // |
104 | // 5) The other Tensor is not a Tensor subclass (except sparse), since |
105 | // it's hard to predict the semantics of arbitrary subclass behavior. |
106 | |
107 | if (at::GradMode::is_enabled()) { |
108 | buffer[pos] = old_var + var; |
109 | } else if ( |
110 | // ATen doesn't route sparse additions correctly... |
111 | old_var.is_sparse() || old_var.is_sparse_csr()) { |
112 | if (can_accumulate_inplace(var)) { |
113 | buffer[pos] = var.add_(old_var); |
114 | } else { |
115 | buffer[pos] = var + old_var; |
116 | } |
117 | } else if ( |
118 | can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) { |
119 | buffer[pos] = old_var.add_(var); |
120 | } else { |
121 | buffer[pos] = old_var + var; |
122 | } |
123 | } |
124 | |
125 | void InputBuffer::add( |
126 | size_t pos, |
127 | Variable&& var, |
128 | const c10::optional<c10::Stream>& opt_producer_stream, |
129 | const c10::optional<c10::Stream>& opt_consumer_stream) { |
130 | TORCH_INTERNAL_ASSERT(pos < buffer.size()); |
131 | if (!var.defined()) { |
132 | return; |
133 | } |
134 | |
135 | // Switches to accumulate device |
136 | // The device (and stream) chosen for accumulation is: |
137 | // (1) var is not a CUDA variable. Accumulation happens on var's device. |
138 | // (2) var is a CUDA variable and it, the consumer, and the producer share |
139 | // the same device: |
140 | // (2a) Uses the consumer's stream as the accumulation stream |
141 | // (2b) Syncs the accumulation stream with the producer's stream (if |
142 | // different) (2c) Accumulates. |
143 | // (3) var is a CUDA variable and it shares a device with the consumer but |
144 | // not the producer: |
145 | // (3a) Uses the consumer's stream as the accumulation stream |
146 | // (3b) Syncs the accumulation stream with the consumer device's default |
147 | // stream (3c) Accumulates. |
148 | // (4) var is a CUDA variable and it shares a device with the producer but |
149 | // not the consumer: |
150 | // (4a) Uses the producer device's default stream as the accumulation |
151 | // stream (4b) Syncs the accumulation stream with the the producer's |
152 | // stream (4c) Accumulates. |
153 | // (5) var is a CUDA variable and it does not share a device with the |
154 | // consumer or producer. |
155 | // Accumulation happens on the var device's default stream. |
156 | |
157 | TORCH_INTERNAL_ASSERT(device_of(var)); |
158 | c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt; |
159 | if (device_of(var)->is_cuda()) { |
160 | const auto on_producer = |
161 | opt_producer_stream && device_of(var) == opt_producer_stream->device(); |
162 | const auto on_consumer = |
163 | opt_consumer_stream && device_of(var) == opt_consumer_stream->device(); |
164 | |
165 | if (on_producer && on_consumer) { |
166 | // (2a) |
167 | opt_accumulate_stream = opt_consumer_stream; |
168 | if (opt_accumulate_stream != opt_producer_stream) { |
169 | // (2b) |
170 | auto event = c10::Event{c10::DeviceType::CUDA}; |
171 | event.record(*opt_producer_stream); |
172 | opt_accumulate_stream->wait(event); |
173 | record_stream_any_impl(var, *opt_accumulate_stream); |
174 | } |
175 | } else { |
176 | c10::optional<c10::Stream> opt_sync_stream = c10::nullopt; |
177 | const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; |
178 | if (on_consumer && !on_producer) { |
179 | // (3a) |
180 | opt_accumulate_stream = opt_consumer_stream; |
181 | opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device()); |
182 | } else if (on_producer && !on_consumer) { |
183 | // (4a) |
184 | opt_accumulate_stream = |
185 | guard.getDefaultStream(opt_producer_stream->device()); |
186 | opt_sync_stream = opt_producer_stream; |
187 | } else { |
188 | // (5) |
189 | opt_accumulate_stream = guard.getDefaultStream(*device_of(var)); |
190 | } |
191 | if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) { |
192 | // (3b), (4b) |
193 | c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()}; |
194 | auto event = c10::Event{c10::DeviceType::CUDA}; |
195 | event.record(*opt_sync_stream); |
196 | opt_accumulate_stream->wait(event); |
197 | const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA); |
198 | record_stream_any_impl(var, *opt_accumulate_stream); |
199 | } |
200 | } |
201 | } |
202 | |
203 | auto& old_var = buffer[pos]; |
204 | if (!old_var.defined()) { |
205 | buffer[pos] = std::move(var); |
206 | } else { |
207 | if (opt_accumulate_stream) { |
208 | c10::OptionalStreamGuard stream_guard{opt_accumulate_stream}; |
209 | accumulate(buffer, pos, std::move(var)); |
210 | } else { |
211 | // (1) non-CUDA variable |
212 | // Accumulation happens on variable's device |
213 | c10::OptionalDeviceGuard device_guard{device_of(var)}; |
214 | accumulate(buffer, pos, std::move(var)); |
215 | } |
216 | } |
217 | } |
218 | |
219 | auto InputBuffer::device() const -> at::Device { |
220 | // Since we pick the first non-CPU tensor, this won't work with |
221 | // mixed device-type operations (e.g., an op that is both CUDA |
222 | // and XLA). This is *incredibly* unlikely, so we don't worry |
223 | // about it. |
224 | for (auto& var : buffer) { |
225 | if (var.defined()) { |
226 | auto device = var.device(); |
227 | if (device.type() != at::kCPU) { |
228 | return device; |
229 | } |
230 | } |
231 | } |
232 | // Only report to the CPU thread if there really were no tensors |
233 | // from other devices. |
234 | return at::kCPU; |
235 | } |
236 | |
237 | auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> { |
238 | std::vector<Variable> result = std::move(g.buffer); |
239 | return result; |
240 | } |
241 | |
242 | } // namespace autograd |
243 | } // namespace torch |
244 | |