1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
4 | #include <torch/csrc/utils/pybind.h> |
5 | #include <torch/csrc/jit/python/pybind_utils.h> |
6 | |
7 | namespace c10d { |
8 | |
9 | // PyProcessGroup is a pybind11 trampoline class to allow a Python |
10 | // class to inherit from torch.distributed.ProcessGroup |
11 | class PyProcessGroup : public ProcessGroup { |
12 | public: |
13 | // PyWork is a pybind11 trampoline class to allow a Python |
14 | // class to inherit from torch.distributed.Work |
15 | class PyWork : public Work { |
16 | public: |
17 | PyWork() = default; |
18 | |
19 | bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { |
20 | PYBIND11_OVERRIDE( |
21 | bool, /* Return type */ |
22 | Work, /* Parent class */ |
23 | wait, /* Name of function in C++ */ |
24 | timeout); |
25 | } |
26 | |
27 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { |
28 | // We cannot use PYBIND11_OVERRIDE because: |
29 | // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and |
30 | // 2. The python name is get_future |
31 | pybind11::gil_scoped_acquire gil; |
32 | auto override = pybind11::get_override(static_cast<const Work *>(this), "get_future" ); |
33 | |
34 | if (override) { |
35 | py::object o = override(); |
36 | auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>(); |
37 | return futWrapper->fut; |
38 | } |
39 | |
40 | return Work::getFuture(); |
41 | } |
42 | }; |
43 | |
44 | using ProcessGroup::ProcessGroup; |
45 | |
46 | const std::string getBackendName() const override { |
47 | PYBIND11_OVERRIDE_PURE( |
48 | std::string, /* Return type */ |
49 | ProcessGroup, /* Parent class */ |
50 | getBackendName, /* Name of function in C++ */ |
51 | ); |
52 | } |
53 | |
54 | c10::intrusive_ptr<Work> allgather( |
55 | std::vector<std::vector<at::Tensor>>& outputTensors, |
56 | std::vector<at::Tensor>& inputTensors, |
57 | const AllgatherOptions& opts = AllgatherOptions()) override { |
58 | PYBIND11_OVERRIDE( |
59 | c10::intrusive_ptr<Work>, /* Return type */ |
60 | ProcessGroup, /* Parent class */ |
61 | allgather, /* Name of function in C++ */ |
62 | outputTensors, |
63 | inputTensors, |
64 | opts); |
65 | } |
66 | |
67 | c10::intrusive_ptr<Work> allreduce( |
68 | std::vector<at::Tensor>& tensors, |
69 | const AllreduceOptions& opts = AllreduceOptions()) override { |
70 | PYBIND11_OVERRIDE( |
71 | c10::intrusive_ptr<Work>, /* Return type */ |
72 | ProcessGroup, /* Parent class */ |
73 | allreduce, /* Name of function in C++ */ |
74 | tensors, |
75 | opts); |
76 | } |
77 | |
78 | c10::intrusive_ptr<Work> barrier( |
79 | const BarrierOptions& opts = BarrierOptions()) override { |
80 | PYBIND11_OVERRIDE( |
81 | c10::intrusive_ptr<Work>, /* Return type */ |
82 | ProcessGroup, /* Parent class */ |
83 | barrier, /* Name of function in C++ */ |
84 | opts); |
85 | } |
86 | |
87 | c10::intrusive_ptr<Work> broadcast( |
88 | std::vector<at::Tensor>& tensors, |
89 | const BroadcastOptions& opts = BroadcastOptions()) override { |
90 | PYBIND11_OVERRIDE( |
91 | c10::intrusive_ptr<Work>, /* Return type */ |
92 | ProcessGroup, /* Parent class */ |
93 | broadcast, /* Name of function in C++ */ |
94 | tensors, |
95 | opts); |
96 | } |
97 | |
98 | c10::intrusive_ptr<Work> reduce_scatter( |
99 | std::vector<at::Tensor>& outputTensors, |
100 | std::vector<std::vector<at::Tensor>>& inputTensors, |
101 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override { |
102 | PYBIND11_OVERRIDE( |
103 | c10::intrusive_ptr<Work>, /* Return type */ |
104 | ProcessGroup, /* Parent class */ |
105 | reduce_scatter, /* Name of function in C++ */ |
106 | outputTensors, |
107 | inputTensors, |
108 | opts); |
109 | } |
110 | |
111 | c10::intrusive_ptr<Work> send( |
112 | std::vector<at::Tensor>& tensors, |
113 | int dstRank, |
114 | int tag) override { |
115 | PYBIND11_OVERRIDE( |
116 | c10::intrusive_ptr<Work>, /* Return type */ |
117 | ProcessGroup, /* Parent class */ |
118 | send, /* Name of function in C++ */ |
119 | tensors, |
120 | dstRank, |
121 | tag); |
122 | } |
123 | |
124 | c10::intrusive_ptr<Work> recv( |
125 | std::vector<at::Tensor>& tensors, |
126 | int srcRank, |
127 | int tag) override { |
128 | PYBIND11_OVERRIDE( |
129 | c10::intrusive_ptr<Work>, /* Return type */ |
130 | ProcessGroup, /* Parent class */ |
131 | recv, /* Name of function in C++ */ |
132 | tensors, |
133 | srcRank, |
134 | tag); |
135 | } |
136 | }; |
137 | |
138 | } // namespace c10d |
139 | |