1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
4 | #include <condition_variable> |
5 | #include <memory> |
6 | #include <mutex> |
7 | #include <stdexcept> |
8 | #include <unordered_map> |
9 | #include <utility> |
10 | #include <vector> |
11 | |
12 | #include <ATen/ATen.h> |
13 | #include <ATen/core/dispatch/Dispatcher.h> |
14 | #include <c10/macros/Macros.h> |
15 | |
16 | #include <torch/csrc/distributed/c10d/Work.hpp> |
17 | // ************************************************************************* |
18 | // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN |
19 | // versions 1.7 and 1.8. |
20 | // PLEASE DO NOT ADD ANY DEPENDENCIES. |
21 | // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 |
22 | // ************************************************************************* |
23 | |
24 | constexpr auto kProcessGroupDefaultTimeout = |
25 | std::chrono::milliseconds(30 * 60 * 1000); |
26 | |
27 | namespace c10d { |
28 | |
29 | // ProcessGroup is a base class that captures collective and point to |
30 | // point communication in a fixed set of processes. |
31 | // |
32 | // The functions specified in the class below describe the API alone; |
33 | // implementations are provided in subclasses. |
34 | // |
35 | // Every function that performs I/O is executed asynchronously by a |
36 | // thread pool owned by the ProcessGroup (by default). They return an |
37 | // object that can be used to wait for completion or error. |
38 | // |
39 | // The ProcessGroup can instantiate subgroups with fewer or an equal |
40 | // number of members. Implementations must take care that multiple |
41 | // process groups can be used in parallel and synchronize accordingly. |
42 | // |
43 | // The ProcessGroup assumes a fixed set of processes. If the set |
44 | // changes, existing instances must be destructed and instantiation |
45 | // and initialization must start from scratch. For members of the |
46 | // process group to find each other (referred to as rendezvous from |
47 | // hereon) |
48 | // |
49 | class TORCH_API ProcessGroup : public torch::CustomClassHolder { |
50 | public: |
51 | // ProcessGroup Options is a base struct that defines the basic options |
52 | // when constructing a ProcessGroup. Each ProcessGroup subclass should |
53 | // extend this struct and define its options if it wants to provide more |
54 | // config options (beyond basic ones defined here) to end user. |
55 | struct TORCH_API Options : torch::CustomClassHolder { |
56 | explicit Options( |
57 | std::string backend, |
58 | std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) |
59 | : timeout(timeout), backend(std::move(backend)) {} |
60 | ~Options() override = default; |
61 | |
62 | std::chrono::milliseconds timeout; |
63 | |
64 | // backend name |
65 | const std::string backend; |
66 | }; |
67 | |
68 | enum BackendType { |
69 | UNDEFINED = 0, |
70 | GLOO = 1, |
71 | NCCL = 2, |
72 | UCC = 3, |
73 | MPI = 4, |
74 | CUSTOM = 5, |
75 | }; |
76 | |
77 | // Not used, set for backwards compatibility and only used for TypeDef in |
78 | // Ops.cpp |
79 | explicit ProcessGroup(int rank, int size); |
80 | |
81 | explicit ProcessGroup( |
82 | const c10::intrusive_ptr<::c10d::Store>& store, |
83 | int rank, |
84 | int size, |
85 | c10::intrusive_ptr<Options> options); |
86 | ~ProcessGroup() override; |
87 | |
88 | int getRank() const { |
89 | return rank_; |
90 | } |
91 | |
92 | int getSize() const { |
93 | return size_; |
94 | } |
95 | |
96 | virtual const std::string getBackendName() const { |
97 | return options_->backend; |
98 | }; |
99 | |
100 | BackendType getBackendType() const { |
101 | return backendType_; |
102 | }; |
103 | |
104 | virtual void startCoalescing(c10::DeviceType deviceType) { |
105 | // only nccl has implemented startCoalescing so only execute for nccl |
106 | // backends |
107 | if (getBackendType() == BackendType::NCCL) { |
108 | getBackend(deviceType)->startCoalescing(); |
109 | } |
110 | } |
111 | |
112 | virtual void endCoalescing( |
113 | c10::DeviceType deviceType, |
114 | std::vector<c10::intrusive_ptr<Work>>& reqs) { |
115 | // only nccl has implemented startCoalescing so only execute for nccl |
116 | // backends |
117 | if (getBackendType() == BackendType::NCCL) { |
118 | getBackend(deviceType)->endCoalescing(reqs); |
119 | } |
120 | } |
121 | |
122 | virtual c10::intrusive_ptr<Work> broadcast( |
123 | std::vector<at::Tensor>& tensors, |
124 | const BroadcastOptions& opts = BroadcastOptions()) { |
125 | static auto op = |
126 | c10::Dispatcher::singleton() |
127 | .findSchemaOrThrow("c10d::broadcast_" , "" ) |
128 | .typed< |
129 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
130 | at::TensorList, |
131 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
132 | int64_t, |
133 | int64_t, |
134 | int64_t)>(); |
135 | // It's awakward to unbox the opts here and box them again in the custom C++ |
136 | // op. But it's also complicated to make opts as a CustomClassHolder. Leave |
137 | // it as it is now. |
138 | return std::get<1>(op.call( |
139 | tensors, |
140 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
141 | opts.rootRank, |
142 | opts.rootTensor, |
143 | opts.timeout.count())); |
144 | } |
145 | |
146 | virtual c10::intrusive_ptr<Work> allreduce( |
147 | std::vector<at::Tensor>& tensors, |
148 | const AllreduceOptions& opts = AllreduceOptions()) { |
149 | static auto op = |
150 | c10::Dispatcher::singleton() |
151 | .findSchemaOrThrow("c10d::allreduce_" , "" ) |
152 | .typed< |
153 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
154 | at::TensorList, |
155 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
156 | const c10::intrusive_ptr<::c10d::ReduceOp>&, |
157 | int64_t)>(); |
158 | |
159 | return std::get<1>(op.call( |
160 | tensors, |
161 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
162 | c10::make_intrusive<ReduceOp>(opts.reduceOp), |
163 | opts.timeout.count())); |
164 | } |
165 | |
166 | virtual c10::intrusive_ptr<Work> allreduce_coalesced( |
167 | std::vector<at::Tensor>& tensors, |
168 | const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) { |
169 | static auto op = c10::Dispatcher::singleton() |
170 | .findSchemaOrThrow("c10d::allreduce_coalesced_" , "" ) |
171 | .typed<c10::intrusive_ptr<::c10d::Work>( |
172 | at::TensorList, |
173 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
174 | const c10::intrusive_ptr<::c10d::ReduceOp>&, |
175 | int64_t)>(); |
176 | |
177 | return op.call( |
178 | tensors, |
179 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
180 | |
181 | c10::make_intrusive<ReduceOp>(opts.reduceOp), |
182 | opts.timeout.count()); |
183 | } |
184 | |
185 | virtual c10::intrusive_ptr<Work> reduce( |
186 | std::vector<at::Tensor>& tensors, |
187 | const ReduceOptions& opts = ReduceOptions()) { |
188 | static auto op = c10::Dispatcher::singleton() |
189 | .findSchemaOrThrow("c10d::reduce_" , "" ) |
190 | .typed<c10::intrusive_ptr<::c10d::Work>( |
191 | at::TensorList, |
192 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
193 | const c10::intrusive_ptr<::c10d::ReduceOp>&, |
194 | int64_t, |
195 | int64_t, |
196 | int64_t)>(); |
197 | return op.call( |
198 | tensors, |
199 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
200 | |
201 | c10::make_intrusive<ReduceOp>(opts.reduceOp), |
202 | opts.rootRank, |
203 | opts.rootTensor, |
204 | opts.timeout.count()); |
205 | } |
206 | |
207 | virtual c10::intrusive_ptr<Work> allgather( |
208 | std::vector<std::vector<at::Tensor>>& outputTensors, |
209 | std::vector<at::Tensor>& inputTensors, |
210 | const AllgatherOptions& opts = AllgatherOptions()) { |
211 | static auto op = c10::Dispatcher::singleton() |
212 | .findSchemaOrThrow("c10d::allgather_" , "" ) |
213 | .typed<std::tuple< |
214 | std::vector<std::vector<at::Tensor>>, |
215 | c10::intrusive_ptr<Work>>( |
216 | const std::vector<std::vector<at::Tensor>>&, |
217 | at::TensorList, |
218 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
219 | int64_t)>(); |
220 | |
221 | return std::get<1>(op.call( |
222 | outputTensors, |
223 | inputTensors, |
224 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
225 | opts.timeout.count())); |
226 | } |
227 | |
228 | // Gathers a single tensor inputBuffer into a single buffer outputBuffer that |
229 | // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. |
230 | // For implementers of ProcessGroup API and advanced users only. |
231 | // Note: this function will be deprecated in near future. |
232 | virtual c10::intrusive_ptr<Work> _allgather_base( |
233 | at::Tensor& outputBuffer, |
234 | at::Tensor& inputBuffer, |
235 | const AllgatherOptions& opts = AllgatherOptions()) { |
236 | static auto op = |
237 | c10::Dispatcher::singleton() |
238 | .findSchemaOrThrow("c10d::_allgather_base_" , "" ) |
239 | .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( |
240 | at::Tensor&, |
241 | at::Tensor&, |
242 | const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); |
243 | |
244 | return std::get<1>(op.call( |
245 | outputBuffer, |
246 | inputBuffer, |
247 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this))); |
248 | } |
249 | |
250 | // This function is deprecated and will be moved out of ProcessGroup to comms: |
251 | // * do not add dependencies on this function, |
252 | // * do not implement it in your ProcessGroup, implement _allgather_base |
253 | // instead. |
254 | virtual c10::intrusive_ptr<Work> allgather_coalesced( |
255 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
256 | std::vector<at::Tensor>& inputTensors, |
257 | const AllgatherOptions& opts = AllgatherOptions()) { |
258 | static auto op = |
259 | c10::Dispatcher::singleton() |
260 | .findSchemaOrThrow("c10d::allgather_coalesced_" , "" ) |
261 | .typed<c10::intrusive_ptr<Work>( |
262 | const std::vector<std::vector<at::Tensor>>&, |
263 | const at::TensorList&, |
264 | const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); |
265 | |
266 | return op.call( |
267 | outputTensorLists, |
268 | inputTensors, |
269 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this)); |
270 | } |
271 | |
272 | virtual c10::intrusive_ptr<Work> gather( |
273 | std::vector<std::vector<at::Tensor>>& outputTensors, |
274 | std::vector<at::Tensor>& inputTensors, |
275 | const GatherOptions& opts = GatherOptions()) { |
276 | static auto op = c10::Dispatcher::singleton() |
277 | .findSchemaOrThrow("c10d::gather_" , "" ) |
278 | .typed<c10::intrusive_ptr<::c10d::Work>( |
279 | const std::vector<std::vector<at::Tensor>>&, |
280 | const at::TensorList&, |
281 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
282 | int64_t, |
283 | int64_t)>(); |
284 | return op.call( |
285 | outputTensors, |
286 | inputTensors, |
287 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
288 | opts.rootRank, |
289 | opts.timeout.count()); |
290 | } |
291 | |
292 | virtual c10::intrusive_ptr<Work> scatter( |
293 | std::vector<at::Tensor>& outputTensors, |
294 | std::vector<std::vector<at::Tensor>>& inputTensors, |
295 | const ScatterOptions& opts = ScatterOptions()) { |
296 | static auto op = |
297 | c10::Dispatcher::singleton() |
298 | .findSchemaOrThrow("c10d::scatter_" , "" ) |
299 | .typed< |
300 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
301 | const at::TensorList&, |
302 | const std::vector<std::vector<at::Tensor>>&, |
303 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
304 | int64_t, |
305 | int64_t)>(); |
306 | return std::get<1>(op.call( |
307 | outputTensors, |
308 | inputTensors, |
309 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
310 | opts.rootRank, |
311 | opts.timeout.count())); |
312 | } |
313 | |
314 | virtual c10::intrusive_ptr<Work> reduce_scatter( |
315 | std::vector<at::Tensor>& outputTensors, |
316 | std::vector<std::vector<at::Tensor>>& inputTensors, |
317 | const ReduceScatterOptions& opts = ReduceScatterOptions()) { |
318 | static auto op = |
319 | c10::Dispatcher::singleton() |
320 | .findSchemaOrThrow("c10d::reduce_scatter_" , "" ) |
321 | .typed< |
322 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
323 | const at::TensorList&, |
324 | const std::vector<std::vector<at::Tensor>>&, |
325 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
326 | const c10::intrusive_ptr<::c10d::ReduceOp>&, |
327 | int64_t)>(); |
328 | return std::get<1>(op.call( |
329 | outputTensors, |
330 | inputTensors, |
331 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
332 | c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), |
333 | opts.timeout.count())); |
334 | } |
335 | |
336 | virtual c10::intrusive_ptr<Work> _reduce_scatter_base( |
337 | at::Tensor& outputBuffer, |
338 | at::Tensor& inputBuffer, |
339 | const ReduceScatterOptions& opts = ReduceScatterOptions()) { |
340 | static auto op = c10::Dispatcher::singleton() |
341 | .findSchemaOrThrow("c10d::_reduce_scatter_base_" , "" ) |
342 | .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( |
343 | at::Tensor&, |
344 | at::Tensor&, |
345 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
346 | const c10::intrusive_ptr<::c10d::ReduceOp>&, |
347 | int64_t)>(); |
348 | return std::get<1>(op.call( |
349 | outputBuffer, |
350 | inputBuffer, |
351 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
352 | c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), |
353 | opts.timeout.count())); |
354 | } |
355 | |
356 | virtual c10::intrusive_ptr<Work> alltoall_base( |
357 | at::Tensor& outputBuffer, |
358 | at::Tensor& inputBuffer, |
359 | std::vector<int64_t>& outputSplitSizes, |
360 | std::vector<int64_t>& inputSplitSizes, |
361 | const AllToAllOptions& opts = AllToAllOptions()) { |
362 | static auto op = c10::Dispatcher::singleton() |
363 | .findSchemaOrThrow("c10d::alltoall_base_" , "" ) |
364 | .typed<c10::intrusive_ptr<::c10d::Work>( |
365 | at::Tensor&, |
366 | at::Tensor&, |
367 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
368 | std::vector<int64_t>, |
369 | std::vector<int64_t>, |
370 | int64_t)>(); |
371 | return op.call( |
372 | outputBuffer, |
373 | inputBuffer, |
374 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
375 | outputSplitSizes, |
376 | inputSplitSizes, |
377 | opts.timeout.count()); |
378 | } |
379 | |
380 | virtual c10::intrusive_ptr<Work> alltoall( |
381 | std::vector<at::Tensor>& outputTensors, |
382 | std::vector<at::Tensor>& inputTensors, |
383 | const AllToAllOptions& opts = AllToAllOptions()) { |
384 | static auto op = c10::Dispatcher::singleton() |
385 | .findSchemaOrThrow("c10d::alltoall_" , "" ) |
386 | .typed<std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
387 | const at::TensorList&, |
388 | const at::TensorList&, |
389 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
390 | int64_t)>(); |
391 | return std::get<1>(op.call( |
392 | outputTensors, |
393 | inputTensors, |
394 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
395 | opts.timeout.count())); |
396 | } |
397 | |
398 | virtual void monitoredBarrier( |
399 | const BarrierOptions& opts, |
400 | bool wait_all_ranks = false) { |
401 | static auto op = c10::Dispatcher::singleton() |
402 | .findSchemaOrThrow("c10d::monitored_barrier_" , "" ) |
403 | .typed<void( |
404 | at::Tensor, |
405 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
406 | const std::vector<int64_t>&, |
407 | int64_t, |
408 | bool)>(); |
409 | // Default to using cpu implementation, monitored barrier is only for GLOO |
410 | at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); |
411 | op.call( |
412 | tensor, |
413 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
414 | opts.device_ids, |
415 | opts.timeout.count(), |
416 | wait_all_ranks); |
417 | } |
418 | |
419 | // Agrees on an initial sequence number for the whole group by having rank 0 |
420 | // create it and broadcast it to other ranks using the store. Only implemented |
421 | // for GLOO and NCCL backends currently. |
422 | virtual void setSequenceNumberForGroup() { |
423 | auto backendType = getBackendType(); |
424 | // TODO: HACK for backend name to get sequence number for that backend. |
425 | if (backendType == ProcessGroup::BackendType::GLOO || |
426 | backendType == ProcessGroup::BackendType::NCCL || |
427 | backendType == ProcessGroup::BackendType::UCC) { |
428 | getDefaultBackend()->setSequenceNumberForGroup(); |
429 | } else { |
430 | TORCH_CHECK( |
431 | false, |
432 | c10::str( |
433 | "ProcessGroup " , |
434 | getBackendName(), |
435 | " does not yet support sequence numbers." )); |
436 | } |
437 | } |
438 | |
439 | // Retrieves the current sequence number for the whole group, which should be |
440 | // in sync. If the returned number is not consistent across the group, it |
441 | // may indicate that there is some sort of collective desynchronization. |
442 | virtual uint64_t getSequenceNumberForGroup() { |
443 | auto backendType = getBackendType(); |
444 | |
445 | // TODO: HACK for backend name to get sequence number for that backend. |
446 | if (backendType == ProcessGroup::BackendType::GLOO || |
447 | backendType == ProcessGroup::BackendType::NCCL || |
448 | backendType == ProcessGroup::BackendType::UCC) { |
449 | return getDefaultBackend()->getSequenceNumberForGroup(); |
450 | } else { |
451 | TORCH_CHECK( |
452 | false, |
453 | c10::str( |
454 | "ProcessGroup " , |
455 | getBackendName(), |
456 | " does not yet support sequence numbers." )); |
457 | } |
458 | } |
459 | |
460 | virtual c10::intrusive_ptr<Work> send( |
461 | std::vector<at::Tensor>& tensors, |
462 | int dstRank, |
463 | int tag) { |
464 | static auto op = c10::Dispatcher::singleton() |
465 | .findSchemaOrThrow("c10d::send" , "" ) |
466 | .typed<c10::intrusive_ptr<::c10d::Work>( |
467 | at::TensorList, |
468 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
469 | int64_t, |
470 | int64_t)>(); |
471 | return op.call( |
472 | tensors, |
473 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
474 | dstRank, |
475 | tag); |
476 | } |
477 | |
478 | virtual c10::intrusive_ptr<Work> recv( |
479 | std::vector<at::Tensor>& tensors, |
480 | int srcRank, |
481 | int tag) { |
482 | static auto op = c10::Dispatcher::singleton() |
483 | .findSchemaOrThrow("c10d::recv_" , "" ) |
484 | .typed<c10::intrusive_ptr<::c10d::Work>( |
485 | at::TensorList, |
486 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
487 | int64_t, |
488 | int64_t)>(); |
489 | return op.call( |
490 | tensors, |
491 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
492 | srcRank, |
493 | tag); |
494 | } |
495 | |
496 | virtual c10::intrusive_ptr<Work> recvAnysource( |
497 | std::vector<at::Tensor>& tensors, |
498 | int tag) { |
499 | static auto op = c10::Dispatcher::singleton() |
500 | .findSchemaOrThrow("c10d::recv_any_source_" , "" ) |
501 | .typed<c10::intrusive_ptr<::c10d::Work>( |
502 | at::TensorList, |
503 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
504 | int64_t)>(); |
505 | return op.call( |
506 | tensors, |
507 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
508 | tag); |
509 | } |
510 | |
511 | virtual c10::intrusive_ptr<Work> barrier( |
512 | const BarrierOptions& opts = BarrierOptions()) { |
513 | static at::Tensor tensor; |
514 | // TODO: if nccl was specified then use it |
515 | if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) { |
516 | // set cuda tensor |
517 | tensor = at::empty( |
518 | {1}, |
519 | at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); |
520 | } else { |
521 | // Default to using cpu implementation |
522 | tensor = at::empty( |
523 | {1}, |
524 | at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte)); |
525 | } |
526 | |
527 | static auto op = c10::Dispatcher::singleton() |
528 | .findSchemaOrThrow("c10d::barrier" , "" ) |
529 | .typed<c10::intrusive_ptr<::c10d::Work>( |
530 | at::Tensor, |
531 | const c10::intrusive_ptr<::c10d::ProcessGroup>&, |
532 | const std::vector<int64_t>&, |
533 | int64_t)>(); |
534 | |
535 | return op.call( |
536 | tensor, |
537 | c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), |
538 | opts.device_ids, |
539 | opts.timeout.count()); |
540 | } |
541 | |
542 | c10::intrusive_ptr<Options> getOptions() { |
543 | return options_; |
544 | } |
545 | |
546 | bool hasBackends() { |
547 | return !deviceTypeToBackendType_.empty(); |
548 | } |
549 | |
550 | void setBackend( |
551 | c10::DeviceType deviceType, |
552 | BackendType backendType, |
553 | const c10::optional<c10::intrusive_ptr<Backend>>& backend) { |
554 | deviceTypeToBackendType_[deviceType] = backendType; |
555 | // if the backendType is already set then reuse it for this device |
556 | if (backendTypeToBackend_.find(backendType) != |
557 | backendTypeToBackend_.end()) { |
558 | auto existingBackend = backendTypeToBackend_.at(backendType); |
559 | deviceTypeToBackend_[deviceType] = existingBackend; |
560 | } else { |
561 | // check if backend has value |
562 | if (backend.has_value()) { |
563 | deviceTypeToBackend_[deviceType] = backend.value(); |
564 | backendTypeToBackend_[backendType] = backend.value(); |
565 | } |
566 | } |
567 | } |
568 | |
569 | c10::intrusive_ptr<Backend> getDefaultBackend() const { |
570 | TORCH_CHECK( |
571 | backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(), |
572 | "Could not find the default backend type " , |
573 | backendType_, |
574 | " for Process Group with name " , |
575 | getBackendName(), |
576 | "." ); |
577 | return backendTypeToBackend_.at(backendType_); |
578 | } |
579 | |
580 | c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType); |
581 | |
582 | c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const { |
583 | TORCH_CHECK( |
584 | backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(), |
585 | "Could not find backend type " , |
586 | backendType, |
587 | "." ); |
588 | return backendTypeToBackend_.at(backendType); |
589 | } |
590 | |
591 | protected: |
592 | // Implementations of this interface need to call this to setup |
593 | // appropriate logging etc. |
594 | void init(); |
595 | |
596 | const c10::intrusive_ptr<c10d::Store> store_; |
597 | const int rank_; |
598 | const int size_; |
599 | const c10::intrusive_ptr<Options> options_; |
600 | const BackendType backendType_; |
601 | // Optional sequence number structure for matching collectives. |
602 | c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt; |
603 | |
604 | // Debug level setting. It is parsed once when ProcessGroup is constructed and |
605 | // remains the same across use of this process group. |
606 | DebugLevel dist_debug_level_; |
607 | |
608 | // Backend classes for this ProcessGroup |
609 | std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_; |
610 | std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>> |
611 | deviceTypeToBackend_; |
612 | std::unordered_map<BackendType, c10::intrusive_ptr<Backend>> |
613 | backendTypeToBackend_; |
614 | }; |
615 | |
616 | } // namespace c10d |
617 | |