1#pragma once
2
3#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4#include <torch/csrc/utils/pybind.h>
5#include <torch/csrc/jit/python/pybind_utils.h>
6
7namespace c10d {
8
9// PyProcessGroup is a pybind11 trampoline class to allow a Python
10// class to inherit from torch.distributed.ProcessGroup
11class PyProcessGroup : public ProcessGroup {
12 public:
13 // PyWork is a pybind11 trampoline class to allow a Python
14 // class to inherit from torch.distributed.Work
15 class PyWork : public Work {
16 public:
17 PyWork() = default;
18
19 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
20 PYBIND11_OVERRIDE(
21 bool, /* Return type */
22 Work, /* Parent class */
23 wait, /* Name of function in C++ */
24 timeout);
25 }
26
27 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
28 // We cannot use PYBIND11_OVERRIDE because:
29 // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
30 // 2. The python name is get_future
31 pybind11::gil_scoped_acquire gil;
32 auto override = pybind11::get_override(static_cast<const Work *>(this), "get_future");
33
34 if (override) {
35 py::object o = override();
36 auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
37 return futWrapper->fut;
38 }
39
40 return Work::getFuture();
41 }
42 };
43
44 using ProcessGroup::ProcessGroup;
45
46 const std::string getBackendName() const override {
47 PYBIND11_OVERRIDE_PURE(
48 std::string, /* Return type */
49 ProcessGroup, /* Parent class */
50 getBackendName, /* Name of function in C++ */
51 );
52 }
53
54 c10::intrusive_ptr<Work> allgather(
55 std::vector<std::vector<at::Tensor>>& outputTensors,
56 std::vector<at::Tensor>& inputTensors,
57 const AllgatherOptions& opts = AllgatherOptions()) override {
58 PYBIND11_OVERRIDE(
59 c10::intrusive_ptr<Work>, /* Return type */
60 ProcessGroup, /* Parent class */
61 allgather, /* Name of function in C++ */
62 outputTensors,
63 inputTensors,
64 opts);
65 }
66
67 c10::intrusive_ptr<Work> allreduce(
68 std::vector<at::Tensor>& tensors,
69 const AllreduceOptions& opts = AllreduceOptions()) override {
70 PYBIND11_OVERRIDE(
71 c10::intrusive_ptr<Work>, /* Return type */
72 ProcessGroup, /* Parent class */
73 allreduce, /* Name of function in C++ */
74 tensors,
75 opts);
76 }
77
78 c10::intrusive_ptr<Work> barrier(
79 const BarrierOptions& opts = BarrierOptions()) override {
80 PYBIND11_OVERRIDE(
81 c10::intrusive_ptr<Work>, /* Return type */
82 ProcessGroup, /* Parent class */
83 barrier, /* Name of function in C++ */
84 opts);
85 }
86
87 c10::intrusive_ptr<Work> broadcast(
88 std::vector<at::Tensor>& tensors,
89 const BroadcastOptions& opts = BroadcastOptions()) override {
90 PYBIND11_OVERRIDE(
91 c10::intrusive_ptr<Work>, /* Return type */
92 ProcessGroup, /* Parent class */
93 broadcast, /* Name of function in C++ */
94 tensors,
95 opts);
96 }
97
98 c10::intrusive_ptr<Work> reduce_scatter(
99 std::vector<at::Tensor>& outputTensors,
100 std::vector<std::vector<at::Tensor>>& inputTensors,
101 const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
102 PYBIND11_OVERRIDE(
103 c10::intrusive_ptr<Work>, /* Return type */
104 ProcessGroup, /* Parent class */
105 reduce_scatter, /* Name of function in C++ */
106 outputTensors,
107 inputTensors,
108 opts);
109 }
110
111 c10::intrusive_ptr<Work> send(
112 std::vector<at::Tensor>& tensors,
113 int dstRank,
114 int tag) override {
115 PYBIND11_OVERRIDE(
116 c10::intrusive_ptr<Work>, /* Return type */
117 ProcessGroup, /* Parent class */
118 send, /* Name of function in C++ */
119 tensors,
120 dstRank,
121 tag);
122 }
123
124 c10::intrusive_ptr<Work> recv(
125 std::vector<at::Tensor>& tensors,
126 int srcRank,
127 int tag) override {
128 PYBIND11_OVERRIDE(
129 c10::intrusive_ptr<Work>, /* Return type */
130 ProcessGroup, /* Parent class */
131 recv, /* Name of function in C++ */
132 tensors,
133 srcRank,
134 tag);
135 }
136};
137
138} // namespace c10d
139