1 | #pragma once |
2 | #include <ATen/ThreadLocalState.h> |
3 | #include <ATen/core/Tensor.h> |
4 | #include <c10/util/ThreadLocal.h> |
5 | #include <torch/csrc/autograd/input_buffer.h> |
6 | #include <torch/csrc/autograd/utils/warnings.h> |
7 | #include <vector> |
8 | |
9 | namespace torch { |
10 | namespace autograd { |
11 | |
12 | using edge_list = std::vector<Edge>; |
13 | struct ReadyQueue; |
14 | |
15 | static constexpr int NO_DEVICE = -2; |
16 | static constexpr int CPU_DEVICE = -1; |
17 | |
18 | namespace { |
19 | std::atomic<uint64_t> graph_task_id{0}; |
20 | } |
21 | |
22 | // GraphTask holds metadata needed for a single execution of backward() |
23 | struct GraphTask : std::enable_shared_from_this<GraphTask> { |
24 | std::atomic<uint64_t> outstanding_tasks_{0}; |
25 | // Indicates if an error occurred while executing any task. When this is |
26 | // true, it signals all threads to stop executing. |
27 | std::atomic_bool has_error_{false}; |
28 | std::atomic_bool future_completed_{false}; |
29 | // It is safe to read keep_graph_ without synchronization |
30 | bool keep_graph_; |
31 | |
32 | // To protect reads/writes to not_ready_, dependencies_, captured_vars_, |
33 | // has_error_, future_result_, cpu_ready_queue_, and leaf_streams. |
34 | std::mutex mutex_; |
35 | std::unordered_map<Node*, InputBuffer> not_ready_; |
36 | std::unordered_map<Node*, int> dependencies_; |
37 | |
38 | // Records the nodes that are in the graph |
39 | std::unordered_set<Node*> nodes_in_graph_; |
40 | c10::SmallVector<Node*, 4> graph_roots_; |
41 | // Note [Exec info] |
42 | // Exec info is created for each GraphTask, which allows filtering paths on |
43 | // the graph that are not needed. It has a bit complicated semantics. If it's |
44 | // empty, it means the task is run in a "default" mode, which means that all |
45 | // next_edges we encounter should get executed. If it's not empty, only |
46 | // functions that have an entry and this entry has needed == True should be |
47 | // executed. exec_info is only empty when the graph is executed via |
48 | // .backward() and the inputs parameter is not passed. Otherwise, when |
49 | // executed through .grad(), or when inputs arg is specified for .backward(), |
50 | // exec_info will be non-empty. |
51 | // |
52 | struct ExecInfo { |
53 | struct Capture { |
54 | Capture(const Capture&) = delete; |
55 | Capture(Capture&&) = default; |
56 | |
57 | Capture(int input_idx, int output_idx) |
58 | : input_idx_(input_idx), output_idx_(output_idx) {} |
59 | int input_idx_; // within Node inputs |
60 | int output_idx_; // within the output vector of a GraphTask |
61 | |
62 | // This hook will be executed after a grad is captured. The captured |
63 | // grad will be replaced by the return value of the hook. |
64 | struct GradCaptureHook { |
65 | virtual ~GradCaptureHook() = default; |
66 | virtual at::Tensor operator()(const at::Tensor& grad) = 0; |
67 | }; |
68 | // NOTE [Deprecated capture hooks] |
69 | // |
70 | // The current status of capture hooks is that we continue to support |
71 | // the single usage of it by distributed in the dist_engine. If anyone |
72 | // else needs to use it for other purposes, they should file an issue. |
73 | // |
74 | // Capture hooks were originally created because there did not exist |
75 | // any way to register pre/post hooks to grad_fn in a way such that it |
76 | // would still be executed even if that is the grad_fn of a Tensor |
77 | // passed as input= of .grad. As far as I know, only dist_engine uses |
78 | // this hook. |
79 | // |
80 | // However, there are other alternatives today like tensor hooks that can |
81 | // replace the usage that originally motivated its creation. Also, |
82 | // Captures hooks are an outlier in terms of the types of hook that |
83 | // autograd offers in how it is registered and behaves, e.g. it is a hook |
84 | // registered not to the graph, but to a particular graph_task! This makes |
85 | // it a burden to maintain. |
86 | // |
87 | // It would be very nice to clean up/do a migration from pre/post |
88 | // hooks used in distributed to use tensor hooks, but for now we just |
89 | // mark this method as deprecated to prevent additional usage. |
90 | // |
91 | // If you still think you really need to capture hooks, please file an |
92 | // issue (and tag autograd). |
93 | const std::vector<std::unique_ptr<GradCaptureHook>>& |
94 | DO_NOT_USE_DEPRECATED_get_capture_hooks() const { |
95 | return hooks_; |
96 | } |
97 | // See NOTE [deprecated capture hooks] |
98 | void DO_NOT_USE_DEPRECATED_register_capture_hook( |
99 | std::unique_ptr<GradCaptureHook> hook) { |
100 | hooks_.push_back(std::move(hook)); |
101 | } |
102 | |
103 | private: |
104 | // The hooks will be called one by one in the order as they were added. |
105 | // The input grad of a hook will be the output of its preceding hook. The |
106 | // first hook will take the captured grad as the input. The output of the |
107 | // last hook will replace the captured grad. |
108 | std::vector<std::unique_ptr<GradCaptureHook>> hooks_; |
109 | }; |
110 | |
111 | bool should_execute() const { |
112 | return needed_ || captures_; |
113 | } |
114 | |
115 | bool needed_ = false; |
116 | std::unique_ptr<std::vector<Capture>> captures_; |
117 | }; |
118 | // exec_info_ is safe to read without synchronization |
119 | std::unordered_map<Node*, ExecInfo> exec_info_; |
120 | // Captures variables are grads captured that we return to the user. After |
121 | // execution of the GraphTask is completed, the captured_vars_ are moved |
122 | // out of the GraphTask and are no longer valid. |
123 | std::vector<Variable> captured_vars_; |
124 | |
125 | // Note: this field is not ready to be used until the proper |
126 | // `thread_locals_.set_grad_mode()` call in the constructor. |
127 | at::ThreadLocalState thread_locals_ = at::ThreadLocalState(); |
128 | |
129 | std::unordered_set<c10::Stream> leaf_streams; |
130 | |
131 | // Per-device current streams of the execute() that called this GraphTask. |
132 | // These will be synced with leaf_streams in exec_post_processing. |
133 | std::vector<c10::optional<c10::Stream>> caller_current_streams_; |
134 | |
135 | // Collects caller_current_streams_ |
136 | void stash_current_streams(); |
137 | |
138 | void init_to_execute( |
139 | Node& graph_root, |
140 | const edge_list& outputs, |
141 | bool accumulate_grad, |
142 | uint64_t min_topo_nr); |
143 | |
144 | // The value of worker_device in the thread that created this task. |
145 | // See Note [Reentrant backwards] |
146 | // Safe to read owner_ and reentrant_depth_ without synchronizaton |
147 | int owner_; |
148 | // The number of parent graph tasks for this graph task |
149 | const int reentrant_depth_; |
150 | |
151 | bool can_checkpoint() const { |
152 | return exec_info_.empty(); |
153 | } |
154 | |
155 | // check if the GraphTask is completed or not |
156 | bool completed(); |
157 | // mark the graph task as completed and trigger post processing |
158 | void mark_as_completed_and_run_post_processing(); |
159 | |
160 | // Set an appropriate exception on this graph_task which was encountered while |
161 | // running the provided function. |
162 | void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn); |
163 | |
164 | // Set an appropriate exception on this graph_task which was encountered while |
165 | // running the provided function. But doesn't signal completion on |
166 | // 'future_result_' right away. The user needs to explicitly mark |
167 | // 'future_result_' completed with an appropriate exception. |
168 | void set_exception_without_signal(const std::shared_ptr<Node>& fn); |
169 | |
170 | // Whether or not to stop execution for this GraphTask when an error is |
171 | // encountered. When set to true, this would cause Engine::execute() to throw |
172 | // an exception as soon as the autograd engine receives an exception. |
173 | bool exit_on_error_; |
174 | |
175 | // CPU threads are dedicated to processing CPU work for the backward they |
176 | // invoked. So any given graph task maintains its own cpu_ready_queue_ where |
177 | // you should send work for it to be done. We memoize the cpu_ready_queue_ per |
178 | // GraphTask so that we know which ready queue we should push to if we are on |
179 | // device thread (i.e. GPU) and but next NodeTask should be run on CPU. |
180 | std::shared_ptr<ReadyQueue> cpu_ready_queue_; |
181 | |
182 | // Future representing the completion of the graph task. Notified when all |
183 | // tasks are done. |
184 | c10::intrusive_ptr<at::ivalue::Future> future_result_; |
185 | |
186 | // Final callbacks installed during execution of this GraphTask |
187 | std::vector<std::function<void()>> final_callbacks_; |
188 | // To protect reads and writes to final_callbacks_. Intentionally no reusing |
189 | // mutex_ as the two are protecting different data structures. |
190 | std::mutex final_callbacks_lock_; |
191 | |
192 | utils::DelayWarningHandler warning_handler_; |
193 | |
194 | uint64_t id_; |
195 | |
196 | GraphTask( |
197 | bool keep_graph, |
198 | bool grad_mode, |
199 | int reentrant_depth, |
200 | std::shared_ptr<ReadyQueue> cpu_ready_queue, |
201 | c10::SmallVector<Node*, 4> graph_roots, |
202 | bool exit_on_error = false) |
203 | : keep_graph_(keep_graph), |
204 | graph_roots_(std::move(graph_roots)), |
205 | owner_(NO_DEVICE), |
206 | reentrant_depth_(reentrant_depth), |
207 | exit_on_error_(exit_on_error), |
208 | cpu_ready_queue_(std::move(cpu_ready_queue)), |
209 | future_result_(c10::make_intrusive<at::ivalue::Future>( |
210 | c10::ListType::create(c10::TensorType::get()))), |
211 | id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) { |
212 | thread_locals_.set_grad_mode(grad_mode); |
213 | } |
214 | |
215 | private: |
216 | // run GraphTask post processing |
217 | void exec_post_processing(); |
218 | }; |
219 | |
220 | // The guard that sets and restores current_graph_task. |
221 | class GraphTaskGuard { |
222 | public: |
223 | explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task); |
224 | ~GraphTaskGuard(); |
225 | |
226 | void restore_current_graph_task(); |
227 | |
228 | private: |
229 | std::shared_ptr<GraphTask> last_graph_task_; |
230 | }; |
231 | |
232 | TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>* |
233 | get_current_graph_task_exec_info(); |
234 | TORCH_API const std::unordered_set<Node*>* |
235 | get_current_graph_task_nodes_in_graph(); |
236 | TORCH_API bool get_current_graph_task_keep_graph(); |
237 | TORCH_API std::vector<Node*> get_current_graph_task_execution_order(); |
238 | TORCH_API int get_current_graph_task_id(); |
239 | void add_node_to_current_graph_task_exec_info(Node* fn); |
240 | |
241 | } // namespace autograd |
242 | } // namespace torch |
243 | |