1#pragma once
2
3#include <condition_variable>
4#include <memory>
5#include <mutex>
6#include <stdexcept>
7#include <unordered_map>
8#include <utility>
9#include <vector>
10
11#include <ATen/ATen.h>
12#include <c10/macros/Macros.h>
13
14#include <torch/csrc/distributed/c10d/Work.hpp>
15#include <torch/csrc/distributed/c10d/Types.hpp>
16#include <torch/csrc/distributed/c10d/Utils.hpp>
17#include <torch/csrc/distributed/c10d/debug.h>
18#include <torch/csrc/distributed/c10d/sequence_num.hpp>
19
20constexpr auto kBackendDefaultTimeout =
21 std::chrono::milliseconds(30 * 60 * 1000);
22
23namespace c10d {
24
25class TORCH_API Backend : public torch::CustomClassHolder {
26 public:
27
28 // Backend Options is a base struct that defines the basic options
29 // when constructing a Backend. Each Backend subclass should
30 // extend this struct and define its options if it wants to provide more
31 // config options (beyond basic ones defined here) to end user.
32 struct TORCH_API Options : torch::CustomClassHolder {
33 explicit Options(
34 std::string backend,
35 std::chrono::milliseconds timeout = kBackendDefaultTimeout)
36 : timeout(timeout), backend(std::move(backend)) {}
37 ~Options() override = default;
38
39 std::chrono::milliseconds timeout;
40
41 // backend name
42 const std::string backend;
43 };
44
45 explicit Backend(int rank, int size);
46 ~Backend() override = 0;
47
48 int getRank() const {
49 return rank_;
50 }
51
52 int getSize() const {
53 return size_;
54 }
55
56 virtual void startCoalescing() {
57 // no-op for backends that have not implemented startCoalescing
58 }
59
60 virtual void endCoalescing(
61 std::vector<c10::intrusive_ptr<Work>>& /* reqs */) {
62 // no-op for backends that have not implemented endCoalescing
63 }
64
65 // Subclasses must override this method to return the backend name
66 virtual const std::string getBackendName() const {
67 TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
68 };
69
70 virtual c10::intrusive_ptr<Work> broadcast(
71 std::vector<at::Tensor>& /* tensors */,
72 const BroadcastOptions& /* opts */ = BroadcastOptions()) {
73 TORCH_CHECK(
74 false,
75 c10::str("Backend ", getBackendName(), "does not support broadcast"));
76 }
77
78 virtual c10::intrusive_ptr<Work> allreduce(
79 std::vector<at::Tensor>& /* tensors */,
80 const AllreduceOptions& /* opts */ = AllreduceOptions()) {
81 TORCH_CHECK(
82 false,
83 c10::str("Backend ", getBackendName(), "does not support allreduce"));
84 }
85
86 virtual c10::intrusive_ptr<Work> allreduce_coalesced(
87 std::vector<at::Tensor>& /* tensors */,
88 const AllreduceCoalescedOptions& /* opts */ =
89 AllreduceCoalescedOptions()) {
90 TORCH_CHECK(
91 false,
92 c10::str(
93 "Backend ",
94 getBackendName(),
95 "does not support allreduce_coalesced"));
96 }
97
98 virtual c10::intrusive_ptr<Work> reduce(
99 std::vector<at::Tensor>& /* tensors */,
100 const ReduceOptions& /* opts */ = ReduceOptions()) {
101 TORCH_CHECK(
102 false,
103 c10::str("Backend ", getBackendName(), "does not support reduce"));
104 }
105
106 virtual c10::intrusive_ptr<Work> allgather(
107 std::vector<std::vector<at::Tensor>>& /* outputTensors */,
108 std::vector<at::Tensor>& /* inputTensors */,
109 const AllgatherOptions& /* opts */ = AllgatherOptions()) {
110 TORCH_CHECK(
111 false,
112 c10::str("Backend ", getBackendName(), "does not support allgather"));
113 }
114
115 // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
116 // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
117 // For implementers of ProcessGroup API and advanced users only.
118 // Note: this function will be deprecated in near future.
119 virtual c10::intrusive_ptr<Work> _allgather_base(
120 at::Tensor& /* outputBuffer */,
121 at::Tensor& /* inputBuffer */,
122 const AllgatherOptions& /* opts */ = AllgatherOptions()) {
123 TORCH_CHECK(
124 false,
125 c10::str(
126 "Backend ", getBackendName(), "does not support _allgather_base"));
127 }
128
129 // This function is deprecated and will be moved out of Backend to comms:
130 // * do not add dependencies on this function,
131 // * do not implement it in your Backend, implement _allgather_base
132 // instead.
133 virtual c10::intrusive_ptr<Work> allgather_coalesced(
134 std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
135 std::vector<at::Tensor>& /* inputTensors */,
136 const AllgatherOptions& /* opts */ = AllgatherOptions()) {
137 TORCH_CHECK(
138 false,
139 c10::str(
140 "Backend ",
141 getBackendName(),
142 "does not support allgather_coalesced"));
143 }
144
145 virtual c10::intrusive_ptr<Work> gather(
146 std::vector<std::vector<at::Tensor>>& /* outputTensors */,
147 std::vector<at::Tensor>& /* inputTensors */,
148 const GatherOptions& /* opts */ = GatherOptions()) {
149 TORCH_CHECK(
150 false,
151 c10::str("Backend ", getBackendName(), "does not support gather"));
152 }
153
154 virtual c10::intrusive_ptr<Work> scatter(
155 std::vector<at::Tensor>& /* outputTensors */,
156 std::vector<std::vector<at::Tensor>>& /* inputTensors */,
157 const ScatterOptions& /* opts */ = ScatterOptions()) {
158 TORCH_CHECK(
159 false,
160 c10::str("Backend ", getBackendName(), "does not support scatter"));
161 }
162
163 virtual c10::intrusive_ptr<Work> reduce_scatter(
164 std::vector<at::Tensor>& /* outputTensors */,
165 std::vector<std::vector<at::Tensor>>& /* inputTensors */,
166 const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
167 TORCH_CHECK(
168 false,
169 c10::str(
170 "Backend ", getBackendName(), "does not support reduce_scatter"));
171 }
172
173 virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
174 at::Tensor& /* outputBuffer */,
175 at::Tensor& /* inputBuffer */,
176 const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
177 TORCH_CHECK(
178 false,
179 c10::str(
180 "Backend ",
181 getBackendName(),
182 "does not support _reduce_scatter_base"));
183 }
184
185 virtual c10::intrusive_ptr<Work> alltoall_base(
186 at::Tensor& /* outputBuffer */,
187 at::Tensor& /* inputBuffer */,
188 std::vector<int64_t>& /* outputSplitSizes */,
189 std::vector<int64_t>& /* inputSplitSizes */,
190 const AllToAllOptions& /* opts */ = AllToAllOptions()) {
191 TORCH_CHECK(
192 false,
193 c10::str(
194 "Backend ", getBackendName(), "does not support alltoall_base"));
195 }
196
197 virtual c10::intrusive_ptr<Work> alltoall(
198 std::vector<at::Tensor>& /* outputTensors */,
199 std::vector<at::Tensor>& /* inputTensors */,
200 const AllToAllOptions& opts = AllToAllOptions()) {
201 TORCH_CHECK(
202 false,
203 c10::str("Backend ", getBackendName(), "does not support alltoall"));
204 }
205
206 virtual void monitoredBarrier(
207 const BarrierOptions& /* unused */,
208 bool /* unused */ = false) {
209 auto backendName = getBackendName();
210 TORCH_CHECK(
211 false,
212 c10::str(
213 "Backend ",
214 backendName,
215 " does not support monitoredBarrier, only GLOO supports monitored barrier."));
216 }
217
218 // Agrees on an initial sequence number for the whole group by having rank 0
219 // create it and broadcast it to other ranks using the store. Only implemented
220 // for GLOO and NCCL backends currently.
221 virtual void setSequenceNumberForGroup() {
222 auto backendName = getBackendName();
223 TORCH_CHECK(
224 false,
225 c10::str(
226 "Backend ",
227 backendName,
228 " does not yet support sequence numbers."));
229 }
230
231 // Retrieves the current sequence number for the whole group, which should be
232 // in sync. If the returned number is not consistent across the group, it
233 // may indicate that there is some sort of collective desynchronization.
234 virtual uint64_t getSequenceNumberForGroup() {
235 auto backendName = getBackendName();
236 TORCH_CHECK(
237 false,
238 c10::str(
239 "Backend ",
240 backendName,
241 " does not yet support sequence numbers."));
242 }
243
244 virtual c10::intrusive_ptr<Work> send(
245 std::vector<at::Tensor>& /* tensors */,
246 int /* dstRank */,
247 int /* tag */) {
248 TORCH_CHECK(
249 false, c10::str("Backend ", getBackendName(), "does not support send"));
250 }
251
252 virtual c10::intrusive_ptr<Work> recv(
253 std::vector<at::Tensor>& /* tensors */,
254 int /* srcRank */,
255 int /* tag */) {
256 TORCH_CHECK(
257 false, c10::str("Backend ", getBackendName(), "does not support recv"));
258 }
259
260 virtual c10::intrusive_ptr<Work> recvAnysource(
261 std::vector<at::Tensor>& /* tensors */,
262 int /* tag */) {
263 TORCH_CHECK(
264 false,
265 c10::str(
266 "Backend ", getBackendName(), "does not support recvAnysource"));
267 }
268
269 virtual c10::intrusive_ptr<Work> barrier(
270 const BarrierOptions& /* opts */ = BarrierOptions()) {
271 TORCH_CHECK(
272 false,
273 c10::str("Backend ", getBackendName(), "does not support barrier"));
274 }
275
276 protected:
277 // Implementations of this interface need to call this to setup
278 // appropriate logging etc.
279 void init();
280
281 // Optional sequence number structure for matching collectives.
282 c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
283 const int rank_;
284 const int size_;
285 // Debug level setting. It is parsed once when ProcessGroup is constructed and
286 // remains the same across use of this process group.
287 DebugLevel dist_debug_level_;
288};
289
290} // namespace c10d
291