1#pragma once
2
3#include <c10/util/irange.h>
4
5#include <ATen/core/boxing/KernelFunction.h>
6#include <ATen/core/dispatch/Dispatcher.h>
7
8#include <torch/csrc/autograd/edge.h>
9#include <torch/csrc/autograd/function.h>
10#include <torch/csrc/autograd/functions/basic_ops.h>
11#include <torch/csrc/autograd/functions/tensor.h>
12#include <torch/csrc/autograd/grad_mode.h>
13#include <torch/csrc/autograd/saved_variable.h>
14#include <torch/csrc/autograd/variable.h>
15
16#include <torch/csrc/autograd/functions/utils.h>
17#include <torch/csrc/autograd/jit_decomp_interface.h>
18#include <torch/csrc/utils/variadic.h>
19
20#include <array>
21#include <cstddef>
22#include <functional>
23#include <initializer_list>
24#include <memory>
25#include <stdexcept>
26#include <string>
27#include <tuple>
28#include <utility>
29#include <vector>
30
31#ifdef _MSC_VER
32#ifdef Type
33#undef Type
34#endif
35#endif
36
37namespace torch {
38namespace autograd {
39
40// The requires_grad argument is used to know if the inplace operation needs
41// gradient to be setup for it.
42// In particular, we can have tensor.requires_grad() != requires_grad when
43// writing a Tensor that requires gradients inplace into a Tensor that does not
44// require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True)
45// a.copy_(b)
46inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
47 if (requires_grad && GradMode::is_enabled()) {
48 auto diff_view_meta = impl::get_view_autograd_meta(tensor);
49 if (diff_view_meta && diff_view_meta->has_bw_view()) {
50 // This can throw or warn
51 handle_view_on_rebase(diff_view_meta);
52 if (tensor.requires_grad() && tensor._base().is_leaf()) {
53 AT_ERROR(
54 "a view of a leaf Variable that requires grad is being used in an in-place operation.");
55 }
56 }
57 if (tensor.requires_grad() && tensor.is_leaf()) {
58 AT_ERROR(
59 "a leaf Variable that requires grad is being used in an in-place operation.");
60 }
61 }
62}
63
64inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) {
65 for (const auto& tensor : tensors) {
66 check_inplace(tensor, requires_grad);
67 }
68}
69
70inline void throw_error_out_requires_grad(const char* name) {
71 AT_ERROR(
72 name,
73 "(): functions with out=... arguments don't support automatic differentiation, "
74 "but one of the arguments requires grad.");
75}
76
77inline void throw_error_for_complex_autograd(
78 const at::Tensor& tensor,
79 const char* name) {
80 if (tensor.requires_grad()) {
81 TORCH_CHECK(
82 !tensor.is_complex(),
83 name,
84 " does not support automatic differentiation for outputs with complex dtype.");
85 }
86}
87
88inline void throw_error_if_base_and_tensor_are_same(
89 const at::Tensor& base,
90 const at::Tensor& tensor) {
91 TORCH_CHECK(
92 base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(),
93 "View operation returned a tensor that is the same as the input base tensor. This "
94 "is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). "
95 "As a user, you could have made a mistake implementing __torch_dispatch__ or a Python "
96 "operator decomposition or meta registration; if that's not the case, please "
97 "report a bug to PyTorch or the backend you are using.");
98}
99
100inline void throw_error_for_complex_autograd(
101 at::ITensorListRef tensorlist,
102 const char* name) {
103 for (const auto& tensor : tensorlist) {
104 throw_error_for_complex_autograd(tensor, name);
105 }
106}
107
108// TODO: Blegh, bare references
109
110inline void rebase_history(Variable& var, std::shared_ptr<Node> grad_fn) {
111 if (grad_fn && var.defined()) {
112 grad_fn->add_input_metadata(var);
113 impl::rebase_history(var, {std::move(grad_fn), 0});
114 }
115}
116
117inline void rebase_history(
118 std::vector<Variable>&& vars,
119 std::shared_ptr<Node> grad_fn) {
120 if (grad_fn) {
121 for (auto& var : vars) {
122 if (var.defined()) {
123 auto output_nr = grad_fn->add_input_metadata(var);
124 impl::rebase_history(var, {grad_fn, output_nr});
125 } else {
126 grad_fn->add_input_metadata(Node::undefined_input());
127 }
128 }
129 }
130}
131
132inline void increment_version(const at::Tensor& t) {
133 impl::bump_version(t);
134}
135
136struct Flatten : IterArgs<Flatten> {
137 Flatten(variable_list& out) : out(out) {}
138 variable_list& out;
139 void operator()(const at::Tensor& x) {
140 out.emplace_back(x);
141 }
142 void operator()(const c10::optional<at::Tensor>& x) {
143 if (x.has_value())
144 out.emplace_back(x.value());
145 }
146 void operator()(at::ArrayRef<at::Tensor> xs) {
147 out.insert(out.end(), xs.begin(), xs.end());
148 }
149};
150
151template <typename... Args>
152inline variable_list flatten_tensor_args(Args&&... args) {
153 variable_list out;
154 out.reserve(count_tensors(std::forward<Args>(args)...));
155 Flatten(out).apply(std::forward<Args>(args)...);
156 return out; // RVO
157}
158
159// See NOTE [ Autograd View Variables ] for details.
160inline at::Tensor as_view(
161 const at::Tensor& base,
162 const at::Tensor& tensor,
163 bool is_bw_differentiable,
164 bool is_fw_differentiable,
165 std::function<at::Tensor(const at::Tensor&)> view_func = nullptr,
166 CreationMeta creation_meta = CreationMeta::DEFAULT,
167 bool allow_tensor_metadata_change = true) {
168 // Note [View of inference tensor]
169 // For inference tensor this code can only be hit outside InferenceMode
170 // since ADInplaceOrView is in the default_included_set.
171 // If Inplace and View were separate dispatch keys we can just put Inplace
172 // in the default_included_set, so that view ops on inference tensor doesn't
173 // have to go through as_view even outside InferenceMode.
174 if (base.is_inference())
175 return tensor;
176
177 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
178
179 // To speed up the most common case, we specially handle when both the forward
180 // and backward view infos are the same, and so a single shared ViewInfo can
181 // be used for both of them.
182 if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
183 is_bw_differentiable && is_fw_differentiable) {
184 throw_error_if_base_and_tensor_are_same(base, tensor);
185 if (diff_view_meta) {
186 creation_meta = propagate_creation_meta(
187 diff_view_meta->get_creation_meta(), creation_meta);
188 return make_variable_differentiable_view(
189 tensor,
190 diff_view_meta->get_backward_view().chain(
191 base, tensor, std::move(view_func)),
192 c10::nullopt,
193 /*shared_view_info*/ true,
194 creation_meta,
195 allow_tensor_metadata_change);
196 } else {
197 return make_variable_differentiable_view(
198 tensor,
199 ViewInfo(base, std::move(view_func)),
200 c10::nullopt,
201 /*shared_view_info*/ true,
202 creation_meta,
203 allow_tensor_metadata_change);
204 }
205 }
206
207 // If they cannot be shared, create the required view infos
208 c10::optional<ViewInfo> new_bw_info;
209 c10::optional<ViewInfo> new_fw_info;
210
211 if (is_bw_differentiable) {
212 if (diff_view_meta && diff_view_meta->has_bw_view()) {
213 const auto& base_bw_info = diff_view_meta->get_backward_view();
214 new_bw_info = base_bw_info.chain(base, tensor, view_func);
215 } else {
216 new_bw_info = ViewInfo(base, view_func);
217 }
218 } else {
219 TORCH_CHECK(
220 creation_meta == CreationMeta::DEFAULT,
221 "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
222 }
223
224 if (is_fw_differentiable) {
225 // Check if base is a forward differentiable view
226 if (diff_view_meta && diff_view_meta->has_fw_view()) {
227 const auto& base_fw_info = diff_view_meta->get_forward_view();
228 new_fw_info = base_fw_info.chain(base, tensor, std::move(view_func));
229 } else {
230 new_fw_info = ViewInfo(base, std::move(view_func));
231 }
232 }
233
234 if (is_fw_differentiable || is_bw_differentiable) {
235 if (diff_view_meta && diff_view_meta->has_bw_view()) {
236 creation_meta = propagate_creation_meta(
237 diff_view_meta->get_creation_meta(), creation_meta);
238 }
239 throw_error_if_base_and_tensor_are_same(base, tensor);
240 return make_variable_differentiable_view(
241 tensor,
242 std::move(new_bw_info),
243 std::move(new_fw_info),
244 /*shared_view_info*/ false,
245 creation_meta,
246 allow_tensor_metadata_change);
247 } else {
248 return make_variable_non_differentiable_view(
249 base, tensor, allow_tensor_metadata_change);
250 }
251}
252
253// See NOTE [ Autograd View Variables ] for details.
254inline std::vector<at::Tensor> as_view(
255 const at::Tensor& base,
256 std::vector<at::Tensor>& tensors,
257 bool is_bw_differentiable,
258 bool is_fw_differentiable,
259 CreationMeta creation_meta = CreationMeta::DEFAULT) {
260 // See Note [View of inference tensor]
261 if (base.is_inference())
262 return tensors;
263
264 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
265
266 // Special case when view info can be shared for forward and backward
267 // differentiable views
268 if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
269 is_bw_differentiable && is_fw_differentiable) {
270 c10::optional<ViewInfo> new_shared_info;
271 if (diff_view_meta) {
272 // TODO: fix fb internal use-case so that it doesn't trigger this internal
273 // assert when the base is not a view. For now, we only do that same
274 // (wrong) thing as the old code which is to only check when the inputs is
275 // a backward differentiable view
276 if (diff_view_meta->has_bw_view()) {
277 TORCH_INTERNAL_ASSERT(
278 creation_meta == CreationMeta::NO_GRAD_MODE ||
279 creation_meta == CreationMeta::INFERENCE_MODE ||
280 creation_meta == CreationMeta::MULTI_OUTPUT_NODE,
281 "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive.");
282 }
283 creation_meta = propagate_creation_meta(
284 diff_view_meta->get_creation_meta(), creation_meta);
285 const auto& base_bw_info = diff_view_meta->get_backward_view();
286 new_shared_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr);
287 } else {
288 new_shared_info = ViewInfo(base, /* view_func */ nullptr);
289 }
290
291 for (at::Tensor& tensor : tensors) {
292 if (is_fw_differentiable || is_bw_differentiable) {
293 tensor = make_variable_differentiable_view(
294 tensor,
295 new_shared_info,
296 c10::nullopt,
297 /*shared_view_info*/ true,
298 creation_meta);
299 } else {
300 tensor = make_variable_non_differentiable_view(base, tensor);
301 }
302 }
303 return tensors;
304 }
305
306 c10::optional<ViewInfo> new_bw_info = c10::nullopt;
307 c10::optional<ViewInfo> new_fw_info = c10::nullopt;
308
309 if (is_bw_differentiable) {
310 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
311 if (diff_view_meta && diff_view_meta->has_bw_view()) {
312 const auto& base_bw_info = diff_view_meta->get_backward_view();
313 // TODO: fix fb internal use-case so that it doesn't trigger this internal
314 // assert when the base is not a view. In this code, the assert should be
315 // outside of the if statement.
316 TORCH_INTERNAL_ASSERT(
317 creation_meta == CreationMeta::NO_GRAD_MODE ||
318 creation_meta == CreationMeta::INFERENCE_MODE ||
319 creation_meta == CreationMeta::MULTI_OUTPUT_NODE,
320 "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive.");
321 // It is ok to create a ViewInfo where only the base is correct in this
322 // case as inplace operations on such views are not allowed
323 new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr);
324 } else {
325 new_bw_info = ViewInfo(base, /* view_func */ nullptr);
326 }
327 } else {
328 TORCH_CHECK(
329 creation_meta == CreationMeta::DEFAULT,
330 "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
331 }
332 if (is_fw_differentiable) {
333 // Check if base is a forward differentiabble view
334 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
335 if (diff_view_meta && diff_view_meta->has_fw_view()) {
336 const auto& base_fw_info = diff_view_meta->get_forward_view();
337 TORCH_INTERNAL_ASSERT(
338 creation_meta == CreationMeta::NO_GRAD_MODE ||
339 creation_meta == CreationMeta::INFERENCE_MODE ||
340 creation_meta == CreationMeta::MULTI_OUTPUT_NODE,
341 "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive.");
342 // It is ok to create a ViewInfo where only the base is correct in this
343 // case as inplace operations on such views are not allowed
344 new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ nullptr);
345 } else {
346 new_fw_info = ViewInfo(base, /* view_func */ nullptr);
347 }
348 }
349
350 if ((is_fw_differentiable || is_bw_differentiable) && base.is_view()) {
351 // is_view() => diff_view_meta
352 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
353 creation_meta = propagate_creation_meta(
354 diff_view_meta->get_creation_meta(), creation_meta);
355 }
356
357 for (at::Tensor& tensor : tensors) {
358 if (is_fw_differentiable || is_bw_differentiable) {
359 tensor = make_variable_differentiable_view(
360 tensor,
361 new_bw_info,
362 new_fw_info,
363 /*shared_view_info*/ false,
364 creation_meta);
365 } else {
366 tensor = make_variable_non_differentiable_view(base, tensor);
367 }
368 }
369 return tensors;
370}
371
372inline void check_no_requires_grad(
373 const at::Tensor& tensor,
374 const char* name,
375 const char* fn_name = "",
376 bool check_grad_mode = true) {
377 TORCH_CHECK(
378 !(tensor.defined() && tensor.requires_grad()) ||
379 !(check_grad_mode && GradMode::is_enabled()),
380 "The function '",
381 fn_name,
382 "' is not differentiable with respect to argument '",
383 name,
384 "'. This input cannot have requires_grad True.");
385}
386
387inline void check_no_requires_grad(
388 const c10::optional<at::Tensor>& tensor,
389 const char* name,
390 const char* fn_name = "") {
391 if (tensor.has_value()) {
392 check_no_requires_grad(*tensor, name, fn_name);
393 }
394}
395
396inline void check_no_requires_grad(
397 at::ITensorListRef tensors,
398 const char* name,
399 const char* fn_name = "") {
400 // GradMode check is expensive, so check it only once for TensorLists
401 if (!GradMode::is_enabled()) {
402 return;
403 }
404 for (auto& tensor : tensors) {
405 check_no_requires_grad(tensor, name, fn_name, /*check_grad_mode*/ false);
406 }
407}
408
409inline void check_no_requires_grad(
410 const c10::List<c10::optional<at::Tensor>>& tensors,
411 const char* name,
412 const char* fn_name = "") {
413 // GradMode check is expensive, so check it only once for TensorLists
414 if (!GradMode::is_enabled()) {
415 return;
416 }
417 for (c10::optional<at::Tensor> tensor : tensors) {
418 if (tensor.has_value()) {
419 check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false);
420 }
421 }
422}
423
424// Assumed that saved tensor lists are never inplace outputs
425inline std::vector<SavedVariable> make_saved_variable_list(
426 at::ITensorListRef tensors) {
427 return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable {
428 return SavedVariable{tensor, false /* is output */};
429 });
430}
431
432// Assumed that saved tensor lists are never inplace outputs
433inline std::vector<SavedVariable> make_saved_variable_list(
434 const c10::List<c10::optional<at::Tensor>>& tensors) {
435 return fmap(
436 tensors, [](const c10::optional<at::Tensor>& tensor) -> SavedVariable {
437 if (tensor.has_value()) {
438 return SavedVariable{*tensor, false /* is output */};
439 } else {
440 return SavedVariable{at::Tensor(), false /* is output */};
441 }
442 });
443}
444
445inline std::vector<std::vector<int64_t>> to_args_sizes(
446 at::ITensorListRef tensors) {
447 std::vector<std::vector<int64_t>> args_sizes(tensors.size());
448 size_t i = 0;
449 for (const auto& t : tensors) {
450 args_sizes[i++] = t.sizes().vec();
451 }
452 return args_sizes;
453}
454
455inline std::vector<std::vector<c10::SymInt>> to_args_sizes_symint(
456 at::ITensorListRef tensors) {
457 std::vector<std::vector<c10::SymInt>> args_sizes(tensors.size());
458 size_t i = 0;
459 for (const auto& t : tensors) {
460 args_sizes[i++] = t.sym_sizes().vec();
461 }
462 return args_sizes;
463}
464
465inline std::vector<c10::ScalarType> to_args_scalartypes(
466 at::ITensorListRef tensors) {
467 std::vector<c10::ScalarType> args_scalartypes(tensors.size());
468 size_t i = 0;
469 for (const auto& t : tensors) {
470 args_scalartypes[i++] = t.scalar_type();
471 }
472 return args_scalartypes;
473}
474
475namespace impl {
476
477namespace {
478
479// If run_jit_decomposition were not a member function, we would be able
480// to pass this as a template parameter to c10::Boxedkernel::makeFromFunction.
481// However, member functions cannot be passed this way - instead we wrap our
482// call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor
483class WrapperFunctor final : public c10::OperatorKernel {
484 public:
485 WrapperFunctor(JitDecompInterface* impl) : impl_(impl){};
486
487 void operator()(
488 const c10::OperatorHandle& op,
489 c10::DispatchKeySet ks,
490 torch::jit::Stack* stack) {
491 impl_->run_jit_decomposition(op, stack);
492 }
493 JitDecompInterface* impl_;
494};
495
496} // namespace
497
498template <class Return, class... Args>
499Return run_jit_decomposition_with_args_for_jvp(
500 c10::string_view name,
501 const c10::OperatorHandle& opHandle,
502 c10::DispatchKeySet dispatchKeySet,
503 Args&&... args) {
504 // see NOTE: [Jit Decomposition Interface]
505 JitDecompInterface* impl = getJitDecompImpl();
506
507 TORCH_CHECK_NOT_IMPLEMENTED(
508 impl && impl->has_jit_decomposition(opHandle.schema()),
509 "Trying to use forward AD with ",
510 name,
511 " that does not support it because it has not been implemented yet.\nPlease file an issue "
512 "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
513 "so that we can prioritize its implementation.\n"
514 "Note that forward AD support for some operators require PyTorch to be built with "
515 "TorchScript and for JIT to be enabled. "
516 "If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, "
517 "some operators may no longer be used with forward AD.");
518
519 return c10::KernelFunction::makeFromBoxedKernel(
520 c10::BoxedKernel::makeFromFunctor(
521 std::make_unique<WrapperFunctor>(impl)))
522 .call<Return, Args...>(
523 opHandle, dispatchKeySet, std::forward<Args>(args)...);
524}
525
526} // namespace impl
527
528} // namespace autograd
529} // namespace torch
530