1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | |
5 | namespace torch { |
6 | namespace autograd { |
7 | |
8 | // [ Using ForwardGrad ] |
9 | // ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner |
10 | // design. But this shared_ptr must be uniquely associated with the object that |
11 | // stores it (as of writing, either AutogradMeta or SavedVariable). This object |
12 | // is called the "owning object" in the discussions below. This owning object |
13 | // must call `ForwardGrad::clear()` when it is destroyed to ensure that the |
14 | // ForwardGrad is properly de-allocated. |
15 | |
16 | struct ForwardGrad; |
17 | |
18 | // This file contains two classes that are used to store forward AD gradients |
19 | // and ensure that they are scoped properly. Because forward AD runs |
20 | // concurrently with the evaluation of the function, we need a mechanism to |
21 | // separate different forward AD invocations and be able to compute the right |
22 | // gradients. We model such invocations as levels here. The particular scoping |
23 | // issue mentioned above has two main drivers: |
24 | // - Ensure that we can conveniently use forward AD within a high level API |
25 | // without |
26 | // leaking the forward AD states outside. |
27 | // - Ensure that we can keep the level that we expose to the user API simple |
28 | // (an integer |
29 | // that represents the nesting depth) while avoiding confusions when the |
30 | // level index is re-used. |
31 | |
32 | // The important external APIs from this file are: |
33 | // - ForwardADLevel::get_next_idx() that can be used to enter a new level and |
34 | // get its index |
35 | // - ForwardADLevel::release_idx() that can be used to exit a given level. |
36 | // - ForwardGrad() can be used to store a given forward gradient that will |
37 | // handle the level |
38 | // tracking automatically. |
39 | |
40 | // The basic implementation strategy is as follows: |
41 | // Every tensor has a ForwardGrad, maintaining a map from levels to tangents. |
42 | // ForwardGrad is responsible for registering itself to the appropriate |
43 | // ForwardADLevel when a new tangent is added to it via ForwardGrad::set_value |
44 | // and to un-register itself from this same level if that tangent is removed via |
45 | // ForwardGrad::reset. The ForwardADLevel is created when a new level is entered |
46 | // via ForwardADLevel::get_next_idx. A reference to the new ForwardADLevel is |
47 | // stored into a global (for the whole process) vector that ensure it can be |
48 | // accessed via ForwardADLevel::get_by_idx. This reference is deleted when the |
49 | // index is released by the user when calling ForwardADLevel::release_idx. When |
50 | // it is destructed, the ForwardADLevel is responsible for clearing all the |
51 | // tangents for its level stored in all the ForwardGrad that registered with it. |
52 | // |
53 | // This process-wide level design, compared to a thread local one, allows us to |
54 | // use very simple user facing handle for the level (an int) while enabling |
55 | // cross-thread forward AD. The only required synchronization for the user is |
56 | // when entering and exiting the levels. Some discussion on alternative design |
57 | // is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 and |
58 | // can be refined in the future. |
59 | |
60 | // Correctness of concurrency: |
61 | // Each class uses its own lock when reading or modifying internal storages. |
62 | // This allows in particular to safely remove tangents from ForwardGrad when the |
63 | // ForwardADLevel is being exited. We ensure no deadlock by ensuring that a |
64 | // methods never calls into another class's method while the local class's lock |
65 | // is held except in one single case: calling from ForwardADLevel's destructor |
66 | // into ForwardGrad::reset with update_level=false. |
67 | |
68 | // The lifetime of these objects is as follows: |
69 | // The ForwardADLevel can be in three states: |
70 | // - Initialized: where one of its reference is held by the global vector |
71 | // and there may be more |
72 | // references held by temporary variables in ForwardGrad's methods. |
73 | // - About to be destructed: where "release_idx" has been called and the |
74 | // only reason for the |
75 | // ForwardADLevel not to be destructed right away is that some methods in |
76 | // ForwardGrad have owning reference to it. This is done so that a |
77 | // ForwardADLevel can never be destructed when a ForwardGrad is |
78 | // registered with it and in the process of adding something to its |
79 | // internal state. |
80 | // - Being destructed: Here the ForwardADLevel is not referenced anymore |
81 | // and can be safely reset |
82 | // all of the ForwardGrad. Note that we can have more than one reset |
83 | // being called here (which is ok) but we are guaranteed that there is at |
84 | // least one. |
85 | // The ForwardGrad is simpler as there is no intermediary state and no special |
86 | // destructor for. The logic to unregister it from the different ForwardADLevel |
87 | // is done when the owning object (AutogradMeta or SavedVariable) is being |
88 | // destroyed. |
89 | |
90 | // Other considered design: |
91 | // To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside |
92 | // the ForwardADLevel. While this would work, it would mean that the set inside |
93 | // the ForwardADLevel would only grow unless we do an expensive linear scan to |
94 | // remove all the dangling weak pointers. Hence this approach was not used. |
95 | |
96 | // Data structures in this file are optimized for this maximum number of levels. |
97 | // The number of levels corresponds to the degree of the gradient being |
98 | // computed using forward AD and we don't expect more than second order |
99 | // gradients to be common. |
100 | #define EXPECTED_MAX_LEVEL 2 |
101 | |
102 | struct TORCH_API ForwardADLevel { |
103 | ForwardADLevel(uint64_t idx) : idx_(idx) {} |
104 | ~ForwardADLevel(); |
105 | |
106 | static uint64_t get_next_idx(); |
107 | static void release_idx(uint64_t idx); |
108 | static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx); |
109 | static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx); |
110 | |
111 | void erase(const std::shared_ptr<ForwardGrad>& grad) { |
112 | std::lock_guard<std::mutex> lock(mutex_); |
113 | grads_.erase(grad); |
114 | } |
115 | |
116 | void insert(const std::shared_ptr<ForwardGrad>& grad) { |
117 | std::lock_guard<std::mutex> lock(mutex_); |
118 | grads_.insert(grad); |
119 | } |
120 | |
121 | private: |
122 | std::unordered_set<std::shared_ptr<ForwardGrad>> grads_; |
123 | std::mutex mutex_; |
124 | uint64_t idx_; |
125 | }; |
126 | |
127 | struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> { |
128 | ForwardGrad() = default; |
129 | |
130 | // This function must only be called when AutogradMeta or SavedVariable is |
131 | // being destructed as it ensures that: |
132 | // - The only (potential) other references to this ForwardGrad are the |
133 | // different level it is registered to |
134 | // - No other thread will try to call `set_value` or `value` ever from now |
135 | // on |
136 | // - Any of the ForwardADLevel that this ForwardGrad is registered with |
137 | // might |
138 | // call `reset` at any point during this function |
139 | void clear() { |
140 | c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx; |
141 | |
142 | { |
143 | std::lock_guard<std::mutex> lock(mutex_); |
144 | for (auto& c : content_) { |
145 | levels_idx.push_back(c.first); |
146 | } |
147 | } |
148 | |
149 | for (auto l_idx : levels_idx) { |
150 | // Use "try" version here as another thread might have deleted this |
151 | // level before we got here |
152 | // This is an owning reference as we want to keep the level alive |
153 | // until we successfully unregister ourselves |
154 | auto level = ForwardADLevel::try_get_by_idx(l_idx); |
155 | if (level) { |
156 | level->erase(shared_from_this()); |
157 | } |
158 | } |
159 | } |
160 | |
161 | void set_value(const at::Tensor& value, uint64_t level) { |
162 | // Owning reference to ensure the forward_level is not destroyed |
163 | // while we are updating our internal state |
164 | auto forward_level = ForwardADLevel::get_by_idx(level); |
165 | forward_level->insert(shared_from_this()); |
166 | |
167 | std::lock_guard<std::mutex> lock(mutex_); |
168 | content_.insert({level, value}); |
169 | } |
170 | |
171 | // This function removes the tangent for a given level from this ForwardGrad |
172 | // Use the update_level flag to disable notifying the level about this reset |
173 | // This flag is most notably used by the ForwardADLevel destructor. |
174 | void reset(uint64_t level, bool update_level = true) { |
175 | if (update_level) { |
176 | ForwardADLevel::get_by_idx(level)->erase(shared_from_this()); |
177 | } |
178 | |
179 | std::unique_lock<std::mutex> lock(mutex_); |
180 | const auto& it = content_.find(level); |
181 | TORCH_INTERNAL_ASSERT( |
182 | it != content_.end(), "Resetting a non-existent level." ); |
183 | // Keep the Tensor alive until we have released the lock |
184 | // This is needed as we can be in a case where this function is called by |
185 | // ForwardADLevel destructor |
186 | auto t = (*it).second; |
187 | content_.erase(level); |
188 | lock.unlock(); |
189 | } |
190 | |
191 | const at::Tensor& value(uint64_t level) const; |
192 | |
193 | bool contains(uint64_t level) { |
194 | std::lock_guard<std::mutex> lock(mutex_); |
195 | return content_.count(level) > 0; |
196 | } |
197 | |
198 | bool empty() const { |
199 | return content_.empty(); |
200 | } |
201 | |
202 | static const at::Tensor& undef_grad(); |
203 | |
204 | private: |
205 | // TODO(albanD): replace this with a SmallVector |
206 | std::unordered_map<uint64_t, at::Tensor> content_; |
207 | mutable std::mutex mutex_; |
208 | }; |
209 | |
210 | } // namespace autograd |
211 | } // namespace torch |
212 | |