1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | |
5 | #include <ATen/core/boxing/KernelFunction.h> |
6 | #include <ATen/core/dispatch/Dispatcher.h> |
7 | |
8 | #include <torch/csrc/autograd/edge.h> |
9 | #include <torch/csrc/autograd/function.h> |
10 | #include <torch/csrc/autograd/functions/basic_ops.h> |
11 | #include <torch/csrc/autograd/functions/tensor.h> |
12 | #include <torch/csrc/autograd/grad_mode.h> |
13 | #include <torch/csrc/autograd/saved_variable.h> |
14 | #include <torch/csrc/autograd/variable.h> |
15 | |
16 | #include <torch/csrc/autograd/functions/utils.h> |
17 | #include <torch/csrc/autograd/jit_decomp_interface.h> |
18 | #include <torch/csrc/utils/variadic.h> |
19 | |
20 | #include <array> |
21 | #include <cstddef> |
22 | #include <functional> |
23 | #include <initializer_list> |
24 | #include <memory> |
25 | #include <stdexcept> |
26 | #include <string> |
27 | #include <tuple> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | #ifdef _MSC_VER |
32 | #ifdef Type |
33 | #undef Type |
34 | #endif |
35 | #endif |
36 | |
37 | namespace torch { |
38 | namespace autograd { |
39 | |
40 | // The requires_grad argument is used to know if the inplace operation needs |
41 | // gradient to be setup for it. |
42 | // In particular, we can have tensor.requires_grad() != requires_grad when |
43 | // writing a Tensor that requires gradients inplace into a Tensor that does not |
44 | // require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True) |
45 | // a.copy_(b) |
46 | inline void check_inplace(const at::Tensor& tensor, bool requires_grad) { |
47 | if (requires_grad && GradMode::is_enabled()) { |
48 | auto diff_view_meta = impl::get_view_autograd_meta(tensor); |
49 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
50 | // This can throw or warn |
51 | handle_view_on_rebase(diff_view_meta); |
52 | if (tensor.requires_grad() && tensor._base().is_leaf()) { |
53 | AT_ERROR( |
54 | "a view of a leaf Variable that requires grad is being used in an in-place operation." ); |
55 | } |
56 | } |
57 | if (tensor.requires_grad() && tensor.is_leaf()) { |
58 | AT_ERROR( |
59 | "a leaf Variable that requires grad is being used in an in-place operation." ); |
60 | } |
61 | } |
62 | } |
63 | |
64 | inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) { |
65 | for (const auto& tensor : tensors) { |
66 | check_inplace(tensor, requires_grad); |
67 | } |
68 | } |
69 | |
70 | inline void throw_error_out_requires_grad(const char* name) { |
71 | AT_ERROR( |
72 | name, |
73 | "(): functions with out=... arguments don't support automatic differentiation, " |
74 | "but one of the arguments requires grad." ); |
75 | } |
76 | |
77 | inline void throw_error_for_complex_autograd( |
78 | const at::Tensor& tensor, |
79 | const char* name) { |
80 | if (tensor.requires_grad()) { |
81 | TORCH_CHECK( |
82 | !tensor.is_complex(), |
83 | name, |
84 | " does not support automatic differentiation for outputs with complex dtype." ); |
85 | } |
86 | } |
87 | |
88 | inline void throw_error_if_base_and_tensor_are_same( |
89 | const at::Tensor& base, |
90 | const at::Tensor& tensor) { |
91 | TORCH_CHECK( |
92 | base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(), |
93 | "View operation returned a tensor that is the same as the input base tensor. This " |
94 | "is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). " |
95 | "As a user, you could have made a mistake implementing __torch_dispatch__ or a Python " |
96 | "operator decomposition or meta registration; if that's not the case, please " |
97 | "report a bug to PyTorch or the backend you are using." ); |
98 | } |
99 | |
100 | inline void throw_error_for_complex_autograd( |
101 | at::ITensorListRef tensorlist, |
102 | const char* name) { |
103 | for (const auto& tensor : tensorlist) { |
104 | throw_error_for_complex_autograd(tensor, name); |
105 | } |
106 | } |
107 | |
108 | // TODO: Blegh, bare references |
109 | |
110 | inline void rebase_history(Variable& var, std::shared_ptr<Node> grad_fn) { |
111 | if (grad_fn && var.defined()) { |
112 | grad_fn->add_input_metadata(var); |
113 | impl::rebase_history(var, {std::move(grad_fn), 0}); |
114 | } |
115 | } |
116 | |
117 | inline void rebase_history( |
118 | std::vector<Variable>&& vars, |
119 | std::shared_ptr<Node> grad_fn) { |
120 | if (grad_fn) { |
121 | for (auto& var : vars) { |
122 | if (var.defined()) { |
123 | auto output_nr = grad_fn->add_input_metadata(var); |
124 | impl::rebase_history(var, {grad_fn, output_nr}); |
125 | } else { |
126 | grad_fn->add_input_metadata(Node::undefined_input()); |
127 | } |
128 | } |
129 | } |
130 | } |
131 | |
132 | inline void increment_version(const at::Tensor& t) { |
133 | impl::bump_version(t); |
134 | } |
135 | |
136 | struct Flatten : IterArgs<Flatten> { |
137 | Flatten(variable_list& out) : out(out) {} |
138 | variable_list& out; |
139 | void operator()(const at::Tensor& x) { |
140 | out.emplace_back(x); |
141 | } |
142 | void operator()(const c10::optional<at::Tensor>& x) { |
143 | if (x.has_value()) |
144 | out.emplace_back(x.value()); |
145 | } |
146 | void operator()(at::ArrayRef<at::Tensor> xs) { |
147 | out.insert(out.end(), xs.begin(), xs.end()); |
148 | } |
149 | }; |
150 | |
151 | template <typename... Args> |
152 | inline variable_list flatten_tensor_args(Args&&... args) { |
153 | variable_list out; |
154 | out.reserve(count_tensors(std::forward<Args>(args)...)); |
155 | Flatten(out).apply(std::forward<Args>(args)...); |
156 | return out; // RVO |
157 | } |
158 | |
159 | // See NOTE [ Autograd View Variables ] for details. |
160 | inline at::Tensor as_view( |
161 | const at::Tensor& base, |
162 | const at::Tensor& tensor, |
163 | bool is_bw_differentiable, |
164 | bool is_fw_differentiable, |
165 | std::function<at::Tensor(const at::Tensor&)> view_func = nullptr, |
166 | CreationMeta creation_meta = CreationMeta::DEFAULT, |
167 | bool allow_tensor_metadata_change = true) { |
168 | // Note [View of inference tensor] |
169 | // For inference tensor this code can only be hit outside InferenceMode |
170 | // since ADInplaceOrView is in the default_included_set. |
171 | // If Inplace and View were separate dispatch keys we can just put Inplace |
172 | // in the default_included_set, so that view ops on inference tensor doesn't |
173 | // have to go through as_view even outside InferenceMode. |
174 | if (base.is_inference()) |
175 | return tensor; |
176 | |
177 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); |
178 | |
179 | // To speed up the most common case, we specially handle when both the forward |
180 | // and backward view infos are the same, and so a single shared ViewInfo can |
181 | // be used for both of them. |
182 | if ((!diff_view_meta || diff_view_meta->shared_view_info()) && |
183 | is_bw_differentiable && is_fw_differentiable) { |
184 | throw_error_if_base_and_tensor_are_same(base, tensor); |
185 | if (diff_view_meta) { |
186 | creation_meta = propagate_creation_meta( |
187 | diff_view_meta->get_creation_meta(), creation_meta); |
188 | return make_variable_differentiable_view( |
189 | tensor, |
190 | diff_view_meta->get_backward_view().chain( |
191 | base, tensor, std::move(view_func)), |
192 | c10::nullopt, |
193 | /*shared_view_info*/ true, |
194 | creation_meta, |
195 | allow_tensor_metadata_change); |
196 | } else { |
197 | return make_variable_differentiable_view( |
198 | tensor, |
199 | ViewInfo(base, std::move(view_func)), |
200 | c10::nullopt, |
201 | /*shared_view_info*/ true, |
202 | creation_meta, |
203 | allow_tensor_metadata_change); |
204 | } |
205 | } |
206 | |
207 | // If they cannot be shared, create the required view infos |
208 | c10::optional<ViewInfo> new_bw_info; |
209 | c10::optional<ViewInfo> new_fw_info; |
210 | |
211 | if (is_bw_differentiable) { |
212 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
213 | const auto& base_bw_info = diff_view_meta->get_backward_view(); |
214 | new_bw_info = base_bw_info.chain(base, tensor, view_func); |
215 | } else { |
216 | new_bw_info = ViewInfo(base, view_func); |
217 | } |
218 | } else { |
219 | TORCH_CHECK( |
220 | creation_meta == CreationMeta::DEFAULT, |
221 | "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT" ); |
222 | } |
223 | |
224 | if (is_fw_differentiable) { |
225 | // Check if base is a forward differentiable view |
226 | if (diff_view_meta && diff_view_meta->has_fw_view()) { |
227 | const auto& base_fw_info = diff_view_meta->get_forward_view(); |
228 | new_fw_info = base_fw_info.chain(base, tensor, std::move(view_func)); |
229 | } else { |
230 | new_fw_info = ViewInfo(base, std::move(view_func)); |
231 | } |
232 | } |
233 | |
234 | if (is_fw_differentiable || is_bw_differentiable) { |
235 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
236 | creation_meta = propagate_creation_meta( |
237 | diff_view_meta->get_creation_meta(), creation_meta); |
238 | } |
239 | throw_error_if_base_and_tensor_are_same(base, tensor); |
240 | return make_variable_differentiable_view( |
241 | tensor, |
242 | std::move(new_bw_info), |
243 | std::move(new_fw_info), |
244 | /*shared_view_info*/ false, |
245 | creation_meta, |
246 | allow_tensor_metadata_change); |
247 | } else { |
248 | return make_variable_non_differentiable_view( |
249 | base, tensor, allow_tensor_metadata_change); |
250 | } |
251 | } |
252 | |
253 | // See NOTE [ Autograd View Variables ] for details. |
254 | inline std::vector<at::Tensor> as_view( |
255 | const at::Tensor& base, |
256 | std::vector<at::Tensor>& tensors, |
257 | bool is_bw_differentiable, |
258 | bool is_fw_differentiable, |
259 | CreationMeta creation_meta = CreationMeta::DEFAULT) { |
260 | // See Note [View of inference tensor] |
261 | if (base.is_inference()) |
262 | return tensors; |
263 | |
264 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); |
265 | |
266 | // Special case when view info can be shared for forward and backward |
267 | // differentiable views |
268 | if ((!diff_view_meta || diff_view_meta->shared_view_info()) && |
269 | is_bw_differentiable && is_fw_differentiable) { |
270 | c10::optional<ViewInfo> new_shared_info; |
271 | if (diff_view_meta) { |
272 | // TODO: fix fb internal use-case so that it doesn't trigger this internal |
273 | // assert when the base is not a view. For now, we only do that same |
274 | // (wrong) thing as the old code which is to only check when the inputs is |
275 | // a backward differentiable view |
276 | if (diff_view_meta->has_bw_view()) { |
277 | TORCH_INTERNAL_ASSERT( |
278 | creation_meta == CreationMeta::NO_GRAD_MODE || |
279 | creation_meta == CreationMeta::INFERENCE_MODE || |
280 | creation_meta == CreationMeta::MULTI_OUTPUT_NODE, |
281 | "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive." ); |
282 | } |
283 | creation_meta = propagate_creation_meta( |
284 | diff_view_meta->get_creation_meta(), creation_meta); |
285 | const auto& base_bw_info = diff_view_meta->get_backward_view(); |
286 | new_shared_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr); |
287 | } else { |
288 | new_shared_info = ViewInfo(base, /* view_func */ nullptr); |
289 | } |
290 | |
291 | for (at::Tensor& tensor : tensors) { |
292 | if (is_fw_differentiable || is_bw_differentiable) { |
293 | tensor = make_variable_differentiable_view( |
294 | tensor, |
295 | new_shared_info, |
296 | c10::nullopt, |
297 | /*shared_view_info*/ true, |
298 | creation_meta); |
299 | } else { |
300 | tensor = make_variable_non_differentiable_view(base, tensor); |
301 | } |
302 | } |
303 | return tensors; |
304 | } |
305 | |
306 | c10::optional<ViewInfo> new_bw_info = c10::nullopt; |
307 | c10::optional<ViewInfo> new_fw_info = c10::nullopt; |
308 | |
309 | if (is_bw_differentiable) { |
310 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); |
311 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
312 | const auto& base_bw_info = diff_view_meta->get_backward_view(); |
313 | // TODO: fix fb internal use-case so that it doesn't trigger this internal |
314 | // assert when the base is not a view. In this code, the assert should be |
315 | // outside of the if statement. |
316 | TORCH_INTERNAL_ASSERT( |
317 | creation_meta == CreationMeta::NO_GRAD_MODE || |
318 | creation_meta == CreationMeta::INFERENCE_MODE || |
319 | creation_meta == CreationMeta::MULTI_OUTPUT_NODE, |
320 | "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive." ); |
321 | // It is ok to create a ViewInfo where only the base is correct in this |
322 | // case as inplace operations on such views are not allowed |
323 | new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr); |
324 | } else { |
325 | new_bw_info = ViewInfo(base, /* view_func */ nullptr); |
326 | } |
327 | } else { |
328 | TORCH_CHECK( |
329 | creation_meta == CreationMeta::DEFAULT, |
330 | "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT" ); |
331 | } |
332 | if (is_fw_differentiable) { |
333 | // Check if base is a forward differentiabble view |
334 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); |
335 | if (diff_view_meta && diff_view_meta->has_fw_view()) { |
336 | const auto& base_fw_info = diff_view_meta->get_forward_view(); |
337 | TORCH_INTERNAL_ASSERT( |
338 | creation_meta == CreationMeta::NO_GRAD_MODE || |
339 | creation_meta == CreationMeta::INFERENCE_MODE || |
340 | creation_meta == CreationMeta::MULTI_OUTPUT_NODE, |
341 | "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive." ); |
342 | // It is ok to create a ViewInfo where only the base is correct in this |
343 | // case as inplace operations on such views are not allowed |
344 | new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ nullptr); |
345 | } else { |
346 | new_fw_info = ViewInfo(base, /* view_func */ nullptr); |
347 | } |
348 | } |
349 | |
350 | if ((is_fw_differentiable || is_bw_differentiable) && base.is_view()) { |
351 | // is_view() => diff_view_meta |
352 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); |
353 | creation_meta = propagate_creation_meta( |
354 | diff_view_meta->get_creation_meta(), creation_meta); |
355 | } |
356 | |
357 | for (at::Tensor& tensor : tensors) { |
358 | if (is_fw_differentiable || is_bw_differentiable) { |
359 | tensor = make_variable_differentiable_view( |
360 | tensor, |
361 | new_bw_info, |
362 | new_fw_info, |
363 | /*shared_view_info*/ false, |
364 | creation_meta); |
365 | } else { |
366 | tensor = make_variable_non_differentiable_view(base, tensor); |
367 | } |
368 | } |
369 | return tensors; |
370 | } |
371 | |
372 | inline void check_no_requires_grad( |
373 | const at::Tensor& tensor, |
374 | const char* name, |
375 | const char* fn_name = "" , |
376 | bool check_grad_mode = true) { |
377 | TORCH_CHECK( |
378 | !(tensor.defined() && tensor.requires_grad()) || |
379 | !(check_grad_mode && GradMode::is_enabled()), |
380 | "The function '" , |
381 | fn_name, |
382 | "' is not differentiable with respect to argument '" , |
383 | name, |
384 | "'. This input cannot have requires_grad True." ); |
385 | } |
386 | |
387 | inline void check_no_requires_grad( |
388 | const c10::optional<at::Tensor>& tensor, |
389 | const char* name, |
390 | const char* fn_name = "" ) { |
391 | if (tensor.has_value()) { |
392 | check_no_requires_grad(*tensor, name, fn_name); |
393 | } |
394 | } |
395 | |
396 | inline void check_no_requires_grad( |
397 | at::ITensorListRef tensors, |
398 | const char* name, |
399 | const char* fn_name = "" ) { |
400 | // GradMode check is expensive, so check it only once for TensorLists |
401 | if (!GradMode::is_enabled()) { |
402 | return; |
403 | } |
404 | for (auto& tensor : tensors) { |
405 | check_no_requires_grad(tensor, name, fn_name, /*check_grad_mode*/ false); |
406 | } |
407 | } |
408 | |
409 | inline void check_no_requires_grad( |
410 | const c10::List<c10::optional<at::Tensor>>& tensors, |
411 | const char* name, |
412 | const char* fn_name = "" ) { |
413 | // GradMode check is expensive, so check it only once for TensorLists |
414 | if (!GradMode::is_enabled()) { |
415 | return; |
416 | } |
417 | for (c10::optional<at::Tensor> tensor : tensors) { |
418 | if (tensor.has_value()) { |
419 | check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false); |
420 | } |
421 | } |
422 | } |
423 | |
424 | // Assumed that saved tensor lists are never inplace outputs |
425 | inline std::vector<SavedVariable> make_saved_variable_list( |
426 | at::ITensorListRef tensors) { |
427 | return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable { |
428 | return SavedVariable{tensor, false /* is output */}; |
429 | }); |
430 | } |
431 | |
432 | // Assumed that saved tensor lists are never inplace outputs |
433 | inline std::vector<SavedVariable> make_saved_variable_list( |
434 | const c10::List<c10::optional<at::Tensor>>& tensors) { |
435 | return fmap( |
436 | tensors, [](const c10::optional<at::Tensor>& tensor) -> SavedVariable { |
437 | if (tensor.has_value()) { |
438 | return SavedVariable{*tensor, false /* is output */}; |
439 | } else { |
440 | return SavedVariable{at::Tensor(), false /* is output */}; |
441 | } |
442 | }); |
443 | } |
444 | |
445 | inline std::vector<std::vector<int64_t>> to_args_sizes( |
446 | at::ITensorListRef tensors) { |
447 | std::vector<std::vector<int64_t>> args_sizes(tensors.size()); |
448 | size_t i = 0; |
449 | for (const auto& t : tensors) { |
450 | args_sizes[i++] = t.sizes().vec(); |
451 | } |
452 | return args_sizes; |
453 | } |
454 | |
455 | inline std::vector<std::vector<c10::SymInt>> to_args_sizes_symint( |
456 | at::ITensorListRef tensors) { |
457 | std::vector<std::vector<c10::SymInt>> args_sizes(tensors.size()); |
458 | size_t i = 0; |
459 | for (const auto& t : tensors) { |
460 | args_sizes[i++] = t.sym_sizes().vec(); |
461 | } |
462 | return args_sizes; |
463 | } |
464 | |
465 | inline std::vector<c10::ScalarType> to_args_scalartypes( |
466 | at::ITensorListRef tensors) { |
467 | std::vector<c10::ScalarType> args_scalartypes(tensors.size()); |
468 | size_t i = 0; |
469 | for (const auto& t : tensors) { |
470 | args_scalartypes[i++] = t.scalar_type(); |
471 | } |
472 | return args_scalartypes; |
473 | } |
474 | |
475 | namespace impl { |
476 | |
477 | namespace { |
478 | |
479 | // If run_jit_decomposition were not a member function, we would be able |
480 | // to pass this as a template parameter to c10::Boxedkernel::makeFromFunction. |
481 | // However, member functions cannot be passed this way - instead we wrap our |
482 | // call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor |
483 | class WrapperFunctor final : public c10::OperatorKernel { |
484 | public: |
485 | WrapperFunctor(JitDecompInterface* impl) : impl_(impl){}; |
486 | |
487 | void operator()( |
488 | const c10::OperatorHandle& op, |
489 | c10::DispatchKeySet ks, |
490 | torch::jit::Stack* stack) { |
491 | impl_->run_jit_decomposition(op, stack); |
492 | } |
493 | JitDecompInterface* impl_; |
494 | }; |
495 | |
496 | } // namespace |
497 | |
498 | template <class Return, class... Args> |
499 | Return run_jit_decomposition_with_args_for_jvp( |
500 | c10::string_view name, |
501 | const c10::OperatorHandle& opHandle, |
502 | c10::DispatchKeySet dispatchKeySet, |
503 | Args&&... args) { |
504 | // see NOTE: [Jit Decomposition Interface] |
505 | JitDecompInterface* impl = getJitDecompImpl(); |
506 | |
507 | TORCH_CHECK_NOT_IMPLEMENTED( |
508 | impl && impl->has_jit_decomposition(opHandle.schema()), |
509 | "Trying to use forward AD with " , |
510 | name, |
511 | " that does not support it because it has not been implemented yet.\nPlease file an issue " |
512 | "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " |
513 | "so that we can prioritize its implementation.\n" |
514 | "Note that forward AD support for some operators require PyTorch to be built with " |
515 | "TorchScript and for JIT to be enabled. " |
516 | "If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, " |
517 | "some operators may no longer be used with forward AD." ); |
518 | |
519 | return c10::KernelFunction::makeFromBoxedKernel( |
520 | c10::BoxedKernel::makeFromFunctor( |
521 | std::make_unique<WrapperFunctor>(impl))) |
522 | .call<Return, Args...>( |
523 | opHandle, dispatchKeySet, std::forward<Args>(args)...); |
524 | } |
525 | |
526 | } // namespace impl |
527 | |
528 | } // namespace autograd |
529 | } // namespace torch |
530 | |