1#include <torch/csrc/autograd/python_cpp_function.h>
2#include <torch/csrc/distributed/autograd/autograd.h>
3#include <torch/csrc/jit/python/pybind_utils.h>
4#include <torch/csrc/python_headers.h>
5#include <torch/csrc/utils/object_ptr.h>
6#include <torch/csrc/utils/pybind.h>
7#include <torch/types.h>
8
9namespace torch {
10namespace distributed {
11namespace autograd {
12
13namespace {
14
15template <typename T>
16using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
17
18PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
19 auto autograd_module =
20 THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
21 if (!autograd_module) {
22 throw python_error();
23 }
24
25 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
26 if (!torch_C_module) {
27 throw python_error();
28 }
29
30 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
31 auto m = torch_C_m.def_submodule(
32 "_distributed_autograd", "distributed autograd bindings");
33
34 auto module = py::handle(m).cast<py::module>();
35
36 auto distAutogradContext =
37 shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
38 .def(
39 "_context_id",
40 &DistAutogradContext::contextId,
41 py::call_guard<py::gil_scoped_release>())
42 .def(
43 "_recv_functions",
44 [](const DistAutogradContext& ctx) {
45 std::map<int64_t, py::object> funcs;
46 auto recvFunctions = ctx.recvFunctions();
47
48 // Acquire GIL only when necessary to avoid deadlocks.
49 pybind11::gil_scoped_acquire ag;
50 for (const auto& map_entry : recvFunctions) {
51 funcs.emplace(
52 map_entry.first,
53 py::reinterpret_steal<py::object>(
54 torch::autograd::functionToPyObject(
55 map_entry.second)));
56 }
57 return funcs;
58 },
59 py::call_guard<py::gil_scoped_release>())
60 .def(
61 "_send_functions",
62 [](const ContextPtr& ctx) {
63 std::map<int64_t, py::object> funcs;
64 auto sendFunctions = ctx->sendFunctions();
65
66 // Acquire GIL only when necessary to avoid deadlocks.
67 pybind11::gil_scoped_acquire ag;
68 for (const auto& map_entry : sendFunctions) {
69 funcs.emplace(
70 map_entry.first,
71 py::reinterpret_steal<py::object>(
72 torch::autograd::functionToPyObject(
73 map_entry.second)));
74 }
75 return funcs;
76 },
77 py::call_guard<py::gil_scoped_release>())
78 .def(
79 "_known_worker_ids",
80 &DistAutogradContext::getKnownWorkerIds,
81 py::call_guard<py::gil_scoped_release>());
82
83 module.def(
84 "_new_context",
85 []() -> const ContextPtr {
86 return DistAutogradContainer::getInstance().newContext();
87 },
88 py::return_value_policy::reference,
89 py::call_guard<py::gil_scoped_release>());
90
91 module.def(
92 "_release_context",
93 [](int64_t context_id) {
94 return DistAutogradContainer::getInstance().releaseContext(context_id);
95 },
96 py::call_guard<py::gil_scoped_release>());
97
98 module.def(
99 "_get_max_id",
100 []() { return DistAutogradContainer::getInstance().getMaxId(); },
101 py::call_guard<py::gil_scoped_release>());
102
103 module.def(
104 "_is_valid_context",
105 [](int64_t worker_id) {
106 DistAutogradContainer::getInstance().isValidContext(worker_id);
107 },
108 py::call_guard<py::gil_scoped_release>());
109
110 module.def(
111 "_retrieve_context",
112 [](int64_t context_id) -> const ContextPtr {
113 return DistAutogradContainer::getInstance().retrieveContext(context_id);
114 },
115 py::return_value_policy::reference,
116 py::call_guard<py::gil_scoped_release>());
117
118 module.def(
119 "_current_context",
120 []() -> const ContextPtr {
121 return DistAutogradContainer::getInstance().currentContext();
122 },
123 py::return_value_policy::reference,
124 py::call_guard<py::gil_scoped_release>());
125
126 module.def(
127 "_init",
128 [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
129 py::call_guard<py::gil_scoped_release>());
130
131 module.def(
132 "_get_debug_info",
133 []() { return DistEngine::getInstance().getDebugInfo(); },
134 py::call_guard<py::gil_scoped_release>());
135
136 py::options options;
137 options.disable_function_signatures();
138
139 module.def(
140 "backward",
141 backward,
142 R"(
143backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
144
145Kicks off the distributed backward pass using the provided roots. This
146currently implements the :ref:`fast-mode-algorithm` which
147assumes all RPC messages sent in the same distributed autograd context
148across workers would be part of the autograd graph during the backward pass.
149
150We use the provided roots to discover the autograd graph and compute
151appropriate dependencies. This method blocks until the entire
152autograd computation is done.
153
154We accumulate the gradients in the appropriate
155:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
156context to be used is looked up given the ``context_id`` that is passed in when
157:meth:`torch.distributed.autograd.backward` is called. If there is no valid
158autograd context corresponding to the given ID, we throw an error. You can
159retrieve the accumulated gradients using the
160:meth:`~torch.distributed.autograd.get_gradients` API.
161
162Arguments:
163 context_id (int): The autograd context id for which we should retrieve the gradients.
164 roots (list): Tensors which represent the roots of the autograd
165 computation. All the tensors should be scalars.
166 retain_graph(bool, optional): If False, the graph used to compute the grad
167 will be freed. Note that in nearly all cases setting this
168 option to True is not needed and often can be worked around
169 in a much more efficient way. Usually, you need to set this
170 to True to run backward multiple times.
171
172Example::
173 >>> import torch.distributed.autograd as dist_autograd
174 >>> with dist_autograd.context() as context_id:
175 >>> pred = model.forward()
176 >>> loss = loss_func(pred, loss)
177 >>> dist_autograd.backward(context_id, loss)
178)",
179 py::arg("contextId"),
180 py::arg("roots"),
181 py::arg("retain_graph") = false,
182 py::call_guard<py::gil_scoped_release>());
183
184 module.def(
185 "get_gradients",
186 [](int64_t contextId) -> py::dict {
187 const auto& autogradContext =
188 DistAutogradContainer::getInstance().retrieveContext(contextId);
189 auto ival = IValue(autogradContext->getGradients());
190
191 // Acquire GIL only for pyobject conversion.
192 pybind11::gil_scoped_acquire ag;
193 return torch::jit::toPyObject(ival);
194 },
195 R"(
196get_gradients(context_id: int) -> Dict[Tensor, Tensor]
197
198Retrieves a map from Tensor to the appropriate gradient for that Tensor
199accumulated in the provided context corresponding to the given ``context_id``
200as part of the distributed autograd backward pass.
201
202Arguments:
203 context_id(int): The autograd context id for which we should retrieve the
204 gradients.
205
206Returns:
207 A map where the key is the Tensor and the value is the associated gradient
208 for that Tensor.
209
210Example::
211 >>> import torch.distributed.autograd as dist_autograd
212 >>> with dist_autograd.context() as context_id:
213 >>> t1 = torch.rand((3, 3), requires_grad=True)
214 >>> t2 = torch.rand((3, 3), requires_grad=True)
215 >>> loss = t1 + t2
216 >>> dist_autograd.backward(context_id, [loss.sum()])
217 >>> grads = dist_autograd.get_gradients(context_id)
218 >>> print(grads[t1])
219 >>> print(grads[t2])
220)",
221 py::arg("context_id"),
222 py::call_guard<py::gil_scoped_release>());
223
224 Py_RETURN_TRUE;
225}
226} // namespace
227
228static PyMethodDef methods[] = { // NOLINT
229 {"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr},
230 {nullptr, nullptr, 0, nullptr}};
231
232PyMethodDef* python_functions() {
233 return methods;
234}
235
236} // namespace autograd
237} // namespace distributed
238} // namespace torch
239