1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <c10/util/irange.h>
3#include <torch/csrc/autograd/variable.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#else
8#include <ATen/ops/_has_same_storage_numel.h>
9#include <ATen/ops/_new_zeros_with_same_feature_meta.h>
10#include <ATen/ops/zeros.h>
11#endif
12
13namespace torch {
14namespace autograd {
15
16using at::Tensor;
17
18// [Forward Grad View/inplace]
19// It is important to us to allow view and inplace to work with dual Tensors.
20// These operations should either compute the right gradient or raise a
21// user-friendly error.
22
23// The basic case where all Tensors are dual Tensors is as follows:
24// # Have:
25// # foo is a dual Tensor that is not a view
26// # bar is a dual Tensor of appropriate size (depending on cases) that is
27// not a view
28//
29// # Case 1: no view
30// foo.copy_(bar)
31//
32// # Case 2: with view, propagate from view to base
33// view = foo[0]
34// view.copy_(bar)
35//
36// # Case 3: with view, propagate from base to view
37// view = foo[0]
38// foo.copy_(bar)
39//
40// # In both cases, the forward grad of foo must be properly updated.
41// # In the second and third cases, the forward grad of view must match
42// # the one of foo for the subset they have in common.
43//
44// All these cases can be handled by the following layout constraint on the
45// forward grad:
46// - A Tensor and its forward grad (for all levels) must have the same
47// metadata (size, stride
48// conj/neg bit and storage offset). Storage offset must be in this metadata
49// because of as_strided. conj/neg bit must be part of this metadata because
50// of ops like `real`.
51// - View operations must create a forward grad that is a view of the base's
52// forward grad.
53// - Inplace operations must modify the input's forward grad inplace.
54//
55// This layout constraint is ensured in the `set_fw_grad` function below
56
57// More complex cases arrise when non-dual Tensor interact with dual Tensors.
58// The two most important cases are:
59//
60// # Have:
61// # foo is a regular Tensor that is not a view
62// # bar is a dual Tensor of appropriate size (depending on cases) that is
63// not a view
64//
65// # Case 4: Changes on the view must propagate to its base
66// view = foo[0]
67// # view is still a regular Tensor here
68// view.copy_(bar)
69// # Now both view and foo are dual Tensor with appropriate forward grad
70//
71// # Case 5: Changes on the base must propagate on all its views
72// view = foo[0]
73// # view is still a regular Tensor here
74// base.copy_(bar)
75// # Now both view and foo are dual Tensor with appropriate forward grad
76//
77// # NB there is a case 6 involving changes on a view propagating to other
78// views # but it is fully described by the two others and is skipped in
79// this discussion.
80//
81// Case 4 is handled by set_fw_grad by properly setting the forward grad of the
82// base if needed. Case 5 is handled in fw_grad by reading the forward grad from
83// the base if needed.
84
85namespace utils {
86
87// Enforcing that the metadata between the primal and tangent are same has two
88// goals:
89// - When properties of the primal are checked in composite op's to determine
90// control flow, the code path decided upon is also reasonable for the tangent
91// - Make sure that when the same as_strided is applied to both primal and
92// and tangent, it behaves similarly.
93//
94// We do that by checking:
95// 1) the storages have same properties: size and conj/neg-ness
96// 2) the same indices refer to the same elements in storage
97// (we are more strict than necessary here to satisfy the goal 1)
98bool has_same_meta(const Variable& base, const Variable& other) {
99 if (!base.defined() || !other.defined()) {
100 return false;
101 }
102 // 1) The storages have the same properties
103 if (!at::_has_same_storage_numel(base, other)) {
104 return false;
105 }
106 if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) {
107 return false;
108 }
109
110 // Technically dim and size belong as part of (2), so we shouldn't really care
111 // if a zero-numel tensor violates these. But since these properties
112 // (unlike offset and strides) often determine control flow in composite ops
113 // it is useful to enforce that they match for primal and tangent here so
114 // nothing funny happens later (See goal 1).
115 if (base.dim() != other.dim()) {
116 return false;
117 }
118 for (const auto i : c10::irange(base.dim())) {
119 if (base.sizes()[i] != other.sizes()[i]) {
120 return false;
121 }
122 }
123
124 // The check below will always be vacuously true for 0-element tensors
125 if (base.numel() == 0 && other.numel() == 0) {
126 return true;
127 }
128
129 // 2) The same indices refer to the same elements in storage
130 if (base.storage_offset() != other.storage_offset()) {
131 return false;
132 }
133
134 for (const auto i : c10::irange(base.dim())) {
135 if (base.strides()[i] != other.strides()[i] && base.sizes()[i] != 1 &&
136 base.sizes()[i] != 0) {
137 return false;
138 }
139 }
140 return true;
141}
142
143} // namespace utils
144
145// This function is will ensure that the fw_grad_ is properly a view of the base
146// for inplace ops on Tensors that do not have forward grad originally.
147void AutogradMeta::set_fw_grad(
148 const at::TensorBase& new_grad_base,
149 const at::TensorBase& self_base,
150 uint64_t level,
151 bool is_inplace_op) {
152 TORCH_CHECK(
153 !new_grad_base._fw_grad(level).defined(),
154 "Setting a forward grad that "
155 "itself has a forward gradient at the same level",
156 level,
157 " is not supported.");
158 TORCH_INTERNAL_ASSERT(
159 (new_grad_base.is_floating_point() || new_grad_base.is_complex()) &&
160 (self_base.is_floating_point() || self_base.is_complex()),
161 "Expected both tensor and its forward grad to be floating point or complex");
162 // Lazy initialization
163 {
164 std::lock_guard<std::mutex> lock(mutex_);
165 if (!fw_grad_) {
166 fw_grad_ = std::make_shared<ForwardGrad>();
167 }
168 }
169 if (fw_grad_->contains(level)) {
170 // Setting the forward grad again is only allowed if it is a no-op.
171 // We do allow this case to simplify writing codegen for inplace ops.
172 TORCH_INTERNAL_ASSERT(
173 new_grad_base.defined(),
174 "Cannot set a forward grad that is an undefined Tensor. Use "
175 "_fw_primal(level) to get a new Tensor with this forward grad unset.");
176
177 TORCH_INTERNAL_ASSERT(
178 is_inplace_op,
179 "Only inplace operations can re-set the forward grad of a Tensor that "
180 "already has one.");
181
182 TORCH_INTERNAL_ASSERT(
183 fw_grad_->value(level).is_same(new_grad_base),
184 "Cannot set a value of a forward grad if it "
185 "already exists. Inplace operations should modify it inplace.");
186 } else {
187 // TODO(alband) remove this spurious version counter bump
188 Tensor new_grad(new_grad_base);
189 at::OptionalTensorRef self_ref(self_base);
190 const Tensor& self = *self_ref;
191
192 TORCH_CHECK(
193 self.is_same_size(new_grad),
194 "Trying to set a forward gradient that has a different size than that "
195 "of the original Tensor, this is not supported. Tensor is of size ",
196 self.sizes(),
197 " while the given "
198 "forward gradient is of size ",
199 new_grad.sizes(),
200 ".");
201
202 if (is_inplace_op && is_view_) {
203 auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
204
205 // For inplace ops on a Tensor that does not already have a forward grad
206 // and is a view, we propagate the tangent to the base and ensure that the
207 // new_grad is a view of that base's tangent. This ensure that case 4 from
208 // [Forward Grad View/inplace] above works fine What happens in this long
209 // if statement is:
210 // - Check if the base already has a grad
211 // - If not, set a new fw_grad for it full of zeros
212 // - Take a view of the base's forward grad
213 // - Copy the given new_grad into this view
214 // - Use this view as the new new_grad
215 if (this_view_meta->has_fw_view()) {
216 auto view_info = this_view_meta->get_forward_view();
217 auto& base = view_info.base_;
218
219 if (!base._fw_grad(level).defined()) {
220 // Enforce same meta here to make sure that the view op below is
221 // always valid
222 Tensor new_base_fw_grad;
223 if (utils::has_same_meta(new_grad, base) &&
224 utils::has_same_meta(new_grad, self)) {
225 // TODO extend this special case to when the underlying storage of
226 // new_grad can be re-used.
227 new_base_fw_grad = new_grad;
228 } else {
229 new_base_fw_grad =
230 at::_new_zeros_with_same_feature_meta(new_grad, base);
231 new_base_fw_grad._set_conj(base.is_conj());
232 new_base_fw_grad._set_neg(base.is_neg());
233
234 // Update new_grad to be a view of the base
235 Tensor new_fw_grad_value;
236 if (view_info.has_view_fn()) {
237 new_fw_grad_value = view_info.view_fn()(new_base_fw_grad);
238 } else {
239 new_fw_grad_value = new_base_fw_grad.as_strided(
240 self.sizes(), self.strides(), self.storage_offset());
241 }
242
243 new_fw_grad_value.copy_(new_grad);
244 new_grad = new_fw_grad_value;
245 }
246
247 base._set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false);
248 }
249 }
250 }
251
252 // Enforce the basic layout constraint
253 if (!utils::has_same_meta(new_grad, self)) {
254 if (is_view_) {
255 auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
256 TORCH_INTERNAL_ASSERT(
257 !this_view_meta->has_fw_view(),
258 "Expected the output of forward differentiable view operations to have the tangent have the same layout as primal")
259 }
260 auto res = at::_new_zeros_with_same_feature_meta(new_grad, self);
261 res._set_conj(self.is_conj());
262 res._set_neg(self.is_neg());
263 res.copy_(new_grad);
264 new_grad = res;
265 }
266
267 fw_grad_->set_value(new_grad, level);
268 }
269}
270
271const Variable& AutogradMeta::fw_grad(
272 uint64_t level,
273 const at::TensorBase& self) const {
274 // TLS that disables forward AD.
275 if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
276 return ForwardGrad::undef_grad();
277 }
278
279 // Ensure that concurent fw_grad() "reads" are thread safe
280 std::lock_guard<std::mutex> lock(mutex_);
281
282 const auto& direct_fw_grad =
283 fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad();
284
285 if (!direct_fw_grad.defined() && is_view_) {
286 // For view that don't have a forward grad, check if their base has one that
287 // has been defined by an inplace operation.
288 // This ensure that case 5 from [Forward Grad View/inplace] above works fine
289 auto const_view_meta =
290 static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
291 // This is ok to do as we ONLY modify fw_grad_ and this field is properly
292 // locked in all methods
293 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
294 auto this_view_meta =
295 const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta);
296 if (this_view_meta->has_fw_view()) {
297 const auto& view_info = this_view_meta->get_forward_view();
298 const auto& base = view_info.base_;
299
300 const auto& base_val = base._fw_grad(level);
301 if (base_val.defined()) {
302 // Lazy initialization of fw_grad_
303 this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
304
305 Variable new_val;
306 if (view_info.has_view_fn()) {
307 new_val = view_info.view_fn()(base_val);
308 } else {
309 new_val = base_val.as_strided(
310 self.sizes(), self.strides(), self.storage_offset());
311 }
312
313 this_view_meta->fw_grad_->set_value(new_val, level);
314 return this_view_meta->fw_grad_->value(level);
315 }
316 }
317 }
318 return direct_fw_grad;
319}
320
321} // namespace autograd
322} // namespace torch
323