1#include <ATen/RedispatchFunctions.h>
2#include <ATen/TracerMode.h>
3#include <ATen/core/op_registration/op_registration.h>
4#include <c10/core/ScalarType.h>
5#include <c10/util/Optional.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/autograd/FunctionsManual.h>
8#include <torch/csrc/autograd/VariableTypeUtils.h>
9#include <torch/csrc/autograd/autograd.h>
10#include <torch/csrc/autograd/functions/utils.h>
11#include <torch/csrc/utils/memory.h>
12#include <torch/library.h>
13
14#include <utility>
15
16using namespace at;
17using namespace torch::autograd::generated;
18using torch::autograd::as_view;
19using torch::autograd::CreationMeta;
20
21namespace torch {
22namespace autograd {
23namespace VariableType {
24
25std::vector<at::DeprecatedTypeProperties*> allTypesForBackends(
26 at::ArrayRef<at::Backend> backends) {
27 std::vector<DeprecatedTypeProperties*> res;
28 res.reserve(backends.size());
29 for (auto p : backends) {
30 for (const auto s :
31 c10::irange(static_cast<int64_t>(ScalarType::NumOptions))) {
32 auto& type = getDeprecatedTypeProperties(
33 static_cast<Backend>(p), static_cast<ScalarType>(s));
34 res.emplace_back(&type);
35 }
36 }
37 return res;
38}
39
40C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCPUTypes() {
41 return allTypesForBackends({Backend::CPU, Backend::SparseCPU});
42}
43
44C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
45 at::globalContext().lazyInitCUDA();
46 return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA});
47}
48
49namespace {
50const Variable& checked_cast_variable(
51 const Tensor& t,
52 const char* name,
53 int pos) {
54 if (!t.defined()) {
55 AT_ERROR(
56 "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
57 "for argument #",
58 pos,
59 " '",
60 name,
61 "'");
62 }
63 return t;
64}
65
66Variable& checked_cast_variable(Tensor& t, const char* name, int pos) {
67 if (!t.defined()) {
68 AT_ERROR(
69 "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
70 "for argument #",
71 pos,
72 " '",
73 name,
74 "'");
75 }
76 return t;
77}
78} // namespace
79
80const Tensor& unpack(const Tensor& t, const char* name, int pos) {
81 return checked_cast_variable(t, name, pos);
82}
83
84Tensor& unpack(Tensor& t, const char* name, int pos) {
85 return checked_cast_variable(t, name, pos);
86}
87
88Tensor unpack_opt(const Tensor& t, const char* name, int pos) {
89 if (!t.defined()) {
90 return Tensor();
91 }
92 return unpack(t, name, pos);
93}
94
95std::vector<at::Tensor> unpack(
96 at::ITensorListRef tl,
97 const char* name,
98 int pos) {
99 std::vector<at::Tensor> ret;
100 ret.reserve(tl.size());
101 for (const auto& t : tl) {
102 ret.push_back(t.defined() ? static_cast<const Variable&>(t) : Variable{});
103 }
104 return ret;
105}
106
107namespace {
108
109// Taken from codegened version
110Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) {
111 auto& self_ = unpack(self, "self", 0);
112 std::shared_ptr<Identity> grad_fn;
113 if (compute_requires_grad(self)) {
114 grad_fn = std::make_shared<Identity>();
115 grad_fn->set_next_edges(collect_next_edges(self));
116 }
117
118 auto result = ([&]() {
119 at::AutoDispatchBelowAutograd guard;
120 return at::redispatch::_fw_primal(
121 ks & c10::after_autograd_keyset, self_, level);
122 })();
123
124 if (grad_fn) {
125 set_history(flatten_tensor_args(result), grad_fn);
126 }
127 if (isFwGradDefined(self)) {
128 // Modified from original codegen
129 // We explicitly want to ignore the forward grad at the given level
130 TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
131 // End modified from original codegen
132 }
133 return result;
134}
135
136// NB: We need a manual variable type kernel so that set_fw_grad properly
137// detects that _make_dual is not a forward-differentiable view
138//
139// This function can be used to create a dual Tensor that holds a tangent to
140// compute forward mode gradients. Note that the dual Tensor's primal is a view
141// of the given primal and the given tangent is used as-is. This function is
142// backward differentiable.
143Tensor _make_dual(
144 c10::DispatchKeySet ks,
145 const Tensor& primal,
146 const Tensor& tangent,
147 int64_t level) {
148 TORCH_CHECK(
149 !primal._fw_grad(level).defined(),
150 "Making a dual Tensor based on a Tensor that "
151 "already has a forward gradient at the same level ",
152 level,
153 " is not supported.");
154 auto& primal_ = unpack(primal, "primal", 0);
155 auto& tangent_ = unpack(tangent, "tangent", 0);
156 std::shared_ptr<ViewBackward0> grad_fn;
157 if (compute_requires_grad(primal_)) {
158 grad_fn = std::make_shared<ViewBackward0>();
159 grad_fn->self_sym_sizes = primal_.sym_sizes().vec();
160 grad_fn->set_next_edges(collect_next_edges(primal_));
161 }
162
163 auto result = ([&]() {
164 at::AutoDispatchBelowAutograd guard;
165 return at::redispatch::_make_dual(
166 ks & c10::after_autograd_keyset, primal_, tangent_, level);
167 })();
168
169 if (grad_fn) {
170 set_history(flatten_tensor_args(result), grad_fn);
171 }
172
173 TORCH_CHECK(level == 0, "Invalid level given to _make_dual");
174 result._set_fw_grad(tangent_, level, /* is_inplace_op */ false);
175 return result;
176}
177
178// We don't have an outplace copy, so this can't be generated automatically
179Tensor& copy_(
180 c10::DispatchKeySet ks,
181 Tensor& self,
182 const Tensor& src,
183 bool non_blocking) {
184 // TODO: once copy is exposed in Declarations.yaml we may be able to bind
185 // it automatically
186 auto& self_ = unpack(self, "self", 0);
187 auto& src_ = unpack(src, "src", 1);
188 std::shared_ptr<CopyBackwards> grad_fn;
189 auto requires_grad = compute_requires_grad(self, src);
190 requires_grad &= isDifferentiableType(self.scalar_type());
191 check_inplace(self, requires_grad);
192 if (requires_grad) {
193 grad_fn = std::make_shared<CopyBackwards>();
194 grad_fn->set_next_edges(collect_next_edges(self, src));
195 grad_fn->src_options = src.options();
196 }
197 {
198 at::AutoDispatchBelowAutograd mode;
199 at::redispatch::copy_(
200 ks & c10::after_autograd_keyset, self_, src_, non_blocking);
201 }
202 rebase_history(self, std::move(grad_fn));
203
204 if (isDifferentiableType(self.scalar_type()) &&
205 (isFwGradDefined(self) || isFwGradDefined(src))) {
206 auto self_fw_grad = generated::details::toNonOptFwGrad(self);
207 auto src_fw_grad = generated::details::toNonOptFwGrad(src);
208 Tensor new_fw_grad;
209 if (self_fw_grad.defined()) {
210 if (src_fw_grad.defined()) {
211 new_fw_grad = self_fw_grad.copy_(src_fw_grad);
212 } else {
213 new_fw_grad = self_fw_grad.fill_(0);
214 }
215 } else {
216 if (!self.is_same_size(src_fw_grad)) {
217 new_fw_grad = src_fw_grad.broadcast_to(self.sizes());
218 } else {
219 new_fw_grad = src_fw_grad.clone();
220 }
221 }
222 self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
223 }
224
225 return self;
226}
227
228const Tensor& resize_(
229 c10::DispatchKeySet ks,
230 const Tensor& self,
231 SymIntArrayRef size,
232 c10::optional<MemoryFormat> optional_memory_format) {
233 auto& self_ = unpack(self, "self", 0);
234 if (self.requires_grad()) {
235 AT_ERROR("cannot resize variables that require grad");
236 }
237 {
238 at::AutoDispatchBelowAutograd mode;
239 at::redispatch::resize__symint(
240 ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
241 }
242
243 if (self._fw_grad(/* level */ 0).defined()) {
244 AT_ERROR("cannot resize variables that has a forward grad");
245 }
246
247 return self;
248}
249
250const Tensor& resize_as_(
251 c10::DispatchKeySet ks,
252 const Tensor& self,
253 const Tensor& the_template,
254 c10::optional<MemoryFormat> optional_memory_format) {
255 auto& self_ = unpack(self, "self", 0);
256 auto& the_template_ = unpack(the_template, "the_template", 1);
257 if (self.requires_grad()) {
258 AT_ERROR("cannot resize variables that require grad");
259 }
260 {
261 at::AutoDispatchBelowAutograd mode;
262 at::redispatch::resize_as_(
263 ks & c10::after_autograd_keyset,
264 self_,
265 the_template_,
266 optional_memory_format);
267 }
268
269 // Handle fw grad
270 if (self._fw_grad(/* level */ 0).defined()) {
271 AT_ERROR("cannot resize variables that has a forward grad");
272 }
273 return self;
274}
275
276Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
277 auto& self_ = unpack(self, "self", 0);
278 RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
279 auto result = ([&]() {
280 at::AutoDispatchBelowAutograd guard;
281 return at::redispatch::detach(ks & c10::after_autograd_keyset, self_);
282 })();
283 namedinference::propagate_names(result, self);
284
285 // Detach the forward grads by not setting anything on the result
286
287 return result;
288}
289
290Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) {
291 RECORD_FUNCTION("detach_", std::vector<c10::IValue>({self}));
292 if (self.is_view()) {
293 // See NOTE [ View + Inplace detection ]
294 AT_ERROR(
295 "Can't detach views in-place. Use detach() instead. "
296 "If you are using DistributedDataParallel (DDP) for training, "
297 "and gradient_as_bucket_view is set as True, gradients are "
298 "views of DDP buckets, and hence detach_() cannot be called "
299 "on these gradients. To fix this error, please refer to the "
300 "Optimizer.zero_grad() function in torch/optim/optimizer.py "
301 "as the solution.");
302 }
303 // I think the choice here is conservative. In principle, doing
304 // an in-place detach should give us the ability to just clear
305 // the autograd meta. But this function ONLY resets requires_grad,
306 // grad_fn and output_nr; there's other metadata like debug name
307 // and hooks which aren't cleared. Is this function supposed to
308 // clear those too? I'm not too sure, so I'm leaving it be for now.
309 auto autograd_meta = impl::materialize_autograd_meta(self);
310 autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl());
311 autograd_meta->grad_fn_.reset();
312 autograd_meta->output_nr_ = 0;
313 autograd_meta->fw_grad_.reset();
314
315 return self;
316}
317
318// Ops in the following registration list are registered as
319// (1) CompositeImplicitAutograd kernels
320// (2) Autograd kernels
321// (3) CompositeExplicitAutograd kernels and additionally Autograd kernels
322// The reason for (3) is that ops that also use dispatch (e.g. register
323// CPU/CUDA/QuantizedCPU kernels) will skip picking up CompositeImplicitAutograd
324// kernels for Autograd, so we register them to both CompositeExplicitAutograd
325// and Autograd instead. See
326// https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword
327// for more details.
328// Invariant:
329// - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd
330// below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py.
331// and they have manual_kernel_registration=True in native_functions.yaml.
332// - Ops registered to DispatchKey::Autograd below must be included in
333// `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py
334
335TORCH_LIBRARY_IMPL(aten, Autograd, m) {
336 m.impl(
337 "resize_",
338 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_)));
339 m.impl(
340 "resize_as_",
341 torch::dispatch(
342 DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_)));
343 m.impl(
344 "detach",
345 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach)));
346 m.impl(
347 "detach_",
348 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_)));
349 m.impl(
350 "copy_",
351 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_)));
352 m.impl(
353 "_fw_primal",
354 torch::dispatch(
355 DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal)));
356 m.impl(
357 "_make_dual",
358 torch::dispatch(
359 DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual)));
360}
361
362} // namespace
363} // namespace VariableType
364} // namespace autograd
365
366namespace ADInplaceOrView {
367#define CREATION_META_DEFINITION \
368 InferenceMode::is_enabled() \
369 ? CreationMeta::INFERENCE_MODE \
370 : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \
371 : CreationMeta::NO_GRAD_MODE)
372
373Tensor& copy_(
374 c10::DispatchKeySet ks,
375 Tensor& self,
376 const Tensor& src,
377 bool non_blocking) {
378 {
379 at::AutoDispatchBelowADInplaceOrView guard;
380 at::redispatch::copy_(
381 ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking);
382 }
383 torch::autograd::increment_version(self);
384 return self;
385}
386
387Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
388 auto out = ([&]() {
389 at::AutoDispatchBelowADInplaceOrView guard;
390 return at::_ops::detach::redispatch(
391 ks & c10::after_ADInplaceOrView_keyset, self);
392 })();
393 // NB: we can't make detach() a normal view operator because the codegen
394 // generates allow_tensor_metadata_change = True for them. In the future we
395 // should have an option for this in the codegen.
396 std::function<at::Tensor(const at::Tensor&)> func = nullptr;
397 auto result = as_view(
398 /* base */ self,
399 /* output */ out,
400 /* is_bw_differentiable */ false,
401 /* is_fw_differentiable */ false,
402 /* view_func */ std::move(func),
403 /* creation_meta */ CreationMeta::DEFAULT,
404 /*allow_tensor_metadata_change=*/false);
405
406 return result;
407}
408
409Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) {
410 auto tmp = ([&]() {
411 at::AutoDispatchBelowADInplaceOrView guard;
412 return at::alias(self);
413 })();
414 std::function<at::Tensor(const at::Tensor&)> func = nullptr;
415 if (!self.unsafeGetTensorImpl()->support_as_strided()) {
416 auto size_vec = self.sizes().vec();
417 func = [=](const at::Tensor& input_base) {
418 return input_base.view(size_vec);
419 };
420 }
421 auto result = as_view(
422 /* base */ self,
423 /* output */ tmp,
424 /* is_bw_differentiable */ true,
425 /* is_fw_differentiable */ false,
426 /* view_func */ std::move(func),
427 /* creation_meta */ CREATION_META_DEFINITION);
428
429 return result;
430}
431
432// NB: This does not redispatch any further
433Tensor _make_dual(
434 c10::DispatchKeySet ks,
435 const Tensor& primal,
436 const Tensor& tangent,
437 int64_t level) {
438 auto tmp = ([&]() {
439 at::AutoDispatchBelowADInplaceOrView guard;
440 return at::alias(primal);
441 })();
442 std::function<at::Tensor(const at::Tensor&)> func = nullptr;
443 if (!primal.unsafeGetTensorImpl()->support_as_strided()) {
444 auto size_vec = primal.sizes().vec();
445 func = [=](const at::Tensor& input_base) {
446 return input_base.view(size_vec);
447 };
448 }
449 auto result = as_view(
450 /* base */ primal,
451 /* output */ tmp,
452 /* is_bw_differentiable */ true,
453 /* is_fw_differentiable */ false,
454 /* view_func */ std::move(func),
455 /* creation_meta */ CREATION_META_DEFINITION);
456
457 return result;
458}
459
460namespace {
461TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) {
462 m.impl(
463 "copy_",
464 torch::dispatch(
465 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_)));
466 m.impl(
467 "detach",
468 torch::dispatch(
469 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach)));
470 m.impl(
471 "_fw_primal",
472 torch::dispatch(
473 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal)));
474 m.impl(
475 "_make_dual",
476 torch::dispatch(
477 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual)));
478}
479} // namespace
480} // namespace ADInplaceOrView
481} // namespace torch
482