1 | #include <c10/util/DeadlockDetection.h> |
2 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
3 | |
4 | namespace torch { |
5 | namespace distributed { |
6 | namespace rpc { |
7 | |
8 | RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() { |
9 | // WorkerInfo needs to be registered exactly once. Since the op registration |
10 | // happens in libtorch_python we wrap the class registration in a helper to |
11 | // make sure that if there's multiple copies of Python such as used in |
12 | // torch::deploy we only ever register it once. |
13 | static auto workerInfo = torch::class_<WorkerInfo>("dist_rpc" , "WorkerInfo" ) |
14 | .def(torch::init<std::string, int64_t>()); |
15 | } |
16 | |
17 | constexpr size_t WorkerInfo::MAX_NAME_LEN; |
18 | |
19 | WorkerInfo::WorkerInfo(std::string name, int64_t id) |
20 | : WorkerInfo(std::move(name), (worker_id_t)id) { |
21 | TORCH_CHECK( |
22 | id <= std::numeric_limits<worker_id_t>::max(), |
23 | "RPC worker id " , |
24 | id, |
25 | " out of bound of int16_t." ); |
26 | } |
27 | |
28 | WorkerInfo::WorkerInfo(std::string name, worker_id_t id) |
29 | : name_(std::move(name)), id_(id) { |
30 | bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0; |
31 | bool validChar = |
32 | std::find_if(name_.begin(), name_.end(), [](char c) { |
33 | return !(std::isalnum(c) || c == '-' || c == '_' || c == ':'); |
34 | }) == name_.end(); |
35 | TORCH_CHECK( |
36 | validSize && validChar, |
37 | "Worker name must match ^[A-Za-z0-9-_:]*$, " |
38 | "and must be non-empty and shorter than " , |
39 | MAX_NAME_LEN, |
40 | " chars, " |
41 | "but got " , |
42 | name_); |
43 | } |
44 | |
45 | // Large Time Duration for waiting on the condition variable until the map is |
46 | // population. Cannot use |
47 | // std::chrono::time_point<std::chrono::steady_clock>::max() due to a known |
48 | // overflow-related bug. |
49 | constexpr auto kLargeTimeDuration = std::chrono::hours(10000); |
50 | |
51 | RpcAgent::RpcAgent( |
52 | WorkerInfo workerId, |
53 | std::unique_ptr<RequestCallback> cb, |
54 | std::chrono::milliseconds rpcTimeout) |
55 | : workerInfo_(std::move(workerId)), |
56 | cb_(std::move(cb)), |
57 | rpcTimeout_(rpcTimeout), |
58 | profilingEnabled_(false), |
59 | rpcAgentRunning_(false) {} |
60 | |
61 | RpcAgent::~RpcAgent() { |
62 | if (rpcAgentRunning_.load()) { |
63 | shutdown(); |
64 | } |
65 | } |
66 | |
67 | void RpcAgent::start() { |
68 | rpcAgentRunning_.store(true); |
69 | rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this); |
70 | startImpl(); |
71 | } |
72 | |
73 | void RpcAgent::shutdown() { |
74 | TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP(); |
75 | std::unique_lock<std::mutex> lock(rpcRetryMutex_); |
76 | rpcAgentRunning_.store(false); |
77 | lock.unlock(); |
78 | rpcRetryMapCV_.notify_one(); |
79 | if (rpcRetryThread_.joinable()) { |
80 | rpcRetryThread_.join(); |
81 | } |
82 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.PureVirtualCall) |
83 | shutdownImpl(); |
84 | } |
85 | |
86 | c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries( |
87 | const WorkerInfo& to, |
88 | c10::intrusive_ptr<Message> message, |
89 | RpcRetryOptions retryOptions) { |
90 | TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative." ); |
91 | TORCH_CHECK( |
92 | retryOptions.retryBackoff >= 1, |
93 | "maxRetries cannot be exponentially decaying." ); |
94 | TORCH_CHECK( |
95 | retryOptions.rpcRetryDuration.count() >= 0, |
96 | "rpcRetryDuration cannot be negative." ); |
97 | |
98 | auto originalFuture = |
99 | c10::make_intrusive<JitFuture>(at::AnyClassType::get(), getDevices()); |
100 | steady_clock_time_point newTime = |
101 | computeNewRpcRetryTime(retryOptions, /* retryCount */ 0); |
102 | auto firstRetryRpc = std::make_shared<RpcRetryInfo>( |
103 | to, |
104 | message, |
105 | originalFuture, |
106 | /* retryCount */ 0, |
107 | retryOptions); |
108 | auto jitFuture = send(to, std::move(message)); |
109 | jitFuture->addCallback([this, newTime, firstRetryRpc](JitFuture& future) { |
110 | rpcRetryCallback(future, newTime, firstRetryRpc); |
111 | }); |
112 | |
113 | return originalFuture; |
114 | } |
115 | |
116 | void RpcAgent::retryExpiredRpcs() { |
117 | // Stores the retried futures so callbacks can be added outside the lock. |
118 | std::vector< |
119 | std::pair<c10::intrusive_ptr<JitFuture>, std::shared_ptr<RpcRetryInfo>>> |
120 | futures; |
121 | // Stores futures and exception messages for non-retriable error-ed futures. |
122 | std::vector<std::pair<c10::intrusive_ptr<JitFuture>, std::string>> |
123 | errorFutures; |
124 | |
125 | while (rpcAgentRunning_.load()) { |
126 | std::unique_lock<std::mutex> lock(rpcRetryMutex_); |
127 | |
128 | // We must continue sleeping as long as the RPC Agent is running and when |
129 | // either the Retry Map is empty, or when the Retry Map's earliest expiring |
130 | // RPC is set to be retried in the future. |
131 | steady_clock_time_point earliestTimeout = |
132 | std::chrono::steady_clock::now() + kLargeTimeDuration; |
133 | |
134 | for (;;) { |
135 | if (!rpcAgentRunning_.load()) |
136 | return; |
137 | if (std::chrono::steady_clock::now() >= earliestTimeout) |
138 | break; |
139 | if (!rpcRetryMap_.empty()) { |
140 | earliestTimeout = rpcRetryMap_.begin()->first; |
141 | } |
142 | rpcRetryMapCV_.wait_until(lock, earliestTimeout); |
143 | } |
144 | |
145 | // Updating these since something may have been added to the map while this |
146 | // thread was sleeping. |
147 | earliestTimeout = rpcRetryMap_.begin()->first; |
148 | auto& earliestRpcList = rpcRetryMap_.begin()->second; |
149 | |
150 | // We iterate through all the RPC's set to be retried at the current |
151 | // timepoint, resend those RPC's, and add the RPC's and their futures to |
152 | // a list to later attach callbacks. These callbacks either schedule |
153 | // the RPC for a future retry or marks it with success/error depending on |
154 | // the outcome of the current send. Then, we clean up the rpcRetryMap_. |
155 | for (auto it = earliestRpcList.begin(); it != earliestRpcList.end(); |
156 | /* no increment */) { |
157 | auto& earliestRpc = *it; |
158 | c10::intrusive_ptr<JitFuture> jitFuture; |
159 | |
160 | // send() will throw an exception if an RPC is retried while the agent is |
161 | // shutdown. We must catch this exception and mark the original future |
162 | // with an error, since this RPC never succeeded and can no longer be |
163 | // retried. |
164 | try { |
165 | jitFuture = send(earliestRpc->to_, earliestRpc->message_); |
166 | futures.emplace_back(jitFuture, earliestRpc); |
167 | } catch (std::exception& e) { |
168 | // We must store the futures and exception messages here and only mark |
169 | // the futures with an error after releasing the lock. |
170 | errorFutures.emplace_back(earliestRpc->originalFuture_, e.what()); |
171 | } |
172 | |
173 | // A callback will be attached to all futures for the retries in this |
174 | // list. Thus they will either be rescheduled for future retries or they |
175 | // will be marked as complete. We can safely delete them from the retry |
176 | // Map for the current timepoint. |
177 | it = earliestRpcList.erase(it); |
178 | } |
179 | |
180 | // If there are no more RPC's set to be retried at the current timepoint, |
181 | // we can remove the corresponsing unordered_set from the retry map. |
182 | if (earliestRpcList.empty()) { |
183 | rpcRetryMap_.erase(earliestTimeout); |
184 | } |
185 | |
186 | lock.unlock(); |
187 | // We attach callbacks to the futures outside of the lock to prevent |
188 | // potential deadlocks. |
189 | for (const auto& it : futures) { |
190 | auto jitFuture = it.first; |
191 | auto earliestRpc = it.second; |
192 | steady_clock_time_point newTime = computeNewRpcRetryTime( |
193 | earliestRpc->options_, earliestRpc->retryCount_); |
194 | earliestRpc->retryCount_++; |
195 | |
196 | jitFuture->addCallback([this, newTime, earliestRpc](JitFuture& future) { |
197 | rpcRetryCallback(future, newTime, earliestRpc); |
198 | }); |
199 | } |
200 | futures.clear(); |
201 | |
202 | // For exceptions caught while retrying RPC's above, we set those futures |
203 | // with errors now that we have released the lock. |
204 | for (const auto& it : errorFutures) { |
205 | auto errorFuture = it.first; |
206 | auto errorMsg = it.second; |
207 | errorFuture->setError( |
208 | std::make_exception_ptr(std::runtime_error(errorMsg))); |
209 | } |
210 | errorFutures.clear(); |
211 | } |
212 | } |
213 | |
214 | void RpcAgent::rpcRetryCallback( |
215 | JitFuture& jitFuture, |
216 | steady_clock_time_point newTime, |
217 | std::shared_ptr<RpcRetryInfo> earliestRpc) { |
218 | if (jitFuture.hasError()) { |
219 | // Adding one since we want to include the original send as well and not |
220 | // just the retry count. |
221 | LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed" ; |
222 | if (!rpcAgentRunning_.load()) { |
223 | // If the RPC Agent has shutdown, we cannot retry messages. Thus we mark |
224 | // the future with an error since the RPC was never completed |
225 | // successfully. |
226 | std::string errorMessage = c10::str( |
227 | "RPC Agent is no longer running on Node " , |
228 | RpcAgent::getWorkerInfo().id_, |
229 | ". Cannot retry message." ); |
230 | earliestRpc->originalFuture_->setError(jitFuture.exception_ptr()); |
231 | } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) { |
232 | // If the previous future completed with an error and we haven't |
233 | // completed maxRetries send attempts, we move the earliestRpc |
234 | // struct to a new time point in the retry map (effectively |
235 | // scheduling it for a future retry.) |
236 | { |
237 | std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_); |
238 | rpcRetryMap_[newTime].emplace(std::move(earliestRpc)); |
239 | } |
240 | // The retry thread waits for the map to be populated. Thus we notify |
241 | // once an item has been added. |
242 | rpcRetryMapCV_.notify_one(); |
243 | } else { |
244 | // We have completed maxRetries send attempts. We're now marking |
245 | // the future with an error. |
246 | std::string errorMessage = c10::str( |
247 | "The RPC has not succeeded after the specified number of max retries (" , |
248 | earliestRpc->options_.maxRetries, |
249 | ")." ); |
250 | earliestRpc->originalFuture_->setError( |
251 | std::make_exception_ptr(std::runtime_error(errorMessage))); |
252 | } |
253 | } else { |
254 | // This try succeeded, so we can make the original future as complete. |
255 | earliestRpc->originalFuture_->markCompleted( |
256 | jitFuture.value(), jitFuture.storages()); |
257 | } |
258 | } |
259 | |
260 | const WorkerInfo& RpcAgent::getWorkerInfo() const { |
261 | return workerInfo_; |
262 | } |
263 | |
264 | std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr; |
265 | |
266 | bool RpcAgent::isCurrentRpcAgentSet() { |
267 | return std::atomic_load(¤tRpcAgent_) != nullptr; |
268 | } |
269 | |
270 | std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() { |
271 | std::shared_ptr<RpcAgent> agent = std::atomic_load(¤tRpcAgent_); |
272 | TORCH_CHECK( |
273 | agent, |
274 | "Current RPC agent is not set! Did you initialize the RPC " |
275 | "framework (e.g. by calling `rpc.init_rpc`)?" ); |
276 | return agent; |
277 | } |
278 | |
279 | void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) { |
280 | if (rpcAgent) { |
281 | std::shared_ptr<RpcAgent> previousAgent; |
282 | // Use compare_exchange so that we don't actually perform the exchange if |
283 | // that would trigger the assert just below. See: |
284 | // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange |
285 | std::atomic_compare_exchange_strong( |
286 | ¤tRpcAgent_, &previousAgent, std::move(rpcAgent)); |
287 | TORCH_INTERNAL_ASSERT( |
288 | previousAgent == nullptr, "Current RPC agent is set!" ); |
289 | } else { |
290 | // We can't use compare_exchange (we don't know what value to expect) but we |
291 | // don't need to, as the only case that would trigger the assert is if we |
292 | // replaced nullptr with nullptr, which we can just do as it has no effect. |
293 | std::shared_ptr<RpcAgent> previousAgent = |
294 | std::atomic_exchange(¤tRpcAgent_, std::move(rpcAgent)); |
295 | TORCH_INTERNAL_ASSERT( |
296 | previousAgent != nullptr, "Current RPC agent is not set!" ); |
297 | } |
298 | } |
299 | |
300 | void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) { |
301 | typeResolver_ = std::move(typeResolver); |
302 | } |
303 | |
304 | std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() { |
305 | TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!" ); |
306 | return typeResolver_; |
307 | } |
308 | |
309 | void RpcAgent::enableGILProfiling(bool flag) { |
310 | profilingEnabled_ = flag; |
311 | } |
312 | |
313 | bool RpcAgent::isGILProfilingEnabled() { |
314 | return profilingEnabled_.load(); |
315 | } |
316 | |
317 | DeviceMap RpcAgent::getDeviceMap(const WorkerInfo& /* unused */) const { |
318 | // Default implementation has no device map. |
319 | return {}; |
320 | } |
321 | |
322 | const std::vector<c10::Device>& RpcAgent::getDevices() const { |
323 | // By default the agent is CPU-only. |
324 | static const std::vector<c10::Device> noDevices = {}; |
325 | return noDevices; |
326 | } |
327 | |
328 | std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() { |
329 | /* This would later include more info other than metrics for eg: may include |
330 | stack traces for the threads owned by the agent */ |
331 | // Default implementation: return getMetrics(). |
332 | return getMetrics(); |
333 | } |
334 | |
335 | std::ostream& operator<<(std::ostream& os, const WorkerInfo& workerInfo) { |
336 | return os << "WorkerInfo(id=" << workerInfo.id_ |
337 | << ", name=" << workerInfo.name_ << ")" ; |
338 | } |
339 | |
340 | } // namespace rpc |
341 | } // namespace distributed |
342 | } // namespace torch |
343 | |