1 | #pragma once |
2 | |
3 | #include <torch/csrc/autograd/anomaly_mode.h> |
4 | #include <torch/csrc/autograd/edge.h> |
5 | #include <torch/csrc/autograd/grad_mode.h> |
6 | #include <torch/csrc/autograd/graph_task.h> |
7 | #include <torch/csrc/autograd/input_metadata.h> |
8 | #include <torch/csrc/autograd/saved_variable.h> |
9 | #include <torch/csrc/autograd/variable.h> |
10 | #include <torch/csrc/utils/python_stub.h> |
11 | #include <torch/csrc/utils/variadic.h> |
12 | |
13 | #include <ATen/SequenceNumber.h> |
14 | #include <ATen/core/Tensor.h> |
15 | #include <ATen/record_function.h> |
16 | #include <c10/util/Exception.h> |
17 | #include <c10/util/irange.h> |
18 | |
19 | #include <algorithm> |
20 | #include <cstdint> |
21 | #include <initializer_list> |
22 | #include <memory> |
23 | #include <string> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | C10_CLANG_DIAGNOSTIC_PUSH() |
28 | #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |
29 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32" ) |
30 | #endif |
31 | |
32 | namespace torch { |
33 | namespace autograd { |
34 | |
35 | struct Edge; |
36 | struct FunctionPostHook; |
37 | struct FunctionPreHook; |
38 | |
39 | using tensor_list = std::vector<at::Tensor>; |
40 | using variable_list = std::vector<Variable>; |
41 | using edge_list = std::vector<Edge>; |
42 | using saved_variable_list = std::vector<SavedVariable>; |
43 | using IndexRange = std::pair<size_t, size_t>; |
44 | |
45 | // Custom deleter to prevent stack overflows. |
46 | TORCH_API void deleteNode(Node* function); |
47 | |
48 | // Guard that sets and restores the evaluating node |
49 | class NodeGuard { |
50 | public: |
51 | explicit NodeGuard(std::shared_ptr<Node> node); |
52 | ~NodeGuard(); |
53 | |
54 | private: |
55 | std::shared_ptr<Node> last_evaluating_node_; |
56 | }; |
57 | |
58 | // Return the Node currently being evaluated (if any) |
59 | // This is only set during the backward pass while a Node is being |
60 | // executed. |
61 | TORCH_API std::shared_ptr<Node> get_current_node(); |
62 | |
63 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
64 | // Node |
65 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
66 | // A `Node` is an abstract class that represents an operation taking zero |
67 | // or more input `Variable`s and producing zero or more output `Variable`s. All |
68 | // functions in PyTorch's autograd machinery derive from this class and |
69 | // override its `apply` method. Instances of such subclasses will then be |
70 | // invokeable via the call operator. |
71 | // |
72 | // Nodes in the Autograd Graph |
73 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
74 | // When viewing the autograd system as a graph, `Node`s are the vertices or |
75 | // nodes, connected to each other via (directed) `Edge`s, which themselves are |
76 | // represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to |
77 | // and inputs of `Node`s, and travel between these edges during execution |
78 | // of the graph. When two or more `Edge`s (from different sources) point at the |
79 | // same input to a `Node`, the values produced along all of these edges are |
80 | // implicitly summed prior to being forwarded to the target `Node`. |
81 | // |
82 | // Hierarchy |
83 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
84 | // Subclasses usually represent differentiable functions as well as their |
85 | // gradient operators. Note, however, that due to the very general definition |
86 | // of a `Node` taking *zero* or more inputs and producing *zero* or more |
87 | // outputs, uses of `Node`s are flexible and extend beyond purely |
88 | // mathematical operations. For example, the `AccumulateGrad` function is a |
89 | // *sink*: it takes one input, but produces no outputs, instead accumulating |
90 | // the input as a side effect. At the other extreme, the `GraphRoot` function |
91 | // receives no inputs from other functions, but produces multiple outputs. |
92 | // |
93 | // Interface |
94 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
95 | // The most important method on `Node` is the call operator, which takes in |
96 | // a list of variables and produces a list of variables. The precise size of |
97 | // these lists can be determined with `num_inputs()` and `num_outputs()`. |
98 | // `Node`s are stitched together via their `next_edge` interface, which let |
99 | // you manipulate the set of outgoing edges of a `Node`. You can add an |
100 | // edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and |
101 | // iterate over them via the `next_edges()` method. Other methods exist for |
102 | // integration with the JIT and other parts of PyTorch. Every `Node` has a |
103 | // *sequence number* that increases monotonically in the order of `Node` |
104 | // construction. It can be retrieved via the `sequence_nr()` method. Note that |
105 | // this sequence number is *thread local*. This means that when `Node`s |
106 | // `A`, `B` and `C` are created consecutively in the same thread, their |
107 | // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` |
108 | // are created in one thread and `C` is created in a new thread, there are *no |
109 | // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. |
110 | // See NOTE [ Sequence Number] for more details on the usages of sequence |
111 | // number. |
112 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
113 | struct TORCH_API Node : std::enable_shared_from_this<Node> { |
114 | public: |
115 | /// Construct a new `Node` with the given `next_edges` |
116 | explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) |
117 | : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { |
118 | for (const Edge& edge : next_edges_) { |
119 | update_topological_nr(edge); |
120 | } |
121 | |
122 | if (AnomalyMode::is_enabled()) { |
123 | metadata()->store_stack(); |
124 | |
125 | // If anomaly mode is enabled and graph is constructed, then assign the |
126 | // currently evaluating node as the parent of this node. |
127 | // A parent is a Node where this Node is created. |
128 | // We are tracking the parents to track multiple backward operations. |
129 | assign_parent(); |
130 | } |
131 | |
132 | // Store the thread_id of the forward operator. |
133 | // See NOTE [ Sequence Numbers ] |
134 | thread_id_ = at::RecordFunction::currentThreadId(); |
135 | } |
136 | |
137 | explicit Node(edge_list&& next_edges = edge_list()) |
138 | : Node( |
139 | /*sequence_nr=*/at::sequence_number::get_and_increment(), |
140 | std::move(next_edges)) {} |
141 | |
142 | /// Nodes are neither copyable nor moveable. |
143 | Node(const Node& other) = delete; |
144 | Node(Node&& other) = delete; |
145 | Node& operator=(const Node& other) = delete; |
146 | Node& operator=(Node&& other) = delete; |
147 | virtual ~Node() = default; |
148 | |
149 | std::shared_ptr<Node> getptr() { |
150 | return shared_from_this(); |
151 | } |
152 | /// Evaluates the function on the given inputs and returns the result of the |
153 | /// function call. |
154 | variable_list operator()(variable_list&& inputs) { |
155 | // In the first iteration of named tensors, autograd ignores names and |
156 | // operates on unnamed tensors. In the long term, autograd should |
157 | // probably operate with names. |
158 | at::NoNamesGuard no_names_guard; |
159 | |
160 | #ifdef USE_ROCM |
161 | // Keep track of backward pass for rocblas. |
162 | at::ROCmBackwardPassGuard in_backward; |
163 | #endif |
164 | |
165 | auto step_callbacks = |
166 | at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); |
167 | if (C10_UNLIKELY(step_callbacks.has_value())) { |
168 | at::RecordFunction guard(std::move(*step_callbacks)); |
169 | // Using sequence number and thread id to correlate with |
170 | // the forward pass function |
171 | guard.setForwardThreadId(thread_id_); |
172 | if (guard.needsInputs()) { |
173 | std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end()); |
174 | guard.before( |
175 | name(), |
176 | c10::ArrayRef<const c10::IValue>( |
177 | inputs_vec.data(), inputs_vec.size()), |
178 | sequence_nr()); |
179 | } else { |
180 | guard.before(name(), sequence_nr()); |
181 | } |
182 | return apply(std::move(inputs)); |
183 | } else { |
184 | return apply(std::move(inputs)); |
185 | } |
186 | } |
187 | |
188 | // Graph Connectivity API |
189 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
190 | |
191 | // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the |
192 | // forward function. |
193 | |
194 | // Marker for expected undefined input |
195 | struct undefined_input {}; |
196 | |
197 | /// Adds the type and shape metadata for a new input. Returns the index of |
198 | /// of the new input. |
199 | uint32_t add_input_metadata( |
200 | const at::TensorOptions& options, |
201 | c10::SymIntArrayRef shape, |
202 | bool is_tensor_subclass) noexcept { |
203 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
204 | uint32_t input_nr = input_metadata_.size(); |
205 | auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape}; |
206 | input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass); |
207 | return input_nr; |
208 | } |
209 | |
210 | uint32_t add_input_metadata(const at::Tensor& t) noexcept { |
211 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
212 | uint32_t input_nr = input_metadata_.size(); |
213 | input_metadata_.emplace_back(t); |
214 | return input_nr; |
215 | } |
216 | |
217 | /// Adds a placeholder for an input that will not be used. |
218 | uint32_t add_input_metadata(undefined_input u) noexcept { |
219 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
220 | uint32_t input_nr = input_metadata_.size(); |
221 | input_metadata_.emplace_back(); |
222 | return input_nr; |
223 | } |
224 | |
225 | uint32_t num_inputs() const noexcept { |
226 | return input_metadata_.size(); |
227 | } |
228 | |
229 | const InputMetadata& input_metadata(size_t index) const { |
230 | return input_metadata_[index]; |
231 | } |
232 | |
233 | /** |
234 | * Note: Function Streams |
235 | * A function's stream (for a given device type) is the stream of the first |
236 | * element of its input buffer on a device of that type. |
237 | * |
238 | * If all elements are on the same device they MUST share a stream. If |
239 | * elements are on different devices (across multiple GPUs, for example) |
240 | * they may have different streams. |
241 | */ |
242 | c10::optional<c10::Stream> stream(const c10::DeviceType device_type) { |
243 | for (const auto& metadata : input_metadata_) { |
244 | if (metadata.device().type() == device_type) |
245 | return metadata.stream(); |
246 | } |
247 | |
248 | return c10::nullopt; |
249 | } |
250 | |
251 | void clear_input_metadata() { |
252 | input_metadata_.clear(); |
253 | } |
254 | |
255 | // Outputs ("Next Edges") |
256 | |
257 | void update_topological_nr(const Edge& edge) { |
258 | TORCH_INTERNAL_ASSERT( |
259 | !has_parent_, |
260 | "Cannot update a node's topological_nr after it already has a parent." |
261 | " If we allow this, we can no longer guarantee that a parent's" |
262 | " topo_nr is always greater than those of all its children" ) |
263 | Node* node = edge.function.get(); |
264 | if (node) { |
265 | auto topo_nr = node->topological_nr(); |
266 | if (topological_nr_ <= topo_nr) { |
267 | topological_nr_ = topo_nr + 1; |
268 | } |
269 | } |
270 | } |
271 | |
272 | void set_next_edge(size_t index, Edge edge) { |
273 | update_topological_nr(edge); |
274 | next_edges_[index] = std::move(edge); |
275 | } |
276 | |
277 | void add_next_edge(Edge edge) { |
278 | update_topological_nr(edge); |
279 | next_edges_.emplace_back(std::move(edge)); |
280 | } |
281 | |
282 | void set_next_edges(edge_list&& next_edges) { |
283 | next_edges_ = std::move(next_edges); |
284 | for (const auto& next_edge : next_edges_) { |
285 | update_topological_nr(next_edge); |
286 | } |
287 | } |
288 | |
289 | const Edge& next_edge(size_t index) const noexcept { |
290 | return next_edges_[index]; |
291 | } |
292 | |
293 | const edge_list& next_edges() const noexcept { |
294 | return next_edges_; |
295 | } |
296 | |
297 | edge_list& next_edges() noexcept { |
298 | return next_edges_; |
299 | } |
300 | |
301 | uint32_t num_outputs() const noexcept { |
302 | return next_edges_.size(); |
303 | } |
304 | |
305 | // Miscellaneous Methods |
306 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
307 | |
308 | /// NOTE [ Sequence Number] |
309 | /// |
310 | /// The sequence_nr has two main usages in autograd: |
311 | /// |
312 | /// 1) Helps determine the node's execution priority in the engine. |
313 | /// All else being equal, nodes with higher priority numbers are executed |
314 | /// first. Thus, nodes corresponding to ops executed later are the first to |
315 | /// be executed in the backward pass. One caveat is that we prioritize |
316 | /// AccumulateGrad nodes by explicitly setting its sequence_nr to be |
317 | /// UINT64_MAX. |
318 | /// 2) The sequence number of this `Node` is paired with with thread_id it was |
319 | /// created in |
320 | /// as a unique identifier by the profiler to annotate recorded events. |
321 | /// The purpose of this is to help users (and possibly programs) |
322 | /// interpreting the profiler's output to correlate backward nodes with its |
323 | /// forward ops. We need both sequence_nr and thread_id to identify a node |
324 | /// because sequence_nr is thread_local, i.e., starts counting up from zero |
325 | /// in a new thread |
326 | uint64_t sequence_nr() const noexcept { |
327 | return sequence_nr_; |
328 | } |
329 | |
330 | // NOTE [ Topological Number ] |
331 | // |
332 | // topological_nr is used to prune branches in the DAG during autograd |
333 | // discovery as maintaining topological_nr helps us check in O(1) if there |
334 | // does NOT exist a directed path between two nodes. |
335 | // |
336 | // The topological order number of this `Node` representing the length of the |
337 | // longest possible path from this Node to any leaf node. If you are leaf |
338 | // node, aka AccumulateGrad, this will be zero. This value has the property |
339 | // that For every pair of nodes X, Y in G, existence of a directed path from X |
340 | // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so |
341 | // we cannot prove existence of a path from X to Y, only non-existence. |
342 | // |
343 | // One assumption we make when using topo_nr is that once a node |
344 | // has been used, i.e., has a parent node, its own topo_nr does not change |
345 | // we have added some checks with the `has_parent_` field to enforce this. |
346 | // |
347 | // What NOT to do: |
348 | // |
349 | // 1) 2 -> 1 -> 0 In this diagram we label nodes with their |
350 | // topo_nr. |
351 | // 2 -> 1 -> 0 We have two simple graphs that can each |
352 | // arise from |
353 | // `t.exp().exp()`, for example. |
354 | // 2) 2 -> 1 -> 0 |
355 | // / |
356 | // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 |
357 | // already |
358 | // has a parent. |
359 | // 3) 2 -> 1 -> 0 |
360 | // / |
361 | // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3! |
362 | // |
363 | uint64_t topological_nr() const noexcept { |
364 | has_parent_ = true; |
365 | return topological_nr_; |
366 | } |
367 | |
368 | // assigning a node as a parent to this node |
369 | void assign_parent(); |
370 | |
371 | /// Id of the thread that created Node |
372 | uint64_t thread_id() const noexcept { |
373 | return thread_id_; |
374 | } |
375 | |
376 | /// Returns the name of the dynamic type of the function, for debugging. |
377 | virtual std::string name() const; |
378 | |
379 | /// The difference between functions `should_compute_output` and |
380 | /// `task_should_compute_output`: |
381 | /// - `should_compute_output` should only be used during graph construction |
382 | /// and takes into account only requires_grad information |
383 | /// - `task_should_compute_output` should only be called during the backward |
384 | /// pass (unless called directly through grad_fn) and takes into account the |
385 | /// current graph task. Specifically, the autograd engine trims unnecessary |
386 | /// edges when `inputs` are specified, and during backward untrimmed nodes |
387 | /// left on the graph can/should check `task_should_compute_output` to see if |
388 | /// any outgoing edges have been trimmed by the engine. If that is the case, |
389 | /// gradient computation wrt those edges can be omitted. |
390 | /// |
391 | /// Returns true if the particular output edge is active, and that particular |
392 | /// output of this function should be computed. |
393 | bool should_compute_output(size_t output_edge_index) const { |
394 | TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range" ); |
395 | return next_edges_[output_edge_index].is_valid(); |
396 | } |
397 | |
398 | /// Returns true if any of the output edges in any of the ranges are active. |
399 | bool should_compute_output(std::initializer_list<IndexRange> idxs) const { |
400 | return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { |
401 | for (const auto i : c10::irange(range.first, range.second)) { |
402 | if (should_compute_output(i)) |
403 | return true; |
404 | } |
405 | return false; |
406 | }); |
407 | } |
408 | |
409 | /// Same as the above `should_compute_output` function but will also |
410 | /// check whether this edge is needed within the current graph task. |
411 | bool task_should_compute_output(size_t output_edge_index) const { |
412 | TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range" ); |
413 | const auto& next = next_edges_[output_edge_index]; |
414 | if (next.is_valid()) { |
415 | const auto exec_info = get_current_graph_task_exec_info(); |
416 | if (exec_info && !exec_info->empty()) { |
417 | auto it = exec_info->find(next.function.get()); |
418 | if (it == exec_info->end() || !it->second.should_execute()) { |
419 | return false; // this edge is not needed for the current graph_task |
420 | } |
421 | } |
422 | return true; |
423 | } |
424 | return false; |
425 | } |
426 | |
427 | /// Returns true if any of the output edges in any of the ranges are active |
428 | /// and should be computed in the current graph task. |
429 | bool task_should_compute_output( |
430 | std::initializer_list<IndexRange> idxs) const { |
431 | return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { |
432 | for (const auto i : c10::irange(range.first, range.second)) { |
433 | if (task_should_compute_output(i)) |
434 | return true; |
435 | } |
436 | return false; |
437 | }); |
438 | } |
439 | |
440 | /// Returns the `PyObject` stored for this `Node` (for Python |
441 | /// interaction). |
442 | PyObject* pyobj() const noexcept { |
443 | return pyobj_; |
444 | } |
445 | |
446 | /// Sets the `PyObject` stored for this `Node` (for Python interaction). |
447 | void set_pyobj(PyObject* pyobj) noexcept { |
448 | pyobj_ = pyobj; |
449 | } |
450 | |
451 | /// Returns the anomaly metadata stored for this `Node`. |
452 | /// If none exist, creates a new empty one. |
453 | AnomalyMetadata* metadata() noexcept; |
454 | |
455 | // Hook API |
456 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
457 | |
458 | uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) { |
459 | post_hooks_.emplace_back(std::move(post_hook)); |
460 | // Use the raw pointer as the unique key to identify this hook. This key |
461 | // can then be used in del_post_hook(key) to remove this hook. |
462 | return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get()); |
463 | } |
464 | |
465 | const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() |
466 | const noexcept { |
467 | return post_hooks_; |
468 | } |
469 | |
470 | // delete a post hook matching the key |
471 | bool del_post_hook(const uintptr_t& key) { |
472 | for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) { |
473 | if (key == reinterpret_cast<std::uintptr_t>(it->get())) { |
474 | post_hooks_.erase(it); |
475 | return true; |
476 | } |
477 | } |
478 | return false; |
479 | } |
480 | |
481 | std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept { |
482 | return post_hooks_; |
483 | } |
484 | |
485 | void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { |
486 | pre_hooks_.emplace_back(std::move(pre_hook)); |
487 | } |
488 | |
489 | void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { |
490 | tensor_pre_hooks_.emplace_back(std::move(pre_hook)); |
491 | } |
492 | |
493 | void add_retains_grad_hook( |
494 | std::unique_ptr<FunctionPreHook>&& pre_hook, |
495 | int output_idx) { |
496 | retains_grad_hooks_[output_idx] = std::move(pre_hook); |
497 | } |
498 | |
499 | std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) { |
500 | auto ret = std::move(retains_grad_hooks_[output_idx]); |
501 | retains_grad_hooks_.erase(output_idx); |
502 | return ret; |
503 | } |
504 | |
505 | const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() |
506 | const noexcept { |
507 | return pre_hooks_; |
508 | } |
509 | |
510 | std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept { |
511 | return pre_hooks_; |
512 | } |
513 | |
514 | virtual std::vector<std::unique_ptr<FunctionPreHook>>& |
515 | tensor_pre_hooks() noexcept { |
516 | return tensor_pre_hooks_; |
517 | } |
518 | |
519 | std::unordered_map<int, std::unique_ptr<FunctionPreHook>>& |
520 | retains_grad_hooks() noexcept { |
521 | return retains_grad_hooks_; |
522 | } |
523 | |
524 | // Customization Points for Subclasses |
525 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
526 | |
527 | /// Releases saved variables if the operation won't be reused. |
528 | virtual void release_variables() {} |
529 | |
530 | /// Called before an apply if `release_variables()` is going to be called. |
531 | /// Allows larger ops like `InterpreterAutogradFunction` to incrementally |
532 | /// release variables as they run. |
533 | virtual void will_release_variables() {} |
534 | |
535 | /// Returns true if this function is traceable. An op is traceable if all |
536 | /// operations happening within `apply()` are performed on autograd |
537 | /// `Variables` (i.e. apply mostly instantiates and applies other functions). |
538 | virtual bool is_traceable() { |
539 | return false; |
540 | } |
541 | |
542 | /// A `Node` is said to pass state transparently to backward, if the |
543 | /// state consists only of (Saved)Variables and only non-variable objects |
544 | /// that parameterize the operation in some way that defines the graph |
545 | /// structure AND the backward function is traceable. In particular, |
546 | /// parametrization MUST NOT depend on the data of any `Variable`. |
547 | /// TODO: it might be possible to handle cases where backward is |
548 | /// non-traceable but state passing could be considered transparent. This |
549 | /// will probably depend on saved_variable_list being mutable. |
550 | /// NOTE: this value matters only if is_traceable() returns false. |
551 | virtual bool passes_state_transparently() { |
552 | return false; |
553 | } |
554 | |
555 | protected: |
556 | /// Performs the `Node`'s actual operation. |
557 | virtual variable_list apply(variable_list&& inputs) = 0; |
558 | |
559 | /// Calls `apply()`, but instruments it with tracing machinery. |
560 | variable_list traced_apply(variable_list inputs); |
561 | |
562 | // Sequence number used to correlate backward nodes with forward ops in the |
563 | // profiler and provide determinisim in the engine. |
564 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
565 | const uint64_t sequence_nr_; |
566 | |
567 | // See NOTE [ Topological Number ] |
568 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
569 | uint64_t topological_nr_ = 0; |
570 | |
571 | // Tracks whether this node has been added as the next_edge of another node |
572 | // via set_next_edge(s), which always calls topological_nr() of all its |
573 | // children See NOTE [ Topological Number ] for why we need this. |
574 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
575 | mutable bool has_parent_ = false; |
576 | |
577 | // Id of the thread that created the instance |
578 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
579 | uint64_t thread_id_ = 0; |
580 | |
581 | // Note [Thread Safety on Autograd Node] |
582 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
583 | // Autograd Engine let the owning thread which calls Engine::execute to drive |
584 | // the GraphTask execution, there might be cases that part of the GraphTask is |
585 | // shared across different `backward()` or `grad()` calls, i.e. fork new |
586 | // threads in the middle of the forward and call `backward()` separately from |
587 | // different threads. We need to protect the thread safety on NodeTask to |
588 | // prevent data racing on shared variables read/write. |
589 | // |
590 | // NB: This is only needed for Autograd Nodes that runs on CPU, technically |
591 | // "CUDA", "XLA" nodes don't need locking because device threads are always |
592 | // single threaded. |
593 | // |
594 | // Here we add a thread mutex to help protect the Node's thread safety, so |
595 | // that different threads cannot race the shared data when executing the same |
596 | // NodeTask from multiple CPU threads. It IS the user/developer responsibility |
597 | // to take advantage of this mutex to protect the thread safety of their |
598 | // autograd Node. The general strategy of thread safety on autograd Node: |
599 | // |
600 | // 1. User should lock the mutex during Node::release_variables() if the Node |
601 | // needs |
602 | // to release the variables on the fly, this serve the purpose that when we |
603 | // release saved_variables from one thread, no other threads can release |
604 | // the saved variables concurrently. call the Node::apply(), |
605 | // 2. User should lock the mutex during Node::apply(), this is to ensure Node |
606 | // that |
607 | // writing to the shared variable are not racing across threads (i.e. |
608 | // AccumulateGrad and custom C++ Autograd Node if writing to shared |
609 | // variables ) |
610 | // 3. item 2 and item 3 should work together so that when we release saved |
611 | // variables |
612 | // from one thread, no other threads can call Node::apply(), this ensures |
613 | // the variable references from other threads aren't dangling. |
614 | // 4. if the Node don't release any variables and no shared data read/write in |
615 | // the Node |
616 | // i.e. purely functional, user don't need to lock the mutex |
617 | // |
618 | // This way we could protect the thread safety on Autograd Node, but we could |
619 | // still not protect the thread safety on Node pre/post C++ hooks (python |
620 | // hooks are automatically thread safe), we rely on the user to write thread |
621 | // safe C++ hooks if they want the hook to be correctly applied in |
622 | // multithreading environment. |
623 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
624 | std::mutex mutex_; |
625 | |
626 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
627 | edge_list next_edges_; |
628 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
629 | PyObject* pyobj_ = nullptr; // weak reference |
630 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
631 | std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr; |
632 | |
633 | // NOTE [Hooks ordering] |
634 | // We have 3 separate fields for pre hooks registered to the autograd nodes |
635 | // because the conditions under which they execute are different, and we |
636 | // want more fine-grained control over the order in which different types |
637 | // of hooks are executed. |
638 | // - pre_hooks are only executed when the node itself is executed |
639 | // - tensor_pre_hook is executed as long as the engine traverses over it |
640 | // even if that node won't be executed. |
641 | // - retains_grad_hook are like tensor_pre_hooks except they are always |
642 | // ordered after all other tensor pre hooks |
643 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
644 | std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; |
645 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
646 | std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_; |
647 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
648 | std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_; |
649 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
650 | std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; |
651 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
652 | at::SmallVector<InputMetadata, 2> input_metadata_; |
653 | }; |
654 | |
655 | /// See Node::is_traceable() for definition. |
656 | struct TraceableFunction : public Node { |
657 | using Node::Node; |
658 | bool is_traceable() final { |
659 | return true; |
660 | } |
661 | }; |
662 | |
663 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
664 | // Associated Free Nodes |
665 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
666 | |
667 | namespace detail { |
668 | // Implementation of `collect_next_edges` (see below). |
669 | struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { |
670 | edge_list next_edges; |
671 | using IterArgs<MakeNextFunctionList>::operator(); |
672 | void operator()(const Variable& variable) { |
673 | // NOLINTNEXTLINE(bugprone-branch-clone) |
674 | if (variable.defined()) { |
675 | next_edges.emplace_back(impl::gradient_edge(variable)); |
676 | } else { |
677 | next_edges.emplace_back(); |
678 | } |
679 | } |
680 | void operator()(const Variable* variable) { |
681 | // NOLINTNEXTLINE(bugprone-branch-clone) |
682 | if (variable->defined()) { |
683 | next_edges.emplace_back(impl::gradient_edge(*variable)); |
684 | } else { |
685 | next_edges.emplace_back(); |
686 | } |
687 | } |
688 | void operator()(const c10::optional<Variable>& variable) { |
689 | // NOLINTNEXTLINE(bugprone-branch-clone) |
690 | if (variable.has_value() && variable->defined()) { |
691 | next_edges.emplace_back(impl::gradient_edge(*variable)); |
692 | } else { |
693 | next_edges.emplace_back(); |
694 | } |
695 | } |
696 | }; |
697 | } // namespace detail |
698 | |
699 | /// Create an `Edge` between the given `variable` and the `function`, which is |
700 | /// assumed to be the gradient function of this variable (i.e. the function |
701 | /// through which this variable is backpropagated during the backward pass). |
702 | /// This sets the `grad_fn` property of the `variable`. This function assumes |
703 | /// that the `Variable` is a new input to the gradient function and its |
704 | /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it |
705 | /// increments the `Node`'s number of inputs by one. Approximately |
706 | /// equivalent to `variable.set_gradient_edge(function, |
707 | /// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`. |
708 | /// If you don't want the `Node`'s `num_inputs` to be incremented, use |
709 | /// `set_gradient_edge` directly. |
710 | inline void create_gradient_edge( |
711 | Variable& variable, |
712 | std::shared_ptr<Node> function) { |
713 | // Copy before move. |
714 | const auto input_nr = function->add_input_metadata(variable); |
715 | impl::set_gradient_edge(variable, {std::move(function), input_nr}); |
716 | } |
717 | |
718 | /// Return true if any of the variables in the list require a gradient. |
719 | inline bool any_variable_requires_grad(const variable_list& variables) { |
720 | return std::any_of( |
721 | variables.begin(), variables.end(), [](const Variable& variable) { |
722 | return variable.defined() && variable.requires_grad(); |
723 | }); |
724 | } |
725 | |
726 | /// Return the next edges of all the given variables, or tuples of variables. |
727 | template <typename... Variables> |
728 | edge_list collect_next_edges(Variables&&... variables) { |
729 | detail::MakeNextFunctionList make; |
730 | make.apply(std::forward<Variables>(variables)...); |
731 | return std::move(make.next_edges); |
732 | } |
733 | } // namespace autograd |
734 | } // namespace torch |
735 | |
736 | C10_CLANG_DIAGNOSTIC_POP() |
737 | |