1#include <torch/csrc/autograd/engine.h>
2
3#include <torch/csrc/autograd/anomaly_mode.h>
4#include <torch/csrc/autograd/autograd.h>
5#include <torch/csrc/autograd/function.h>
6#include <torch/csrc/autograd/functions/basic_ops.h>
7#include <torch/csrc/autograd/grad_mode.h>
8#include <torch/csrc/autograd/variable.h>
9#include <torch/csrc/utils/memory.h>
10
11#include <ATen/DeviceGuard.h>
12#include <ATen/ExpandUtils.h>
13#include <ATen/Parallel.h>
14#include <ATen/SparseCsrTensorUtils.h>
15#include <ATen/detail/CUDAHooksInterface.h>
16
17#ifndef AT_PER_OPERATOR_HEADERS
18#include <ATen/Functions.h>
19#else
20#include <ATen/ops/isnan.h>
21#endif
22
23#include <c10/core/DeviceGuard.h>
24#include <c10/core/Event.h>
25#include <c10/core/Stream.h>
26#include <c10/core/StreamGuard.h>
27#include <c10/util/Exception.h>
28#include <c10/util/Optional.h>
29#include <c10/util/ThreadLocal.h>
30#include <c10/util/irange.h>
31
32#include <atomic>
33#include <chrono>
34#include <condition_variable>
35#include <cstdint>
36#include <functional>
37#include <iostream>
38#include <memory>
39#include <mutex>
40#include <queue>
41#include <set>
42#include <sstream>
43#include <string>
44#include <thread>
45#include <typeinfo>
46#include <unordered_set>
47#include <utility>
48
49namespace torch {
50namespace autograd {
51
52namespace {
53static bool in_bad_autograd_fork =
54 false; // True for children forked after engine's thread pool init
55
56// Called in the forked child if engine's thread pool has already been
57// initialized
58static void forked_autograd_child() {
59 in_bad_autograd_fork = true;
60}
61
62// Should be called before unsafe for forks (thread pool) calls
63static void track_bad_autograd_forks() {
64#if !defined(WIN32)
65 static c10::once_flag flag;
66 c10::call_once(
67 flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); });
68#endif
69}
70
71inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
72 if (device == c10::kCPU || device == c10::kMeta || device == c10::kLazy) {
73 return true;
74 } else {
75 return false;
76 }
77}
78} // namespace
79
80// Threads spawned by the engine are assigned a 'worker_device' specifying
81// what device they process work for. This variable is initialized at:
82// 1. thread creation time for CUDA, XLA device threads, as they are
83// spinning threads waiting for works on their device.
84// 2. before the graph task execution for CPU threads, as for each
85// backward call we use the caller thread to drive engine execution.
86// This is used when handling reentrant backwards calls;
87// See Note [Reentrant backwards]
88static thread_local int worker_device = NO_DEVICE;
89
90// This variable is true if ALL invocations in the stack of re-entrant engine
91// invocations are imperative backwards. This special variable is needed for the
92// gradient checkpointing feature only.
93static thread_local bool checkpoint_valid = true;
94
95// Number of nested reentrant backwards calls currently on this thread
96static thread_local int current_depth = 0;
97
98// For all device threads (i.e. CUDA, XLA), total_depth represents the total
99// nested
100// reentrant backwards depths over all device threads.
101// For CPU devices, it is the total depth associated with the original backward
102// call.
103static thread_local int total_depth = 0;
104
105// The current GraphTask being executed by this thread. This helps
106// queue_callback() to find the target GraphTask to append final callbacks.
107C10_DEFINE_TLS_static(std::shared_ptr<GraphTask>, tls_current_graph_task);
108#define current_graph_task (tls_current_graph_task.get())
109
110// Every autograd worker thread is associated with a ready queue, which
111// specifies the stream of work of this thread to do. This shared_ptr is a
112// thread_local pointer to each thread's ready_queue, and it should be
113// initialized via the Engine::init_local_ready_queue() call in each
114// corresponding thread before execution.
115//
116// The CUDA, XLA threads are shared among all invocations of backwards via
117// device_ready_queues_, while the caller thread is dedicated to processing work
118// for devices returning true in should_run_in_cpu_ready_queue (most notably the
119// CPU device). So any given graph task maintains its own cpu_ready_queue_ where
120// you should send work for it to be done.
121//
122// For reentrant backward calls, if we spawn new thread from the current thread
123// because we reached the maximum depth, the new thread will just reuse the same
124// ReadyQueue with the parent thread for performance improvement.
125// see Note [Reentrant backwards] for more details.
126C10_DEFINE_TLS_static(std::shared_ptr<ReadyQueue>, tls_local_ready_queue);
127#define local_ready_queue (tls_local_ready_queue.get())
128
129// Note [Reentrant backwards]
130// ~~~~~~~~~~~~~~~~~~~~~~~~~~
131// To understand the reentrant backwards problem, we have to notice two
132// aspects of how the autograd engine is implemented today:
133//
134// 1. When you call Engine::execute(), you want to block until
135// differentiation finishes so that you can get the final result variables
136// of the backwards pass.
137//
138// 2. The engine operates by having a single worker thread per work queue,
139// and every work queue is pinned to a specific device where the
140// operation is executed.
141//
142// The problem is, suppose that you call backward() inside of a worker
143// thread. By property (1), we're supposed to block until the nested task
144// finishes. However, by property (2), this worker thread is on the
145// hook for processing the tasks assigned to it; we better not block,
146// because then all of our backward executions (including the one we
147// just started) will deadlock!
148//
149// We maintain a pool of threads waiting for work to do
150// When a reentrant backwards call occurs, the current thread blocks
151// and a thread from the pool is woken up to complete the blocking tasks and an
152// any other tasks that would have been assigned to that worker. If there are no
153// threads available, a new thread is spawned. The new thread will continue
154// processing tasks from the same ReadyQueue as the parent worker
155//
156// When the GraphTask is finished, the parent worker thread that is waiting on
157// the task is notified and the current thread returns to the pool.
158
159// Note [Streaming backwards]
160// ~~~~~~~~~~~~~~~~~~~~~~~~~~
161// On CUDA devices the autograd engine's device operations are run on the
162// same stream that ran them in forward. This requires automatically
163// syncing the streams so that function A finishes producing its
164// output before function B consumes it.
165//
166// This synchronization occurs when outputs are placed into input buffers.
167// The functions corresponding to input buffer positions have metadata
168// recording their streams from forward, and during backward this
169// data is used to sync the producer's stream with the consumer's.
170//
171// When a CUDA function is run either all its inputs were accumulated on the
172// stream used to run the function OR the inputs are on different devices
173// and the function is responsible for properly acquiring them.
174//
175// User-facing stream semantics of a backward() (or torch.autograd.grad())
176// call with respect to surrounding ops are the same as for any other call.
177// See "Stream semantics of backward passes" on
178// https://pytorch.org/docs/stable/notes/cuda.html
179//
180// Internally, backward() runs ops (including leaf nodes) on side threads.
181// And streams are thread local. So GraphTask achieves the above semantics by
182// 1. remembering the current streams on all active CUDA devices
183// in the user-facing thread (aka, the thread that called execute() to
184// launch the GraphTask)
185// 2. remembering the "leaf streams" (streams each backward leaf node ran on)
186// 3. during exec_post_processing, for each leaf stream, sync the remembered
187// current streams (on the leaf stream's device) with that
188// leaf stream.
189
190int NodeTask::getReentrantDepth() const {
191 std::shared_ptr<GraphTask> graph_task = base_.lock();
192 if (graph_task) {
193 return graph_task->reentrant_depth_;
194 } else {
195 // The graph task is no longer valid indicating an error. As a result, we
196 // try to move this to the front of the queue to ensure the autograd
197 // engine threads pick up this error soon.
198 return std::numeric_limits<int>::max();
199 }
200}
201
202CheckpointValidGuard::CheckpointValidGuard(
203 const std::shared_ptr<const GraphTask>& graph_task) {
204 prev_checkpoint_valid_state = checkpoint_valid;
205 checkpoint_valid =
206 graph_task->can_checkpoint() && prev_checkpoint_valid_state;
207}
208
209CheckpointValidGuard::~CheckpointValidGuard() {
210 checkpoint_valid = prev_checkpoint_valid_state;
211}
212
213auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
214 {
215 // Lock mutex for writing to heap_
216 std::lock_guard<std::mutex> lock(mutex_);
217 if (incrementOutstandingTasks) {
218 std::shared_ptr<GraphTask> graph_task = item.base_.lock();
219 TORCH_INTERNAL_ASSERT(graph_task, "GraphTask is no longer valid!");
220 ++graph_task->outstanding_tasks_;
221 }
222 heap_.push(std::move(item));
223 }
224 not_empty_.notify_one();
225}
226
227auto ReadyQueue::pushShutdownTask() -> void {
228 {
229 std::lock_guard<std::mutex> lock(mutex_);
230 heap_.push(NodeTask({}, nullptr, InputBuffer(0), true));
231 }
232 not_empty_.notify_one();
233}
234
235size_t ReadyQueue::size() const {
236 // Lock mutex for accesses to heap_
237 std::unique_lock<std::mutex> lock(mutex_);
238 return heap_.size();
239}
240
241auto ReadyQueue::pop() -> NodeTask {
242 // Lock mutex for accesses to heap_
243 std::unique_lock<std::mutex> lock(mutex_);
244 not_empty_.wait(lock, [this] { return !heap_.empty(); });
245 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
246 auto task = std::move(const_cast<NodeTask&>(heap_.top()));
247 heap_.pop();
248 return task;
249}
250
251bool ReadyQueue::empty() const {
252 // Lock mutex for accesses to heap_
253 std::unique_lock<std::mutex> lock(mutex_);
254 return heap_.empty();
255}
256
257Engine::Engine()
258 : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {}
259
260Engine::~Engine() {
261 stop();
262}
263
264// Send shutdown tasks to all device_ready_queues_ if no backward tasks are
265// running Even though readyQueue should be empty, shutdown tasks have the
266// highest priority
267void Engine::stop() {
268 if (stopped_) {
269 return;
270 }
271 stopped_ = true;
272 // Under some conditions, autograd threads can hang on shutdown
273 // Do not wait for them to shutdown indefinitely but rely on timeout
274 auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT");
275 auto wait_duration = wait_duration_str ? std::atof(wait_duration_str) : 10.0;
276 bool noBackward = true;
277 for (auto& queue : device_ready_queues_) {
278 noBackward = noBackward && queue->empty();
279 }
280 if (noBackward && wait_duration > 0.0f) {
281 for (auto& queue : device_ready_queues_) {
282 queue->pushShutdownTask();
283 }
284 // Do not wait for termination of global threads on Windows
285 // Because CRT terminates DLL threads before calling
286 // global object destructors
287#if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME)
288
289 using namespace std::chrono_literals;
290 // Set a deadline for how long it is OK to wait device threads to shutdown
291 auto wait_deadline =
292 std::chrono::steady_clock::now() + wait_duration * 1.0s;
293 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
294 while (non_reentrant_device_thread_count_.load() != 0) {
295 if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) ==
296 std::cv_status::timeout) {
297 break;
298 }
299 }
300#endif
301 }
302 // Otherwise threads are leaked
303}
304
305void Engine::release_workers() {
306 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
307 non_reentrant_device_thread_count_.store(0);
308 non_reentrant_device_thread_condvar_.notify_one();
309}
310
311void Engine::increment_non_reentrant_thread_count() {
312 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
313 non_reentrant_device_thread_count_.fetch_add(1);
314 non_reentrant_device_thread_condvar_.notify_one();
315}
316
317void Engine::decrement_non_reentrant_thread_count() {
318 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
319 non_reentrant_device_thread_count_.fetch_sub(1);
320 non_reentrant_device_thread_condvar_.notify_one();
321}
322
323void Engine::thread_init(
324 int device,
325 const std::shared_ptr<ReadyQueue>& ready_queue,
326 bool should_increment) {
327 if (should_increment) {
328 increment_non_reentrant_thread_count();
329 }
330
331 at::init_num_threads();
332
333 // Note [Allocating GPUs to autograd threads]
334 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335 // What's our strategy here? Originally, the autograd engine was written
336 // with only CUDA in mind. We allocate one thread to handle all CPU
337 // operations, and a thread per CUDA device.
338 //
339 // But what if we have OTHER devices? There are two plausible
340 // strategies:
341 //
342 // - We can allocate threads equal to max(num_cuda_devices, num_xla_devices,
343 // ...) and colocate cuda device 0 with xla device 0
344 // - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices,
345 // ...) keeping everyone separate.
346 //
347 // We don't have any good reason to prefer one or the other, so we've
348 // arbitrarily picked to colocate devices. Maybe the other approach is
349 // better.
350
351#if defined(USE_CUDA)
352 if (at::detail::getCUDAHooks().hasPrimaryContext(device)) {
353 set_device(device);
354 }
355#else
356 set_device(device);
357#endif
358
359 // initialize each device thread's thread local ready queue with the ready
360 // queue that is created before the thread initialization
361 init_local_ready_queue(ready_queue);
362
363 std::shared_ptr<GraphTask> graph_task = nullptr;
364 thread_main(graph_task);
365 if (should_increment) {
366 // Decrement the count during shutdown if we incremented earlier.
367 decrement_non_reentrant_thread_count();
368 }
369}
370
371GraphTaskGuard::GraphTaskGuard(std::shared_ptr<GraphTask> graph_task) {
372 last_graph_task_ = std::move(current_graph_task);
373 current_graph_task = std::move(graph_task);
374}
375GraphTaskGuard::~GraphTaskGuard() {
376 restore_current_graph_task();
377}
378
379void GraphTaskGuard::restore_current_graph_task() {
380 current_graph_task = std::move(last_graph_task_);
381}
382
383// The current graph task's exec_info is being used to trim unnecessary edegs
384// during node evaluation, see `Node.task_should_compute_output()` function.
385const std::unordered_map<Node*, GraphTask::ExecInfo>*
386get_current_graph_task_exec_info() {
387 return current_graph_task ? &current_graph_task->exec_info_ : nullptr;
388}
389
390const std::unordered_set<Node*>* get_current_graph_task_nodes_in_graph() {
391 return current_graph_task ? &current_graph_task->nodes_in_graph_ : nullptr;
392}
393
394int get_current_graph_task_id() {
395 return current_graph_task ? current_graph_task->id_ : -1;
396}
397
398bool get_current_graph_task_keep_graph() {
399 return current_graph_task ? current_graph_task->keep_graph_ : true;
400}
401
402void add_node_to_current_graph_task_exec_info(Node* fn) {
403 current_graph_task->exec_info_[fn].needed_ = true;
404}
405
406// NB: The engine itself does not use the outputs of this function.
407std::vector<Node*> get_current_graph_task_execution_order() {
408 std::shared_ptr<GraphTask> task = current_graph_task;
409 if (!task) {
410 return {};
411 }
412
413 // We could potentially check if there is only a single device here
414 // but explicitly require this context doens't seem bad either
415 TORCH_CHECK(
416 !c10::AutogradState::get_tls_state().get_multithreading_enabled(),
417 "get_current_graph_task_execution_order expects the current backward to be "
418 "executed with multithreading disabled, e.g. by running:\n\n"
419 ">>> with torch.autograd.set_multithreading_enabled(False):\n"
420 "... torch.autograd.grad(...)\n");
421
422 const bool check_exec_info = !task->exec_info_.empty();
423 std::vector<Node*> out{};
424 std::unordered_set<Node*> seen{};
425
426 auto compare_seq_nr = [](Node* n1, Node* n2) {
427 return n1->sequence_nr() < n2->sequence_nr();
428 };
429 std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
430 compare_seq_nr);
431
432 for (Node* ptr : task->graph_roots_) {
433 heap.push(ptr);
434 }
435
436 // Implementation notes:
437 // - Don't need to count dependencies because we have sequence_nr
438 // - Don't need to check topological_nr because we have exec_info
439 while (!heap.empty()) {
440 Node* fn = heap.top();
441 heap.pop();
442
443 const bool was_inserted = seen.insert(fn).second;
444 if (!was_inserted) {
445 continue;
446 }
447
448 out.push_back(fn);
449 for (const auto& edge : fn->next_edges()) {
450 Node* next_ptr = edge.function.get();
451 if (!next_ptr) {
452 continue;
453 }
454 if (check_exec_info) {
455 auto it = task->exec_info_.find(next_ptr);
456 if (it == task->exec_info_.end() || !it->second.should_execute()) {
457 continue;
458 }
459 }
460 heap.push(next_ptr);
461 }
462 }
463 return out;
464}
465
466// NOTE: graph_tasks do not necessarily form a stack. Imagine this
467// case:
468//
469// +----> Eval1
470// Root
471// +----> Eval2
472//
473// Once Root is executed, both Eval1 and Eval2 are added to the ready queue.
474// Next, Eval1 is run and this causes the worker to enter thread_main again.
475// Then, it pops the next task from the queue, but at this point it is Eval2.
476// It enters thread_main once again, but now with graph_task of Eval2, which is
477// completely unrelated to that of Eval1 (it's not a recursive call).
478// It's all ok and is handled right now, but it should be accounted for
479// in case this code is to be changed.
480//
481// thread_main is used by:
482// 1). autograd threads for devices (i.e. CUDA, XLA)
483// 2). the caller/owning thread of the backward call on CPU (sync mode)
484// 3). Renetrant backward that invoked by either 1) or 2)
485// The exit conditions are different for the above three cases.
486// For 1), we are spinning on running the thread_main on device autograd
487// threads throughout the Engine lifetime, thread_main will get
488// terminated during Engine destruction by pushing shutdown tasks
489// For 2), the owning thread of the backward call drives the thread_main
490// synchronously until the graph_task of that owning thread is
491// completed and exit the thread_main to continue executing the
492// result of caller's code.
493// For 3), the reentrant backward that invokes
494// thread_main, either from 1) or 2), will not spin and will exit as
495// long as graph_task is completed and notify the owning thread as
496// needed.
497auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
498 // When graph_task is nullptr, this is a long running thread that processes
499 // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
500 // backwards, user thread), this function is expected to exit once that
501 // graph_task complete.
502
503 // local_ready_queue should already been initialized when we get into
504 // thread_main
505 TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
506 while (graph_task == nullptr || !graph_task->future_result_->completed()) {
507 // local_graph_task represents the graph_task we retrieve from the queue.
508 // The outer graph_task represents the overall graph_task we need to execute
509 // for reentrant execution.
510 std::shared_ptr<GraphTask> local_graph_task;
511 {
512 // Scope this block of execution since NodeTask is not needed after this
513 // block and can be deallocated (release any references to grad tensors
514 // as part of inputs_).
515 NodeTask task = local_ready_queue->pop();
516 // This will only work if the worker is running a non backward task
517 // TODO Needs to be fixed this to work in all cases
518 if (task.isShutdownTask_) {
519 C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
520 break;
521 }
522
523 if (!(local_graph_task = task.base_.lock())) {
524 // GraphTask for function is no longer valid, skipping further
525 // execution.
526 continue;
527 }
528
529 if (task.fn_ && !local_graph_task->has_error_.load()) {
530 // Set the ThreadLocalState before calling the function.
531 // NB: The ThreadLocalStateGuard doesn't set the grad_mode because
532 // GraphTask always saves ThreadLocalState without grad_mode.
533 at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
534 c10::WarningUtils::WarningHandlerGuard warnings_guard(
535 &local_graph_task->warning_handler_);
536
537 try {
538 // The guard sets the thread_local current_graph_task on construction
539 // and restores it on exit. The current_graph_task variable helps
540 // queue_callback() to find the target GraphTask to append final
541 // callbacks.
542 GraphTaskGuard guard(local_graph_task);
543 NodeGuard ndguard(task.fn_);
544 {
545 RECORD_FUNCTION(
546 c10::str(
547 "autograd::engine::evaluate_function: ",
548 task.fn_.get()->name()),
549 c10::ArrayRef<const c10::IValue>());
550 evaluate_function(
551 local_graph_task,
552 task.fn_.get(),
553 task.inputs_,
554 local_graph_task->cpu_ready_queue_);
555 }
556 } catch (std::exception& e) {
557 thread_on_exception(local_graph_task, task.fn_, e);
558 }
559 }
560 }
561
562 // Decrement the outstanding tasks.
563 --local_graph_task->outstanding_tasks_;
564
565 // Check if we've completed execution.
566 if (local_graph_task->completed()) {
567 local_graph_task->mark_as_completed_and_run_post_processing();
568
569 auto base_owner = local_graph_task->owner_;
570 // The current worker thread finish the graph_task, but the owning thread
571 // of the graph_task might be sleeping on pop() if it does not have work.
572 // So we need to send a dummy function task to the owning thread just to
573 // ensure that it's not sleeping, so that we can exit the thread_main.
574 // If it has work, it might see that graph_task->outstanding_tasks_ == 0
575 // before it gets to the task, but it's a no-op anyway.
576 //
577 // NB: This is not necessary if the current thread is the owning thread.
578 if (worker_device != base_owner) {
579 // Synchronize outstanding_tasks_ with queue mutex
580 std::atomic_thread_fence(std::memory_order_release);
581 ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
582 ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
583 }
584 }
585 }
586}
587
588// Reentrant call will re-use the graph_task's owner thread ready_queue for
589// queueing tasks (NOTE: this is not true in the async_mode of the engine).
590// While we can create separate ready queue for each new reentrant
591// thread, but sharing the same cpu_ready_queue with parent thread is a
592// performance improvement and cuda thread still have to do the same thing.
593void Engine::reentrant_thread_init() {
594 at::init_num_threads();
595 auto tp_shared = thread_pool_shared_;
596 while (true) {
597 std::unique_lock<std::mutex> lk(tp_shared->mutex_);
598 ++thread_pool_shared_->num_workers_;
599 tp_shared->work_.wait(
600 lk, [&tp_shared] { return !tp_shared->graphtasks_queue_.empty(); });
601 --thread_pool_shared_->num_workers_;
602 auto task = tp_shared->graphtasks_queue_.front();
603 tp_shared->graphtasks_queue_.pop();
604 lk.unlock();
605 std::shared_ptr<GraphTask> graph_task;
606 if (!(graph_task = task.lock())) {
607 LOG(INFO) << "GraphTask has expired, skipping reentrant execution";
608 continue;
609 }
610 set_device(graph_task->owner_);
611 // set the local_ready_queue to the ready queue on the graph_task->owner_
612 // device
613 local_ready_queue =
614 ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
615 total_depth = graph_task->reentrant_depth_;
616 thread_main(graph_task);
617 }
618}
619
620void Engine::thread_on_exception(
621 std::shared_ptr<GraphTask> graph_task,
622 const std::shared_ptr<Node>& fn,
623 std::exception& e) {
624 graph_task->set_exception(std::current_exception(), fn);
625}
626
627bool GraphTask::completed() {
628 return outstanding_tasks_.load() == 0 ||
629 (exit_on_error_ && has_error_.load());
630}
631
632void GraphTask::mark_as_completed_and_run_post_processing() {
633 // Allow only one thread one attempt to process this logic.
634 if (future_completed_.exchange(true)) {
635 // Future is already marked complete, or being marked as such.
636 // In case the marking complete is only in progress, we add a
637 // wait() to guarantee the future is marked complete on exit.
638 future_result_->wait();
639 return;
640 }
641
642 try {
643 // Run post processing, before marking the future as complete.
644 // Drop lock prior to completing, to avoid holding across callbacks.
645 std::unique_lock<std::mutex> lock(mutex_);
646
647 exec_post_processing();
648 std::vector<Variable> vars = std::move(captured_vars_);
649
650 // Need to unlock before we call markCompleted to avoid holding locks
651 // when the callbacks are called.
652 lock.unlock();
653 future_result_->markCompleted(std::move(vars));
654 } catch (std::exception& e) {
655 future_result_->setErrorIfNeeded(std::current_exception());
656 }
657}
658
659void GraphTask::exec_post_processing() {
660 if (!not_ready_.empty()) {
661 throw std::runtime_error("could not compute gradients for some functions");
662 }
663
664 // set the thread_local current_graph_task_ as more callbacks can be installed
665 // by existing final callbacks.
666 GraphTaskGuard guard(shared_from_this());
667 // Lock mutex during each iteration for accessing final_callbacks.size()
668 // Unlocking is necessary, because the callback can register
669 // more callbacks (or they can be registered from other threads
670 // while it's waiting.
671 std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
672
673 // caller_current_streams_ with nullopt entries removed
674 std::vector<c10::Stream> caller_current_streams_filtered;
675
676 // See Note [Streaming backwards].
677 // Syncs caller_current_stream with leaf streams, so final_callbacks may use
678 // any grad on its device's current stream.
679 if (!leaf_streams.empty()) {
680 for (const auto& leaf_stream : leaf_streams) {
681 // stash_current_streams() stashed streams for all device IDs that already
682 // had a CUDA context before the GraphTask executed. For inactive devices,
683 // it stashed a c10::nullopt. I don't expect GraphTask's backward pass ran
684 // leaf nodes on any new devices, so the stashed streams should be enough.
685 // If leaf_stream.device_index() happens to be for a new device,
686 // operator* on the c10::nullopt should throw an error.
687 const auto caller_current_stream =
688 *caller_current_streams_[leaf_stream.device_index()];
689
690 if (caller_current_stream != leaf_stream) {
691 auto event = c10::Event{c10::DeviceType::CUDA};
692 event.record(leaf_stream);
693 caller_current_stream.wait(event);
694 }
695 }
696
697 caller_current_streams_filtered.reserve(caller_current_streams_.size());
698 for (const auto& opt_stream : caller_current_streams_) {
699 if (opt_stream.has_value()) {
700 caller_current_streams_filtered.push_back(*opt_stream);
701 }
702 }
703 }
704
705 {
706 // final_callbacks run on the per-device caller_current_streams (the ambient
707 // streams surrounding the user's call to backward()). This has two
708 // benefits:
709 // 1. caller_current_streams have been synced with leaf_streams, so
710 // callbacks may
711 // safely access any grad.
712 // 2. The callback's results can safely be used on (user-facing)
713 // caller_current_streams
714 // after backward().
715 c10::MultiStreamGuard g(caller_current_streams_filtered);
716
717 // Set the ThreadLocalState before calling the function.
718 // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
719 // always saves ThreadLocalState without grad_mode.
720 at::ThreadLocalStateGuard tls_guard(this->thread_locals_);
721
722 // WARNING: Don't use a range-for loop here because more callbacks may be
723 // added in between callback calls, so iterators may become invalidated.
724 // NOLINTNEXTLINE(modernize-loop-convert)
725 for (size_t i = 0; i < final_callbacks_.size(); ++i) {
726 cb_lock.unlock();
727 final_callbacks_[i]();
728 cb_lock.lock();
729 }
730 }
731}
732
733void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
734 if (!has_error_.exchange(true)) {
735 if (AnomalyMode::is_enabled() && fn) {
736 fn->metadata()->print_stack(fn->name());
737 }
738 }
739}
740
741void GraphTask::set_exception(
742 std::exception_ptr eptr,
743 const std::shared_ptr<Node>& fn) {
744 set_exception_without_signal(fn);
745 if (!future_completed_.exchange(true)) {
746 future_result_->setError(std::move(eptr));
747 }
748}
749
750static variable_list call_pre_hooks(Node& fn, variable_list inputs) {
751 for (const auto& hook : fn.pre_hooks()) {
752 inputs = (*hook)(inputs);
753 }
754 return inputs;
755}
756
757static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) {
758 for (const auto& hook : fn.tensor_pre_hooks()) {
759 inputs = (*hook)(inputs);
760 }
761 for (const auto& pair : fn.retains_grad_hooks()) {
762 inputs = (*pair.second)(inputs);
763 }
764 return inputs;
765}
766
767static variable_list call_post_hooks(
768 Node& fn,
769 variable_list outputs,
770 const variable_list& inputs) {
771 for (const auto& hook : fn.post_hooks()) {
772 outputs = (*hook)(outputs, inputs);
773 }
774 return outputs;
775}
776
777void set_device(int device) {
778 // NB: We MUST NOT construct the guard for device CPU,
779 // as in some settings we compile with cuda, but
780 // have lazy stubs for CUDA functionality (so actually
781 // attempting to setup a guard(CPU_DEVICE) will cause an
782 // error, because it will still query cudaGetDevice).
783 //
784 // Don't use DeviceGuard here because its destructor may be called before the
785 // device is reset. This is fine because the device is thread local.
786 if (device != CPU_DEVICE) {
787 for (const auto i : c10::irange(static_cast<size_t>(
788 c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) {
789 auto* impl = c10::impl::device_guard_impl_registry[i].load();
790 if (impl && device < impl->deviceCount()) {
791 impl->setDevice(at::Device(static_cast<c10::DeviceType>(i), device));
792 }
793 }
794 }
795 worker_device = device;
796}
797
798void validate_outputs(
799 const edge_list& edges,
800 variable_list& grads,
801 const std::function<std::string(const std::string&)>& format_error) {
802 if (grads.size() != edges.size()) {
803 std::stringstream ss;
804 ss << "invalid number of gradients - expected ";
805 ss << edges.size() << ", but got " << grads.size();
806 AT_ERROR(format_error(ss.str()));
807 }
808 for (const auto i : c10::irange(grads.size())) {
809 const auto& edge = edges[i];
810 if (!edge.is_valid())
811 continue;
812
813 const auto& metadata = edge.function->input_metadata(edge.input_nr);
814 auto& grad = grads[i];
815 if (!grad.defined()) {
816 // FIXME: TestJit.test_ge_optimized fails this assertion.
817 // std::stringstream ss;
818 // ss << "undefined gradient at index " << i;
819 // AT_ERROR(format_error(ss.str()));
820 continue;
821 }
822
823 if (!metadata.is_same_shape(grad)) {
824 if (metadata.is_expandable_to_shape(grad)) {
825 grad = metadata.reduce_grad(grad);
826 } else {
827 const auto message = metadata.incompatible_shape_error_message(i, grad);
828 AT_ERROR(format_error(message.str()));
829 }
830 }
831
832 bool input_is_complex =
833 isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
834 bool grad_is_complex = isComplexType(grad.scalar_type());
835
836 TORCH_CHECK(
837 isFloatingType(grad.scalar_type()) ||
838 (input_is_complex == grad_is_complex));
839 if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
840 grad.scalar_type()) {
841 grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
842 }
843 if (grad.dtype() != metadata.dtype()) {
844 std::stringstream ss;
845 ss << "invalid gradient at index " << i << " - expected dtype ";
846 ss << metadata.dtype() << " but got " << grad.dtype();
847 AT_ERROR(format_error(ss.str()));
848 }
849 if (grad.layout() != metadata.layout()) {
850 // TODO: Currently we only support (*, Sparse) combination for
851 // (tensor.layout(), tensor.grad.layout()) In future, there will be an
852 // oppportunity to support more combinations of layouts if they are
853 // composable (example., operations like addition etc., are well defined
854 // between tensors of different layouts.), as well as all parts of
855 // autograd like AccumulateGrad correctly handle this. We allow grad to be
856 // Strided when metadata is SparseCsr
857 if (!grad.is_sparse() &&
858 !(grad.layout() == at::kStrided &&
859 (at::sparse_csr::is_sparse_compressed(metadata.layout()) ||
860 metadata.layout() == at::kSparse))) {
861 std::stringstream ss;
862 ss << "invalid gradient at index " << i << " - expected layout ";
863 ss << metadata.layout() << " but got " << grad.layout();
864 AT_ERROR(format_error(ss.str()));
865 }
866 }
867
868 if (grad.device() != metadata.device()) {
869 // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
870 // should be eventually removed
871 if (!(metadata.is_tensor_subclass() ||
872 grad.unsafeGetTensorImpl()->is_python_dispatch())) {
873 if (grad.dim() == 0) {
874 grad = grad.to(metadata.device());
875 } else {
876 std::stringstream ss;
877 ss << "invalid gradient at index " << i << " - expected device ";
878 ss << metadata.device() << " but got " << grad.device();
879 AT_ERROR(format_error(ss.str()));
880 }
881 }
882 }
883 // We should not build graph for Tensors that are not differentiable
884 TORCH_INTERNAL_ASSERT(isDifferentiableType(grad.scalar_type()));
885 }
886}
887
888static variable_list call_function(
889 std::shared_ptr<GraphTask>& graph_task,
890 Node* func,
891 InputBuffer& inputBuffer) {
892 CheckpointValidGuard cpvguard(graph_task);
893 auto& fn = *func;
894 auto inputs =
895 call_tensor_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
896 inputs = call_pre_hooks(fn, std::move(inputs));
897 if (!graph_task->keep_graph_) {
898 fn.will_release_variables();
899 }
900
901 const auto has_post_hooks = !fn.post_hooks().empty();
902 variable_list outputs;
903
904 if (has_post_hooks) {
905 // In functions/accumulate_grad.cpp, there is some logic to check the
906 // conditions under which the incoming gradient can be stolen directly
907 // (which elides a deep copy) instead of cloned. One of these conditions
908 // is that the incoming gradient's refcount must be 1 (nothing else is
909 // referencing the same data). Stashing inputs_copy here bumps the
910 // refcount, so if post hooks are employed, it's actually still ok for
911 // accumulate_grad.cpp to steal the gradient if the refcount is 2.
912 //
913 // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
914 // accumulate_grad.cpp accounts for this, but also creates a silent
915 // dependency between engine.cpp (ie, this particular engine
916 // implementation) and accumulate_grad.cpp.
917 //
918 // If you change the logic here, make sure it's compatible with
919 // accumulate_grad.cpp.
920 auto inputs_copy = inputs;
921 outputs = fn(std::move(inputs_copy));
922 } else {
923 outputs = fn(std::move(inputs));
924 }
925
926 validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
927 std::ostringstream ss;
928 ss << "Function " << fn.name() << " returned an " << msg;
929 return ss.str();
930 });
931
932 if (has_post_hooks) {
933 // NOLINTNEXTLINE(bugprone-use-after-move)
934 return call_post_hooks(fn, std::move(outputs), inputs);
935 }
936 return outputs;
937}
938
939void Engine::evaluate_function(
940 std::shared_ptr<GraphTask>& graph_task,
941 Node* func,
942 InputBuffer& inputs,
943 const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
944 // The InputBuffer::adds that supplied incoming grads took pains to
945 // ensure they're safe to consume in the context of the present
946 // func's stream (if applicable). So we guard onto that stream
947 // before working with the grads in any capacity.
948 const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
949 c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
950
951 // If exec_info_ is not empty, we have to instrument the execution
952 auto& exec_info_ = graph_task->exec_info_;
953 if (!exec_info_.empty()) {
954 auto& fn_info = exec_info_.at(func);
955 variable_list new_inputs = inputs.buffer;
956 if (!fn_info.needed_) {
957 // We always want to call tensor pre-hooks, but want to avoid calling it
958 // twice. needed_ = True indicates that we will call tensor pre-hooks
959 // later.
960 //
961 // See NOTE [Hooks ordering] for more context.
962 new_inputs = call_tensor_pre_hooks(
963 *func, InputBuffer::variables(std::move(inputs)));
964 }
965 if (auto* capture_vec = fn_info.captures_.get()) {
966 const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
967 // Lock mutex for writing to graph_task->captured_vars_.
968 std::lock_guard<std::mutex> lock(graph_task->mutex_);
969 for (const auto& capture : *capture_vec) {
970 auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
971 captured_grad = new_inputs[capture.input_idx_];
972 // NOTE [Deprecated capture hooks]
973 for (const auto& hook :
974 capture.DO_NOT_USE_DEPRECATED_get_capture_hooks()) {
975 captured_grad = (*hook)(captured_grad);
976 }
977 if (opt_parent_stream) {
978 // No need to take graph_task->mutex_ here, we already hold it
979 graph_task->leaf_streams.emplace(*opt_parent_stream);
980 }
981 }
982 }
983 if (!fn_info.needed_) {
984 // Skip execution if we don't need to execute the function.
985 return;
986 }
987 }
988
989 auto outputs = call_function(graph_task, func, inputs);
990
991 auto& fn = *func;
992 if (!graph_task->keep_graph_) {
993 fn.release_variables();
994 }
995
996 int num_outputs = outputs.size();
997 if (num_outputs == 0) { // Note: doesn't acquire the mutex
998 // Records leaf stream (if applicable)
999 // See Note [Streaming backwards]
1000 if (opt_parent_stream) {
1001 std::lock_guard<std::mutex> lock(graph_task->mutex_);
1002 graph_task->leaf_streams.emplace(*opt_parent_stream);
1003 }
1004 return;
1005 }
1006
1007 if (AnomalyMode::is_enabled() && AnomalyMode::should_check_nan()) {
1008 AutoGradMode grad_mode(false);
1009 for (const auto i : c10::irange(num_outputs)) {
1010 auto& output = outputs[i];
1011 at::OptionalDeviceGuard guard(device_of(output));
1012 if (output.defined() && isnan(output)._is_any_true().item<bool>()) {
1013 std::stringstream ss;
1014 ss << "Function '" << fn.name() << "' returned nan values in its " << i
1015 << "th output.";
1016 throw std::runtime_error(ss.str());
1017 }
1018 }
1019 }
1020
1021 // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and
1022 // cpu_ready_queue_ below
1023 std::lock_guard<std::mutex> lock(graph_task->mutex_);
1024 for (const auto i : c10::irange(num_outputs)) {
1025 auto& output = outputs[i];
1026 const auto& next = fn.next_edge(i);
1027
1028 if (!next.is_valid())
1029 continue;
1030
1031 // Check if the next function is ready to be computed
1032 bool is_ready = false;
1033 auto& dependencies = graph_task->dependencies_;
1034 auto it = dependencies.find(next.function.get());
1035
1036 if (it == dependencies.end()) {
1037 auto name = next.function->name();
1038 throw std::runtime_error(std::string("dependency not found for ") + name);
1039 } else if (--it->second == 0) {
1040 dependencies.erase(it);
1041 is_ready = true;
1042 }
1043
1044 auto& not_ready = graph_task->not_ready_;
1045 auto not_ready_it = not_ready.find(next.function.get());
1046 if (not_ready_it == not_ready.end()) {
1047 // Skip functions that aren't supposed to be executed
1048 if (!exec_info_.empty()) {
1049 auto it = exec_info_.find(next.function.get());
1050 if (it == exec_info_.end() || !it->second.should_execute()) {
1051 continue;
1052 }
1053 }
1054 // No buffers have been allocated for the function
1055 InputBuffer input_buffer(next.function->num_inputs());
1056
1057 // Accumulates into buffer
1058 const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
1059 input_buffer.add(
1060 next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1061
1062 if (is_ready) {
1063 auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1064 queue->push(
1065 NodeTask(graph_task, next.function, std::move(input_buffer)));
1066 } else {
1067 not_ready.emplace(next.function.get(), std::move(input_buffer));
1068 }
1069 } else {
1070 // The function already has a buffer
1071 auto& input_buffer = not_ready_it->second;
1072
1073 // Accumulates into buffer
1074 const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
1075 input_buffer.add(
1076 next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1077 if (is_ready) {
1078 auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1079 queue->push(
1080 NodeTask(graph_task, next.function, std::move(input_buffer)));
1081 not_ready.erase(not_ready_it);
1082 }
1083 }
1084 }
1085}
1086
1087inline static uint64_t compute_min_topological_nr(const edge_list& outputs) {
1088 // Computes the mininum topological number among all the outputs
1089 if (outputs.empty()) {
1090 return 0;
1091 }
1092 auto min_topo_nr = std::numeric_limits<uint64_t>::max();
1093 for (auto& output_edge : outputs) {
1094 auto topo_nr = output_edge.function.get()->topological_nr();
1095 min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr;
1096 }
1097 return min_topo_nr;
1098}
1099
1100auto Engine::compute_dependencies(
1101 Node* root,
1102 GraphTask& task,
1103 uint64_t min_topo_nr) -> void {
1104 // Computes the number of dependencies for each function which requires grad
1105 std::vector<Node*> queue{root};
1106 bool might_use_cuda = at::globalContext().hasCUDA();
1107 bool will_use_cuda = false;
1108
1109 // Queue contains all nodes that will start propagating gradients.
1110 // We no longer have to expand functions that don't require grad.
1111 auto& dependencies = task.dependencies_;
1112 while (!queue.empty()) {
1113 auto fn = queue.back();
1114 queue.pop_back();
1115 if (fn->topological_nr() < min_topo_nr) {
1116 continue;
1117 }
1118 if (might_use_cuda && !will_use_cuda) {
1119 will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
1120 }
1121 for (const auto& edge : fn->next_edges()) {
1122 if (auto next_ptr = edge.function.get()) {
1123 dependencies[next_ptr] += 1;
1124 const bool was_inserted = task.nodes_in_graph_.insert(next_ptr).second;
1125 if (was_inserted)
1126 queue.push_back(next_ptr);
1127 }
1128 }
1129 }
1130
1131 if (will_use_cuda) {
1132 // Collects current streams for devices where this process has a context,
1133 // so GraphTask::exec_post_processing can sync them with leaf_streams.
1134 task.stash_current_streams();
1135 }
1136}
1137
1138auto Engine::execute(
1139 const edge_list& root_edges,
1140 const variable_list& inputs,
1141 bool keep_graph,
1142 bool create_graph,
1143 bool accumulate_grad,
1144 const edge_list& outputs) -> variable_list {
1145 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1146 validate_outputs(
1147 root_edges,
1148 const_cast<variable_list&>(inputs),
1149 [](const std::string& msg) { return msg; });
1150 if (accumulate_grad && create_graph) {
1151 TORCH_WARN_ONCE(
1152 "Using backward() with create_graph=True will create a reference cycle "
1153 "between the parameter and its gradient which can cause a memory leak. "
1154 "We recommend using autograd.grad when creating the graph to avoid this. "
1155 "If you have to use this function, make sure to reset the .grad fields of "
1156 "your parameters to None after use to break the cycle and avoid the leak.");
1157 }
1158
1159 // accumulate_grad is true if and only if the frontend call was to
1160 // grad(), not backward(). grad() returns the sum of the gradients
1161 // w.r.t. the inputs and thus needs the inputs to be present.
1162 TORCH_CHECK_VALUE(
1163 accumulate_grad || !outputs.empty(), "grad requires non-empty inputs.");
1164
1165 // A fresh first time Engine::execute call should start on the CPU device,
1166 // initialize a new thread local ready queue on CPU or reuse the existing one
1167 // (if there is one allocated already, i.e. consecutive backward calls,
1168 // re-entrant backward calls), then memoize the local_ready_queue in GraphTask
1169 init_local_ready_queue();
1170 bool not_reentrant_backward_call = worker_device == NO_DEVICE;
1171
1172 // Store root nodes so we can traverse through the graph later
1173 // e.g., for get_current_graph_task_execution_order
1174 c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
1175 for (const auto i : c10::irange(root_edges.size())) {
1176 temp_roots[i] = root_edges[i].function.get();
1177 }
1178
1179 auto graph_task = std::make_shared<GraphTask>(
1180 /* keep_graph */ keep_graph,
1181 /* create_graph */ create_graph,
1182 /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
1183 /* cpu_ready_queue */ local_ready_queue,
1184 /* graph_roots */ std::move(temp_roots));
1185
1186 // If we receive a single root, skip creating extra root node
1187 bool skip_dummy_node = root_edges.size() == 1;
1188 auto graph_root = skip_dummy_node
1189 ? root_edges.at(0).function
1190 : std::make_shared<GraphRoot>(root_edges, inputs);
1191
1192 auto min_topo_nr = compute_min_topological_nr(outputs);
1193 // Now compute the dependencies for all executable functions
1194 compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
1195
1196 if (!outputs.empty()) {
1197 graph_task->init_to_execute(
1198 *graph_root, outputs, accumulate_grad, min_topo_nr);
1199 }
1200
1201 // Queue the root
1202 if (skip_dummy_node) {
1203 InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
1204 auto input = inputs.at(0);
1205
1206 const auto input_stream = InputMetadata(input).stream();
1207 const auto opt_next_stream =
1208 root_edges.at(0).function->stream(c10::DeviceType::CUDA);
1209 input_buffer.add(
1210 root_edges.at(0).input_nr,
1211 std::move(input),
1212 input_stream,
1213 opt_next_stream);
1214
1215 execute_with_graph_task(
1216 graph_task, std::move(graph_root), std::move(input_buffer));
1217 } else {
1218 execute_with_graph_task(
1219 graph_task, std::move(graph_root), InputBuffer(variable_list()));
1220 }
1221 // Avoid a refcount bump for the Future, since we check for refcount in
1222 // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
1223 // in dist_engine.cpp).
1224 auto& fut = graph_task->future_result_;
1225 fut->wait();
1226 graph_task->warning_handler_.replay_warnings();
1227 return fut->value().toTensorVector();
1228}
1229
1230void Engine::initialize_device_threads_pool() {
1231 TORCH_CHECK(
1232 !in_bad_autograd_fork,
1233 "Unable to handle autograd's threading in combination with fork-based multiprocessing. "
1234 "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork");
1235 c10::call_once(
1236 start_device_threads_flag_, &Engine::start_device_threads, this);
1237}
1238
1239c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
1240 const std::shared_ptr<GraphTask>& graph_task,
1241 std::shared_ptr<Node> graph_root,
1242 InputBuffer&& input_buffer) {
1243 initialize_device_threads_pool();
1244 // Lock mutex for GraphTask.
1245 std::unique_lock<std::mutex> lock(graph_task->mutex_);
1246
1247 auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
1248
1249 // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
1250 // autograd engine with corresponding GraphTask, and its NOT a re-entrant call
1251 if (worker_device == NO_DEVICE) {
1252 // We set the worker_device to CPU_DEVICE only if worker_device was
1253 // previously NO_DEVICE. Setting it to CPU afterwards allow us to detect
1254 // whether this is a re-entrant call or not.
1255 set_device(CPU_DEVICE);
1256
1257 // set the graph_task owner to the current device
1258 graph_task->owner_ = worker_device;
1259
1260 // Now that all the non-thread safe fields of the graph_task have been
1261 // populated, we can enqueue it.
1262 queue->push(
1263 NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1264
1265 // The owning thread start to drive the engine execution for any CPU task
1266 // that was just pushed or will be added later from other worker threads
1267 lock.unlock();
1268 thread_main(graph_task);
1269 TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1270 // reset the worker_device after the completion of the graph_task, this is
1271 // so that the initial state of the engine remains the same across every
1272 // backward() or grad() call, we don't need to reset local_ready_queue as we
1273 // could possibly reuse it for new backward calls.
1274 worker_device = NO_DEVICE;
1275 } else {
1276 // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
1277 // backward call from that device.
1278 graph_task->owner_ = worker_device;
1279
1280 // Now that all the non-thread safe fields of the graph_task have been
1281 // populated, we can enqueue it.
1282 queue->push(
1283 NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1284
1285 if (current_depth >= max_recursion_depth_) {
1286 // See Note [Reentrant backwards]
1287 // If reached the max depth, switch to a different thread
1288 add_thread_pool_task(graph_task);
1289 } else {
1290 // Total depth needs to be updated only in this codepath, since it is
1291 // not used in the block above (when we call add_thread_pool_task).
1292 // In the codepath above, GraphTask.reentrant_depth_ is used to
1293 // bootstrap total_depth in the other thread.
1294 ++total_depth;
1295
1296 // Get back to work while we wait for our new graph_task to
1297 // complete!
1298 ++current_depth;
1299 lock.unlock();
1300 thread_main(graph_task);
1301 --current_depth;
1302 --total_depth;
1303
1304 // The graph task should have completed and the associated future should
1305 // be marked completed as well since 'thread_main' above is a call
1306 // blocking an autograd engine thread.
1307 TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1308 }
1309 }
1310 // graph_task_exec_post_processing is done when the Future is marked as
1311 // completed in mark_as_completed_and_run_post_processing.
1312 return graph_task->future_result_;
1313}
1314
1315// note that when python is present, this base engine will be overriden
1316// with a PythonEngine. Because this typically happens before get_default_engine
1317// is called, this base engine will never be created.
1318Engine& Engine::get_base_engine() {
1319 static Engine engine;
1320 return engine;
1321}
1322
1323std::atomic<EngineStub> engine_stub(Engine::get_base_engine);
1324
1325void set_default_engine_stub(EngineStub stub) {
1326 engine_stub.store(stub);
1327}
1328
1329Engine& Engine::get_default_engine() {
1330 return engine_stub.load()();
1331}
1332
1333void Engine::queue_callback(std::function<void()> callback) {
1334 TORCH_CHECK(
1335 current_graph_task,
1336 "Final callbacks can only be installed during backward pass.");
1337
1338 std::lock_guard<std::mutex> lock(current_graph_task->final_callbacks_lock_);
1339 current_graph_task->final_callbacks_.emplace_back(std::move(callback));
1340}
1341
1342bool Engine::is_checkpoint_valid() {
1343 return checkpoint_valid;
1344}
1345
1346void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
1347 if (ready_queue) {
1348 // if ready_queue provided in the caller, use the caller's ready_queue to
1349 // initialize local_ready_queue
1350 local_ready_queue = std::move(ready_queue);
1351 } else if (!local_ready_queue) {
1352 // otherwise if local_ready_queue not allocated, allocate a new ready_queue
1353 local_ready_queue = std::make_shared<ReadyQueue>();
1354 }
1355}
1356
1357// CPU ready queue is per GraphTask, but CUDA device ready queues are shared
1358// across all graph tasks
1359auto Engine::ready_queue(
1360 std::shared_ptr<ReadyQueue> cpu_ready_queue,
1361 at::Device device) -> std::shared_ptr<ReadyQueue> {
1362 bool multithreading_disabled =
1363 !c10::AutogradState::get_tls_state().get_multithreading_enabled();
1364 if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
1365 // return the cpu ready queue passed in
1366 TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1367 return cpu_ready_queue;
1368 } else {
1369 TORCH_INTERNAL_ASSERT(
1370 0 <= device.index() &&
1371 device.index() <
1372 static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1373 // See Note [Allocating GPUs to autograd threads]
1374 return device_ready_queues_.at(device.index());
1375 }
1376}
1377
1378auto Engine::ready_queue_by_index(
1379 std::shared_ptr<ReadyQueue> cpu_ready_queue,
1380 int device_index) -> std::shared_ptr<ReadyQueue> {
1381 if (device_index == CPU_DEVICE) {
1382 // return the cpu ready queue passed in
1383 TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1384 return cpu_ready_queue;
1385 } else {
1386 TORCH_INTERNAL_ASSERT(
1387 0 <= device_index &&
1388 device_index <
1389 static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1390 // See Note [Allocating GPUs to autograd threads]
1391 // NB: This function would become obsolete if we truly allocated a CPU
1392 // thread per device, rather than colocate.
1393 return device_ready_queues_.at(device_index);
1394 }
1395}
1396
1397auto Engine::start_device_threads() -> void {
1398 // First always initialize the thread pool for re-entrant threads
1399 thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
1400
1401 // Second, create special threads for each non-CPU device
1402 // See Note [Allocating GPUs to autograd threads]
1403 c10::DeviceIndex num_devices = 0;
1404 for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
1405 auto* impl = impl_atomic.load();
1406 // Only record the number of devices for device that don't run on the
1407 // cpu ready queue.
1408 if (impl && !should_run_in_cpu_ready_queue(impl->type())) {
1409 num_devices = std::max(num_devices, impl->deviceCount());
1410 }
1411 }
1412
1413 // If there are no device except cpu, no need to create worker threads
1414 if (num_devices == 0) {
1415 return;
1416 }
1417
1418 // Since we're about to create threads, forking is not possible anymore
1419 track_bad_autograd_forks();
1420
1421 // allocate one thread for every GPU device (but colocate GPUs of different
1422 // types), and pre-allocate the device_ready_queues_ to ensure safe reading on
1423 // it.
1424 device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
1425 for (auto& queue : device_ready_queues_) {
1426 queue = std::make_shared<ReadyQueue>();
1427 }
1428
1429 for (const auto i : c10::irange(num_devices)) {
1430 std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
1431 t.detach();
1432 }
1433 // Wait for the threads to start
1434 {
1435 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
1436 while (non_reentrant_device_thread_count_.load() !=
1437 static_cast<uint32_t>(num_devices)) {
1438 non_reentrant_device_thread_condvar_.wait(lk);
1439 }
1440 }
1441}
1442
1443void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
1444 std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
1445 // There may already be some items on the graphtasks_queue_ added by other
1446 // threads but not enough workers to get to the new task that will be
1447 // added
1448 bool create_thread =
1449 (thread_pool_shared_->num_workers_ <=
1450 thread_pool_shared_->graphtasks_queue_.size());
1451 thread_pool_shared_->graphtasks_queue_.push(graph_task);
1452 // Don't need to be holding the lock while actually creating the thread
1453 lck.unlock();
1454 if (create_thread) {
1455 // If we're creating a new thread, forking is not allowed anymore
1456 track_bad_autograd_forks();
1457 std::thread t(&Engine::reentrant_thread_init, this);
1458 t.detach();
1459 }
1460 // This works even if new thread is created because wait() will test the
1461 // predicate before waiting
1462 thread_pool_shared_->work_.notify_one();
1463}
1464
1465// Remembers current streams on all devices where a context has been created.
1466// Only called if Engine::execute detects at least one node runs on a cuda
1467// stream.
1468void GraphTask::stash_current_streams() {
1469 const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
1470 auto num_gpus = guard.deviceCount();
1471 caller_current_streams_.resize(num_gpus);
1472 if (num_gpus > 0) {
1473 for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) {
1474#if defined(USE_ROCM) && (ROCM_VERSION < 50000)
1475 // If the build targets ROCM, stash streams for all visible devices
1476 // unconditionally, to work around
1477 // https://github.com/pytorch/pytorch/issues/59750.
1478 // TODO: Remove ROCM-specific behavior when
1479 // https://github.com/pytorch/pytorch/issues/59750 is fixed.
1480 if (true) {
1481#else
1482 if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) {
1483#endif
1484 caller_current_streams_[idx] =
1485 guard.getStream({c10::DeviceType::CUDA, idx});
1486 } else {
1487 caller_current_streams_[idx] = c10::nullopt;
1488 }
1489 }
1490 }
1491}
1492
1493void GraphTask::init_to_execute(
1494 Node& graph_root,
1495 const edge_list& outputs,
1496 bool accumulate_grad,
1497 uint64_t min_topo_nr) {
1498 // Populates exec_info so nodes that should be executed have
1499 // `exec_info[node].needed_ = true` Only nodes that have a path to any edge in
1500 // `outputs` should be executed. The code below populates exec_info using
1501 // recursion, but the actual code does this iteratively. Refer to the
1502 // numbering to see how the actual code corresponds. A difference to note is
1503 // that in the iterative version, when you are working with the current Node,
1504 // you are reponsible to update your parent's is_needed after all your
1505 // children have been updated.
1506 //
1507 // is_needed = {fn: True for fn in outputs} # (0)
1508 // seen = {}
1509 // def compute_is_needed(fn):
1510 // for next_edge in fn.next_edges:
1511 // child_fn = next_edge.fn
1512 // if child_fn in seen and is_needed[child_fn]: # (1)
1513 // is_needed[fn] = true
1514 // else:
1515 // seen.add(child_fn)
1516 // if compute_is_needed(child_fn):
1517 // is_needed[fn] = true # (2)
1518 // # (3) exit for-loop
1519 // return is_needed[fn]
1520 // compute_is_needed(graph_root)
1521 //
1522 // NB: you might be wondering why we don't populate `seen` with outputs. We
1523 // cannot because in the case where two outputs lie on the same path, we still
1524 // need to explore past the first output or we would miss the nodes that are
1525 // required to compute the second output.
1526 int output_idx = 0;
1527 for (auto& output_edge : outputs) {
1528 // (0) `is_needed` above corresponds to `exec_info_[fn].needed_`
1529 Node* output = output_edge.function.get();
1530 auto& info = exec_info_[output];
1531 if (accumulate_grad) {
1532 // if called through `.backward()` we directly set `needed_` for all the
1533 // outputs to true
1534 info.needed_ = true;
1535 } else {
1536 // otherwise it is `.grad()` and we set exec_info[fn].captures_ instead
1537 // In terms of populating the rest of exec_info though, you can basically
1538 // think of this as the same as setting `needed_` is true directly.
1539 if (!info.captures_) {
1540 info.captures_ = make_unique<std::vector<ExecInfo::Capture>>();
1541 }
1542 info.captures_->emplace_back(output_edge.input_nr, output_idx++);
1543 }
1544 }
1545 captured_vars_.resize(output_idx);
1546
1547 struct Frame {
1548 Frame(Node* fn) : fn_(fn) {}
1549 Node* fn_{};
1550 size_t next_next_fn_{};
1551
1552 Node* get_next_fn() {
1553 const auto& next = fn_->next_edges();
1554 auto num_next = next.size();
1555 while (next_next_fn_ < num_next) {
1556 auto fn = next[next_next_fn_++].function.get();
1557 if (fn)
1558 return fn;
1559 }
1560 return nullptr;
1561 }
1562 };
1563
1564 auto nodeShouldExecute = [this](Node* fn) {
1565 auto it = exec_info_.find(fn);
1566 return it != exec_info_.end() && it->second.should_execute();
1567 };
1568
1569 std::vector<Frame> stack;
1570 std::unordered_set<Node*> seen;
1571 stack.emplace_back(&graph_root);
1572 exec_info_.emplace(stack.back().fn_, ExecInfo());
1573
1574 while (!stack.empty()) {
1575 auto& frame = stack.back();
1576 const auto fn = frame.fn_;
1577
1578 Node* child_fn = nullptr;
1579 while ((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
1580 // (1) next child exists AND has already been seen
1581 if (nodeShouldExecute(child_fn)) {
1582 exec_info_[fn].needed_ = true;
1583 }
1584 }
1585
1586 if (child_fn) {
1587 // (2) next child exists but has not been seen
1588 if (child_fn->topological_nr() < min_topo_nr) {
1589 // child created before the first output means this child cannot have
1590 // an edge to output
1591 continue;
1592 }
1593 stack.emplace_back(child_fn);
1594 } else {
1595 // (3) no next child exists for `fn` means its `needed` has already been
1596 // finalized. pop stack and update parent
1597 stack.pop_back();
1598 if (nodeShouldExecute(fn) && !stack.empty()) {
1599 exec_info_[stack.back().fn_].needed_ = true;
1600 }
1601 }
1602 }
1603}
1604
1605} // namespace autograd
1606} // namespace torch
1607