1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8#include <ATen/functorch/Macros.h>
9#include <c10/core/DispatchKey.h>
10#include <ATen/core/function_schema.h>
11#include <c10/util/Optional.h>
12#include <c10/util/variant.h>
13#include <unordered_map>
14#include <mutex>
15#include <c10/core/impl/LocalDispatchKeySet.h>
16#include <ATen/functorch/Interpreter.h>
17#include <ATen/functorch/VmapInterpreter.h>
18#include <ATen/functorch/ADInterpreters.h>
19#include <ATen/functorch/FunctionalizeInterpreter.h>
20
21// Forward declared
22namespace c10 { struct AutogradMetaInterface; }
23
24namespace at {
25namespace functorch {
26
27// This file contains the implementation of functorch's interpreter stack.
28// See NOTE: [functorch interpreter stack] first before reading on.
29//
30// NB: the functorch interpreter stack is also referred to as:
31// - the "dynamic layer stack" -- an older name for "interpreter" was
32// "dynamic layer".
33// - the "functorch mode stack". You can think of each functorch transform as a
34// "mode" (in the same sense as torch_dispatch mode or torch_function mode),
35// and functorch being an implementation of a "mode stack" where the modes
36// may be arbitrary composed.
37
38// DynamicLayer is basically the same thing as an Interpreter.
39// It represents a functorch transform and it holds an Interpreter,
40// which contains metadata related to the transform and instructions on
41// how to perform the transform.
42//
43// TODO: we can excise DynamicLayer in favor of Interpreter,
44// But I am going to leave it for now as a compatiblity shim to avoid
45// needing to refactor a lot of callsites...
46struct TORCH_API DynamicLayer {
47 explicit DynamicLayer(
48 TransformType transform_type,
49 int64_t layerId,
50 optional<int64_t> batchSize = nullopt,
51 optional<RandomnessType> randomness = nullopt,
52 optional<bool> prev_grad_mode = nullopt,
53 optional<bool> pre_fwd_grad_mode = nullopt,
54 optional<bool> functionalize_add_back_views = nullopt);
55
56 TransformType key() const;
57 int64_t layerId() const;
58
59 const Interpreter& interpreter() const { return interpreter_; }
60 Interpreter& interpreter() { return interpreter_; }
61
62 // Only valid for vmap
63 int64_t batchSize() const;
64 RandomnessType randomness() const;
65
66 private:
67 Interpreter interpreter_;
68};
69
70TORCH_API int64_t initAndPushDynamicLayer(
71 TransformType transform_type,
72 optional<int64_t> batch_size = nullopt,
73 optional<RandomnessType> randomness = nullopt,
74 optional<bool> prev_grad_mode = nullopt,
75 optional<bool> prev_fwd_grad_mode = nullopt,
76 optional<bool> functionalize_add_back_views = nullopt);
77TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
78TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
79TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
80TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
81TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
82
83// NOTE: [Life handles and lexically scoped transforms]
84// functorch transforms are lexically scoped.
85// Given a level, we store a "life handle" that is a boolean that tells us if the
86// transform with that level is active or not.
87//
88// functorch's TensorWrapper (for grad transforms) stores a life handle.
89// If a TensorWrapper escapes from the scope of the transform, then somehow
90// it must know it escaped; it can tell by querying the life handle.
91TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
92
93// Returns if an operator is in-place. An operator is inplace if:
94// 1. The first argument is a Tensor and it is being written to
95// 2. The first argument is being returned
96// 3. No other arguments are aliased
97// Here is an example of an in-place operator:
98// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
99TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
100
101// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
102TORCH_API c10::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
103
104TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
105TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
106
107// Pretty printers
108TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
109TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
110
111// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
112// is disabled by default. The following two APIs are APIs for enabling
113// it. These are not user-facing APIs. We can delete this in the future, but
114// it is useful for debugging when something goes wrong with the
115// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
116// because it leads to loud errors if something is incorrect.
117TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
118TORCH_API bool getSingleLevelAutogradFunctionAllowed();
119
120// While a functorch grad transform is active, Tensor.requires_grad_() gets
121// disabled. These two functions are the mechanism to controlling that.
122TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
123TORCH_API bool getInplaceRequiresGradAllowed();
124
125TORCH_API DynamicLayer popDynamicLayer();
126TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
127
128}
129} // namespace at
130