1#include <fmt/format.h>
2#include <torch/csrc/distributed/rpc/agent_utils.h>
3
4namespace torch {
5namespace distributed {
6namespace rpc {
7
8std::unordered_map<std::string, worker_id_t> collectNames(
9 ::c10d::PrefixStore store,
10 const worker_id_t selfId,
11 const std::string& selfName,
12 const int worldSize) {
13 std::vector<uint8_t> selfNameVector(
14 (uint8_t*)selfName.c_str(),
15 (uint8_t*)selfName.c_str() + selfName.length());
16 store.set(c10::to_string(selfId), selfNameVector);
17
18 std::unordered_map<std::string, worker_id_t> nameToId;
19 nameToId.reserve(worldSize);
20 nameToId.emplace(selfName, selfId);
21 for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
22 if (workerId == selfId) {
23 continue;
24 }
25 std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
26 std::string workerName(
27 (char*)workerNameVector.data(), workerNameVector.size());
28
29 TORCH_CHECK(
30 nameToId.find(workerName) == nameToId.end(),
31 "RPC worker name ",
32 workerName,
33 " is not unique. Workers ",
34 nameToId.find(workerName)->second,
35 " and ",
36 workerId,
37 " share the same name.");
38
39 nameToId.emplace(workerName, workerId);
40 }
41 return nameToId;
42}
43
44std::vector<std::string> splitString(
45 const std::string& s,
46 const std::string& delim) {
47 std::vector<std::string> tokens;
48 size_t start = 0;
49 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
50 size_t end;
51 // Iterate through each delimiter
52 while ((end = s.find(delim, start)) != std::string::npos) {
53 tokens.emplace_back(s.substr(start, end - start));
54 start = end + delim.length();
55 }
56 tokens.emplace_back(s.substr(start));
57 return tokens;
58}
59
60const std::string allWorkerInfosKey = "_ALL_WORKER_INFOS";
61
62std::unordered_map<std::string, worker_id_t> collectCurrentNames(
63 ::c10d::PrefixStore store,
64 const worker_id_t selfId,
65 const std::string& selfName) {
66 std::vector<uint8_t> selfNameVector(
67 (uint8_t*)selfName.c_str(),
68 (uint8_t*)selfName.c_str() + selfName.length());
69
70 // Check that ID does not already exist and set {ID : NAME}
71 std::vector<uint8_t> resultVector = store.compareSet(
72 c10::to_string(selfId), std::vector<uint8_t>(), selfNameVector);
73 TORCH_CHECK(
74 resultVector == selfNameVector,
75 "RPC worker id ",
76 selfId,
77 " is not unique. Worker ",
78 resultVector,
79 " and already has ID and ",
80 selfNameVector,
81 " cannot be added.");
82
83 store.set(c10::to_string(selfId), selfNameVector);
84
85 std::unordered_map<std::string, worker_id_t> nameToId;
86 nameToId.emplace(selfName, selfId);
87
88 // Check to see if there is list of worker names in the store
89 bool worker_names_available =
90 store.check(std::vector<std::string>{allWorkerInfosKey});
91 std::string allWorkerInfos;
92 if (worker_names_available) {
93 // Get the current list of workers
94 std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
95 allWorkerInfos = std::string(
96 (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
97 // workerInfos are comma separated with a comma at the end (e.g.
98 // "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers.
99 if (!allWorkerInfos.empty()) {
100 for (const std::string& workerInfoString : splitString(
101 allWorkerInfos.substr(0, allWorkerInfos.size() - 1), ",")) {
102 auto workerInfoVec = splitString(workerInfoString, "-");
103 std::string workerName = workerInfoVec.at(0);
104 int workerId = std::stoi(workerInfoVec.at(1));
105
106 TORCH_CHECK(
107 nameToId.find(workerName) == nameToId.end(),
108 "RPC worker name ",
109 workerName,
110 " is not unique. Workers ",
111 nameToId.find(workerName)->second,
112 " and ",
113 workerId,
114 " share the same name.");
115
116 nameToId.emplace(workerName, workerId);
117 }
118 }
119 }
120 // Add own name to worker list
121 allWorkerInfos = fmt::format("{}{}-{},", allWorkerInfos, selfName, selfId);
122 std::vector<uint8_t> allWorkerInfosVector(
123 (uint8_t*)allWorkerInfos.c_str(),
124 (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
125 store.set(allWorkerInfosKey, allWorkerInfosVector);
126
127 return nameToId;
128}
129
130void removeCurrentName(
131 ::c10d::PrefixStore store,
132 const worker_id_t selfId,
133 const std::string& selfName) {
134 // Get current list of names/ranks
135 std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
136 std::string allWorkerInfos = std::string(
137 (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
138
139 // Remove the current name and rank
140 std::string str_to_erase = fmt::format("{}-{},", selfName, selfId);
141 int start_position_to_erase = allWorkerInfos.find(str_to_erase);
142 allWorkerInfos.erase(start_position_to_erase, str_to_erase.length());
143
144 // Set the new data
145 std::vector<uint8_t> newAllWorkerInfosVector(
146 (uint8_t*)allWorkerInfos.c_str(),
147 (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
148 store.set(allWorkerInfosKey, newAllWorkerInfosVector);
149}
150
151const string storeKeyBarrierId = "_ID_";
152const string storeKeyProcessCount = "PROCESS_COUNT";
153const string storeKeyActiveCallCount = "ACTIVE_CALLS";
154const string storeKeyReady = "READY";
155static std::atomic<int> barrierId(0);
156
157std::tuple<std::string, std::string, std::string> getNextKeyIds() {
158 barrierId++;
159 std::string processCountKey =
160 fmt::format("{}{}{}", storeKeyProcessCount, storeKeyBarrierId, barrierId);
161 std::string activeCallCountKey = fmt::format(
162 "{}{}{}", storeKeyActiveCallCount, storeKeyBarrierId, barrierId);
163 std::string barrierKey =
164 fmt::format("{}{}{}", storeKeyReady, storeKeyBarrierId, barrierId);
165 return std::make_tuple(processCountKey, activeCallCountKey, barrierKey);
166}
167
168// Synchronize process with all other agent processes strictly using store
169// Block until all ``RpcAgent``s reach this method.
170// Returns total number of active calls of all RPC agents in the group
171int syncCallCount(
172 ::c10d::PrefixStore store,
173 const int worldSize,
174 int activeCalls) {
175 std::string processCountKey, activeCallCountKey, readyKey;
176 std::tie(processCountKey, activeCallCountKey, readyKey) = getNextKeyIds();
177
178 // Add to keys which will record the number of processes and active calls
179 store.add(activeCallCountKey, activeCalls);
180 int totalProcessCount = store.add(processCountKey, 1);
181
182 // The last worker will need to set the ready key
183 if (totalProcessCount == worldSize) {
184 store.set(readyKey, std::vector<uint8_t>());
185 }
186
187 // Wait on the ready key to be set
188 store.wait(std::vector<std::string>{readyKey});
189
190 // Read count of active calls which may have changed
191 auto activeCallCountData = store.get(activeCallCountKey);
192 int totalCallCount = std::stoi(
193 std::string(activeCallCountData.begin(), activeCallCountData.end()));
194 return totalCallCount;
195}
196
197} // namespace rpc
198} // namespace distributed
199} // namespace torch
200