1 | #include <ATen/ThreadLocalState.h> |
2 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
3 | |
4 | #include <c10/util/Logging.h> |
5 | #include <fmt/format.h> |
6 | |
7 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
8 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
9 | #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp> |
10 | #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
11 | #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp> |
12 | #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp> |
13 | |
14 | namespace c10d { |
15 | |
16 | ProcessGroup::BackendType strToBackendType(std::string backend) { |
17 | if (backend == "undefined" ) { |
18 | return ProcessGroup::BackendType::UNDEFINED; |
19 | } else if (backend == "gloo" ) { |
20 | return ProcessGroup::BackendType::GLOO; |
21 | } else if (backend == "nccl" ) { |
22 | return ProcessGroup::BackendType::NCCL; |
23 | } else if (backend == "ucc" ) { |
24 | return ProcessGroup::BackendType::UCC; |
25 | } else if (backend == "mpi" ) { |
26 | return ProcessGroup::BackendType::MPI; |
27 | } else { |
28 | return ProcessGroup::BackendType::CUSTOM; |
29 | } |
30 | } |
31 | |
32 | std::string backendTypeToStr(ProcessGroup::BackendType backendType) { |
33 | switch (backendType) { |
34 | case ProcessGroup::BackendType::UNDEFINED: |
35 | return "undefined" ; |
36 | case ProcessGroup::BackendType::GLOO: |
37 | return "gloo" ; |
38 | case ProcessGroup::BackendType::NCCL: |
39 | return "nccl" ; |
40 | case ProcessGroup::BackendType::UCC: |
41 | return "ucc" ; |
42 | case ProcessGroup::BackendType::MPI: |
43 | return "mpi" ; |
44 | case ProcessGroup::BackendType::CUSTOM: |
45 | return "custom" ; |
46 | default: |
47 | TORCH_INTERNAL_ASSERT(false, "Unknown backend type" ); |
48 | } |
49 | } |
50 | |
51 | std::string opTypeToString(OpType opType) { |
52 | switch (opType) { |
53 | case OpType::BROADCAST: |
54 | return "BROADCAST" ; |
55 | case OpType::ALLREDUCE: |
56 | return "ALLREDUCE" ; |
57 | case OpType::ALLREDUCE_COALESCED: |
58 | return "ALLREDUCE_COALESCED" ; |
59 | case OpType::REDUCE: |
60 | return "REDUCE" ; |
61 | case OpType::ALLGATHER: |
62 | return "ALLGATHER" ; |
63 | case OpType::_ALLGATHER_BASE: |
64 | return "_ALLGATHER_BASE" ; |
65 | case OpType::ALLGATHER_COALESCED: |
66 | return "ALLGATHER_COALESCED" ; |
67 | case OpType::GATHER: |
68 | return "GATHER" ; |
69 | case OpType::SCATTER: |
70 | return "SCATTER" ; |
71 | case OpType::REDUCE_SCATTER: |
72 | return "REDUCE_SCATTER" ; |
73 | case OpType::ALLTOALL_BASE: |
74 | return "ALLTOALL_BASE" ; |
75 | case OpType::ALLTOALL: |
76 | return "ALLTOALL" ; |
77 | case OpType::SEND: |
78 | return "SEND" ; |
79 | case OpType::RECV: |
80 | return "RECV" ; |
81 | case OpType::RECVANYSOURCE: |
82 | return "RECVANYSOURCE" ; |
83 | case OpType::BARRIER: |
84 | return "BARRIER" ; |
85 | case OpType::UNKNOWN: |
86 | return "UNKNOWN" ; |
87 | case OpType::_REDUCE_SCATTER_BASE: |
88 | return "_REDUCE_SCATTER_BASE" ; |
89 | default: |
90 | TORCH_INTERNAL_ASSERT(false, "Unknown op type!" ); |
91 | } |
92 | return "UNKNOWN" ; |
93 | } |
94 | |
95 | bool isP2POp(OpType opType, bool batchP2P /*= false*/) { |
96 | if (batchP2P) |
97 | return false; |
98 | return opType == OpType::SEND || opType == OpType::RECV || |
99 | opType == OpType::RECVANYSOURCE; |
100 | } |
101 | |
102 | c10::intrusive_ptr<Backend> ProcessGroup::getBackend( |
103 | c10::DeviceType deviceType) { |
104 | // If there is a backend associated with this device type then return it |
105 | if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) { |
106 | return deviceTypeToBackend_.at(deviceType); |
107 | } |
108 | |
109 | // Get the backend type associated with the device |
110 | ProcessGroup::BackendType backendType; |
111 | try { |
112 | backendType = deviceTypeToBackendType_.at(deviceType); |
113 | } catch (const std::out_of_range& e) { |
114 | TORCH_CHECK( |
115 | false, "No backend type associated with device type " , deviceType); |
116 | } |
117 | |
118 | // Check if the backend has already been initialized |
119 | if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) { |
120 | auto backend = backendTypeToBackend_.at(backendType); |
121 | deviceTypeToBackend_[deviceType] = backend; |
122 | return backend; |
123 | } |
124 | |
125 | TORCH_CHECK( |
126 | false, |
127 | "Could not retrieve or create the backend " , |
128 | backendType, |
129 | " for device type " , |
130 | deviceType); |
131 | } |
132 | |
133 | ProcessGroup::ProcessGroup( |
134 | const c10::intrusive_ptr<::c10d::Store>& store, |
135 | int rank, |
136 | int size, |
137 | c10::intrusive_ptr<Options> options) |
138 | : store_(store), |
139 | rank_(rank), |
140 | size_(size), |
141 | options_(options), |
142 | backendType_(strToBackendType(options->backend)), |
143 | dist_debug_level_(debug_level()) { |
144 | C10_LOG_API_USAGE_ONCE("c10d.process_group" ); |
145 | } |
146 | |
147 | ProcessGroup::ProcessGroup(int rank, int size) |
148 | : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {} |
149 | |
150 | ProcessGroup::~ProcessGroup() = default; |
151 | |
152 | void ProcessGroup::init() { |
153 | C10_LOG_API_USAGE_ONCE( |
154 | fmt::format("c10d.process_group_{}" , getBackendName())); |
155 | } |
156 | } // namespace c10d |
157 | |