1 | #include <torch/csrc/autograd/engine.h> |
2 | |
3 | #include <torch/csrc/autograd/anomaly_mode.h> |
4 | #include <torch/csrc/autograd/autograd.h> |
5 | #include <torch/csrc/autograd/function.h> |
6 | #include <torch/csrc/autograd/functions/basic_ops.h> |
7 | #include <torch/csrc/autograd/grad_mode.h> |
8 | #include <torch/csrc/autograd/variable.h> |
9 | #include <torch/csrc/utils/memory.h> |
10 | |
11 | #include <ATen/DeviceGuard.h> |
12 | #include <ATen/ExpandUtils.h> |
13 | #include <ATen/Parallel.h> |
14 | #include <ATen/SparseCsrTensorUtils.h> |
15 | #include <ATen/detail/CUDAHooksInterface.h> |
16 | |
17 | #ifndef AT_PER_OPERATOR_HEADERS |
18 | #include <ATen/Functions.h> |
19 | #else |
20 | #include <ATen/ops/isnan.h> |
21 | #endif |
22 | |
23 | #include <c10/core/DeviceGuard.h> |
24 | #include <c10/core/Event.h> |
25 | #include <c10/core/Stream.h> |
26 | #include <c10/core/StreamGuard.h> |
27 | #include <c10/util/Exception.h> |
28 | #include <c10/util/Optional.h> |
29 | #include <c10/util/ThreadLocal.h> |
30 | #include <c10/util/irange.h> |
31 | |
32 | #include <atomic> |
33 | #include <chrono> |
34 | #include <condition_variable> |
35 | #include <cstdint> |
36 | #include <functional> |
37 | #include <iostream> |
38 | #include <memory> |
39 | #include <mutex> |
40 | #include <queue> |
41 | #include <set> |
42 | #include <sstream> |
43 | #include <string> |
44 | #include <thread> |
45 | #include <typeinfo> |
46 | #include <unordered_set> |
47 | #include <utility> |
48 | |
49 | namespace torch { |
50 | namespace autograd { |
51 | |
52 | namespace { |
53 | static bool in_bad_autograd_fork = |
54 | false; // True for children forked after engine's thread pool init |
55 | |
56 | // Called in the forked child if engine's thread pool has already been |
57 | // initialized |
58 | static void forked_autograd_child() { |
59 | in_bad_autograd_fork = true; |
60 | } |
61 | |
62 | // Should be called before unsafe for forks (thread pool) calls |
63 | static void track_bad_autograd_forks() { |
64 | #if !defined(WIN32) |
65 | static c10::once_flag flag; |
66 | c10::call_once( |
67 | flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); }); |
68 | #endif |
69 | } |
70 | |
71 | inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) { |
72 | if (device == c10::kCPU || device == c10::kMeta || device == c10::kLazy) { |
73 | return true; |
74 | } else { |
75 | return false; |
76 | } |
77 | } |
78 | } // namespace |
79 | |
80 | // Threads spawned by the engine are assigned a 'worker_device' specifying |
81 | // what device they process work for. This variable is initialized at: |
82 | // 1. thread creation time for CUDA, XLA device threads, as they are |
83 | // spinning threads waiting for works on their device. |
84 | // 2. before the graph task execution for CPU threads, as for each |
85 | // backward call we use the caller thread to drive engine execution. |
86 | // This is used when handling reentrant backwards calls; |
87 | // See Note [Reentrant backwards] |
88 | static thread_local int worker_device = NO_DEVICE; |
89 | |
90 | // This variable is true if ALL invocations in the stack of re-entrant engine |
91 | // invocations are imperative backwards. This special variable is needed for the |
92 | // gradient checkpointing feature only. |
93 | static thread_local bool checkpoint_valid = true; |
94 | |
95 | // Number of nested reentrant backwards calls currently on this thread |
96 | static thread_local int current_depth = 0; |
97 | |
98 | // For all device threads (i.e. CUDA, XLA), total_depth represents the total |
99 | // nested |
100 | // reentrant backwards depths over all device threads. |
101 | // For CPU devices, it is the total depth associated with the original backward |
102 | // call. |
103 | static thread_local int total_depth = 0; |
104 | |
105 | // The current GraphTask being executed by this thread. This helps |
106 | // queue_callback() to find the target GraphTask to append final callbacks. |
107 | C10_DEFINE_TLS_static(std::shared_ptr<GraphTask>, tls_current_graph_task); |
108 | #define current_graph_task (tls_current_graph_task.get()) |
109 | |
110 | // Every autograd worker thread is associated with a ready queue, which |
111 | // specifies the stream of work of this thread to do. This shared_ptr is a |
112 | // thread_local pointer to each thread's ready_queue, and it should be |
113 | // initialized via the Engine::init_local_ready_queue() call in each |
114 | // corresponding thread before execution. |
115 | // |
116 | // The CUDA, XLA threads are shared among all invocations of backwards via |
117 | // device_ready_queues_, while the caller thread is dedicated to processing work |
118 | // for devices returning true in should_run_in_cpu_ready_queue (most notably the |
119 | // CPU device). So any given graph task maintains its own cpu_ready_queue_ where |
120 | // you should send work for it to be done. |
121 | // |
122 | // For reentrant backward calls, if we spawn new thread from the current thread |
123 | // because we reached the maximum depth, the new thread will just reuse the same |
124 | // ReadyQueue with the parent thread for performance improvement. |
125 | // see Note [Reentrant backwards] for more details. |
126 | C10_DEFINE_TLS_static(std::shared_ptr<ReadyQueue>, tls_local_ready_queue); |
127 | #define local_ready_queue (tls_local_ready_queue.get()) |
128 | |
129 | // Note [Reentrant backwards] |
130 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
131 | // To understand the reentrant backwards problem, we have to notice two |
132 | // aspects of how the autograd engine is implemented today: |
133 | // |
134 | // 1. When you call Engine::execute(), you want to block until |
135 | // differentiation finishes so that you can get the final result variables |
136 | // of the backwards pass. |
137 | // |
138 | // 2. The engine operates by having a single worker thread per work queue, |
139 | // and every work queue is pinned to a specific device where the |
140 | // operation is executed. |
141 | // |
142 | // The problem is, suppose that you call backward() inside of a worker |
143 | // thread. By property (1), we're supposed to block until the nested task |
144 | // finishes. However, by property (2), this worker thread is on the |
145 | // hook for processing the tasks assigned to it; we better not block, |
146 | // because then all of our backward executions (including the one we |
147 | // just started) will deadlock! |
148 | // |
149 | // We maintain a pool of threads waiting for work to do |
150 | // When a reentrant backwards call occurs, the current thread blocks |
151 | // and a thread from the pool is woken up to complete the blocking tasks and an |
152 | // any other tasks that would have been assigned to that worker. If there are no |
153 | // threads available, a new thread is spawned. The new thread will continue |
154 | // processing tasks from the same ReadyQueue as the parent worker |
155 | // |
156 | // When the GraphTask is finished, the parent worker thread that is waiting on |
157 | // the task is notified and the current thread returns to the pool. |
158 | |
159 | // Note [Streaming backwards] |
160 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
161 | // On CUDA devices the autograd engine's device operations are run on the |
162 | // same stream that ran them in forward. This requires automatically |
163 | // syncing the streams so that function A finishes producing its |
164 | // output before function B consumes it. |
165 | // |
166 | // This synchronization occurs when outputs are placed into input buffers. |
167 | // The functions corresponding to input buffer positions have metadata |
168 | // recording their streams from forward, and during backward this |
169 | // data is used to sync the producer's stream with the consumer's. |
170 | // |
171 | // When a CUDA function is run either all its inputs were accumulated on the |
172 | // stream used to run the function OR the inputs are on different devices |
173 | // and the function is responsible for properly acquiring them. |
174 | // |
175 | // User-facing stream semantics of a backward() (or torch.autograd.grad()) |
176 | // call with respect to surrounding ops are the same as for any other call. |
177 | // See "Stream semantics of backward passes" on |
178 | // https://pytorch.org/docs/stable/notes/cuda.html |
179 | // |
180 | // Internally, backward() runs ops (including leaf nodes) on side threads. |
181 | // And streams are thread local. So GraphTask achieves the above semantics by |
182 | // 1. remembering the current streams on all active CUDA devices |
183 | // in the user-facing thread (aka, the thread that called execute() to |
184 | // launch the GraphTask) |
185 | // 2. remembering the "leaf streams" (streams each backward leaf node ran on) |
186 | // 3. during exec_post_processing, for each leaf stream, sync the remembered |
187 | // current streams (on the leaf stream's device) with that |
188 | // leaf stream. |
189 | |
190 | int NodeTask::getReentrantDepth() const { |
191 | std::shared_ptr<GraphTask> graph_task = base_.lock(); |
192 | if (graph_task) { |
193 | return graph_task->reentrant_depth_; |
194 | } else { |
195 | // The graph task is no longer valid indicating an error. As a result, we |
196 | // try to move this to the front of the queue to ensure the autograd |
197 | // engine threads pick up this error soon. |
198 | return std::numeric_limits<int>::max(); |
199 | } |
200 | } |
201 | |
202 | CheckpointValidGuard::CheckpointValidGuard( |
203 | const std::shared_ptr<const GraphTask>& graph_task) { |
204 | prev_checkpoint_valid_state = checkpoint_valid; |
205 | checkpoint_valid = |
206 | graph_task->can_checkpoint() && prev_checkpoint_valid_state; |
207 | } |
208 | |
209 | CheckpointValidGuard::~CheckpointValidGuard() { |
210 | checkpoint_valid = prev_checkpoint_valid_state; |
211 | } |
212 | |
213 | auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void { |
214 | { |
215 | // Lock mutex for writing to heap_ |
216 | std::lock_guard<std::mutex> lock(mutex_); |
217 | if (incrementOutstandingTasks) { |
218 | std::shared_ptr<GraphTask> graph_task = item.base_.lock(); |
219 | TORCH_INTERNAL_ASSERT(graph_task, "GraphTask is no longer valid!" ); |
220 | ++graph_task->outstanding_tasks_; |
221 | } |
222 | heap_.push(std::move(item)); |
223 | } |
224 | not_empty_.notify_one(); |
225 | } |
226 | |
227 | auto ReadyQueue::pushShutdownTask() -> void { |
228 | { |
229 | std::lock_guard<std::mutex> lock(mutex_); |
230 | heap_.push(NodeTask({}, nullptr, InputBuffer(0), true)); |
231 | } |
232 | not_empty_.notify_one(); |
233 | } |
234 | |
235 | size_t ReadyQueue::size() const { |
236 | // Lock mutex for accesses to heap_ |
237 | std::unique_lock<std::mutex> lock(mutex_); |
238 | return heap_.size(); |
239 | } |
240 | |
241 | auto ReadyQueue::pop() -> NodeTask { |
242 | // Lock mutex for accesses to heap_ |
243 | std::unique_lock<std::mutex> lock(mutex_); |
244 | not_empty_.wait(lock, [this] { return !heap_.empty(); }); |
245 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
246 | auto task = std::move(const_cast<NodeTask&>(heap_.top())); |
247 | heap_.pop(); |
248 | return task; |
249 | } |
250 | |
251 | bool ReadyQueue::empty() const { |
252 | // Lock mutex for accesses to heap_ |
253 | std::unique_lock<std::mutex> lock(mutex_); |
254 | return heap_.empty(); |
255 | } |
256 | |
257 | Engine::Engine() |
258 | : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {} |
259 | |
260 | Engine::~Engine() { |
261 | stop(); |
262 | } |
263 | |
264 | // Send shutdown tasks to all device_ready_queues_ if no backward tasks are |
265 | // running Even though readyQueue should be empty, shutdown tasks have the |
266 | // highest priority |
267 | void Engine::stop() { |
268 | if (stopped_) { |
269 | return; |
270 | } |
271 | stopped_ = true; |
272 | // Under some conditions, autograd threads can hang on shutdown |
273 | // Do not wait for them to shutdown indefinitely but rely on timeout |
274 | auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT" ); |
275 | auto wait_duration = wait_duration_str ? std::atof(wait_duration_str) : 10.0; |
276 | bool noBackward = true; |
277 | for (auto& queue : device_ready_queues_) { |
278 | noBackward = noBackward && queue->empty(); |
279 | } |
280 | if (noBackward && wait_duration > 0.0f) { |
281 | for (auto& queue : device_ready_queues_) { |
282 | queue->pushShutdownTask(); |
283 | } |
284 | // Do not wait for termination of global threads on Windows |
285 | // Because CRT terminates DLL threads before calling |
286 | // global object destructors |
287 | #if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME) |
288 | |
289 | using namespace std::chrono_literals; |
290 | // Set a deadline for how long it is OK to wait device threads to shutdown |
291 | auto wait_deadline = |
292 | std::chrono::steady_clock::now() + wait_duration * 1.0s; |
293 | std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_); |
294 | while (non_reentrant_device_thread_count_.load() != 0) { |
295 | if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) == |
296 | std::cv_status::timeout) { |
297 | break; |
298 | } |
299 | } |
300 | #endif |
301 | } |
302 | // Otherwise threads are leaked |
303 | } |
304 | |
305 | void Engine::release_workers() { |
306 | std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_); |
307 | non_reentrant_device_thread_count_.store(0); |
308 | non_reentrant_device_thread_condvar_.notify_one(); |
309 | } |
310 | |
311 | void Engine::increment_non_reentrant_thread_count() { |
312 | std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_); |
313 | non_reentrant_device_thread_count_.fetch_add(1); |
314 | non_reentrant_device_thread_condvar_.notify_one(); |
315 | } |
316 | |
317 | void Engine::decrement_non_reentrant_thread_count() { |
318 | std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_); |
319 | non_reentrant_device_thread_count_.fetch_sub(1); |
320 | non_reentrant_device_thread_condvar_.notify_one(); |
321 | } |
322 | |
323 | void Engine::thread_init( |
324 | int device, |
325 | const std::shared_ptr<ReadyQueue>& ready_queue, |
326 | bool should_increment) { |
327 | if (should_increment) { |
328 | increment_non_reentrant_thread_count(); |
329 | } |
330 | |
331 | at::init_num_threads(); |
332 | |
333 | // Note [Allocating GPUs to autograd threads] |
334 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
335 | // What's our strategy here? Originally, the autograd engine was written |
336 | // with only CUDA in mind. We allocate one thread to handle all CPU |
337 | // operations, and a thread per CUDA device. |
338 | // |
339 | // But what if we have OTHER devices? There are two plausible |
340 | // strategies: |
341 | // |
342 | // - We can allocate threads equal to max(num_cuda_devices, num_xla_devices, |
343 | // ...) and colocate cuda device 0 with xla device 0 |
344 | // - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices, |
345 | // ...) keeping everyone separate. |
346 | // |
347 | // We don't have any good reason to prefer one or the other, so we've |
348 | // arbitrarily picked to colocate devices. Maybe the other approach is |
349 | // better. |
350 | |
351 | #if defined(USE_CUDA) |
352 | if (at::detail::getCUDAHooks().hasPrimaryContext(device)) { |
353 | set_device(device); |
354 | } |
355 | #else |
356 | set_device(device); |
357 | #endif |
358 | |
359 | // initialize each device thread's thread local ready queue with the ready |
360 | // queue that is created before the thread initialization |
361 | init_local_ready_queue(ready_queue); |
362 | |
363 | std::shared_ptr<GraphTask> graph_task = nullptr; |
364 | thread_main(graph_task); |
365 | if (should_increment) { |
366 | // Decrement the count during shutdown if we incremented earlier. |
367 | decrement_non_reentrant_thread_count(); |
368 | } |
369 | } |
370 | |
371 | GraphTaskGuard::GraphTaskGuard(std::shared_ptr<GraphTask> graph_task) { |
372 | last_graph_task_ = std::move(current_graph_task); |
373 | current_graph_task = std::move(graph_task); |
374 | } |
375 | GraphTaskGuard::~GraphTaskGuard() { |
376 | restore_current_graph_task(); |
377 | } |
378 | |
379 | void GraphTaskGuard::restore_current_graph_task() { |
380 | current_graph_task = std::move(last_graph_task_); |
381 | } |
382 | |
383 | // The current graph task's exec_info is being used to trim unnecessary edegs |
384 | // during node evaluation, see `Node.task_should_compute_output()` function. |
385 | const std::unordered_map<Node*, GraphTask::ExecInfo>* |
386 | get_current_graph_task_exec_info() { |
387 | return current_graph_task ? ¤t_graph_task->exec_info_ : nullptr; |
388 | } |
389 | |
390 | const std::unordered_set<Node*>* get_current_graph_task_nodes_in_graph() { |
391 | return current_graph_task ? ¤t_graph_task->nodes_in_graph_ : nullptr; |
392 | } |
393 | |
394 | int get_current_graph_task_id() { |
395 | return current_graph_task ? current_graph_task->id_ : -1; |
396 | } |
397 | |
398 | bool get_current_graph_task_keep_graph() { |
399 | return current_graph_task ? current_graph_task->keep_graph_ : true; |
400 | } |
401 | |
402 | void add_node_to_current_graph_task_exec_info(Node* fn) { |
403 | current_graph_task->exec_info_[fn].needed_ = true; |
404 | } |
405 | |
406 | // NB: The engine itself does not use the outputs of this function. |
407 | std::vector<Node*> get_current_graph_task_execution_order() { |
408 | std::shared_ptr<GraphTask> task = current_graph_task; |
409 | if (!task) { |
410 | return {}; |
411 | } |
412 | |
413 | // We could potentially check if there is only a single device here |
414 | // but explicitly require this context doens't seem bad either |
415 | TORCH_CHECK( |
416 | !c10::AutogradState::get_tls_state().get_multithreading_enabled(), |
417 | "get_current_graph_task_execution_order expects the current backward to be " |
418 | "executed with multithreading disabled, e.g. by running:\n\n" |
419 | ">>> with torch.autograd.set_multithreading_enabled(False):\n" |
420 | "... torch.autograd.grad(...)\n" ); |
421 | |
422 | const bool check_exec_info = !task->exec_info_.empty(); |
423 | std::vector<Node*> out{}; |
424 | std::unordered_set<Node*> seen{}; |
425 | |
426 | auto compare_seq_nr = [](Node* n1, Node* n2) { |
427 | return n1->sequence_nr() < n2->sequence_nr(); |
428 | }; |
429 | std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap( |
430 | compare_seq_nr); |
431 | |
432 | for (Node* ptr : task->graph_roots_) { |
433 | heap.push(ptr); |
434 | } |
435 | |
436 | // Implementation notes: |
437 | // - Don't need to count dependencies because we have sequence_nr |
438 | // - Don't need to check topological_nr because we have exec_info |
439 | while (!heap.empty()) { |
440 | Node* fn = heap.top(); |
441 | heap.pop(); |
442 | |
443 | const bool was_inserted = seen.insert(fn).second; |
444 | if (!was_inserted) { |
445 | continue; |
446 | } |
447 | |
448 | out.push_back(fn); |
449 | for (const auto& edge : fn->next_edges()) { |
450 | Node* next_ptr = edge.function.get(); |
451 | if (!next_ptr) { |
452 | continue; |
453 | } |
454 | if (check_exec_info) { |
455 | auto it = task->exec_info_.find(next_ptr); |
456 | if (it == task->exec_info_.end() || !it->second.should_execute()) { |
457 | continue; |
458 | } |
459 | } |
460 | heap.push(next_ptr); |
461 | } |
462 | } |
463 | return out; |
464 | } |
465 | |
466 | // NOTE: graph_tasks do not necessarily form a stack. Imagine this |
467 | // case: |
468 | // |
469 | // +----> Eval1 |
470 | // Root |
471 | // +----> Eval2 |
472 | // |
473 | // Once Root is executed, both Eval1 and Eval2 are added to the ready queue. |
474 | // Next, Eval1 is run and this causes the worker to enter thread_main again. |
475 | // Then, it pops the next task from the queue, but at this point it is Eval2. |
476 | // It enters thread_main once again, but now with graph_task of Eval2, which is |
477 | // completely unrelated to that of Eval1 (it's not a recursive call). |
478 | // It's all ok and is handled right now, but it should be accounted for |
479 | // in case this code is to be changed. |
480 | // |
481 | // thread_main is used by: |
482 | // 1). autograd threads for devices (i.e. CUDA, XLA) |
483 | // 2). the caller/owning thread of the backward call on CPU (sync mode) |
484 | // 3). Renetrant backward that invoked by either 1) or 2) |
485 | // The exit conditions are different for the above three cases. |
486 | // For 1), we are spinning on running the thread_main on device autograd |
487 | // threads throughout the Engine lifetime, thread_main will get |
488 | // terminated during Engine destruction by pushing shutdown tasks |
489 | // For 2), the owning thread of the backward call drives the thread_main |
490 | // synchronously until the graph_task of that owning thread is |
491 | // completed and exit the thread_main to continue executing the |
492 | // result of caller's code. |
493 | // For 3), the reentrant backward that invokes |
494 | // thread_main, either from 1) or 2), will not spin and will exit as |
495 | // long as graph_task is completed and notify the owning thread as |
496 | // needed. |
497 | auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void { |
498 | // When graph_task is nullptr, this is a long running thread that processes |
499 | // tasks (ex: device threads). When graph_task is non-null (ex: reentrant |
500 | // backwards, user thread), this function is expected to exit once that |
501 | // graph_task complete. |
502 | |
503 | // local_ready_queue should already been initialized when we get into |
504 | // thread_main |
505 | TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr); |
506 | while (graph_task == nullptr || !graph_task->future_result_->completed()) { |
507 | // local_graph_task represents the graph_task we retrieve from the queue. |
508 | // The outer graph_task represents the overall graph_task we need to execute |
509 | // for reentrant execution. |
510 | std::shared_ptr<GraphTask> local_graph_task; |
511 | { |
512 | // Scope this block of execution since NodeTask is not needed after this |
513 | // block and can be deallocated (release any references to grad tensors |
514 | // as part of inputs_). |
515 | NodeTask task = local_ready_queue->pop(); |
516 | // This will only work if the worker is running a non backward task |
517 | // TODO Needs to be fixed this to work in all cases |
518 | if (task.isShutdownTask_) { |
519 | C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown" ); |
520 | break; |
521 | } |
522 | |
523 | if (!(local_graph_task = task.base_.lock())) { |
524 | // GraphTask for function is no longer valid, skipping further |
525 | // execution. |
526 | continue; |
527 | } |
528 | |
529 | if (task.fn_ && !local_graph_task->has_error_.load()) { |
530 | // Set the ThreadLocalState before calling the function. |
531 | // NB: The ThreadLocalStateGuard doesn't set the grad_mode because |
532 | // GraphTask always saves ThreadLocalState without grad_mode. |
533 | at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); |
534 | c10::WarningUtils::WarningHandlerGuard warnings_guard( |
535 | &local_graph_task->warning_handler_); |
536 | |
537 | try { |
538 | // The guard sets the thread_local current_graph_task on construction |
539 | // and restores it on exit. The current_graph_task variable helps |
540 | // queue_callback() to find the target GraphTask to append final |
541 | // callbacks. |
542 | GraphTaskGuard guard(local_graph_task); |
543 | NodeGuard ndguard(task.fn_); |
544 | { |
545 | RECORD_FUNCTION( |
546 | c10::str( |
547 | "autograd::engine::evaluate_function: " , |
548 | task.fn_.get()->name()), |
549 | c10::ArrayRef<const c10::IValue>()); |
550 | evaluate_function( |
551 | local_graph_task, |
552 | task.fn_.get(), |
553 | task.inputs_, |
554 | local_graph_task->cpu_ready_queue_); |
555 | } |
556 | } catch (std::exception& e) { |
557 | thread_on_exception(local_graph_task, task.fn_, e); |
558 | } |
559 | } |
560 | } |
561 | |
562 | // Decrement the outstanding tasks. |
563 | --local_graph_task->outstanding_tasks_; |
564 | |
565 | // Check if we've completed execution. |
566 | if (local_graph_task->completed()) { |
567 | local_graph_task->mark_as_completed_and_run_post_processing(); |
568 | |
569 | auto base_owner = local_graph_task->owner_; |
570 | // The current worker thread finish the graph_task, but the owning thread |
571 | // of the graph_task might be sleeping on pop() if it does not have work. |
572 | // So we need to send a dummy function task to the owning thread just to |
573 | // ensure that it's not sleeping, so that we can exit the thread_main. |
574 | // If it has work, it might see that graph_task->outstanding_tasks_ == 0 |
575 | // before it gets to the task, but it's a no-op anyway. |
576 | // |
577 | // NB: This is not necessary if the current thread is the owning thread. |
578 | if (worker_device != base_owner) { |
579 | // Synchronize outstanding_tasks_ with queue mutex |
580 | std::atomic_thread_fence(std::memory_order_release); |
581 | ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner) |
582 | ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); |
583 | } |
584 | } |
585 | } |
586 | } |
587 | |
588 | // Reentrant call will re-use the graph_task's owner thread ready_queue for |
589 | // queueing tasks (NOTE: this is not true in the async_mode of the engine). |
590 | // While we can create separate ready queue for each new reentrant |
591 | // thread, but sharing the same cpu_ready_queue with parent thread is a |
592 | // performance improvement and cuda thread still have to do the same thing. |
593 | void Engine::reentrant_thread_init() { |
594 | at::init_num_threads(); |
595 | auto tp_shared = thread_pool_shared_; |
596 | while (true) { |
597 | std::unique_lock<std::mutex> lk(tp_shared->mutex_); |
598 | ++thread_pool_shared_->num_workers_; |
599 | tp_shared->work_.wait( |
600 | lk, [&tp_shared] { return !tp_shared->graphtasks_queue_.empty(); }); |
601 | --thread_pool_shared_->num_workers_; |
602 | auto task = tp_shared->graphtasks_queue_.front(); |
603 | tp_shared->graphtasks_queue_.pop(); |
604 | lk.unlock(); |
605 | std::shared_ptr<GraphTask> graph_task; |
606 | if (!(graph_task = task.lock())) { |
607 | LOG(INFO) << "GraphTask has expired, skipping reentrant execution" ; |
608 | continue; |
609 | } |
610 | set_device(graph_task->owner_); |
611 | // set the local_ready_queue to the ready queue on the graph_task->owner_ |
612 | // device |
613 | local_ready_queue = |
614 | ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_); |
615 | total_depth = graph_task->reentrant_depth_; |
616 | thread_main(graph_task); |
617 | } |
618 | } |
619 | |
620 | void Engine::thread_on_exception( |
621 | std::shared_ptr<GraphTask> graph_task, |
622 | const std::shared_ptr<Node>& fn, |
623 | std::exception& e) { |
624 | graph_task->set_exception(std::current_exception(), fn); |
625 | } |
626 | |
627 | bool GraphTask::completed() { |
628 | return outstanding_tasks_.load() == 0 || |
629 | (exit_on_error_ && has_error_.load()); |
630 | } |
631 | |
632 | void GraphTask::mark_as_completed_and_run_post_processing() { |
633 | // Allow only one thread one attempt to process this logic. |
634 | if (future_completed_.exchange(true)) { |
635 | // Future is already marked complete, or being marked as such. |
636 | // In case the marking complete is only in progress, we add a |
637 | // wait() to guarantee the future is marked complete on exit. |
638 | future_result_->wait(); |
639 | return; |
640 | } |
641 | |
642 | try { |
643 | // Run post processing, before marking the future as complete. |
644 | // Drop lock prior to completing, to avoid holding across callbacks. |
645 | std::unique_lock<std::mutex> lock(mutex_); |
646 | |
647 | exec_post_processing(); |
648 | std::vector<Variable> vars = std::move(captured_vars_); |
649 | |
650 | // Need to unlock before we call markCompleted to avoid holding locks |
651 | // when the callbacks are called. |
652 | lock.unlock(); |
653 | future_result_->markCompleted(std::move(vars)); |
654 | } catch (std::exception& e) { |
655 | future_result_->setErrorIfNeeded(std::current_exception()); |
656 | } |
657 | } |
658 | |
659 | void GraphTask::exec_post_processing() { |
660 | if (!not_ready_.empty()) { |
661 | throw std::runtime_error("could not compute gradients for some functions" ); |
662 | } |
663 | |
664 | // set the thread_local current_graph_task_ as more callbacks can be installed |
665 | // by existing final callbacks. |
666 | GraphTaskGuard guard(shared_from_this()); |
667 | // Lock mutex during each iteration for accessing final_callbacks.size() |
668 | // Unlocking is necessary, because the callback can register |
669 | // more callbacks (or they can be registered from other threads |
670 | // while it's waiting. |
671 | std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_); |
672 | |
673 | // caller_current_streams_ with nullopt entries removed |
674 | std::vector<c10::Stream> caller_current_streams_filtered; |
675 | |
676 | // See Note [Streaming backwards]. |
677 | // Syncs caller_current_stream with leaf streams, so final_callbacks may use |
678 | // any grad on its device's current stream. |
679 | if (!leaf_streams.empty()) { |
680 | for (const auto& leaf_stream : leaf_streams) { |
681 | // stash_current_streams() stashed streams for all device IDs that already |
682 | // had a CUDA context before the GraphTask executed. For inactive devices, |
683 | // it stashed a c10::nullopt. I don't expect GraphTask's backward pass ran |
684 | // leaf nodes on any new devices, so the stashed streams should be enough. |
685 | // If leaf_stream.device_index() happens to be for a new device, |
686 | // operator* on the c10::nullopt should throw an error. |
687 | const auto caller_current_stream = |
688 | *caller_current_streams_[leaf_stream.device_index()]; |
689 | |
690 | if (caller_current_stream != leaf_stream) { |
691 | auto event = c10::Event{c10::DeviceType::CUDA}; |
692 | event.record(leaf_stream); |
693 | caller_current_stream.wait(event); |
694 | } |
695 | } |
696 | |
697 | caller_current_streams_filtered.reserve(caller_current_streams_.size()); |
698 | for (const auto& opt_stream : caller_current_streams_) { |
699 | if (opt_stream.has_value()) { |
700 | caller_current_streams_filtered.push_back(*opt_stream); |
701 | } |
702 | } |
703 | } |
704 | |
705 | { |
706 | // final_callbacks run on the per-device caller_current_streams (the ambient |
707 | // streams surrounding the user's call to backward()). This has two |
708 | // benefits: |
709 | // 1. caller_current_streams have been synced with leaf_streams, so |
710 | // callbacks may |
711 | // safely access any grad. |
712 | // 2. The callback's results can safely be used on (user-facing) |
713 | // caller_current_streams |
714 | // after backward(). |
715 | c10::MultiStreamGuard g(caller_current_streams_filtered); |
716 | |
717 | // Set the ThreadLocalState before calling the function. |
718 | // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask |
719 | // always saves ThreadLocalState without grad_mode. |
720 | at::ThreadLocalStateGuard tls_guard(this->thread_locals_); |
721 | |
722 | // WARNING: Don't use a range-for loop here because more callbacks may be |
723 | // added in between callback calls, so iterators may become invalidated. |
724 | // NOLINTNEXTLINE(modernize-loop-convert) |
725 | for (size_t i = 0; i < final_callbacks_.size(); ++i) { |
726 | cb_lock.unlock(); |
727 | final_callbacks_[i](); |
728 | cb_lock.lock(); |
729 | } |
730 | } |
731 | } |
732 | |
733 | void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) { |
734 | if (!has_error_.exchange(true)) { |
735 | if (AnomalyMode::is_enabled() && fn) { |
736 | fn->metadata()->print_stack(fn->name()); |
737 | } |
738 | } |
739 | } |
740 | |
741 | void GraphTask::set_exception( |
742 | std::exception_ptr eptr, |
743 | const std::shared_ptr<Node>& fn) { |
744 | set_exception_without_signal(fn); |
745 | if (!future_completed_.exchange(true)) { |
746 | future_result_->setError(std::move(eptr)); |
747 | } |
748 | } |
749 | |
750 | static variable_list call_pre_hooks(Node& fn, variable_list inputs) { |
751 | for (const auto& hook : fn.pre_hooks()) { |
752 | inputs = (*hook)(inputs); |
753 | } |
754 | return inputs; |
755 | } |
756 | |
757 | static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) { |
758 | for (const auto& hook : fn.tensor_pre_hooks()) { |
759 | inputs = (*hook)(inputs); |
760 | } |
761 | for (const auto& pair : fn.retains_grad_hooks()) { |
762 | inputs = (*pair.second)(inputs); |
763 | } |
764 | return inputs; |
765 | } |
766 | |
767 | static variable_list call_post_hooks( |
768 | Node& fn, |
769 | variable_list outputs, |
770 | const variable_list& inputs) { |
771 | for (const auto& hook : fn.post_hooks()) { |
772 | outputs = (*hook)(outputs, inputs); |
773 | } |
774 | return outputs; |
775 | } |
776 | |
777 | void set_device(int device) { |
778 | // NB: We MUST NOT construct the guard for device CPU, |
779 | // as in some settings we compile with cuda, but |
780 | // have lazy stubs for CUDA functionality (so actually |
781 | // attempting to setup a guard(CPU_DEVICE) will cause an |
782 | // error, because it will still query cudaGetDevice). |
783 | // |
784 | // Don't use DeviceGuard here because its destructor may be called before the |
785 | // device is reset. This is fine because the device is thread local. |
786 | if (device != CPU_DEVICE) { |
787 | for (const auto i : c10::irange(static_cast<size_t>( |
788 | c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) { |
789 | auto* impl = c10::impl::device_guard_impl_registry[i].load(); |
790 | if (impl && device < impl->deviceCount()) { |
791 | impl->setDevice(at::Device(static_cast<c10::DeviceType>(i), device)); |
792 | } |
793 | } |
794 | } |
795 | worker_device = device; |
796 | } |
797 | |
798 | void validate_outputs( |
799 | const edge_list& edges, |
800 | variable_list& grads, |
801 | const std::function<std::string(const std::string&)>& format_error) { |
802 | if (grads.size() != edges.size()) { |
803 | std::stringstream ss; |
804 | ss << "invalid number of gradients - expected " ; |
805 | ss << edges.size() << ", but got " << grads.size(); |
806 | AT_ERROR(format_error(ss.str())); |
807 | } |
808 | for (const auto i : c10::irange(grads.size())) { |
809 | const auto& edge = edges[i]; |
810 | if (!edge.is_valid()) |
811 | continue; |
812 | |
813 | const auto& metadata = edge.function->input_metadata(edge.input_nr); |
814 | auto& grad = grads[i]; |
815 | if (!grad.defined()) { |
816 | // FIXME: TestJit.test_ge_optimized fails this assertion. |
817 | // std::stringstream ss; |
818 | // ss << "undefined gradient at index " << i; |
819 | // AT_ERROR(format_error(ss.str())); |
820 | continue; |
821 | } |
822 | |
823 | if (!metadata.is_same_shape(grad)) { |
824 | if (metadata.is_expandable_to_shape(grad)) { |
825 | grad = metadata.reduce_grad(grad); |
826 | } else { |
827 | const auto message = metadata.incompatible_shape_error_message(i, grad); |
828 | AT_ERROR(format_error(message.str())); |
829 | } |
830 | } |
831 | |
832 | bool input_is_complex = |
833 | isComplexType(c10::typeMetaToScalarType(metadata.options().dtype())); |
834 | bool grad_is_complex = isComplexType(grad.scalar_type()); |
835 | |
836 | TORCH_CHECK( |
837 | isFloatingType(grad.scalar_type()) || |
838 | (input_is_complex == grad_is_complex)); |
839 | if (c10::typeMetaToScalarType(metadata.options().dtype()) != |
840 | grad.scalar_type()) { |
841 | grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype())); |
842 | } |
843 | if (grad.dtype() != metadata.dtype()) { |
844 | std::stringstream ss; |
845 | ss << "invalid gradient at index " << i << " - expected dtype " ; |
846 | ss << metadata.dtype() << " but got " << grad.dtype(); |
847 | AT_ERROR(format_error(ss.str())); |
848 | } |
849 | if (grad.layout() != metadata.layout()) { |
850 | // TODO: Currently we only support (*, Sparse) combination for |
851 | // (tensor.layout(), tensor.grad.layout()) In future, there will be an |
852 | // oppportunity to support more combinations of layouts if they are |
853 | // composable (example., operations like addition etc., are well defined |
854 | // between tensors of different layouts.), as well as all parts of |
855 | // autograd like AccumulateGrad correctly handle this. We allow grad to be |
856 | // Strided when metadata is SparseCsr |
857 | if (!grad.is_sparse() && |
858 | !(grad.layout() == at::kStrided && |
859 | (at::sparse_csr::is_sparse_compressed(metadata.layout()) || |
860 | metadata.layout() == at::kSparse))) { |
861 | std::stringstream ss; |
862 | ss << "invalid gradient at index " << i << " - expected layout " ; |
863 | ss << metadata.layout() << " but got " << grad.layout(); |
864 | AT_ERROR(format_error(ss.str())); |
865 | } |
866 | } |
867 | |
868 | if (grad.device() != metadata.device()) { |
869 | // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but |
870 | // should be eventually removed |
871 | if (!(metadata.is_tensor_subclass() || |
872 | grad.unsafeGetTensorImpl()->is_python_dispatch())) { |
873 | if (grad.dim() == 0) { |
874 | grad = grad.to(metadata.device()); |
875 | } else { |
876 | std::stringstream ss; |
877 | ss << "invalid gradient at index " << i << " - expected device " ; |
878 | ss << metadata.device() << " but got " << grad.device(); |
879 | AT_ERROR(format_error(ss.str())); |
880 | } |
881 | } |
882 | } |
883 | // We should not build graph for Tensors that are not differentiable |
884 | TORCH_INTERNAL_ASSERT(isDifferentiableType(grad.scalar_type())); |
885 | } |
886 | } |
887 | |
888 | static variable_list call_function( |
889 | std::shared_ptr<GraphTask>& graph_task, |
890 | Node* func, |
891 | InputBuffer& inputBuffer) { |
892 | CheckpointValidGuard cpvguard(graph_task); |
893 | auto& fn = *func; |
894 | auto inputs = |
895 | call_tensor_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer))); |
896 | inputs = call_pre_hooks(fn, std::move(inputs)); |
897 | if (!graph_task->keep_graph_) { |
898 | fn.will_release_variables(); |
899 | } |
900 | |
901 | const auto has_post_hooks = !fn.post_hooks().empty(); |
902 | variable_list outputs; |
903 | |
904 | if (has_post_hooks) { |
905 | // In functions/accumulate_grad.cpp, there is some logic to check the |
906 | // conditions under which the incoming gradient can be stolen directly |
907 | // (which elides a deep copy) instead of cloned. One of these conditions |
908 | // is that the incoming gradient's refcount must be 1 (nothing else is |
909 | // referencing the same data). Stashing inputs_copy here bumps the |
910 | // refcount, so if post hooks are employed, it's actually still ok for |
911 | // accumulate_grad.cpp to steal the gradient if the refcount is 2. |
912 | // |
913 | // "new_grad.use_count() <= 1 + !post_hooks().empty()" in |
914 | // accumulate_grad.cpp accounts for this, but also creates a silent |
915 | // dependency between engine.cpp (ie, this particular engine |
916 | // implementation) and accumulate_grad.cpp. |
917 | // |
918 | // If you change the logic here, make sure it's compatible with |
919 | // accumulate_grad.cpp. |
920 | auto inputs_copy = inputs; |
921 | outputs = fn(std::move(inputs_copy)); |
922 | } else { |
923 | outputs = fn(std::move(inputs)); |
924 | } |
925 | |
926 | validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) { |
927 | std::ostringstream ss; |
928 | ss << "Function " << fn.name() << " returned an " << msg; |
929 | return ss.str(); |
930 | }); |
931 | |
932 | if (has_post_hooks) { |
933 | // NOLINTNEXTLINE(bugprone-use-after-move) |
934 | return call_post_hooks(fn, std::move(outputs), inputs); |
935 | } |
936 | return outputs; |
937 | } |
938 | |
939 | void Engine::evaluate_function( |
940 | std::shared_ptr<GraphTask>& graph_task, |
941 | Node* func, |
942 | InputBuffer& inputs, |
943 | const std::shared_ptr<ReadyQueue>& cpu_ready_queue) { |
944 | // The InputBuffer::adds that supplied incoming grads took pains to |
945 | // ensure they're safe to consume in the context of the present |
946 | // func's stream (if applicable). So we guard onto that stream |
947 | // before working with the grads in any capacity. |
948 | const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA); |
949 | c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream}; |
950 | |
951 | // If exec_info_ is not empty, we have to instrument the execution |
952 | auto& exec_info_ = graph_task->exec_info_; |
953 | if (!exec_info_.empty()) { |
954 | auto& fn_info = exec_info_.at(func); |
955 | variable_list new_inputs = inputs.buffer; |
956 | if (!fn_info.needed_) { |
957 | // We always want to call tensor pre-hooks, but want to avoid calling it |
958 | // twice. needed_ = True indicates that we will call tensor pre-hooks |
959 | // later. |
960 | // |
961 | // See NOTE [Hooks ordering] for more context. |
962 | new_inputs = call_tensor_pre_hooks( |
963 | *func, InputBuffer::variables(std::move(inputs))); |
964 | } |
965 | if (auto* capture_vec = fn_info.captures_.get()) { |
966 | const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA); |
967 | // Lock mutex for writing to graph_task->captured_vars_. |
968 | std::lock_guard<std::mutex> lock(graph_task->mutex_); |
969 | for (const auto& capture : *capture_vec) { |
970 | auto& captured_grad = graph_task->captured_vars_[capture.output_idx_]; |
971 | captured_grad = new_inputs[capture.input_idx_]; |
972 | // NOTE [Deprecated capture hooks] |
973 | for (const auto& hook : |
974 | capture.DO_NOT_USE_DEPRECATED_get_capture_hooks()) { |
975 | captured_grad = (*hook)(captured_grad); |
976 | } |
977 | if (opt_parent_stream) { |
978 | // No need to take graph_task->mutex_ here, we already hold it |
979 | graph_task->leaf_streams.emplace(*opt_parent_stream); |
980 | } |
981 | } |
982 | } |
983 | if (!fn_info.needed_) { |
984 | // Skip execution if we don't need to execute the function. |
985 | return; |
986 | } |
987 | } |
988 | |
989 | auto outputs = call_function(graph_task, func, inputs); |
990 | |
991 | auto& fn = *func; |
992 | if (!graph_task->keep_graph_) { |
993 | fn.release_variables(); |
994 | } |
995 | |
996 | int num_outputs = outputs.size(); |
997 | if (num_outputs == 0) { // Note: doesn't acquire the mutex |
998 | // Records leaf stream (if applicable) |
999 | // See Note [Streaming backwards] |
1000 | if (opt_parent_stream) { |
1001 | std::lock_guard<std::mutex> lock(graph_task->mutex_); |
1002 | graph_task->leaf_streams.emplace(*opt_parent_stream); |
1003 | } |
1004 | return; |
1005 | } |
1006 | |
1007 | if (AnomalyMode::is_enabled() && AnomalyMode::should_check_nan()) { |
1008 | AutoGradMode grad_mode(false); |
1009 | for (const auto i : c10::irange(num_outputs)) { |
1010 | auto& output = outputs[i]; |
1011 | at::OptionalDeviceGuard guard(device_of(output)); |
1012 | if (output.defined() && isnan(output)._is_any_true().item<bool>()) { |
1013 | std::stringstream ss; |
1014 | ss << "Function '" << fn.name() << "' returned nan values in its " << i |
1015 | << "th output." ; |
1016 | throw std::runtime_error(ss.str()); |
1017 | } |
1018 | } |
1019 | } |
1020 | |
1021 | // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and |
1022 | // cpu_ready_queue_ below |
1023 | std::lock_guard<std::mutex> lock(graph_task->mutex_); |
1024 | for (const auto i : c10::irange(num_outputs)) { |
1025 | auto& output = outputs[i]; |
1026 | const auto& next = fn.next_edge(i); |
1027 | |
1028 | if (!next.is_valid()) |
1029 | continue; |
1030 | |
1031 | // Check if the next function is ready to be computed |
1032 | bool is_ready = false; |
1033 | auto& dependencies = graph_task->dependencies_; |
1034 | auto it = dependencies.find(next.function.get()); |
1035 | |
1036 | if (it == dependencies.end()) { |
1037 | auto name = next.function->name(); |
1038 | throw std::runtime_error(std::string("dependency not found for " ) + name); |
1039 | } else if (--it->second == 0) { |
1040 | dependencies.erase(it); |
1041 | is_ready = true; |
1042 | } |
1043 | |
1044 | auto& not_ready = graph_task->not_ready_; |
1045 | auto not_ready_it = not_ready.find(next.function.get()); |
1046 | if (not_ready_it == not_ready.end()) { |
1047 | // Skip functions that aren't supposed to be executed |
1048 | if (!exec_info_.empty()) { |
1049 | auto it = exec_info_.find(next.function.get()); |
1050 | if (it == exec_info_.end() || !it->second.should_execute()) { |
1051 | continue; |
1052 | } |
1053 | } |
1054 | // No buffers have been allocated for the function |
1055 | InputBuffer input_buffer(next.function->num_inputs()); |
1056 | |
1057 | // Accumulates into buffer |
1058 | const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); |
1059 | input_buffer.add( |
1060 | next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); |
1061 | |
1062 | if (is_ready) { |
1063 | auto queue = ready_queue(cpu_ready_queue, input_buffer.device()); |
1064 | queue->push( |
1065 | NodeTask(graph_task, next.function, std::move(input_buffer))); |
1066 | } else { |
1067 | not_ready.emplace(next.function.get(), std::move(input_buffer)); |
1068 | } |
1069 | } else { |
1070 | // The function already has a buffer |
1071 | auto& input_buffer = not_ready_it->second; |
1072 | |
1073 | // Accumulates into buffer |
1074 | const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); |
1075 | input_buffer.add( |
1076 | next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); |
1077 | if (is_ready) { |
1078 | auto queue = ready_queue(cpu_ready_queue, input_buffer.device()); |
1079 | queue->push( |
1080 | NodeTask(graph_task, next.function, std::move(input_buffer))); |
1081 | not_ready.erase(not_ready_it); |
1082 | } |
1083 | } |
1084 | } |
1085 | } |
1086 | |
1087 | inline static uint64_t compute_min_topological_nr(const edge_list& outputs) { |
1088 | // Computes the mininum topological number among all the outputs |
1089 | if (outputs.empty()) { |
1090 | return 0; |
1091 | } |
1092 | auto min_topo_nr = std::numeric_limits<uint64_t>::max(); |
1093 | for (auto& output_edge : outputs) { |
1094 | auto topo_nr = output_edge.function.get()->topological_nr(); |
1095 | min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr; |
1096 | } |
1097 | return min_topo_nr; |
1098 | } |
1099 | |
1100 | auto Engine::compute_dependencies( |
1101 | Node* root, |
1102 | GraphTask& task, |
1103 | uint64_t min_topo_nr) -> void { |
1104 | // Computes the number of dependencies for each function which requires grad |
1105 | std::vector<Node*> queue{root}; |
1106 | bool might_use_cuda = at::globalContext().hasCUDA(); |
1107 | bool will_use_cuda = false; |
1108 | |
1109 | // Queue contains all nodes that will start propagating gradients. |
1110 | // We no longer have to expand functions that don't require grad. |
1111 | auto& dependencies = task.dependencies_; |
1112 | while (!queue.empty()) { |
1113 | auto fn = queue.back(); |
1114 | queue.pop_back(); |
1115 | if (fn->topological_nr() < min_topo_nr) { |
1116 | continue; |
1117 | } |
1118 | if (might_use_cuda && !will_use_cuda) { |
1119 | will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value(); |
1120 | } |
1121 | for (const auto& edge : fn->next_edges()) { |
1122 | if (auto next_ptr = edge.function.get()) { |
1123 | dependencies[next_ptr] += 1; |
1124 | const bool was_inserted = task.nodes_in_graph_.insert(next_ptr).second; |
1125 | if (was_inserted) |
1126 | queue.push_back(next_ptr); |
1127 | } |
1128 | } |
1129 | } |
1130 | |
1131 | if (will_use_cuda) { |
1132 | // Collects current streams for devices where this process has a context, |
1133 | // so GraphTask::exec_post_processing can sync them with leaf_streams. |
1134 | task.stash_current_streams(); |
1135 | } |
1136 | } |
1137 | |
1138 | auto Engine::execute( |
1139 | const edge_list& root_edges, |
1140 | const variable_list& inputs, |
1141 | bool keep_graph, |
1142 | bool create_graph, |
1143 | bool accumulate_grad, |
1144 | const edge_list& outputs) -> variable_list { |
1145 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
1146 | validate_outputs( |
1147 | root_edges, |
1148 | const_cast<variable_list&>(inputs), |
1149 | [](const std::string& msg) { return msg; }); |
1150 | if (accumulate_grad && create_graph) { |
1151 | TORCH_WARN_ONCE( |
1152 | "Using backward() with create_graph=True will create a reference cycle " |
1153 | "between the parameter and its gradient which can cause a memory leak. " |
1154 | "We recommend using autograd.grad when creating the graph to avoid this. " |
1155 | "If you have to use this function, make sure to reset the .grad fields of " |
1156 | "your parameters to None after use to break the cycle and avoid the leak." ); |
1157 | } |
1158 | |
1159 | // accumulate_grad is true if and only if the frontend call was to |
1160 | // grad(), not backward(). grad() returns the sum of the gradients |
1161 | // w.r.t. the inputs and thus needs the inputs to be present. |
1162 | TORCH_CHECK_VALUE( |
1163 | accumulate_grad || !outputs.empty(), "grad requires non-empty inputs." ); |
1164 | |
1165 | // A fresh first time Engine::execute call should start on the CPU device, |
1166 | // initialize a new thread local ready queue on CPU or reuse the existing one |
1167 | // (if there is one allocated already, i.e. consecutive backward calls, |
1168 | // re-entrant backward calls), then memoize the local_ready_queue in GraphTask |
1169 | init_local_ready_queue(); |
1170 | bool not_reentrant_backward_call = worker_device == NO_DEVICE; |
1171 | |
1172 | // Store root nodes so we can traverse through the graph later |
1173 | // e.g., for get_current_graph_task_execution_order |
1174 | c10::SmallVector<Node*, 4> temp_roots{root_edges.size()}; |
1175 | for (const auto i : c10::irange(root_edges.size())) { |
1176 | temp_roots[i] = root_edges[i].function.get(); |
1177 | } |
1178 | |
1179 | auto graph_task = std::make_shared<GraphTask>( |
1180 | /* keep_graph */ keep_graph, |
1181 | /* create_graph */ create_graph, |
1182 | /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1, |
1183 | /* cpu_ready_queue */ local_ready_queue, |
1184 | /* graph_roots */ std::move(temp_roots)); |
1185 | |
1186 | // If we receive a single root, skip creating extra root node |
1187 | bool skip_dummy_node = root_edges.size() == 1; |
1188 | auto graph_root = skip_dummy_node |
1189 | ? root_edges.at(0).function |
1190 | : std::make_shared<GraphRoot>(root_edges, inputs); |
1191 | |
1192 | auto min_topo_nr = compute_min_topological_nr(outputs); |
1193 | // Now compute the dependencies for all executable functions |
1194 | compute_dependencies(graph_root.get(), *graph_task, min_topo_nr); |
1195 | |
1196 | if (!outputs.empty()) { |
1197 | graph_task->init_to_execute( |
1198 | *graph_root, outputs, accumulate_grad, min_topo_nr); |
1199 | } |
1200 | |
1201 | // Queue the root |
1202 | if (skip_dummy_node) { |
1203 | InputBuffer input_buffer(root_edges.at(0).function->num_inputs()); |
1204 | auto input = inputs.at(0); |
1205 | |
1206 | const auto input_stream = InputMetadata(input).stream(); |
1207 | const auto opt_next_stream = |
1208 | root_edges.at(0).function->stream(c10::DeviceType::CUDA); |
1209 | input_buffer.add( |
1210 | root_edges.at(0).input_nr, |
1211 | std::move(input), |
1212 | input_stream, |
1213 | opt_next_stream); |
1214 | |
1215 | execute_with_graph_task( |
1216 | graph_task, std::move(graph_root), std::move(input_buffer)); |
1217 | } else { |
1218 | execute_with_graph_task( |
1219 | graph_task, std::move(graph_root), InputBuffer(variable_list())); |
1220 | } |
1221 | // Avoid a refcount bump for the Future, since we check for refcount in |
1222 | // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1) |
1223 | // in dist_engine.cpp). |
1224 | auto& fut = graph_task->future_result_; |
1225 | fut->wait(); |
1226 | graph_task->warning_handler_.replay_warnings(); |
1227 | return fut->value().toTensorVector(); |
1228 | } |
1229 | |
1230 | void Engine::initialize_device_threads_pool() { |
1231 | TORCH_CHECK( |
1232 | !in_bad_autograd_fork, |
1233 | "Unable to handle autograd's threading in combination with fork-based multiprocessing. " |
1234 | "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork" ); |
1235 | c10::call_once( |
1236 | start_device_threads_flag_, &Engine::start_device_threads, this); |
1237 | } |
1238 | |
1239 | c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task( |
1240 | const std::shared_ptr<GraphTask>& graph_task, |
1241 | std::shared_ptr<Node> graph_root, |
1242 | InputBuffer&& input_buffer) { |
1243 | initialize_device_threads_pool(); |
1244 | // Lock mutex for GraphTask. |
1245 | std::unique_lock<std::mutex> lock(graph_task->mutex_); |
1246 | |
1247 | auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device()); |
1248 | |
1249 | // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the |
1250 | // autograd engine with corresponding GraphTask, and its NOT a re-entrant call |
1251 | if (worker_device == NO_DEVICE) { |
1252 | // We set the worker_device to CPU_DEVICE only if worker_device was |
1253 | // previously NO_DEVICE. Setting it to CPU afterwards allow us to detect |
1254 | // whether this is a re-entrant call or not. |
1255 | set_device(CPU_DEVICE); |
1256 | |
1257 | // set the graph_task owner to the current device |
1258 | graph_task->owner_ = worker_device; |
1259 | |
1260 | // Now that all the non-thread safe fields of the graph_task have been |
1261 | // populated, we can enqueue it. |
1262 | queue->push( |
1263 | NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); |
1264 | |
1265 | // The owning thread start to drive the engine execution for any CPU task |
1266 | // that was just pushed or will be added later from other worker threads |
1267 | lock.unlock(); |
1268 | thread_main(graph_task); |
1269 | TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); |
1270 | // reset the worker_device after the completion of the graph_task, this is |
1271 | // so that the initial state of the engine remains the same across every |
1272 | // backward() or grad() call, we don't need to reset local_ready_queue as we |
1273 | // could possibly reuse it for new backward calls. |
1274 | worker_device = NO_DEVICE; |
1275 | } else { |
1276 | // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant |
1277 | // backward call from that device. |
1278 | graph_task->owner_ = worker_device; |
1279 | |
1280 | // Now that all the non-thread safe fields of the graph_task have been |
1281 | // populated, we can enqueue it. |
1282 | queue->push( |
1283 | NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); |
1284 | |
1285 | if (current_depth >= max_recursion_depth_) { |
1286 | // See Note [Reentrant backwards] |
1287 | // If reached the max depth, switch to a different thread |
1288 | add_thread_pool_task(graph_task); |
1289 | } else { |
1290 | // Total depth needs to be updated only in this codepath, since it is |
1291 | // not used in the block above (when we call add_thread_pool_task). |
1292 | // In the codepath above, GraphTask.reentrant_depth_ is used to |
1293 | // bootstrap total_depth in the other thread. |
1294 | ++total_depth; |
1295 | |
1296 | // Get back to work while we wait for our new graph_task to |
1297 | // complete! |
1298 | ++current_depth; |
1299 | lock.unlock(); |
1300 | thread_main(graph_task); |
1301 | --current_depth; |
1302 | --total_depth; |
1303 | |
1304 | // The graph task should have completed and the associated future should |
1305 | // be marked completed as well since 'thread_main' above is a call |
1306 | // blocking an autograd engine thread. |
1307 | TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); |
1308 | } |
1309 | } |
1310 | // graph_task_exec_post_processing is done when the Future is marked as |
1311 | // completed in mark_as_completed_and_run_post_processing. |
1312 | return graph_task->future_result_; |
1313 | } |
1314 | |
1315 | // note that when python is present, this base engine will be overriden |
1316 | // with a PythonEngine. Because this typically happens before get_default_engine |
1317 | // is called, this base engine will never be created. |
1318 | Engine& Engine::get_base_engine() { |
1319 | static Engine engine; |
1320 | return engine; |
1321 | } |
1322 | |
1323 | std::atomic<EngineStub> engine_stub(Engine::get_base_engine); |
1324 | |
1325 | void set_default_engine_stub(EngineStub stub) { |
1326 | engine_stub.store(stub); |
1327 | } |
1328 | |
1329 | Engine& Engine::get_default_engine() { |
1330 | return engine_stub.load()(); |
1331 | } |
1332 | |
1333 | void Engine::queue_callback(std::function<void()> callback) { |
1334 | TORCH_CHECK( |
1335 | current_graph_task, |
1336 | "Final callbacks can only be installed during backward pass." ); |
1337 | |
1338 | std::lock_guard<std::mutex> lock(current_graph_task->final_callbacks_lock_); |
1339 | current_graph_task->final_callbacks_.emplace_back(std::move(callback)); |
1340 | } |
1341 | |
1342 | bool Engine::is_checkpoint_valid() { |
1343 | return checkpoint_valid; |
1344 | } |
1345 | |
1346 | void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) { |
1347 | if (ready_queue) { |
1348 | // if ready_queue provided in the caller, use the caller's ready_queue to |
1349 | // initialize local_ready_queue |
1350 | local_ready_queue = std::move(ready_queue); |
1351 | } else if (!local_ready_queue) { |
1352 | // otherwise if local_ready_queue not allocated, allocate a new ready_queue |
1353 | local_ready_queue = std::make_shared<ReadyQueue>(); |
1354 | } |
1355 | } |
1356 | |
1357 | // CPU ready queue is per GraphTask, but CUDA device ready queues are shared |
1358 | // across all graph tasks |
1359 | auto Engine::ready_queue( |
1360 | std::shared_ptr<ReadyQueue> cpu_ready_queue, |
1361 | at::Device device) -> std::shared_ptr<ReadyQueue> { |
1362 | bool multithreading_disabled = |
1363 | !c10::AutogradState::get_tls_state().get_multithreading_enabled(); |
1364 | if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) { |
1365 | // return the cpu ready queue passed in |
1366 | TORCH_INTERNAL_ASSERT(cpu_ready_queue); |
1367 | return cpu_ready_queue; |
1368 | } else { |
1369 | TORCH_INTERNAL_ASSERT( |
1370 | 0 <= device.index() && |
1371 | device.index() < |
1372 | static_cast<c10::DeviceIndex>(device_ready_queues_.size())); |
1373 | // See Note [Allocating GPUs to autograd threads] |
1374 | return device_ready_queues_.at(device.index()); |
1375 | } |
1376 | } |
1377 | |
1378 | auto Engine::ready_queue_by_index( |
1379 | std::shared_ptr<ReadyQueue> cpu_ready_queue, |
1380 | int device_index) -> std::shared_ptr<ReadyQueue> { |
1381 | if (device_index == CPU_DEVICE) { |
1382 | // return the cpu ready queue passed in |
1383 | TORCH_INTERNAL_ASSERT(cpu_ready_queue); |
1384 | return cpu_ready_queue; |
1385 | } else { |
1386 | TORCH_INTERNAL_ASSERT( |
1387 | 0 <= device_index && |
1388 | device_index < |
1389 | static_cast<c10::DeviceIndex>(device_ready_queues_.size())); |
1390 | // See Note [Allocating GPUs to autograd threads] |
1391 | // NB: This function would become obsolete if we truly allocated a CPU |
1392 | // thread per device, rather than colocate. |
1393 | return device_ready_queues_.at(device_index); |
1394 | } |
1395 | } |
1396 | |
1397 | auto Engine::start_device_threads() -> void { |
1398 | // First always initialize the thread pool for re-entrant threads |
1399 | thread_pool_shared_ = std::make_shared<ThreadPoolShared>(); |
1400 | |
1401 | // Second, create special threads for each non-CPU device |
1402 | // See Note [Allocating GPUs to autograd threads] |
1403 | c10::DeviceIndex num_devices = 0; |
1404 | for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) { |
1405 | auto* impl = impl_atomic.load(); |
1406 | // Only record the number of devices for device that don't run on the |
1407 | // cpu ready queue. |
1408 | if (impl && !should_run_in_cpu_ready_queue(impl->type())) { |
1409 | num_devices = std::max(num_devices, impl->deviceCount()); |
1410 | } |
1411 | } |
1412 | |
1413 | // If there are no device except cpu, no need to create worker threads |
1414 | if (num_devices == 0) { |
1415 | return; |
1416 | } |
1417 | |
1418 | // Since we're about to create threads, forking is not possible anymore |
1419 | track_bad_autograd_forks(); |
1420 | |
1421 | // allocate one thread for every GPU device (but colocate GPUs of different |
1422 | // types), and pre-allocate the device_ready_queues_ to ensure safe reading on |
1423 | // it. |
1424 | device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices); |
1425 | for (auto& queue : device_ready_queues_) { |
1426 | queue = std::make_shared<ReadyQueue>(); |
1427 | } |
1428 | |
1429 | for (const auto i : c10::irange(num_devices)) { |
1430 | std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true); |
1431 | t.detach(); |
1432 | } |
1433 | // Wait for the threads to start |
1434 | { |
1435 | std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_); |
1436 | while (non_reentrant_device_thread_count_.load() != |
1437 | static_cast<uint32_t>(num_devices)) { |
1438 | non_reentrant_device_thread_condvar_.wait(lk); |
1439 | } |
1440 | } |
1441 | } |
1442 | |
1443 | void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) { |
1444 | std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_); |
1445 | // There may already be some items on the graphtasks_queue_ added by other |
1446 | // threads but not enough workers to get to the new task that will be |
1447 | // added |
1448 | bool create_thread = |
1449 | (thread_pool_shared_->num_workers_ <= |
1450 | thread_pool_shared_->graphtasks_queue_.size()); |
1451 | thread_pool_shared_->graphtasks_queue_.push(graph_task); |
1452 | // Don't need to be holding the lock while actually creating the thread |
1453 | lck.unlock(); |
1454 | if (create_thread) { |
1455 | // If we're creating a new thread, forking is not allowed anymore |
1456 | track_bad_autograd_forks(); |
1457 | std::thread t(&Engine::reentrant_thread_init, this); |
1458 | t.detach(); |
1459 | } |
1460 | // This works even if new thread is created because wait() will test the |
1461 | // predicate before waiting |
1462 | thread_pool_shared_->work_.notify_one(); |
1463 | } |
1464 | |
1465 | // Remembers current streams on all devices where a context has been created. |
1466 | // Only called if Engine::execute detects at least one node runs on a cuda |
1467 | // stream. |
1468 | void GraphTask::stash_current_streams() { |
1469 | const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; |
1470 | auto num_gpus = guard.deviceCount(); |
1471 | caller_current_streams_.resize(num_gpus); |
1472 | if (num_gpus > 0) { |
1473 | for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) { |
1474 | #if defined(USE_ROCM) && (ROCM_VERSION < 50000) |
1475 | // If the build targets ROCM, stash streams for all visible devices |
1476 | // unconditionally, to work around |
1477 | // https://github.com/pytorch/pytorch/issues/59750. |
1478 | // TODO: Remove ROCM-specific behavior when |
1479 | // https://github.com/pytorch/pytorch/issues/59750 is fixed. |
1480 | if (true) { |
1481 | #else |
1482 | if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) { |
1483 | #endif |
1484 | caller_current_streams_[idx] = |
1485 | guard.getStream({c10::DeviceType::CUDA, idx}); |
1486 | } else { |
1487 | caller_current_streams_[idx] = c10::nullopt; |
1488 | } |
1489 | } |
1490 | } |
1491 | } |
1492 | |
1493 | void GraphTask::init_to_execute( |
1494 | Node& graph_root, |
1495 | const edge_list& outputs, |
1496 | bool accumulate_grad, |
1497 | uint64_t min_topo_nr) { |
1498 | // Populates exec_info so nodes that should be executed have |
1499 | // `exec_info[node].needed_ = true` Only nodes that have a path to any edge in |
1500 | // `outputs` should be executed. The code below populates exec_info using |
1501 | // recursion, but the actual code does this iteratively. Refer to the |
1502 | // numbering to see how the actual code corresponds. A difference to note is |
1503 | // that in the iterative version, when you are working with the current Node, |
1504 | // you are reponsible to update your parent's is_needed after all your |
1505 | // children have been updated. |
1506 | // |
1507 | // is_needed = {fn: True for fn in outputs} # (0) |
1508 | // seen = {} |
1509 | // def compute_is_needed(fn): |
1510 | // for next_edge in fn.next_edges: |
1511 | // child_fn = next_edge.fn |
1512 | // if child_fn in seen and is_needed[child_fn]: # (1) |
1513 | // is_needed[fn] = true |
1514 | // else: |
1515 | // seen.add(child_fn) |
1516 | // if compute_is_needed(child_fn): |
1517 | // is_needed[fn] = true # (2) |
1518 | // # (3) exit for-loop |
1519 | // return is_needed[fn] |
1520 | // compute_is_needed(graph_root) |
1521 | // |
1522 | // NB: you might be wondering why we don't populate `seen` with outputs. We |
1523 | // cannot because in the case where two outputs lie on the same path, we still |
1524 | // need to explore past the first output or we would miss the nodes that are |
1525 | // required to compute the second output. |
1526 | int output_idx = 0; |
1527 | for (auto& output_edge : outputs) { |
1528 | // (0) `is_needed` above corresponds to `exec_info_[fn].needed_` |
1529 | Node* output = output_edge.function.get(); |
1530 | auto& info = exec_info_[output]; |
1531 | if (accumulate_grad) { |
1532 | // if called through `.backward()` we directly set `needed_` for all the |
1533 | // outputs to true |
1534 | info.needed_ = true; |
1535 | } else { |
1536 | // otherwise it is `.grad()` and we set exec_info[fn].captures_ instead |
1537 | // In terms of populating the rest of exec_info though, you can basically |
1538 | // think of this as the same as setting `needed_` is true directly. |
1539 | if (!info.captures_) { |
1540 | info.captures_ = make_unique<std::vector<ExecInfo::Capture>>(); |
1541 | } |
1542 | info.captures_->emplace_back(output_edge.input_nr, output_idx++); |
1543 | } |
1544 | } |
1545 | captured_vars_.resize(output_idx); |
1546 | |
1547 | struct Frame { |
1548 | Frame(Node* fn) : fn_(fn) {} |
1549 | Node* fn_{}; |
1550 | size_t next_next_fn_{}; |
1551 | |
1552 | Node* get_next_fn() { |
1553 | const auto& next = fn_->next_edges(); |
1554 | auto num_next = next.size(); |
1555 | while (next_next_fn_ < num_next) { |
1556 | auto fn = next[next_next_fn_++].function.get(); |
1557 | if (fn) |
1558 | return fn; |
1559 | } |
1560 | return nullptr; |
1561 | } |
1562 | }; |
1563 | |
1564 | auto nodeShouldExecute = [this](Node* fn) { |
1565 | auto it = exec_info_.find(fn); |
1566 | return it != exec_info_.end() && it->second.should_execute(); |
1567 | }; |
1568 | |
1569 | std::vector<Frame> stack; |
1570 | std::unordered_set<Node*> seen; |
1571 | stack.emplace_back(&graph_root); |
1572 | exec_info_.emplace(stack.back().fn_, ExecInfo()); |
1573 | |
1574 | while (!stack.empty()) { |
1575 | auto& frame = stack.back(); |
1576 | const auto fn = frame.fn_; |
1577 | |
1578 | Node* child_fn = nullptr; |
1579 | while ((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) { |
1580 | // (1) next child exists AND has already been seen |
1581 | if (nodeShouldExecute(child_fn)) { |
1582 | exec_info_[fn].needed_ = true; |
1583 | } |
1584 | } |
1585 | |
1586 | if (child_fn) { |
1587 | // (2) next child exists but has not been seen |
1588 | if (child_fn->topological_nr() < min_topo_nr) { |
1589 | // child created before the first output means this child cannot have |
1590 | // an edge to output |
1591 | continue; |
1592 | } |
1593 | stack.emplace_back(child_fn); |
1594 | } else { |
1595 | // (3) no next child exists for `fn` means its `needed` has already been |
1596 | // finalized. pop stack and update parent |
1597 | stack.pop_back(); |
1598 | if (nodeShouldExecute(fn) && !stack.empty()) { |
1599 | exec_info_[stack.back().fn_].needed_ = true; |
1600 | } |
1601 | } |
1602 | } |
1603 | } |
1604 | |
1605 | } // namespace autograd |
1606 | } // namespace torch |
1607 | |