1#pragma once
2#include <ATen/ThreadLocalState.h>
3#include <ATen/core/Tensor.h>
4#include <c10/util/ThreadLocal.h>
5#include <torch/csrc/autograd/input_buffer.h>
6#include <torch/csrc/autograd/utils/warnings.h>
7#include <vector>
8
9namespace torch {
10namespace autograd {
11
12using edge_list = std::vector<Edge>;
13struct ReadyQueue;
14
15static constexpr int NO_DEVICE = -2;
16static constexpr int CPU_DEVICE = -1;
17
18namespace {
19std::atomic<uint64_t> graph_task_id{0};
20}
21
22// GraphTask holds metadata needed for a single execution of backward()
23struct GraphTask : std::enable_shared_from_this<GraphTask> {
24 std::atomic<uint64_t> outstanding_tasks_{0};
25 // Indicates if an error occurred while executing any task. When this is
26 // true, it signals all threads to stop executing.
27 std::atomic_bool has_error_{false};
28 std::atomic_bool future_completed_{false};
29 // It is safe to read keep_graph_ without synchronization
30 bool keep_graph_;
31
32 // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
33 // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
34 std::mutex mutex_;
35 std::unordered_map<Node*, InputBuffer> not_ready_;
36 std::unordered_map<Node*, int> dependencies_;
37
38 // Records the nodes that are in the graph
39 std::unordered_set<Node*> nodes_in_graph_;
40 c10::SmallVector<Node*, 4> graph_roots_;
41 // Note [Exec info]
42 // Exec info is created for each GraphTask, which allows filtering paths on
43 // the graph that are not needed. It has a bit complicated semantics. If it's
44 // empty, it means the task is run in a "default" mode, which means that all
45 // next_edges we encounter should get executed. If it's not empty, only
46 // functions that have an entry and this entry has needed == True should be
47 // executed. exec_info is only empty when the graph is executed via
48 // .backward() and the inputs parameter is not passed. Otherwise, when
49 // executed through .grad(), or when inputs arg is specified for .backward(),
50 // exec_info will be non-empty.
51 //
52 struct ExecInfo {
53 struct Capture {
54 Capture(const Capture&) = delete;
55 Capture(Capture&&) = default;
56
57 Capture(int input_idx, int output_idx)
58 : input_idx_(input_idx), output_idx_(output_idx) {}
59 int input_idx_; // within Node inputs
60 int output_idx_; // within the output vector of a GraphTask
61
62 // This hook will be executed after a grad is captured. The captured
63 // grad will be replaced by the return value of the hook.
64 struct GradCaptureHook {
65 virtual ~GradCaptureHook() = default;
66 virtual at::Tensor operator()(const at::Tensor& grad) = 0;
67 };
68 // NOTE [Deprecated capture hooks]
69 //
70 // The current status of capture hooks is that we continue to support
71 // the single usage of it by distributed in the dist_engine. If anyone
72 // else needs to use it for other purposes, they should file an issue.
73 //
74 // Capture hooks were originally created because there did not exist
75 // any way to register pre/post hooks to grad_fn in a way such that it
76 // would still be executed even if that is the grad_fn of a Tensor
77 // passed as input= of .grad. As far as I know, only dist_engine uses
78 // this hook.
79 //
80 // However, there are other alternatives today like tensor hooks that can
81 // replace the usage that originally motivated its creation. Also,
82 // Captures hooks are an outlier in terms of the types of hook that
83 // autograd offers in how it is registered and behaves, e.g. it is a hook
84 // registered not to the graph, but to a particular graph_task! This makes
85 // it a burden to maintain.
86 //
87 // It would be very nice to clean up/do a migration from pre/post
88 // hooks used in distributed to use tensor hooks, but for now we just
89 // mark this method as deprecated to prevent additional usage.
90 //
91 // If you still think you really need to capture hooks, please file an
92 // issue (and tag autograd).
93 const std::vector<std::unique_ptr<GradCaptureHook>>&
94 DO_NOT_USE_DEPRECATED_get_capture_hooks() const {
95 return hooks_;
96 }
97 // See NOTE [deprecated capture hooks]
98 void DO_NOT_USE_DEPRECATED_register_capture_hook(
99 std::unique_ptr<GradCaptureHook> hook) {
100 hooks_.push_back(std::move(hook));
101 }
102
103 private:
104 // The hooks will be called one by one in the order as they were added.
105 // The input grad of a hook will be the output of its preceding hook. The
106 // first hook will take the captured grad as the input. The output of the
107 // last hook will replace the captured grad.
108 std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
109 };
110
111 bool should_execute() const {
112 return needed_ || captures_;
113 }
114
115 bool needed_ = false;
116 std::unique_ptr<std::vector<Capture>> captures_;
117 };
118 // exec_info_ is safe to read without synchronization
119 std::unordered_map<Node*, ExecInfo> exec_info_;
120 // Captures variables are grads captured that we return to the user. After
121 // execution of the GraphTask is completed, the captured_vars_ are moved
122 // out of the GraphTask and are no longer valid.
123 std::vector<Variable> captured_vars_;
124
125 // Note: this field is not ready to be used until the proper
126 // `thread_locals_.set_grad_mode()` call in the constructor.
127 at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
128
129 std::unordered_set<c10::Stream> leaf_streams;
130
131 // Per-device current streams of the execute() that called this GraphTask.
132 // These will be synced with leaf_streams in exec_post_processing.
133 std::vector<c10::optional<c10::Stream>> caller_current_streams_;
134
135 // Collects caller_current_streams_
136 void stash_current_streams();
137
138 void init_to_execute(
139 Node& graph_root,
140 const edge_list& outputs,
141 bool accumulate_grad,
142 uint64_t min_topo_nr);
143
144 // The value of worker_device in the thread that created this task.
145 // See Note [Reentrant backwards]
146 // Safe to read owner_ and reentrant_depth_ without synchronizaton
147 int owner_;
148 // The number of parent graph tasks for this graph task
149 const int reentrant_depth_;
150
151 bool can_checkpoint() const {
152 return exec_info_.empty();
153 }
154
155 // check if the GraphTask is completed or not
156 bool completed();
157 // mark the graph task as completed and trigger post processing
158 void mark_as_completed_and_run_post_processing();
159
160 // Set an appropriate exception on this graph_task which was encountered while
161 // running the provided function.
162 void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
163
164 // Set an appropriate exception on this graph_task which was encountered while
165 // running the provided function. But doesn't signal completion on
166 // 'future_result_' right away. The user needs to explicitly mark
167 // 'future_result_' completed with an appropriate exception.
168 void set_exception_without_signal(const std::shared_ptr<Node>& fn);
169
170 // Whether or not to stop execution for this GraphTask when an error is
171 // encountered. When set to true, this would cause Engine::execute() to throw
172 // an exception as soon as the autograd engine receives an exception.
173 bool exit_on_error_;
174
175 // CPU threads are dedicated to processing CPU work for the backward they
176 // invoked. So any given graph task maintains its own cpu_ready_queue_ where
177 // you should send work for it to be done. We memoize the cpu_ready_queue_ per
178 // GraphTask so that we know which ready queue we should push to if we are on
179 // device thread (i.e. GPU) and but next NodeTask should be run on CPU.
180 std::shared_ptr<ReadyQueue> cpu_ready_queue_;
181
182 // Future representing the completion of the graph task. Notified when all
183 // tasks are done.
184 c10::intrusive_ptr<at::ivalue::Future> future_result_;
185
186 // Final callbacks installed during execution of this GraphTask
187 std::vector<std::function<void()>> final_callbacks_;
188 // To protect reads and writes to final_callbacks_. Intentionally no reusing
189 // mutex_ as the two are protecting different data structures.
190 std::mutex final_callbacks_lock_;
191
192 utils::DelayWarningHandler warning_handler_;
193
194 uint64_t id_;
195
196 GraphTask(
197 bool keep_graph,
198 bool grad_mode,
199 int reentrant_depth,
200 std::shared_ptr<ReadyQueue> cpu_ready_queue,
201 c10::SmallVector<Node*, 4> graph_roots,
202 bool exit_on_error = false)
203 : keep_graph_(keep_graph),
204 graph_roots_(std::move(graph_roots)),
205 owner_(NO_DEVICE),
206 reentrant_depth_(reentrant_depth),
207 exit_on_error_(exit_on_error),
208 cpu_ready_queue_(std::move(cpu_ready_queue)),
209 future_result_(c10::make_intrusive<at::ivalue::Future>(
210 c10::ListType::create(c10::TensorType::get()))),
211 id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) {
212 thread_locals_.set_grad_mode(grad_mode);
213 }
214
215 private:
216 // run GraphTask post processing
217 void exec_post_processing();
218};
219
220// The guard that sets and restores current_graph_task.
221class GraphTaskGuard {
222 public:
223 explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
224 ~GraphTaskGuard();
225
226 void restore_current_graph_task();
227
228 private:
229 std::shared_ptr<GraphTask> last_graph_task_;
230};
231
232TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
233get_current_graph_task_exec_info();
234TORCH_API const std::unordered_set<Node*>*
235get_current_graph_task_nodes_in_graph();
236TORCH_API bool get_current_graph_task_keep_graph();
237TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
238TORCH_API int get_current_graph_task_id();
239void add_node_to_current_graph_task_exec_info(Node* fn);
240
241} // namespace autograd
242} // namespace torch
243