1#include <ATen/ThreadLocalState.h>
2#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
3
4#include <c10/util/Logging.h>
5#include <fmt/format.h>
6
7#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
8#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
9#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
10#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
11#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
12#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
13
14namespace c10d {
15
16ProcessGroup::BackendType strToBackendType(std::string backend) {
17 if (backend == "undefined") {
18 return ProcessGroup::BackendType::UNDEFINED;
19 } else if (backend == "gloo") {
20 return ProcessGroup::BackendType::GLOO;
21 } else if (backend == "nccl") {
22 return ProcessGroup::BackendType::NCCL;
23 } else if (backend == "ucc") {
24 return ProcessGroup::BackendType::UCC;
25 } else if (backend == "mpi") {
26 return ProcessGroup::BackendType::MPI;
27 } else {
28 return ProcessGroup::BackendType::CUSTOM;
29 }
30}
31
32std::string backendTypeToStr(ProcessGroup::BackendType backendType) {
33 switch (backendType) {
34 case ProcessGroup::BackendType::UNDEFINED:
35 return "undefined";
36 case ProcessGroup::BackendType::GLOO:
37 return "gloo";
38 case ProcessGroup::BackendType::NCCL:
39 return "nccl";
40 case ProcessGroup::BackendType::UCC:
41 return "ucc";
42 case ProcessGroup::BackendType::MPI:
43 return "mpi";
44 case ProcessGroup::BackendType::CUSTOM:
45 return "custom";
46 default:
47 TORCH_INTERNAL_ASSERT(false, "Unknown backend type");
48 }
49}
50
51std::string opTypeToString(OpType opType) {
52 switch (opType) {
53 case OpType::BROADCAST:
54 return "BROADCAST";
55 case OpType::ALLREDUCE:
56 return "ALLREDUCE";
57 case OpType::ALLREDUCE_COALESCED:
58 return "ALLREDUCE_COALESCED";
59 case OpType::REDUCE:
60 return "REDUCE";
61 case OpType::ALLGATHER:
62 return "ALLGATHER";
63 case OpType::_ALLGATHER_BASE:
64 return "_ALLGATHER_BASE";
65 case OpType::ALLGATHER_COALESCED:
66 return "ALLGATHER_COALESCED";
67 case OpType::GATHER:
68 return "GATHER";
69 case OpType::SCATTER:
70 return "SCATTER";
71 case OpType::REDUCE_SCATTER:
72 return "REDUCE_SCATTER";
73 case OpType::ALLTOALL_BASE:
74 return "ALLTOALL_BASE";
75 case OpType::ALLTOALL:
76 return "ALLTOALL";
77 case OpType::SEND:
78 return "SEND";
79 case OpType::RECV:
80 return "RECV";
81 case OpType::RECVANYSOURCE:
82 return "RECVANYSOURCE";
83 case OpType::BARRIER:
84 return "BARRIER";
85 case OpType::UNKNOWN:
86 return "UNKNOWN";
87 case OpType::_REDUCE_SCATTER_BASE:
88 return "_REDUCE_SCATTER_BASE";
89 default:
90 TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
91 }
92 return "UNKNOWN";
93}
94
95bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
96 if (batchP2P)
97 return false;
98 return opType == OpType::SEND || opType == OpType::RECV ||
99 opType == OpType::RECVANYSOURCE;
100}
101
102c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
103 c10::DeviceType deviceType) {
104 // If there is a backend associated with this device type then return it
105 if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
106 return deviceTypeToBackend_.at(deviceType);
107 }
108
109 // Get the backend type associated with the device
110 ProcessGroup::BackendType backendType;
111 try {
112 backendType = deviceTypeToBackendType_.at(deviceType);
113 } catch (const std::out_of_range& e) {
114 TORCH_CHECK(
115 false, "No backend type associated with device type ", deviceType);
116 }
117
118 // Check if the backend has already been initialized
119 if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
120 auto backend = backendTypeToBackend_.at(backendType);
121 deviceTypeToBackend_[deviceType] = backend;
122 return backend;
123 }
124
125 TORCH_CHECK(
126 false,
127 "Could not retrieve or create the backend ",
128 backendType,
129 " for device type ",
130 deviceType);
131}
132
133ProcessGroup::ProcessGroup(
134 const c10::intrusive_ptr<::c10d::Store>& store,
135 int rank,
136 int size,
137 c10::intrusive_ptr<Options> options)
138 : store_(store),
139 rank_(rank),
140 size_(size),
141 options_(options),
142 backendType_(strToBackendType(options->backend)),
143 dist_debug_level_(debug_level()) {
144 C10_LOG_API_USAGE_ONCE("c10d.process_group");
145}
146
147ProcessGroup::ProcessGroup(int rank, int size)
148 : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
149
150ProcessGroup::~ProcessGroup() = default;
151
152void ProcessGroup::init() {
153 C10_LOG_API_USAGE_ONCE(
154 fmt::format("c10d.process_group_{}", getBackendName()));
155}
156} // namespace c10d
157