1 | #include <fmt/format.h> |
2 | #include <torch/csrc/distributed/rpc/agent_utils.h> |
3 | |
4 | namespace torch { |
5 | namespace distributed { |
6 | namespace rpc { |
7 | |
8 | std::unordered_map<std::string, worker_id_t> collectNames( |
9 | ::c10d::PrefixStore store, |
10 | const worker_id_t selfId, |
11 | const std::string& selfName, |
12 | const int worldSize) { |
13 | std::vector<uint8_t> selfNameVector( |
14 | (uint8_t*)selfName.c_str(), |
15 | (uint8_t*)selfName.c_str() + selfName.length()); |
16 | store.set(c10::to_string(selfId), selfNameVector); |
17 | |
18 | std::unordered_map<std::string, worker_id_t> nameToId; |
19 | nameToId.reserve(worldSize); |
20 | nameToId.emplace(selfName, selfId); |
21 | for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) { |
22 | if (workerId == selfId) { |
23 | continue; |
24 | } |
25 | std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId)); |
26 | std::string workerName( |
27 | (char*)workerNameVector.data(), workerNameVector.size()); |
28 | |
29 | TORCH_CHECK( |
30 | nameToId.find(workerName) == nameToId.end(), |
31 | "RPC worker name " , |
32 | workerName, |
33 | " is not unique. Workers " , |
34 | nameToId.find(workerName)->second, |
35 | " and " , |
36 | workerId, |
37 | " share the same name." ); |
38 | |
39 | nameToId.emplace(workerName, workerId); |
40 | } |
41 | return nameToId; |
42 | } |
43 | |
44 | std::vector<std::string> splitString( |
45 | const std::string& s, |
46 | const std::string& delim) { |
47 | std::vector<std::string> tokens; |
48 | size_t start = 0; |
49 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
50 | size_t end; |
51 | // Iterate through each delimiter |
52 | while ((end = s.find(delim, start)) != std::string::npos) { |
53 | tokens.emplace_back(s.substr(start, end - start)); |
54 | start = end + delim.length(); |
55 | } |
56 | tokens.emplace_back(s.substr(start)); |
57 | return tokens; |
58 | } |
59 | |
60 | const std::string allWorkerInfosKey = "_ALL_WORKER_INFOS" ; |
61 | |
62 | std::unordered_map<std::string, worker_id_t> collectCurrentNames( |
63 | ::c10d::PrefixStore store, |
64 | const worker_id_t selfId, |
65 | const std::string& selfName) { |
66 | std::vector<uint8_t> selfNameVector( |
67 | (uint8_t*)selfName.c_str(), |
68 | (uint8_t*)selfName.c_str() + selfName.length()); |
69 | |
70 | // Check that ID does not already exist and set {ID : NAME} |
71 | std::vector<uint8_t> resultVector = store.compareSet( |
72 | c10::to_string(selfId), std::vector<uint8_t>(), selfNameVector); |
73 | TORCH_CHECK( |
74 | resultVector == selfNameVector, |
75 | "RPC worker id " , |
76 | selfId, |
77 | " is not unique. Worker " , |
78 | resultVector, |
79 | " and already has ID and " , |
80 | selfNameVector, |
81 | " cannot be added." ); |
82 | |
83 | store.set(c10::to_string(selfId), selfNameVector); |
84 | |
85 | std::unordered_map<std::string, worker_id_t> nameToId; |
86 | nameToId.emplace(selfName, selfId); |
87 | |
88 | // Check to see if there is list of worker names in the store |
89 | bool worker_names_available = |
90 | store.check(std::vector<std::string>{allWorkerInfosKey}); |
91 | std::string allWorkerInfos; |
92 | if (worker_names_available) { |
93 | // Get the current list of workers |
94 | std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey); |
95 | allWorkerInfos = std::string( |
96 | (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size()); |
97 | // workerInfos are comma separated with a comma at the end (e.g. |
98 | // "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers. |
99 | if (!allWorkerInfos.empty()) { |
100 | for (const std::string& workerInfoString : splitString( |
101 | allWorkerInfos.substr(0, allWorkerInfos.size() - 1), "," )) { |
102 | auto workerInfoVec = splitString(workerInfoString, "-" ); |
103 | std::string workerName = workerInfoVec.at(0); |
104 | int workerId = std::stoi(workerInfoVec.at(1)); |
105 | |
106 | TORCH_CHECK( |
107 | nameToId.find(workerName) == nameToId.end(), |
108 | "RPC worker name " , |
109 | workerName, |
110 | " is not unique. Workers " , |
111 | nameToId.find(workerName)->second, |
112 | " and " , |
113 | workerId, |
114 | " share the same name." ); |
115 | |
116 | nameToId.emplace(workerName, workerId); |
117 | } |
118 | } |
119 | } |
120 | // Add own name to worker list |
121 | allWorkerInfos = fmt::format("{}{}-{}," , allWorkerInfos, selfName, selfId); |
122 | std::vector<uint8_t> allWorkerInfosVector( |
123 | (uint8_t*)allWorkerInfos.c_str(), |
124 | (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length()); |
125 | store.set(allWorkerInfosKey, allWorkerInfosVector); |
126 | |
127 | return nameToId; |
128 | } |
129 | |
130 | void removeCurrentName( |
131 | ::c10d::PrefixStore store, |
132 | const worker_id_t selfId, |
133 | const std::string& selfName) { |
134 | // Get current list of names/ranks |
135 | std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey); |
136 | std::string allWorkerInfos = std::string( |
137 | (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size()); |
138 | |
139 | // Remove the current name and rank |
140 | std::string str_to_erase = fmt::format("{}-{}," , selfName, selfId); |
141 | int start_position_to_erase = allWorkerInfos.find(str_to_erase); |
142 | allWorkerInfos.erase(start_position_to_erase, str_to_erase.length()); |
143 | |
144 | // Set the new data |
145 | std::vector<uint8_t> newAllWorkerInfosVector( |
146 | (uint8_t*)allWorkerInfos.c_str(), |
147 | (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length()); |
148 | store.set(allWorkerInfosKey, newAllWorkerInfosVector); |
149 | } |
150 | |
151 | const string storeKeyBarrierId = "_ID_" ; |
152 | const string storeKeyProcessCount = "PROCESS_COUNT" ; |
153 | const string storeKeyActiveCallCount = "ACTIVE_CALLS" ; |
154 | const string storeKeyReady = "READY" ; |
155 | static std::atomic<int> barrierId(0); |
156 | |
157 | std::tuple<std::string, std::string, std::string> getNextKeyIds() { |
158 | barrierId++; |
159 | std::string processCountKey = |
160 | fmt::format("{}{}{}" , storeKeyProcessCount, storeKeyBarrierId, barrierId); |
161 | std::string activeCallCountKey = fmt::format( |
162 | "{}{}{}" , storeKeyActiveCallCount, storeKeyBarrierId, barrierId); |
163 | std::string barrierKey = |
164 | fmt::format("{}{}{}" , storeKeyReady, storeKeyBarrierId, barrierId); |
165 | return std::make_tuple(processCountKey, activeCallCountKey, barrierKey); |
166 | } |
167 | |
168 | // Synchronize process with all other agent processes strictly using store |
169 | // Block until all ``RpcAgent``s reach this method. |
170 | // Returns total number of active calls of all RPC agents in the group |
171 | int syncCallCount( |
172 | ::c10d::PrefixStore store, |
173 | const int worldSize, |
174 | int activeCalls) { |
175 | std::string processCountKey, activeCallCountKey, readyKey; |
176 | std::tie(processCountKey, activeCallCountKey, readyKey) = getNextKeyIds(); |
177 | |
178 | // Add to keys which will record the number of processes and active calls |
179 | store.add(activeCallCountKey, activeCalls); |
180 | int totalProcessCount = store.add(processCountKey, 1); |
181 | |
182 | // The last worker will need to set the ready key |
183 | if (totalProcessCount == worldSize) { |
184 | store.set(readyKey, std::vector<uint8_t>()); |
185 | } |
186 | |
187 | // Wait on the ready key to be set |
188 | store.wait(std::vector<std::string>{readyKey}); |
189 | |
190 | // Read count of active calls which may have changed |
191 | auto activeCallCountData = store.get(activeCallCountKey); |
192 | int totalCallCount = std::stoi( |
193 | std::string(activeCallCountData.begin(), activeCallCountData.end())); |
194 | return totalCallCount; |
195 | } |
196 | |
197 | } // namespace rpc |
198 | } // namespace distributed |
199 | } // namespace torch |
200 | |