1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <c10/util/intrusive_ptr.h> |
4 | #include <c10/util/string_view.h> |
5 | #include <torch/csrc/distributed/c10d/FileStore.hpp> |
6 | #include <torch/csrc/distributed/c10d/TCPStore.hpp> |
7 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
8 | #ifndef _WIN32 |
9 | #include <torch/csrc/distributed/c10d/HashStore.hpp> |
10 | #include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp> |
11 | #endif |
12 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
13 | #include <torch/csrc/distributed/c10d/PyProcessGroup.hpp> |
14 | |
15 | #ifdef USE_C10D_GLOO |
16 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
17 | #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp> |
18 | #endif |
19 | |
20 | #ifdef USE_C10D_NCCL |
21 | #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
22 | #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
23 | #endif |
24 | |
25 | #ifdef USE_C10D_MPI |
26 | #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp> |
27 | #endif |
28 | |
29 | #ifdef USE_C10D_UCC |
30 | #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp> |
31 | #endif |
32 | |
33 | #include <fmt/format.h> |
34 | #include <pybind11/chrono.h> |
35 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
36 | |
37 | #include <torch/csrc/distributed/c10d/comm.hpp> |
38 | #include <torch/csrc/distributed/c10d/debug.h> |
39 | #include <torch/csrc/distributed/c10d/logger.hpp> |
40 | #include <torch/csrc/distributed/c10d/reducer.hpp> |
41 | |
42 | #include <torch/csrc/Exceptions.h> |
43 | #include <torch/csrc/distributed/c10d/python_comm_hook.h> |
44 | #include <torch/csrc/jit/python/pybind_utils.h> |
45 | #include <torch/csrc/utils/object_ptr.h> |
46 | #include <torch/csrc/utils/pybind.h> |
47 | |
48 | #include <torch/custom_class.h> |
49 | |
50 | namespace { |
51 | |
52 | // Wrapper to ensure GIL is released before destructing ProcessGroupGloo |
53 | // TODO: move this somewhere more generally useful |
54 | template <typename T> |
55 | class IntrusivePtrNoGilDestructor { |
56 | c10::intrusive_ptr<T> impl_; |
57 | |
58 | public: |
59 | IntrusivePtrNoGilDestructor() = default; |
60 | IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default; |
61 | IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default; |
62 | IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) = |
63 | default; |
64 | IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) = |
65 | default; |
66 | /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl) |
67 | : impl_(std::move(impl)) {} |
68 | // This ctor is very important; see |
69 | // https://github.com/pybind/pybind11/issues/2957 |
70 | explicit IntrusivePtrNoGilDestructor(T* impl) |
71 | : impl_(c10::intrusive_ptr<T>::unsafe_steal_from_new(impl)) {} |
72 | ~IntrusivePtrNoGilDestructor() { |
73 | if (impl_) { |
74 | if (PyGILState_Check()) { |
75 | pybind11::gil_scoped_release release; |
76 | impl_.reset(); |
77 | } else { |
78 | impl_.reset(); |
79 | } |
80 | } |
81 | } |
82 | T& operator*() const noexcept { |
83 | return *impl_; |
84 | } |
85 | T* operator->() const noexcept { |
86 | return impl_.get(); |
87 | } |
88 | C10_NODISCARD T* get() const noexcept { |
89 | return impl_.get(); |
90 | } |
91 | void reset() noexcept { |
92 | impl_.reset(); |
93 | } |
94 | operator bool() const noexcept { |
95 | return impl_; |
96 | } |
97 | }; |
98 | |
99 | } // anonymous namespace |
100 | |
101 | PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor<T>, true); |
102 | |
103 | namespace torch { |
104 | namespace distributed { |
105 | namespace c10d { |
106 | |
107 | namespace { |
108 | |
109 | template <typename T> |
110 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
111 | |
112 | constexpr auto kDeprecationWarning = |
113 | "{} API is being deprecated, please ping " |
114 | "https://github.com/pytorch/pytorch/issues/46291 " |
115 | "if you see this warning" ; |
116 | template <typename T> |
117 | using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>; |
118 | |
119 | template <typename T> |
120 | using intrusive_ptr_no_gil_destructor_class_ = |
121 | py::class_<T, IntrusivePtrNoGilDestructor<T>>; |
122 | |
123 | // PythonStore is a pybind11 trampoline class to allow a Python |
124 | // class to inherit from c10d.Store and implement its interface. |
125 | class PythonStore : public ::c10d::Store { |
126 | public: |
127 | using ::c10d::Store::Store; |
128 | |
129 | // Note: this function manually calls the Python-side overload |
130 | // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
131 | // macros. This is done so that we can call the Python-side |
132 | // function with a std::string instead of a std::vector<uint8_t>. |
133 | void set(const std::string& key, const std::vector<uint8_t>& value) override { |
134 | pybind11::gil_scoped_acquire gil; |
135 | pybind11::function fn = |
136 | pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set" ); |
137 | TORCH_INTERNAL_ASSERT(fn); |
138 | // Call function with a py::bytes object for the value. |
139 | fn(key, |
140 | py::bytes(reinterpret_cast<const char*>(value.data()), value.size())); |
141 | } |
142 | |
143 | // Note: this function manually calls the Python-side overload |
144 | // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
145 | // macros. This is done so that the Python-side function can |
146 | // return a py::bytes instead of a std::vector<uint8_t>. |
147 | std::vector<uint8_t> get(const std::string& key) override { |
148 | pybind11::gil_scoped_acquire gil; |
149 | pybind11::function fn = |
150 | pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "get" ); |
151 | TORCH_INTERNAL_ASSERT(fn); |
152 | // Cast return value from Python to py::bytes, then implicitly |
153 | // convert that to a std::string, so that we can construct a |
154 | // std::vector<uint8_t>. There is no API for directly accessing |
155 | // the contents of the py::bytes object. |
156 | std::string str = pybind11::cast<py::bytes>(fn(key)); |
157 | return std::vector<uint8_t>(str.begin(), str.end()); |
158 | } |
159 | |
160 | // Note: this function manually calls the Python-side overload |
161 | // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
162 | // macros. This is done so that the Python-side function can |
163 | // return a py::bytes instead of a std::vector<uint8_t>. |
164 | std::vector<uint8_t> compareSet( |
165 | const std::string& key, |
166 | const std::vector<uint8_t>& expectedValue, |
167 | const std::vector<uint8_t>& desiredValue) override { |
168 | pybind11::gil_scoped_acquire gil; |
169 | pybind11::function fn = pybind11::get_overload( |
170 | static_cast<const ::c10d::Store*>(this), "compare_set" ); |
171 | TORCH_INTERNAL_ASSERT(fn); |
172 | // Cast return value from Python to py::bytes, then implicitly |
173 | // convert that to a std::string, so that we can construct a |
174 | // std::vector<uint8_t>. There is no API for directly accessing |
175 | // the contents of the py::bytes object. |
176 | std::string str = |
177 | pybind11::cast<py::bytes>(fn(key, expectedValue, desiredValue)); |
178 | return std::vector<uint8_t>(str.begin(), str.end()); |
179 | } |
180 | |
181 | int64_t add(const std::string& key, int64_t value) override { |
182 | PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value); |
183 | } |
184 | |
185 | int64_t getNumKeys() override { |
186 | PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys); |
187 | } |
188 | |
189 | bool deleteKey(const std::string& key) override { |
190 | PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key); |
191 | } |
192 | |
193 | bool check(const std::vector<std::string>& keys) override { |
194 | PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys); |
195 | } |
196 | |
197 | void wait(const std::vector<std::string>& keys) override { |
198 | PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys); |
199 | } |
200 | |
201 | void wait( |
202 | const std::vector<std::string>& keys, |
203 | const std::chrono::milliseconds& timeout) override { |
204 | PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout); |
205 | } |
206 | }; |
207 | |
208 | // Called from DDP's Python API to create a c10d Python comm hook object. |
209 | // The input state and callable comm_hook are Python objects. It later calls |
210 | // register_comm_hook function of the reducer input to register the hook. |
211 | void _register_comm_hook( |
212 | ::c10d::Reducer& reducer, |
213 | py::object state, |
214 | py::object comm_hook) { |
215 | reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>( |
216 | std::move(state), std::move(comm_hook))); |
217 | } |
218 | |
219 | // Called from DDP's Python API to create a c10d C++ comm hook. |
220 | // The input is an enum hook type. It later calls register_builtin_comm_hook |
221 | // function of the reducer input to set the hook type. |
222 | void _register_builtin_comm_hook( |
223 | ::c10d::Reducer& reducer, |
224 | ::c10d::BuiltinCommHookType comm_hook_type) { |
225 | reducer.register_builtin_comm_hook(comm_hook_type); |
226 | } |
227 | |
228 | // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility. |
229 | // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to |
230 | // struct from enum, sacrificing some of the Python built-in function supports |
231 | // such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191) |
232 | // and `copy` (see |
233 | // https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below, |
234 | // we define a custom `isinstance` in CPython/pybind11 |
235 | // (`reduceopmeta___instancecheck__`) and modify the default metaclass of |
236 | // pybind11 (`GetReduceOpMetaclass`) so that |
237 | // `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)` |
238 | // returns :obj:`True` as if `ReduceOp` is enum. |
239 | // Ref: |
240 | // - https://docs.python.org/3/extending/newtypes_tutorial.html |
241 | // - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods |
242 | // - https://github.com/pybind/pybind11/issues/2696 |
243 | static PyObject* reduceopmeta___instancecheck__( |
244 | PyObject* self, |
245 | PyObject* args) { |
246 | if (Py_TYPE(self) == Py_TYPE(args)) { |
247 | Py_RETURN_TRUE; |
248 | } |
249 | if (c10::string_view(args->ob_type->tp_name).find("RedOpType" ) != |
250 | c10::string_view::npos) { |
251 | Py_RETURN_TRUE; |
252 | } |
253 | Py_RETURN_FALSE; |
254 | } |
255 | static PyMethodDef reduceopmeta_methods[] = { |
256 | {"__instancecheck__" , |
257 | (PyCFunction)reduceopmeta___instancecheck__, |
258 | METH_O, |
259 | "Custom `__instancecheck__` for ReduceOp" }, |
260 | {nullptr, nullptr}}; |
261 | PyTypeObject* GetReduceOpMetaclass() { |
262 | static auto* metaclass = [] { |
263 | PyTypeObject* base_metaclass = |
264 | pybind11::detail::get_internals().default_metaclass; |
265 | PyType_Slot slots[] = { |
266 | {Py_tp_base, base_metaclass}, |
267 | {Py_tp_methods, reduceopmeta_methods}, |
268 | {0}, |
269 | }; |
270 | PyType_Spec spec = {}; |
271 | spec.name = "torch._C._distributed_c10d._ReduceOpMeta" ; |
272 | spec.basicsize = base_metaclass->tp_basicsize; |
273 | spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; |
274 | spec.slots = slots; |
275 | PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec); |
276 | if (!metaclass) |
277 | throw py::error_already_set(); |
278 | return metaclass; |
279 | }(); |
280 | return metaclass; |
281 | } |
282 | |
283 | PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { |
284 | C10_LOG_API_USAGE_ONCE("c10d.python.import" ); |
285 | |
286 | auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed" )); |
287 | if (!c10d_module) { |
288 | throw python_error(); |
289 | } |
290 | |
291 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C" )); |
292 | if (!torch_C_module) { |
293 | throw python_error(); |
294 | } |
295 | |
296 | auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); |
297 | auto m = |
298 | torch_C_m.def_submodule("_distributed_c10d" , "distributed c10d bindings" ); |
299 | |
300 | auto module = py::handle(m).cast<py::module>(); |
301 | |
302 | module |
303 | .def( |
304 | "_register_comm_hook" , |
305 | &_register_comm_hook, |
306 | py::arg("reducer" ), |
307 | py::arg("state" ), |
308 | py::arg("comm_hook" ), |
309 | py::call_guard<py::gil_scoped_release>()) |
310 | .def( |
311 | "_register_builtin_comm_hook" , |
312 | &_register_builtin_comm_hook, |
313 | py::arg("reducer" ), |
314 | py::arg("comm_hook_type" )); |
315 | |
316 | shared_ptr_class_<::c10d::GradBucket>( |
317 | module, |
318 | "GradBucket" , |
319 | R"( |
320 | This class mainly passes a flattened gradient tensor |
321 | (returned by :meth:`~torch.distributed.GradBucket.buffer`) |
322 | to DDP communication hook. |
323 | This tensor can be further decomposed into a list of per-parameter tensors within this bucket |
324 | (returned by :meth:`~torch.distributed.GradBucket.get_per_parameter_tensors`) |
325 | to apply layer-wise operations. |
326 | )" ) |
327 | .def( |
328 | "index" , |
329 | &::c10d::GradBucket::getIndex, |
330 | py::call_guard<py::gil_scoped_release>(), |
331 | R"( |
332 | .. warning:: |
333 | Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training. |
334 | |
335 | Returns: |
336 | The index of a bucket that stores gradients of a few contiguous layers. |
337 | All the gradients are bucketized. |
338 | )" ) |
339 | .def( |
340 | "buffer" , |
341 | &::c10d::GradBucket::getBuffer, |
342 | py::call_guard<py::gil_scoped_release>(), |
343 | R"( |
344 | Returns: |
345 | A flattened 1D ``torch.Tensor`` buffer, |
346 | which can be further decomposed into a list of per-parameter tensors within this bucket. |
347 | )" ) |
348 | .def( |
349 | "gradients" , |
350 | &::c10d::GradBucket::getGradients, |
351 | py::call_guard<py::gil_scoped_release>(), |
352 | R"( |
353 | Returns: |
354 | A list of ``torch.Tensor``. Each tensor in the list corresponds to a gradient. |
355 | )" ) |
356 | .def( |
357 | "parameters" , |
358 | &::c10d::GradBucket::getParameters, |
359 | py::call_guard<py::gil_scoped_release>(), |
360 | R"( |
361 | Returns: |
362 | A list of ``torch.Tensor``. Each tensor in the list corresponds to a model |
363 | parameter. |
364 | )" ) |
365 | .def( |
366 | "is_last" , |
367 | &::c10d::GradBucket::isLast, |
368 | py::call_guard<py::gil_scoped_release>(), |
369 | R"( |
370 | Returns: |
371 | Whether this bucket is the last bucket to allreduce in an iteration. |
372 | This also means that this bucket corresponds to the first few layers in the forward pass. |
373 | )" ) |
374 | .def( |
375 | "set_buffer" , |
376 | &::c10d::GradBucket::setBuffer, |
377 | py::arg("buffer" ), |
378 | py::call_guard<py::gil_scoped_release>(), |
379 | R"( |
380 | Replaces the tensor in the bucket with the input tensor buffer. |
381 | )" ); |
382 | |
383 | py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType" , R"( |
384 | An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)" ) |
385 | .value("ALLREDUCE" , ::c10d::BuiltinCommHookType::ALLREDUCE) |
386 | .value("FP16_COMPRESS" , ::c10d::BuiltinCommHookType::FP16_COMPRESS); |
387 | |
388 | shared_ptr_class_<::c10d::Reducer>(module, "Reducer" ) |
389 | .def( |
390 | py::init< |
391 | std::vector<at::Tensor>, |
392 | std::vector<std::vector<size_t>>, |
393 | std::vector<size_t>, |
394 | c10::intrusive_ptr<::c10d::ProcessGroup>, |
395 | std::vector<bool>, |
396 | int64_t, |
397 | bool, |
398 | bool, |
399 | std::unordered_map<size_t, std::string>, |
400 | int64_t>(), |
401 | py::arg("params" ), |
402 | py::arg("bucket_indices" ), |
403 | py::arg("per_bucket_size_limits" ), |
404 | py::arg("process_group" ), |
405 | py::arg("expect_sparse_gradients" ) = std::vector<bool>(), |
406 | py::arg("bucket_bytes_cap" ) = ::c10d::kDefaultBucketBytesCap, |
407 | py::arg("find_unused_parameters" ) = false, |
408 | py::arg("gradient_as_bucket_view" ) = false, |
409 | py::arg("param_to_name_mapping" ) = |
410 | std::unordered_map<size_t, std::string>(), |
411 | py::arg("first_bucket_bytes_cap" ) = ::c10d::kDefaultFirstBucketBytes, |
412 | py::call_guard<py::gil_scoped_release>()) |
413 | .def( |
414 | "prepare_for_forward" , |
415 | &::c10d::Reducer::prepare_for_forward, |
416 | py::call_guard<py::gil_scoped_release>()) |
417 | .def( |
418 | "prepare_for_backward" , |
419 | &::c10d::Reducer::prepare_for_backward, |
420 | py::call_guard<py::gil_scoped_release>()) |
421 | .def( |
422 | "prepare_for_backward" , |
423 | [](::c10d::Reducer& reducer, const at::Tensor& output) -> void { |
424 | reducer.prepare_for_backward({output}); |
425 | }, |
426 | py::call_guard<py::gil_scoped_release>()) |
427 | .def("get_backward_stats" , &::c10d::Reducer::get_backward_stats) |
428 | .def( |
429 | "_install_post_backward_futures" , |
430 | [](::c10d::Reducer& reducer, |
431 | const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& |
432 | futs) { |
433 | c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures( |
434 | c10::FutureType::create(c10::TensorType::get())); |
435 | for (const auto& fut : futs) { |
436 | futures.push_back(fut->fut); |
437 | } |
438 | reducer.install_futures(std::move(futures)); |
439 | }, |
440 | py::call_guard<py::gil_scoped_release>()) |
441 | .def( |
442 | "_rebuild_buckets" , |
443 | &::c10d::Reducer::rebuild_buckets, |
444 | py::call_guard<py::gil_scoped_release>()) |
445 | .def( |
446 | "_get_zeros_like_grad_buckets" , |
447 | [](::c10d::Reducer& reducer) { |
448 | return reducer.get_grad_buckets(/* return_zero_tensors */ true); |
449 | }, |
450 | py::call_guard<py::gil_scoped_release>()) |
451 | .def( |
452 | "_set_grads_to_none" , |
453 | [](::c10d::Reducer& reducer) { reducer.set_grads_to_none(true); }, |
454 | py::call_guard<py::gil_scoped_release>()) |
455 | .def( |
456 | "_push_all_rebuilt_params" , |
457 | &::c10d::Reducer::push_rebuilt_params_for_all_indices, |
458 | py::call_guard<py::gil_scoped_release>()) |
459 | .def( |
460 | "_set_forward_pass_work_handle" , |
461 | &::c10d::Reducer::set_forward_pass_work_handle, |
462 | py::call_guard<py::gil_scoped_release>()) |
463 | .def( |
464 | "_get_local_used_map" , &::c10d::Reducer::get_local_used_map_on_device) |
465 | .def( |
466 | "_set_ddp_runtime_logging_sample_rate" , |
467 | &::c10d::Reducer::set_ddp_runtime_logging_sample_rate, |
468 | py::arg("sample_rate" ), |
469 | py::call_guard<py::gil_scoped_release>()) |
470 | .def( |
471 | "_set_static_graph" , |
472 | &::c10d::Reducer::set_static_graph, |
473 | py::call_guard<py::gil_scoped_release>()) |
474 | .def( |
475 | "_ddp_graph_static" , |
476 | &::c10d::Reducer::ddp_graph_static, |
477 | py::call_guard<py::gil_scoped_release>()) |
478 | .def( |
479 | "_delay_all_reduce" , |
480 | &::c10d::Reducer::delay_all_reduce, |
481 | py::call_guard<py::gil_scoped_release>()) |
482 | .def( |
483 | "_run_comm_hook" , |
484 | [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket) |
485 | -> std::shared_ptr<jit::PythonFutureWrapper> { |
486 | c10::intrusive_ptr<c10::ivalue::Future> fut = |
487 | reducer.run_comm_hook(bucket); |
488 | return std::make_shared<jit::PythonFutureWrapper>(fut); |
489 | }, |
490 | py::call_guard<py::gil_scoped_release>()) |
491 | .def( |
492 | "_run_allreduce_hook" , |
493 | [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket) |
494 | -> std::shared_ptr<jit::PythonFutureWrapper> { |
495 | c10::intrusive_ptr<c10::ivalue::Future> fut = |
496 | reducer.run_allreduce_hook(bucket); |
497 | return std::make_shared<jit::PythonFutureWrapper>(fut); |
498 | }, |
499 | py::call_guard<py::gil_scoped_release>()) |
500 | .def( |
501 | "set_logger" , |
502 | [](::c10d::Reducer& reducer, |
503 | const std::shared_ptr<::c10d::Logger> logger) { |
504 | std::weak_ptr<::c10d::Logger> logger_weakref = logger; |
505 | reducer.set_logger(logger_weakref); |
506 | }); |
507 | |
508 | shared_ptr_class_<::c10d::Logger>(module, "Logger" ) |
509 | .def( |
510 | py::init<std::shared_ptr<::c10d::Reducer>>(), |
511 | py::arg("reducer" ), |
512 | py::call_guard<py::gil_scoped_release>()) |
513 | .def( |
514 | "set_construction_data_and_log" , |
515 | &::c10d::Logger::set_construction_data_and_log, |
516 | py::arg("module_name" ), |
517 | py::arg("device_ids" ), |
518 | py::arg("output_device" ), |
519 | py::arg("broadcast_buffers" ), |
520 | py::arg("has_sync_bn" ), |
521 | py::arg("static_graph" ), |
522 | py::call_guard<py::gil_scoped_release>()) |
523 | .def( |
524 | "set_runtime_stats_and_log" , |
525 | &::c10d::Logger::set_runtime_stats_and_log, |
526 | py::call_guard<py::gil_scoped_release>()) |
527 | .def( |
528 | "set_error_and_log" , |
529 | [](::c10d::Logger& logger, const std::string& error) { |
530 | logger.set_error_and_log(error); |
531 | }, |
532 | py::call_guard<py::gil_scoped_release>()) |
533 | .def( |
534 | "_get_ddp_logging_data" , |
535 | &::c10d::Logger::get_ddp_logging_data, |
536 | py::call_guard<py::gil_scoped_release>()) |
537 | .def( |
538 | "_set_comm_hook_name" , |
539 | &::c10d::Logger::set_comm_hook, |
540 | py::arg("comm_hook" ), |
541 | py::call_guard<py::gil_scoped_release>()) |
542 | .def( |
543 | "_set_uneven_input_join" , |
544 | &::c10d::Logger::set_uneven_input_join, |
545 | py::call_guard<py::gil_scoped_release>()) |
546 | .def( |
547 | "_set_static_graph" , |
548 | &::c10d::Logger::set_static_graph, |
549 | py::call_guard<py::gil_scoped_release>()); |
550 | |
551 | py::enum_<::c10d::DebugLevel>(module, "DebugLevel" , R"( |
552 | An enum whose values correspond to different debug levels of the |
553 | torch.distributed package. Currently supporting OFF, INFO, and DETAIL, |
554 | which can be set via the TORCH_DISTRIBUTED_DEBUG environment variable |
555 | or via ``set_debug_level()`` function. |
556 | )" ) |
557 | .value("OFF" , ::c10d::DebugLevel::Off) |
558 | .value("INFO" , ::c10d::DebugLevel::Info) |
559 | .value("DETAIL" , ::c10d::DebugLevel::Detail); |
560 | |
561 | module |
562 | .def( |
563 | "get_debug_level" , |
564 | ::c10d::debug_level, |
565 | R"(Gets the debug level of the torch.distributed package.)" ) |
566 | .def( |
567 | "set_debug_level" , |
568 | ::c10d::setDebugLevel, |
569 | R"(Sets the debug level of the torch.distributed package.)" ) |
570 | .def( |
571 | "set_debug_level_from_env" , |
572 | ::c10d::setDebugLevelFromEnvironment, |
573 | R"(Sets the debug level of the torch.distributed package from the |
574 | ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)" ); |
575 | |
576 | // TODO(crcrpar): Hardening `ReduceOp`. |
577 | // While keeping most op types as enum value, |
578 | // making `PREMUL_SUM` callable, i.e., allowing for |
579 | // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol. |
580 | // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types |
581 | py::class_<::c10d::ReduceOp> reduce_op( |
582 | module, "ReduceOp" , py::metaclass((PyObject*)GetReduceOpMetaclass()), R"( |
583 | An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, |
584 | ``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``. |
585 | |
586 | ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when |
587 | using the ``NCCL`` backend. |
588 | |
589 | ``AVG`` divides values by the world size before summing across ranks. |
590 | ``AVG`` is only available with the ``NCCL`` backend, |
591 | and only for NCCL versions 2.10 or later. |
592 | |
593 | ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction. |
594 | ``PREMUL_SUM`` is only available with the ``NCCL`` backend, |
595 | and only available for NCCL versions 2.11 or later. Users are supposed to |
596 | use ``torch.distributed._make_nccl_premul_sum``. |
597 | |
598 | Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. |
599 | |
600 | The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. |
601 | They are used in specifying strategies for reduction collectives, e.g., |
602 | :func:`reduce`, :func:`all_reduce_multigpu`, etc. |
603 | |
604 | This class does not support ``__members__`` property.)" ); |
605 | |
606 | reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>()) |
607 | .def_readwrite("op" , &::c10d::ReduceOp::op_); |
608 | // The following are for some kind of backward compatibility. |
609 | // Since c10d::ReduceOp had been an `enum class`, users can do comparison and |
610 | // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here |
611 | // I define some member methods. |
612 | reduce_op |
613 | // todo(crcrpar): Support `RedOpType == ReduceOp`. |
614 | .def( |
615 | // This calls `operator==(const ReduceOp::RedOpType)` |
616 | "__eq__" , |
617 | [](const ::c10d::ReduceOp& self, |
618 | const ::c10d::ReduceOp::RedOpType& other) { |
619 | return self == other; |
620 | }) |
621 | .def( |
622 | // This calls `operator==(const ReduceOp)` for the future support of |
623 | // `PREMUL_SUM` comparison |
624 | "__eq__" , |
625 | [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) { |
626 | return self == other; |
627 | }) |
628 | .def( |
629 | // With the above custom `__eq__`'s, I have to manually support the |
630 | // other types. |
631 | "__eq__" , |
632 | [](const ::c10d::ReduceOp& self, py::object) { return false; }) |
633 | .def( |
634 | "__hash__" , |
635 | [](const ::c10d::ReduceOp& self) { |
636 | return static_cast<uint8_t>(self.op_); |
637 | }) |
638 | .def( |
639 | "__copy__" , |
640 | [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); }) |
641 | .def( |
642 | "__deepcopy__" , |
643 | [](const ::c10d::ReduceOp& self, const py::dict& memo) { |
644 | return ::c10d::ReduceOp(self); |
645 | }) |
646 | .def(py::pickle( |
647 | [](const ::c10d::ReduceOp& r) { |
648 | // __getstate__ |
649 | if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { |
650 | return py::make_tuple(r.op_, py::none()); |
651 | } |
652 | TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp" ); |
653 | const auto* preMulSupplement = |
654 | reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>( |
655 | r.supplement_.get()); |
656 | if (!preMulSupplement->tensor_factor.defined()) { |
657 | return py::make_tuple(r.op_, preMulSupplement->double_factor); |
658 | } else { |
659 | return py::make_tuple(r.op_, preMulSupplement->tensor_factor); |
660 | } |
661 | }, |
662 | [](const py::tuple t) { |
663 | // __setstate__ |
664 | TORCH_CHECK(t.size() == 2, "Invalid state" ); |
665 | const auto op = |
666 | static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>()); |
667 | if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { |
668 | return ::c10d::ReduceOp(op); |
669 | } |
670 | const auto preMulSupplement_factor = t[1]; |
671 | if (py::isinstance<py::float_>(preMulSupplement_factor)) { |
672 | return ::c10d::makeNCCLPreMulSum(t[1].cast<double>()); |
673 | } else { |
674 | return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>()); |
675 | } |
676 | })); |
677 | |
678 | py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType" ) |
679 | .value("SUM" , ::c10d::ReduceOp::RedOpType::SUM) |
680 | .value("AVG" , ::c10d::ReduceOp::RedOpType::AVG) |
681 | .value("PRODUCT" , ::c10d::ReduceOp::RedOpType::PRODUCT) |
682 | .value("MIN" , ::c10d::ReduceOp::RedOpType::MIN) |
683 | .value("MAX" , ::c10d::ReduceOp::RedOpType::MAX) |
684 | .value("BAND" , ::c10d::ReduceOp::RedOpType::BAND) |
685 | .value("BOR" , ::c10d::ReduceOp::RedOpType::BOR) |
686 | .value("BXOR" , ::c10d::ReduceOp::RedOpType::BXOR) |
687 | .value("PREMUL_SUM" , ::c10d::ReduceOp::RedOpType::PREMUL_SUM) |
688 | .export_values(); |
689 | |
690 | // note(crcrpar): This could be removed because users will not pass |
691 | // `RedOpType` to reduce collective ops Ref: [Implicit |
692 | // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions) |
693 | // Let us skip the explicit construction of `c10d::ReduceOp` from |
694 | // `c10d::ReduceOp::RedOpType` in Python. |
695 | py::implicitly_convertible<::c10d::ReduceOp::RedOpType, ::c10d::ReduceOp>(); |
696 | |
697 | module |
698 | .def( |
699 | "_make_nccl_premul_sum" , |
700 | &::c10d::makeNCCLPreMulSum<double>, |
701 | py::arg("factor" ).noconvert(), |
702 | py::return_value_policy::copy, // seems safest |
703 | py::call_guard<py::gil_scoped_release>()) |
704 | .def( |
705 | "_make_nccl_premul_sum" , |
706 | &::c10d::makeNCCLPreMulSum<at::Tensor>, |
707 | py::arg("factor" ).noconvert(), |
708 | py::return_value_policy::copy, // seems safest |
709 | py::call_guard<py::gil_scoped_release>()); |
710 | |
711 | py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions" ) |
712 | .def(py::init<>()) |
713 | .def_readwrite("rootRank" , &::c10d::BroadcastOptions::rootRank) |
714 | .def_readwrite("rootTensor" , &::c10d::BroadcastOptions::rootTensor) |
715 | .def_readwrite("timeout" , &::c10d::BroadcastOptions::timeout); |
716 | |
717 | py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions" ) |
718 | .def(py::init<>()) |
719 | .def_readwrite("reduceOp" , &::c10d::AllreduceOptions::reduceOp) |
720 | .def_readwrite("timeout" , &::c10d::AllreduceOptions::timeout); |
721 | |
722 | py::class_<::c10d::AllreduceCoalescedOptions>( |
723 | module, "AllreduceCoalescedOptions" ) |
724 | .def(py::init<>()) |
725 | .def_readwrite("reduceOp" , &::c10d::AllreduceCoalescedOptions::reduceOp) |
726 | .def_readwrite("timeout" , &::c10d::AllreduceCoalescedOptions::timeout); |
727 | |
728 | py::class_<::c10d::ReduceOptions>(module, "ReduceOptions" ) |
729 | .def(py::init<>()) |
730 | .def_readwrite("reduceOp" , &::c10d::ReduceOptions::reduceOp) |
731 | .def_readwrite("rootRank" , &::c10d::ReduceOptions::rootRank) |
732 | .def_readwrite("rootTensor" , &::c10d::ReduceOptions::rootTensor) |
733 | .def_readwrite("timeout" , &::c10d::ReduceOptions::timeout); |
734 | |
735 | py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions" ) |
736 | .def(py::init<>()) |
737 | .def_readwrite("timeout" , &::c10d::AllgatherOptions::timeout); |
738 | |
739 | py::class_<::c10d::GatherOptions>(module, "GatherOptions" ) |
740 | .def(py::init<>()) |
741 | .def_readwrite("rootRank" , &::c10d::GatherOptions::rootRank) |
742 | .def_readwrite("timeout" , &::c10d::GatherOptions::timeout); |
743 | |
744 | py::class_<::c10d::ScatterOptions>(module, "ScatterOptions" ) |
745 | .def(py::init<>()) |
746 | .def_readwrite("rootRank" , &::c10d::ScatterOptions::rootRank) |
747 | .def_readwrite("timeout" , &::c10d::ScatterOptions::timeout); |
748 | |
749 | py::class_<::c10d::ReduceScatterOptions>(module, "ReduceScatterOptions" ) |
750 | .def(py::init<>()) |
751 | .def_readwrite("reduceOp" , &::c10d::ReduceScatterOptions::reduceOp) |
752 | .def_readwrite("timeout" , &::c10d::ReduceScatterOptions::timeout); |
753 | |
754 | py::class_<::c10d::BarrierOptions>(module, "BarrierOptions" ) |
755 | .def(py::init<>()) |
756 | .def_readwrite("device_ids" , &::c10d::BarrierOptions::device_ids) |
757 | .def_readwrite("timeout" , &::c10d::BarrierOptions::timeout); |
758 | |
759 | py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions" ) |
760 | .def(py::init<>()) |
761 | .def_readwrite("timeout" , &::c10d::AllToAllOptions::timeout); |
762 | |
763 | py::class_<::c10d::DistributedBackendOptions>( |
764 | module, "_DistributedBackendOptions" ) |
765 | .def(py::init<>()) |
766 | .def_readwrite("store" , &::c10d::DistributedBackendOptions::store) |
767 | .def_readwrite( |
768 | "group_rank" , &::c10d::DistributedBackendOptions::group_rank) |
769 | .def_readwrite( |
770 | "group_size" , &::c10d::DistributedBackendOptions::group_size) |
771 | .def_readwrite("timeout" , &::c10d::DistributedBackendOptions::timeout) |
772 | .def_readwrite("group_id" , &::c10d::DistributedBackendOptions::group_id) |
773 | .def_readwrite( |
774 | "global_ranks_in_group" , |
775 | &::c10d::DistributedBackendOptions::global_ranks_in_group); |
776 | |
777 | auto store = |
778 | py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( |
779 | module, |
780 | "Store" , |
781 | R"( |
782 | Base class for all store implementations, such as the 3 provided by PyTorch |
783 | distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`, |
784 | and :class:`~torch.distributed.HashStore`). |
785 | )" ) |
786 | // Default constructor. |
787 | .def(py::init<>()) |
788 | // Convert from std::string to std::vector<uint8>. |
789 | .def( |
790 | "set" , |
791 | [](::c10d::Store& store, |
792 | const std::string& key, |
793 | const std::string& value) { |
794 | std::vector<uint8_t> value_(value.begin(), value.end()); |
795 | store.set(key, value_); |
796 | }, |
797 | py::call_guard<py::gil_scoped_release>(), |
798 | R"( |
799 | Inserts the key-value pair into the store based on the supplied ``key`` and |
800 | ``value``. If ``key`` already exists in the store, it will overwrite the old |
801 | value with the new supplied ``value``. |
802 | |
803 | Arguments: |
804 | key (str): The key to be added to the store. |
805 | value (str): The value associated with ``key`` to be added to the store. |
806 | |
807 | Example:: |
808 | >>> import torch.distributed as dist |
809 | >>> from datetime import timedelta |
810 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
811 | >>> store.set("first_key", "first_value") |
812 | >>> # Should return "first_value" |
813 | >>> store.get("first_key") |
814 | )" ) |
815 | .def( |
816 | "compare_set" , |
817 | [](::c10d::Store& store, |
818 | const std::string& key, |
819 | const std::string& expected_value, |
820 | const std::string& desired_value) -> py::bytes { |
821 | std::vector<uint8_t> expectedValue_( |
822 | expected_value.begin(), expected_value.end()); |
823 | std::vector<uint8_t> desiredValue_( |
824 | desired_value.begin(), desired_value.end()); |
825 | auto value = |
826 | store.compareSet(key, expectedValue_, desiredValue_); |
827 | return py::bytes( |
828 | reinterpret_cast<char*>(value.data()), value.size()); |
829 | }, |
830 | py::call_guard<py::gil_scoped_release>(), |
831 | R"( |
832 | Inserts the key-value pair into the store based on the supplied ``key`` and |
833 | performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value`` |
834 | will only be set if ``expected_value`` for the ``key`` already exists in the store or if ``expected_value`` |
835 | is an empty string. |
836 | |
837 | Arguments: |
838 | key (str): The key to be checked in the store. |
839 | expected_value (str): The value associated with ``key`` to be checked before insertion. |
840 | desired_value (str): The value associated with ``key`` to be added to the store. |
841 | |
842 | Example:: |
843 | >>> import torch.distributed as dist |
844 | >>> from datetime import timedelta |
845 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
846 | >>> store.set("key", "first_value") |
847 | >>> store.compare_set("key", "first_value", "second_value") |
848 | >>> # Should return "second_value" |
849 | >>> store.get("key") |
850 | )" ) |
851 | // Convert from std::vector<uint8_t> to py::bytes. |
852 | // The returned value is not guaranteed to be valid UTF-8. |
853 | .def( |
854 | "get" , |
855 | [](::c10d::Store& store, const std::string& key) -> py::bytes { |
856 | auto value = [&]() { |
857 | py::gil_scoped_release guard; |
858 | return store.get(key); |
859 | }(); |
860 | return py::bytes( |
861 | reinterpret_cast<char*>(value.data()), value.size()); |
862 | }, |
863 | R"( |
864 | Retrieves the value associated with the given ``key`` in the store. If ``key`` is not |
865 | present in the store, the function will wait for ``timeout``, which is defined |
866 | when initializing the store, before throwing an exception. |
867 | |
868 | Arguments: |
869 | key (str): The function will return the value associated with this key. |
870 | |
871 | Returns: |
872 | Value associated with ``key`` if ``key`` is in the store. |
873 | |
874 | Example:: |
875 | >>> import torch.distributed as dist |
876 | >>> from datetime import timedelta |
877 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
878 | >>> store.set("first_key", "first_value") |
879 | >>> # Should return "first_value" |
880 | >>> store.get("first_key") |
881 | )" ) |
882 | .def( |
883 | "add" , |
884 | &::c10d::Store::add, |
885 | py::call_guard<py::gil_scoped_release>(), |
886 | R"( |
887 | The first call to add for a given ``key`` creates a counter associated |
888 | with ``key`` in the store, initialized to ``amount``. Subsequent calls to add |
889 | with the same ``key`` increment the counter by the specified ``amount``. |
890 | Calling :meth:`~torch.distributed.store.add` with a key that has already |
891 | been set in the store by :meth:`~torch.distributed.store.set` will result |
892 | in an exception. |
893 | |
894 | Arguments: |
895 | key (str): The key in the store whose counter will be incremented. |
896 | amount (int): The quantity by which the counter will be incremented. |
897 | |
898 | Example:: |
899 | >>> import torch.distributed as dist |
900 | >>> from datetime import timedelta |
901 | >>> # Using TCPStore as an example, other store types can also be used |
902 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
903 | >>> store.add("first_key", 1) |
904 | >>> store.add("first_key", 6) |
905 | >>> # Should return 7 |
906 | >>> store.get("first_key") |
907 | )" ) |
908 | .def( |
909 | "delete_key" , |
910 | &::c10d::Store::deleteKey, |
911 | py::call_guard<py::gil_scoped_release>(), |
912 | R"( |
913 | Deletes the key-value pair associated with ``key`` from the store. Returns |
914 | `true` if the key was successfully deleted, and `false` if it was not. |
915 | |
916 | .. warning:: |
917 | The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API |
918 | with the :class:`~torch.distributed.FileStore` will result in an exception. |
919 | |
920 | Arguments: |
921 | key (str): The key to be deleted from the store |
922 | |
923 | Returns: |
924 | `True` if ``key`` was deleted, otherwise `False`. |
925 | |
926 | Example:: |
927 | >>> import torch.distributed as dist |
928 | >>> from datetime import timedelta |
929 | >>> # Using TCPStore as an example, HashStore can also be used |
930 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
931 | >>> store.set("first_key") |
932 | >>> # This should return true |
933 | >>> store.delete_key("first_key") |
934 | >>> # This should return false |
935 | >>> store.delete_key("bad_key") |
936 | )" ) |
937 | .def( |
938 | "num_keys" , |
939 | &::c10d::Store::getNumKeys, |
940 | py::call_guard<py::gil_scoped_release>(), |
941 | R"( |
942 | Returns the number of keys set in the store. Note that this number will typically |
943 | be one greater than the number of keys added by :meth:`~torch.distributed.store.set` |
944 | and :meth:`~torch.distributed.store.add` since one key is used to coordinate all |
945 | the workers using the store. |
946 | |
947 | .. warning:: |
948 | When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained. |
949 | |
950 | Returns: |
951 | The number of keys present in the store. |
952 | |
953 | Example:: |
954 | >>> import torch.distributed as dist |
955 | >>> from datetime import timedelta |
956 | >>> # Using TCPStore as an example, other store types can also be used |
957 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
958 | >>> store.set("first_key", "first_value") |
959 | >>> # This should return 2 |
960 | >>> store.num_keys() |
961 | )" ) |
962 | .def( |
963 | "set_timeout" , |
964 | &::c10d::Store::setTimeout, |
965 | py::call_guard<py::gil_scoped_release>(), |
966 | R"( |
967 | Sets the store's default timeout. This timeout is used during initialization and in |
968 | :meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`. |
969 | |
970 | Arguments: |
971 | timeout (timedelta): timeout to be set in the store. |
972 | |
973 | Example:: |
974 | >>> import torch.distributed as dist |
975 | >>> from datetime import timedelta |
976 | >>> # Using TCPStore as an example, other store types can also be used |
977 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
978 | >>> store.set_timeout(timedelta(seconds=10)) |
979 | >>> # This will throw an exception after 10 seconds |
980 | >>> store.wait(["bad_key"]) |
981 | )" ) |
982 | .def( |
983 | "wait" , |
984 | [](::c10d::Store& store, const std::vector<std::string>& keys) { |
985 | store.wait(keys); |
986 | }, |
987 | py::call_guard<py::gil_scoped_release>(), |
988 | R"( |
989 | Waits for each key in ``keys`` to be added to the store. If not all keys are |
990 | set before the ``timeout`` (set during store initialization), then ``wait`` |
991 | will throw an exception. |
992 | |
993 | Arguments: |
994 | keys (list): List of keys on which to wait until they are set in the store. |
995 | |
996 | Example:: |
997 | >>> import torch.distributed as dist |
998 | >>> from datetime import timedelta |
999 | >>> # Using TCPStore as an example, other store types can also be used |
1000 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
1001 | >>> # This will throw an exception after 30 seconds |
1002 | >>> store.wait(["bad_key"]) |
1003 | )" ) |
1004 | .def( |
1005 | "wait" , |
1006 | [](::c10d::Store& store, |
1007 | const std::vector<std::string>& keys, |
1008 | const std::chrono::milliseconds& timeout) { |
1009 | store.wait(keys, timeout); |
1010 | }, |
1011 | py::call_guard<py::gil_scoped_release>(), |
1012 | R"( |
1013 | Waits for each key in ``keys`` to be added to the store, and throws an exception |
1014 | if the keys have not been set by the supplied ``timeout``. |
1015 | |
1016 | Arguments: |
1017 | keys (list): List of keys on which to wait until they are set in the store. |
1018 | timeout (timedelta): Time to wait for the keys to be added before throwing an exception. |
1019 | |
1020 | Example:: |
1021 | >>> import torch.distributed as dist |
1022 | >>> from datetime import timedelta |
1023 | >>> # Using TCPStore as an example, other store types can also be used |
1024 | >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
1025 | >>> # This will throw an exception after 10 seconds |
1026 | >>> store.wait(["bad_key"], timedelta(seconds=10)) |
1027 | )" ) |
1028 | .def_property_readonly( |
1029 | "timeout" , |
1030 | &::c10d::Store::getTimeout, |
1031 | R"(Gets the timeout of the store.)" ); |
1032 | |
1033 | intrusive_ptr_class_<::c10d::FileStore>( |
1034 | module, |
1035 | "FileStore" , |
1036 | store, |
1037 | R"( |
1038 | A store implementation that uses a file to store the underlying key-value pairs. |
1039 | |
1040 | Arguments: |
1041 | file_name (str): path of the file in which to store the key-value pairs |
1042 | world_size (int, optional): The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users). |
1043 | |
1044 | Example:: |
1045 | >>> import torch.distributed as dist |
1046 | >>> store1 = dist.FileStore("/tmp/filestore", 2) |
1047 | >>> store2 = dist.FileStore("/tmp/filestore", 2) |
1048 | >>> # Use any of the store methods from either the client or server after initialization |
1049 | >>> store1.set("first_key", "first_value") |
1050 | >>> store2.get("first_key") |
1051 | |
1052 | )" ) |
1053 | .def( |
1054 | py::init<const std::string&, int>(), |
1055 | py::arg("file_name" ), |
1056 | py::arg("world_size" ) = -1) |
1057 | .def_property_readonly( |
1058 | "path" , |
1059 | &::c10d::FileStore::getPath, |
1060 | R"(Gets the path of the file used by FileStore to store key-value pairs.)" ); |
1061 | |
1062 | #ifndef _WIN32 |
1063 | intrusive_ptr_class_<::c10d::HashStore>( |
1064 | module, |
1065 | "HashStore" , |
1066 | store, |
1067 | R"( |
1068 | A thread-safe store implementation based on an underlying hashmap. This store can be used |
1069 | within the same process (for example, by other threads), but cannot be used across processes. |
1070 | |
1071 | Example:: |
1072 | >>> import torch.distributed as dist |
1073 | >>> store = dist.HashStore() |
1074 | >>> # store can be used from other threads |
1075 | >>> # Use any of the store methods after initialization |
1076 | >>> store.set("first_key", "first_value") |
1077 | )" ) |
1078 | .def(py::init<>()); |
1079 | #endif |
1080 | |
1081 | intrusive_ptr_class_<::c10d::TCPStore>( |
1082 | module, |
1083 | "TCPStore" , |
1084 | store, |
1085 | R"( |
1086 | A TCP-based distributed key-value store implementation. The server store holds |
1087 | the data, while the client stores can connect to the server store over TCP and |
1088 | perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value |
1089 | pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. There |
1090 | should always be one server store initialized because the client store(s) will wait for |
1091 | the server to establish a connection. |
1092 | |
1093 | Arguments: |
1094 | host_name (str): The hostname or IP Address the server store should run on. |
1095 | port (int): The port on which the server store should listen for incoming requests. |
1096 | world_size (int, optional): The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users). |
1097 | is_master (bool, optional): True when initializing the server store and False for client stores. Default is False. |
1098 | timeout (timedelta, optional): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. Default is timedelta(seconds=300) |
1099 | wait_for_worker (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. |
1100 | |
1101 | Example:: |
1102 | >>> import torch.distributed as dist |
1103 | >>> from datetime import timedelta |
1104 | >>> # Run on process 1 (server) |
1105 | >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) |
1106 | >>> # Run on process 2 (client) |
1107 | >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) |
1108 | >>> # Use any of the store methods from either the client or server after initialization |
1109 | >>> server_store.set("first_key", "first_value") |
1110 | >>> client_store.get("first_key") |
1111 | )" ) |
1112 | .def( |
1113 | py::init([](const std::string& host, |
1114 | uint16_t port, |
1115 | c10::optional<int> worldSize, |
1116 | bool isServer, |
1117 | std::chrono::milliseconds timeout, |
1118 | bool waitWorkers, |
1119 | bool multiTenant) { |
1120 | c10::optional<std::size_t> numWorkers = c10::nullopt; |
1121 | if (worldSize.has_value() && worldSize.value() > -1) { |
1122 | numWorkers = static_cast<std::size_t>(worldSize.value()); |
1123 | } |
1124 | |
1125 | ::c10d::TCPStoreOptions opts{ |
1126 | port, isServer, numWorkers, waitWorkers, timeout, multiTenant}; |
1127 | |
1128 | return c10::make_intrusive<::c10d::TCPStore>(host, opts); |
1129 | }), |
1130 | py::arg("host_name" ), |
1131 | py::arg("port" ), |
1132 | py::arg("world_size" ) = py::none(), |
1133 | // using noconvert() requires this argument to be True or False |
1134 | // prevents accidental implicit conversion to bool |
1135 | py::arg("is_master" ).noconvert() = false, |
1136 | py::arg("timeout" ) = |
1137 | std::chrono::milliseconds(::c10d::Store::kDefaultTimeout), |
1138 | py::arg("wait_for_workers" ) = true, |
1139 | py::arg("multi_tenant" ) = false) |
1140 | .def_property_readonly( |
1141 | "host" , |
1142 | &::c10d::TCPStore::getHost, |
1143 | R"(Gets the hostname on which the store listens for requests.)" ) |
1144 | |
1145 | .def_property_readonly( |
1146 | "port" , |
1147 | &::c10d::TCPStore::getPort, |
1148 | R"(Gets the port number on which the store listens for requests.)" ); |
1149 | |
1150 | intrusive_ptr_class_<::c10d::PrefixStore>( |
1151 | module, |
1152 | "PrefixStore" , |
1153 | store, |
1154 | R"( |
1155 | A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`, |
1156 | :class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`) |
1157 | that adds a prefix to each key inserted to the store. |
1158 | |
1159 | Arguments: |
1160 | prefix (str): The prefix string that is prepended to each key before being inserted into the store. |
1161 | store (torch.distributed.store): A store object that forms the underlying key-value store. |
1162 | )" ) |
1163 | .def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>()) |
1164 | .def_property_readonly( |
1165 | "underlying_store" , |
1166 | &::c10d::PrefixStore::getUnderlyingStore, |
1167 | R"(Gets the underlying store object that PrefixStore wraps around.)" ); |
1168 | |
1169 | auto processGroup = |
1170 | py::class_< |
1171 | ::c10d::ProcessGroup, |
1172 | c10::intrusive_ptr<::c10d::ProcessGroup>, |
1173 | ::c10d::PyProcessGroup>(module, "ProcessGroup" ) |
1174 | .def(py::init<int, int>()) |
1175 | .def( |
1176 | py::init< |
1177 | const c10::intrusive_ptr<::c10d::Store>&, |
1178 | int, |
1179 | int, |
1180 | c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), |
1181 | py::call_guard<py::gil_scoped_release>()) |
1182 | .def("rank" , &::c10d::ProcessGroup::getRank) |
1183 | .def("size" , &::c10d::ProcessGroup::getSize) |
1184 | .def("name" , &::c10d::ProcessGroup::getBackendName) |
1185 | .def_property_readonly("options" , &::c10d::ProcessGroup::getOptions) |
1186 | .def( |
1187 | "broadcast" , |
1188 | &::c10d::ProcessGroup::broadcast, |
1189 | py::arg("tensors" ), |
1190 | py::arg("opts" ) = ::c10d::BroadcastOptions(), |
1191 | py::call_guard<py::gil_scoped_release>()) |
1192 | .def( |
1193 | "broadcast" , |
1194 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1195 | at::Tensor& x, |
1196 | int rootRank) { |
1197 | ::c10d::BroadcastOptions opts; |
1198 | opts.rootRank = rootRank; |
1199 | std::vector<at::Tensor> tensors = {x}; |
1200 | return self->broadcast(tensors, opts); |
1201 | }, |
1202 | py::arg("tensor" ), |
1203 | py::arg("root" ), |
1204 | py::call_guard<py::gil_scoped_release>()) |
1205 | .def( |
1206 | "allreduce" , |
1207 | &::c10d::ProcessGroup::allreduce, |
1208 | py::arg("tensors" ), |
1209 | py::arg("opts" ) = ::c10d::AllreduceOptions(), |
1210 | py::call_guard<py::gil_scoped_release>()) |
1211 | .def( |
1212 | "allreduce" , |
1213 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1214 | std::vector<at::Tensor>& xs, |
1215 | ::c10d::ReduceOp op) { |
1216 | ::c10d::AllreduceOptions opts; |
1217 | opts.reduceOp = op; |
1218 | return self->allreduce(xs, opts); |
1219 | }, |
1220 | py::arg("tensors" ), |
1221 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1222 | py::call_guard<py::gil_scoped_release>()) |
1223 | |
1224 | .def( |
1225 | "allreduce" , |
1226 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1227 | at::Tensor& x, |
1228 | ::c10d::ReduceOp op) { |
1229 | ::c10d::AllreduceOptions opts; |
1230 | opts.reduceOp = op; |
1231 | std::vector<at::Tensor> xs = {x}; |
1232 | return self->allreduce(xs, opts); |
1233 | }, |
1234 | py::arg("tensor" ), |
1235 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1236 | py::call_guard<py::gil_scoped_release>()) |
1237 | .def( |
1238 | "allreduce_coalesced" , |
1239 | &::c10d::ProcessGroup::allreduce_coalesced, |
1240 | py::arg("tensors" ), |
1241 | py::arg("opts" ) = ::c10d::AllreduceCoalescedOptions(), |
1242 | py::call_guard<py::gil_scoped_release>()) |
1243 | |
1244 | .def( |
1245 | "reduce" , |
1246 | &::c10d::ProcessGroup::reduce, |
1247 | py::arg("tensors" ), |
1248 | py::arg("opts" ) = ::c10d::ReduceOptions(), |
1249 | py::call_guard<py::gil_scoped_release>()) |
1250 | |
1251 | .def( |
1252 | "reduce" , |
1253 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1254 | at::Tensor& x, |
1255 | int rootRank, |
1256 | ::c10d::ReduceOp op) { |
1257 | ::c10d::ReduceOptions opts; |
1258 | opts.reduceOp = op; |
1259 | opts.rootRank = rootRank; |
1260 | std::vector<at::Tensor> xs = {x}; |
1261 | return self->reduce(xs, opts); |
1262 | }, |
1263 | py::arg("tensor" ), |
1264 | py::arg("root" ), |
1265 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1266 | py::call_guard<py::gil_scoped_release>()) |
1267 | .def( |
1268 | "allgather" , |
1269 | &::c10d::ProcessGroup::allgather, |
1270 | py::arg("output_tensors" ), |
1271 | py::arg("input_tensors" ), |
1272 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1273 | py::call_guard<py::gil_scoped_release>()) |
1274 | .def( |
1275 | "allgather" , |
1276 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1277 | std::vector<at::Tensor>& output, |
1278 | at::Tensor& input) { |
1279 | std::vector<std::vector<at::Tensor>> outputs = {output}; |
1280 | std::vector<at::Tensor> inputs = {input}; |
1281 | return self->allgather( |
1282 | outputs, inputs, ::c10d::AllgatherOptions()); |
1283 | }, |
1284 | py::arg("output_tensors" ), |
1285 | py::arg("input_tensor" ), |
1286 | py::call_guard<py::gil_scoped_release>()) |
1287 | .def( |
1288 | "_allgather_base" , |
1289 | &::c10d::ProcessGroup::_allgather_base, |
1290 | py::arg("output" ), |
1291 | py::arg("input" ), |
1292 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1293 | py::call_guard<py::gil_scoped_release>()) |
1294 | .def( |
1295 | "allgather_coalesced" , |
1296 | &::c10d::ProcessGroup::allgather_coalesced, |
1297 | py::arg("output_lists" ), |
1298 | py::arg("input_list" ), |
1299 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1300 | py::call_guard<py::gil_scoped_release>()) |
1301 | .def( |
1302 | "gather" , |
1303 | &::c10d::ProcessGroup::gather, |
1304 | py::arg("output_tensors" ), |
1305 | py::arg("input_tensors" ), |
1306 | py::arg("opts" ) = ::c10d::GatherOptions(), |
1307 | py::call_guard<py::gil_scoped_release>()) |
1308 | |
1309 | .def( |
1310 | "gather" , |
1311 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1312 | std::vector<at::Tensor>& output, |
1313 | at::Tensor& input, |
1314 | int rootRank) { |
1315 | ::c10d::GatherOptions opts; |
1316 | opts.rootRank = rootRank; |
1317 | std::vector<std::vector<at::Tensor>> outputs = {output}; |
1318 | std::vector<at::Tensor> inputs = {input}; |
1319 | return self->gather(outputs, inputs, opts); |
1320 | }, |
1321 | py::arg("output_tensors" ), |
1322 | py::arg("input_tensor" ), |
1323 | py::arg("root" ), |
1324 | py::call_guard<py::gil_scoped_release>()) |
1325 | .def( |
1326 | "scatter" , |
1327 | &::c10d::ProcessGroup::scatter, |
1328 | py::arg("output_tensors" ), |
1329 | py::arg("input_tensors" ), |
1330 | py::arg("opts" ) = ::c10d::ScatterOptions(), |
1331 | py::call_guard<py::gil_scoped_release>()) |
1332 | .def( |
1333 | "scatter" , |
1334 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1335 | at::Tensor& output, |
1336 | std::vector<at::Tensor>& input, |
1337 | int rootRank) { |
1338 | ::c10d::ScatterOptions opts; |
1339 | opts.rootRank = rootRank; |
1340 | std::vector<std::vector<at::Tensor>> inputs = {input}; |
1341 | std::vector<at::Tensor> outputs = {output}; |
1342 | return self->scatter(outputs, inputs, opts); |
1343 | }, |
1344 | py::arg("output_tensor" ), |
1345 | py::arg("input_tensors" ), |
1346 | py::arg("root" ), |
1347 | py::call_guard<py::gil_scoped_release>()) |
1348 | .def( |
1349 | "reduce_scatter" , |
1350 | &::c10d::ProcessGroup::reduce_scatter, |
1351 | py::arg("output_tensors" ), |
1352 | py::arg("input_tensors" ), |
1353 | py::arg("opts" ) = ::c10d::ReduceScatterOptions(), |
1354 | py::call_guard<py::gil_scoped_release>()) |
1355 | .def( |
1356 | "reduce_scatter" , |
1357 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1358 | at::Tensor& output, |
1359 | std::vector<at::Tensor>& input, |
1360 | ::c10d::ReduceOp op) { |
1361 | std::vector<at::Tensor> outputs = {output}; |
1362 | std::vector<std::vector<at::Tensor>> inputs = {input}; |
1363 | ::c10d::ReduceScatterOptions opts; |
1364 | opts.reduceOp = op; |
1365 | return self->reduce_scatter(outputs, inputs, opts); |
1366 | }, |
1367 | py::arg("output" ), |
1368 | py::arg("input" ), |
1369 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1370 | py::call_guard<py::gil_scoped_release>()) |
1371 | .def( |
1372 | "_reduce_scatter_base" , |
1373 | &::c10d::ProcessGroup::_reduce_scatter_base, |
1374 | py::arg("outputTensor" ), |
1375 | py::arg("inputTensor" ), |
1376 | py::arg("opts" ) = ::c10d::ReduceScatterOptions(), |
1377 | py::call_guard<py::gil_scoped_release>()) |
1378 | .def( |
1379 | "alltoall_base" , |
1380 | &::c10d::ProcessGroup::alltoall_base, |
1381 | py::arg("output" ), |
1382 | py::arg("input" ), |
1383 | py::arg("output_split_sizes" ), |
1384 | py::arg("input_split_sizes" ), |
1385 | py::arg("opts" ) = ::c10d::AllToAllOptions(), |
1386 | py::call_guard<py::gil_scoped_release>()) |
1387 | .def( |
1388 | "alltoall" , |
1389 | &::c10d::ProcessGroup::alltoall, |
1390 | py::arg("output_tensors" ), |
1391 | py::arg("input_tensors" ), |
1392 | py::arg("opts" ) = ::c10d::AllToAllOptions(), |
1393 | py::call_guard<py::gil_scoped_release>()) |
1394 | .def( |
1395 | "send" , |
1396 | &::c10d::ProcessGroup::send, |
1397 | py::arg("tensors" ), |
1398 | py::arg("dstRank" ), |
1399 | py::arg("tag" ), |
1400 | py::call_guard<py::gil_scoped_release>()) |
1401 | .def( |
1402 | "recv" , |
1403 | &::c10d::ProcessGroup::recv, |
1404 | py::arg("tensors" ), |
1405 | py::arg("srcRank" ), |
1406 | py::arg("tag" ), |
1407 | py::call_guard<py::gil_scoped_release>()) |
1408 | .def( |
1409 | "recv_anysource" , |
1410 | &::c10d::ProcessGroup::recvAnysource, |
1411 | py::call_guard<py::gil_scoped_release>()) |
1412 | .def( |
1413 | "barrier" , |
1414 | &::c10d::ProcessGroup::barrier, |
1415 | py::arg("opts" ) = ::c10d::BarrierOptions(), |
1416 | py::call_guard<py::gil_scoped_release>()) |
1417 | .def( |
1418 | "_set_sequence_number_for_group" , |
1419 | &::c10d::ProcessGroup::setSequenceNumberForGroup, |
1420 | py::call_guard<py::gil_scoped_release>()) |
1421 | .def( |
1422 | "_get_sequence_number_for_group" , |
1423 | &::c10d::ProcessGroup::getSequenceNumberForGroup, |
1424 | py::call_guard<py::gil_scoped_release>()) |
1425 | .def( |
1426 | "monitored_barrier" , |
1427 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1428 | const std::chrono::milliseconds& timeout, |
1429 | bool waitAllRanks) { |
1430 | ::c10d::BarrierOptions opts; |
1431 | opts.timeout = timeout; |
1432 | return self->monitoredBarrier(opts, waitAllRanks); |
1433 | }, |
1434 | py::arg("timeout" ) = ::c10d::kUnsetTimeout, |
1435 | py::arg("wait_all_ranks" ) = false, |
1436 | py::call_guard<py::gil_scoped_release>()) |
1437 | .def( |
1438 | "_get_backend_name" , |
1439 | &::c10d::ProcessGroup::getBackendName, |
1440 | py::call_guard<py::gil_scoped_release>()) |
1441 | .def( |
1442 | "_start_coalescing" , |
1443 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1444 | const c10::Device& device) { |
1445 | self->startCoalescing(device.type()); |
1446 | }, |
1447 | py::arg("device_type" ), |
1448 | py::call_guard<py::gil_scoped_release>()) |
1449 | .def( |
1450 | "_end_coalescing" , |
1451 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1452 | const c10::Device& device, |
1453 | std::vector<c10::intrusive_ptr<::c10d::Work>>& reqs) { |
1454 | self->endCoalescing(device.type(), reqs); |
1455 | }, |
1456 | py::arg("device_type" ), |
1457 | py::arg("reqs" ), |
1458 | py::call_guard<py::gil_scoped_release>()) |
1459 | .def( |
1460 | "_register_backend" , |
1461 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1462 | const c10::Device& device, |
1463 | const ::c10d::ProcessGroup::BackendType& backendType, |
1464 | const c10::optional<c10::intrusive_ptr<::c10d::Backend>>& |
1465 | backend) { |
1466 | self->setBackend(device.type(), backendType, backend); |
1467 | }, |
1468 | py::arg("device" ), |
1469 | py::arg("backend_type" ), |
1470 | py::arg("backend" ) = |
1471 | c10::optional<c10::intrusive_ptr<::c10d::Backend>>(), |
1472 | py::call_guard<py::gil_scoped_release>()) |
1473 | .def( |
1474 | "_get_backend" , |
1475 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
1476 | const c10::Device& device) { |
1477 | return self->getBackend(device.type()); |
1478 | }, |
1479 | py::arg("device" ), |
1480 | py::call_guard<py::gil_scoped_release>()); |
1481 | |
1482 | py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType" ) |
1483 | .value("UNDEFINED" , ::c10d::ProcessGroup::BackendType::UNDEFINED) |
1484 | .value("GLOO" , ::c10d::ProcessGroup::BackendType::GLOO) |
1485 | .value("NCCL" , ::c10d::ProcessGroup::BackendType::NCCL) |
1486 | .value("UCC" , ::c10d::ProcessGroup::BackendType::UCC) |
1487 | .value("MPI" , ::c10d::ProcessGroup::BackendType::MPI) |
1488 | .value("CUSTOM" , ::c10d::ProcessGroup::BackendType::CUSTOM) |
1489 | .export_values(); |
1490 | |
1491 | // base ProcessGroup::Options binding |
1492 | auto processGroupOptions = |
1493 | intrusive_ptr_class_<::c10d::ProcessGroup::Options>( |
1494 | processGroup, |
1495 | "Options" , |
1496 | R"( |
1497 | Base class for all processs group options implementations, such as the nccl |
1498 | options :class:`~torch.distributed.ProcessGroupNCCL.Options`). |
1499 | )" ) |
1500 | .def( |
1501 | py::init([](const std::string& backend, |
1502 | const std::chrono::milliseconds& timeout) { |
1503 | return c10::make_intrusive<::c10d::ProcessGroup::Options>( |
1504 | backend, timeout); |
1505 | }), |
1506 | py::arg("backend" ), |
1507 | py::arg("timeout" ) = kProcessGroupDefaultTimeout, |
1508 | py::call_guard<py::gil_scoped_release>()) |
1509 | .def_readonly("backend" , &::c10d::ProcessGroup::Options::backend) |
1510 | .def_readwrite("_timeout" , &::c10d::ProcessGroup::Options::timeout); |
1511 | |
1512 | #ifndef _WIN32 |
1513 | module.def( |
1514 | "_round_robin_process_groups" , |
1515 | [](std::vector<c10::intrusive_ptr<::c10d::ProcessGroup>> processGroups) |
1516 | -> c10::intrusive_ptr<::c10d::ProcessGroup> { |
1517 | if (processGroups.empty()) { |
1518 | throw std::invalid_argument("Specify at least 1 process group" ); |
1519 | } |
1520 | const auto& first = processGroups.front(); |
1521 | return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>( |
1522 | first->getRank(), first->getSize(), std::move(processGroups)); |
1523 | }, |
1524 | py::arg("process_groups" ), |
1525 | py::call_guard<py::gil_scoped_release>()); |
1526 | #endif |
1527 | |
1528 | // TODO: The collection definitions handles direct instantiation of |
1529 | // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported |
1530 | // and should be removed once all tests are transitioned |
1531 | auto backend = |
1532 | py::class_<::c10d::Backend, c10::intrusive_ptr<::c10d::Backend>>( |
1533 | module, "Backend" ) |
1534 | .def("rank" , &::c10d::Backend::getRank) |
1535 | .def("size" , &::c10d::Backend::getSize) |
1536 | .def("name" , &::c10d::Backend::getBackendName) |
1537 | .def( |
1538 | "broadcast" , |
1539 | &::c10d::Backend::broadcast, |
1540 | py::arg("tensors" ), |
1541 | py::arg("opts" ) = ::c10d::BroadcastOptions(), |
1542 | py::call_guard<py::gil_scoped_release>()) |
1543 | .def( |
1544 | "broadcast" , |
1545 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1546 | at::Tensor& x, |
1547 | int rootRank) { |
1548 | ::c10d::BroadcastOptions opts; |
1549 | opts.rootRank = rootRank; |
1550 | std::vector<at::Tensor> xs = {x}; |
1551 | return self->broadcast(xs, opts); |
1552 | }, |
1553 | py::arg("tensor" ), |
1554 | py::arg("root" ), |
1555 | py::call_guard<py::gil_scoped_release>()) |
1556 | .def( |
1557 | "allreduce" , |
1558 | &::c10d::Backend::allreduce, |
1559 | py::arg("tensors" ), |
1560 | py::arg("opts" ) = ::c10d::AllreduceOptions(), |
1561 | py::call_guard<py::gil_scoped_release>()) |
1562 | .def( |
1563 | "allreduce" , |
1564 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1565 | std::vector<at::Tensor>& xs, |
1566 | ::c10d::ReduceOp op) { |
1567 | ::c10d::AllreduceOptions opts; |
1568 | opts.reduceOp = op; |
1569 | return self->allreduce(xs, opts); |
1570 | }, |
1571 | py::arg("tensors" ), |
1572 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1573 | py::call_guard<py::gil_scoped_release>()) |
1574 | .def( |
1575 | "allreduce" , |
1576 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1577 | at::Tensor& x, |
1578 | ::c10d::ReduceOp op) { |
1579 | ::c10d::AllreduceOptions opts; |
1580 | opts.reduceOp = op; |
1581 | std::vector<at::Tensor> xs = {x}; |
1582 | return self->allreduce(xs, opts); |
1583 | }, |
1584 | py::arg("tensor" ), |
1585 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1586 | py::call_guard<py::gil_scoped_release>()) |
1587 | .def( |
1588 | "allreduce_coalesced" , |
1589 | &::c10d::Backend::allreduce_coalesced, |
1590 | py::arg("tensors" ), |
1591 | py::arg("opts" ) = ::c10d::AllreduceCoalescedOptions(), |
1592 | py::call_guard<py::gil_scoped_release>()) |
1593 | .def( |
1594 | "reduce" , |
1595 | &::c10d::Backend::reduce, |
1596 | py::arg("tensors" ), |
1597 | py::arg("opts" ) = ::c10d::ReduceOptions(), |
1598 | py::call_guard<py::gil_scoped_release>()) |
1599 | .def( |
1600 | "reduce" , |
1601 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1602 | at::Tensor& x, |
1603 | int rootRank, |
1604 | ::c10d::ReduceOp op) { |
1605 | ::c10d::ReduceOptions opts; |
1606 | opts.reduceOp = op; |
1607 | opts.rootRank = rootRank; |
1608 | std::vector<at::Tensor> xs = {x}; |
1609 | return self->reduce(xs, opts); |
1610 | }, |
1611 | py::arg("tensor" ), |
1612 | py::arg("root" ), |
1613 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1614 | py::call_guard<py::gil_scoped_release>()) |
1615 | .def( |
1616 | "allgather" , |
1617 | &::c10d::Backend::allgather, |
1618 | py::arg("output_tensors" ), |
1619 | py::arg("input_tensors" ), |
1620 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1621 | py::call_guard<py::gil_scoped_release>()) |
1622 | .def( |
1623 | "_allgather_base" , |
1624 | &::c10d::Backend::_allgather_base, |
1625 | py::arg("output" ), |
1626 | py::arg("input" ), |
1627 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1628 | py::call_guard<py::gil_scoped_release>()) |
1629 | .def( |
1630 | "allgather" , |
1631 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1632 | std::vector<at::Tensor>& output, |
1633 | at::Tensor& input) { |
1634 | std::vector<std::vector<at::Tensor>> outputs = {output}; |
1635 | std::vector<at::Tensor> inputs = {input}; |
1636 | return self->allgather( |
1637 | outputs, inputs, ::c10d::AllgatherOptions()); |
1638 | }, |
1639 | py::arg("output_tensors" ), |
1640 | py::arg("input_tensor" ), |
1641 | py::call_guard<py::gil_scoped_release>()) |
1642 | .def( |
1643 | "allgather_coalesced" , |
1644 | &::c10d::Backend::allgather_coalesced, |
1645 | py::arg("output_lists" ), |
1646 | py::arg("input_list" ), |
1647 | py::arg("opts" ) = ::c10d::AllgatherOptions(), |
1648 | py::call_guard<py::gil_scoped_release>()) |
1649 | .def( |
1650 | "gather" , |
1651 | &::c10d::Backend::gather, |
1652 | py::arg("output_tensors" ), |
1653 | py::arg("input_tensors" ), |
1654 | py::arg("opts" ) = ::c10d::GatherOptions(), |
1655 | py::call_guard<py::gil_scoped_release>()) |
1656 | .def( |
1657 | "gather" , |
1658 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1659 | std::vector<at::Tensor>& output, |
1660 | at::Tensor& input, |
1661 | int rootRank) { |
1662 | ::c10d::GatherOptions opts; |
1663 | opts.rootRank = rootRank; |
1664 | std::vector<std::vector<at::Tensor>> outputs = {output}; |
1665 | std::vector<at::Tensor> inputs = {input}; |
1666 | return self->gather(outputs, inputs, opts); |
1667 | }, |
1668 | py::arg("output_tensors" ), |
1669 | py::arg("input_tensor" ), |
1670 | py::arg("root" ), |
1671 | py::call_guard<py::gil_scoped_release>()) |
1672 | .def( |
1673 | "scatter" , |
1674 | &::c10d::Backend::scatter, |
1675 | py::arg("output_tensors" ), |
1676 | py::arg("input_tensors" ), |
1677 | py::arg("opts" ) = ::c10d::ScatterOptions(), |
1678 | py::call_guard<py::gil_scoped_release>()) |
1679 | .def( |
1680 | "scatter" , |
1681 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1682 | at::Tensor& output, |
1683 | std::vector<at::Tensor>& input, |
1684 | int rootRank) { |
1685 | ::c10d::ScatterOptions opts; |
1686 | opts.rootRank = rootRank; |
1687 | std::vector<std::vector<at::Tensor>> inputs = {input}; |
1688 | std::vector<at::Tensor> outputs = {output}; |
1689 | return self->scatter(outputs, inputs, opts); |
1690 | }, |
1691 | py::arg("output_tensor" ), |
1692 | py::arg("input_tensors" ), |
1693 | py::arg("root" ), |
1694 | py::call_guard<py::gil_scoped_release>()) |
1695 | .def( |
1696 | "reduce_scatter" , |
1697 | &::c10d::Backend::reduce_scatter, |
1698 | py::arg("output_tensors" ), |
1699 | py::arg("input_tensors" ), |
1700 | py::arg("opts" ) = ::c10d::ReduceScatterOptions(), |
1701 | py::call_guard<py::gil_scoped_release>()) |
1702 | .def( |
1703 | "reduce_scatter" , |
1704 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1705 | at::Tensor& output, |
1706 | std::vector<at::Tensor>& input, |
1707 | ::c10d::ReduceOp op) { |
1708 | std::vector<at::Tensor> outputs = {output}; |
1709 | std::vector<std::vector<at::Tensor>> inputs = {input}; |
1710 | ::c10d::ReduceScatterOptions opts; |
1711 | opts.reduceOp = op; |
1712 | return self->reduce_scatter(outputs, inputs, opts); |
1713 | }, |
1714 | py::arg("output_tensors" ), |
1715 | py::arg("input_tensor" ), |
1716 | py::arg("op" ) = ::c10d::ReduceOp::SUM, |
1717 | py::call_guard<py::gil_scoped_release>()) |
1718 | .def( |
1719 | "_reduce_scatter_base" , |
1720 | &::c10d::Backend::_reduce_scatter_base, |
1721 | py::arg("outputTensor" ), |
1722 | py::arg("inputTensor" ), |
1723 | py::arg("opts" ) = ::c10d::ReduceScatterOptions(), |
1724 | py::call_guard<py::gil_scoped_release>()) |
1725 | .def( |
1726 | "alltoall_base" , |
1727 | &::c10d::Backend::alltoall_base, |
1728 | py::arg("output_tensor" ), |
1729 | py::arg("input_tensor" ), |
1730 | py::arg("output_split_sizes" ), |
1731 | py::arg("input_split_sizes" ), |
1732 | py::arg("opts" ) = ::c10d::AllToAllOptions(), |
1733 | py::call_guard<py::gil_scoped_release>()) |
1734 | .def( |
1735 | "alltoall_base" , |
1736 | [](::c10d::Backend& self, |
1737 | at::Tensor& output, |
1738 | at::Tensor& input, |
1739 | std::vector<int64_t> outputSplitSizes, |
1740 | std::vector<int64_t> inputSplitSizes) { |
1741 | return self.alltoall_base( |
1742 | output, |
1743 | input, |
1744 | outputSplitSizes, |
1745 | inputSplitSizes, |
1746 | ::c10d::AllToAllOptions()); |
1747 | }, |
1748 | py::arg("output" ), |
1749 | py::arg("input" ), |
1750 | py::arg("output_split_sizes" ), |
1751 | py::arg("input_split_sizes" ), |
1752 | py::call_guard<py::gil_scoped_release>()) |
1753 | .def( |
1754 | "alltoall" , |
1755 | &::c10d::Backend::alltoall, |
1756 | py::arg("output_tensor" ), |
1757 | py::arg("input_tensor" ), |
1758 | py::arg("opts" ) = ::c10d::AllToAllOptions(), |
1759 | py::call_guard<py::gil_scoped_release>()) |
1760 | .def( |
1761 | "send" , |
1762 | &::c10d::Backend::send, |
1763 | py::arg("tensors" ), |
1764 | py::arg("dstRank" ), |
1765 | py::arg("tag" ), |
1766 | py::call_guard<py::gil_scoped_release>()) |
1767 | .def( |
1768 | "recv" , |
1769 | &::c10d::Backend::recv, |
1770 | py::arg("tensors" ), |
1771 | py::arg("srcRank" ), |
1772 | py::arg("tag" ), |
1773 | py::call_guard<py::gil_scoped_release>()) |
1774 | .def( |
1775 | "recv_anysource" , |
1776 | &::c10d::Backend::recvAnysource, |
1777 | py::call_guard<py::gil_scoped_release>()) |
1778 | .def( |
1779 | "barrier" , |
1780 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1781 | const ::c10d::BarrierOptions& opts) { |
1782 | return self->barrier(opts); |
1783 | }, |
1784 | py::arg("opts" ) = ::c10d::BarrierOptions(), |
1785 | py::call_guard<py::gil_scoped_release>()) |
1786 | .def( |
1787 | "_set_sequence_number_for_group" , |
1788 | &::c10d::Backend::setSequenceNumberForGroup, |
1789 | py::call_guard<py::gil_scoped_release>()) |
1790 | .def( |
1791 | "_get_sequence_number_for_group" , |
1792 | &::c10d::Backend::getSequenceNumberForGroup, |
1793 | py::call_guard<py::gil_scoped_release>()) |
1794 | .def( |
1795 | "monitored_barrier" , |
1796 | [](const c10::intrusive_ptr<::c10d::Backend>& self, |
1797 | const std::chrono::milliseconds& timeout, |
1798 | bool waitAllRanks) { |
1799 | ::c10d::BarrierOptions opts; |
1800 | opts.timeout = timeout; |
1801 | return self->monitoredBarrier(opts, waitAllRanks); |
1802 | }, |
1803 | py::arg("timeout" ) = ::c10d::kUnsetTimeout, |
1804 | py::arg("wait_all_ranks" ) = false, |
1805 | py::call_guard<py::gil_scoped_release>()) |
1806 | .def( |
1807 | "_get_backend_name" , |
1808 | &::c10d::Backend::getBackendName, |
1809 | py::call_guard<py::gil_scoped_release>()) |
1810 | .def( |
1811 | "_start_coalescing" , |
1812 | &::c10d::Backend::startCoalescing, |
1813 | py::call_guard<py::gil_scoped_release>()) |
1814 | .def( |
1815 | "_end_coalescing" , |
1816 | &::c10d::Backend::endCoalescing, |
1817 | py::arg("reqs" ), |
1818 | py::call_guard<py::gil_scoped_release>()); |
1819 | |
1820 | #ifdef USE_C10D_GLOO |
1821 | static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME" ; |
1822 | |
1823 | auto processGroupGloo = |
1824 | intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>( |
1825 | module, "ProcessGroupGloo" , backend); |
1826 | |
1827 | shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device" ); |
1828 | |
1829 | intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( |
1830 | processGroupGloo, "_Options" , processGroupOptions) |
1831 | .def(py::init<>()) |
1832 | .def_readwrite("_devices" , &::c10d::ProcessGroupGloo::Options::devices) |
1833 | .def_readwrite("_threads" , &::c10d::ProcessGroupGloo::Options::threads); |
1834 | |
1835 | processGroupGloo |
1836 | .def_static( |
1837 | "create_device" , |
1838 | [](const std::string& hostname, const std::string& interface) |
1839 | -> std::shared_ptr<::gloo::transport::Device> { |
1840 | if (!hostname.empty()) { |
1841 | return ::c10d::ProcessGroupGloo::createDeviceForHostname( |
1842 | hostname); |
1843 | } |
1844 | if (!interface.empty()) { |
1845 | return ::c10d::ProcessGroupGloo::createDeviceForInterface( |
1846 | interface); |
1847 | } |
1848 | throw std::invalid_argument( |
1849 | "Specify either `hostname` or `interface` argument." ); |
1850 | }, |
1851 | py::arg("hostname" ) = "" , |
1852 | py::arg("interface" ) = "" ) |
1853 | .def_static( |
1854 | "create_default_device" , |
1855 | &::c10d::ProcessGroupGloo::createDefaultDevice); |
1856 | |
1857 | processGroupGloo |
1858 | .def( |
1859 | py::init< |
1860 | const c10::intrusive_ptr<::c10d::Store>&, |
1861 | int, |
1862 | int, |
1863 | c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(), |
1864 | py::call_guard<py::gil_scoped_release>()) |
1865 | .def( |
1866 | py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
1867 | int rank, |
1868 | int size, |
1869 | std::chrono::milliseconds timeout) { |
1870 | auto options = ::c10d::ProcessGroupGloo::Options::create(); |
1871 | |
1872 | // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. |
1873 | char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); |
1874 | if (ifnameEnv && strlen(ifnameEnv) > 1) { |
1875 | for (const auto& iface : ::c10d::split(',', ifnameEnv)) { |
1876 | options->devices.push_back( |
1877 | ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); |
1878 | } |
1879 | } else { |
1880 | // If no hostname is specified, this function looks up |
1881 | // the machine's hostname and returns a device instance |
1882 | // associated with the address that the hostname resolves to. |
1883 | options->devices.push_back( |
1884 | ::c10d::ProcessGroupGloo::createDefaultDevice()); |
1885 | } |
1886 | |
1887 | options->timeout = timeout; |
1888 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
1889 | options->threads = options->devices.size() * 2; |
1890 | return c10::make_intrusive<::c10d::ProcessGroupGloo>( |
1891 | store, rank, size, options); |
1892 | }), |
1893 | py::arg("store" ), |
1894 | py::arg("rank" ), |
1895 | py::arg("size" ), |
1896 | py::arg("timeout" ) = kProcessGroupDefaultTimeout, |
1897 | py::call_guard<py::gil_scoped_release>()) |
1898 | .def_property_readonly("options" , &::c10d::ProcessGroupGloo::getOptions); |
1899 | |
1900 | // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process |
1901 | // group. It can be used to validate collective calls across processes by |
1902 | // checking the op type and input tensor shapes. |
1903 | auto processGroupWrapper = |
1904 | intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>( |
1905 | module, "_ProcessGroupWrapper" , backend) |
1906 | .def( |
1907 | py::init( |
1908 | [](const c10::intrusive_ptr<::c10d::Backend>& backend, |
1909 | const c10::intrusive_ptr<::c10d::Backend>& gloo_backend) { |
1910 | return c10::make_intrusive<::c10d::ProcessGroupWrapper>( |
1911 | backend, gloo_backend); |
1912 | }), |
1913 | py::arg("backend" ), |
1914 | py::arg("gloo_backend" ), |
1915 | py::call_guard<py::gil_scoped_release>()) |
1916 | .def_property_readonly( |
1917 | "wrapped_pg" , &::c10d::ProcessGroupWrapper::getWrappedPg); |
1918 | #endif |
1919 | |
1920 | #ifdef USE_C10D_NCCL |
1921 | auto processGroupNCCL = |
1922 | intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupNCCL>( |
1923 | module, "ProcessGroupNCCL" , backend) |
1924 | .def( |
1925 | py::init< |
1926 | const c10::intrusive_ptr<::c10d::Store>&, |
1927 | int, |
1928 | int, |
1929 | c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(), |
1930 | py::call_guard<py::gil_scoped_release>()) |
1931 | .def( |
1932 | py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
1933 | int rank, |
1934 | int size, |
1935 | const std::chrono::milliseconds& timeout) { |
1936 | auto options = ::c10d::ProcessGroupNCCL::Options::create(); |
1937 | options->is_high_priority_stream = false; |
1938 | options->timeout = timeout; |
1939 | return c10::make_intrusive<::c10d::ProcessGroupNCCL>( |
1940 | store, rank, size, options); |
1941 | }), |
1942 | py::arg("store" ), |
1943 | py::arg("rank" ), |
1944 | py::arg("size" ), |
1945 | py::arg("timeout" ) = kProcessGroupDefaultTimeout, |
1946 | py::call_guard<py::gil_scoped_release>()) |
1947 | .def_property_readonly( |
1948 | "options" , &::c10d::ProcessGroupNCCL::getOptions) |
1949 | .def_property_readonly( |
1950 | "is_ucc_available" , &::c10d::ProcessGroupNCCL::isUCCAvailable); |
1951 | |
1952 | intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( |
1953 | processGroupNCCL, |
1954 | "Options" , |
1955 | processGroupOptions, |
1956 | R"( |
1957 | ProcessGroup options for the NCCL backend |
1958 | |
1959 | Arguments: |
1960 | is_high_priority_stream (bool, optional): flag to enable/disable process |
1961 | group to pick up high priority cuda streams. It lets CUDA driver |
1962 | to prioritize NCCL kernels when there are compute kernels waiting. |
1963 | Default is False. |
1964 | |
1965 | Example:: |
1966 | >>> import torch.distributed as dist |
1967 | >>> |
1968 | >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) |
1969 | >>> # initialize a nccl process group with the options just created |
1970 | >>> dist.init_process_group("nccl", pg_options=nccl_options) |
1971 | )" ) |
1972 | .def(py::init<bool>(), py::arg("is_high_priority_stream" ) = false) |
1973 | .def_readwrite( |
1974 | "is_high_priority_stream" , |
1975 | &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); |
1976 | processGroupNCCL.def_static( |
1977 | "_group_start" , []() { ::c10d::ProcessGroupNCCL::groupStart(); }); |
1978 | processGroupNCCL.def_static( |
1979 | "_group_end" , []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); |
1980 | #endif |
1981 | |
1982 | #ifdef USE_C10D_MPI |
1983 | auto processGroupMPI = |
1984 | intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>( |
1985 | module, "ProcessGroupMPI" , backend); |
1986 | |
1987 | // Define static create function instead of a constructor, because |
1988 | // this function may return null. This happens if this process is not |
1989 | // part of a sub group that is to be created. |
1990 | processGroupMPI.def_static( |
1991 | "create" , |
1992 | [](std::vector<int> ranks) { |
1993 | return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); |
1994 | }, |
1995 | py::call_guard<py::gil_scoped_release>()); |
1996 | #endif |
1997 | |
1998 | #ifdef USE_C10D_UCC |
1999 | auto processGroupUCC = |
2000 | intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>( |
2001 | module, "ProcessGroupUCC" , backend) |
2002 | .def( |
2003 | py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
2004 | int rank, |
2005 | int size, |
2006 | const std::chrono::milliseconds& timeout) { |
2007 | return c10::make_intrusive<::c10d::ProcessGroupUCC>( |
2008 | store, rank, size, timeout); |
2009 | }), |
2010 | py::arg("store" ), |
2011 | py::arg("rank" ), |
2012 | py::arg("size" ), |
2013 | py::arg("timeout" ) = kProcessGroupDefaultTimeout, |
2014 | py::call_guard<py::gil_scoped_release>()); |
2015 | #endif |
2016 | |
2017 | py::class_< |
2018 | ::c10d::Work, |
2019 | c10::intrusive_ptr<::c10d::Work>, |
2020 | ::c10d::PyProcessGroup::PyWork>(module, "Work" ) |
2021 | .def(py::init<>()) |
2022 | .def("is_completed" , &::c10d::Work::isCompleted) |
2023 | .def( |
2024 | "is_success" , |
2025 | [](::c10d::Work& work) -> bool { |
2026 | TORCH_WARN_ONCE( |
2027 | fmt::format(kDeprecationWarning, "Work::is_success" )); |
2028 | return work.isSuccess(); |
2029 | }) |
2030 | .def( |
2031 | "exception" , |
2032 | [](::c10d::Work& work) -> std::exception_ptr { |
2033 | TORCH_WARN_ONCE( |
2034 | fmt::format(kDeprecationWarning, "Work::exception" )); |
2035 | return work.exception(); |
2036 | }) |
2037 | .def( |
2038 | "source_rank" , |
2039 | [](::c10d::Work& work) -> int { |
2040 | TORCH_WARN_ONCE( |
2041 | fmt::format(kDeprecationWarning, "Work::source_rank" )); |
2042 | return work.sourceRank(); |
2043 | }) |
2044 | .def("_source_rank" , &::c10d::Work::sourceRank) |
2045 | .def( |
2046 | "result" , |
2047 | [](::c10d::Work& work) -> std::vector<at::Tensor> { |
2048 | return work.result(); |
2049 | }) |
2050 | .def( |
2051 | "synchronize" , |
2052 | [](::c10d::Work& work) -> void { |
2053 | TORCH_WARN_ONCE( |
2054 | fmt::format(kDeprecationWarning, "Work::synchronize" )); |
2055 | work.synchronize(); |
2056 | }) |
2057 | .def( |
2058 | "wait" , |
2059 | &::c10d::Work::wait, |
2060 | py::arg("timeout" ) = kNoTimeout, |
2061 | py::call_guard<py::gil_scoped_release>()) |
2062 | .def( |
2063 | "get_future" , |
2064 | [](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> { |
2065 | return std::make_shared<jit::PythonFutureWrapper>(work.getFuture()); |
2066 | }, |
2067 | R"( |
2068 | Returns: |
2069 | A ``torch.futures.Future`` object which is associated with the completion of |
2070 | the ``Work``. As an example, a future object can be retrieved |
2071 | by ``fut = process_group.allreduce(tensors).get_future()``. |
2072 | |
2073 | Example:: |
2074 | Below is an example of a simple allreduce DDP communication hook that uses |
2075 | ``get_future` API to retrieve a Future associated with the completion of |
2076 | ``allreduce``. |
2077 | |
2078 | >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future |
2079 | >>> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD |
2080 | >>> tensor = bucket.buffer().div_(group_to_use.size()) |
2081 | >>> return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future() |
2082 | >>> ddp_model.register_comm_hook(state=None, hook=allreduce) |
2083 | |
2084 | .. warning :: |
2085 | ``get_future`` API supports NCCL, and partially GLOO and MPI backends |
2086 | (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. |
2087 | |
2088 | In the example above, ``allreduce`` work will be done on GPU using NCCL backend, |
2089 | ``fut.wait()`` will return after synchronizing the appropriate NCCL streams |
2090 | with PyTorch's current device streams to ensure we can have asynchronous CUDA |
2091 | execution and it does not wait for the entire operation to complete on GPU. Note that |
2092 | ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. |
2093 | In addition, if a callback function was added by ``fut.then()``, it will wait until |
2094 | ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback |
2095 | stream and invoke the callback inline after running the callback on the callback stream. |
2096 | ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the |
2097 | callback and a ``CUDAEvent`` that recorded the callback stream. |
2098 | |
2099 | 1. For CPU work, ``fut.done()`` returns true when work has been complted and value() |
2100 | tensors are ready. |
2101 | 2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued. |
2102 | 3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns |
2103 | true when tensors have arrived on respective nodes, but not yet necessarily synched on |
2104 | respective GPUs (similarly to GPU work). |
2105 | )" ); |
2106 | |
2107 | py::class_<c10::DDPLoggingData>(module, "DDPLoggingData" ) |
2108 | .def(py::init<>()) |
2109 | .def_readwrite("strs_map" , &c10::DDPLoggingData::strs_map) |
2110 | .def_readwrite("ints_map" , &c10::DDPLoggingData::ints_map); |
2111 | |
2112 | module.def( |
2113 | "_compute_bucket_assignment_by_size" , |
2114 | [](const std::vector<at::Tensor>& tensors, |
2115 | const std::vector<size_t>& bucket_size_limits, |
2116 | const std::vector<bool>& expect_sparse_gradient, |
2117 | const std::vector<int64_t>& tensor_indices, |
2118 | const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) { |
2119 | if (logger.has_value()) { |
2120 | std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); |
2121 | return ::c10d::compute_bucket_assignment_by_size( |
2122 | tensors, |
2123 | bucket_size_limits, |
2124 | expect_sparse_gradient, |
2125 | tensor_indices, |
2126 | {logger_weakref}); |
2127 | } else { |
2128 | return ::c10d::compute_bucket_assignment_by_size( |
2129 | tensors, |
2130 | bucket_size_limits, |
2131 | expect_sparse_gradient, |
2132 | tensor_indices, |
2133 | {}); |
2134 | } |
2135 | }, |
2136 | py::arg("tensors" ), |
2137 | py::arg("bucket_size" ), |
2138 | py::arg("expect_sparse_gradient" ) = std::vector<bool>(), |
2139 | py::arg("tensor_indices" ) = std::vector<int64_t>(), |
2140 | py::arg("logger" ) = c10::optional<std::shared_ptr<::c10d::Logger>>{}, |
2141 | py::call_guard<py::gil_scoped_release>()); |
2142 | |
2143 | module.def( |
2144 | "_verify_params_across_processes" , |
2145 | [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, |
2146 | const std::vector<at::Tensor>& params, |
2147 | const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) { |
2148 | if (logger.has_value()) { |
2149 | std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); |
2150 | verify_params_across_processes( |
2151 | process_group, params, {logger_weakref}); |
2152 | } else { |
2153 | verify_params_across_processes(process_group, params, {}); |
2154 | } |
2155 | }, |
2156 | py::arg("process_group" ), |
2157 | py::arg("params" ), |
2158 | py::arg("logger" ) = c10::optional<std::shared_ptr<::c10d::Logger>>{}, |
2159 | py::call_guard<py::gil_scoped_release>()); |
2160 | |
2161 | module.def( |
2162 | "_broadcast_coalesced" , |
2163 | // Define a lambda such that the pybind11 prototype can take a std::vector |
2164 | // for the tensor list argument, but still pass it to the underlying |
2165 | // function as a c10::ArrayRef. |
2166 | [](c10::intrusive_ptr<::c10d::ProcessGroup> process_group, |
2167 | std::vector<at::Tensor> tensors, // NOLINT |
2168 | size_t buffer_size, |
2169 | int rank) { |
2170 | broadcast_coalesced( |
2171 | std::move(process_group), tensors, buffer_size, rank); |
2172 | }, |
2173 | py::arg("process_group" ), |
2174 | py::arg("tensors" ), |
2175 | py::arg("buffer_size" ), |
2176 | // The source of truth rank to broadcast the tensors from. |
2177 | py::arg("src" ) = 0, |
2178 | py::call_guard<py::gil_scoped_release>()); |
2179 | |
2180 | module.def( |
2181 | "_test_python_store" , |
2182 | // Define a function that takes a c10d store and runs a few tests. |
2183 | // This is used by the PythonStore tests, which we cannot test from the |
2184 | // Python side of the world. Calling Python functions on a Python object |
2185 | // completely bypasses pybind11. We need to test that the overloaded |
2186 | // functions call into Python and behave like we expect. |
2187 | [](c10::intrusive_ptr<::c10d::Store> store) { |
2188 | auto add = [&store](const std::string& key, int64_t value) { |
2189 | store->add(key, value); |
2190 | }; |
2191 | |
2192 | auto set = [&store](const std::string& key, const std::string& value) { |
2193 | store->set(key, value); |
2194 | }; |
2195 | |
2196 | auto get = [&store](const std::string& key) { |
2197 | auto value = store->get(key); |
2198 | return std::string(value.begin(), value.end()); |
2199 | }; |
2200 | |
2201 | add("key" , 1); |
2202 | add("key" , 2); |
2203 | add("key" , 3); |
2204 | set("key0" , "value0" ); |
2205 | add("key3" , 1); |
2206 | set("key1" , "value1" ); |
2207 | add("key3" , 2); |
2208 | set("key2" , "value2" ); |
2209 | add("key3" , 3); |
2210 | add("key3" , 4); |
2211 | add("key3" , 3); |
2212 | add("key3" , 2); |
2213 | if (get("key" ) != "6" ) { |
2214 | TORCH_CHECK(false, "assertion failed" ); |
2215 | } |
2216 | if (get("key0" ) != "value0" ) { |
2217 | TORCH_CHECK(false, "assertion failed" ); |
2218 | } |
2219 | if (get("key1" ) != "value1" ) { |
2220 | TORCH_CHECK(false, "assertion failed" ); |
2221 | } |
2222 | if (get("key2" ) != "value2" ) { |
2223 | TORCH_CHECK(false, "assertion failed" ); |
2224 | } |
2225 | if (get("key3" ) != "15" ) { |
2226 | TORCH_CHECK(false, "assertion failed" ); |
2227 | } |
2228 | }, |
2229 | py::call_guard<py::gil_scoped_release>()); |
2230 | |
2231 | module.attr("_DEFAULT_FIRST_BUCKET_BYTES" ) = ::c10d::kDefaultFirstBucketBytes; |
2232 | module.attr("_DEFAULT_PG_TIMEOUT" ) = py::cast(kProcessGroupDefaultTimeout); |
2233 | module.attr("_DEFAULT_NO_TIMEOUT" ) = py::cast(kNoTimeout); |
2234 | |
2235 | module.def( |
2236 | "_create_work_from_future" , |
2237 | [](std::shared_ptr<jit::PythonFutureWrapper> future) { |
2238 | return ::c10d::Work::create_from_future(future->fut); |
2239 | }, |
2240 | py::arg("future" ), |
2241 | R"( |
2242 | Arguments: |
2243 | future(str): The future to wrap. |
2244 | Returns: |
2245 | A ``Work`` object which is associated with the completion of |
2246 | the ``torch.futures.Future``. |
2247 | This is the preferred way of constructing Work objects when writing a custom ProcessGroup |
2248 | in python. |
2249 | Example:: |
2250 | >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): |
2251 | >>> def broadcast(self, tensor_list, opts): |
2252 | >>> fut = torch.futures.Future() |
2253 | >>> fut.set_result(tensor_list) |
2254 | >>> return torch._C._distributed_c10d._create_work_from_future(fut) |
2255 | .. warning :: |
2256 | This API is experimental and subject to change. |
2257 | The returned Work object has multiple limitations: |
2258 | - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization. |
2259 | - wait() ignored timeout argument. |
2260 | - sourceRank() raises. |
2261 | - abort() raises. |
2262 | The provided Future object result must be a Tensor or a list of Tensors. |
2263 | )" ); |
2264 | |
2265 | Py_RETURN_TRUE; |
2266 | } |
2267 | |
2268 | #undef PROCESS_GROUP_DEPRECATION_WARNING |
2269 | |
2270 | } // namespace |
2271 | |
2272 | // c10d methods on torch._C |
2273 | static PyMethodDef methods[] = { // NOLINT |
2274 | {"_c10d_init" , c10d_init, METH_NOARGS, nullptr}, |
2275 | {nullptr, nullptr, 0, nullptr}}; |
2276 | |
2277 | PyMethodDef* python_functions() { |
2278 | return methods; |
2279 | } |
2280 | |
2281 | } // namespace c10d |
2282 | } // namespace distributed |
2283 | } // namespace torch |
2284 | |