1 | #include <torch/csrc/autograd/forward_grad.h> |
---|---|
2 | |
3 | namespace torch { |
4 | namespace autograd { |
5 | |
6 | namespace { |
7 | // See discussion in forward_grad.h for why these are global variables and not |
8 | // thread local |
9 | |
10 | std::mutex all_forward_levels_mutex_; |
11 | std::vector<std::shared_ptr<ForwardADLevel>> all_forward_levels_; |
12 | |
13 | const static at::Tensor singleton_undefined_tensor; |
14 | } // namespace |
15 | |
16 | uint64_t ForwardADLevel::get_next_idx() { |
17 | std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); |
18 | auto next_idx = all_forward_levels_.size(); |
19 | TORCH_CHECK( |
20 | next_idx == 0, "Nested forward mode AD is not supported at the moment"); |
21 | all_forward_levels_.push_back(std::make_shared<ForwardADLevel>(next_idx)); |
22 | return next_idx; |
23 | } |
24 | |
25 | void ForwardADLevel::release_idx(uint64_t idx) { |
26 | std::unique_lock<std::mutex> lock(all_forward_levels_mutex_); |
27 | TORCH_CHECK( |
28 | idx + 1 == all_forward_levels_.size(), |
29 | "Exiting a forward AD level that is not the " |
30 | "last that was created is not support. Ensure they are released in the reverse " |
31 | "order they were created."); |
32 | TORCH_INTERNAL_ASSERT(!all_forward_levels_.empty()); |
33 | // Keep the level alive until we have released the lock |
34 | auto lvl = all_forward_levels_.back(); |
35 | all_forward_levels_.pop_back(); |
36 | lock.unlock(); |
37 | } |
38 | |
39 | std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) { |
40 | std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); |
41 | TORCH_CHECK( |
42 | idx < all_forward_levels_.size(), |
43 | "Trying to access a forward AD level with an invalid index. " |
44 | "This index was either not created or is already deleted."); |
45 | return all_forward_levels_[idx]; |
46 | } |
47 | |
48 | std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) { |
49 | std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); |
50 | if (idx < all_forward_levels_.size()) { |
51 | return all_forward_levels_[idx]; |
52 | } else { |
53 | return nullptr; |
54 | } |
55 | } |
56 | |
57 | ForwardADLevel::~ForwardADLevel() { |
58 | std::lock_guard<std::mutex> lock(mutex_); |
59 | auto it = grads_.begin(); |
60 | while (it != grads_.end()) { |
61 | // Warning this will lock *it mutex |
62 | // This is ok as this function is the *only* one to call back into another |
63 | // class's method. |
64 | (*it)->reset(idx_, /* update_level */ false); |
65 | it = grads_.erase(it); |
66 | } |
67 | } |
68 | |
69 | const at::Tensor& ForwardGrad::value(uint64_t level) const { |
70 | std::lock_guard<std::mutex> lock(mutex_); |
71 | const auto& it = content_.find(level); |
72 | return it == content_.end() ? singleton_undefined_tensor : (*it).second; |
73 | } |
74 | |
75 | const at::Tensor& ForwardGrad::undef_grad() { |
76 | return singleton_undefined_tensor; |
77 | } |
78 | |
79 | } // namespace autograd |
80 | } // namespace torch |
81 |