1#include <torch/csrc/autograd/function.h>
2
3#include <c10/util/ThreadLocal.h>
4#include <torch/csrc/autograd/engine.h>
5#include <torch/csrc/autograd/variable.h>
6
7#include <ATen/ATen.h>
8
9#include <algorithm>
10#include <cstdint>
11#include <memory>
12#include <stdexcept>
13#include <string>
14#include <utility>
15#include <vector>
16
17namespace torch {
18namespace autograd {
19
20// The current evaluating node. This is useful to assign the current node as a
21// parent of new nodes created during the evaluation of this node in anomaly
22// mode.
23C10_DEFINE_TLS_static(std::shared_ptr<Node>, tls_current_evaluating_node);
24#define current_evaluating_node (tls_current_evaluating_node.get())
25
26NodeGuard::NodeGuard(std::shared_ptr<Node> node) {
27 last_evaluating_node_ = std::move(current_evaluating_node);
28 current_evaluating_node = std::move(node);
29}
30NodeGuard::~NodeGuard() {
31 // restore the previous evaluating node
32 current_evaluating_node = std::move(last_evaluating_node_);
33}
34
35std::shared_ptr<Node> get_current_node() {
36 return current_evaluating_node;
37}
38
39void Node::assign_parent() {
40 metadata()->assign_parent(current_evaluating_node);
41}
42
43auto Node::name() const -> std::string {
44 return c10::demangle(typeid(*this).name());
45}
46
47AnomalyMetadata* Node::metadata() noexcept {
48 if (!anomaly_metadata_) {
49 anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
50 }
51 return anomaly_metadata_.get();
52}
53
54static void gatherFunctions(
55 Node* func,
56 std::vector<std::shared_ptr<Node>>& stack) {
57 func->release_variables();
58
59 for (auto& edge : func->next_edges()) {
60 if (edge.function.use_count() == 1) {
61 stack.emplace_back(std::move(edge.function));
62 } else {
63 edge.function.reset();
64 }
65 }
66}
67
68/*
69 * Fix for #5534: prevent stack overflow on deletion of deep computation graph
70 *
71 * Sometimes one can end up with a very big computation graph of Nodes
72 * and Edges. Each std::shared_ptr<Node> contains a list of Edge, and
73 * each Edge contains a std::shared_ptr<Node>. Deleting a
74 * std::shared_ptr<Node> can trigger the recursive deletion of other
75 * std::shared_ptr<Node>'s: this can stack overflow if the graph
76 * is deep enough. Here is an example of such a graph:
77 *
78 * shared_ptr<Node> -> Edge -> shared_ptr<Node> -> Edge -> ... ->
79 * shared_ptr<Node>
80 *
81 * The solution here is to detect when we are decrementing away the last
82 * reference to a Node, and when doing so to buffer up the Node's
83 * that will be recursively decremented. We can then decrement (and free)
84 * the original Node without causing a recursive cascade, before
85 * draining the buffer applying the same behavior. This is, in effect,
86 * converting recursion to a loop, using a heap buffer in place of the
87 * recursive call stack.
88 */
89void deleteNode(Node* function) {
90 // To avoid stack overflow on large computational graphs,
91 // we need to track reference decrementing and freeing
92 // on the heap.
93 function->release_variables();
94 std::vector<std::shared_ptr<Node>> stack;
95 gatherFunctions(function, stack);
96 delete function;
97
98 while (!stack.empty()) {
99 auto func = std::move(stack.back());
100 stack.pop_back();
101 gatherFunctions(func.get(), stack);
102 // Reference count is decremented on the loop backedge.
103 }
104}
105
106} // namespace autograd
107} // namespace torch
108