1#pragma once
2
3#ifdef USE_C10D_GLOO
4
5#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
6#include <torch/csrc/distributed/c10d/Types.hpp>
7#include <torch/csrc/distributed/c10d/Utils.hpp>
8
9namespace c10d {
10
11class TORCH_API ProcessGroupWrapper : public Backend {
12 public:
13 explicit ProcessGroupWrapper(
14 c10::intrusive_ptr<Backend> backend,
15 c10::intrusive_ptr<Backend> glooBackend);
16
17 const std::string getBackendName() const override;
18
19 c10::intrusive_ptr<Work> broadcast(
20 std::vector<at::Tensor>& data,
21 const BroadcastOptions& opts = BroadcastOptions()) override;
22
23 c10::intrusive_ptr<Work> allreduce(
24 std::vector<at::Tensor>& data,
25 const AllreduceOptions& opts = AllreduceOptions()) override;
26
27 c10::intrusive_ptr<Work> allreduce_coalesced(
28 std::vector<at::Tensor>& tensors,
29 const AllreduceCoalescedOptions& opts =
30 AllreduceCoalescedOptions()) override;
31
32 c10::intrusive_ptr<Work> reduce(
33 std::vector<at::Tensor>& tensors,
34 const ReduceOptions& opts = ReduceOptions()) override;
35
36 c10::intrusive_ptr<Work> allgather(
37 std::vector<std::vector<at::Tensor>>& outputTensors,
38 std::vector<at::Tensor>& inputTensors,
39 const AllgatherOptions& opts = AllgatherOptions()) override;
40
41 c10::intrusive_ptr<Work> _allgather_base(
42 at::Tensor& outputBuffer,
43 at::Tensor& inputBuffer,
44 const AllgatherOptions& opts = AllgatherOptions()) override;
45
46 // This function is deprecated and will be moved out of ProcessGroup to comms:
47 // * do not add dependencies on this function,
48 // * do not implement it in your ProcessGroup, implement _allgather_base
49 // instead.
50 c10::intrusive_ptr<Work> allgather_coalesced(
51 std::vector<std::vector<at::Tensor>>& outputTensorLists,
52 std::vector<at::Tensor>& inputTensors,
53 const AllgatherOptions& opts = AllgatherOptions()) override;
54
55 c10::intrusive_ptr<Work> gather(
56 std::vector<std::vector<at::Tensor>>& outputTensors,
57 std::vector<at::Tensor>& inputTensors,
58 const GatherOptions& opts = GatherOptions()) override;
59
60 c10::intrusive_ptr<Work> scatter(
61 std::vector<at::Tensor>& outputTensors,
62 std::vector<std::vector<at::Tensor>>& inputTensors,
63 const ScatterOptions& opts = ScatterOptions()) override;
64
65 c10::intrusive_ptr<Work> reduce_scatter(
66 std::vector<at::Tensor>& outputTensors,
67 std::vector<std::vector<at::Tensor>>& inputTensors,
68 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
69
70 c10::intrusive_ptr<Work> alltoall_base(
71 at::Tensor& outputTensor,
72 at::Tensor& inputTensor,
73 std::vector<int64_t>& outputSplitSizes,
74 std::vector<int64_t>& inputSplitSizes,
75 const AllToAllOptions& opts = AllToAllOptions()) override;
76
77 c10::intrusive_ptr<Work> alltoall(
78 std::vector<at::Tensor>& outputTensors,
79 std::vector<at::Tensor>& inputTensors,
80 const AllToAllOptions& opts = AllToAllOptions()) override;
81
82 void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
83 override;
84
85 // Agrees on an initial sequence number for the whole group by having rank 0
86 // create it and broadcast it to other ranks using the store. Only implemented
87 // for GLOO and NCCL backends currently.
88 // dont implement this
89 void setSequenceNumberForGroup() override;
90
91 // Retrieves the current sequence number for the whole group, which should be
92 // in sync. If the returned number is not consistent across the group, it
93 // may indicate that there is some sort of collective desynchronization.
94 uint64_t getSequenceNumberForGroup() override; // just call underlying
95
96 c10::intrusive_ptr<Work> send(
97 std::vector<at::Tensor>& tensors,
98 int dstRank,
99 int tag) override;
100
101 c10::intrusive_ptr<Work> recv(
102 std::vector<at::Tensor>& tensors,
103 int srcRank,
104 int tag) override;
105
106 c10::intrusive_ptr<Work> recvAnysource(
107 std::vector<at::Tensor>& tensors,
108 int tag) override;
109
110 c10::intrusive_ptr<Work> barrier(
111 const BarrierOptions& opts = BarrierOptions()) override;
112
113 c10::intrusive_ptr<Work> _reduce_scatter_base(
114 at::Tensor& outputBuffer,
115 at::Tensor& inputBuffer,
116 const ReduceScatterOptions& opts) override;
117
118 c10::intrusive_ptr<Backend> getWrappedPg() const;
119
120 private:
121 // Underlying process group that actual application collectives will be
122 // dispatched to
123 c10::intrusive_ptr<Backend> backend_;
124 // Gloo process group responsible for internal coordination such as monitored
125 // barrier, sequence number checking, collective fingerprint collecting.
126 c10::intrusive_ptr<Backend> glooBackend_;
127 // Conducts several checks to ensure that the underlying collective is well
128 // formed with the goal of notifying the user about incorrect collective use
129 // in the application.
130 void runCollectiveChecks(
131 OpType op_type,
132 const std::vector<at::Tensor>& tensors) const;
133};
134} // namespace c10d
135
136#endif // USE_C10D_GLOO
137