1#pragma once
2
3// Engine implements backpropagation from output variables and their gradients
4// to "root" variables (variables created by the user with requires_grad=True).
5
6#include <ATen/Tensor.h>
7#include <ATen/ThreadLocalState.h>
8#include <ATen/core/ivalue.h>
9#include <torch/csrc/Export.h>
10#include <torch/csrc/autograd/anomaly_mode.h>
11#include <torch/csrc/autograd/function.h>
12#include <torch/csrc/autograd/functions/basic_ops.h>
13#include <torch/csrc/autograd/graph_task.h>
14#include <torch/csrc/autograd/input_buffer.h>
15#include <torch/csrc/autograd/saved_variable_hooks.h>
16#include <torch/csrc/autograd/utils/warnings.h>
17
18#include <c10/util/CallOnce.h>
19
20#include <deque>
21#include <exception>
22#include <functional>
23#include <memory>
24#include <queue>
25#include <thread>
26#include <unordered_map>
27#include <utility>
28#include <vector>
29
30namespace torch {
31namespace autograd {
32struct ReadyQueue;
33}
34} // namespace torch
35
36namespace torch {
37namespace autograd {
38
39// Maximum reentrant backward depth before switching to a new thread
40// This limit is based on the TSAN's deadlock detector, where it will
41// fail if a program hold more than 65 locks in one thread at once.
42// As we hold mutex in every of our custom C++ autograd Node, we would
43// like to avoid TSAN complains on this when doing reentrant backwards
44// For reference, see https://github.com/google/sanitizers/issues/950
45static constexpr int MAX_DEPTH = 60;
46
47void set_device(int device);
48void validate_outputs(
49 const edge_list& edges,
50 variable_list& grads,
51 const std::function<std::string(const std::string&)>& format_error);
52
53struct NodeTask {
54 std::weak_ptr<GraphTask> base_;
55 std::shared_ptr<Node> fn_;
56 // This buffer serves as an implicit "addition" node for all of the
57 // gradients flowing here. Once all the dependencies are finished, we
58 // use the contents of this buffer to run the function.
59 InputBuffer inputs_;
60 // When worker receives a task with isShutdownTask = true, it will immediately
61 // exit. The engine sends a shutdown task to every queue upon its destruction.
62 bool isShutdownTask_;
63
64 int getReentrantDepth() const;
65
66 NodeTask(
67 std::weak_ptr<GraphTask> base,
68 std::shared_ptr<Node> fn,
69 InputBuffer inputs,
70 bool isShutdownTask = false)
71 : base_(std::move(base)),
72 fn_(std::move(fn)),
73 inputs_(std::move(inputs)),
74 isShutdownTask_(isShutdownTask) {}
75};
76
77// Guard that sets and restores checkpoint_valid
78class CheckpointValidGuard {
79 public:
80 explicit CheckpointValidGuard(
81 const std::shared_ptr<const GraphTask>& graph_task);
82 ~CheckpointValidGuard();
83
84 private:
85 bool prev_checkpoint_valid_state;
86};
87
88struct ReadyQueue {
89 private:
90 // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
91 // Shutdown tasks are first and then empty NodeTask are next.
92 struct CompareNodeTaskTime {
93 bool operator()(NodeTask const& t1, NodeTask const& t2) {
94 // NOLINTNEXTLINE(bugprone-branch-clone)
95 if (t2.isShutdownTask_) {
96 return true;
97 } else if (!t1.fn_ || t1.isShutdownTask_) {
98 return false;
99 } else if (!t2.fn_) {
100 return true;
101 } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
102 return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
103 } else {
104 return t1.getReentrantDepth() < t2.getReentrantDepth();
105 }
106 }
107 };
108
109 // To notify threads waiting on the ReadyQueue of available tasks on the heap_
110 std::condition_variable not_empty_;
111 // To protect read and writes to heap_
112 mutable std::mutex mutex_;
113
114 std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime>
115 heap_;
116
117 public:
118 // incrementOutstandingTasks indicates whether or not we should increment
119 // 'outstanding_tasks_' for the associated GraphTask. This should mostly
120 // always be true and is only set false in certain cases (see docs for
121 // DistEngine.execute_graph_task_until_ready_queue_empty)
122 void push(NodeTask item, bool incrementOutstandingTasks = true);
123 void pushShutdownTask();
124 NodeTask pop();
125 bool empty() const;
126 size_t size() const;
127};
128
129// A single instance of this struct should be created through the whole process
130// lifetime. The worker thread creation logic and Engine's destructor rely on
131// this.
132struct TORCH_API Engine {
133 /// Returns a reference to a static `Engine` instance.
134 static Engine& get_default_engine();
135
136 static Engine& get_base_engine();
137
138 Engine(const Engine&) = delete;
139 Engine(Engine&&) = delete;
140 virtual ~Engine();
141
142 // Given a list of (Node, input number) pairs computes the value of the graph
143 // by following next_edge references.
144 virtual variable_list execute(
145 const edge_list& roots,
146 const variable_list& inputs,
147 bool keep_graph,
148 bool create_graph,
149 bool accumulate_grad,
150 const edge_list& outputs = {});
151
152 // Given a pre-populated GraphTask and GraphRoot, computes the backward pass
153 // for the graph.
154 //
155 // NB: This API should only be used by internal autograd specific
156 // machinery and shouldn't be exposed to users in anyway.
157 virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
158 const std::shared_ptr<GraphTask>& graph_task,
159 std::shared_ptr<Node> graph_root,
160 InputBuffer&& input_buffer);
161
162 virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
163 return std::make_unique<AnomalyMetadata>();
164 }
165
166 virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() {
167 return nullptr;
168 }
169
170 // We pass cpu_ready_queue to evaluate_function, so that it knows
171 // the correct ready queue to push to after a NodeTask is ready
172 void evaluate_function(
173 std::shared_ptr<GraphTask>& graph_task,
174 Node* func,
175 InputBuffer& inputs,
176 const std::shared_ptr<ReadyQueue>& cpu_ready_queue);
177
178 void initialize_device_threads_pool();
179 virtual void thread_on_exception(
180 std::shared_ptr<GraphTask> graph_task,
181 const std::shared_ptr<Node>& fn,
182 std::exception& e);
183
184 void queue_callback(std::function<void()> callback);
185
186 bool is_checkpoint_valid();
187
188 // Should be called after fork to notify that worker threads are gone
189 void release_workers();
190
191 // Must be called by subclass before destructing to avoid a data-race-on-vptr.
192 void stop();
193
194 // Initializes a device thread for the autograd engine.
195 virtual void thread_init(
196 int device,
197 const std::shared_ptr<ReadyQueue>& ready_queue,
198 bool should_increment = true);
199
200 protected:
201 Engine();
202 void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr);
203
204 // initialize the thread local ready queue with the ready queue that is
205 // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new
206 // ready queue if ready_queue is not provided.
207 void init_local_ready_queue(
208 std::shared_ptr<ReadyQueue> ready_queue = nullptr);
209
210 std::shared_ptr<ReadyQueue> ready_queue(
211 std::shared_ptr<ReadyQueue> cpu_ready_queue,
212 at::Device device);
213 std::shared_ptr<ReadyQueue> ready_queue_by_index(
214 std::shared_ptr<ReadyQueue> cpu_ready_queue,
215 int device_index);
216 // start device threads (CUDA, XLA, etc.) in Engine,
217 // note that it does NOT start CPU thread.
218 void start_device_threads();
219 void increment_non_reentrant_thread_count();
220 void decrement_non_reentrant_thread_count();
221 virtual void thread_main(const std::shared_ptr<GraphTask>& task);
222 void reentrant_thread_init();
223 void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
224
225 // Ensures device_ready_queues_ are initialized only once
226 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
227 c10::once_flag start_device_threads_flag_;
228 // Safe to read device_ready_queues_ without synchronization after
229 // initialization
230 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
231 std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
232
233 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
234 std::vector<std::function<void()>> final_callbacks_;
235 // To protect reads and writes to final_callbacks_
236 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
237 std::mutex post_callbacks_lock_;
238
239 // How many nested reentrant calls are allowed until a new thread is used
240 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
241 int max_recursion_depth_;
242
243 struct ThreadPoolShared {
244 // Data structures used by the threads for executing reentrant backwards
245 // tasks. See Note [Reentrant backwards]
246 // Number of available threads for processing new GraphTasks.
247 unsigned int num_workers_{0};
248 // The threads will wait on work_ to be notified of GraphTasks
249 std::condition_variable work_;
250 // To protect reads and writes to graphtask_queue_ and num_workers_
251 // and for synchronizing creating new threads when needed
252 std::mutex mutex_;
253 // Workers will process the GraphTasks added to this queue. A GraphTask is
254 // allocated inside Engine::execute and lives for the duration of execute
255 std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
256
257 ThreadPoolShared() = default;
258 };
259
260 // Temporary workaround until shutting down threads is done
261 // We need shared ownership of all these objects because the threads are
262 // leaked when Engine shuts down, so there may be threads waiting on work_ for
263 // the graphtasks_queue_ to be nonempty.
264 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
265 std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
266
267 private:
268 // Number of non-reentrant threads
269 std::atomic<uint32_t> non_reentrant_device_thread_count_;
270 // Destructor will wait for non-reentrant threads to finish
271 std::condition_variable non_reentrant_device_thread_condvar_;
272 std::mutex non_reentrant_device_thread_mutex_;
273 // stop() must be called before the destruction path goes down to the base
274 // class, in order to avoid a data-race-on-vptr. Use this boolean to guard
275 // whether stop() has already been called, so we can call this in every
276 // destructor of the class hierarchy.
277 bool stopped_{false};
278};
279
280// allow python_engine to override the default engine when it loads
281using EngineStub = Engine& (*)();
282TORCH_API void set_default_engine_stub(EngineStub stub);
283
284} // namespace autograd
285} // namespace torch
286