1#pragma once
2
3#include <torch/csrc/distributed/autograd/context/container.h>
4#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
5
6namespace torch {
7namespace distributed {
8namespace autograd {
9
10using torch::autograd::variable_list;
11
12/// C++ API of Distributed Autograd that kicks off the distributed backward pass
13/// using the provided roots. This currently implements the
14/// :ref:`fast-mode-algorithm` which assumes all RPC messages sent in the same
15/// distributed autograd context across workers would be part of the autograd
16/// graph during the backward pass.
17///
18/// We use the provided roots to discover the autograd graph and compute
19/// appropriate dependencies. This method blocks until the entire
20/// autograd computation is done.
21/// This function accumulates gradients in the leaves - you might need to zero
22/// them before calling it.
23///
24/// \param context_id The autograd context id for which we should retrieve the
25/// gradients.
26/// \param roots Tensors which represent the roots of the autograd computation.
27/// All the tensors should be scalars.
28/// \param retain_graph If `false`, the graph used to compute the grad will be
29/// freed. Note that in nearly all cases setting this
30/// option to `true` is not needed and often can be worked
31/// around in a much more efficient way. Usually, you need
32/// to set this to `true` to run backward multiple times.
33TORCH_API void backward(
34 int64_t context_id,
35 const variable_list& roots,
36 bool retain_graph = false);
37
38} // namespace autograd
39} // namespace distributed
40} // namespace torch
41