1 | #include <torch/csrc/autograd/function.h> |
2 | |
3 | #include <c10/util/ThreadLocal.h> |
4 | #include <torch/csrc/autograd/engine.h> |
5 | #include <torch/csrc/autograd/variable.h> |
6 | |
7 | #include <ATen/ATen.h> |
8 | |
9 | #include <algorithm> |
10 | #include <cstdint> |
11 | #include <memory> |
12 | #include <stdexcept> |
13 | #include <string> |
14 | #include <utility> |
15 | #include <vector> |
16 | |
17 | namespace torch { |
18 | namespace autograd { |
19 | |
20 | // The current evaluating node. This is useful to assign the current node as a |
21 | // parent of new nodes created during the evaluation of this node in anomaly |
22 | // mode. |
23 | C10_DEFINE_TLS_static(std::shared_ptr<Node>, tls_current_evaluating_node); |
24 | #define current_evaluating_node (tls_current_evaluating_node.get()) |
25 | |
26 | NodeGuard::NodeGuard(std::shared_ptr<Node> node) { |
27 | last_evaluating_node_ = std::move(current_evaluating_node); |
28 | current_evaluating_node = std::move(node); |
29 | } |
30 | NodeGuard::~NodeGuard() { |
31 | // restore the previous evaluating node |
32 | current_evaluating_node = std::move(last_evaluating_node_); |
33 | } |
34 | |
35 | std::shared_ptr<Node> get_current_node() { |
36 | return current_evaluating_node; |
37 | } |
38 | |
39 | void Node::assign_parent() { |
40 | metadata()->assign_parent(current_evaluating_node); |
41 | } |
42 | |
43 | auto Node::name() const -> std::string { |
44 | return c10::demangle(typeid(*this).name()); |
45 | } |
46 | |
47 | AnomalyMetadata* Node::metadata() noexcept { |
48 | if (!anomaly_metadata_) { |
49 | anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata(); |
50 | } |
51 | return anomaly_metadata_.get(); |
52 | } |
53 | |
54 | static void gatherFunctions( |
55 | Node* func, |
56 | std::vector<std::shared_ptr<Node>>& stack) { |
57 | func->release_variables(); |
58 | |
59 | for (auto& edge : func->next_edges()) { |
60 | if (edge.function.use_count() == 1) { |
61 | stack.emplace_back(std::move(edge.function)); |
62 | } else { |
63 | edge.function.reset(); |
64 | } |
65 | } |
66 | } |
67 | |
68 | /* |
69 | * Fix for #5534: prevent stack overflow on deletion of deep computation graph |
70 | * |
71 | * Sometimes one can end up with a very big computation graph of Nodes |
72 | * and Edges. Each std::shared_ptr<Node> contains a list of Edge, and |
73 | * each Edge contains a std::shared_ptr<Node>. Deleting a |
74 | * std::shared_ptr<Node> can trigger the recursive deletion of other |
75 | * std::shared_ptr<Node>'s: this can stack overflow if the graph |
76 | * is deep enough. Here is an example of such a graph: |
77 | * |
78 | * shared_ptr<Node> -> Edge -> shared_ptr<Node> -> Edge -> ... -> |
79 | * shared_ptr<Node> |
80 | * |
81 | * The solution here is to detect when we are decrementing away the last |
82 | * reference to a Node, and when doing so to buffer up the Node's |
83 | * that will be recursively decremented. We can then decrement (and free) |
84 | * the original Node without causing a recursive cascade, before |
85 | * draining the buffer applying the same behavior. This is, in effect, |
86 | * converting recursion to a loop, using a heap buffer in place of the |
87 | * recursive call stack. |
88 | */ |
89 | void deleteNode(Node* function) { |
90 | // To avoid stack overflow on large computational graphs, |
91 | // we need to track reference decrementing and freeing |
92 | // on the heap. |
93 | function->release_variables(); |
94 | std::vector<std::shared_ptr<Node>> stack; |
95 | gatherFunctions(function, stack); |
96 | delete function; |
97 | |
98 | while (!stack.empty()) { |
99 | auto func = std::move(stack.back()); |
100 | stack.pop_back(); |
101 | gatherFunctions(func.get(), stack); |
102 | // Reference count is decremented on the loop backedge. |
103 | } |
104 | } |
105 | |
106 | } // namespace autograd |
107 | } // namespace torch |
108 | |