1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <c10/util/irange.h> |
3 | #include <torch/csrc/autograd/variable.h> |
4 | |
5 | #ifndef AT_PER_OPERATOR_HEADERS |
6 | #include <ATen/Functions.h> |
7 | #else |
8 | #include <ATen/ops/_has_same_storage_numel.h> |
9 | #include <ATen/ops/_new_zeros_with_same_feature_meta.h> |
10 | #include <ATen/ops/zeros.h> |
11 | #endif |
12 | |
13 | namespace torch { |
14 | namespace autograd { |
15 | |
16 | using at::Tensor; |
17 | |
18 | // [Forward Grad View/inplace] |
19 | // It is important to us to allow view and inplace to work with dual Tensors. |
20 | // These operations should either compute the right gradient or raise a |
21 | // user-friendly error. |
22 | |
23 | // The basic case where all Tensors are dual Tensors is as follows: |
24 | // # Have: |
25 | // # foo is a dual Tensor that is not a view |
26 | // # bar is a dual Tensor of appropriate size (depending on cases) that is |
27 | // not a view |
28 | // |
29 | // # Case 1: no view |
30 | // foo.copy_(bar) |
31 | // |
32 | // # Case 2: with view, propagate from view to base |
33 | // view = foo[0] |
34 | // view.copy_(bar) |
35 | // |
36 | // # Case 3: with view, propagate from base to view |
37 | // view = foo[0] |
38 | // foo.copy_(bar) |
39 | // |
40 | // # In both cases, the forward grad of foo must be properly updated. |
41 | // # In the second and third cases, the forward grad of view must match |
42 | // # the one of foo for the subset they have in common. |
43 | // |
44 | // All these cases can be handled by the following layout constraint on the |
45 | // forward grad: |
46 | // - A Tensor and its forward grad (for all levels) must have the same |
47 | // metadata (size, stride |
48 | // conj/neg bit and storage offset). Storage offset must be in this metadata |
49 | // because of as_strided. conj/neg bit must be part of this metadata because |
50 | // of ops like `real`. |
51 | // - View operations must create a forward grad that is a view of the base's |
52 | // forward grad. |
53 | // - Inplace operations must modify the input's forward grad inplace. |
54 | // |
55 | // This layout constraint is ensured in the `set_fw_grad` function below |
56 | |
57 | // More complex cases arrise when non-dual Tensor interact with dual Tensors. |
58 | // The two most important cases are: |
59 | // |
60 | // # Have: |
61 | // # foo is a regular Tensor that is not a view |
62 | // # bar is a dual Tensor of appropriate size (depending on cases) that is |
63 | // not a view |
64 | // |
65 | // # Case 4: Changes on the view must propagate to its base |
66 | // view = foo[0] |
67 | // # view is still a regular Tensor here |
68 | // view.copy_(bar) |
69 | // # Now both view and foo are dual Tensor with appropriate forward grad |
70 | // |
71 | // # Case 5: Changes on the base must propagate on all its views |
72 | // view = foo[0] |
73 | // # view is still a regular Tensor here |
74 | // base.copy_(bar) |
75 | // # Now both view and foo are dual Tensor with appropriate forward grad |
76 | // |
77 | // # NB there is a case 6 involving changes on a view propagating to other |
78 | // views # but it is fully described by the two others and is skipped in |
79 | // this discussion. |
80 | // |
81 | // Case 4 is handled by set_fw_grad by properly setting the forward grad of the |
82 | // base if needed. Case 5 is handled in fw_grad by reading the forward grad from |
83 | // the base if needed. |
84 | |
85 | namespace utils { |
86 | |
87 | // Enforcing that the metadata between the primal and tangent are same has two |
88 | // goals: |
89 | // - When properties of the primal are checked in composite op's to determine |
90 | // control flow, the code path decided upon is also reasonable for the tangent |
91 | // - Make sure that when the same as_strided is applied to both primal and |
92 | // and tangent, it behaves similarly. |
93 | // |
94 | // We do that by checking: |
95 | // 1) the storages have same properties: size and conj/neg-ness |
96 | // 2) the same indices refer to the same elements in storage |
97 | // (we are more strict than necessary here to satisfy the goal 1) |
98 | bool has_same_meta(const Variable& base, const Variable& other) { |
99 | if (!base.defined() || !other.defined()) { |
100 | return false; |
101 | } |
102 | // 1) The storages have the same properties |
103 | if (!at::_has_same_storage_numel(base, other)) { |
104 | return false; |
105 | } |
106 | if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) { |
107 | return false; |
108 | } |
109 | |
110 | // Technically dim and size belong as part of (2), so we shouldn't really care |
111 | // if a zero-numel tensor violates these. But since these properties |
112 | // (unlike offset and strides) often determine control flow in composite ops |
113 | // it is useful to enforce that they match for primal and tangent here so |
114 | // nothing funny happens later (See goal 1). |
115 | if (base.dim() != other.dim()) { |
116 | return false; |
117 | } |
118 | for (const auto i : c10::irange(base.dim())) { |
119 | if (base.sizes()[i] != other.sizes()[i]) { |
120 | return false; |
121 | } |
122 | } |
123 | |
124 | // The check below will always be vacuously true for 0-element tensors |
125 | if (base.numel() == 0 && other.numel() == 0) { |
126 | return true; |
127 | } |
128 | |
129 | // 2) The same indices refer to the same elements in storage |
130 | if (base.storage_offset() != other.storage_offset()) { |
131 | return false; |
132 | } |
133 | |
134 | for (const auto i : c10::irange(base.dim())) { |
135 | if (base.strides()[i] != other.strides()[i] && base.sizes()[i] != 1 && |
136 | base.sizes()[i] != 0) { |
137 | return false; |
138 | } |
139 | } |
140 | return true; |
141 | } |
142 | |
143 | } // namespace utils |
144 | |
145 | // This function is will ensure that the fw_grad_ is properly a view of the base |
146 | // for inplace ops on Tensors that do not have forward grad originally. |
147 | void AutogradMeta::set_fw_grad( |
148 | const at::TensorBase& new_grad_base, |
149 | const at::TensorBase& self_base, |
150 | uint64_t level, |
151 | bool is_inplace_op) { |
152 | TORCH_CHECK( |
153 | !new_grad_base._fw_grad(level).defined(), |
154 | "Setting a forward grad that " |
155 | "itself has a forward gradient at the same level" , |
156 | level, |
157 | " is not supported." ); |
158 | TORCH_INTERNAL_ASSERT( |
159 | (new_grad_base.is_floating_point() || new_grad_base.is_complex()) && |
160 | (self_base.is_floating_point() || self_base.is_complex()), |
161 | "Expected both tensor and its forward grad to be floating point or complex" ); |
162 | // Lazy initialization |
163 | { |
164 | std::lock_guard<std::mutex> lock(mutex_); |
165 | if (!fw_grad_) { |
166 | fw_grad_ = std::make_shared<ForwardGrad>(); |
167 | } |
168 | } |
169 | if (fw_grad_->contains(level)) { |
170 | // Setting the forward grad again is only allowed if it is a no-op. |
171 | // We do allow this case to simplify writing codegen for inplace ops. |
172 | TORCH_INTERNAL_ASSERT( |
173 | new_grad_base.defined(), |
174 | "Cannot set a forward grad that is an undefined Tensor. Use " |
175 | "_fw_primal(level) to get a new Tensor with this forward grad unset." ); |
176 | |
177 | TORCH_INTERNAL_ASSERT( |
178 | is_inplace_op, |
179 | "Only inplace operations can re-set the forward grad of a Tensor that " |
180 | "already has one." ); |
181 | |
182 | TORCH_INTERNAL_ASSERT( |
183 | fw_grad_->value(level).is_same(new_grad_base), |
184 | "Cannot set a value of a forward grad if it " |
185 | "already exists. Inplace operations should modify it inplace." ); |
186 | } else { |
187 | // TODO(alband) remove this spurious version counter bump |
188 | Tensor new_grad(new_grad_base); |
189 | at::OptionalTensorRef self_ref(self_base); |
190 | const Tensor& self = *self_ref; |
191 | |
192 | TORCH_CHECK( |
193 | self.is_same_size(new_grad), |
194 | "Trying to set a forward gradient that has a different size than that " |
195 | "of the original Tensor, this is not supported. Tensor is of size " , |
196 | self.sizes(), |
197 | " while the given " |
198 | "forward gradient is of size " , |
199 | new_grad.sizes(), |
200 | "." ); |
201 | |
202 | if (is_inplace_op && is_view_) { |
203 | auto this_view_meta = static_cast<DifferentiableViewMeta*>(this); |
204 | |
205 | // For inplace ops on a Tensor that does not already have a forward grad |
206 | // and is a view, we propagate the tangent to the base and ensure that the |
207 | // new_grad is a view of that base's tangent. This ensure that case 4 from |
208 | // [Forward Grad View/inplace] above works fine What happens in this long |
209 | // if statement is: |
210 | // - Check if the base already has a grad |
211 | // - If not, set a new fw_grad for it full of zeros |
212 | // - Take a view of the base's forward grad |
213 | // - Copy the given new_grad into this view |
214 | // - Use this view as the new new_grad |
215 | if (this_view_meta->has_fw_view()) { |
216 | auto view_info = this_view_meta->get_forward_view(); |
217 | auto& base = view_info.base_; |
218 | |
219 | if (!base._fw_grad(level).defined()) { |
220 | // Enforce same meta here to make sure that the view op below is |
221 | // always valid |
222 | Tensor new_base_fw_grad; |
223 | if (utils::has_same_meta(new_grad, base) && |
224 | utils::has_same_meta(new_grad, self)) { |
225 | // TODO extend this special case to when the underlying storage of |
226 | // new_grad can be re-used. |
227 | new_base_fw_grad = new_grad; |
228 | } else { |
229 | new_base_fw_grad = |
230 | at::_new_zeros_with_same_feature_meta(new_grad, base); |
231 | new_base_fw_grad._set_conj(base.is_conj()); |
232 | new_base_fw_grad._set_neg(base.is_neg()); |
233 | |
234 | // Update new_grad to be a view of the base |
235 | Tensor new_fw_grad_value; |
236 | if (view_info.has_view_fn()) { |
237 | new_fw_grad_value = view_info.view_fn()(new_base_fw_grad); |
238 | } else { |
239 | new_fw_grad_value = new_base_fw_grad.as_strided( |
240 | self.sizes(), self.strides(), self.storage_offset()); |
241 | } |
242 | |
243 | new_fw_grad_value.copy_(new_grad); |
244 | new_grad = new_fw_grad_value; |
245 | } |
246 | |
247 | base._set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false); |
248 | } |
249 | } |
250 | } |
251 | |
252 | // Enforce the basic layout constraint |
253 | if (!utils::has_same_meta(new_grad, self)) { |
254 | if (is_view_) { |
255 | auto this_view_meta = static_cast<DifferentiableViewMeta*>(this); |
256 | TORCH_INTERNAL_ASSERT( |
257 | !this_view_meta->has_fw_view(), |
258 | "Expected the output of forward differentiable view operations to have the tangent have the same layout as primal" ) |
259 | } |
260 | auto res = at::_new_zeros_with_same_feature_meta(new_grad, self); |
261 | res._set_conj(self.is_conj()); |
262 | res._set_neg(self.is_neg()); |
263 | res.copy_(new_grad); |
264 | new_grad = res; |
265 | } |
266 | |
267 | fw_grad_->set_value(new_grad, level); |
268 | } |
269 | } |
270 | |
271 | const Variable& AutogradMeta::fw_grad( |
272 | uint64_t level, |
273 | const at::TensorBase& self) const { |
274 | // TLS that disables forward AD. |
275 | if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) { |
276 | return ForwardGrad::undef_grad(); |
277 | } |
278 | |
279 | // Ensure that concurent fw_grad() "reads" are thread safe |
280 | std::lock_guard<std::mutex> lock(mutex_); |
281 | |
282 | const auto& direct_fw_grad = |
283 | fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad(); |
284 | |
285 | if (!direct_fw_grad.defined() && is_view_) { |
286 | // For view that don't have a forward grad, check if their base has one that |
287 | // has been defined by an inplace operation. |
288 | // This ensure that case 5 from [Forward Grad View/inplace] above works fine |
289 | auto const_view_meta = |
290 | static_cast<const torch::autograd::DifferentiableViewMeta*>(this); |
291 | // This is ok to do as we ONLY modify fw_grad_ and this field is properly |
292 | // locked in all methods |
293 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
294 | auto this_view_meta = |
295 | const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta); |
296 | if (this_view_meta->has_fw_view()) { |
297 | const auto& view_info = this_view_meta->get_forward_view(); |
298 | const auto& base = view_info.base_; |
299 | |
300 | const auto& base_val = base._fw_grad(level); |
301 | if (base_val.defined()) { |
302 | // Lazy initialization of fw_grad_ |
303 | this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>(); |
304 | |
305 | Variable new_val; |
306 | if (view_info.has_view_fn()) { |
307 | new_val = view_info.view_fn()(base_val); |
308 | } else { |
309 | new_val = base_val.as_strided( |
310 | self.sizes(), self.strides(), self.storage_offset()); |
311 | } |
312 | |
313 | this_view_meta->fw_grad_->set_value(new_val, level); |
314 | return this_view_meta->fw_grad_->value(level); |
315 | } |
316 | } |
317 | } |
318 | return direct_fw_grad; |
319 | } |
320 | |
321 | } // namespace autograd |
322 | } // namespace torch |
323 | |