1#include <c10/util/DeadlockDetection.h>
2#include <torch/csrc/distributed/rpc/rpc_agent.h>
3
4namespace torch {
5namespace distributed {
6namespace rpc {
7
8RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
9 // WorkerInfo needs to be registered exactly once. Since the op registration
10 // happens in libtorch_python we wrap the class registration in a helper to
11 // make sure that if there's multiple copies of Python such as used in
12 // torch::deploy we only ever register it once.
13 static auto workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
14 .def(torch::init<std::string, int64_t>());
15}
16
17constexpr size_t WorkerInfo::MAX_NAME_LEN;
18
19WorkerInfo::WorkerInfo(std::string name, int64_t id)
20 : WorkerInfo(std::move(name), (worker_id_t)id) {
21 TORCH_CHECK(
22 id <= std::numeric_limits<worker_id_t>::max(),
23 "RPC worker id ",
24 id,
25 " out of bound of int16_t.");
26}
27
28WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
29 : name_(std::move(name)), id_(id) {
30 bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
31 bool validChar =
32 std::find_if(name_.begin(), name_.end(), [](char c) {
33 return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
34 }) == name_.end();
35 TORCH_CHECK(
36 validSize && validChar,
37 "Worker name must match ^[A-Za-z0-9-_:]*$, "
38 "and must be non-empty and shorter than ",
39 MAX_NAME_LEN,
40 " chars, "
41 "but got ",
42 name_);
43}
44
45// Large Time Duration for waiting on the condition variable until the map is
46// population. Cannot use
47// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
48// overflow-related bug.
49constexpr auto kLargeTimeDuration = std::chrono::hours(10000);
50
51RpcAgent::RpcAgent(
52 WorkerInfo workerId,
53 std::unique_ptr<RequestCallback> cb,
54 std::chrono::milliseconds rpcTimeout)
55 : workerInfo_(std::move(workerId)),
56 cb_(std::move(cb)),
57 rpcTimeout_(rpcTimeout),
58 profilingEnabled_(false),
59 rpcAgentRunning_(false) {}
60
61RpcAgent::~RpcAgent() {
62 if (rpcAgentRunning_.load()) {
63 shutdown();
64 }
65}
66
67void RpcAgent::start() {
68 rpcAgentRunning_.store(true);
69 rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
70 startImpl();
71}
72
73void RpcAgent::shutdown() {
74 TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
75 std::unique_lock<std::mutex> lock(rpcRetryMutex_);
76 rpcAgentRunning_.store(false);
77 lock.unlock();
78 rpcRetryMapCV_.notify_one();
79 if (rpcRetryThread_.joinable()) {
80 rpcRetryThread_.join();
81 }
82 // NOLINTNEXTLINE(clang-analyzer-cplusplus.PureVirtualCall)
83 shutdownImpl();
84}
85
86c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries(
87 const WorkerInfo& to,
88 c10::intrusive_ptr<Message> message,
89 RpcRetryOptions retryOptions) {
90 TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
91 TORCH_CHECK(
92 retryOptions.retryBackoff >= 1,
93 "maxRetries cannot be exponentially decaying.");
94 TORCH_CHECK(
95 retryOptions.rpcRetryDuration.count() >= 0,
96 "rpcRetryDuration cannot be negative.");
97
98 auto originalFuture =
99 c10::make_intrusive<JitFuture>(at::AnyClassType::get(), getDevices());
100 steady_clock_time_point newTime =
101 computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
102 auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
103 to,
104 message,
105 originalFuture,
106 /* retryCount */ 0,
107 retryOptions);
108 auto jitFuture = send(to, std::move(message));
109 jitFuture->addCallback([this, newTime, firstRetryRpc](JitFuture& future) {
110 rpcRetryCallback(future, newTime, firstRetryRpc);
111 });
112
113 return originalFuture;
114}
115
116void RpcAgent::retryExpiredRpcs() {
117 // Stores the retried futures so callbacks can be added outside the lock.
118 std::vector<
119 std::pair<c10::intrusive_ptr<JitFuture>, std::shared_ptr<RpcRetryInfo>>>
120 futures;
121 // Stores futures and exception messages for non-retriable error-ed futures.
122 std::vector<std::pair<c10::intrusive_ptr<JitFuture>, std::string>>
123 errorFutures;
124
125 while (rpcAgentRunning_.load()) {
126 std::unique_lock<std::mutex> lock(rpcRetryMutex_);
127
128 // We must continue sleeping as long as the RPC Agent is running and when
129 // either the Retry Map is empty, or when the Retry Map's earliest expiring
130 // RPC is set to be retried in the future.
131 steady_clock_time_point earliestTimeout =
132 std::chrono::steady_clock::now() + kLargeTimeDuration;
133
134 for (;;) {
135 if (!rpcAgentRunning_.load())
136 return;
137 if (std::chrono::steady_clock::now() >= earliestTimeout)
138 break;
139 if (!rpcRetryMap_.empty()) {
140 earliestTimeout = rpcRetryMap_.begin()->first;
141 }
142 rpcRetryMapCV_.wait_until(lock, earliestTimeout);
143 }
144
145 // Updating these since something may have been added to the map while this
146 // thread was sleeping.
147 earliestTimeout = rpcRetryMap_.begin()->first;
148 auto& earliestRpcList = rpcRetryMap_.begin()->second;
149
150 // We iterate through all the RPC's set to be retried at the current
151 // timepoint, resend those RPC's, and add the RPC's and their futures to
152 // a list to later attach callbacks. These callbacks either schedule
153 // the RPC for a future retry or marks it with success/error depending on
154 // the outcome of the current send. Then, we clean up the rpcRetryMap_.
155 for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
156 /* no increment */) {
157 auto& earliestRpc = *it;
158 c10::intrusive_ptr<JitFuture> jitFuture;
159
160 // send() will throw an exception if an RPC is retried while the agent is
161 // shutdown. We must catch this exception and mark the original future
162 // with an error, since this RPC never succeeded and can no longer be
163 // retried.
164 try {
165 jitFuture = send(earliestRpc->to_, earliestRpc->message_);
166 futures.emplace_back(jitFuture, earliestRpc);
167 } catch (std::exception& e) {
168 // We must store the futures and exception messages here and only mark
169 // the futures with an error after releasing the lock.
170 errorFutures.emplace_back(earliestRpc->originalFuture_, e.what());
171 }
172
173 // A callback will be attached to all futures for the retries in this
174 // list. Thus they will either be rescheduled for future retries or they
175 // will be marked as complete. We can safely delete them from the retry
176 // Map for the current timepoint.
177 it = earliestRpcList.erase(it);
178 }
179
180 // If there are no more RPC's set to be retried at the current timepoint,
181 // we can remove the corresponsing unordered_set from the retry map.
182 if (earliestRpcList.empty()) {
183 rpcRetryMap_.erase(earliestTimeout);
184 }
185
186 lock.unlock();
187 // We attach callbacks to the futures outside of the lock to prevent
188 // potential deadlocks.
189 for (const auto& it : futures) {
190 auto jitFuture = it.first;
191 auto earliestRpc = it.second;
192 steady_clock_time_point newTime = computeNewRpcRetryTime(
193 earliestRpc->options_, earliestRpc->retryCount_);
194 earliestRpc->retryCount_++;
195
196 jitFuture->addCallback([this, newTime, earliestRpc](JitFuture& future) {
197 rpcRetryCallback(future, newTime, earliestRpc);
198 });
199 }
200 futures.clear();
201
202 // For exceptions caught while retrying RPC's above, we set those futures
203 // with errors now that we have released the lock.
204 for (const auto& it : errorFutures) {
205 auto errorFuture = it.first;
206 auto errorMsg = it.second;
207 errorFuture->setError(
208 std::make_exception_ptr(std::runtime_error(errorMsg)));
209 }
210 errorFutures.clear();
211 }
212}
213
214void RpcAgent::rpcRetryCallback(
215 JitFuture& jitFuture,
216 steady_clock_time_point newTime,
217 std::shared_ptr<RpcRetryInfo> earliestRpc) {
218 if (jitFuture.hasError()) {
219 // Adding one since we want to include the original send as well and not
220 // just the retry count.
221 LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed";
222 if (!rpcAgentRunning_.load()) {
223 // If the RPC Agent has shutdown, we cannot retry messages. Thus we mark
224 // the future with an error since the RPC was never completed
225 // successfully.
226 std::string errorMessage = c10::str(
227 "RPC Agent is no longer running on Node ",
228 RpcAgent::getWorkerInfo().id_,
229 ". Cannot retry message.");
230 earliestRpc->originalFuture_->setError(jitFuture.exception_ptr());
231 } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) {
232 // If the previous future completed with an error and we haven't
233 // completed maxRetries send attempts, we move the earliestRpc
234 // struct to a new time point in the retry map (effectively
235 // scheduling it for a future retry.)
236 {
237 std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
238 rpcRetryMap_[newTime].emplace(std::move(earliestRpc));
239 }
240 // The retry thread waits for the map to be populated. Thus we notify
241 // once an item has been added.
242 rpcRetryMapCV_.notify_one();
243 } else {
244 // We have completed maxRetries send attempts. We're now marking
245 // the future with an error.
246 std::string errorMessage = c10::str(
247 "The RPC has not succeeded after the specified number of max retries (",
248 earliestRpc->options_.maxRetries,
249 ").");
250 earliestRpc->originalFuture_->setError(
251 std::make_exception_ptr(std::runtime_error(errorMessage)));
252 }
253 } else {
254 // This try succeeded, so we can make the original future as complete.
255 earliestRpc->originalFuture_->markCompleted(
256 jitFuture.value(), jitFuture.storages());
257 }
258}
259
260const WorkerInfo& RpcAgent::getWorkerInfo() const {
261 return workerInfo_;
262}
263
264std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;
265
266bool RpcAgent::isCurrentRpcAgentSet() {
267 return std::atomic_load(&currentRpcAgent_) != nullptr;
268}
269
270std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
271 std::shared_ptr<RpcAgent> agent = std::atomic_load(&currentRpcAgent_);
272 TORCH_CHECK(
273 agent,
274 "Current RPC agent is not set! Did you initialize the RPC "
275 "framework (e.g. by calling `rpc.init_rpc`)?");
276 return agent;
277}
278
279void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
280 if (rpcAgent) {
281 std::shared_ptr<RpcAgent> previousAgent;
282 // Use compare_exchange so that we don't actually perform the exchange if
283 // that would trigger the assert just below. See:
284 // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
285 std::atomic_compare_exchange_strong(
286 &currentRpcAgent_, &previousAgent, std::move(rpcAgent));
287 TORCH_INTERNAL_ASSERT(
288 previousAgent == nullptr, "Current RPC agent is set!");
289 } else {
290 // We can't use compare_exchange (we don't know what value to expect) but we
291 // don't need to, as the only case that would trigger the assert is if we
292 // replaced nullptr with nullptr, which we can just do as it has no effect.
293 std::shared_ptr<RpcAgent> previousAgent =
294 std::atomic_exchange(&currentRpcAgent_, std::move(rpcAgent));
295 TORCH_INTERNAL_ASSERT(
296 previousAgent != nullptr, "Current RPC agent is not set!");
297 }
298}
299
300void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
301 typeResolver_ = std::move(typeResolver);
302}
303
304std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() {
305 TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!");
306 return typeResolver_;
307}
308
309void RpcAgent::enableGILProfiling(bool flag) {
310 profilingEnabled_ = flag;
311}
312
313bool RpcAgent::isGILProfilingEnabled() {
314 return profilingEnabled_.load();
315}
316
317DeviceMap RpcAgent::getDeviceMap(const WorkerInfo& /* unused */) const {
318 // Default implementation has no device map.
319 return {};
320}
321
322const std::vector<c10::Device>& RpcAgent::getDevices() const {
323 // By default the agent is CPU-only.
324 static const std::vector<c10::Device> noDevices = {};
325 return noDevices;
326}
327
328std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
329 /* This would later include more info other than metrics for eg: may include
330 stack traces for the threads owned by the agent */
331 // Default implementation: return getMetrics().
332 return getMetrics();
333}
334
335std::ostream& operator<<(std::ostream& os, const WorkerInfo& workerInfo) {
336 return os << "WorkerInfo(id=" << workerInfo.id_
337 << ", name=" << workerInfo.name_ << ")";
338}
339
340} // namespace rpc
341} // namespace distributed
342} // namespace torch
343