1#include <torch/csrc/autograd/input_buffer.h>
2
3#include <ATen/LegacyBatchedTensorImpl.h>
4#include <ATen/SparseCsrTensorUtils.h>
5#include <ATen/SparseTensorUtils.h>
6#include <ATen/TensorOperators.h>
7#include <ATen/TensorSubclassLikeUtils.h>
8
9#include <c10/core/DeviceGuard.h>
10#include <c10/core/Event.h>
11#include <c10/core/StreamGuard.h>
12#include <c10/util/Optional.h>
13
14#include <cstddef>
15#include <utility>
16#include <vector>
17
18namespace torch {
19namespace autograd {
20
21namespace {
22// look what you made me do >.<
23// Divergent paths for per-Impl stream recording that leak implementation
24// details of the impls should not be needed here.
25// See https://github.com/pytorch/pytorch/issues/60306
26// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
27// improved
28void record_stream_any_impl(Variable& var, c10::Stream& stream) {
29 const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA);
30
31 if (C10_UNLIKELY(at::isBatchedTensor(var))) {
32 auto* impl = at::maybeGetBatchedImpl(var);
33 if (impl) {
34 guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
35 } else {
36 TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
37 }
38 } else {
39 switch (var.layout()) {
40 case c10::kSparseCsr:
41 case c10::kSparseCsc:
42 case c10::kSparseBsr:
43 case c10::kSparseBsc: {
44 auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
45 guard.recordDataPtrOnStream(
46 impl->values().storage().data_ptr(), stream);
47 guard.recordDataPtrOnStream(
48 impl->compressed_indices().storage().data_ptr(), stream);
49 guard.recordDataPtrOnStream(
50 impl->plain_indices().storage().data_ptr(), stream);
51 break;
52 }
53 case c10::kSparse: {
54 auto* impl = at::sparse::get_sparse_impl(var);
55 guard.recordDataPtrOnStream(
56 impl->values().storage().data_ptr(), stream);
57 guard.recordDataPtrOnStream(
58 impl->indices().storage().data_ptr(), stream);
59 break;
60 }
61 case c10::kStrided:
62 guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
63 break;
64 default:
65 TORCH_INTERNAL_ASSERT(
66 false, "Unknown layout in record_stream_any_impl");
67 }
68 }
69}
70
71bool can_accumulate_inplace(const Variable& v) {
72 return (
73 // `v` is a "vanilla" Tensor
74 !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) &&
75
76 // with a favorable memory layout
77 v.is_non_overlapping_and_dense() &&
78
79 // and we hold the last reference
80 v.use_count() == 1 && v.has_storage() && v.storage().use_count() == 1);
81}
82} // anonymous namespace
83
84static void accumulate(
85 std::vector<Variable>& buffer,
86 const size_t pos,
87 Variable&& var) {
88 TORCH_INTERNAL_ASSERT(pos < buffer.size());
89 auto& old_var = buffer[pos];
90 // If we hold the last reference to `old_var` AND its storage we will try to
91 // repurpose it to store the output. (Or, if `old_var` is sparse then `var`
92 // becomes the candidate output Tensor.) We only do this if:
93 // 1) GradMode is disabled since Autograd has special handling for inplace
94 // mutation which we don't want to trigger.
95 //
96 // 2) We hold the last reference.
97 // (Both `.use_count` and `.storage().use_count()` are one)
98 //
99 // 3) The candidate tensor is a contiguous, non-overlapping, dense, and
100 // otherwise stock standard Tensor.
101 //
102 // 4) The candidate is mutable. Currently only ZeroTensors are immutable.
103 //
104 // 5) The other Tensor is not a Tensor subclass (except sparse), since
105 // it's hard to predict the semantics of arbitrary subclass behavior.
106
107 if (at::GradMode::is_enabled()) {
108 buffer[pos] = old_var + var;
109 } else if (
110 // ATen doesn't route sparse additions correctly...
111 old_var.is_sparse() || old_var.is_sparse_csr()) {
112 if (can_accumulate_inplace(var)) {
113 buffer[pos] = var.add_(old_var);
114 } else {
115 buffer[pos] = var + old_var;
116 }
117 } else if (
118 can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) {
119 buffer[pos] = old_var.add_(var);
120 } else {
121 buffer[pos] = old_var + var;
122 }
123}
124
125void InputBuffer::add(
126 size_t pos,
127 Variable&& var,
128 const c10::optional<c10::Stream>& opt_producer_stream,
129 const c10::optional<c10::Stream>& opt_consumer_stream) {
130 TORCH_INTERNAL_ASSERT(pos < buffer.size());
131 if (!var.defined()) {
132 return;
133 }
134
135 // Switches to accumulate device
136 // The device (and stream) chosen for accumulation is:
137 // (1) var is not a CUDA variable. Accumulation happens on var's device.
138 // (2) var is a CUDA variable and it, the consumer, and the producer share
139 // the same device:
140 // (2a) Uses the consumer's stream as the accumulation stream
141 // (2b) Syncs the accumulation stream with the producer's stream (if
142 // different) (2c) Accumulates.
143 // (3) var is a CUDA variable and it shares a device with the consumer but
144 // not the producer:
145 // (3a) Uses the consumer's stream as the accumulation stream
146 // (3b) Syncs the accumulation stream with the consumer device's default
147 // stream (3c) Accumulates.
148 // (4) var is a CUDA variable and it shares a device with the producer but
149 // not the consumer:
150 // (4a) Uses the producer device's default stream as the accumulation
151 // stream (4b) Syncs the accumulation stream with the the producer's
152 // stream (4c) Accumulates.
153 // (5) var is a CUDA variable and it does not share a device with the
154 // consumer or producer.
155 // Accumulation happens on the var device's default stream.
156
157 TORCH_INTERNAL_ASSERT(device_of(var));
158 c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt;
159 if (device_of(var)->is_cuda()) {
160 const auto on_producer =
161 opt_producer_stream && device_of(var) == opt_producer_stream->device();
162 const auto on_consumer =
163 opt_consumer_stream && device_of(var) == opt_consumer_stream->device();
164
165 if (on_producer && on_consumer) {
166 // (2a)
167 opt_accumulate_stream = opt_consumer_stream;
168 if (opt_accumulate_stream != opt_producer_stream) {
169 // (2b)
170 auto event = c10::Event{c10::DeviceType::CUDA};
171 event.record(*opt_producer_stream);
172 opt_accumulate_stream->wait(event);
173 record_stream_any_impl(var, *opt_accumulate_stream);
174 }
175 } else {
176 c10::optional<c10::Stream> opt_sync_stream = c10::nullopt;
177 const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
178 if (on_consumer && !on_producer) {
179 // (3a)
180 opt_accumulate_stream = opt_consumer_stream;
181 opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
182 } else if (on_producer && !on_consumer) {
183 // (4a)
184 opt_accumulate_stream =
185 guard.getDefaultStream(opt_producer_stream->device());
186 opt_sync_stream = opt_producer_stream;
187 } else {
188 // (5)
189 opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
190 }
191 if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
192 // (3b), (4b)
193 c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
194 auto event = c10::Event{c10::DeviceType::CUDA};
195 event.record(*opt_sync_stream);
196 opt_accumulate_stream->wait(event);
197 const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA);
198 record_stream_any_impl(var, *opt_accumulate_stream);
199 }
200 }
201 }
202
203 auto& old_var = buffer[pos];
204 if (!old_var.defined()) {
205 buffer[pos] = std::move(var);
206 } else {
207 if (opt_accumulate_stream) {
208 c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
209 accumulate(buffer, pos, std::move(var));
210 } else {
211 // (1) non-CUDA variable
212 // Accumulation happens on variable's device
213 c10::OptionalDeviceGuard device_guard{device_of(var)};
214 accumulate(buffer, pos, std::move(var));
215 }
216 }
217}
218
219auto InputBuffer::device() const -> at::Device {
220 // Since we pick the first non-CPU tensor, this won't work with
221 // mixed device-type operations (e.g., an op that is both CUDA
222 // and XLA). This is *incredibly* unlikely, so we don't worry
223 // about it.
224 for (auto& var : buffer) {
225 if (var.defined()) {
226 auto device = var.device();
227 if (device.type() != at::kCPU) {
228 return device;
229 }
230 }
231 }
232 // Only report to the CPU thread if there really were no tensors
233 // from other devices.
234 return at::kCPU;
235}
236
237auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
238 std::vector<Variable> result = std::move(g.buffer);
239 return result;
240}
241
242} // namespace autograd
243} // namespace torch
244