1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/distributed/c10d/Store.hpp> |
5 | #include <torch/csrc/distributed/c10d/Types.hpp> |
6 | |
7 | #include <sys/types.h> |
8 | |
9 | #include <cstdlib> |
10 | #include <string> |
11 | #include <system_error> |
12 | #include <vector> |
13 | |
14 | namespace c10d { |
15 | inline std::string getTraceStartKey(const std::string& pgName, int rank) { |
16 | return pgName + "_" + std::to_string(rank) + "_trace_start" ; |
17 | } |
18 | |
19 | inline std::string getTraceEndKey(const std::string& pgName, int rank) { |
20 | return pgName + "_" + std::to_string(rank) + "_trace_end" ; |
21 | } |
22 | |
23 | inline bool traceUpdate( |
24 | c10::intrusive_ptr<Store>& store, |
25 | const std::string& key, |
26 | uint64_t seq, |
27 | const std::string& col) { |
28 | std::vector<uint8_t> value(col.size() + sizeof(seq) + 1); |
29 | memcpy(value.data(), &seq, sizeof(seq)); |
30 | memcpy(value.data() + sizeof(seq), col.data(), col.size()); |
31 | try { |
32 | store->set(key, value); |
33 | return true; |
34 | } catch (...) { |
35 | LOG(ERROR) << "Store is down while updating #" << seq << " with key " |
36 | << key; |
37 | return false; |
38 | } |
39 | return true; |
40 | } |
41 | |
42 | enum TraceDebugEvent { |
43 | kEventStart, |
44 | kEventEnd, |
45 | }; |
46 | // <seq, <rank, <col, start/end>>> |
47 | using TraceMap = |
48 | std::map<uint64_t, std::map<int, std::pair<std::string, TraceDebugEvent>>>; |
49 | |
50 | inline std::string ranksToString(const std::vector<int>& ranks) { |
51 | std::string str; |
52 | for (int rank : ranks) { |
53 | if (str.empty()) { |
54 | str = std::to_string(rank); |
55 | } else { |
56 | str += ", " + std::to_string(rank); |
57 | } |
58 | } |
59 | return str; |
60 | } |
61 | |
62 | inline std::string ranksFromTrace( |
63 | const std::vector<std::pair<int, std::string>>& items) { |
64 | std::string ranks; |
65 | for (auto& p : items) { |
66 | if (ranks.empty()) { |
67 | ranks = std::to_string(p.first); |
68 | } else { |
69 | ranks += ", " + std::to_string(p.first); |
70 | } |
71 | } |
72 | return ranks; |
73 | } |
74 | |
75 | inline std::string analyzeMissingRanks(const std::vector<int>& missingRanks) { |
76 | return c10::str( |
77 | "\n\t - To our best knowledge, ranks [" , |
78 | ranksToString(missingRanks), |
79 | "] are the lagging ranks that caused this timeout. " |
80 | "They never joined any collectives" ); |
81 | } |
82 | |
83 | inline std::string analyzeLaggingRanks(const TraceMap& traceMap) { |
84 | uint64_t lagSeq = traceMap.begin()->first; |
85 | std::vector<int> startRanks; |
86 | std::vector<int> endRanks; |
87 | for (auto& p : traceMap.begin()->second) { |
88 | if (p.second.second == kEventStart) { |
89 | startRanks.push_back(p.first); |
90 | } else { |
91 | endRanks.push_back(p.first); |
92 | } |
93 | } |
94 | std::string report = |
95 | "\n\t - To our best knowledge, the lagging/dead/mismatched ranks " |
96 | "that caused the desync are:" ; |
97 | if (startRanks.size()) { |
98 | report += c10::str( |
99 | "\n\t - [" , |
100 | ranksToString(startRanks), |
101 | "] joined but didn't finish collective #" , |
102 | lagSeq, |
103 | " (count from 1)" ); |
104 | } |
105 | if (endRanks.size()) { |
106 | report += c10::str( |
107 | "\n\t [" , |
108 | ranksToString(endRanks), |
109 | "] finished collective #" , |
110 | lagSeq, |
111 | ", but didn't join collective #" , |
112 | lagSeq + 1, |
113 | " (count from 1)" ); |
114 | } |
115 | return report; |
116 | } |
117 | |
118 | inline std::string dumpSnapshot(TraceMap& traceMap) { |
119 | std::string report = "\n\t - Snapshot of ranks' latest states:" ; |
120 | for (auto& tracePair : traceMap) { |
121 | uint64_t seq = tracePair.first; |
122 | std::map<int, std::pair<std::string, TraceDebugEvent>>& subMap = |
123 | tracePair.second; |
124 | |
125 | std::unordered_map<std::string, std::vector<int>> collectivesStart; |
126 | std::unordered_map<std::string, std::vector<int>> collectivesEnd; |
127 | for (auto& p : subMap) { |
128 | int rank = p.first; |
129 | const std::string& col = p.second.first; |
130 | if (p.second.second == kEventStart) { |
131 | collectivesStart[col].push_back(rank); |
132 | } else { |
133 | collectivesEnd[col].push_back(rank); |
134 | } |
135 | } |
136 | |
137 | if (collectivesStart.size()) { |
138 | report += c10::str("\n\t #" , seq, " started ranks:" ); |
139 | for (auto& mapPair : collectivesStart) { |
140 | report += c10::str( |
141 | "\n\t [" , |
142 | ranksToString(mapPair.second), |
143 | "] started " , |
144 | mapPair.first); |
145 | } |
146 | } |
147 | if (collectivesEnd.size()) { |
148 | report += c10::str("\n\t #" , seq, " finished ranks:" ); |
149 | for (auto& mapPair : collectivesEnd) { |
150 | report += c10::str( |
151 | "\n\t [" , |
152 | ranksToString(mapPair.second), |
153 | "] finished " , |
154 | mapPair.first); |
155 | } |
156 | } |
157 | } |
158 | return report; |
159 | } |
160 | |
161 | inline bool parseTraceValue( |
162 | c10::intrusive_ptr<Store>& store, |
163 | const std::string& key, |
164 | uint64_t& seq, |
165 | std::string& col) { |
166 | try { |
167 | std::vector<uint8_t> traceValue = store->get(key); |
168 | memcpy(&seq, traceValue.data(), sizeof(seq)); |
169 | std::string colName((char*)traceValue.data() + sizeof(seq)); |
170 | col = colName; |
171 | return true; |
172 | } catch (...) { |
173 | LOG(ERROR) << "Store is down while getting key " << key; |
174 | return false; |
175 | } |
176 | return true; |
177 | } |
178 | |
179 | inline std::string retrieveDesyncReport( |
180 | c10::intrusive_ptr<Store>& store, |
181 | const std::string& pgName, |
182 | int myRank, |
183 | int worldSize) { |
184 | std::string report; |
185 | |
186 | uint64_t thisSeq; |
187 | std::string thisCol; |
188 | |
189 | std::vector<int> missingRanks; |
190 | TraceMap traceMap; |
191 | |
192 | for (const auto rank : c10::irange(worldSize)) { |
193 | // Build traceMapStart. |
194 | uint64_t seqStart; |
195 | { |
196 | std::string traceKeyStart = getTraceStartKey(pgName, rank); |
197 | if (!store->check({traceKeyStart})) { |
198 | missingRanks.push_back(rank); |
199 | continue; |
200 | } |
201 | std::string col; |
202 | if (!parseTraceValue(store, traceKeyStart, seqStart, col)) { |
203 | return report; |
204 | } |
205 | traceMap[seqStart].emplace(rank, std::make_pair(col, kEventStart)); |
206 | if (rank == myRank) { |
207 | thisSeq = seqStart; |
208 | thisCol = std::move(col); |
209 | } |
210 | } |
211 | |
212 | // Build traceMapEnd. |
213 | { |
214 | std::string traceKeyEnd = getTraceEndKey(pgName, rank); |
215 | if (!store->check({traceKeyEnd})) { |
216 | continue; |
217 | } |
218 | uint64_t seq; |
219 | std::string col; |
220 | if (!parseTraceValue(store, traceKeyEnd, seq, col)) { |
221 | return report; |
222 | } |
223 | if (seq == seqStart) { |
224 | traceMap[seq][rank].second = kEventEnd; |
225 | } |
226 | } |
227 | } |
228 | |
229 | TORCH_INTERNAL_ASSERT( |
230 | !missingRanks.empty() || !traceMap.empty(), |
231 | "Trace shouldn't be empty while enabled GLOO_ASYNC_TIMEOUT_DEBUG" ); |
232 | TORCH_INTERNAL_ASSERT( |
233 | !thisCol.empty(), |
234 | "Timeout rank [" , |
235 | myRank, |
236 | "] must have collective tracking iteam in c10::Store trace" ); |
237 | TORCH_INTERNAL_ASSERT( |
238 | traceMap[thisSeq][myRank].second == kEventStart, |
239 | "Timeout rank [" , |
240 | myRank, |
241 | "] last trace item must be kEventStart. thisSeq = " , |
242 | thisSeq, |
243 | ", col = " , |
244 | thisCol); |
245 | |
246 | report += c10::str( |
247 | "\n\t - [" , myRank, "] Timeout at collective: " , thisCol, ", #" , thisSeq); |
248 | |
249 | if (!missingRanks.empty()) { |
250 | report += analyzeMissingRanks(missingRanks); |
251 | } else { |
252 | report += analyzeLaggingRanks(traceMap); |
253 | report += dumpSnapshot(traceMap); |
254 | } |
255 | |
256 | return report; |
257 | } |
258 | |
259 | } // namespace c10d |
260 | |