1#pragma once
2
3#include <ATen/core/Tensor.h>
4
5namespace torch {
6namespace autograd {
7
8// [ Using ForwardGrad ]
9// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner
10// design. But this shared_ptr must be uniquely associated with the object that
11// stores it (as of writing, either AutogradMeta or SavedVariable). This object
12// is called the "owning object" in the discussions below. This owning object
13// must call `ForwardGrad::clear()` when it is destroyed to ensure that the
14// ForwardGrad is properly de-allocated.
15
16struct ForwardGrad;
17
18// This file contains two classes that are used to store forward AD gradients
19// and ensure that they are scoped properly. Because forward AD runs
20// concurrently with the evaluation of the function, we need a mechanism to
21// separate different forward AD invocations and be able to compute the right
22// gradients. We model such invocations as levels here. The particular scoping
23// issue mentioned above has two main drivers:
24// - Ensure that we can conveniently use forward AD within a high level API
25// without
26// leaking the forward AD states outside.
27// - Ensure that we can keep the level that we expose to the user API simple
28// (an integer
29// that represents the nesting depth) while avoiding confusions when the
30// level index is re-used.
31
32// The important external APIs from this file are:
33// - ForwardADLevel::get_next_idx() that can be used to enter a new level and
34// get its index
35// - ForwardADLevel::release_idx() that can be used to exit a given level.
36// - ForwardGrad() can be used to store a given forward gradient that will
37// handle the level
38// tracking automatically.
39
40// The basic implementation strategy is as follows:
41// Every tensor has a ForwardGrad, maintaining a map from levels to tangents.
42// ForwardGrad is responsible for registering itself to the appropriate
43// ForwardADLevel when a new tangent is added to it via ForwardGrad::set_value
44// and to un-register itself from this same level if that tangent is removed via
45// ForwardGrad::reset. The ForwardADLevel is created when a new level is entered
46// via ForwardADLevel::get_next_idx. A reference to the new ForwardADLevel is
47// stored into a global (for the whole process) vector that ensure it can be
48// accessed via ForwardADLevel::get_by_idx. This reference is deleted when the
49// index is released by the user when calling ForwardADLevel::release_idx. When
50// it is destructed, the ForwardADLevel is responsible for clearing all the
51// tangents for its level stored in all the ForwardGrad that registered with it.
52//
53// This process-wide level design, compared to a thread local one, allows us to
54// use very simple user facing handle for the level (an int) while enabling
55// cross-thread forward AD. The only required synchronization for the user is
56// when entering and exiting the levels. Some discussion on alternative design
57// is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 and
58// can be refined in the future.
59
60// Correctness of concurrency:
61// Each class uses its own lock when reading or modifying internal storages.
62// This allows in particular to safely remove tangents from ForwardGrad when the
63// ForwardADLevel is being exited. We ensure no deadlock by ensuring that a
64// methods never calls into another class's method while the local class's lock
65// is held except in one single case: calling from ForwardADLevel's destructor
66// into ForwardGrad::reset with update_level=false.
67
68// The lifetime of these objects is as follows:
69// The ForwardADLevel can be in three states:
70// - Initialized: where one of its reference is held by the global vector
71// and there may be more
72// references held by temporary variables in ForwardGrad's methods.
73// - About to be destructed: where "release_idx" has been called and the
74// only reason for the
75// ForwardADLevel not to be destructed right away is that some methods in
76// ForwardGrad have owning reference to it. This is done so that a
77// ForwardADLevel can never be destructed when a ForwardGrad is
78// registered with it and in the process of adding something to its
79// internal state.
80// - Being destructed: Here the ForwardADLevel is not referenced anymore
81// and can be safely reset
82// all of the ForwardGrad. Note that we can have more than one reset
83// being called here (which is ok) but we are guaranteed that there is at
84// least one.
85// The ForwardGrad is simpler as there is no intermediary state and no special
86// destructor for. The logic to unregister it from the different ForwardADLevel
87// is done when the owning object (AutogradMeta or SavedVariable) is being
88// destroyed.
89
90// Other considered design:
91// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside
92// the ForwardADLevel. While this would work, it would mean that the set inside
93// the ForwardADLevel would only grow unless we do an expensive linear scan to
94// remove all the dangling weak pointers. Hence this approach was not used.
95
96// Data structures in this file are optimized for this maximum number of levels.
97// The number of levels corresponds to the degree of the gradient being
98// computed using forward AD and we don't expect more than second order
99// gradients to be common.
100#define EXPECTED_MAX_LEVEL 2
101
102struct TORCH_API ForwardADLevel {
103 ForwardADLevel(uint64_t idx) : idx_(idx) {}
104 ~ForwardADLevel();
105
106 static uint64_t get_next_idx();
107 static void release_idx(uint64_t idx);
108 static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx);
109 static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx);
110
111 void erase(const std::shared_ptr<ForwardGrad>& grad) {
112 std::lock_guard<std::mutex> lock(mutex_);
113 grads_.erase(grad);
114 }
115
116 void insert(const std::shared_ptr<ForwardGrad>& grad) {
117 std::lock_guard<std::mutex> lock(mutex_);
118 grads_.insert(grad);
119 }
120
121 private:
122 std::unordered_set<std::shared_ptr<ForwardGrad>> grads_;
123 std::mutex mutex_;
124 uint64_t idx_;
125};
126
127struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
128 ForwardGrad() = default;
129
130 // This function must only be called when AutogradMeta or SavedVariable is
131 // being destructed as it ensures that:
132 // - The only (potential) other references to this ForwardGrad are the
133 // different level it is registered to
134 // - No other thread will try to call `set_value` or `value` ever from now
135 // on
136 // - Any of the ForwardADLevel that this ForwardGrad is registered with
137 // might
138 // call `reset` at any point during this function
139 void clear() {
140 c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx;
141
142 {
143 std::lock_guard<std::mutex> lock(mutex_);
144 for (auto& c : content_) {
145 levels_idx.push_back(c.first);
146 }
147 }
148
149 for (auto l_idx : levels_idx) {
150 // Use "try" version here as another thread might have deleted this
151 // level before we got here
152 // This is an owning reference as we want to keep the level alive
153 // until we successfully unregister ourselves
154 auto level = ForwardADLevel::try_get_by_idx(l_idx);
155 if (level) {
156 level->erase(shared_from_this());
157 }
158 }
159 }
160
161 void set_value(const at::Tensor& value, uint64_t level) {
162 // Owning reference to ensure the forward_level is not destroyed
163 // while we are updating our internal state
164 auto forward_level = ForwardADLevel::get_by_idx(level);
165 forward_level->insert(shared_from_this());
166
167 std::lock_guard<std::mutex> lock(mutex_);
168 content_.insert({level, value});
169 }
170
171 // This function removes the tangent for a given level from this ForwardGrad
172 // Use the update_level flag to disable notifying the level about this reset
173 // This flag is most notably used by the ForwardADLevel destructor.
174 void reset(uint64_t level, bool update_level = true) {
175 if (update_level) {
176 ForwardADLevel::get_by_idx(level)->erase(shared_from_this());
177 }
178
179 std::unique_lock<std::mutex> lock(mutex_);
180 const auto& it = content_.find(level);
181 TORCH_INTERNAL_ASSERT(
182 it != content_.end(), "Resetting a non-existent level.");
183 // Keep the Tensor alive until we have released the lock
184 // This is needed as we can be in a case where this function is called by
185 // ForwardADLevel destructor
186 auto t = (*it).second;
187 content_.erase(level);
188 lock.unlock();
189 }
190
191 const at::Tensor& value(uint64_t level) const;
192
193 bool contains(uint64_t level) {
194 std::lock_guard<std::mutex> lock(mutex_);
195 return content_.count(level) > 0;
196 }
197
198 bool empty() const {
199 return content_.empty();
200 }
201
202 static const at::Tensor& undef_grad();
203
204 private:
205 // TODO(albanD): replace this with a SmallVector
206 std::unordered_map<uint64_t, at::Tensor> content_;
207 mutable std::mutex mutex_;
208};
209
210} // namespace autograd
211} // namespace torch
212