1#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
2
3#ifdef USE_TENSORPIPE
4
5#include <limits>
6#include <tuple>
7#include <utility>
8
9#include <fmt/format.h>
10#include <tensorpipe/tensorpipe.h>
11
12#include <torch/csrc/distributed/rpc/agent_utils.h>
13#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
14#include <torch/csrc/distributed/rpc/utils.h>
15
16#include <c10/core/StreamGuard.h>
17#include <c10/util/irange.h>
18
19namespace torch {
20namespace distributed {
21namespace rpc {
22
23namespace {
24
25// An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that
26// allows the user to specify a device to bind to, instead of binding to the
27// address that the hostname resolves to.
28const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
29const std::string kDefaultUvAddress = "127.0.0.1";
30
31const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
32const std::string kThreadPoolSize = "agent.thread_pool_size";
33const std::string kNumIdleThreads = "agent.num_idle_threads";
34const std::string kClientActiveCalls = "agent.client_active_calls";
35const std::string kServerActiveCalls = "agent.server_active_calls";
36const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
37
38std::vector<c10::Device> getDevicesForTensors(
39 const std::vector<torch::Tensor>& tensors,
40 const DeviceMap& deviceMap,
41 const std::string& remoteName) {
42 // If the deviceMap is overridden, use that instead.
43 const auto errStr = c10::str(
44 "TensorPipe RPC backend only supports CPU tensors by default, please "
45 "move your tensors to CPU before sending them over RPC, or call "
46 "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
47 "configure device mapping. ",
48 "Request device mapping is not available for destination ",
49 remoteName);
50 std::vector<c10::Device> devices;
51 devices.reserve(tensors.size());
52 bool hasMappedDevice = false;
53 for (const auto& t : tensors) {
54 if (t.device().is_cpu()) {
55 const auto deviceIter = deviceMap.find(c10::kCPU);
56 if (deviceIter == deviceMap.end()) {
57 devices.emplace_back(c10::kCPU);
58 } else {
59 devices.emplace_back(deviceIter->second);
60 hasMappedDevice = true;
61 }
62 } else {
63 const auto deviceIter = deviceMap.find(t.device());
64 TORCH_CHECK(
65 deviceIter != deviceMap.end(),
66 errStr,
67 " for device ",
68 t.device(),
69 " but received a tensor on that device.");
70 devices.push_back(deviceIter->second);
71 hasMappedDevice = true;
72 }
73 }
74 if (!hasMappedDevice) {
75 devices.clear();
76 }
77 return devices;
78}
79
80std::vector<c10::Stream> getStreamsFromPoolForDevices(
81 const std::vector<c10::Device>& devices) {
82 if (devices.empty()) {
83 return {};
84 }
85 c10::impl::VirtualGuardImpl impl(devices[0].type());
86 std::vector<c10::Stream> streams;
87 streams.reserve(devices.size());
88 for (const c10::Device& device : devices) {
89 TORCH_INTERNAL_ASSERT(device.type() == impl.type());
90 streams.push_back(impl.getStreamFromGlobalPool(device));
91 }
92 return streams;
93}
94
95std::vector<c10::Stream> getCurrentStreamsForDevices(
96 const std::vector<c10::Device>& devices) {
97 if (devices.empty()) {
98 return {};
99 }
100 c10::impl::VirtualGuardImpl impl(devices[0].type());
101 std::vector<c10::Stream> streams;
102 streams.reserve(devices.size());
103 for (const c10::Device& device : devices) {
104 TORCH_INTERNAL_ASSERT(device.type() == impl.type());
105 streams.push_back(impl.getStream(device));
106 }
107 return streams;
108}
109
110std::vector<c10::Device> getDevicesOfTensors(
111 const std::vector<torch::Tensor>& tensors) {
112 c10::optional<c10::impl::VirtualGuardImpl> impl;
113 size_t deviceCount = 0;
114 std::vector<bool> indexBitset;
115 for (const torch::Tensor& tensor : tensors) {
116 if (!tensor.is_cpu()) {
117 c10::Device device = tensor.device();
118 if (!impl.has_value()) {
119 impl.emplace(device.type());
120 indexBitset.resize(impl->deviceCount());
121 }
122 TORCH_INTERNAL_ASSERT(device.type() == impl->type());
123 TORCH_INTERNAL_ASSERT(device.has_index());
124 if (!indexBitset[device.index()]) {
125 deviceCount++;
126 indexBitset[device.index()] = true;
127 }
128 }
129 }
130 std::vector<c10::Device> devices;
131 devices.reserve(deviceCount);
132 for (const auto idx : c10::irange(indexBitset.size())) {
133 if (indexBitset[idx]) {
134 devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
135 }
136 }
137 return devices;
138}
139
140void makeStreamsWaitOnOthers(
141 const std::vector<c10::Stream>& consumers,
142 const std::vector<c10::Stream>& producers) {
143 for (const c10::Stream& producer : producers) {
144 const c10::Stream& consumer =
145 getStreamForDevice(consumers, producer.device());
146 c10::Event event(producer.device_type());
147 event.record(producer);
148 event.block(consumer);
149 }
150}
151
152} // namespace
153
154C10_DEFINE_REGISTRY_WITHOUT_WARNING(
155 TensorPipeTransportRegistry,
156 TransportRegistration);
157
158C10_DEFINE_REGISTRY_WITHOUT_WARNING(
159 TensorPipeChannelRegistry,
160 ChannelRegistration);
161
162const std::string& TensorPipeAgent::guessAddress() {
163 static const std::string uvAddress = []() {
164 tensorpipe::Error error;
165 std::string result;
166 char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
167 if (ifnameEnv != nullptr) {
168 std::tie(error, result) =
169 tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
170 if (error) {
171 LOG(WARNING) << "Failed to look up the IP address for interface "
172 << ifnameEnv << " (" << error.what() << "), defaulting to "
173 << kDefaultUvAddress;
174 return kDefaultUvAddress;
175 }
176 } else {
177 std::tie(error, result) =
178 tensorpipe::transport::uv::lookupAddrForHostname();
179 if (error) {
180 LOG(WARNING) << "Failed to look up the IP address for the hostname ("
181 << error.what() << "), defaulting to "
182 << kDefaultUvAddress;
183 return kDefaultUvAddress;
184 }
185 }
186 return result;
187 }();
188 return uvAddress;
189}
190
191namespace {
192
193std::unique_ptr<TransportRegistration> makeUvTransport() {
194 auto context = tensorpipe::transport::uv::create();
195 std::string address = TensorPipeAgent::guessAddress();
196 return std::make_unique<TransportRegistration>(TransportRegistration{
197 std::move(context), kUvTransportPriority, std::move(address)});
198}
199
200// The UV transport is implemented using standard TCP connections. It leverages
201// libuv (https://github.com/libuv/libuv) in order to be cross-platform.
202C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport);
203
204#if TENSORPIPE_HAS_SHM_TRANSPORT
205
206std::unique_ptr<TransportRegistration> makeShmTransport() {
207 auto context = tensorpipe::transport::shm::create();
208 return std::make_unique<TransportRegistration>(
209 TransportRegistration{std::move(context), kShmTransportPriority, ""});
210}
211
212// The SHM implements connections using ringbuffers residing in anonymous shared
213// memory (plus UNIX domain sockets to bootstrap the connection and exchange
214// file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
215// eventfd, ...).
216C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport);
217
218#endif // TENSORPIPE_HAS_SHM_TRANSPORT
219
220#if TENSORPIPE_HAS_IBV_TRANSPORT
221
222std::unique_ptr<TransportRegistration> makeIbvTransport() {
223 auto context = tensorpipe::transport::ibv::create();
224 std::string address = TensorPipeAgent::guessAddress();
225 return std::make_unique<TransportRegistration>(TransportRegistration{
226 std::move(context), kIbvTransportPriority, std::move(address)});
227}
228
229// The IBV transport sends data across using an InfiniBand queue pair, locally
230// copying data to and from a staging buffer (registered with libibverbs) and
231// issuing a RDMA write for transferring data across machines (plus a send for
232// acknowledging it). It bootstraps using a standard TCP connection to exchange
233// setup information. It is Linux-only.
234C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport);
235
236#endif // TENSORPIPE_HAS_IBV_TRANSPORT
237
238std::unique_ptr<ChannelRegistration> makeBasicChannel() {
239 auto context = tensorpipe::channel::basic::create();
240 return std::make_unique<ChannelRegistration>(
241 ChannelRegistration{std::move(context), kBasicChannelPriority});
242}
243
244// The basic channel is just a straightforward adapter wrapper that allows any
245// transport to be used as a channel.
246C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel);
247
248#if TENSORPIPE_HAS_CMA_CHANNEL
249
250std::unique_ptr<ChannelRegistration> makeCmaChannel() {
251 auto context = tensorpipe::channel::cma::create();
252 return std::make_unique<ChannelRegistration>(
253 ChannelRegistration{std::move(context), kCmaChannelPriority});
254}
255
256// The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
257// and _writev), which allow one process to access the private memory of another
258// process (as long as they belong to the same user and other security
259// constraints are satisfied). It does, more or less, what GDB does when it's
260// attached to a running process.
261C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel);
262
263#endif // TENSORPIPE_HAS_CMA_CHANNEL
264
265constexpr static int kNumUvThreads = 16;
266
267std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() {
268 std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
269 std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
270 for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) {
271 auto context = tensorpipe::transport::uv::create();
272 std::string address = TensorPipeAgent::guessAddress();
273 contexts.push_back(std::move(context));
274 listeners.push_back(contexts.back()->listen(address));
275 }
276 auto context = tensorpipe::channel::mpt::create(
277 std::move(contexts), std::move(listeners));
278 return std::make_unique<ChannelRegistration>(
279 ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
280}
281
282// The multiplexed UV channel encapsulates multiple UV transports (each with its
283// own event loop thread). Each channel will, in turn, contain multiple UV
284// connections, one for each of those contexts. When sending a tensor, its data
285// is split in equal chunks and each chunks is sent on a different connection
286// and thus driven by a different thread. This is needed to reach very high
287// bandwidths.
288C10_REGISTER_CREATOR(
289 TensorPipeChannelRegistry,
290 mpt_uv,
291 makeMultiplexedUvChannel);
292
293} // namespace
294
295////////////////////////// MetricsTracker /////////////////////////////////
296
297TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
298 uint64_t currentSum,
299 uint64_t currentCount)
300 : currentSum_(currentSum), currentCount_(currentCount) {}
301
302void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
303 currentSum_ += dataPoint;
304 ++currentCount_;
305}
306
307float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
308 return currentCount_ == 0 ? 0 : currentSum_ / (float)currentCount_;
309}
310
311//////////////////////// TensorpipeRpcAgent /////////////////////////////////
312
313void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) {
314 // Remove entry from timeoutMap_.
315 {
316 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
317 auto it = messageIdToTimeout_.find(messageId);
318 if (it == messageIdToTimeout_.end()) {
319 // Already removed from the map by pollTimeoutRpcs(), no need to
320 // process further.
321 return;
322 }
323
324 auto& expirationTime = it->second;
325
326 auto& timedOutFuturesVector = timeoutMap_[expirationTime];
327 for (auto it = timedOutFuturesVector.begin();
328 it != timedOutFuturesVector.end();
329 it++) {
330 if (it->messageId == messageId) {
331 it = timedOutFuturesVector.erase(it);
332 break;
333 }
334 }
335
336 if (timedOutFuturesVector.empty()) {
337 timeoutMap_.erase(expirationTime);
338 }
339
340 // Remove from messageId to timeout map as well.
341 messageIdToTimeout_.erase(messageId);
342 }
343}
344
345void TensorPipeAgent::prepareNames(bool isStaticGroup) {
346 std::unordered_map<std::string, worker_id_t> nameToId;
347 if (isStaticGroup) {
348 nameToId = collectNames(
349 rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
350 } else {
351 nameToId = collectCurrentNames(
352 rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
353 }
354
355 for (const auto& entry : nameToId) {
356 const auto& workerName = entry.first;
357 const auto& workerId = entry.second;
358 workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
359 workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
360 }
361}
362
363void TensorPipeAgent::checkAndSetStaticGroup(
364 const c10::intrusive_ptr<::c10d::Store>& store) {
365 std::string isStaticGroupKey("rpcIsStaticGroup");
366
367 std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
368 std::vector<uint8_t> isStaticGroupVec(
369 (uint8_t*)isStaticGroupStr.c_str(),
370 (uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length());
371 std::vector<uint8_t> returnedVec;
372 returnedVec = store->compareSet(
373 isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
374 std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
375 // In both cases, the returned value should be the value of isStaticGroupStr,
376 // otherwise there is a discrepency with initialization among one of the
377 // members
378 TORCH_CHECK(
379 returnedVal == isStaticGroupStr,
380 fmt::format(
381 "RPC group mixes statically and dynamically initialized members which is not supported. ",
382 "Static group property is initialized as {} and is trying to be set as {} ",
383 isStaticGroup_,
384 returnedVal));
385}
386
387TensorPipeAgent::TensorPipeAgent(
388 const c10::intrusive_ptr<::c10d::Store>& store,
389 std::string selfName,
390 worker_id_t selfId,
391 optional<int> worldSize,
392 TensorPipeRpcBackendOptions opts,
393 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
394 std::vector<c10::Device> devices,
395 std::unique_ptr<RequestCallback> cb)
396 : RpcAgent(
397 WorkerInfo(std::move(selfName), selfId),
398 std::move(cb),
399 std::chrono::milliseconds(
400 (long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
401 isStaticGroup_(worldSize.has_value()),
402 store_(store),
403 opts_(std::move(opts)),
404 reverseDeviceMaps_(std::move(reverseDeviceMaps)),
405 devices_(std::move(devices)),
406 threadPool_(opts_.numWorkerThreads),
407 context_(std::make_shared<tensorpipe::Context>(
408 tensorpipe::ContextOptions().name(workerInfo_.name_))),
409 rankToNameStore_("names", store),
410 nameToAddressStore_("addrs", store),
411 shutdownStore_("shutdown", store) {
412 if (isStaticGroup_) {
413 worldSize_ = worldSize.value();
414 }
415
416 // check the static group attribute against store
417 checkAndSetStaticGroup(store);
418
419 // collect worker names
420 prepareNames(isStaticGroup_);
421
422 // Initialize the time-series metrics tracking map
423 timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
424}
425
426TensorPipeAgent::~TensorPipeAgent() {
427 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
428 shutdown();
429}
430
431void TensorPipeAgent::startImpl() {
432 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
433
434 std::vector<std::string> addresses;
435 int lowestPriority = std::numeric_limits<int>::max();
436 std::string lowestPriorityTransport;
437
438 // Register transports
439 for (auto& key : TensorPipeTransportRegistry()->Keys()) {
440 int64_t priority = -1;
441 if (opts_.transports.has_value()) {
442 auto iter =
443 std::find(opts_.transports->begin(), opts_.transports->end(), key);
444 if (iter == opts_.transports->end()) {
445 continue;
446 }
447 // Assign priorities in reverse order of occurrence in the vector, so that
448 // a transport that comes before another receives a higher priority.
449 priority =
450 opts_.transports->size() - 1 - (iter - opts_.transports->begin());
451 }
452 std::unique_ptr<TransportRegistration> reg =
453 TensorPipeTransportRegistry()->Create(key);
454 if (!reg->transport->isViable()) {
455 continue;
456 }
457 if (priority == -1) {
458 priority = reg->priority;
459 }
460 if (priority < lowestPriority) {
461 lowestPriority = priority;
462 lowestPriorityTransport = key;
463 }
464 addresses.push_back(c10::str(key, "://", reg->address));
465 context_->registerTransport(
466 priority, std::move(key), std::move(reg->transport));
467 }
468
469 // Register channels
470 for (auto& key : TensorPipeChannelRegistry()->Keys()) {
471 int64_t priority = -1;
472 if (opts_.channels.has_value()) {
473 auto iter =
474 std::find(opts_.channels->begin(), opts_.channels->end(), key);
475 if (iter == opts_.channels->end()) {
476 continue;
477 }
478 // Assign priorities in reverse order of occurrence in the vector, so
479 // that a channel that comes before another receives a higher priority.
480 priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin());
481 }
482 std::unique_ptr<ChannelRegistration> reg =
483 TensorPipeChannelRegistry()->Create(key);
484 if (!reg->channel->isViable()) {
485 continue;
486 }
487 if (priority == -1) {
488 priority = reg->priority;
489 }
490 context_->registerChannel(
491 priority, std::move(key), std::move(reg->channel));
492 }
493
494 listener_ = context_->listen(addresses);
495
496 // Store our own url.
497 const auto address = listener_->url(lowestPriorityTransport);
498 nameToAddressStore_.set(workerInfo_.name_, address);
499
500 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
501 << address;
502
503 for (const auto& p : workerNameToInfo_) {
504 const auto& name = p.first;
505 auto nodeAddrData = nameToAddressStore_.get(name);
506 auto nodeAddrStr =
507 std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
508 workerNameToURL_.insert({name, nodeAddrStr});
509 }
510
511 // Start the Timeout Thread
512 timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
513
514 listener_->accept([this](
515 const tensorpipe::Error& error,
516 std::shared_ptr<tensorpipe::Pipe> pipe) {
517 onListenerAccepted(error, pipe);
518 });
519}
520
521void TensorPipeAgent::onListenerAccepted(
522 const tensorpipe::Error& error,
523 std::shared_ptr<tensorpipe::Pipe>& pipe) {
524 if (error) {
525 if (error.isOfType<tensorpipe::ListenerClosedError>() &&
526 !rpcAgentRunning_.load()) {
527 // This is expected.
528 } else {
529 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
530 << " encountered error when accepting incoming pipe: "
531 << error.what();
532 }
533 return;
534 }
535
536 // Accept the next connection request
537 listener_->accept([this](
538 const tensorpipe::Error& error,
539 std::shared_ptr<tensorpipe::Pipe> pipe) {
540 onListenerAccepted(error, pipe);
541 });
542
543 VLOG(1) << "RPC agent for " << workerInfo_.name_
544 << " accepted incoming pipe from " << pipe->getRemoteName();
545
546 // Arm for server read
547 respond(pipe);
548}
549
550void TensorPipeAgent::pipeRead(
551 const std::shared_ptr<tensorpipe::Pipe>& pipe,
552 std::function<void(
553 const tensorpipe::Error&,
554 c10::intrusive_ptr<Message>,
555 std::vector<c10::Stream>)> fn) noexcept {
556 pipe->readDescriptor([this, fn{std::move(fn)}, pipe](
557 const tensorpipe::Error& error,
558 tensorpipe::Descriptor tpDescriptor) mutable {
559 if (error) {
560 fn(error, c10::intrusive_ptr<Message>(), {});
561 return;
562 }
563
564 std::vector<c10::Stream> streams;
565 {
566 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
567 streams = getStreamsFromPoolForDevices(devices_);
568 }
569 tensorpipe::Allocation tpAllocation;
570 TensorpipeReadBuffers tpBuffers;
571 std::tie(tpAllocation, tpBuffers) =
572 tensorpipeAllocate(tpDescriptor, streams);
573
574 pipe->read(
575 std::move(tpAllocation),
576 [tpDescriptor{std::move(tpDescriptor)},
577 tpBuffers{
578 std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
579 fn{std::move(fn)},
580 streams{std::move(streams)}](const tensorpipe::Error& error) mutable {
581 if (error) {
582 fn(error, c10::intrusive_ptr<Message>(), {});
583 return;
584 }
585
586 // FIXME This does some unpickling, which could be a bit expensive:
587 // perhaps it would be best to perform it inside the worker threads?
588 c10::intrusive_ptr<Message> rpcMessage = tensorpipeDeserialize(
589 std::move(tpDescriptor), std::move(*tpBuffers));
590
591 fn(error, std::move(rpcMessage), std::move(streams));
592 });
593 });
594}
595
596void TensorPipeAgent::pipeWrite(
597 const std::shared_ptr<tensorpipe::Pipe>& pipe,
598 c10::intrusive_ptr<Message> rpcMessage,
599 std::vector<c10::Device>&& devices,
600 std::vector<c10::Stream> streams,
601 std::function<void(const tensorpipe::Error&)> fn) noexcept {
602 tensorpipe::Message tpMessage;
603 TensorpipeWriteBuffers tpBuffers;
604
605 std::tie(tpMessage, tpBuffers) =
606 tensorpipeSerialize(std::move(rpcMessage), std::move(devices), streams);
607
608 pipe->write(
609 std::move(tpMessage),
610 [tpBuffers{
611 std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
612 fn{std::move(fn)},
613 streams{std::move(streams)}](const tensorpipe::Error& error) {
614 fn(error);
615 });
616}
617
618void TensorPipeAgent::sendCompletedResponseMessage(
619 std::shared_ptr<tensorpipe::Pipe>& pipe,
620 JitFuture& futureResponseMessage,
621 uint64_t messageId,
622 std::vector<c10::Stream> streams) {
623 if (!rpcAgentRunning_.load()) {
624 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
625 << " won't send response to request #" << messageId << " to "
626 << pipe->getRemoteName() << ", as the agent is shutting down";
627 return;
628 }
629
630 VLOG(1) << "RPC agent for " << workerInfo_.name_
631 << " is sending response to request #" << messageId << " to "
632 << pipe->getRemoteName();
633
634 if (!futureResponseMessage.hasError()) {
635 c10::intrusive_ptr<Message> responseMessage =
636 futureResponseMessage.value().toCustomClass<Message>();
637 responseMessage->setId(messageId);
638
639 std::vector<c10::Device> devices;
640 try {
641 devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
642 } catch (const std::exception& e) {
643 responseMessage = createExceptionResponse(e.what(), messageId);
644 }
645
646 for (const auto& tensor : responseMessage->tensors()) {
647 const auto device = tensor.device();
648 if (!device.is_cpu()) {
649 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
650 if (std::find(devices_.begin(), devices_.end(), device) ==
651 devices_.end()) {
652 std::ostringstream oss;
653 std::copy(
654 devices_.begin(),
655 devices_.end(),
656 std::ostream_iterator<c10::Device>(oss, ", "));
657 responseMessage = createExceptionResponse(
658 c10::str(
659 "RPC detected that a user-function output tensor on device ",
660 device,
661 ". This device is not one of the input tensor devices: ",
662 oss.str(),
663 "which is not yet supported. Please file a feature request "
664 "issue in PyTorch GitHub repo."),
665 messageId);
666 break;
667 }
668 }
669 }
670
671 pipeWrite(
672 pipe,
673 std::move(responseMessage),
674 std::move(devices),
675 std::move(streams),
676 [this, pipe, messageId](const tensorpipe::Error& error) {
677 if (error) {
678 LOG(WARNING)
679 << "RPC agent for " << workerInfo_.name_
680 << " encountered error when sending response to request #"
681 << messageId << " to " << pipe->getRemoteName() << ": "
682 << error.what();
683 return;
684 }
685
686 VLOG(1) << "RPC agent for " << workerInfo_.name_
687 << " done sending response to request #" << messageId
688 << " to " << pipe->getRemoteName();
689 });
690 } else {
691 pipeWrite(
692 pipe,
693 createExceptionResponse(
694 futureResponseMessage.tryRetrieveErrorMessage(), messageId),
695 /* devices */ {},
696 std::move(streams),
697 [this, pipe, messageId](const tensorpipe::Error& error) {
698 if (error) {
699 LOG(WARNING)
700 << "RPC agent for " << workerInfo_.name_
701 << " encountered error when sending response to request #"
702 << messageId << " to " << pipe->getRemoteName() << ": "
703 << error.what();
704 return;
705 }
706
707 VLOG(1) << "RPC agent for " << workerInfo_.name_
708 << " done sending response to request #" << messageId
709 << " to " << pipe->getRemoteName();
710 });
711 }
712}
713
714void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
715 pipeRead(
716 pipe,
717 [this, pipe](
718 const tensorpipe::Error& error,
719 c10::intrusive_ptr<Message> requestMessage,
720 std::vector<c10::Stream> streams) mutable {
721 if (error) {
722 if (shuttingDown_) {
723 // This is expected.
724 } else {
725 LOG(WARNING)
726 << "RPC agent for " << workerInfo_.name_
727 << " encountered error when reading incoming request from "
728 << pipe->getRemoteName() << ": " << error.what();
729 }
730 return;
731 }
732
733 // Arm for next read
734 respond(pipe);
735
736 uint64_t messageId = requestMessage->id();
737 increaseCallCount(serverActiveCalls_);
738
739 VLOG(1) << "RPC agent for " << workerInfo_.name_
740 << " received request #" << messageId << " from "
741 << pipe->getRemoteName();
742
743 // Defer user RPC UDF run to thread pool
744 threadPool_.run([this,
745 pipe,
746 messageId,
747 requestMessage{std::move(requestMessage)},
748 streams{std::move(streams)}]() mutable {
749 VLOG(1) << "RPC agent for " << workerInfo_.name_
750 << " is running request #" << messageId << " from "
751 << pipe->getRemoteName() << " in thread pool";
752
753 c10::intrusive_ptr<JitFuture> futureResponseMessage;
754 try {
755 // Instead of creating a MultiStreamGuard here, the ctx is passed
756 // to the callback and the MultiStreamGuard is created there,
757 // because subsequent processing can switch threads due to 1)
758 // waiting for RRef arguments to become ready 2) async_execution.
759 // Besides, the `ctx` also needs to be propagated to
760 // `process***Call` methods to synchronize CUDA streams there
761 // to make sure that we fetch the correct value from `to_here()`
762 // call.
763 futureResponseMessage =
764 cb_->operator()(*requestMessage, std::move(streams));
765 } catch (const std::exception& /* unused */) {
766 futureResponseMessage =
767 c10::make_intrusive<JitFuture>(at::AnyClassType::get());
768 futureResponseMessage->setError(std::current_exception());
769 }
770
771 increaseCallCount(serverActiveAsyncCalls_);
772 futureResponseMessage->addCallback(
773 [this, pipe, messageId](
774 JitFuture& futureResponseMessage) mutable {
775 decreaseCallCount(serverActiveCalls_);
776 decreaseCallCount(serverActiveAsyncCalls_);
777 auto streams = getCurrentStreamsForDevices(
778 futureResponseMessage.devices());
779 sendCompletedResponseMessage(
780 pipe, futureResponseMessage, messageId, std::move(streams));
781 });
782
783 VLOG(1) << "RPC agent for " << workerInfo_.name_
784 << " done running request #" << messageId << " from "
785 << pipe->getRemoteName() << " in thread pool";
786 });
787 });
788}
789
790c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
791 const WorkerInfo& toWorkerInfo,
792 c10::intrusive_ptr<Message> requestMessage,
793 const float rpcTimeoutSeconds,
794 const DeviceMap& deviceMap) {
795 TORCH_CHECK(
796 requestMessage->isRequest(),
797 "TensorPipeAgent::send(..) is only for sending requests.");
798
799 if (!rpcAgentRunning_.load()) {
800 auto err = c10::str(
801 "Node ",
802 RpcAgent::getWorkerInfo().id_,
803 "tried to send() a message of type ",
804 requestMessage->type(),
805 " but RPC is no longer running on this node.");
806 TORCH_CHECK(false, err);
807 }
808
809 const auto& url = findWorkerURL(toWorkerInfo);
810
811 decltype(connectedPipes_)::iterator it;
812 {
813 std::unique_lock<std::mutex> lock(connectedPipesMutex_);
814
815 // See if we already have a connection to this address or not
816 it = connectedPipes_.find(toWorkerInfo.id_);
817 if (it == connectedPipes_.end()) {
818 // An instance of ClientPipe cannot be copied or moved as it contains a
819 // mutex, and to force in-place construction in GCC 5 we need piecewise
820 // construction in order to work around an issue.
821 std::tie(it, std::ignore) = connectedPipes_.emplace(
822 std::piecewise_construct,
823 std::forward_as_tuple(toWorkerInfo.id_),
824 std::forward_as_tuple(context_->connect(
825 url, tensorpipe::PipeOptions().remoteName(toWorkerInfo.name_))));
826 }
827 }
828 ClientPipe& clientPipe = it->second;
829
830 std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
831 futureResponseMessage;
832 {
833 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
834 futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
835 }
836 uint64_t messageId = nextMessageID_++;
837 requestMessage->setId(messageId);
838
839 {
840 std::unique_lock<std::mutex> lock(clientPipe.mutex_);
841 clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
842 }
843
844 // Get devices for tensors in the request message. This can throw if device
845 // maps are not configured properly for this request.
846 std::vector<c10::Device> devices;
847 if (deviceMap.empty()) {
848 devices =
849 getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
850 } else {
851 // If deviceMap is specified, use that instead.
852 devices = getDevicesForTensors(
853 requestMessage->tensors(),
854 deviceMap,
855 clientPipe.pipe_->getRemoteName());
856 }
857
858 futureResponseMessage->jitFuture->addCallback(
859 [this](JitFuture& /* unused */) {
860 TORCH_INTERNAL_ASSERT(
861 this->threadPool_.inThreadPool(),
862 "Future marked complete from outside the thread pool");
863 });
864
865 increaseCallCount(clientActiveCalls_);
866 // Use the default RPC timeout if no timeout is specified for this send call
867 auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
868 ? getRpcTimeout()
869 : std::chrono::milliseconds(
870 static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
871
872 // We only add to the timeoutMap_ if the timeout is not 0. Per our
873 // documentation, a user-provided timeout of 0 indicates the RPC should never
874 // expire (infinite timeout), so there is no need to track it in the
875 // timeoutMap_.
876 steady_clock_time_point expirationTime;
877 if (timeout.count() != 0) {
878 // Compute the expiration time for this message based on the timeout
879 expirationTime = computeRpcMessageExpiryTime(timeout);
880
881 // Add the Future to the right vector in the timeoutMap_
882 {
883 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
884 auto& timeoutFuturesVector = timeoutMap_[expirationTime];
885 messageIdToTimeout_.emplace(messageId, expirationTime);
886 timeoutFuturesVector.emplace_back(
887 messageId, futureResponseMessage, timeout);
888 }
889 timeoutThreadCV_.notify_one();
890 }
891
892 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
893 << messageId << " to " << clientPipe.pipe_->getRemoteName();
894
895 std::vector<c10::Stream> streams;
896 {
897 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
898 streams = getStreamsFromPoolForDevices(devices_);
899 }
900 makeStreamsWaitOnOthers(
901 streams,
902 getCurrentStreamsForDevices(
903 getDevicesOfTensors(requestMessage->tensors())));
904 pipeWrite(
905 clientPipe.pipe_,
906 std::move(requestMessage),
907 std::move(devices),
908 std::move(streams),
909 [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
910 if (error) {
911 if (error.isOfType<tensorpipe::PipeClosedError>() &&
912 !rpcAgentRunning_.load()) {
913 // This is expected.
914 } else {
915 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
916 << " encountered error when sending outgoing request #"
917 << messageId << " to "
918 << clientPipe.pipe_->getRemoteName() << ": "
919 << error.what();
920 }
921 handleClientError(clientPipe, error);
922 return;
923 }
924
925 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #"
926 << messageId << " to " << clientPipe.pipe_->getRemoteName();
927
928 pipeRead(
929 clientPipe.pipe_,
930 [this, &clientPipe](
931 const tensorpipe::Error& error,
932 c10::intrusive_ptr<Message> responseMessage,
933 std::vector<c10::Stream> streams) {
934 if (error) {
935 if (error.isOfType<tensorpipe::PipeClosedError>() &&
936 !rpcAgentRunning_.load()) {
937 // This is expected.
938 } else {
939 LOG(WARNING)
940 << "RPC agent for " << workerInfo_.name_
941 << " encountered error when reading incoming response from "
942 << clientPipe.pipe_->getRemoteName() << ": "
943 << error.what();
944 }
945 handleClientError(clientPipe, error);
946 return;
947 }
948
949 // Identify future response message by message ID
950 uint64_t messageId = responseMessage->id();
951
952 VLOG(1) << "RPC agent for " << workerInfo_.name_
953 << " received response #" << messageId << " from "
954 << clientPipe.pipe_->getRemoteName();
955
956 std::shared_ptr<AtomicJitFuture> futureResponseMessage;
957 {
958 std::lock_guard<std::mutex> lock(clientPipe.mutex_);
959 // A read error will lead all following callbacks to be
960 // invoked with error, and shouldn't reach here.
961 TORCH_INTERNAL_ASSERT(
962 !clientPipe.inError_, "Shouldn't be in error state");
963 auto it = clientPipe.pendingResponseMessage_.find(messageId);
964 TORCH_INTERNAL_ASSERT(
965 it != clientPipe.pendingResponseMessage_.end(),
966 "message ID ",
967 messageId,
968 " is not recognized");
969 futureResponseMessage = std::move(it->second);
970 clientPipe.pendingResponseMessage_.erase(it);
971 }
972
973 // Remove entry from timeoutMap_.
974 removeFromTimeoutMap(messageId);
975
976 if (responseMessage->type() == MessageType::EXCEPTION) {
977 markFutureWithError(
978 std::move(futureResponseMessage),
979 std::string(
980 responseMessage->payload().begin(),
981 responseMessage->payload().end()));
982 } else {
983 markFutureAsComplete(
984 std::move(futureResponseMessage),
985 std::move(responseMessage),
986 std::move(streams));
987 }
988 });
989 });
990
991 return futureResponseMessage->jitFuture;
992}
993
994void TensorPipeAgent::handleClientError(
995 ClientPipe& clientPipe,
996 const tensorpipe::Error& error) {
997 // When an error occurs on a pipe all pending operations will be aborted and
998 // all callbacks invoked with error, hence we immediately flush all future
999 // messages belonging to this pipe.
1000 decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
1001 {
1002 std::lock_guard<std::mutex> lock(clientPipe.mutex_);
1003 std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
1004 clientPipe.inError_ = true;
1005 }
1006 std::string errorMsg = error.what();
1007 for (auto& p : pendingMsgs) {
1008 markFutureWithError(std::move(p.second), errorMsg);
1009
1010 // Remove entry from timeoutMap_.
1011 removeFromTimeoutMap(p.first);
1012 }
1013}
1014
1015void TensorPipeAgent::pollTimeoutRpcs() {
1016 while (rpcAgentRunning_.load()) {
1017 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1018
1019 // We sleep until the earliest expiring RPC in the timeoutMap_. We must
1020 // also ensure that we sleep while the map is empty, and we exit sleeping
1021 // if the RPC Agent has been shutdown.
1022 for (;;) {
1023 if (!rpcAgentRunning_.load()) {
1024 return;
1025 }
1026
1027 if (!timeoutMap_.empty()) {
1028 steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
1029 if (std::chrono::steady_clock::now() >= earliestTimeout) {
1030 break;
1031 }
1032 timeoutThreadCV_.wait_until(lock, earliestTimeout);
1033 } else {
1034 timeoutThreadCV_.wait(lock);
1035 }
1036 }
1037
1038 // Move all these futures to a separate vector so we can process them
1039 // outside the lock.
1040 std::vector<TimeoutMessageMetadata> timedOutFutures =
1041 std::move(timeoutMap_.begin()->second);
1042
1043 // We can safely remove this key from the timeoutMap_ since all these
1044 // futures will be processed.
1045 timeoutMap_.erase(timeoutMap_.begin());
1046
1047 for (auto& timeoutMetadata : timedOutFutures) {
1048 // Remove from messageIdToTimeout map.
1049 messageIdToTimeout_.erase(timeoutMetadata.messageId);
1050 }
1051 lock.unlock();
1052
1053 // Set an error on futures added to the timedOutFutures vector. We do this
1054 // outside the lock to prevent potential lock-order-inversions by callbacks
1055 // triggered by the setError call.
1056 for (auto& timeoutMetadata : timedOutFutures) {
1057 std::string errorMsg =
1058 fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
1059 auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
1060 markFutureWithError(
1061 std::move(timeoutMetadata.responseFuture), std::move(err));
1062 }
1063 }
1064}
1065
1066void TensorPipeAgent::leaveGroup() {
1067 std::unique_lock<std::mutex> lock(callCountMutex_);
1068 // local worker ActiveCallCount is 0 at this point and we will shutdown
1069 // (any future calls will be dropped)
1070 callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1071
1072 // Remove this agent's WorkerInfo from store
1073 removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
1074
1075 // Set internal variable to be used during destructor
1076 shuttingDown_ = true;
1077}
1078
1079// TODO: Remove join()
1080void TensorPipeAgent::join(bool shutdown, float /* unused */) {
1081 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
1082 if (!isStaticGroup_) {
1083 leaveGroup();
1084 return;
1085 }
1086
1087 // This method behaves like a barrier, as it can only return once all workers
1088 // have no more requests pending, including "nested" requests (triggered from
1089 // within the remote code of another call) and "follow-up" requests (triggered
1090 // from the callback of a future).
1091 while (true) {
1092 {
1093 std::unique_lock<std::mutex> lock(callCountMutex_);
1094 // It is enough to wait for there to be no more active client calls, since
1095 // each server call corresponds to a client call for some other worker.
1096 callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1097
1098 // We'd like to immediately proceed with the allreduce, but it's a call
1099 // that may block for some time, as it waits for other workers to also
1100 // complete all their active client calls. While we call allreduce we must
1101 // hold the mutex, or else the count we send to other workers may get
1102 // stale (e.g., if some nested call happens in the meantime). But we can't
1103 // hold the lock for an indeterminately long time, as that would block
1104 // other operations (e.g., send). Thus we must release the lock and only
1105 // re-acquire it when all workers are ready to proceed with the allreduce.
1106 // We perform this synchronization using a barrier.
1107 }
1108 VLOG(1) << "RPC agent for " << workerInfo_.name_
1109 << " completed all client calls and is entering a barrier";
1110 syncCallCount(shutdownStore_, worldSize_);
1111 {
1112 std::unique_lock<std::mutex> lock(callCountMutex_);
1113 // At this point, the count may have become non-zero again. We can't wait
1114 // for those calls to complete as other workers are waiting for us in the
1115 // allreduce and we would block them. Thus we send our count even if it is
1116 // non-zero and if anyone (be it us or another worker) has a non-zero
1117 // count we'll just do another round.
1118 VLOG(1) << "RPC agent for " << workerInfo_.name_
1119 << " exited the barrier and found " << clientActiveCalls_
1120 << " active client calls";
1121 int totalClientActiveCalls =
1122 syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
1123 VLOG(1) << "RPC agent for " << workerInfo_.name_
1124 << " completed sync call counts and got a total of "
1125 << totalClientActiveCalls
1126 << " active client calls across all workers";
1127 if (totalClientActiveCalls == 0) {
1128 if (shutdown) {
1129 shuttingDown_ = true;
1130 syncCallCount(shutdownStore_, worldSize_);
1131 }
1132 break;
1133 }
1134 }
1135 }
1136 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
1137}
1138
1139void TensorPipeAgent::shutdownImpl() {
1140 // FIXME Isn't it too verbose for a library to print logs in normal operation?
1141 LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
1142
1143 // Join the Timeout Thread
1144 timeoutThreadCV_.notify_one();
1145 if (timeoutThread_.joinable()) {
1146 timeoutThread_.join();
1147 }
1148 VLOG(1) << "RPC agent for " << workerInfo_.name_
1149 << " done waiting for timeout thread to join";
1150
1151 // This will close all the pipes and listeners, invoke all callbacks with
1152 // errors, turn down the I/O event loops and wait for everything to terminate.
1153 context_->join();
1154 VLOG(1) << "RPC agent for " << workerInfo_.name_
1155 << " done waiting for TensorPipe context to join";
1156
1157 // NOTE: We need to call waitWorkComplete in the end after we have shutdown
1158 // all listeners for Tensorpipe. This is to drain any already accepted work
1159 // in the ThreadPool. If this is done before we shutdown the listeners,
1160 // additional work could be added after this call and before we shutdown
1161 // listeners. This work would continue executing in the threadpool and might
1162 // cause issues during shutdown of the system.
1163 threadPool_.waitWorkComplete();
1164 VLOG(1) << "RPC agent for " << workerInfo_.name_
1165 << " done waiting for thread pool to complete work";
1166}
1167
1168const WorkerInfo& TensorPipeAgent::getWorkerInfo(
1169 const std::string& workerName) const {
1170 std::unordered_map<std::string, WorkerInfo>::const_iterator it;
1171 {
1172 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1173 it = workerNameToInfo_.find(workerName);
1174 }
1175 TORCH_CHECK(
1176 it != workerNameToInfo_.end(),
1177 fmt::format(
1178 "name:{},rank:{} could not find destination name {}",
1179 workerInfo_.name_,
1180 workerInfo_.id_,
1181 workerName));
1182 return it->second;
1183}
1184
1185const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
1186 std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
1187 {
1188 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1189 it = workerIdToInfo_.find(workerId);
1190 }
1191 TORCH_CHECK(
1192 it != workerIdToInfo_.end(),
1193 fmt::format(
1194 "name:{},rank:{} could not find destination id {}",
1195 workerInfo_.name_,
1196 workerInfo_.id_,
1197 workerId));
1198 return it->second;
1199}
1200
1201std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
1202 std::vector<WorkerInfo> workerInfos;
1203 workerInfos.reserve(workerNameToInfo_.size());
1204 for (auto& item : workerNameToInfo_) {
1205 workerInfos.emplace_back(item.second);
1206 }
1207 return workerInfos;
1208}
1209
1210const std::string& TensorPipeAgent::findWorkerURL(
1211 const WorkerInfo& worker) const {
1212 std::unordered_map<std::string, std::string>::const_iterator it;
1213 {
1214 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1215 it = workerNameToURL_.find(worker.name_);
1216 }
1217 TORCH_CHECK(
1218 it != workerNameToURL_.end(),
1219 fmt::format(
1220 "name:{},rank:{} could not find destination url for name {}",
1221 workerInfo_.name_,
1222 workerInfo_.id_,
1223 worker.name_));
1224 return it->second;
1225}
1226
1227void TensorPipeAgent::updateGroupMembership(
1228 const WorkerInfo& workerInfo,
1229 const std::vector<c10::Device> devices,
1230 const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
1231 bool isJoin) {
1232 std::string name = workerInfo.name_;
1233 worker_id_t id = workerInfo.id_;
1234 // Rank with workerInfo is joining the group, update internal mappings
1235 if (isJoin) {
1236 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1237 workerIdToInfo_.emplace(id, workerInfo);
1238 workerNameToInfo_.emplace(name, workerInfo);
1239
1240 // TODO: we should get nodeAddrStr in the joining process, then pass in as
1241 // an argument rather than getting from store each time
1242 auto nodeAddrData = nameToAddressStore_.get(name);
1243 auto nodeAddrStr =
1244 std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
1245 workerNameToURL_.insert({name, nodeAddrStr});
1246
1247 for (const auto& it : reverseDeviceMaps) {
1248 if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
1249 reverseDeviceMaps_[it.first] = it.second;
1250 }
1251 }
1252 // TODO: clean up mutex for devices_ usage
1253 // Add devices that have not been added yet
1254 for (const auto& it : devices) {
1255 if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
1256 devices_.push_back(it);
1257 }
1258 }
1259 } else {
1260 workerIdToInfo_.erase(id);
1261 workerNameToInfo_.erase(name);
1262 workerNameToURL_.erase(name);
1263
1264 // remove reverse device maps that are no longer used
1265 for (auto it = reverseDeviceMaps_.begin();
1266 it != reverseDeviceMaps_.end();) {
1267 if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
1268 it = reverseDeviceMaps_.erase(it);
1269 } else {
1270 it++;
1271 }
1272 }
1273
1274 // remove devices that are no longer used
1275 for (auto it = devices_.begin(); it != devices_.end();) {
1276 if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
1277 it = devices_.erase(it);
1278 } else {
1279 it++;
1280 }
1281 }
1282 }
1283}
1284std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
1285 std::unordered_map<std::string, std::string> metrics;
1286 metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
1287 metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
1288 {
1289 std::unique_lock<std::mutex> lock(callCountMutex_);
1290 metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_);
1291 metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_);
1292 metrics[kServerActiveAsyncCalls] = c10::to_string(serverActiveAsyncCalls_);
1293 }
1294 if (isGILProfilingEnabled()) {
1295 {
1296 std::unique_lock<std::mutex> lock(metricsMutex_);
1297 // Include the averages for each time series metric. This is just the GIL
1298 // Wait Time for now.
1299 auto averageGilWaitTime =
1300 timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
1301 lock.unlock();
1302 metrics[kGilAverageWaitTime] = c10::to_string(averageGilWaitTime);
1303 }
1304 }
1305
1306 return metrics;
1307}
1308
1309void TensorPipeAgent::addGilWaitTime(
1310 const std::chrono::microseconds gilWaitTime) {
1311 std::lock_guard<std::mutex> lock(metricsMutex_);
1312 timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
1313}
1314
1315TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() {
1316 std::lock_guard<std::mutex> lock(networkDataMutex_);
1317 return networkData_;
1318}
1319
1320NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() {
1321 NetworkSourceInfo info = {
1322 RpcAgent::getWorkerInfo().id_,
1323 nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
1324
1325 return info;
1326}
1327
1328void TensorPipeAgent::trackNetworkData(
1329 uint64_t requestSize,
1330 uint64_t responseSize,
1331 const std::string& destWorkerName) {
1332 std::lock_guard<std::mutex> lock(networkDataMutex_);
1333 networkData_[destWorkerName].numCalls++;
1334 networkData_[destWorkerName].totalSentBytes += requestSize;
1335 networkData_[destWorkerName].totalRecvBytes += responseSize;
1336}
1337
1338void TensorPipeAgent::trackNetworkError(
1339 uint64_t requestSize,
1340 const std::string& destWorkerName) {
1341 std::lock_guard<std::mutex> lock(networkDataMutex_);
1342 networkData_[destWorkerName].numCalls++;
1343 networkData_[destWorkerName].totalSentBytes += requestSize;
1344 networkData_[destWorkerName].totalErrors++;
1345}
1346
1347void TensorPipeAgent::increaseCallCount(int32_t& count) {
1348 {
1349 std::unique_lock<std::mutex> lock(callCountMutex_);
1350 ++count;
1351 }
1352 callCountCV_.notify_all();
1353}
1354
1355void TensorPipeAgent::decreaseCallCount(int32_t& count) {
1356 {
1357 std::unique_lock<std::mutex> lock(callCountMutex_);
1358 --count;
1359 }
1360 callCountCV_.notify_all();
1361}
1362
1363void TensorPipeAgent::markFutureAsComplete(
1364 std::shared_ptr<AtomicJitFuture> atomicFuture,
1365 c10::intrusive_ptr<Message> message,
1366 std::vector<c10::Stream> streams) {
1367 if (!atomicFuture->isComplete.test_and_set()) {
1368 // Completing the future will run its callbacks, which could execute
1369 // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1370 // loops, we defer this to a worker thread.
1371 threadPool_.run([this,
1372 atomicFuture{std::move(atomicFuture)},
1373 message{std::move(message)},
1374 streams{std::move(streams)}]() mutable {
1375 c10::MultiStreamGuard guard(streams);
1376 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
1377 message->getStorages();
1378 atomicFuture->jitFuture->markCompleted(
1379 std::move(message), std::move(storages));
1380 // The future's callbacks may schedule further RPCs, increasing the count.
1381 // Thus we must decrease it after completing the future, otherwise it may
1382 // briefly dip to zero and trick join into thinking all work is done.
1383 decreaseCallCount(clientActiveCalls_);
1384 });
1385 }
1386}
1387
1388void TensorPipeAgent::markFutureWithError(
1389 std::shared_ptr<AtomicJitFuture> atomicFuture,
1390 std::string errorMsg) {
1391 if (!atomicFuture->isComplete.test_and_set()) {
1392 // Completing the future will run its callbacks, which could execute
1393 // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1394 // loops, we defer this to a worker thread.
1395 threadPool_.run([this,
1396 atomicFuture{std::move(atomicFuture)},
1397 errorMsg{std::move(errorMsg)}]() mutable {
1398 atomicFuture->jitFuture->setError(
1399 std::make_exception_ptr(std::runtime_error(errorMsg)));
1400 // The future's callbacks may schedule further RPCs, increasing the count.
1401 // Thus we must decrease it after completing the future, otherwise it may
1402 // briefly dip to zero and trick join into thinking all work is done.
1403 decreaseCallCount(clientActiveCalls_);
1404 });
1405 }
1406}
1407
1408std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
1409 const std::string& remoteName,
1410 const Message& message) const {
1411 std::unordered_map<std::string, DeviceMap> deviceMaps;
1412 {
1413 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1414 deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
1415 }
1416
1417 const auto errStr = c10::str(
1418 "TensorPipe RPC backend only supports CPU tensors by default, please "
1419 "move your tensors to CPU before sending them over RPC, or call "
1420 "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
1421 "configure device mapping. ",
1422 message.isRequest() ? "Request" : "Response",
1423 " device mapping is not available for destination ",
1424 remoteName);
1425
1426 const auto& iter = deviceMaps.find(remoteName);
1427 if (iter == deviceMaps.end()) {
1428 for (const auto& t : message.tensors()) {
1429 TORCH_CHECK(
1430 t.device().is_cpu(),
1431 errStr,
1432 ", but found tensor on device: ",
1433 t.device());
1434 }
1435 return {};
1436 } else {
1437 return getDevicesForTensors(message.tensors(), iter->second, errStr);
1438 }
1439}
1440
1441DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
1442 auto it = opts_.deviceMaps.find(dst.name_);
1443 if (it == opts_.deviceMaps.end()) {
1444 return {};
1445 }
1446 return it->second;
1447}
1448
1449const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const {
1450 return store_;
1451}
1452
1453TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
1454 return opts_;
1455}
1456
1457const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
1458 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1459 return devices_;
1460}
1461
1462size_t TensorPipeAgent::timeoutMapSize() {
1463 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1464 return timeoutMap_.size();
1465}
1466
1467size_t TensorPipeAgent::numPendingResponses() {
1468 std::unique_lock<std::mutex> lock(callCountMutex_);
1469 return clientActiveCalls_;
1470}
1471
1472size_t TensorPipeAgent::messageIdToTimeoutMapSize() {
1473 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1474 return messageIdToTimeout_.size();
1475}
1476
1477} // namespace rpc
1478} // namespace distributed
1479} // namespace torch
1480
1481#endif // USE_TENSORPIPE
1482