1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/autograd/context/container.h> |
4 | #include <torch/csrc/distributed/autograd/engine/dist_engine.h> |
5 | |
6 | namespace torch { |
7 | namespace distributed { |
8 | namespace autograd { |
9 | |
10 | using torch::autograd::variable_list; |
11 | |
12 | /// C++ API of Distributed Autograd that kicks off the distributed backward pass |
13 | /// using the provided roots. This currently implements the |
14 | /// :ref:`fast-mode-algorithm` which assumes all RPC messages sent in the same |
15 | /// distributed autograd context across workers would be part of the autograd |
16 | /// graph during the backward pass. |
17 | /// |
18 | /// We use the provided roots to discover the autograd graph and compute |
19 | /// appropriate dependencies. This method blocks until the entire |
20 | /// autograd computation is done. |
21 | /// This function accumulates gradients in the leaves - you might need to zero |
22 | /// them before calling it. |
23 | /// |
24 | /// \param context_id The autograd context id for which we should retrieve the |
25 | /// gradients. |
26 | /// \param roots Tensors which represent the roots of the autograd computation. |
27 | /// All the tensors should be scalars. |
28 | /// \param retain_graph If `false`, the graph used to compute the grad will be |
29 | /// freed. Note that in nearly all cases setting this |
30 | /// option to `true` is not needed and often can be worked |
31 | /// around in a much more efficient way. Usually, you need |
32 | /// to set this to `true` to run backward multiple times. |
33 | TORCH_API void backward( |
34 | int64_t context_id, |
35 | const variable_list& roots, |
36 | bool retain_graph = false); |
37 | |
38 | } // namespace autograd |
39 | } // namespace distributed |
40 | } // namespace torch |
41 | |