1#include <torch/csrc/python_headers.h>
2
3#include <c10/util/intrusive_ptr.h>
4#include <c10/util/string_view.h>
5#include <torch/csrc/distributed/c10d/FileStore.hpp>
6#include <torch/csrc/distributed/c10d/TCPStore.hpp>
7#include <torch/csrc/distributed/c10d/Utils.hpp>
8#ifndef _WIN32
9#include <torch/csrc/distributed/c10d/HashStore.hpp>
10#include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
11#endif
12#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
13#include <torch/csrc/distributed/c10d/PyProcessGroup.hpp>
14
15#ifdef USE_C10D_GLOO
16#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
17#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
18#endif
19
20#ifdef USE_C10D_NCCL
21#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
22#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
23#endif
24
25#ifdef USE_C10D_MPI
26#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
27#endif
28
29#ifdef USE_C10D_UCC
30#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
31#endif
32
33#include <fmt/format.h>
34#include <pybind11/chrono.h>
35#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
36
37#include <torch/csrc/distributed/c10d/comm.hpp>
38#include <torch/csrc/distributed/c10d/debug.h>
39#include <torch/csrc/distributed/c10d/logger.hpp>
40#include <torch/csrc/distributed/c10d/reducer.hpp>
41
42#include <torch/csrc/Exceptions.h>
43#include <torch/csrc/distributed/c10d/python_comm_hook.h>
44#include <torch/csrc/jit/python/pybind_utils.h>
45#include <torch/csrc/utils/object_ptr.h>
46#include <torch/csrc/utils/pybind.h>
47
48#include <torch/custom_class.h>
49
50namespace {
51
52// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
53// TODO: move this somewhere more generally useful
54template <typename T>
55class IntrusivePtrNoGilDestructor {
56 c10::intrusive_ptr<T> impl_;
57
58 public:
59 IntrusivePtrNoGilDestructor() = default;
60 IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default;
61 IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default;
62 IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) =
63 default;
64 IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) =
65 default;
66 /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl)
67 : impl_(std::move(impl)) {}
68 // This ctor is very important; see
69 // https://github.com/pybind/pybind11/issues/2957
70 explicit IntrusivePtrNoGilDestructor(T* impl)
71 : impl_(c10::intrusive_ptr<T>::unsafe_steal_from_new(impl)) {}
72 ~IntrusivePtrNoGilDestructor() {
73 if (impl_) {
74 if (PyGILState_Check()) {
75 pybind11::gil_scoped_release release;
76 impl_.reset();
77 } else {
78 impl_.reset();
79 }
80 }
81 }
82 T& operator*() const noexcept {
83 return *impl_;
84 }
85 T* operator->() const noexcept {
86 return impl_.get();
87 }
88 C10_NODISCARD T* get() const noexcept {
89 return impl_.get();
90 }
91 void reset() noexcept {
92 impl_.reset();
93 }
94 operator bool() const noexcept {
95 return impl_;
96 }
97};
98
99} // anonymous namespace
100
101PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor<T>, true);
102
103namespace torch {
104namespace distributed {
105namespace c10d {
106
107namespace {
108
109template <typename T>
110using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
111
112constexpr auto kDeprecationWarning =
113 "{} API is being deprecated, please ping "
114 "https://github.com/pytorch/pytorch/issues/46291 "
115 "if you see this warning";
116template <typename T>
117using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>;
118
119template <typename T>
120using intrusive_ptr_no_gil_destructor_class_ =
121 py::class_<T, IntrusivePtrNoGilDestructor<T>>;
122
123// PythonStore is a pybind11 trampoline class to allow a Python
124// class to inherit from c10d.Store and implement its interface.
125class PythonStore : public ::c10d::Store {
126 public:
127 using ::c10d::Store::Store;
128
129 // Note: this function manually calls the Python-side overload
130 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
131 // macros. This is done so that we can call the Python-side
132 // function with a std::string instead of a std::vector<uint8_t>.
133 void set(const std::string& key, const std::vector<uint8_t>& value) override {
134 pybind11::gil_scoped_acquire gil;
135 pybind11::function fn =
136 pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
137 TORCH_INTERNAL_ASSERT(fn);
138 // Call function with a py::bytes object for the value.
139 fn(key,
140 py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
141 }
142
143 // Note: this function manually calls the Python-side overload
144 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
145 // macros. This is done so that the Python-side function can
146 // return a py::bytes instead of a std::vector<uint8_t>.
147 std::vector<uint8_t> get(const std::string& key) override {
148 pybind11::gil_scoped_acquire gil;
149 pybind11::function fn =
150 pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "get");
151 TORCH_INTERNAL_ASSERT(fn);
152 // Cast return value from Python to py::bytes, then implicitly
153 // convert that to a std::string, so that we can construct a
154 // std::vector<uint8_t>. There is no API for directly accessing
155 // the contents of the py::bytes object.
156 std::string str = pybind11::cast<py::bytes>(fn(key));
157 return std::vector<uint8_t>(str.begin(), str.end());
158 }
159
160 // Note: this function manually calls the Python-side overload
161 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
162 // macros. This is done so that the Python-side function can
163 // return a py::bytes instead of a std::vector<uint8_t>.
164 std::vector<uint8_t> compareSet(
165 const std::string& key,
166 const std::vector<uint8_t>& expectedValue,
167 const std::vector<uint8_t>& desiredValue) override {
168 pybind11::gil_scoped_acquire gil;
169 pybind11::function fn = pybind11::get_overload(
170 static_cast<const ::c10d::Store*>(this), "compare_set");
171 TORCH_INTERNAL_ASSERT(fn);
172 // Cast return value from Python to py::bytes, then implicitly
173 // convert that to a std::string, so that we can construct a
174 // std::vector<uint8_t>. There is no API for directly accessing
175 // the contents of the py::bytes object.
176 std::string str =
177 pybind11::cast<py::bytes>(fn(key, expectedValue, desiredValue));
178 return std::vector<uint8_t>(str.begin(), str.end());
179 }
180
181 int64_t add(const std::string& key, int64_t value) override {
182 PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value);
183 }
184
185 int64_t getNumKeys() override {
186 PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys);
187 }
188
189 bool deleteKey(const std::string& key) override {
190 PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key);
191 }
192
193 bool check(const std::vector<std::string>& keys) override {
194 PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys);
195 }
196
197 void wait(const std::vector<std::string>& keys) override {
198 PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys);
199 }
200
201 void wait(
202 const std::vector<std::string>& keys,
203 const std::chrono::milliseconds& timeout) override {
204 PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout);
205 }
206};
207
208// Called from DDP's Python API to create a c10d Python comm hook object.
209// The input state and callable comm_hook are Python objects. It later calls
210// register_comm_hook function of the reducer input to register the hook.
211void _register_comm_hook(
212 ::c10d::Reducer& reducer,
213 py::object state,
214 py::object comm_hook) {
215 reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
216 std::move(state), std::move(comm_hook)));
217}
218
219// Called from DDP's Python API to create a c10d C++ comm hook.
220// The input is an enum hook type. It later calls register_builtin_comm_hook
221// function of the reducer input to set the hook type.
222void _register_builtin_comm_hook(
223 ::c10d::Reducer& reducer,
224 ::c10d::BuiltinCommHookType comm_hook_type) {
225 reducer.register_builtin_comm_hook(comm_hook_type);
226}
227
228// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
229// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
230// struct from enum, sacrificing some of the Python built-in function supports
231// such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191)
232// and `copy` (see
233// https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below,
234// we define a custom `isinstance` in CPython/pybind11
235// (`reduceopmeta___instancecheck__`) and modify the default metaclass of
236// pybind11 (`GetReduceOpMetaclass`) so that
237// `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)`
238// returns :obj:`True` as if `ReduceOp` is enum.
239// Ref:
240// - https://docs.python.org/3/extending/newtypes_tutorial.html
241// - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods
242// - https://github.com/pybind/pybind11/issues/2696
243static PyObject* reduceopmeta___instancecheck__(
244 PyObject* self,
245 PyObject* args) {
246 if (Py_TYPE(self) == Py_TYPE(args)) {
247 Py_RETURN_TRUE;
248 }
249 if (c10::string_view(args->ob_type->tp_name).find("RedOpType") !=
250 c10::string_view::npos) {
251 Py_RETURN_TRUE;
252 }
253 Py_RETURN_FALSE;
254}
255static PyMethodDef reduceopmeta_methods[] = {
256 {"__instancecheck__",
257 (PyCFunction)reduceopmeta___instancecheck__,
258 METH_O,
259 "Custom `__instancecheck__` for ReduceOp"},
260 {nullptr, nullptr}};
261PyTypeObject* GetReduceOpMetaclass() {
262 static auto* metaclass = [] {
263 PyTypeObject* base_metaclass =
264 pybind11::detail::get_internals().default_metaclass;
265 PyType_Slot slots[] = {
266 {Py_tp_base, base_metaclass},
267 {Py_tp_methods, reduceopmeta_methods},
268 {0},
269 };
270 PyType_Spec spec = {};
271 spec.name = "torch._C._distributed_c10d._ReduceOpMeta";
272 spec.basicsize = base_metaclass->tp_basicsize;
273 spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
274 spec.slots = slots;
275 PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
276 if (!metaclass)
277 throw py::error_already_set();
278 return metaclass;
279 }();
280 return metaclass;
281}
282
283PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
284 C10_LOG_API_USAGE_ONCE("c10d.python.import");
285
286 auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
287 if (!c10d_module) {
288 throw python_error();
289 }
290
291 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
292 if (!torch_C_module) {
293 throw python_error();
294 }
295
296 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
297 auto m =
298 torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings");
299
300 auto module = py::handle(m).cast<py::module>();
301
302 module
303 .def(
304 "_register_comm_hook",
305 &_register_comm_hook,
306 py::arg("reducer"),
307 py::arg("state"),
308 py::arg("comm_hook"),
309 py::call_guard<py::gil_scoped_release>())
310 .def(
311 "_register_builtin_comm_hook",
312 &_register_builtin_comm_hook,
313 py::arg("reducer"),
314 py::arg("comm_hook_type"));
315
316 shared_ptr_class_<::c10d::GradBucket>(
317 module,
318 "GradBucket",
319 R"(
320This class mainly passes a flattened gradient tensor
321(returned by :meth:`~torch.distributed.GradBucket.buffer`)
322to DDP communication hook.
323This tensor can be further decomposed into a list of per-parameter tensors within this bucket
324(returned by :meth:`~torch.distributed.GradBucket.get_per_parameter_tensors`)
325to apply layer-wise operations.
326)")
327 .def(
328 "index",
329 &::c10d::GradBucket::getIndex,
330 py::call_guard<py::gil_scoped_release>(),
331 R"(
332.. warning::
333 Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.
334
335Returns:
336 The index of a bucket that stores gradients of a few contiguous layers.
337 All the gradients are bucketized.
338)")
339 .def(
340 "buffer",
341 &::c10d::GradBucket::getBuffer,
342 py::call_guard<py::gil_scoped_release>(),
343 R"(
344Returns:
345 A flattened 1D ``torch.Tensor`` buffer,
346 which can be further decomposed into a list of per-parameter tensors within this bucket.
347)")
348 .def(
349 "gradients",
350 &::c10d::GradBucket::getGradients,
351 py::call_guard<py::gil_scoped_release>(),
352 R"(
353Returns:
354 A list of ``torch.Tensor``. Each tensor in the list corresponds to a gradient.
355)")
356 .def(
357 "parameters",
358 &::c10d::GradBucket::getParameters,
359 py::call_guard<py::gil_scoped_release>(),
360 R"(
361Returns:
362 A list of ``torch.Tensor``. Each tensor in the list corresponds to a model
363 parameter.
364)")
365 .def(
366 "is_last",
367 &::c10d::GradBucket::isLast,
368 py::call_guard<py::gil_scoped_release>(),
369 R"(
370Returns:
371 Whether this bucket is the last bucket to allreduce in an iteration.
372 This also means that this bucket corresponds to the first few layers in the forward pass.
373)")
374 .def(
375 "set_buffer",
376 &::c10d::GradBucket::setBuffer,
377 py::arg("buffer"),
378 py::call_guard<py::gil_scoped_release>(),
379 R"(
380Replaces the tensor in the bucket with the input tensor buffer.
381)");
382
383 py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
384An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)")
385 .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
386 .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);
387
388 shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
389 .def(
390 py::init<
391 std::vector<at::Tensor>,
392 std::vector<std::vector<size_t>>,
393 std::vector<size_t>,
394 c10::intrusive_ptr<::c10d::ProcessGroup>,
395 std::vector<bool>,
396 int64_t,
397 bool,
398 bool,
399 std::unordered_map<size_t, std::string>,
400 int64_t>(),
401 py::arg("params"),
402 py::arg("bucket_indices"),
403 py::arg("per_bucket_size_limits"),
404 py::arg("process_group"),
405 py::arg("expect_sparse_gradients") = std::vector<bool>(),
406 py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
407 py::arg("find_unused_parameters") = false,
408 py::arg("gradient_as_bucket_view") = false,
409 py::arg("param_to_name_mapping") =
410 std::unordered_map<size_t, std::string>(),
411 py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
412 py::call_guard<py::gil_scoped_release>())
413 .def(
414 "prepare_for_forward",
415 &::c10d::Reducer::prepare_for_forward,
416 py::call_guard<py::gil_scoped_release>())
417 .def(
418 "prepare_for_backward",
419 &::c10d::Reducer::prepare_for_backward,
420 py::call_guard<py::gil_scoped_release>())
421 .def(
422 "prepare_for_backward",
423 [](::c10d::Reducer& reducer, const at::Tensor& output) -> void {
424 reducer.prepare_for_backward({output});
425 },
426 py::call_guard<py::gil_scoped_release>())
427 .def("get_backward_stats", &::c10d::Reducer::get_backward_stats)
428 .def(
429 "_install_post_backward_futures",
430 [](::c10d::Reducer& reducer,
431 const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>&
432 futs) {
433 c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures(
434 c10::FutureType::create(c10::TensorType::get()));
435 for (const auto& fut : futs) {
436 futures.push_back(fut->fut);
437 }
438 reducer.install_futures(std::move(futures));
439 },
440 py::call_guard<py::gil_scoped_release>())
441 .def(
442 "_rebuild_buckets",
443 &::c10d::Reducer::rebuild_buckets,
444 py::call_guard<py::gil_scoped_release>())
445 .def(
446 "_get_zeros_like_grad_buckets",
447 [](::c10d::Reducer& reducer) {
448 return reducer.get_grad_buckets(/* return_zero_tensors */ true);
449 },
450 py::call_guard<py::gil_scoped_release>())
451 .def(
452 "_set_grads_to_none",
453 [](::c10d::Reducer& reducer) { reducer.set_grads_to_none(true); },
454 py::call_guard<py::gil_scoped_release>())
455 .def(
456 "_push_all_rebuilt_params",
457 &::c10d::Reducer::push_rebuilt_params_for_all_indices,
458 py::call_guard<py::gil_scoped_release>())
459 .def(
460 "_set_forward_pass_work_handle",
461 &::c10d::Reducer::set_forward_pass_work_handle,
462 py::call_guard<py::gil_scoped_release>())
463 .def(
464 "_get_local_used_map", &::c10d::Reducer::get_local_used_map_on_device)
465 .def(
466 "_set_ddp_runtime_logging_sample_rate",
467 &::c10d::Reducer::set_ddp_runtime_logging_sample_rate,
468 py::arg("sample_rate"),
469 py::call_guard<py::gil_scoped_release>())
470 .def(
471 "_set_static_graph",
472 &::c10d::Reducer::set_static_graph,
473 py::call_guard<py::gil_scoped_release>())
474 .def(
475 "_ddp_graph_static",
476 &::c10d::Reducer::ddp_graph_static,
477 py::call_guard<py::gil_scoped_release>())
478 .def(
479 "_delay_all_reduce",
480 &::c10d::Reducer::delay_all_reduce,
481 py::call_guard<py::gil_scoped_release>())
482 .def(
483 "_run_comm_hook",
484 [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
485 -> std::shared_ptr<jit::PythonFutureWrapper> {
486 c10::intrusive_ptr<c10::ivalue::Future> fut =
487 reducer.run_comm_hook(bucket);
488 return std::make_shared<jit::PythonFutureWrapper>(fut);
489 },
490 py::call_guard<py::gil_scoped_release>())
491 .def(
492 "_run_allreduce_hook",
493 [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
494 -> std::shared_ptr<jit::PythonFutureWrapper> {
495 c10::intrusive_ptr<c10::ivalue::Future> fut =
496 reducer.run_allreduce_hook(bucket);
497 return std::make_shared<jit::PythonFutureWrapper>(fut);
498 },
499 py::call_guard<py::gil_scoped_release>())
500 .def(
501 "set_logger",
502 [](::c10d::Reducer& reducer,
503 const std::shared_ptr<::c10d::Logger> logger) {
504 std::weak_ptr<::c10d::Logger> logger_weakref = logger;
505 reducer.set_logger(logger_weakref);
506 });
507
508 shared_ptr_class_<::c10d::Logger>(module, "Logger")
509 .def(
510 py::init<std::shared_ptr<::c10d::Reducer>>(),
511 py::arg("reducer"),
512 py::call_guard<py::gil_scoped_release>())
513 .def(
514 "set_construction_data_and_log",
515 &::c10d::Logger::set_construction_data_and_log,
516 py::arg("module_name"),
517 py::arg("device_ids"),
518 py::arg("output_device"),
519 py::arg("broadcast_buffers"),
520 py::arg("has_sync_bn"),
521 py::arg("static_graph"),
522 py::call_guard<py::gil_scoped_release>())
523 .def(
524 "set_runtime_stats_and_log",
525 &::c10d::Logger::set_runtime_stats_and_log,
526 py::call_guard<py::gil_scoped_release>())
527 .def(
528 "set_error_and_log",
529 [](::c10d::Logger& logger, const std::string& error) {
530 logger.set_error_and_log(error);
531 },
532 py::call_guard<py::gil_scoped_release>())
533 .def(
534 "_get_ddp_logging_data",
535 &::c10d::Logger::get_ddp_logging_data,
536 py::call_guard<py::gil_scoped_release>())
537 .def(
538 "_set_comm_hook_name",
539 &::c10d::Logger::set_comm_hook,
540 py::arg("comm_hook"),
541 py::call_guard<py::gil_scoped_release>())
542 .def(
543 "_set_uneven_input_join",
544 &::c10d::Logger::set_uneven_input_join,
545 py::call_guard<py::gil_scoped_release>())
546 .def(
547 "_set_static_graph",
548 &::c10d::Logger::set_static_graph,
549 py::call_guard<py::gil_scoped_release>());
550
551 py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
552 An enum whose values correspond to different debug levels of the
553 torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
554 which can be set via the TORCH_DISTRIBUTED_DEBUG environment variable
555 or via ``set_debug_level()`` function.
556 )")
557 .value("OFF", ::c10d::DebugLevel::Off)
558 .value("INFO", ::c10d::DebugLevel::Info)
559 .value("DETAIL", ::c10d::DebugLevel::Detail);
560
561 module
562 .def(
563 "get_debug_level",
564 ::c10d::debug_level,
565 R"(Gets the debug level of the torch.distributed package.)")
566 .def(
567 "set_debug_level",
568 ::c10d::setDebugLevel,
569 R"(Sets the debug level of the torch.distributed package.)")
570 .def(
571 "set_debug_level_from_env",
572 ::c10d::setDebugLevelFromEnvironment,
573 R"(Sets the debug level of the torch.distributed package from the
574 ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
575
576 // TODO(crcrpar): Hardening `ReduceOp`.
577 // While keeping most op types as enum value,
578 // making `PREMUL_SUM` callable, i.e., allowing for
579 // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
580 // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
581 py::class_<::c10d::ReduceOp> reduce_op(
582 module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
583An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
584``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.
585
586``BAND``, ``BOR``, and ``BXOR`` reductions are not available when
587using the ``NCCL`` backend.
588
589``AVG`` divides values by the world size before summing across ranks.
590``AVG`` is only available with the ``NCCL`` backend,
591and only for NCCL versions 2.10 or later.
592
593``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction.
594``PREMUL_SUM`` is only available with the ``NCCL`` backend,
595and only available for NCCL versions 2.11 or later. Users are supposed to
596use ``torch.distributed._make_nccl_premul_sum``.
597
598Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors.
599
600The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
601They are used in specifying strategies for reduction collectives, e.g.,
602:func:`reduce`, :func:`all_reduce_multigpu`, etc.
603
604This class does not support ``__members__`` property.)");
605
606 reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>())
607 .def_readwrite("op", &::c10d::ReduceOp::op_);
608 // The following are for some kind of backward compatibility.
609 // Since c10d::ReduceOp had been an `enum class`, users can do comparison and
610 // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here
611 // I define some member methods.
612 reduce_op
613 // todo(crcrpar): Support `RedOpType == ReduceOp`.
614 .def(
615 // This calls `operator==(const ReduceOp::RedOpType)`
616 "__eq__",
617 [](const ::c10d::ReduceOp& self,
618 const ::c10d::ReduceOp::RedOpType& other) {
619 return self == other;
620 })
621 .def(
622 // This calls `operator==(const ReduceOp)` for the future support of
623 // `PREMUL_SUM` comparison
624 "__eq__",
625 [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) {
626 return self == other;
627 })
628 .def(
629 // With the above custom `__eq__`'s, I have to manually support the
630 // other types.
631 "__eq__",
632 [](const ::c10d::ReduceOp& self, py::object) { return false; })
633 .def(
634 "__hash__",
635 [](const ::c10d::ReduceOp& self) {
636 return static_cast<uint8_t>(self.op_);
637 })
638 .def(
639 "__copy__",
640 [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); })
641 .def(
642 "__deepcopy__",
643 [](const ::c10d::ReduceOp& self, const py::dict& memo) {
644 return ::c10d::ReduceOp(self);
645 })
646 .def(py::pickle(
647 [](const ::c10d::ReduceOp& r) {
648 // __getstate__
649 if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
650 return py::make_tuple(r.op_, py::none());
651 }
652 TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
653 const auto* preMulSupplement =
654 reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
655 r.supplement_.get());
656 if (!preMulSupplement->tensor_factor.defined()) {
657 return py::make_tuple(r.op_, preMulSupplement->double_factor);
658 } else {
659 return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
660 }
661 },
662 [](const py::tuple t) {
663 // __setstate__
664 TORCH_CHECK(t.size() == 2, "Invalid state");
665 const auto op =
666 static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
667 if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
668 return ::c10d::ReduceOp(op);
669 }
670 const auto preMulSupplement_factor = t[1];
671 if (py::isinstance<py::float_>(preMulSupplement_factor)) {
672 return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
673 } else {
674 return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
675 }
676 }));
677
678 py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
679 .value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
680 .value("AVG", ::c10d::ReduceOp::RedOpType::AVG)
681 .value("PRODUCT", ::c10d::ReduceOp::RedOpType::PRODUCT)
682 .value("MIN", ::c10d::ReduceOp::RedOpType::MIN)
683 .value("MAX", ::c10d::ReduceOp::RedOpType::MAX)
684 .value("BAND", ::c10d::ReduceOp::RedOpType::BAND)
685 .value("BOR", ::c10d::ReduceOp::RedOpType::BOR)
686 .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR)
687 .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM)
688 .export_values();
689
690 // note(crcrpar): This could be removed because users will not pass
691 // `RedOpType` to reduce collective ops Ref: [Implicit
692 // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions)
693 // Let us skip the explicit construction of `c10d::ReduceOp` from
694 // `c10d::ReduceOp::RedOpType` in Python.
695 py::implicitly_convertible<::c10d::ReduceOp::RedOpType, ::c10d::ReduceOp>();
696
697 module
698 .def(
699 "_make_nccl_premul_sum",
700 &::c10d::makeNCCLPreMulSum<double>,
701 py::arg("factor").noconvert(),
702 py::return_value_policy::copy, // seems safest
703 py::call_guard<py::gil_scoped_release>())
704 .def(
705 "_make_nccl_premul_sum",
706 &::c10d::makeNCCLPreMulSum<at::Tensor>,
707 py::arg("factor").noconvert(),
708 py::return_value_policy::copy, // seems safest
709 py::call_guard<py::gil_scoped_release>());
710
711 py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
712 .def(py::init<>())
713 .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
714 .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor)
715 .def_readwrite("timeout", &::c10d::BroadcastOptions::timeout);
716
717 py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
718 .def(py::init<>())
719 .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
720 .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
721
722 py::class_<::c10d::AllreduceCoalescedOptions>(
723 module, "AllreduceCoalescedOptions")
724 .def(py::init<>())
725 .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
726 .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
727
728 py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
729 .def(py::init<>())
730 .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
731 .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
732 .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
733 .def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
734
735 py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
736 .def(py::init<>())
737 .def_readwrite("timeout", &::c10d::AllgatherOptions::timeout);
738
739 py::class_<::c10d::GatherOptions>(module, "GatherOptions")
740 .def(py::init<>())
741 .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
742 .def_readwrite("timeout", &::c10d::GatherOptions::timeout);
743
744 py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
745 .def(py::init<>())
746 .def_readwrite("rootRank", &::c10d::ScatterOptions::rootRank)
747 .def_readwrite("timeout", &::c10d::ScatterOptions::timeout);
748
749 py::class_<::c10d::ReduceScatterOptions>(module, "ReduceScatterOptions")
750 .def(py::init<>())
751 .def_readwrite("reduceOp", &::c10d::ReduceScatterOptions::reduceOp)
752 .def_readwrite("timeout", &::c10d::ReduceScatterOptions::timeout);
753
754 py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
755 .def(py::init<>())
756 .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
757 .def_readwrite("timeout", &::c10d::BarrierOptions::timeout);
758
759 py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
760 .def(py::init<>())
761 .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
762
763 py::class_<::c10d::DistributedBackendOptions>(
764 module, "_DistributedBackendOptions")
765 .def(py::init<>())
766 .def_readwrite("store", &::c10d::DistributedBackendOptions::store)
767 .def_readwrite(
768 "group_rank", &::c10d::DistributedBackendOptions::group_rank)
769 .def_readwrite(
770 "group_size", &::c10d::DistributedBackendOptions::group_size)
771 .def_readwrite("timeout", &::c10d::DistributedBackendOptions::timeout)
772 .def_readwrite("group_id", &::c10d::DistributedBackendOptions::group_id)
773 .def_readwrite(
774 "global_ranks_in_group",
775 &::c10d::DistributedBackendOptions::global_ranks_in_group);
776
777 auto store =
778 py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>(
779 module,
780 "Store",
781 R"(
782Base class for all store implementations, such as the 3 provided by PyTorch
783distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`,
784and :class:`~torch.distributed.HashStore`).
785)")
786 // Default constructor.
787 .def(py::init<>())
788 // Convert from std::string to std::vector<uint8>.
789 .def(
790 "set",
791 [](::c10d::Store& store,
792 const std::string& key,
793 const std::string& value) {
794 std::vector<uint8_t> value_(value.begin(), value.end());
795 store.set(key, value_);
796 },
797 py::call_guard<py::gil_scoped_release>(),
798 R"(
799Inserts the key-value pair into the store based on the supplied ``key`` and
800``value``. If ``key`` already exists in the store, it will overwrite the old
801value with the new supplied ``value``.
802
803Arguments:
804 key (str): The key to be added to the store.
805 value (str): The value associated with ``key`` to be added to the store.
806
807Example::
808 >>> import torch.distributed as dist
809 >>> from datetime import timedelta
810 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
811 >>> store.set("first_key", "first_value")
812 >>> # Should return "first_value"
813 >>> store.get("first_key")
814)")
815 .def(
816 "compare_set",
817 [](::c10d::Store& store,
818 const std::string& key,
819 const std::string& expected_value,
820 const std::string& desired_value) -> py::bytes {
821 std::vector<uint8_t> expectedValue_(
822 expected_value.begin(), expected_value.end());
823 std::vector<uint8_t> desiredValue_(
824 desired_value.begin(), desired_value.end());
825 auto value =
826 store.compareSet(key, expectedValue_, desiredValue_);
827 return py::bytes(
828 reinterpret_cast<char*>(value.data()), value.size());
829 },
830 py::call_guard<py::gil_scoped_release>(),
831 R"(
832Inserts the key-value pair into the store based on the supplied ``key`` and
833performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value``
834will only be set if ``expected_value`` for the ``key`` already exists in the store or if ``expected_value``
835is an empty string.
836
837Arguments:
838 key (str): The key to be checked in the store.
839 expected_value (str): The value associated with ``key`` to be checked before insertion.
840 desired_value (str): The value associated with ``key`` to be added to the store.
841
842Example::
843 >>> import torch.distributed as dist
844 >>> from datetime import timedelta
845 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
846 >>> store.set("key", "first_value")
847 >>> store.compare_set("key", "first_value", "second_value")
848 >>> # Should return "second_value"
849 >>> store.get("key")
850)")
851 // Convert from std::vector<uint8_t> to py::bytes.
852 // The returned value is not guaranteed to be valid UTF-8.
853 .def(
854 "get",
855 [](::c10d::Store& store, const std::string& key) -> py::bytes {
856 auto value = [&]() {
857 py::gil_scoped_release guard;
858 return store.get(key);
859 }();
860 return py::bytes(
861 reinterpret_cast<char*>(value.data()), value.size());
862 },
863 R"(
864Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
865present in the store, the function will wait for ``timeout``, which is defined
866when initializing the store, before throwing an exception.
867
868Arguments:
869 key (str): The function will return the value associated with this key.
870
871Returns:
872 Value associated with ``key`` if ``key`` is in the store.
873
874Example::
875 >>> import torch.distributed as dist
876 >>> from datetime import timedelta
877 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
878 >>> store.set("first_key", "first_value")
879 >>> # Should return "first_value"
880 >>> store.get("first_key")
881)")
882 .def(
883 "add",
884 &::c10d::Store::add,
885 py::call_guard<py::gil_scoped_release>(),
886 R"(
887The first call to add for a given ``key`` creates a counter associated
888with ``key`` in the store, initialized to ``amount``. Subsequent calls to add
889with the same ``key`` increment the counter by the specified ``amount``.
890Calling :meth:`~torch.distributed.store.add` with a key that has already
891been set in the store by :meth:`~torch.distributed.store.set` will result
892in an exception.
893
894Arguments:
895 key (str): The key in the store whose counter will be incremented.
896 amount (int): The quantity by which the counter will be incremented.
897
898Example::
899 >>> import torch.distributed as dist
900 >>> from datetime import timedelta
901 >>> # Using TCPStore as an example, other store types can also be used
902 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
903 >>> store.add("first_key", 1)
904 >>> store.add("first_key", 6)
905 >>> # Should return 7
906 >>> store.get("first_key")
907)")
908 .def(
909 "delete_key",
910 &::c10d::Store::deleteKey,
911 py::call_guard<py::gil_scoped_release>(),
912 R"(
913Deletes the key-value pair associated with ``key`` from the store. Returns
914`true` if the key was successfully deleted, and `false` if it was not.
915
916.. warning::
917 The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API
918 with the :class:`~torch.distributed.FileStore` will result in an exception.
919
920Arguments:
921 key (str): The key to be deleted from the store
922
923Returns:
924 `True` if ``key`` was deleted, otherwise `False`.
925
926Example::
927 >>> import torch.distributed as dist
928 >>> from datetime import timedelta
929 >>> # Using TCPStore as an example, HashStore can also be used
930 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
931 >>> store.set("first_key")
932 >>> # This should return true
933 >>> store.delete_key("first_key")
934 >>> # This should return false
935 >>> store.delete_key("bad_key")
936)")
937 .def(
938 "num_keys",
939 &::c10d::Store::getNumKeys,
940 py::call_guard<py::gil_scoped_release>(),
941 R"(
942Returns the number of keys set in the store. Note that this number will typically
943be one greater than the number of keys added by :meth:`~torch.distributed.store.set`
944and :meth:`~torch.distributed.store.add` since one key is used to coordinate all
945the workers using the store.
946
947.. warning::
948 When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.
949
950Returns:
951 The number of keys present in the store.
952
953Example::
954 >>> import torch.distributed as dist
955 >>> from datetime import timedelta
956 >>> # Using TCPStore as an example, other store types can also be used
957 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
958 >>> store.set("first_key", "first_value")
959 >>> # This should return 2
960 >>> store.num_keys()
961)")
962 .def(
963 "set_timeout",
964 &::c10d::Store::setTimeout,
965 py::call_guard<py::gil_scoped_release>(),
966 R"(
967Sets the store's default timeout. This timeout is used during initialization and in
968:meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`.
969
970Arguments:
971 timeout (timedelta): timeout to be set in the store.
972
973Example::
974 >>> import torch.distributed as dist
975 >>> from datetime import timedelta
976 >>> # Using TCPStore as an example, other store types can also be used
977 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
978 >>> store.set_timeout(timedelta(seconds=10))
979 >>> # This will throw an exception after 10 seconds
980 >>> store.wait(["bad_key"])
981)")
982 .def(
983 "wait",
984 [](::c10d::Store& store, const std::vector<std::string>& keys) {
985 store.wait(keys);
986 },
987 py::call_guard<py::gil_scoped_release>(),
988 R"(
989Waits for each key in ``keys`` to be added to the store. If not all keys are
990set before the ``timeout`` (set during store initialization), then ``wait``
991will throw an exception.
992
993Arguments:
994 keys (list): List of keys on which to wait until they are set in the store.
995
996Example::
997 >>> import torch.distributed as dist
998 >>> from datetime import timedelta
999 >>> # Using TCPStore as an example, other store types can also be used
1000 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1001 >>> # This will throw an exception after 30 seconds
1002 >>> store.wait(["bad_key"])
1003)")
1004 .def(
1005 "wait",
1006 [](::c10d::Store& store,
1007 const std::vector<std::string>& keys,
1008 const std::chrono::milliseconds& timeout) {
1009 store.wait(keys, timeout);
1010 },
1011 py::call_guard<py::gil_scoped_release>(),
1012 R"(
1013Waits for each key in ``keys`` to be added to the store, and throws an exception
1014if the keys have not been set by the supplied ``timeout``.
1015
1016Arguments:
1017 keys (list): List of keys on which to wait until they are set in the store.
1018 timeout (timedelta): Time to wait for the keys to be added before throwing an exception.
1019
1020Example::
1021 >>> import torch.distributed as dist
1022 >>> from datetime import timedelta
1023 >>> # Using TCPStore as an example, other store types can also be used
1024 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1025 >>> # This will throw an exception after 10 seconds
1026 >>> store.wait(["bad_key"], timedelta(seconds=10))
1027)")
1028 .def_property_readonly(
1029 "timeout",
1030 &::c10d::Store::getTimeout,
1031 R"(Gets the timeout of the store.)");
1032
1033 intrusive_ptr_class_<::c10d::FileStore>(
1034 module,
1035 "FileStore",
1036 store,
1037 R"(
1038A store implementation that uses a file to store the underlying key-value pairs.
1039
1040Arguments:
1041 file_name (str): path of the file in which to store the key-value pairs
1042 world_size (int, optional): The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users).
1043
1044Example::
1045 >>> import torch.distributed as dist
1046 >>> store1 = dist.FileStore("/tmp/filestore", 2)
1047 >>> store2 = dist.FileStore("/tmp/filestore", 2)
1048 >>> # Use any of the store methods from either the client or server after initialization
1049 >>> store1.set("first_key", "first_value")
1050 >>> store2.get("first_key")
1051
1052 )")
1053 .def(
1054 py::init<const std::string&, int>(),
1055 py::arg("file_name"),
1056 py::arg("world_size") = -1)
1057 .def_property_readonly(
1058 "path",
1059 &::c10d::FileStore::getPath,
1060 R"(Gets the path of the file used by FileStore to store key-value pairs.)");
1061
1062#ifndef _WIN32
1063 intrusive_ptr_class_<::c10d::HashStore>(
1064 module,
1065 "HashStore",
1066 store,
1067 R"(
1068A thread-safe store implementation based on an underlying hashmap. This store can be used
1069within the same process (for example, by other threads), but cannot be used across processes.
1070
1071Example::
1072 >>> import torch.distributed as dist
1073 >>> store = dist.HashStore()
1074 >>> # store can be used from other threads
1075 >>> # Use any of the store methods after initialization
1076 >>> store.set("first_key", "first_value")
1077 )")
1078 .def(py::init<>());
1079#endif
1080
1081 intrusive_ptr_class_<::c10d::TCPStore>(
1082 module,
1083 "TCPStore",
1084 store,
1085 R"(
1086A TCP-based distributed key-value store implementation. The server store holds
1087the data, while the client stores can connect to the server store over TCP and
1088perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value
1089pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. There
1090should always be one server store initialized because the client store(s) will wait for
1091the server to establish a connection.
1092
1093Arguments:
1094 host_name (str): The hostname or IP Address the server store should run on.
1095 port (int): The port on which the server store should listen for incoming requests.
1096 world_size (int, optional): The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users).
1097 is_master (bool, optional): True when initializing the server store and False for client stores. Default is False.
1098 timeout (timedelta, optional): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. Default is timedelta(seconds=300)
1099 wait_for_worker (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True.
1100
1101Example::
1102 >>> import torch.distributed as dist
1103 >>> from datetime import timedelta
1104 >>> # Run on process 1 (server)
1105 >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
1106 >>> # Run on process 2 (client)
1107 >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
1108 >>> # Use any of the store methods from either the client or server after initialization
1109 >>> server_store.set("first_key", "first_value")
1110 >>> client_store.get("first_key")
1111 )")
1112 .def(
1113 py::init([](const std::string& host,
1114 uint16_t port,
1115 c10::optional<int> worldSize,
1116 bool isServer,
1117 std::chrono::milliseconds timeout,
1118 bool waitWorkers,
1119 bool multiTenant) {
1120 c10::optional<std::size_t> numWorkers = c10::nullopt;
1121 if (worldSize.has_value() && worldSize.value() > -1) {
1122 numWorkers = static_cast<std::size_t>(worldSize.value());
1123 }
1124
1125 ::c10d::TCPStoreOptions opts{
1126 port, isServer, numWorkers, waitWorkers, timeout, multiTenant};
1127
1128 return c10::make_intrusive<::c10d::TCPStore>(host, opts);
1129 }),
1130 py::arg("host_name"),
1131 py::arg("port"),
1132 py::arg("world_size") = py::none(),
1133 // using noconvert() requires this argument to be True or False
1134 // prevents accidental implicit conversion to bool
1135 py::arg("is_master").noconvert() = false,
1136 py::arg("timeout") =
1137 std::chrono::milliseconds(::c10d::Store::kDefaultTimeout),
1138 py::arg("wait_for_workers") = true,
1139 py::arg("multi_tenant") = false)
1140 .def_property_readonly(
1141 "host",
1142 &::c10d::TCPStore::getHost,
1143 R"(Gets the hostname on which the store listens for requests.)")
1144
1145 .def_property_readonly(
1146 "port",
1147 &::c10d::TCPStore::getPort,
1148 R"(Gets the port number on which the store listens for requests.)");
1149
1150 intrusive_ptr_class_<::c10d::PrefixStore>(
1151 module,
1152 "PrefixStore",
1153 store,
1154 R"(
1155A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`,
1156:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`)
1157that adds a prefix to each key inserted to the store.
1158
1159Arguments:
1160 prefix (str): The prefix string that is prepended to each key before being inserted into the store.
1161 store (torch.distributed.store): A store object that forms the underlying key-value store.
1162 )")
1163 .def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>())
1164 .def_property_readonly(
1165 "underlying_store",
1166 &::c10d::PrefixStore::getUnderlyingStore,
1167 R"(Gets the underlying store object that PrefixStore wraps around.)");
1168
1169 auto processGroup =
1170 py::class_<
1171 ::c10d::ProcessGroup,
1172 c10::intrusive_ptr<::c10d::ProcessGroup>,
1173 ::c10d::PyProcessGroup>(module, "ProcessGroup")
1174 .def(py::init<int, int>())
1175 .def(
1176 py::init<
1177 const c10::intrusive_ptr<::c10d::Store>&,
1178 int,
1179 int,
1180 c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
1181 py::call_guard<py::gil_scoped_release>())
1182 .def("rank", &::c10d::ProcessGroup::getRank)
1183 .def("size", &::c10d::ProcessGroup::getSize)
1184 .def("name", &::c10d::ProcessGroup::getBackendName)
1185 .def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
1186 .def(
1187 "broadcast",
1188 &::c10d::ProcessGroup::broadcast,
1189 py::arg("tensors"),
1190 py::arg("opts") = ::c10d::BroadcastOptions(),
1191 py::call_guard<py::gil_scoped_release>())
1192 .def(
1193 "broadcast",
1194 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1195 at::Tensor& x,
1196 int rootRank) {
1197 ::c10d::BroadcastOptions opts;
1198 opts.rootRank = rootRank;
1199 std::vector<at::Tensor> tensors = {x};
1200 return self->broadcast(tensors, opts);
1201 },
1202 py::arg("tensor"),
1203 py::arg("root"),
1204 py::call_guard<py::gil_scoped_release>())
1205 .def(
1206 "allreduce",
1207 &::c10d::ProcessGroup::allreduce,
1208 py::arg("tensors"),
1209 py::arg("opts") = ::c10d::AllreduceOptions(),
1210 py::call_guard<py::gil_scoped_release>())
1211 .def(
1212 "allreduce",
1213 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1214 std::vector<at::Tensor>& xs,
1215 ::c10d::ReduceOp op) {
1216 ::c10d::AllreduceOptions opts;
1217 opts.reduceOp = op;
1218 return self->allreduce(xs, opts);
1219 },
1220 py::arg("tensors"),
1221 py::arg("op") = ::c10d::ReduceOp::SUM,
1222 py::call_guard<py::gil_scoped_release>())
1223
1224 .def(
1225 "allreduce",
1226 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1227 at::Tensor& x,
1228 ::c10d::ReduceOp op) {
1229 ::c10d::AllreduceOptions opts;
1230 opts.reduceOp = op;
1231 std::vector<at::Tensor> xs = {x};
1232 return self->allreduce(xs, opts);
1233 },
1234 py::arg("tensor"),
1235 py::arg("op") = ::c10d::ReduceOp::SUM,
1236 py::call_guard<py::gil_scoped_release>())
1237 .def(
1238 "allreduce_coalesced",
1239 &::c10d::ProcessGroup::allreduce_coalesced,
1240 py::arg("tensors"),
1241 py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
1242 py::call_guard<py::gil_scoped_release>())
1243
1244 .def(
1245 "reduce",
1246 &::c10d::ProcessGroup::reduce,
1247 py::arg("tensors"),
1248 py::arg("opts") = ::c10d::ReduceOptions(),
1249 py::call_guard<py::gil_scoped_release>())
1250
1251 .def(
1252 "reduce",
1253 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1254 at::Tensor& x,
1255 int rootRank,
1256 ::c10d::ReduceOp op) {
1257 ::c10d::ReduceOptions opts;
1258 opts.reduceOp = op;
1259 opts.rootRank = rootRank;
1260 std::vector<at::Tensor> xs = {x};
1261 return self->reduce(xs, opts);
1262 },
1263 py::arg("tensor"),
1264 py::arg("root"),
1265 py::arg("op") = ::c10d::ReduceOp::SUM,
1266 py::call_guard<py::gil_scoped_release>())
1267 .def(
1268 "allgather",
1269 &::c10d::ProcessGroup::allgather,
1270 py::arg("output_tensors"),
1271 py::arg("input_tensors"),
1272 py::arg("opts") = ::c10d::AllgatherOptions(),
1273 py::call_guard<py::gil_scoped_release>())
1274 .def(
1275 "allgather",
1276 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1277 std::vector<at::Tensor>& output,
1278 at::Tensor& input) {
1279 std::vector<std::vector<at::Tensor>> outputs = {output};
1280 std::vector<at::Tensor> inputs = {input};
1281 return self->allgather(
1282 outputs, inputs, ::c10d::AllgatherOptions());
1283 },
1284 py::arg("output_tensors"),
1285 py::arg("input_tensor"),
1286 py::call_guard<py::gil_scoped_release>())
1287 .def(
1288 "_allgather_base",
1289 &::c10d::ProcessGroup::_allgather_base,
1290 py::arg("output"),
1291 py::arg("input"),
1292 py::arg("opts") = ::c10d::AllgatherOptions(),
1293 py::call_guard<py::gil_scoped_release>())
1294 .def(
1295 "allgather_coalesced",
1296 &::c10d::ProcessGroup::allgather_coalesced,
1297 py::arg("output_lists"),
1298 py::arg("input_list"),
1299 py::arg("opts") = ::c10d::AllgatherOptions(),
1300 py::call_guard<py::gil_scoped_release>())
1301 .def(
1302 "gather",
1303 &::c10d::ProcessGroup::gather,
1304 py::arg("output_tensors"),
1305 py::arg("input_tensors"),
1306 py::arg("opts") = ::c10d::GatherOptions(),
1307 py::call_guard<py::gil_scoped_release>())
1308
1309 .def(
1310 "gather",
1311 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1312 std::vector<at::Tensor>& output,
1313 at::Tensor& input,
1314 int rootRank) {
1315 ::c10d::GatherOptions opts;
1316 opts.rootRank = rootRank;
1317 std::vector<std::vector<at::Tensor>> outputs = {output};
1318 std::vector<at::Tensor> inputs = {input};
1319 return self->gather(outputs, inputs, opts);
1320 },
1321 py::arg("output_tensors"),
1322 py::arg("input_tensor"),
1323 py::arg("root"),
1324 py::call_guard<py::gil_scoped_release>())
1325 .def(
1326 "scatter",
1327 &::c10d::ProcessGroup::scatter,
1328 py::arg("output_tensors"),
1329 py::arg("input_tensors"),
1330 py::arg("opts") = ::c10d::ScatterOptions(),
1331 py::call_guard<py::gil_scoped_release>())
1332 .def(
1333 "scatter",
1334 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1335 at::Tensor& output,
1336 std::vector<at::Tensor>& input,
1337 int rootRank) {
1338 ::c10d::ScatterOptions opts;
1339 opts.rootRank = rootRank;
1340 std::vector<std::vector<at::Tensor>> inputs = {input};
1341 std::vector<at::Tensor> outputs = {output};
1342 return self->scatter(outputs, inputs, opts);
1343 },
1344 py::arg("output_tensor"),
1345 py::arg("input_tensors"),
1346 py::arg("root"),
1347 py::call_guard<py::gil_scoped_release>())
1348 .def(
1349 "reduce_scatter",
1350 &::c10d::ProcessGroup::reduce_scatter,
1351 py::arg("output_tensors"),
1352 py::arg("input_tensors"),
1353 py::arg("opts") = ::c10d::ReduceScatterOptions(),
1354 py::call_guard<py::gil_scoped_release>())
1355 .def(
1356 "reduce_scatter",
1357 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1358 at::Tensor& output,
1359 std::vector<at::Tensor>& input,
1360 ::c10d::ReduceOp op) {
1361 std::vector<at::Tensor> outputs = {output};
1362 std::vector<std::vector<at::Tensor>> inputs = {input};
1363 ::c10d::ReduceScatterOptions opts;
1364 opts.reduceOp = op;
1365 return self->reduce_scatter(outputs, inputs, opts);
1366 },
1367 py::arg("output"),
1368 py::arg("input"),
1369 py::arg("op") = ::c10d::ReduceOp::SUM,
1370 py::call_guard<py::gil_scoped_release>())
1371 .def(
1372 "_reduce_scatter_base",
1373 &::c10d::ProcessGroup::_reduce_scatter_base,
1374 py::arg("outputTensor"),
1375 py::arg("inputTensor"),
1376 py::arg("opts") = ::c10d::ReduceScatterOptions(),
1377 py::call_guard<py::gil_scoped_release>())
1378 .def(
1379 "alltoall_base",
1380 &::c10d::ProcessGroup::alltoall_base,
1381 py::arg("output"),
1382 py::arg("input"),
1383 py::arg("output_split_sizes"),
1384 py::arg("input_split_sizes"),
1385 py::arg("opts") = ::c10d::AllToAllOptions(),
1386 py::call_guard<py::gil_scoped_release>())
1387 .def(
1388 "alltoall",
1389 &::c10d::ProcessGroup::alltoall,
1390 py::arg("output_tensors"),
1391 py::arg("input_tensors"),
1392 py::arg("opts") = ::c10d::AllToAllOptions(),
1393 py::call_guard<py::gil_scoped_release>())
1394 .def(
1395 "send",
1396 &::c10d::ProcessGroup::send,
1397 py::arg("tensors"),
1398 py::arg("dstRank"),
1399 py::arg("tag"),
1400 py::call_guard<py::gil_scoped_release>())
1401 .def(
1402 "recv",
1403 &::c10d::ProcessGroup::recv,
1404 py::arg("tensors"),
1405 py::arg("srcRank"),
1406 py::arg("tag"),
1407 py::call_guard<py::gil_scoped_release>())
1408 .def(
1409 "recv_anysource",
1410 &::c10d::ProcessGroup::recvAnysource,
1411 py::call_guard<py::gil_scoped_release>())
1412 .def(
1413 "barrier",
1414 &::c10d::ProcessGroup::barrier,
1415 py::arg("opts") = ::c10d::BarrierOptions(),
1416 py::call_guard<py::gil_scoped_release>())
1417 .def(
1418 "_set_sequence_number_for_group",
1419 &::c10d::ProcessGroup::setSequenceNumberForGroup,
1420 py::call_guard<py::gil_scoped_release>())
1421 .def(
1422 "_get_sequence_number_for_group",
1423 &::c10d::ProcessGroup::getSequenceNumberForGroup,
1424 py::call_guard<py::gil_scoped_release>())
1425 .def(
1426 "monitored_barrier",
1427 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1428 const std::chrono::milliseconds& timeout,
1429 bool waitAllRanks) {
1430 ::c10d::BarrierOptions opts;
1431 opts.timeout = timeout;
1432 return self->monitoredBarrier(opts, waitAllRanks);
1433 },
1434 py::arg("timeout") = ::c10d::kUnsetTimeout,
1435 py::arg("wait_all_ranks") = false,
1436 py::call_guard<py::gil_scoped_release>())
1437 .def(
1438 "_get_backend_name",
1439 &::c10d::ProcessGroup::getBackendName,
1440 py::call_guard<py::gil_scoped_release>())
1441 .def(
1442 "_start_coalescing",
1443 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1444 const c10::Device& device) {
1445 self->startCoalescing(device.type());
1446 },
1447 py::arg("device_type"),
1448 py::call_guard<py::gil_scoped_release>())
1449 .def(
1450 "_end_coalescing",
1451 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1452 const c10::Device& device,
1453 std::vector<c10::intrusive_ptr<::c10d::Work>>& reqs) {
1454 self->endCoalescing(device.type(), reqs);
1455 },
1456 py::arg("device_type"),
1457 py::arg("reqs"),
1458 py::call_guard<py::gil_scoped_release>())
1459 .def(
1460 "_register_backend",
1461 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1462 const c10::Device& device,
1463 const ::c10d::ProcessGroup::BackendType& backendType,
1464 const c10::optional<c10::intrusive_ptr<::c10d::Backend>>&
1465 backend) {
1466 self->setBackend(device.type(), backendType, backend);
1467 },
1468 py::arg("device"),
1469 py::arg("backend_type"),
1470 py::arg("backend") =
1471 c10::optional<c10::intrusive_ptr<::c10d::Backend>>(),
1472 py::call_guard<py::gil_scoped_release>())
1473 .def(
1474 "_get_backend",
1475 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1476 const c10::Device& device) {
1477 return self->getBackend(device.type());
1478 },
1479 py::arg("device"),
1480 py::call_guard<py::gil_scoped_release>());
1481
1482 py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType")
1483 .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
1484 .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO)
1485 .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL)
1486 .value("UCC", ::c10d::ProcessGroup::BackendType::UCC)
1487 .value("MPI", ::c10d::ProcessGroup::BackendType::MPI)
1488 .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
1489 .export_values();
1490
1491 // base ProcessGroup::Options binding
1492 auto processGroupOptions =
1493 intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
1494 processGroup,
1495 "Options",
1496 R"(
1497Base class for all processs group options implementations, such as the nccl
1498options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
1499)")
1500 .def(
1501 py::init([](const std::string& backend,
1502 const std::chrono::milliseconds& timeout) {
1503 return c10::make_intrusive<::c10d::ProcessGroup::Options>(
1504 backend, timeout);
1505 }),
1506 py::arg("backend"),
1507 py::arg("timeout") = kProcessGroupDefaultTimeout,
1508 py::call_guard<py::gil_scoped_release>())
1509 .def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
1510 .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
1511
1512#ifndef _WIN32
1513 module.def(
1514 "_round_robin_process_groups",
1515 [](std::vector<c10::intrusive_ptr<::c10d::ProcessGroup>> processGroups)
1516 -> c10::intrusive_ptr<::c10d::ProcessGroup> {
1517 if (processGroups.empty()) {
1518 throw std::invalid_argument("Specify at least 1 process group");
1519 }
1520 const auto& first = processGroups.front();
1521 return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>(
1522 first->getRank(), first->getSize(), std::move(processGroups));
1523 },
1524 py::arg("process_groups"),
1525 py::call_guard<py::gil_scoped_release>());
1526#endif
1527
1528 // TODO: The collection definitions handles direct instantiation of
1529 // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
1530 // and should be removed once all tests are transitioned
1531 auto backend =
1532 py::class_<::c10d::Backend, c10::intrusive_ptr<::c10d::Backend>>(
1533 module, "Backend")
1534 .def("rank", &::c10d::Backend::getRank)
1535 .def("size", &::c10d::Backend::getSize)
1536 .def("name", &::c10d::Backend::getBackendName)
1537 .def(
1538 "broadcast",
1539 &::c10d::Backend::broadcast,
1540 py::arg("tensors"),
1541 py::arg("opts") = ::c10d::BroadcastOptions(),
1542 py::call_guard<py::gil_scoped_release>())
1543 .def(
1544 "broadcast",
1545 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1546 at::Tensor& x,
1547 int rootRank) {
1548 ::c10d::BroadcastOptions opts;
1549 opts.rootRank = rootRank;
1550 std::vector<at::Tensor> xs = {x};
1551 return self->broadcast(xs, opts);
1552 },
1553 py::arg("tensor"),
1554 py::arg("root"),
1555 py::call_guard<py::gil_scoped_release>())
1556 .def(
1557 "allreduce",
1558 &::c10d::Backend::allreduce,
1559 py::arg("tensors"),
1560 py::arg("opts") = ::c10d::AllreduceOptions(),
1561 py::call_guard<py::gil_scoped_release>())
1562 .def(
1563 "allreduce",
1564 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1565 std::vector<at::Tensor>& xs,
1566 ::c10d::ReduceOp op) {
1567 ::c10d::AllreduceOptions opts;
1568 opts.reduceOp = op;
1569 return self->allreduce(xs, opts);
1570 },
1571 py::arg("tensors"),
1572 py::arg("op") = ::c10d::ReduceOp::SUM,
1573 py::call_guard<py::gil_scoped_release>())
1574 .def(
1575 "allreduce",
1576 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1577 at::Tensor& x,
1578 ::c10d::ReduceOp op) {
1579 ::c10d::AllreduceOptions opts;
1580 opts.reduceOp = op;
1581 std::vector<at::Tensor> xs = {x};
1582 return self->allreduce(xs, opts);
1583 },
1584 py::arg("tensor"),
1585 py::arg("op") = ::c10d::ReduceOp::SUM,
1586 py::call_guard<py::gil_scoped_release>())
1587 .def(
1588 "allreduce_coalesced",
1589 &::c10d::Backend::allreduce_coalesced,
1590 py::arg("tensors"),
1591 py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
1592 py::call_guard<py::gil_scoped_release>())
1593 .def(
1594 "reduce",
1595 &::c10d::Backend::reduce,
1596 py::arg("tensors"),
1597 py::arg("opts") = ::c10d::ReduceOptions(),
1598 py::call_guard<py::gil_scoped_release>())
1599 .def(
1600 "reduce",
1601 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1602 at::Tensor& x,
1603 int rootRank,
1604 ::c10d::ReduceOp op) {
1605 ::c10d::ReduceOptions opts;
1606 opts.reduceOp = op;
1607 opts.rootRank = rootRank;
1608 std::vector<at::Tensor> xs = {x};
1609 return self->reduce(xs, opts);
1610 },
1611 py::arg("tensor"),
1612 py::arg("root"),
1613 py::arg("op") = ::c10d::ReduceOp::SUM,
1614 py::call_guard<py::gil_scoped_release>())
1615 .def(
1616 "allgather",
1617 &::c10d::Backend::allgather,
1618 py::arg("output_tensors"),
1619 py::arg("input_tensors"),
1620 py::arg("opts") = ::c10d::AllgatherOptions(),
1621 py::call_guard<py::gil_scoped_release>())
1622 .def(
1623 "_allgather_base",
1624 &::c10d::Backend::_allgather_base,
1625 py::arg("output"),
1626 py::arg("input"),
1627 py::arg("opts") = ::c10d::AllgatherOptions(),
1628 py::call_guard<py::gil_scoped_release>())
1629 .def(
1630 "allgather",
1631 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1632 std::vector<at::Tensor>& output,
1633 at::Tensor& input) {
1634 std::vector<std::vector<at::Tensor>> outputs = {output};
1635 std::vector<at::Tensor> inputs = {input};
1636 return self->allgather(
1637 outputs, inputs, ::c10d::AllgatherOptions());
1638 },
1639 py::arg("output_tensors"),
1640 py::arg("input_tensor"),
1641 py::call_guard<py::gil_scoped_release>())
1642 .def(
1643 "allgather_coalesced",
1644 &::c10d::Backend::allgather_coalesced,
1645 py::arg("output_lists"),
1646 py::arg("input_list"),
1647 py::arg("opts") = ::c10d::AllgatherOptions(),
1648 py::call_guard<py::gil_scoped_release>())
1649 .def(
1650 "gather",
1651 &::c10d::Backend::gather,
1652 py::arg("output_tensors"),
1653 py::arg("input_tensors"),
1654 py::arg("opts") = ::c10d::GatherOptions(),
1655 py::call_guard<py::gil_scoped_release>())
1656 .def(
1657 "gather",
1658 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1659 std::vector<at::Tensor>& output,
1660 at::Tensor& input,
1661 int rootRank) {
1662 ::c10d::GatherOptions opts;
1663 opts.rootRank = rootRank;
1664 std::vector<std::vector<at::Tensor>> outputs = {output};
1665 std::vector<at::Tensor> inputs = {input};
1666 return self->gather(outputs, inputs, opts);
1667 },
1668 py::arg("output_tensors"),
1669 py::arg("input_tensor"),
1670 py::arg("root"),
1671 py::call_guard<py::gil_scoped_release>())
1672 .def(
1673 "scatter",
1674 &::c10d::Backend::scatter,
1675 py::arg("output_tensors"),
1676 py::arg("input_tensors"),
1677 py::arg("opts") = ::c10d::ScatterOptions(),
1678 py::call_guard<py::gil_scoped_release>())
1679 .def(
1680 "scatter",
1681 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1682 at::Tensor& output,
1683 std::vector<at::Tensor>& input,
1684 int rootRank) {
1685 ::c10d::ScatterOptions opts;
1686 opts.rootRank = rootRank;
1687 std::vector<std::vector<at::Tensor>> inputs = {input};
1688 std::vector<at::Tensor> outputs = {output};
1689 return self->scatter(outputs, inputs, opts);
1690 },
1691 py::arg("output_tensor"),
1692 py::arg("input_tensors"),
1693 py::arg("root"),
1694 py::call_guard<py::gil_scoped_release>())
1695 .def(
1696 "reduce_scatter",
1697 &::c10d::Backend::reduce_scatter,
1698 py::arg("output_tensors"),
1699 py::arg("input_tensors"),
1700 py::arg("opts") = ::c10d::ReduceScatterOptions(),
1701 py::call_guard<py::gil_scoped_release>())
1702 .def(
1703 "reduce_scatter",
1704 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1705 at::Tensor& output,
1706 std::vector<at::Tensor>& input,
1707 ::c10d::ReduceOp op) {
1708 std::vector<at::Tensor> outputs = {output};
1709 std::vector<std::vector<at::Tensor>> inputs = {input};
1710 ::c10d::ReduceScatterOptions opts;
1711 opts.reduceOp = op;
1712 return self->reduce_scatter(outputs, inputs, opts);
1713 },
1714 py::arg("output_tensors"),
1715 py::arg("input_tensor"),
1716 py::arg("op") = ::c10d::ReduceOp::SUM,
1717 py::call_guard<py::gil_scoped_release>())
1718 .def(
1719 "_reduce_scatter_base",
1720 &::c10d::Backend::_reduce_scatter_base,
1721 py::arg("outputTensor"),
1722 py::arg("inputTensor"),
1723 py::arg("opts") = ::c10d::ReduceScatterOptions(),
1724 py::call_guard<py::gil_scoped_release>())
1725 .def(
1726 "alltoall_base",
1727 &::c10d::Backend::alltoall_base,
1728 py::arg("output_tensor"),
1729 py::arg("input_tensor"),
1730 py::arg("output_split_sizes"),
1731 py::arg("input_split_sizes"),
1732 py::arg("opts") = ::c10d::AllToAllOptions(),
1733 py::call_guard<py::gil_scoped_release>())
1734 .def(
1735 "alltoall_base",
1736 [](::c10d::Backend& self,
1737 at::Tensor& output,
1738 at::Tensor& input,
1739 std::vector<int64_t> outputSplitSizes,
1740 std::vector<int64_t> inputSplitSizes) {
1741 return self.alltoall_base(
1742 output,
1743 input,
1744 outputSplitSizes,
1745 inputSplitSizes,
1746 ::c10d::AllToAllOptions());
1747 },
1748 py::arg("output"),
1749 py::arg("input"),
1750 py::arg("output_split_sizes"),
1751 py::arg("input_split_sizes"),
1752 py::call_guard<py::gil_scoped_release>())
1753 .def(
1754 "alltoall",
1755 &::c10d::Backend::alltoall,
1756 py::arg("output_tensor"),
1757 py::arg("input_tensor"),
1758 py::arg("opts") = ::c10d::AllToAllOptions(),
1759 py::call_guard<py::gil_scoped_release>())
1760 .def(
1761 "send",
1762 &::c10d::Backend::send,
1763 py::arg("tensors"),
1764 py::arg("dstRank"),
1765 py::arg("tag"),
1766 py::call_guard<py::gil_scoped_release>())
1767 .def(
1768 "recv",
1769 &::c10d::Backend::recv,
1770 py::arg("tensors"),
1771 py::arg("srcRank"),
1772 py::arg("tag"),
1773 py::call_guard<py::gil_scoped_release>())
1774 .def(
1775 "recv_anysource",
1776 &::c10d::Backend::recvAnysource,
1777 py::call_guard<py::gil_scoped_release>())
1778 .def(
1779 "barrier",
1780 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1781 const ::c10d::BarrierOptions& opts) {
1782 return self->barrier(opts);
1783 },
1784 py::arg("opts") = ::c10d::BarrierOptions(),
1785 py::call_guard<py::gil_scoped_release>())
1786 .def(
1787 "_set_sequence_number_for_group",
1788 &::c10d::Backend::setSequenceNumberForGroup,
1789 py::call_guard<py::gil_scoped_release>())
1790 .def(
1791 "_get_sequence_number_for_group",
1792 &::c10d::Backend::getSequenceNumberForGroup,
1793 py::call_guard<py::gil_scoped_release>())
1794 .def(
1795 "monitored_barrier",
1796 [](const c10::intrusive_ptr<::c10d::Backend>& self,
1797 const std::chrono::milliseconds& timeout,
1798 bool waitAllRanks) {
1799 ::c10d::BarrierOptions opts;
1800 opts.timeout = timeout;
1801 return self->monitoredBarrier(opts, waitAllRanks);
1802 },
1803 py::arg("timeout") = ::c10d::kUnsetTimeout,
1804 py::arg("wait_all_ranks") = false,
1805 py::call_guard<py::gil_scoped_release>())
1806 .def(
1807 "_get_backend_name",
1808 &::c10d::Backend::getBackendName,
1809 py::call_guard<py::gil_scoped_release>())
1810 .def(
1811 "_start_coalescing",
1812 &::c10d::Backend::startCoalescing,
1813 py::call_guard<py::gil_scoped_release>())
1814 .def(
1815 "_end_coalescing",
1816 &::c10d::Backend::endCoalescing,
1817 py::arg("reqs"),
1818 py::call_guard<py::gil_scoped_release>());
1819
1820#ifdef USE_C10D_GLOO
1821 static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
1822
1823 auto processGroupGloo =
1824 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>(
1825 module, "ProcessGroupGloo", backend);
1826
1827 shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
1828
1829 intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
1830 processGroupGloo, "_Options", processGroupOptions)
1831 .def(py::init<>())
1832 .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
1833 .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
1834
1835 processGroupGloo
1836 .def_static(
1837 "create_device",
1838 [](const std::string& hostname, const std::string& interface)
1839 -> std::shared_ptr<::gloo::transport::Device> {
1840 if (!hostname.empty()) {
1841 return ::c10d::ProcessGroupGloo::createDeviceForHostname(
1842 hostname);
1843 }
1844 if (!interface.empty()) {
1845 return ::c10d::ProcessGroupGloo::createDeviceForInterface(
1846 interface);
1847 }
1848 throw std::invalid_argument(
1849 "Specify either `hostname` or `interface` argument.");
1850 },
1851 py::arg("hostname") = "",
1852 py::arg("interface") = "")
1853 .def_static(
1854 "create_default_device",
1855 &::c10d::ProcessGroupGloo::createDefaultDevice);
1856
1857 processGroupGloo
1858 .def(
1859 py::init<
1860 const c10::intrusive_ptr<::c10d::Store>&,
1861 int,
1862 int,
1863 c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(),
1864 py::call_guard<py::gil_scoped_release>())
1865 .def(
1866 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
1867 int rank,
1868 int size,
1869 std::chrono::milliseconds timeout) {
1870 auto options = ::c10d::ProcessGroupGloo::Options::create();
1871
1872 // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set.
1873 char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
1874 if (ifnameEnv && strlen(ifnameEnv) > 1) {
1875 for (const auto& iface : ::c10d::split(',', ifnameEnv)) {
1876 options->devices.push_back(
1877 ::c10d::ProcessGroupGloo::createDeviceForInterface(iface));
1878 }
1879 } else {
1880 // If no hostname is specified, this function looks up
1881 // the machine's hostname and returns a device instance
1882 // associated with the address that the hostname resolves to.
1883 options->devices.push_back(
1884 ::c10d::ProcessGroupGloo::createDefaultDevice());
1885 }
1886
1887 options->timeout = timeout;
1888 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1889 options->threads = options->devices.size() * 2;
1890 return c10::make_intrusive<::c10d::ProcessGroupGloo>(
1891 store, rank, size, options);
1892 }),
1893 py::arg("store"),
1894 py::arg("rank"),
1895 py::arg("size"),
1896 py::arg("timeout") = kProcessGroupDefaultTimeout,
1897 py::call_guard<py::gil_scoped_release>())
1898 .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions);
1899
1900 // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process
1901 // group. It can be used to validate collective calls across processes by
1902 // checking the op type and input tensor shapes.
1903 auto processGroupWrapper =
1904 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>(
1905 module, "_ProcessGroupWrapper", backend)
1906 .def(
1907 py::init(
1908 [](const c10::intrusive_ptr<::c10d::Backend>& backend,
1909 const c10::intrusive_ptr<::c10d::Backend>& gloo_backend) {
1910 return c10::make_intrusive<::c10d::ProcessGroupWrapper>(
1911 backend, gloo_backend);
1912 }),
1913 py::arg("backend"),
1914 py::arg("gloo_backend"),
1915 py::call_guard<py::gil_scoped_release>())
1916 .def_property_readonly(
1917 "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg);
1918#endif
1919
1920#ifdef USE_C10D_NCCL
1921 auto processGroupNCCL =
1922 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupNCCL>(
1923 module, "ProcessGroupNCCL", backend)
1924 .def(
1925 py::init<
1926 const c10::intrusive_ptr<::c10d::Store>&,
1927 int,
1928 int,
1929 c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(),
1930 py::call_guard<py::gil_scoped_release>())
1931 .def(
1932 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
1933 int rank,
1934 int size,
1935 const std::chrono::milliseconds& timeout) {
1936 auto options = ::c10d::ProcessGroupNCCL::Options::create();
1937 options->is_high_priority_stream = false;
1938 options->timeout = timeout;
1939 return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
1940 store, rank, size, options);
1941 }),
1942 py::arg("store"),
1943 py::arg("rank"),
1944 py::arg("size"),
1945 py::arg("timeout") = kProcessGroupDefaultTimeout,
1946 py::call_guard<py::gil_scoped_release>())
1947 .def_property_readonly(
1948 "options", &::c10d::ProcessGroupNCCL::getOptions)
1949 .def_property_readonly(
1950 "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable);
1951
1952 intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
1953 processGroupNCCL,
1954 "Options",
1955 processGroupOptions,
1956 R"(
1957ProcessGroup options for the NCCL backend
1958
1959Arguments:
1960 is_high_priority_stream (bool, optional): flag to enable/disable process
1961 group to pick up high priority cuda streams. It lets CUDA driver
1962 to prioritize NCCL kernels when there are compute kernels waiting.
1963 Default is False.
1964
1965Example::
1966 >>> import torch.distributed as dist
1967 >>>
1968 >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
1969 >>> # initialize a nccl process group with the options just created
1970 >>> dist.init_process_group("nccl", pg_options=nccl_options)
1971 )")
1972 .def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
1973 .def_readwrite(
1974 "is_high_priority_stream",
1975 &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
1976 processGroupNCCL.def_static(
1977 "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); });
1978 processGroupNCCL.def_static(
1979 "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); });
1980#endif
1981
1982#ifdef USE_C10D_MPI
1983 auto processGroupMPI =
1984 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>(
1985 module, "ProcessGroupMPI", backend);
1986
1987 // Define static create function instead of a constructor, because
1988 // this function may return null. This happens if this process is not
1989 // part of a sub group that is to be created.
1990 processGroupMPI.def_static(
1991 "create",
1992 [](std::vector<int> ranks) {
1993 return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
1994 },
1995 py::call_guard<py::gil_scoped_release>());
1996#endif
1997
1998#ifdef USE_C10D_UCC
1999 auto processGroupUCC =
2000 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
2001 module, "ProcessGroupUCC", backend)
2002 .def(
2003 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2004 int rank,
2005 int size,
2006 const std::chrono::milliseconds& timeout) {
2007 return c10::make_intrusive<::c10d::ProcessGroupUCC>(
2008 store, rank, size, timeout);
2009 }),
2010 py::arg("store"),
2011 py::arg("rank"),
2012 py::arg("size"),
2013 py::arg("timeout") = kProcessGroupDefaultTimeout,
2014 py::call_guard<py::gil_scoped_release>());
2015#endif
2016
2017 py::class_<
2018 ::c10d::Work,
2019 c10::intrusive_ptr<::c10d::Work>,
2020 ::c10d::PyProcessGroup::PyWork>(module, "Work")
2021 .def(py::init<>())
2022 .def("is_completed", &::c10d::Work::isCompleted)
2023 .def(
2024 "is_success",
2025 [](::c10d::Work& work) -> bool {
2026 TORCH_WARN_ONCE(
2027 fmt::format(kDeprecationWarning, "Work::is_success"));
2028 return work.isSuccess();
2029 })
2030 .def(
2031 "exception",
2032 [](::c10d::Work& work) -> std::exception_ptr {
2033 TORCH_WARN_ONCE(
2034 fmt::format(kDeprecationWarning, "Work::exception"));
2035 return work.exception();
2036 })
2037 .def(
2038 "source_rank",
2039 [](::c10d::Work& work) -> int {
2040 TORCH_WARN_ONCE(
2041 fmt::format(kDeprecationWarning, "Work::source_rank"));
2042 return work.sourceRank();
2043 })
2044 .def("_source_rank", &::c10d::Work::sourceRank)
2045 .def(
2046 "result",
2047 [](::c10d::Work& work) -> std::vector<at::Tensor> {
2048 return work.result();
2049 })
2050 .def(
2051 "synchronize",
2052 [](::c10d::Work& work) -> void {
2053 TORCH_WARN_ONCE(
2054 fmt::format(kDeprecationWarning, "Work::synchronize"));
2055 work.synchronize();
2056 })
2057 .def(
2058 "wait",
2059 &::c10d::Work::wait,
2060 py::arg("timeout") = kNoTimeout,
2061 py::call_guard<py::gil_scoped_release>())
2062 .def(
2063 "get_future",
2064 [](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
2065 return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
2066 },
2067 R"(
2068 Returns:
2069 A ``torch.futures.Future`` object which is associated with the completion of
2070 the ``Work``. As an example, a future object can be retrieved
2071 by ``fut = process_group.allreduce(tensors).get_future()``.
2072
2073 Example::
2074 Below is an example of a simple allreduce DDP communication hook that uses
2075 ``get_future` API to retrieve a Future associated with the completion of
2076 ``allreduce``.
2077
2078 >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future
2079 >>> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD
2080 >>> tensor = bucket.buffer().div_(group_to_use.size())
2081 >>> return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
2082 >>> ddp_model.register_comm_hook(state=None, hook=allreduce)
2083
2084 .. warning ::
2085 ``get_future`` API supports NCCL, and partially GLOO and MPI backends
2086 (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
2087
2088 In the example above, ``allreduce`` work will be done on GPU using NCCL backend,
2089 ``fut.wait()`` will return after synchronizing the appropriate NCCL streams
2090 with PyTorch's current device streams to ensure we can have asynchronous CUDA
2091 execution and it does not wait for the entire operation to complete on GPU. Note that
2092 ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``.
2093 In addition, if a callback function was added by ``fut.then()``, it will wait until
2094 ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback
2095 stream and invoke the callback inline after running the callback on the callback stream.
2096 ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the
2097 callback and a ``CUDAEvent`` that recorded the callback stream.
2098
2099 1. For CPU work, ``fut.done()`` returns true when work has been complted and value()
2100 tensors are ready.
2101 2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued.
2102 3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns
2103 true when tensors have arrived on respective nodes, but not yet necessarily synched on
2104 respective GPUs (similarly to GPU work).
2105 )");
2106
2107 py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
2108 .def(py::init<>())
2109 .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map)
2110 .def_readwrite("ints_map", &c10::DDPLoggingData::ints_map);
2111
2112 module.def(
2113 "_compute_bucket_assignment_by_size",
2114 [](const std::vector<at::Tensor>& tensors,
2115 const std::vector<size_t>& bucket_size_limits,
2116 const std::vector<bool>& expect_sparse_gradient,
2117 const std::vector<int64_t>& tensor_indices,
2118 const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) {
2119 if (logger.has_value()) {
2120 std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
2121 return ::c10d::compute_bucket_assignment_by_size(
2122 tensors,
2123 bucket_size_limits,
2124 expect_sparse_gradient,
2125 tensor_indices,
2126 {logger_weakref});
2127 } else {
2128 return ::c10d::compute_bucket_assignment_by_size(
2129 tensors,
2130 bucket_size_limits,
2131 expect_sparse_gradient,
2132 tensor_indices,
2133 {});
2134 }
2135 },
2136 py::arg("tensors"),
2137 py::arg("bucket_size"),
2138 py::arg("expect_sparse_gradient") = std::vector<bool>(),
2139 py::arg("tensor_indices") = std::vector<int64_t>(),
2140 py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{},
2141 py::call_guard<py::gil_scoped_release>());
2142
2143 module.def(
2144 "_verify_params_across_processes",
2145 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
2146 const std::vector<at::Tensor>& params,
2147 const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) {
2148 if (logger.has_value()) {
2149 std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
2150 verify_params_across_processes(
2151 process_group, params, {logger_weakref});
2152 } else {
2153 verify_params_across_processes(process_group, params, {});
2154 }
2155 },
2156 py::arg("process_group"),
2157 py::arg("params"),
2158 py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{},
2159 py::call_guard<py::gil_scoped_release>());
2160
2161 module.def(
2162 "_broadcast_coalesced",
2163 // Define a lambda such that the pybind11 prototype can take a std::vector
2164 // for the tensor list argument, but still pass it to the underlying
2165 // function as a c10::ArrayRef.
2166 [](c10::intrusive_ptr<::c10d::ProcessGroup> process_group,
2167 std::vector<at::Tensor> tensors, // NOLINT
2168 size_t buffer_size,
2169 int rank) {
2170 broadcast_coalesced(
2171 std::move(process_group), tensors, buffer_size, rank);
2172 },
2173 py::arg("process_group"),
2174 py::arg("tensors"),
2175 py::arg("buffer_size"),
2176 // The source of truth rank to broadcast the tensors from.
2177 py::arg("src") = 0,
2178 py::call_guard<py::gil_scoped_release>());
2179
2180 module.def(
2181 "_test_python_store",
2182 // Define a function that takes a c10d store and runs a few tests.
2183 // This is used by the PythonStore tests, which we cannot test from the
2184 // Python side of the world. Calling Python functions on a Python object
2185 // completely bypasses pybind11. We need to test that the overloaded
2186 // functions call into Python and behave like we expect.
2187 [](c10::intrusive_ptr<::c10d::Store> store) {
2188 auto add = [&store](const std::string& key, int64_t value) {
2189 store->add(key, value);
2190 };
2191
2192 auto set = [&store](const std::string& key, const std::string& value) {
2193 store->set(key, value);
2194 };
2195
2196 auto get = [&store](const std::string& key) {
2197 auto value = store->get(key);
2198 return std::string(value.begin(), value.end());
2199 };
2200
2201 add("key", 1);
2202 add("key", 2);
2203 add("key", 3);
2204 set("key0", "value0");
2205 add("key3", 1);
2206 set("key1", "value1");
2207 add("key3", 2);
2208 set("key2", "value2");
2209 add("key3", 3);
2210 add("key3", 4);
2211 add("key3", 3);
2212 add("key3", 2);
2213 if (get("key") != "6") {
2214 TORCH_CHECK(false, "assertion failed");
2215 }
2216 if (get("key0") != "value0") {
2217 TORCH_CHECK(false, "assertion failed");
2218 }
2219 if (get("key1") != "value1") {
2220 TORCH_CHECK(false, "assertion failed");
2221 }
2222 if (get("key2") != "value2") {
2223 TORCH_CHECK(false, "assertion failed");
2224 }
2225 if (get("key3") != "15") {
2226 TORCH_CHECK(false, "assertion failed");
2227 }
2228 },
2229 py::call_guard<py::gil_scoped_release>());
2230
2231 module.attr("_DEFAULT_FIRST_BUCKET_BYTES") = ::c10d::kDefaultFirstBucketBytes;
2232 module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
2233 module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
2234
2235 module.def(
2236 "_create_work_from_future",
2237 [](std::shared_ptr<jit::PythonFutureWrapper> future) {
2238 return ::c10d::Work::create_from_future(future->fut);
2239 },
2240 py::arg("future"),
2241 R"(
2242 Arguments:
2243 future(str): The future to wrap.
2244 Returns:
2245 A ``Work`` object which is associated with the completion of
2246 the ``torch.futures.Future``.
2247 This is the preferred way of constructing Work objects when writing a custom ProcessGroup
2248 in python.
2249 Example::
2250 >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup):
2251 >>> def broadcast(self, tensor_list, opts):
2252 >>> fut = torch.futures.Future()
2253 >>> fut.set_result(tensor_list)
2254 >>> return torch._C._distributed_c10d._create_work_from_future(fut)
2255 .. warning ::
2256 This API is experimental and subject to change.
2257 The returned Work object has multiple limitations:
2258 - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization.
2259 - wait() ignored timeout argument.
2260 - sourceRank() raises.
2261 - abort() raises.
2262 The provided Future object result must be a Tensor or a list of Tensors.
2263 )");
2264
2265 Py_RETURN_TRUE;
2266}
2267
2268#undef PROCESS_GROUP_DEPRECATION_WARNING
2269
2270} // namespace
2271
2272// c10d methods on torch._C
2273static PyMethodDef methods[] = { // NOLINT
2274 {"_c10d_init", c10d_init, METH_NOARGS, nullptr},
2275 {nullptr, nullptr, 0, nullptr}};
2276
2277PyMethodDef* python_functions() {
2278 return methods;
2279}
2280
2281} // namespace c10d
2282} // namespace distributed
2283} // namespace torch
2284