1#pragma once
2
3#include <c10/util/irange.h>
4#include <torch/csrc/distributed/c10d/Store.hpp>
5#include <torch/csrc/distributed/c10d/Types.hpp>
6
7#include <sys/types.h>
8
9#include <cstdlib>
10#include <string>
11#include <system_error>
12#include <vector>
13
14namespace c10d {
15inline std::string getTraceStartKey(const std::string& pgName, int rank) {
16 return pgName + "_" + std::to_string(rank) + "_trace_start";
17}
18
19inline std::string getTraceEndKey(const std::string& pgName, int rank) {
20 return pgName + "_" + std::to_string(rank) + "_trace_end";
21}
22
23inline bool traceUpdate(
24 c10::intrusive_ptr<Store>& store,
25 const std::string& key,
26 uint64_t seq,
27 const std::string& col) {
28 std::vector<uint8_t> value(col.size() + sizeof(seq) + 1);
29 memcpy(value.data(), &seq, sizeof(seq));
30 memcpy(value.data() + sizeof(seq), col.data(), col.size());
31 try {
32 store->set(key, value);
33 return true;
34 } catch (...) {
35 LOG(ERROR) << "Store is down while updating #" << seq << " with key "
36 << key;
37 return false;
38 }
39 return true;
40}
41
42enum TraceDebugEvent {
43 kEventStart,
44 kEventEnd,
45};
46// <seq, <rank, <col, start/end>>>
47using TraceMap =
48 std::map<uint64_t, std::map<int, std::pair<std::string, TraceDebugEvent>>>;
49
50inline std::string ranksToString(const std::vector<int>& ranks) {
51 std::string str;
52 for (int rank : ranks) {
53 if (str.empty()) {
54 str = std::to_string(rank);
55 } else {
56 str += ", " + std::to_string(rank);
57 }
58 }
59 return str;
60}
61
62inline std::string ranksFromTrace(
63 const std::vector<std::pair<int, std::string>>& items) {
64 std::string ranks;
65 for (auto& p : items) {
66 if (ranks.empty()) {
67 ranks = std::to_string(p.first);
68 } else {
69 ranks += ", " + std::to_string(p.first);
70 }
71 }
72 return ranks;
73}
74
75inline std::string analyzeMissingRanks(const std::vector<int>& missingRanks) {
76 return c10::str(
77 "\n\t - To our best knowledge, ranks [",
78 ranksToString(missingRanks),
79 "] are the lagging ranks that caused this timeout. "
80 "They never joined any collectives");
81}
82
83inline std::string analyzeLaggingRanks(const TraceMap& traceMap) {
84 uint64_t lagSeq = traceMap.begin()->first;
85 std::vector<int> startRanks;
86 std::vector<int> endRanks;
87 for (auto& p : traceMap.begin()->second) {
88 if (p.second.second == kEventStart) {
89 startRanks.push_back(p.first);
90 } else {
91 endRanks.push_back(p.first);
92 }
93 }
94 std::string report =
95 "\n\t - To our best knowledge, the lagging/dead/mismatched ranks "
96 "that caused the desync are:";
97 if (startRanks.size()) {
98 report += c10::str(
99 "\n\t - [",
100 ranksToString(startRanks),
101 "] joined but didn't finish collective #",
102 lagSeq,
103 " (count from 1)");
104 }
105 if (endRanks.size()) {
106 report += c10::str(
107 "\n\t [",
108 ranksToString(endRanks),
109 "] finished collective #",
110 lagSeq,
111 ", but didn't join collective #",
112 lagSeq + 1,
113 " (count from 1)");
114 }
115 return report;
116}
117
118inline std::string dumpSnapshot(TraceMap& traceMap) {
119 std::string report = "\n\t - Snapshot of ranks' latest states:";
120 for (auto& tracePair : traceMap) {
121 uint64_t seq = tracePair.first;
122 std::map<int, std::pair<std::string, TraceDebugEvent>>& subMap =
123 tracePair.second;
124
125 std::unordered_map<std::string, std::vector<int>> collectivesStart;
126 std::unordered_map<std::string, std::vector<int>> collectivesEnd;
127 for (auto& p : subMap) {
128 int rank = p.first;
129 const std::string& col = p.second.first;
130 if (p.second.second == kEventStart) {
131 collectivesStart[col].push_back(rank);
132 } else {
133 collectivesEnd[col].push_back(rank);
134 }
135 }
136
137 if (collectivesStart.size()) {
138 report += c10::str("\n\t #", seq, " started ranks:");
139 for (auto& mapPair : collectivesStart) {
140 report += c10::str(
141 "\n\t [",
142 ranksToString(mapPair.second),
143 "] started ",
144 mapPair.first);
145 }
146 }
147 if (collectivesEnd.size()) {
148 report += c10::str("\n\t #", seq, " finished ranks:");
149 for (auto& mapPair : collectivesEnd) {
150 report += c10::str(
151 "\n\t [",
152 ranksToString(mapPair.second),
153 "] finished ",
154 mapPair.first);
155 }
156 }
157 }
158 return report;
159}
160
161inline bool parseTraceValue(
162 c10::intrusive_ptr<Store>& store,
163 const std::string& key,
164 uint64_t& seq,
165 std::string& col) {
166 try {
167 std::vector<uint8_t> traceValue = store->get(key);
168 memcpy(&seq, traceValue.data(), sizeof(seq));
169 std::string colName((char*)traceValue.data() + sizeof(seq));
170 col = colName;
171 return true;
172 } catch (...) {
173 LOG(ERROR) << "Store is down while getting key " << key;
174 return false;
175 }
176 return true;
177}
178
179inline std::string retrieveDesyncReport(
180 c10::intrusive_ptr<Store>& store,
181 const std::string& pgName,
182 int myRank,
183 int worldSize) {
184 std::string report;
185
186 uint64_t thisSeq;
187 std::string thisCol;
188
189 std::vector<int> missingRanks;
190 TraceMap traceMap;
191
192 for (const auto rank : c10::irange(worldSize)) {
193 // Build traceMapStart.
194 uint64_t seqStart;
195 {
196 std::string traceKeyStart = getTraceStartKey(pgName, rank);
197 if (!store->check({traceKeyStart})) {
198 missingRanks.push_back(rank);
199 continue;
200 }
201 std::string col;
202 if (!parseTraceValue(store, traceKeyStart, seqStart, col)) {
203 return report;
204 }
205 traceMap[seqStart].emplace(rank, std::make_pair(col, kEventStart));
206 if (rank == myRank) {
207 thisSeq = seqStart;
208 thisCol = std::move(col);
209 }
210 }
211
212 // Build traceMapEnd.
213 {
214 std::string traceKeyEnd = getTraceEndKey(pgName, rank);
215 if (!store->check({traceKeyEnd})) {
216 continue;
217 }
218 uint64_t seq;
219 std::string col;
220 if (!parseTraceValue(store, traceKeyEnd, seq, col)) {
221 return report;
222 }
223 if (seq == seqStart) {
224 traceMap[seq][rank].second = kEventEnd;
225 }
226 }
227 }
228
229 TORCH_INTERNAL_ASSERT(
230 !missingRanks.empty() || !traceMap.empty(),
231 "Trace shouldn't be empty while enabled GLOO_ASYNC_TIMEOUT_DEBUG");
232 TORCH_INTERNAL_ASSERT(
233 !thisCol.empty(),
234 "Timeout rank [",
235 myRank,
236 "] must have collective tracking iteam in c10::Store trace");
237 TORCH_INTERNAL_ASSERT(
238 traceMap[thisSeq][myRank].second == kEventStart,
239 "Timeout rank [",
240 myRank,
241 "] last trace item must be kEventStart. thisSeq = ",
242 thisSeq,
243 ", col = ",
244 thisCol);
245
246 report += c10::str(
247 "\n\t - [", myRank, "] Timeout at collective: ", thisCol, ", #", thisSeq);
248
249 if (!missingRanks.empty()) {
250 report += analyzeMissingRanks(missingRanks);
251 } else {
252 report += analyzeLaggingRanks(traceMap);
253 report += dumpSnapshot(traceMap);
254 }
255
256 return report;
257}
258
259} // namespace c10d
260