1 | #include <torch/csrc/autograd/python_cpp_function.h> |
2 | #include <torch/csrc/distributed/autograd/autograd.h> |
3 | #include <torch/csrc/jit/python/pybind_utils.h> |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <torch/csrc/utils/object_ptr.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | #include <torch/types.h> |
8 | |
9 | namespace torch { |
10 | namespace distributed { |
11 | namespace autograd { |
12 | |
13 | namespace { |
14 | |
15 | template <typename T> |
16 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
17 | |
18 | PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) { |
19 | auto autograd_module = |
20 | THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd" )); |
21 | if (!autograd_module) { |
22 | throw python_error(); |
23 | } |
24 | |
25 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C" )); |
26 | if (!torch_C_module) { |
27 | throw python_error(); |
28 | } |
29 | |
30 | auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); |
31 | auto m = torch_C_m.def_submodule( |
32 | "_distributed_autograd" , "distributed autograd bindings" ); |
33 | |
34 | auto module = py::handle(m).cast<py::module>(); |
35 | |
36 | auto distAutogradContext = |
37 | shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext" ) |
38 | .def( |
39 | "_context_id" , |
40 | &DistAutogradContext::contextId, |
41 | py::call_guard<py::gil_scoped_release>()) |
42 | .def( |
43 | "_recv_functions" , |
44 | [](const DistAutogradContext& ctx) { |
45 | std::map<int64_t, py::object> funcs; |
46 | auto recvFunctions = ctx.recvFunctions(); |
47 | |
48 | // Acquire GIL only when necessary to avoid deadlocks. |
49 | pybind11::gil_scoped_acquire ag; |
50 | for (const auto& map_entry : recvFunctions) { |
51 | funcs.emplace( |
52 | map_entry.first, |
53 | py::reinterpret_steal<py::object>( |
54 | torch::autograd::functionToPyObject( |
55 | map_entry.second))); |
56 | } |
57 | return funcs; |
58 | }, |
59 | py::call_guard<py::gil_scoped_release>()) |
60 | .def( |
61 | "_send_functions" , |
62 | [](const ContextPtr& ctx) { |
63 | std::map<int64_t, py::object> funcs; |
64 | auto sendFunctions = ctx->sendFunctions(); |
65 | |
66 | // Acquire GIL only when necessary to avoid deadlocks. |
67 | pybind11::gil_scoped_acquire ag; |
68 | for (const auto& map_entry : sendFunctions) { |
69 | funcs.emplace( |
70 | map_entry.first, |
71 | py::reinterpret_steal<py::object>( |
72 | torch::autograd::functionToPyObject( |
73 | map_entry.second))); |
74 | } |
75 | return funcs; |
76 | }, |
77 | py::call_guard<py::gil_scoped_release>()) |
78 | .def( |
79 | "_known_worker_ids" , |
80 | &DistAutogradContext::getKnownWorkerIds, |
81 | py::call_guard<py::gil_scoped_release>()); |
82 | |
83 | module.def( |
84 | "_new_context" , |
85 | []() -> const ContextPtr { |
86 | return DistAutogradContainer::getInstance().newContext(); |
87 | }, |
88 | py::return_value_policy::reference, |
89 | py::call_guard<py::gil_scoped_release>()); |
90 | |
91 | module.def( |
92 | "_release_context" , |
93 | [](int64_t context_id) { |
94 | return DistAutogradContainer::getInstance().releaseContext(context_id); |
95 | }, |
96 | py::call_guard<py::gil_scoped_release>()); |
97 | |
98 | module.def( |
99 | "_get_max_id" , |
100 | []() { return DistAutogradContainer::getInstance().getMaxId(); }, |
101 | py::call_guard<py::gil_scoped_release>()); |
102 | |
103 | module.def( |
104 | "_is_valid_context" , |
105 | [](int64_t worker_id) { |
106 | DistAutogradContainer::getInstance().isValidContext(worker_id); |
107 | }, |
108 | py::call_guard<py::gil_scoped_release>()); |
109 | |
110 | module.def( |
111 | "_retrieve_context" , |
112 | [](int64_t context_id) -> const ContextPtr { |
113 | return DistAutogradContainer::getInstance().retrieveContext(context_id); |
114 | }, |
115 | py::return_value_policy::reference, |
116 | py::call_guard<py::gil_scoped_release>()); |
117 | |
118 | module.def( |
119 | "_current_context" , |
120 | []() -> const ContextPtr { |
121 | return DistAutogradContainer::getInstance().currentContext(); |
122 | }, |
123 | py::return_value_policy::reference, |
124 | py::call_guard<py::gil_scoped_release>()); |
125 | |
126 | module.def( |
127 | "_init" , |
128 | [](int64_t worker_id) { DistAutogradContainer::init(worker_id); }, |
129 | py::call_guard<py::gil_scoped_release>()); |
130 | |
131 | module.def( |
132 | "_get_debug_info" , |
133 | []() { return DistEngine::getInstance().getDebugInfo(); }, |
134 | py::call_guard<py::gil_scoped_release>()); |
135 | |
136 | py::options options; |
137 | options.disable_function_signatures(); |
138 | |
139 | module.def( |
140 | "backward" , |
141 | backward, |
142 | R"( |
143 | backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None |
144 | |
145 | Kicks off the distributed backward pass using the provided roots. This |
146 | currently implements the :ref:`fast-mode-algorithm` which |
147 | assumes all RPC messages sent in the same distributed autograd context |
148 | across workers would be part of the autograd graph during the backward pass. |
149 | |
150 | We use the provided roots to discover the autograd graph and compute |
151 | appropriate dependencies. This method blocks until the entire |
152 | autograd computation is done. |
153 | |
154 | We accumulate the gradients in the appropriate |
155 | :class:`torch.distributed.autograd.context` on each of the nodes. The autograd |
156 | context to be used is looked up given the ``context_id`` that is passed in when |
157 | :meth:`torch.distributed.autograd.backward` is called. If there is no valid |
158 | autograd context corresponding to the given ID, we throw an error. You can |
159 | retrieve the accumulated gradients using the |
160 | :meth:`~torch.distributed.autograd.get_gradients` API. |
161 | |
162 | Arguments: |
163 | context_id (int): The autograd context id for which we should retrieve the gradients. |
164 | roots (list): Tensors which represent the roots of the autograd |
165 | computation. All the tensors should be scalars. |
166 | retain_graph(bool, optional): If False, the graph used to compute the grad |
167 | will be freed. Note that in nearly all cases setting this |
168 | option to True is not needed and often can be worked around |
169 | in a much more efficient way. Usually, you need to set this |
170 | to True to run backward multiple times. |
171 | |
172 | Example:: |
173 | >>> import torch.distributed.autograd as dist_autograd |
174 | >>> with dist_autograd.context() as context_id: |
175 | >>> pred = model.forward() |
176 | >>> loss = loss_func(pred, loss) |
177 | >>> dist_autograd.backward(context_id, loss) |
178 | )" , |
179 | py::arg("contextId" ), |
180 | py::arg("roots" ), |
181 | py::arg("retain_graph" ) = false, |
182 | py::call_guard<py::gil_scoped_release>()); |
183 | |
184 | module.def( |
185 | "get_gradients" , |
186 | [](int64_t contextId) -> py::dict { |
187 | const auto& autogradContext = |
188 | DistAutogradContainer::getInstance().retrieveContext(contextId); |
189 | auto ival = IValue(autogradContext->getGradients()); |
190 | |
191 | // Acquire GIL only for pyobject conversion. |
192 | pybind11::gil_scoped_acquire ag; |
193 | return torch::jit::toPyObject(ival); |
194 | }, |
195 | R"( |
196 | get_gradients(context_id: int) -> Dict[Tensor, Tensor] |
197 | |
198 | Retrieves a map from Tensor to the appropriate gradient for that Tensor |
199 | accumulated in the provided context corresponding to the given ``context_id`` |
200 | as part of the distributed autograd backward pass. |
201 | |
202 | Arguments: |
203 | context_id(int): The autograd context id for which we should retrieve the |
204 | gradients. |
205 | |
206 | Returns: |
207 | A map where the key is the Tensor and the value is the associated gradient |
208 | for that Tensor. |
209 | |
210 | Example:: |
211 | >>> import torch.distributed.autograd as dist_autograd |
212 | >>> with dist_autograd.context() as context_id: |
213 | >>> t1 = torch.rand((3, 3), requires_grad=True) |
214 | >>> t2 = torch.rand((3, 3), requires_grad=True) |
215 | >>> loss = t1 + t2 |
216 | >>> dist_autograd.backward(context_id, [loss.sum()]) |
217 | >>> grads = dist_autograd.get_gradients(context_id) |
218 | >>> print(grads[t1]) |
219 | >>> print(grads[t2]) |
220 | )" , |
221 | py::arg("context_id" ), |
222 | py::call_guard<py::gil_scoped_release>()); |
223 | |
224 | Py_RETURN_TRUE; |
225 | } |
226 | } // namespace |
227 | |
228 | static PyMethodDef methods[] = { // NOLINT |
229 | {"_dist_autograd_init" , dist_autograd_init, METH_NOARGS, nullptr}, |
230 | {nullptr, nullptr, 0, nullptr}}; |
231 | |
232 | PyMethodDef* python_functions() { |
233 | return methods; |
234 | } |
235 | |
236 | } // namespace autograd |
237 | } // namespace distributed |
238 | } // namespace torch |
239 | |