1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #pragma once |
8 | #include <ATen/functorch/Macros.h> |
9 | #include <c10/core/DispatchKey.h> |
10 | #include <ATen/core/function_schema.h> |
11 | #include <c10/util/Optional.h> |
12 | #include <c10/util/variant.h> |
13 | #include <unordered_map> |
14 | #include <mutex> |
15 | #include <c10/core/impl/LocalDispatchKeySet.h> |
16 | #include <ATen/functorch/Interpreter.h> |
17 | #include <ATen/functorch/VmapInterpreter.h> |
18 | #include <ATen/functorch/ADInterpreters.h> |
19 | #include <ATen/functorch/FunctionalizeInterpreter.h> |
20 | |
21 | // Forward declared |
22 | namespace c10 { struct AutogradMetaInterface; } |
23 | |
24 | namespace at { |
25 | namespace functorch { |
26 | |
27 | // This file contains the implementation of functorch's interpreter stack. |
28 | // See NOTE: [functorch interpreter stack] first before reading on. |
29 | // |
30 | // NB: the functorch interpreter stack is also referred to as: |
31 | // - the "dynamic layer stack" -- an older name for "interpreter" was |
32 | // "dynamic layer". |
33 | // - the "functorch mode stack". You can think of each functorch transform as a |
34 | // "mode" (in the same sense as torch_dispatch mode or torch_function mode), |
35 | // and functorch being an implementation of a "mode stack" where the modes |
36 | // may be arbitrary composed. |
37 | |
38 | // DynamicLayer is basically the same thing as an Interpreter. |
39 | // It represents a functorch transform and it holds an Interpreter, |
40 | // which contains metadata related to the transform and instructions on |
41 | // how to perform the transform. |
42 | // |
43 | // TODO: we can excise DynamicLayer in favor of Interpreter, |
44 | // But I am going to leave it for now as a compatiblity shim to avoid |
45 | // needing to refactor a lot of callsites... |
46 | struct TORCH_API DynamicLayer { |
47 | explicit DynamicLayer( |
48 | TransformType transform_type, |
49 | int64_t layerId, |
50 | optional<int64_t> batchSize = nullopt, |
51 | optional<RandomnessType> randomness = nullopt, |
52 | optional<bool> prev_grad_mode = nullopt, |
53 | optional<bool> pre_fwd_grad_mode = nullopt, |
54 | optional<bool> functionalize_add_back_views = nullopt); |
55 | |
56 | TransformType key() const; |
57 | int64_t layerId() const; |
58 | |
59 | const Interpreter& interpreter() const { return interpreter_; } |
60 | Interpreter& interpreter() { return interpreter_; } |
61 | |
62 | // Only valid for vmap |
63 | int64_t batchSize() const; |
64 | RandomnessType randomness() const; |
65 | |
66 | private: |
67 | Interpreter interpreter_; |
68 | }; |
69 | |
70 | TORCH_API int64_t initAndPushDynamicLayer( |
71 | TransformType transform_type, |
72 | optional<int64_t> batch_size = nullopt, |
73 | optional<RandomnessType> randomness = nullopt, |
74 | optional<bool> prev_grad_mode = nullopt, |
75 | optional<bool> prev_fwd_grad_mode = nullopt, |
76 | optional<bool> functionalize_add_back_views = nullopt); |
77 | TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); |
78 | TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer(); |
79 | TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack(); |
80 | TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack); |
81 | TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); |
82 | |
83 | // NOTE: [Life handles and lexically scoped transforms] |
84 | // functorch transforms are lexically scoped. |
85 | // Given a level, we store a "life handle" that is a boolean that tells us if the |
86 | // transform with that level is active or not. |
87 | // |
88 | // functorch's TensorWrapper (for grad transforms) stores a life handle. |
89 | // If a TensorWrapper escapes from the scope of the transform, then somehow |
90 | // it must know it escaped; it can tell by querying the life handle. |
91 | TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level); |
92 | |
93 | // Returns if an operator is in-place. An operator is inplace if: |
94 | // 1. The first argument is a Tensor and it is being written to |
95 | // 2. The first argument is being returned |
96 | // 3. No other arguments are aliased |
97 | // Here is an example of an in-place operator: |
98 | // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) |
99 | TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); |
100 | |
101 | // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped |
102 | TORCH_API c10::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); |
103 | |
104 | TORCH_API Tensor unwrapIfDead(const Tensor& tensor); |
105 | TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); |
106 | |
107 | // Pretty printers |
108 | TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); |
109 | TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack); |
110 | |
111 | // While a functorch transform is active, torch.autograd.function._SingleLevelFunction |
112 | // is disabled by default. The following two APIs are APIs for enabling |
113 | // it. These are not user-facing APIs. We can delete this in the future, but |
114 | // it is useful for debugging when something goes wrong with the |
115 | // autograd.Function <> functorch interaction, which uses _SingleLevelFunction, |
116 | // because it leads to loud errors if something is incorrect. |
117 | TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed); |
118 | TORCH_API bool getSingleLevelAutogradFunctionAllowed(); |
119 | |
120 | // While a functorch grad transform is active, Tensor.requires_grad_() gets |
121 | // disabled. These two functions are the mechanism to controlling that. |
122 | TORCH_API void setInplaceRequiresGradAllowed(bool allowed); |
123 | TORCH_API bool getInplaceRequiresGradAllowed(); |
124 | |
125 | TORCH_API DynamicLayer popDynamicLayer(); |
126 | TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); |
127 | |
128 | } |
129 | } // namespace at |
130 | |