1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#include <ATen/FunctionalTensorWrapper.h>
8#include <ATen/WrapDimUtils.h>
9#include <torch/python.h>
10
11#include <ATen/functorch/BatchRulesHelper.h>
12#include <ATen/functorch/BatchedFallback.h>
13#include <ATen/functorch/BatchedTensorImpl.h>
14#include <ATen/functorch/DynamicLayer.h>
15#include <ATen/functorch/Interpreter.h>
16#include <ATen/functorch/LegacyVmapTransforms.h>
17#include <ATen/functorch/PlumbingHelper.h>
18#include <ATen/functorch/TensorWrapper.h>
19#include <c10/core/AutogradState.h>
20
21// This file contains functorch's Python bindings.
22
23namespace torch {
24namespace functorch {
25namespace impl {
26
27using namespace at::functorch;
28
29static bool has_level(const Tensor& self, int64_t level) {
30 const auto* batched = maybeGetBatchedImpl(self);
31 if (!batched) {
32 return false;
33 }
34 return batched->level() >= level;
35}
36
37Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
38 return addBatchDim(self, batch_dim, level);
39}
40
41Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) {
42 auto t = at::functionalization::impl::to_functional_tensor(self);
43 at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level);
44 return t;
45}
46
47void _assert_wrapped_functional(
48 const Tensor& unwrapped,
49 const Tensor& wrapped) {
50 TORCH_INTERNAL_ASSERT(
51 at::functionalization::impl::isFunctionalTensor(wrapped));
52 TORCH_INTERNAL_ASSERT(
53 !at::functionalization::impl::isFunctionalTensor(unwrapped));
54 auto wrapped_impl =
55 at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
56 auto& wrapped_inner = wrapped_impl->value();
57 TORCH_INTERNAL_ASSERT(
58 unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl())
59}
60
61void _propagate_functional_input_mutation(
62 const Tensor& unwrapped,
63 const Tensor& wrapped) {
64 TORCH_INTERNAL_ASSERT(
65 at::functionalization::impl::isFunctionalTensor(wrapped));
66 TORCH_INTERNAL_ASSERT(
67 !at::functionalization::impl::isFunctionalTensor(unwrapped));
68 auto wrapped_impl =
69 at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
70 // Ensure that the input is up to date by committing any pending updates to
71 // the alias.
72 wrapped_impl->sync_();
73 auto& wrapped_inner = wrapped_impl->value();
74 // It would probably be more reasonable to check that the two tensors are
75 // aliased, but we can't do that unless we give BatchedTensorImpl a notion of
76 // storage.
77 if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) {
78 } else {
79 if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) {
80 // Functions might resize zero-sized inputs, which we need to reflect
81 // ehre.
82 unwrapped.resize__symint(wrapped_inner.sym_sizes());
83 }
84 // If the input tensor's metadata was mutated, then use as_strided_()
85 // to propagate the metadata change.
86 if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) {
87 unwrapped.as_strided__symint(
88 wrapped_inner.sym_sizes(), wrapped_inner.sym_strides());
89 }
90 unwrapped.copy_(wrapped_inner);
91 }
92}
93
94static std::pair<Tensor, int64_t> remove_existing_batch_dim(
95 const BatchedTensorImpl* batched,
96 int64_t level) {
97 TORCH_INTERNAL_ASSERT(batched->level() == level);
98 return std::make_pair(batched->value(), batched->bdim());
99}
100
101// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
102// while preserving the order of other existing dimensions.
103// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
104// When we do, replace the following with it.
105static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
106 auto logical_dim = self.dim();
107 src = at::maybe_wrap_dim(src, logical_dim);
108 dst = at::maybe_wrap_dim(dst, logical_dim);
109 if (src == dst) {
110 return self;
111 }
112 VmapDimVector permutation;
113 permutation.reserve(logical_dim);
114 for (int64_t dim = 0; dim < logical_dim; dim++) {
115 if (dim == src) {
116 continue;
117 }
118 permutation.push_back(dim);
119 }
120 permutation.insert(permutation.begin() + dst, src);
121 return self.permute(permutation);
122}
123
124// Removes the batch dim with level `level` from `self`. If this causes the
125// last batch dim to be removed from a BatchedTensor, then this returns a
126// regular Tensor.
127//
128// If the `level` of the batch dim to remove does not exist in `self`, then we
129// add the batch dim in. This can happen if `self` didn't interact with a tensor
130// inside the vmap level, for example,
131// self = torch.randn(3)
132// y = torch.randn(5)
133// out = vmap(lambda x: vmap(lambda y: x)(y))(self)
134// assert out.shape == (3, 5)
135// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
136// corresponding to the *outer* vmap level and it doesn't have any dimensions
137// that correspond to the inner vmap level so we need to create one for the
138// user.
139//
140// `out_dim` controls where we should put the batch dimension in the output
141// tensor.
142Tensor _remove_batch_dim(
143 const Tensor& self,
144 int64_t level,
145 int64_t batch_size,
146 int64_t out_dim) {
147 if (!has_level(self, level)) {
148 auto self_sizes = self.sizes();
149 VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
150 expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
151 auto result = self.expand(expanded_sizes);
152 return result;
153 }
154
155 // Must be batched if has_level(self, /*any_level*/)
156 const auto* batched = maybeGetBatchedImpl(self);
157 TORCH_INTERNAL_ASSERT(batched != nullptr);
158
159 Tensor self_without_bdim;
160 int64_t newly_exposed_logical_dim;
161 std::tie(self_without_bdim, newly_exposed_logical_dim) =
162 remove_existing_batch_dim(batched, level);
163 auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
164 return result;
165}
166
167Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) {
168 // We only ever call that after popping out of a functionalize() call, in
169 // which case the current tensors should always be wrapped in a
170 // FunctionalTensorWrapper.
171 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
172 auto functional =
173 at::functionalization::impl::unsafeGetFunctionalWrapper(self);
174
175 // when regenerating the (potentially mutated) input tensors, the
176 // functionalization pass regenerates them through a series of view_copy() op
177 // calls. Functorch wants to turn those back into view ops though. Ensure that
178 // the input is up to date by committing any pending updates to the alias.
179 at::functionalization::impl::FunctionalizationReapplyViewsGuard guard(
180 add_back_views);
181 bool any_updates = functional->apply_updates();
182 if (any_updates) {
183 functional->regenerate_from_base();
184 }
185 return functional->value();
186}
187
188Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
189 // NB: different behavior inside??
190 // return self;
191 // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
192 // TORCH_INTERNAL_ASSERT(self.has_storage());
193 return makeTensorWrapper(self, level);
194}
195
196Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
197 auto* result = maybeGetTensorWrapper(self);
198 if (!result) {
199 return self;
200 }
201 TORCH_INTERNAL_ASSERT(result->level().has_value());
202 if (result->level() == level) {
203 return result->value();
204 }
205 return self;
206}
207
208int64_t dlevel(const Tensor& tensor) {
209 auto* wrapped = maybeGetTensorWrapper(tensor);
210 if (!wrapped) {
211 return 0;
212 }
213 if (!wrapped->is_alive()) {
214 return -1;
215 }
216 return wrapped->level().value();
217}
218
219bool dump_tensor(const Tensor& self) {
220 dumpTensorCout(self);
221 return true;
222}
223
224RandomnessType get_randomness_enum(const std::string& randomness) {
225 if (randomness == "error") {
226 return RandomnessType::Error;
227 } else if (randomness == "same") {
228 return RandomnessType::Same;
229 } else if (randomness == "different") {
230 return RandomnessType::Different;
231 } else {
232 TORCH_CHECK(
233 false, "randomness argument must be error, same, or different.");
234 }
235}
236
237int64_t _grad_increment_nesting() {
238 // See NOTE [grad and vjp interaction with no_grad]
239 bool prev_grad_mode = c10::GradMode::is_enabled();
240 return initAndPushDynamicLayer(
241 TransformType::Grad, c10::nullopt, c10::nullopt, prev_grad_mode);
242}
243
244int64_t _grad_decrement_nesting() {
245 auto layer = popDynamicLayerAndDeleteMetadata();
246 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad);
247 return layer.layerId();
248}
249
250int64_t _jvp_increment_nesting() {
251 // See NOTE [grad and vjp interaction with no_grad]
252 bool prev_fwd_grad_mode =
253 c10::AutogradState::get_tls_state().get_fw_grad_mode();
254 return initAndPushDynamicLayer(
255 TransformType::Jvp,
256 c10::nullopt,
257 c10::nullopt,
258 c10::nullopt,
259 prev_fwd_grad_mode);
260}
261
262int64_t _jvp_decrement_nesting() {
263 auto layer = popDynamicLayerAndDeleteMetadata();
264 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp);
265 return layer.layerId();
266}
267
268int64_t _vmap_increment_nesting(
269 int64_t batch_size,
270 const std::string& randomness) {
271 return initAndPushDynamicLayer(
272 TransformType::Vmap, batch_size, get_randomness_enum(randomness));
273}
274
275int64_t _vmap_decrement_nesting() {
276 auto layer = popDynamicLayerAndDeleteMetadata();
277 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap);
278 return layer.layerId();
279}
280
281int64_t _func_increment_nesting(bool reapply_views) {
282 return initAndPushDynamicLayer(
283 TransformType::Functionalize,
284 c10::nullopt,
285 c10::nullopt,
286 c10::nullopt,
287 c10::nullopt,
288 /*functionalize_add_back_views=*/reapply_views);
289}
290
291int64_t _func_decrement_nesting() {
292 auto layer = popDynamicLayerAndDeleteMetadata();
293 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize);
294 return layer.layerId();
295}
296
297static bool is_batchedtensor(const Tensor& tensor) {
298 auto* batched = maybeGetBatchedImpl(tensor);
299 return batched != nullptr;
300}
301
302static bool is_gradtrackingtensor(const Tensor& tensor) {
303 auto* wrapped = maybeGetTensorWrapper(tensor);
304 return wrapped != nullptr;
305}
306
307static bool is_functionaltensor(const Tensor& tensor) {
308 return tensor.unsafeGetTensorImpl()->key_set().has(
309 c10::DispatchKey::Functionalize);
310}
311
312static Tensor get_unwrapped(const Tensor& tensor) {
313 auto* batched = maybeGetBatchedImpl(tensor);
314 if (batched) {
315 return batched->value();
316 }
317 auto* wrapped = maybeGetTensorWrapper(tensor);
318 if (wrapped) {
319 return wrapped->value();
320 }
321 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
322 auto* functional =
323 at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
324 return functional->value();
325 }
326 TORCH_CHECK(false, "No wrappers present!");
327}
328
329static int64_t maybe_get_level(const Tensor& tensor) {
330 auto* batched = maybeGetBatchedImpl(tensor);
331 if (batched) {
332 return batched->level();
333 }
334 auto* wrapped = maybeGetTensorWrapper(tensor);
335 if (wrapped) {
336 if (wrapped->level()) {
337 return *wrapped->level();
338 }
339 // TODO: this is a weird special case...
340 return -2;
341 }
342 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
343 auto* functional =
344 at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
345 return functional->level();
346 }
347 return -1;
348}
349
350static int64_t maybe_get_bdim(const Tensor& tensor) {
351 auto* batched = maybeGetBatchedImpl(tensor);
352 if (batched) {
353 return batched->bdim();
354 }
355 return -1;
356}
357
358static int64_t currentLevel() {
359 auto maybe_layer = maybeCurrentDynamicLayer();
360 TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
361 int64_t current_level = maybe_layer->layerId();
362 return current_level;
363}
364
365static void tls_set_vmap_excluded(bool excluded) {
366 c10::impl::tls_set_dispatch_key_excluded(
367 c10::DispatchKey::FuncTorchBatched, excluded);
368}
369
370static void _set_dynamic_layer_keys_included(bool value) {
371 return setDynamicLayerFrontBackKeysIncluded(value);
372}
373
374static void dump_dls() {
375 std::cout << getDynamicLayerStack() << std::endl;
376}
377
378static void dump_local_tls() {
379 auto tls = c10::impl::tls_local_dispatch_key_set();
380 std::cout << "[Local Include] " << tls.included_ << std::endl;
381 std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
382}
383
384static std::tuple<Tensor, c10::optional<int64_t>> unwrapBatched(
385 const Tensor& tensor,
386 int64_t level) {
387 auto* batched = maybeGetBatchedImpl(tensor);
388 if (!batched) {
389 return std::make_tuple(tensor, nullopt);
390 }
391 if (batched->level() == level) {
392 return std::make_tuple(batched->value(), batched->bdim());
393 }
394 return std::make_tuple(tensor, nullopt);
395}
396
397void initFuncTorchBindings(PyObject* module) {
398 auto _C = py::handle(module).cast<py::module>();
399 auto m = _C.def_submodule("_functorch");
400
401 m.def("_add_batch_dim", &_add_batch_dim, "add batch dim");
402 m.def("_remove_batch_dim", &_remove_batch_dim, "remove batch dim");
403 m.def("_unwrap_batched", &unwrapBatched);
404 m.def(
405 "_wrap_functional_tensor",
406 &_wrap_functional_tensor,
407 "add functional tensor");
408 m.def(
409 "_assert_wrapped_functional",
410 &_assert_wrapped_functional,
411 "assert wrapped functional");
412 m.def(
413 "_propagate_functional_input_mutation",
414 &_propagate_functional_input_mutation,
415 "propagate functional input mutations");
416 m.def(
417 "_unwrap_functional_tensor",
418 &_unwrap_functional_tensor,
419 "remove functional tensor");
420 m.def("_vmap_increment_nesting", &_vmap_increment_nesting);
421 m.def("_vmap_decrement_nesting", &_vmap_decrement_nesting);
422 m.def(
423 "_func_increment_nesting",
424 &_func_increment_nesting,
425 "functionalization start");
426 m.def(
427 "_func_decrement_nesting",
428 &_func_decrement_nesting,
429 "functionalization end");
430 m.def("_grad_increment_nesting", &_grad_increment_nesting);
431 m.def("_grad_decrement_nesting", &_grad_decrement_nesting);
432 m.def("_jvp_increment_nesting", &_jvp_increment_nesting);
433 m.def("_jvp_decrement_nesting", &_jvp_decrement_nesting);
434 m.def("_wrap_for_grad", &_wrap_for_grad, "wrap as gradtrackingtensor");
435 m.def(
436 "_unwrap_for_grad", &_unwrap_for_grad, "unwrap from gradtrackingtensor");
437 m.def(
438 "_set_vmap_fallback_warning_enabled",
439 &at::functorch::setVmapFallbackWarningEnabled,
440 "Set vmap fallback warnings");
441 m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
442 m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
443 m.def(
444 "set_inplace_requires_grad_allowed",
445 &at::functorch::setInplaceRequiresGradAllowed);
446 m.def(
447 "get_inplace_requires_grad_allowed",
448 &at::functorch::getInplaceRequiresGradAllowed);
449 m.def(
450 "set_single_level_autograd_function_allowed",
451 &at::functorch::setSingleLevelAutogradFunctionAllowed);
452 m.def(
453 "get_single_level_autograd_function_allowed",
454 &at::functorch::getSingleLevelAutogradFunctionAllowed);
455 m.def("unwrap_if_dead", &unwrapIfDead);
456 m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
457 m.def("dlevel", &dlevel, "dlevel");
458 m.def("dump_tensor", &dump_tensor, "dump_tensor");
459 m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
460 m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof);
461 // various debugging things. Maybe we should offer these as first-class APIs
462 // on Tensors?
463 m.def("is_batchedtensor", &is_batchedtensor);
464 m.def("is_gradtrackingtensor", &is_gradtrackingtensor);
465 m.def("is_functionaltensor", &is_functionaltensor);
466 m.def("get_unwrapped", &get_unwrapped);
467 m.def("maybe_get_level", &maybe_get_level);
468 m.def("maybe_get_bdim", &maybe_get_bdim);
469 m.def("current_level", &currentLevel);
470 m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
471 m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);
472 m.def("dump_dls", &dump_dls);
473 m.def("dump_local_tls", &dump_local_tls);
474 m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) {
475 return maybe_get_level(tensor) != -1;
476 });
477 m.def("peek_interpreter_stack", []() -> c10::optional<Interpreter> {
478 const auto& stack = getDynamicLayerStack();
479 if (stack.empty()) {
480 return c10::nullopt;
481 }
482 auto result = stack.back().interpreter();
483 return result;
484 });
485 m.def("pop_dynamic_layer_stack", &popDynamicLayer);
486 m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t {
487 return pushDynamicLayer(std::move(layer));
488 });
489 py::class_<DynamicLayer>(m, "DynamicLayer");
490
491 py::enum_<TransformType>(m, "TransformType")
492 .value("Torch", TransformType::Torch)
493 .value("Grad", TransformType::Grad)
494 .value("Jvp", TransformType::Jvp)
495 .value("Functionalize", TransformType::Functionalize)
496 .value("Vmap", TransformType::Vmap);
497 py::enum_<RandomnessType>(m, "RandomnessType")
498 .value("Error", RandomnessType::Error)
499 .value("Same", RandomnessType::Same)
500 .value("Different", RandomnessType::Different);
501 py::class_<Interpreter>(m, "CInterpreter")
502 .def("key", &Interpreter::key)
503 .def("level", &Interpreter::level);
504 py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
505 .def(py::init<const Interpreter*>())
506 .def("key", &GradInterpreterPtr::key)
507 .def("level", &GradInterpreterPtr::level)
508 .def("lift", &GradInterpreterPtr::lift)
509 .def("prevGradMode", &GradInterpreterPtr::prevGradMode);
510 py::class_<JvpInterpreterPtr>(m, "CJvpInterpreterPtr")
511 .def(py::init<const Interpreter*>())
512 .def("key", &JvpInterpreterPtr::key)
513 .def("level", &JvpInterpreterPtr::level)
514 .def("lift", &JvpInterpreterPtr::lift)
515 .def("prevFwdGradMode", &JvpInterpreterPtr::prevFwdGradMode);
516 py::class_<VmapInterpreterPtr>(m, "CVmapInterpreterPtr")
517 .def(py::init<const Interpreter*>())
518 .def("key", &VmapInterpreterPtr::key)
519 .def("level", &VmapInterpreterPtr::level)
520 .def("batchSize", &VmapInterpreterPtr::batchSize)
521 .def("randomness", &VmapInterpreterPtr::randomness);
522 py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
523 .def(py::init<const Interpreter*>())
524 .def("key", &FunctionalizeInterpreterPtr::key)
525 .def("level", &FunctionalizeInterpreterPtr::level)
526 .def(
527 "functionalizeAddBackViews",
528 &FunctionalizeInterpreterPtr::functionalizeAddBackViews);
529}
530
531} // namespace impl
532} // namespace functorch
533} // namespace torch
534