1 | #include <ATen/RedispatchFunctions.h> |
2 | #include <ATen/TracerMode.h> |
3 | #include <ATen/core/op_registration/op_registration.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/autograd/FunctionsManual.h> |
8 | #include <torch/csrc/autograd/VariableTypeUtils.h> |
9 | #include <torch/csrc/autograd/autograd.h> |
10 | #include <torch/csrc/autograd/functions/utils.h> |
11 | #include <torch/csrc/utils/memory.h> |
12 | #include <torch/library.h> |
13 | |
14 | #include <utility> |
15 | |
16 | using namespace at; |
17 | using namespace torch::autograd::generated; |
18 | using torch::autograd::as_view; |
19 | using torch::autograd::CreationMeta; |
20 | |
21 | namespace torch { |
22 | namespace autograd { |
23 | namespace VariableType { |
24 | |
25 | std::vector<at::DeprecatedTypeProperties*> allTypesForBackends( |
26 | at::ArrayRef<at::Backend> backends) { |
27 | std::vector<DeprecatedTypeProperties*> res; |
28 | res.reserve(backends.size()); |
29 | for (auto p : backends) { |
30 | for (const auto s : |
31 | c10::irange(static_cast<int64_t>(ScalarType::NumOptions))) { |
32 | auto& type = getDeprecatedTypeProperties( |
33 | static_cast<Backend>(p), static_cast<ScalarType>(s)); |
34 | res.emplace_back(&type); |
35 | } |
36 | } |
37 | return res; |
38 | } |
39 | |
40 | C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCPUTypes() { |
41 | return allTypesForBackends({Backend::CPU, Backend::SparseCPU}); |
42 | } |
43 | |
44 | C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCUDATypes() { |
45 | at::globalContext().lazyInitCUDA(); |
46 | return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA}); |
47 | } |
48 | |
49 | namespace { |
50 | const Variable& checked_cast_variable( |
51 | const Tensor& t, |
52 | const char* name, |
53 | int pos) { |
54 | if (!t.defined()) { |
55 | AT_ERROR( |
56 | "Expected a proper Tensor but got None (or an undefined Tensor in C++) " , |
57 | "for argument #" , |
58 | pos, |
59 | " '" , |
60 | name, |
61 | "'" ); |
62 | } |
63 | return t; |
64 | } |
65 | |
66 | Variable& checked_cast_variable(Tensor& t, const char* name, int pos) { |
67 | if (!t.defined()) { |
68 | AT_ERROR( |
69 | "Expected a proper Tensor but got None (or an undefined Tensor in C++) " , |
70 | "for argument #" , |
71 | pos, |
72 | " '" , |
73 | name, |
74 | "'" ); |
75 | } |
76 | return t; |
77 | } |
78 | } // namespace |
79 | |
80 | const Tensor& unpack(const Tensor& t, const char* name, int pos) { |
81 | return checked_cast_variable(t, name, pos); |
82 | } |
83 | |
84 | Tensor& unpack(Tensor& t, const char* name, int pos) { |
85 | return checked_cast_variable(t, name, pos); |
86 | } |
87 | |
88 | Tensor unpack_opt(const Tensor& t, const char* name, int pos) { |
89 | if (!t.defined()) { |
90 | return Tensor(); |
91 | } |
92 | return unpack(t, name, pos); |
93 | } |
94 | |
95 | std::vector<at::Tensor> unpack( |
96 | at::ITensorListRef tl, |
97 | const char* name, |
98 | int pos) { |
99 | std::vector<at::Tensor> ret; |
100 | ret.reserve(tl.size()); |
101 | for (const auto& t : tl) { |
102 | ret.push_back(t.defined() ? static_cast<const Variable&>(t) : Variable{}); |
103 | } |
104 | return ret; |
105 | } |
106 | |
107 | namespace { |
108 | |
109 | // Taken from codegened version |
110 | Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { |
111 | auto& self_ = unpack(self, "self" , 0); |
112 | std::shared_ptr<Identity> grad_fn; |
113 | if (compute_requires_grad(self)) { |
114 | grad_fn = std::make_shared<Identity>(); |
115 | grad_fn->set_next_edges(collect_next_edges(self)); |
116 | } |
117 | |
118 | auto result = ([&]() { |
119 | at::AutoDispatchBelowAutograd guard; |
120 | return at::redispatch::_fw_primal( |
121 | ks & c10::after_autograd_keyset, self_, level); |
122 | })(); |
123 | |
124 | if (grad_fn) { |
125 | set_history(flatten_tensor_args(result), grad_fn); |
126 | } |
127 | if (isFwGradDefined(self)) { |
128 | // Modified from original codegen |
129 | // We explicitly want to ignore the forward grad at the given level |
130 | TORCH_CHECK(level == 0, "Invalid level given to _fw_primal" ); |
131 | // End modified from original codegen |
132 | } |
133 | return result; |
134 | } |
135 | |
136 | // NB: We need a manual variable type kernel so that set_fw_grad properly |
137 | // detects that _make_dual is not a forward-differentiable view |
138 | // |
139 | // This function can be used to create a dual Tensor that holds a tangent to |
140 | // compute forward mode gradients. Note that the dual Tensor's primal is a view |
141 | // of the given primal and the given tangent is used as-is. This function is |
142 | // backward differentiable. |
143 | Tensor _make_dual( |
144 | c10::DispatchKeySet ks, |
145 | const Tensor& primal, |
146 | const Tensor& tangent, |
147 | int64_t level) { |
148 | TORCH_CHECK( |
149 | !primal._fw_grad(level).defined(), |
150 | "Making a dual Tensor based on a Tensor that " |
151 | "already has a forward gradient at the same level " , |
152 | level, |
153 | " is not supported." ); |
154 | auto& primal_ = unpack(primal, "primal" , 0); |
155 | auto& tangent_ = unpack(tangent, "tangent" , 0); |
156 | std::shared_ptr<ViewBackward0> grad_fn; |
157 | if (compute_requires_grad(primal_)) { |
158 | grad_fn = std::make_shared<ViewBackward0>(); |
159 | grad_fn->self_sym_sizes = primal_.sym_sizes().vec(); |
160 | grad_fn->set_next_edges(collect_next_edges(primal_)); |
161 | } |
162 | |
163 | auto result = ([&]() { |
164 | at::AutoDispatchBelowAutograd guard; |
165 | return at::redispatch::_make_dual( |
166 | ks & c10::after_autograd_keyset, primal_, tangent_, level); |
167 | })(); |
168 | |
169 | if (grad_fn) { |
170 | set_history(flatten_tensor_args(result), grad_fn); |
171 | } |
172 | |
173 | TORCH_CHECK(level == 0, "Invalid level given to _make_dual" ); |
174 | result._set_fw_grad(tangent_, level, /* is_inplace_op */ false); |
175 | return result; |
176 | } |
177 | |
178 | // We don't have an outplace copy, so this can't be generated automatically |
179 | Tensor& copy_( |
180 | c10::DispatchKeySet ks, |
181 | Tensor& self, |
182 | const Tensor& src, |
183 | bool non_blocking) { |
184 | // TODO: once copy is exposed in Declarations.yaml we may be able to bind |
185 | // it automatically |
186 | auto& self_ = unpack(self, "self" , 0); |
187 | auto& src_ = unpack(src, "src" , 1); |
188 | std::shared_ptr<CopyBackwards> grad_fn; |
189 | auto requires_grad = compute_requires_grad(self, src); |
190 | requires_grad &= isDifferentiableType(self.scalar_type()); |
191 | check_inplace(self, requires_grad); |
192 | if (requires_grad) { |
193 | grad_fn = std::make_shared<CopyBackwards>(); |
194 | grad_fn->set_next_edges(collect_next_edges(self, src)); |
195 | grad_fn->src_options = src.options(); |
196 | } |
197 | { |
198 | at::AutoDispatchBelowAutograd mode; |
199 | at::redispatch::copy_( |
200 | ks & c10::after_autograd_keyset, self_, src_, non_blocking); |
201 | } |
202 | rebase_history(self, std::move(grad_fn)); |
203 | |
204 | if (isDifferentiableType(self.scalar_type()) && |
205 | (isFwGradDefined(self) || isFwGradDefined(src))) { |
206 | auto self_fw_grad = generated::details::toNonOptFwGrad(self); |
207 | auto src_fw_grad = generated::details::toNonOptFwGrad(src); |
208 | Tensor new_fw_grad; |
209 | if (self_fw_grad.defined()) { |
210 | if (src_fw_grad.defined()) { |
211 | new_fw_grad = self_fw_grad.copy_(src_fw_grad); |
212 | } else { |
213 | new_fw_grad = self_fw_grad.fill_(0); |
214 | } |
215 | } else { |
216 | if (!self.is_same_size(src_fw_grad)) { |
217 | new_fw_grad = src_fw_grad.broadcast_to(self.sizes()); |
218 | } else { |
219 | new_fw_grad = src_fw_grad.clone(); |
220 | } |
221 | } |
222 | self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true); |
223 | } |
224 | |
225 | return self; |
226 | } |
227 | |
228 | const Tensor& resize_( |
229 | c10::DispatchKeySet ks, |
230 | const Tensor& self, |
231 | SymIntArrayRef size, |
232 | c10::optional<MemoryFormat> optional_memory_format) { |
233 | auto& self_ = unpack(self, "self" , 0); |
234 | if (self.requires_grad()) { |
235 | AT_ERROR("cannot resize variables that require grad" ); |
236 | } |
237 | { |
238 | at::AutoDispatchBelowAutograd mode; |
239 | at::redispatch::resize__symint( |
240 | ks & c10::after_autograd_keyset, self_, size, optional_memory_format); |
241 | } |
242 | |
243 | if (self._fw_grad(/* level */ 0).defined()) { |
244 | AT_ERROR("cannot resize variables that has a forward grad" ); |
245 | } |
246 | |
247 | return self; |
248 | } |
249 | |
250 | const Tensor& resize_as_( |
251 | c10::DispatchKeySet ks, |
252 | const Tensor& self, |
253 | const Tensor& the_template, |
254 | c10::optional<MemoryFormat> optional_memory_format) { |
255 | auto& self_ = unpack(self, "self" , 0); |
256 | auto& the_template_ = unpack(the_template, "the_template" , 1); |
257 | if (self.requires_grad()) { |
258 | AT_ERROR("cannot resize variables that require grad" ); |
259 | } |
260 | { |
261 | at::AutoDispatchBelowAutograd mode; |
262 | at::redispatch::resize_as_( |
263 | ks & c10::after_autograd_keyset, |
264 | self_, |
265 | the_template_, |
266 | optional_memory_format); |
267 | } |
268 | |
269 | // Handle fw grad |
270 | if (self._fw_grad(/* level */ 0).defined()) { |
271 | AT_ERROR("cannot resize variables that has a forward grad" ); |
272 | } |
273 | return self; |
274 | } |
275 | |
276 | Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { |
277 | auto& self_ = unpack(self, "self" , 0); |
278 | RECORD_FUNCTION("detach" , std::vector<c10::IValue>({self})); |
279 | auto result = ([&]() { |
280 | at::AutoDispatchBelowAutograd guard; |
281 | return at::redispatch::detach(ks & c10::after_autograd_keyset, self_); |
282 | })(); |
283 | namedinference::propagate_names(result, self); |
284 | |
285 | // Detach the forward grads by not setting anything on the result |
286 | |
287 | return result; |
288 | } |
289 | |
290 | Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) { |
291 | RECORD_FUNCTION("detach_" , std::vector<c10::IValue>({self})); |
292 | if (self.is_view()) { |
293 | // See NOTE [ View + Inplace detection ] |
294 | AT_ERROR( |
295 | "Can't detach views in-place. Use detach() instead. " |
296 | "If you are using DistributedDataParallel (DDP) for training, " |
297 | "and gradient_as_bucket_view is set as True, gradients are " |
298 | "views of DDP buckets, and hence detach_() cannot be called " |
299 | "on these gradients. To fix this error, please refer to the " |
300 | "Optimizer.zero_grad() function in torch/optim/optimizer.py " |
301 | "as the solution." ); |
302 | } |
303 | // I think the choice here is conservative. In principle, doing |
304 | // an in-place detach should give us the ability to just clear |
305 | // the autograd meta. But this function ONLY resets requires_grad, |
306 | // grad_fn and output_nr; there's other metadata like debug name |
307 | // and hooks which aren't cleared. Is this function supposed to |
308 | // clear those too? I'm not too sure, so I'm leaving it be for now. |
309 | auto autograd_meta = impl::materialize_autograd_meta(self); |
310 | autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl()); |
311 | autograd_meta->grad_fn_.reset(); |
312 | autograd_meta->output_nr_ = 0; |
313 | autograd_meta->fw_grad_.reset(); |
314 | |
315 | return self; |
316 | } |
317 | |
318 | // Ops in the following registration list are registered as |
319 | // (1) CompositeImplicitAutograd kernels |
320 | // (2) Autograd kernels |
321 | // (3) CompositeExplicitAutograd kernels and additionally Autograd kernels |
322 | // The reason for (3) is that ops that also use dispatch (e.g. register |
323 | // CPU/CUDA/QuantizedCPU kernels) will skip picking up CompositeImplicitAutograd |
324 | // kernels for Autograd, so we register them to both CompositeExplicitAutograd |
325 | // and Autograd instead. See |
326 | // https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword |
327 | // for more details. |
328 | // Invariant: |
329 | // - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd |
330 | // below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. |
331 | // and they have manual_kernel_registration=True in native_functions.yaml. |
332 | // - Ops registered to DispatchKey::Autograd below must be included in |
333 | // `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py |
334 | |
335 | TORCH_LIBRARY_IMPL(aten, Autograd, m) { |
336 | m.impl( |
337 | "resize_" , |
338 | torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_))); |
339 | m.impl( |
340 | "resize_as_" , |
341 | torch::dispatch( |
342 | DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_))); |
343 | m.impl( |
344 | "detach" , |
345 | torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach))); |
346 | m.impl( |
347 | "detach_" , |
348 | torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_))); |
349 | m.impl( |
350 | "copy_" , |
351 | torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_))); |
352 | m.impl( |
353 | "_fw_primal" , |
354 | torch::dispatch( |
355 | DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal))); |
356 | m.impl( |
357 | "_make_dual" , |
358 | torch::dispatch( |
359 | DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual))); |
360 | } |
361 | |
362 | } // namespace |
363 | } // namespace VariableType |
364 | } // namespace autograd |
365 | |
366 | namespace ADInplaceOrView { |
367 | #define CREATION_META_DEFINITION \ |
368 | InferenceMode::is_enabled() \ |
369 | ? CreationMeta::INFERENCE_MODE \ |
370 | : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \ |
371 | : CreationMeta::NO_GRAD_MODE) |
372 | |
373 | Tensor& copy_( |
374 | c10::DispatchKeySet ks, |
375 | Tensor& self, |
376 | const Tensor& src, |
377 | bool non_blocking) { |
378 | { |
379 | at::AutoDispatchBelowADInplaceOrView guard; |
380 | at::redispatch::copy_( |
381 | ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking); |
382 | } |
383 | torch::autograd::increment_version(self); |
384 | return self; |
385 | } |
386 | |
387 | Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { |
388 | auto out = ([&]() { |
389 | at::AutoDispatchBelowADInplaceOrView guard; |
390 | return at::_ops::detach::redispatch( |
391 | ks & c10::after_ADInplaceOrView_keyset, self); |
392 | })(); |
393 | // NB: we can't make detach() a normal view operator because the codegen |
394 | // generates allow_tensor_metadata_change = True for them. In the future we |
395 | // should have an option for this in the codegen. |
396 | std::function<at::Tensor(const at::Tensor&)> func = nullptr; |
397 | auto result = as_view( |
398 | /* base */ self, |
399 | /* output */ out, |
400 | /* is_bw_differentiable */ false, |
401 | /* is_fw_differentiable */ false, |
402 | /* view_func */ std::move(func), |
403 | /* creation_meta */ CreationMeta::DEFAULT, |
404 | /*allow_tensor_metadata_change=*/false); |
405 | |
406 | return result; |
407 | } |
408 | |
409 | Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { |
410 | auto tmp = ([&]() { |
411 | at::AutoDispatchBelowADInplaceOrView guard; |
412 | return at::alias(self); |
413 | })(); |
414 | std::function<at::Tensor(const at::Tensor&)> func = nullptr; |
415 | if (!self.unsafeGetTensorImpl()->support_as_strided()) { |
416 | auto size_vec = self.sizes().vec(); |
417 | func = [=](const at::Tensor& input_base) { |
418 | return input_base.view(size_vec); |
419 | }; |
420 | } |
421 | auto result = as_view( |
422 | /* base */ self, |
423 | /* output */ tmp, |
424 | /* is_bw_differentiable */ true, |
425 | /* is_fw_differentiable */ false, |
426 | /* view_func */ std::move(func), |
427 | /* creation_meta */ CREATION_META_DEFINITION); |
428 | |
429 | return result; |
430 | } |
431 | |
432 | // NB: This does not redispatch any further |
433 | Tensor _make_dual( |
434 | c10::DispatchKeySet ks, |
435 | const Tensor& primal, |
436 | const Tensor& tangent, |
437 | int64_t level) { |
438 | auto tmp = ([&]() { |
439 | at::AutoDispatchBelowADInplaceOrView guard; |
440 | return at::alias(primal); |
441 | })(); |
442 | std::function<at::Tensor(const at::Tensor&)> func = nullptr; |
443 | if (!primal.unsafeGetTensorImpl()->support_as_strided()) { |
444 | auto size_vec = primal.sizes().vec(); |
445 | func = [=](const at::Tensor& input_base) { |
446 | return input_base.view(size_vec); |
447 | }; |
448 | } |
449 | auto result = as_view( |
450 | /* base */ primal, |
451 | /* output */ tmp, |
452 | /* is_bw_differentiable */ true, |
453 | /* is_fw_differentiable */ false, |
454 | /* view_func */ std::move(func), |
455 | /* creation_meta */ CREATION_META_DEFINITION); |
456 | |
457 | return result; |
458 | } |
459 | |
460 | namespace { |
461 | TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) { |
462 | m.impl( |
463 | "copy_" , |
464 | torch::dispatch( |
465 | DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_))); |
466 | m.impl( |
467 | "detach" , |
468 | torch::dispatch( |
469 | DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach))); |
470 | m.impl( |
471 | "_fw_primal" , |
472 | torch::dispatch( |
473 | DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal))); |
474 | m.impl( |
475 | "_make_dual" , |
476 | torch::dispatch( |
477 | DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual))); |
478 | } |
479 | } // namespace |
480 | } // namespace ADInplaceOrView |
481 | } // namespace torch |
482 | |