1#pragma once
2
3#include <torch/csrc/autograd/anomaly_mode.h>
4#include <torch/csrc/autograd/edge.h>
5#include <torch/csrc/autograd/grad_mode.h>
6#include <torch/csrc/autograd/graph_task.h>
7#include <torch/csrc/autograd/input_metadata.h>
8#include <torch/csrc/autograd/saved_variable.h>
9#include <torch/csrc/autograd/variable.h>
10#include <torch/csrc/utils/python_stub.h>
11#include <torch/csrc/utils/variadic.h>
12
13#include <ATen/SequenceNumber.h>
14#include <ATen/core/Tensor.h>
15#include <ATen/record_function.h>
16#include <c10/util/Exception.h>
17#include <c10/util/irange.h>
18
19#include <algorithm>
20#include <cstdint>
21#include <initializer_list>
22#include <memory>
23#include <string>
24#include <utility>
25#include <vector>
26
27C10_CLANG_DIAGNOSTIC_PUSH()
28#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
29C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
30#endif
31
32namespace torch {
33namespace autograd {
34
35struct Edge;
36struct FunctionPostHook;
37struct FunctionPreHook;
38
39using tensor_list = std::vector<at::Tensor>;
40using variable_list = std::vector<Variable>;
41using edge_list = std::vector<Edge>;
42using saved_variable_list = std::vector<SavedVariable>;
43using IndexRange = std::pair<size_t, size_t>;
44
45// Custom deleter to prevent stack overflows.
46TORCH_API void deleteNode(Node* function);
47
48// Guard that sets and restores the evaluating node
49class NodeGuard {
50 public:
51 explicit NodeGuard(std::shared_ptr<Node> node);
52 ~NodeGuard();
53
54 private:
55 std::shared_ptr<Node> last_evaluating_node_;
56};
57
58// Return the Node currently being evaluated (if any)
59// This is only set during the backward pass while a Node is being
60// executed.
61TORCH_API std::shared_ptr<Node> get_current_node();
62
63//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64// Node
65//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
66// A `Node` is an abstract class that represents an operation taking zero
67// or more input `Variable`s and producing zero or more output `Variable`s. All
68// functions in PyTorch's autograd machinery derive from this class and
69// override its `apply` method. Instances of such subclasses will then be
70// invokeable via the call operator.
71//
72// Nodes in the Autograd Graph
73//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
74// When viewing the autograd system as a graph, `Node`s are the vertices or
75// nodes, connected to each other via (directed) `Edge`s, which themselves are
76// represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to
77// and inputs of `Node`s, and travel between these edges during execution
78// of the graph. When two or more `Edge`s (from different sources) point at the
79// same input to a `Node`, the values produced along all of these edges are
80// implicitly summed prior to being forwarded to the target `Node`.
81//
82// Hierarchy
83//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
84// Subclasses usually represent differentiable functions as well as their
85// gradient operators. Note, however, that due to the very general definition
86// of a `Node` taking *zero* or more inputs and producing *zero* or more
87// outputs, uses of `Node`s are flexible and extend beyond purely
88// mathematical operations. For example, the `AccumulateGrad` function is a
89// *sink*: it takes one input, but produces no outputs, instead accumulating
90// the input as a side effect. At the other extreme, the `GraphRoot` function
91// receives no inputs from other functions, but produces multiple outputs.
92//
93// Interface
94//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95// The most important method on `Node` is the call operator, which takes in
96// a list of variables and produces a list of variables. The precise size of
97// these lists can be determined with `num_inputs()` and `num_outputs()`.
98// `Node`s are stitched together via their `next_edge` interface, which let
99// you manipulate the set of outgoing edges of a `Node`. You can add an
100// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and
101// iterate over them via the `next_edges()` method. Other methods exist for
102// integration with the JIT and other parts of PyTorch. Every `Node` has a
103// *sequence number* that increases monotonically in the order of `Node`
104// construction. It can be retrieved via the `sequence_nr()` method. Note that
105// this sequence number is *thread local*. This means that when `Node`s
106// `A`, `B` and `C` are created consecutively in the same thread, their
107// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
108// are created in one thread and `C` is created in a new thread, there are *no
109// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
110// See NOTE [ Sequence Number] for more details on the usages of sequence
111// number.
112//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
113struct TORCH_API Node : std::enable_shared_from_this<Node> {
114 public:
115 /// Construct a new `Node` with the given `next_edges`
116 explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list())
117 : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) {
118 for (const Edge& edge : next_edges_) {
119 update_topological_nr(edge);
120 }
121
122 if (AnomalyMode::is_enabled()) {
123 metadata()->store_stack();
124
125 // If anomaly mode is enabled and graph is constructed, then assign the
126 // currently evaluating node as the parent of this node.
127 // A parent is a Node where this Node is created.
128 // We are tracking the parents to track multiple backward operations.
129 assign_parent();
130 }
131
132 // Store the thread_id of the forward operator.
133 // See NOTE [ Sequence Numbers ]
134 thread_id_ = at::RecordFunction::currentThreadId();
135 }
136
137 explicit Node(edge_list&& next_edges = edge_list())
138 : Node(
139 /*sequence_nr=*/at::sequence_number::get_and_increment(),
140 std::move(next_edges)) {}
141
142 /// Nodes are neither copyable nor moveable.
143 Node(const Node& other) = delete;
144 Node(Node&& other) = delete;
145 Node& operator=(const Node& other) = delete;
146 Node& operator=(Node&& other) = delete;
147 virtual ~Node() = default;
148
149 std::shared_ptr<Node> getptr() {
150 return shared_from_this();
151 }
152 /// Evaluates the function on the given inputs and returns the result of the
153 /// function call.
154 variable_list operator()(variable_list&& inputs) {
155 // In the first iteration of named tensors, autograd ignores names and
156 // operates on unnamed tensors. In the long term, autograd should
157 // probably operate with names.
158 at::NoNamesGuard no_names_guard;
159
160#ifdef USE_ROCM
161 // Keep track of backward pass for rocblas.
162 at::ROCmBackwardPassGuard in_backward;
163#endif
164
165 auto step_callbacks =
166 at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
167 if (C10_UNLIKELY(step_callbacks.has_value())) {
168 at::RecordFunction guard(std::move(*step_callbacks));
169 // Using sequence number and thread id to correlate with
170 // the forward pass function
171 guard.setForwardThreadId(thread_id_);
172 if (guard.needsInputs()) {
173 std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end());
174 guard.before(
175 name(),
176 c10::ArrayRef<const c10::IValue>(
177 inputs_vec.data(), inputs_vec.size()),
178 sequence_nr());
179 } else {
180 guard.before(name(), sequence_nr());
181 }
182 return apply(std::move(inputs));
183 } else {
184 return apply(std::move(inputs));
185 }
186 }
187
188 // Graph Connectivity API
189 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
190
191 // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
192 // forward function.
193
194 // Marker for expected undefined input
195 struct undefined_input {};
196
197 /// Adds the type and shape metadata for a new input. Returns the index of
198 /// of the new input.
199 uint32_t add_input_metadata(
200 const at::TensorOptions& options,
201 c10::SymIntArrayRef shape,
202 bool is_tensor_subclass) noexcept {
203 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
204 uint32_t input_nr = input_metadata_.size();
205 auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape};
206 input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
207 return input_nr;
208 }
209
210 uint32_t add_input_metadata(const at::Tensor& t) noexcept {
211 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
212 uint32_t input_nr = input_metadata_.size();
213 input_metadata_.emplace_back(t);
214 return input_nr;
215 }
216
217 /// Adds a placeholder for an input that will not be used.
218 uint32_t add_input_metadata(undefined_input u) noexcept {
219 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
220 uint32_t input_nr = input_metadata_.size();
221 input_metadata_.emplace_back();
222 return input_nr;
223 }
224
225 uint32_t num_inputs() const noexcept {
226 return input_metadata_.size();
227 }
228
229 const InputMetadata& input_metadata(size_t index) const {
230 return input_metadata_[index];
231 }
232
233 /**
234 * Note: Function Streams
235 * A function's stream (for a given device type) is the stream of the first
236 * element of its input buffer on a device of that type.
237 *
238 * If all elements are on the same device they MUST share a stream. If
239 * elements are on different devices (across multiple GPUs, for example)
240 * they may have different streams.
241 */
242 c10::optional<c10::Stream> stream(const c10::DeviceType device_type) {
243 for (const auto& metadata : input_metadata_) {
244 if (metadata.device().type() == device_type)
245 return metadata.stream();
246 }
247
248 return c10::nullopt;
249 }
250
251 void clear_input_metadata() {
252 input_metadata_.clear();
253 }
254
255 // Outputs ("Next Edges")
256
257 void update_topological_nr(const Edge& edge) {
258 TORCH_INTERNAL_ASSERT(
259 !has_parent_,
260 "Cannot update a node's topological_nr after it already has a parent."
261 " If we allow this, we can no longer guarantee that a parent's"
262 " topo_nr is always greater than those of all its children")
263 Node* node = edge.function.get();
264 if (node) {
265 auto topo_nr = node->topological_nr();
266 if (topological_nr_ <= topo_nr) {
267 topological_nr_ = topo_nr + 1;
268 }
269 }
270 }
271
272 void set_next_edge(size_t index, Edge edge) {
273 update_topological_nr(edge);
274 next_edges_[index] = std::move(edge);
275 }
276
277 void add_next_edge(Edge edge) {
278 update_topological_nr(edge);
279 next_edges_.emplace_back(std::move(edge));
280 }
281
282 void set_next_edges(edge_list&& next_edges) {
283 next_edges_ = std::move(next_edges);
284 for (const auto& next_edge : next_edges_) {
285 update_topological_nr(next_edge);
286 }
287 }
288
289 const Edge& next_edge(size_t index) const noexcept {
290 return next_edges_[index];
291 }
292
293 const edge_list& next_edges() const noexcept {
294 return next_edges_;
295 }
296
297 edge_list& next_edges() noexcept {
298 return next_edges_;
299 }
300
301 uint32_t num_outputs() const noexcept {
302 return next_edges_.size();
303 }
304
305 // Miscellaneous Methods
306 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
307
308 /// NOTE [ Sequence Number]
309 ///
310 /// The sequence_nr has two main usages in autograd:
311 ///
312 /// 1) Helps determine the node's execution priority in the engine.
313 /// All else being equal, nodes with higher priority numbers are executed
314 /// first. Thus, nodes corresponding to ops executed later are the first to
315 /// be executed in the backward pass. One caveat is that we prioritize
316 /// AccumulateGrad nodes by explicitly setting its sequence_nr to be
317 /// UINT64_MAX.
318 /// 2) The sequence number of this `Node` is paired with with thread_id it was
319 /// created in
320 /// as a unique identifier by the profiler to annotate recorded events.
321 /// The purpose of this is to help users (and possibly programs)
322 /// interpreting the profiler's output to correlate backward nodes with its
323 /// forward ops. We need both sequence_nr and thread_id to identify a node
324 /// because sequence_nr is thread_local, i.e., starts counting up from zero
325 /// in a new thread
326 uint64_t sequence_nr() const noexcept {
327 return sequence_nr_;
328 }
329
330 // NOTE [ Topological Number ]
331 //
332 // topological_nr is used to prune branches in the DAG during autograd
333 // discovery as maintaining topological_nr helps us check in O(1) if there
334 // does NOT exist a directed path between two nodes.
335 //
336 // The topological order number of this `Node` representing the length of the
337 // longest possible path from this Node to any leaf node. If you are leaf
338 // node, aka AccumulateGrad, this will be zero. This value has the property
339 // that For every pair of nodes X, Y in G, existence of a directed path from X
340 // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so
341 // we cannot prove existence of a path from X to Y, only non-existence.
342 //
343 // One assumption we make when using topo_nr is that once a node
344 // has been used, i.e., has a parent node, its own topo_nr does not change
345 // we have added some checks with the `has_parent_` field to enforce this.
346 //
347 // What NOT to do:
348 //
349 // 1) 2 -> 1 -> 0 In this diagram we label nodes with their
350 // topo_nr.
351 // 2 -> 1 -> 0 We have two simple graphs that can each
352 // arise from
353 // `t.exp().exp()`, for example.
354 // 2) 2 -> 1 -> 0
355 // /
356 // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1
357 // already
358 // has a parent.
359 // 3) 2 -> 1 -> 0
360 // /
361 // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3!
362 //
363 uint64_t topological_nr() const noexcept {
364 has_parent_ = true;
365 return topological_nr_;
366 }
367
368 // assigning a node as a parent to this node
369 void assign_parent();
370
371 /// Id of the thread that created Node
372 uint64_t thread_id() const noexcept {
373 return thread_id_;
374 }
375
376 /// Returns the name of the dynamic type of the function, for debugging.
377 virtual std::string name() const;
378
379 /// The difference between functions `should_compute_output` and
380 /// `task_should_compute_output`:
381 /// - `should_compute_output` should only be used during graph construction
382 /// and takes into account only requires_grad information
383 /// - `task_should_compute_output` should only be called during the backward
384 /// pass (unless called directly through grad_fn) and takes into account the
385 /// current graph task. Specifically, the autograd engine trims unnecessary
386 /// edges when `inputs` are specified, and during backward untrimmed nodes
387 /// left on the graph can/should check `task_should_compute_output` to see if
388 /// any outgoing edges have been trimmed by the engine. If that is the case,
389 /// gradient computation wrt those edges can be omitted.
390 ///
391 /// Returns true if the particular output edge is active, and that particular
392 /// output of this function should be computed.
393 bool should_compute_output(size_t output_edge_index) const {
394 TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
395 return next_edges_[output_edge_index].is_valid();
396 }
397
398 /// Returns true if any of the output edges in any of the ranges are active.
399 bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
400 return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
401 for (const auto i : c10::irange(range.first, range.second)) {
402 if (should_compute_output(i))
403 return true;
404 }
405 return false;
406 });
407 }
408
409 /// Same as the above `should_compute_output` function but will also
410 /// check whether this edge is needed within the current graph task.
411 bool task_should_compute_output(size_t output_edge_index) const {
412 TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
413 const auto& next = next_edges_[output_edge_index];
414 if (next.is_valid()) {
415 const auto exec_info = get_current_graph_task_exec_info();
416 if (exec_info && !exec_info->empty()) {
417 auto it = exec_info->find(next.function.get());
418 if (it == exec_info->end() || !it->second.should_execute()) {
419 return false; // this edge is not needed for the current graph_task
420 }
421 }
422 return true;
423 }
424 return false;
425 }
426
427 /// Returns true if any of the output edges in any of the ranges are active
428 /// and should be computed in the current graph task.
429 bool task_should_compute_output(
430 std::initializer_list<IndexRange> idxs) const {
431 return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
432 for (const auto i : c10::irange(range.first, range.second)) {
433 if (task_should_compute_output(i))
434 return true;
435 }
436 return false;
437 });
438 }
439
440 /// Returns the `PyObject` stored for this `Node` (for Python
441 /// interaction).
442 PyObject* pyobj() const noexcept {
443 return pyobj_;
444 }
445
446 /// Sets the `PyObject` stored for this `Node` (for Python interaction).
447 void set_pyobj(PyObject* pyobj) noexcept {
448 pyobj_ = pyobj;
449 }
450
451 /// Returns the anomaly metadata stored for this `Node`.
452 /// If none exist, creates a new empty one.
453 AnomalyMetadata* metadata() noexcept;
454
455 // Hook API
456 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
457
458 uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
459 post_hooks_.emplace_back(std::move(post_hook));
460 // Use the raw pointer as the unique key to identify this hook. This key
461 // can then be used in del_post_hook(key) to remove this hook.
462 return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
463 }
464
465 const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks()
466 const noexcept {
467 return post_hooks_;
468 }
469
470 // delete a post hook matching the key
471 bool del_post_hook(const uintptr_t& key) {
472 for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) {
473 if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
474 post_hooks_.erase(it);
475 return true;
476 }
477 }
478 return false;
479 }
480
481 std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
482 return post_hooks_;
483 }
484
485 void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
486 pre_hooks_.emplace_back(std::move(pre_hook));
487 }
488
489 void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
490 tensor_pre_hooks_.emplace_back(std::move(pre_hook));
491 }
492
493 void add_retains_grad_hook(
494 std::unique_ptr<FunctionPreHook>&& pre_hook,
495 int output_idx) {
496 retains_grad_hooks_[output_idx] = std::move(pre_hook);
497 }
498
499 std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) {
500 auto ret = std::move(retains_grad_hooks_[output_idx]);
501 retains_grad_hooks_.erase(output_idx);
502 return ret;
503 }
504
505 const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
506 const noexcept {
507 return pre_hooks_;
508 }
509
510 std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
511 return pre_hooks_;
512 }
513
514 virtual std::vector<std::unique_ptr<FunctionPreHook>>&
515 tensor_pre_hooks() noexcept {
516 return tensor_pre_hooks_;
517 }
518
519 std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
520 retains_grad_hooks() noexcept {
521 return retains_grad_hooks_;
522 }
523
524 // Customization Points for Subclasses
525 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
526
527 /// Releases saved variables if the operation won't be reused.
528 virtual void release_variables() {}
529
530 /// Called before an apply if `release_variables()` is going to be called.
531 /// Allows larger ops like `InterpreterAutogradFunction` to incrementally
532 /// release variables as they run.
533 virtual void will_release_variables() {}
534
535 /// Returns true if this function is traceable. An op is traceable if all
536 /// operations happening within `apply()` are performed on autograd
537 /// `Variables` (i.e. apply mostly instantiates and applies other functions).
538 virtual bool is_traceable() {
539 return false;
540 }
541
542 /// A `Node` is said to pass state transparently to backward, if the
543 /// state consists only of (Saved)Variables and only non-variable objects
544 /// that parameterize the operation in some way that defines the graph
545 /// structure AND the backward function is traceable. In particular,
546 /// parametrization MUST NOT depend on the data of any `Variable`.
547 /// TODO: it might be possible to handle cases where backward is
548 /// non-traceable but state passing could be considered transparent. This
549 /// will probably depend on saved_variable_list being mutable.
550 /// NOTE: this value matters only if is_traceable() returns false.
551 virtual bool passes_state_transparently() {
552 return false;
553 }
554
555 protected:
556 /// Performs the `Node`'s actual operation.
557 virtual variable_list apply(variable_list&& inputs) = 0;
558
559 /// Calls `apply()`, but instruments it with tracing machinery.
560 variable_list traced_apply(variable_list inputs);
561
562 // Sequence number used to correlate backward nodes with forward ops in the
563 // profiler and provide determinisim in the engine.
564 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
565 const uint64_t sequence_nr_;
566
567 // See NOTE [ Topological Number ]
568 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
569 uint64_t topological_nr_ = 0;
570
571 // Tracks whether this node has been added as the next_edge of another node
572 // via set_next_edge(s), which always calls topological_nr() of all its
573 // children See NOTE [ Topological Number ] for why we need this.
574 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
575 mutable bool has_parent_ = false;
576
577 // Id of the thread that created the instance
578 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
579 uint64_t thread_id_ = 0;
580
581 // Note [Thread Safety on Autograd Node]
582 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
583 // Autograd Engine let the owning thread which calls Engine::execute to drive
584 // the GraphTask execution, there might be cases that part of the GraphTask is
585 // shared across different `backward()` or `grad()` calls, i.e. fork new
586 // threads in the middle of the forward and call `backward()` separately from
587 // different threads. We need to protect the thread safety on NodeTask to
588 // prevent data racing on shared variables read/write.
589 //
590 // NB: This is only needed for Autograd Nodes that runs on CPU, technically
591 // "CUDA", "XLA" nodes don't need locking because device threads are always
592 // single threaded.
593 //
594 // Here we add a thread mutex to help protect the Node's thread safety, so
595 // that different threads cannot race the shared data when executing the same
596 // NodeTask from multiple CPU threads. It IS the user/developer responsibility
597 // to take advantage of this mutex to protect the thread safety of their
598 // autograd Node. The general strategy of thread safety on autograd Node:
599 //
600 // 1. User should lock the mutex during Node::release_variables() if the Node
601 // needs
602 // to release the variables on the fly, this serve the purpose that when we
603 // release saved_variables from one thread, no other threads can release
604 // the saved variables concurrently. call the Node::apply(),
605 // 2. User should lock the mutex during Node::apply(), this is to ensure Node
606 // that
607 // writing to the shared variable are not racing across threads (i.e.
608 // AccumulateGrad and custom C++ Autograd Node if writing to shared
609 // variables )
610 // 3. item 2 and item 3 should work together so that when we release saved
611 // variables
612 // from one thread, no other threads can call Node::apply(), this ensures
613 // the variable references from other threads aren't dangling.
614 // 4. if the Node don't release any variables and no shared data read/write in
615 // the Node
616 // i.e. purely functional, user don't need to lock the mutex
617 //
618 // This way we could protect the thread safety on Autograd Node, but we could
619 // still not protect the thread safety on Node pre/post C++ hooks (python
620 // hooks are automatically thread safe), we rely on the user to write thread
621 // safe C++ hooks if they want the hook to be correctly applied in
622 // multithreading environment.
623 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
624 std::mutex mutex_;
625
626 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
627 edge_list next_edges_;
628 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
629 PyObject* pyobj_ = nullptr; // weak reference
630 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
631 std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
632
633 // NOTE [Hooks ordering]
634 // We have 3 separate fields for pre hooks registered to the autograd nodes
635 // because the conditions under which they execute are different, and we
636 // want more fine-grained control over the order in which different types
637 // of hooks are executed.
638 // - pre_hooks are only executed when the node itself is executed
639 // - tensor_pre_hook is executed as long as the engine traverses over it
640 // even if that node won't be executed.
641 // - retains_grad_hook are like tensor_pre_hooks except they are always
642 // ordered after all other tensor pre hooks
643 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
644 std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
645 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
646 std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
647 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
648 std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
649 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
650 std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
651 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
652 at::SmallVector<InputMetadata, 2> input_metadata_;
653};
654
655/// See Node::is_traceable() for definition.
656struct TraceableFunction : public Node {
657 using Node::Node;
658 bool is_traceable() final {
659 return true;
660 }
661};
662
663//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
664// Associated Free Nodes
665//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
666
667namespace detail {
668// Implementation of `collect_next_edges` (see below).
669struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
670 edge_list next_edges;
671 using IterArgs<MakeNextFunctionList>::operator();
672 void operator()(const Variable& variable) {
673 // NOLINTNEXTLINE(bugprone-branch-clone)
674 if (variable.defined()) {
675 next_edges.emplace_back(impl::gradient_edge(variable));
676 } else {
677 next_edges.emplace_back();
678 }
679 }
680 void operator()(const Variable* variable) {
681 // NOLINTNEXTLINE(bugprone-branch-clone)
682 if (variable->defined()) {
683 next_edges.emplace_back(impl::gradient_edge(*variable));
684 } else {
685 next_edges.emplace_back();
686 }
687 }
688 void operator()(const c10::optional<Variable>& variable) {
689 // NOLINTNEXTLINE(bugprone-branch-clone)
690 if (variable.has_value() && variable->defined()) {
691 next_edges.emplace_back(impl::gradient_edge(*variable));
692 } else {
693 next_edges.emplace_back();
694 }
695 }
696};
697} // namespace detail
698
699/// Create an `Edge` between the given `variable` and the `function`, which is
700/// assumed to be the gradient function of this variable (i.e. the function
701/// through which this variable is backpropagated during the backward pass).
702/// This sets the `grad_fn` property of the `variable`. This function assumes
703/// that the `Variable` is a new input to the gradient function and its
704/// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
705/// increments the `Node`'s number of inputs by one. Approximately
706/// equivalent to `variable.set_gradient_edge(function,
707/// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`.
708/// If you don't want the `Node`'s `num_inputs` to be incremented, use
709/// `set_gradient_edge` directly.
710inline void create_gradient_edge(
711 Variable& variable,
712 std::shared_ptr<Node> function) {
713 // Copy before move.
714 const auto input_nr = function->add_input_metadata(variable);
715 impl::set_gradient_edge(variable, {std::move(function), input_nr});
716}
717
718/// Return true if any of the variables in the list require a gradient.
719inline bool any_variable_requires_grad(const variable_list& variables) {
720 return std::any_of(
721 variables.begin(), variables.end(), [](const Variable& variable) {
722 return variable.defined() && variable.requires_grad();
723 });
724}
725
726/// Return the next edges of all the given variables, or tuples of variables.
727template <typename... Variables>
728edge_list collect_next_edges(Variables&&... variables) {
729 detail::MakeNextFunctionList make;
730 make.apply(std::forward<Variables>(variables)...);
731 return std::move(make.next_edges);
732}
733} // namespace autograd
734} // namespace torch
735
736C10_CLANG_DIAGNOSTIC_POP()
737