1 | #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> |
2 | |
3 | #ifdef USE_TENSORPIPE |
4 | |
5 | #include <limits> |
6 | #include <tuple> |
7 | #include <utility> |
8 | |
9 | #include <fmt/format.h> |
10 | #include <tensorpipe/tensorpipe.h> |
11 | |
12 | #include <torch/csrc/distributed/rpc/agent_utils.h> |
13 | #include <torch/csrc/distributed/rpc/tensorpipe_utils.h> |
14 | #include <torch/csrc/distributed/rpc/utils.h> |
15 | |
16 | #include <c10/core/StreamGuard.h> |
17 | #include <c10/util/irange.h> |
18 | |
19 | namespace torch { |
20 | namespace distributed { |
21 | namespace rpc { |
22 | |
23 | namespace { |
24 | |
25 | // An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that |
26 | // allows the user to specify a device to bind to, instead of binding to the |
27 | // address that the hostname resolves to. |
28 | const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME" ; |
29 | const std::string kDefaultUvAddress = "127.0.0.1" ; |
30 | |
31 | const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us" ; |
32 | const std::string kThreadPoolSize = "agent.thread_pool_size" ; |
33 | const std::string kNumIdleThreads = "agent.num_idle_threads" ; |
34 | const std::string kClientActiveCalls = "agent.client_active_calls" ; |
35 | const std::string kServerActiveCalls = "agent.server_active_calls" ; |
36 | const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls" ; |
37 | |
38 | std::vector<c10::Device> getDevicesForTensors( |
39 | const std::vector<torch::Tensor>& tensors, |
40 | const DeviceMap& deviceMap, |
41 | const std::string& remoteName) { |
42 | // If the deviceMap is overridden, use that instead. |
43 | const auto errStr = c10::str( |
44 | "TensorPipe RPC backend only supports CPU tensors by default, please " |
45 | "move your tensors to CPU before sending them over RPC, or call " |
46 | "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly " |
47 | "configure device mapping. " , |
48 | "Request device mapping is not available for destination " , |
49 | remoteName); |
50 | std::vector<c10::Device> devices; |
51 | devices.reserve(tensors.size()); |
52 | bool hasMappedDevice = false; |
53 | for (const auto& t : tensors) { |
54 | if (t.device().is_cpu()) { |
55 | const auto deviceIter = deviceMap.find(c10::kCPU); |
56 | if (deviceIter == deviceMap.end()) { |
57 | devices.emplace_back(c10::kCPU); |
58 | } else { |
59 | devices.emplace_back(deviceIter->second); |
60 | hasMappedDevice = true; |
61 | } |
62 | } else { |
63 | const auto deviceIter = deviceMap.find(t.device()); |
64 | TORCH_CHECK( |
65 | deviceIter != deviceMap.end(), |
66 | errStr, |
67 | " for device " , |
68 | t.device(), |
69 | " but received a tensor on that device." ); |
70 | devices.push_back(deviceIter->second); |
71 | hasMappedDevice = true; |
72 | } |
73 | } |
74 | if (!hasMappedDevice) { |
75 | devices.clear(); |
76 | } |
77 | return devices; |
78 | } |
79 | |
80 | std::vector<c10::Stream> getStreamsFromPoolForDevices( |
81 | const std::vector<c10::Device>& devices) { |
82 | if (devices.empty()) { |
83 | return {}; |
84 | } |
85 | c10::impl::VirtualGuardImpl impl(devices[0].type()); |
86 | std::vector<c10::Stream> streams; |
87 | streams.reserve(devices.size()); |
88 | for (const c10::Device& device : devices) { |
89 | TORCH_INTERNAL_ASSERT(device.type() == impl.type()); |
90 | streams.push_back(impl.getStreamFromGlobalPool(device)); |
91 | } |
92 | return streams; |
93 | } |
94 | |
95 | std::vector<c10::Stream> getCurrentStreamsForDevices( |
96 | const std::vector<c10::Device>& devices) { |
97 | if (devices.empty()) { |
98 | return {}; |
99 | } |
100 | c10::impl::VirtualGuardImpl impl(devices[0].type()); |
101 | std::vector<c10::Stream> streams; |
102 | streams.reserve(devices.size()); |
103 | for (const c10::Device& device : devices) { |
104 | TORCH_INTERNAL_ASSERT(device.type() == impl.type()); |
105 | streams.push_back(impl.getStream(device)); |
106 | } |
107 | return streams; |
108 | } |
109 | |
110 | std::vector<c10::Device> getDevicesOfTensors( |
111 | const std::vector<torch::Tensor>& tensors) { |
112 | c10::optional<c10::impl::VirtualGuardImpl> impl; |
113 | size_t deviceCount = 0; |
114 | std::vector<bool> indexBitset; |
115 | for (const torch::Tensor& tensor : tensors) { |
116 | if (!tensor.is_cpu()) { |
117 | c10::Device device = tensor.device(); |
118 | if (!impl.has_value()) { |
119 | impl.emplace(device.type()); |
120 | indexBitset.resize(impl->deviceCount()); |
121 | } |
122 | TORCH_INTERNAL_ASSERT(device.type() == impl->type()); |
123 | TORCH_INTERNAL_ASSERT(device.has_index()); |
124 | if (!indexBitset[device.index()]) { |
125 | deviceCount++; |
126 | indexBitset[device.index()] = true; |
127 | } |
128 | } |
129 | } |
130 | std::vector<c10::Device> devices; |
131 | devices.reserve(deviceCount); |
132 | for (const auto idx : c10::irange(indexBitset.size())) { |
133 | if (indexBitset[idx]) { |
134 | devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx)); |
135 | } |
136 | } |
137 | return devices; |
138 | } |
139 | |
140 | void makeStreamsWaitOnOthers( |
141 | const std::vector<c10::Stream>& consumers, |
142 | const std::vector<c10::Stream>& producers) { |
143 | for (const c10::Stream& producer : producers) { |
144 | const c10::Stream& consumer = |
145 | getStreamForDevice(consumers, producer.device()); |
146 | c10::Event event(producer.device_type()); |
147 | event.record(producer); |
148 | event.block(consumer); |
149 | } |
150 | } |
151 | |
152 | } // namespace |
153 | |
154 | C10_DEFINE_REGISTRY_WITHOUT_WARNING( |
155 | TensorPipeTransportRegistry, |
156 | TransportRegistration); |
157 | |
158 | C10_DEFINE_REGISTRY_WITHOUT_WARNING( |
159 | TensorPipeChannelRegistry, |
160 | ChannelRegistration); |
161 | |
162 | const std::string& TensorPipeAgent::guessAddress() { |
163 | static const std::string uvAddress = []() { |
164 | tensorpipe::Error error; |
165 | std::string result; |
166 | char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str()); |
167 | if (ifnameEnv != nullptr) { |
168 | std::tie(error, result) = |
169 | tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv); |
170 | if (error) { |
171 | LOG(WARNING) << "Failed to look up the IP address for interface " |
172 | << ifnameEnv << " (" << error.what() << "), defaulting to " |
173 | << kDefaultUvAddress; |
174 | return kDefaultUvAddress; |
175 | } |
176 | } else { |
177 | std::tie(error, result) = |
178 | tensorpipe::transport::uv::lookupAddrForHostname(); |
179 | if (error) { |
180 | LOG(WARNING) << "Failed to look up the IP address for the hostname (" |
181 | << error.what() << "), defaulting to " |
182 | << kDefaultUvAddress; |
183 | return kDefaultUvAddress; |
184 | } |
185 | } |
186 | return result; |
187 | }(); |
188 | return uvAddress; |
189 | } |
190 | |
191 | namespace { |
192 | |
193 | std::unique_ptr<TransportRegistration> makeUvTransport() { |
194 | auto context = tensorpipe::transport::uv::create(); |
195 | std::string address = TensorPipeAgent::guessAddress(); |
196 | return std::make_unique<TransportRegistration>(TransportRegistration{ |
197 | std::move(context), kUvTransportPriority, std::move(address)}); |
198 | } |
199 | |
200 | // The UV transport is implemented using standard TCP connections. It leverages |
201 | // libuv (https://github.com/libuv/libuv) in order to be cross-platform. |
202 | C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport); |
203 | |
204 | #if TENSORPIPE_HAS_SHM_TRANSPORT |
205 | |
206 | std::unique_ptr<TransportRegistration> makeShmTransport() { |
207 | auto context = tensorpipe::transport::shm::create(); |
208 | return std::make_unique<TransportRegistration>( |
209 | TransportRegistration{std::move(context), kShmTransportPriority, "" }); |
210 | } |
211 | |
212 | // The SHM implements connections using ringbuffers residing in anonymous shared |
213 | // memory (plus UNIX domain sockets to bootstrap the connection and exchange |
214 | // file descriptors). It is Linux-only due to some advanced features (O_TMPFILE, |
215 | // eventfd, ...). |
216 | C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport); |
217 | |
218 | #endif // TENSORPIPE_HAS_SHM_TRANSPORT |
219 | |
220 | #if TENSORPIPE_HAS_IBV_TRANSPORT |
221 | |
222 | std::unique_ptr<TransportRegistration> makeIbvTransport() { |
223 | auto context = tensorpipe::transport::ibv::create(); |
224 | std::string address = TensorPipeAgent::guessAddress(); |
225 | return std::make_unique<TransportRegistration>(TransportRegistration{ |
226 | std::move(context), kIbvTransportPriority, std::move(address)}); |
227 | } |
228 | |
229 | // The IBV transport sends data across using an InfiniBand queue pair, locally |
230 | // copying data to and from a staging buffer (registered with libibverbs) and |
231 | // issuing a RDMA write for transferring data across machines (plus a send for |
232 | // acknowledging it). It bootstraps using a standard TCP connection to exchange |
233 | // setup information. It is Linux-only. |
234 | C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport); |
235 | |
236 | #endif // TENSORPIPE_HAS_IBV_TRANSPORT |
237 | |
238 | std::unique_ptr<ChannelRegistration> makeBasicChannel() { |
239 | auto context = tensorpipe::channel::basic::create(); |
240 | return std::make_unique<ChannelRegistration>( |
241 | ChannelRegistration{std::move(context), kBasicChannelPriority}); |
242 | } |
243 | |
244 | // The basic channel is just a straightforward adapter wrapper that allows any |
245 | // transport to be used as a channel. |
246 | C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel); |
247 | |
248 | #if TENSORPIPE_HAS_CMA_CHANNEL |
249 | |
250 | std::unique_ptr<ChannelRegistration> makeCmaChannel() { |
251 | auto context = tensorpipe::channel::cma::create(); |
252 | return std::make_unique<ChannelRegistration>( |
253 | ChannelRegistration{std::move(context), kCmaChannelPriority}); |
254 | } |
255 | |
256 | // The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv |
257 | // and _writev), which allow one process to access the private memory of another |
258 | // process (as long as they belong to the same user and other security |
259 | // constraints are satisfied). It does, more or less, what GDB does when it's |
260 | // attached to a running process. |
261 | C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel); |
262 | |
263 | #endif // TENSORPIPE_HAS_CMA_CHANNEL |
264 | |
265 | constexpr static int kNumUvThreads = 16; |
266 | |
267 | std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() { |
268 | std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts; |
269 | std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners; |
270 | for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) { |
271 | auto context = tensorpipe::transport::uv::create(); |
272 | std::string address = TensorPipeAgent::guessAddress(); |
273 | contexts.push_back(std::move(context)); |
274 | listeners.push_back(contexts.back()->listen(address)); |
275 | } |
276 | auto context = tensorpipe::channel::mpt::create( |
277 | std::move(contexts), std::move(listeners)); |
278 | return std::make_unique<ChannelRegistration>( |
279 | ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority}); |
280 | } |
281 | |
282 | // The multiplexed UV channel encapsulates multiple UV transports (each with its |
283 | // own event loop thread). Each channel will, in turn, contain multiple UV |
284 | // connections, one for each of those contexts. When sending a tensor, its data |
285 | // is split in equal chunks and each chunks is sent on a different connection |
286 | // and thus driven by a different thread. This is needed to reach very high |
287 | // bandwidths. |
288 | C10_REGISTER_CREATOR( |
289 | TensorPipeChannelRegistry, |
290 | mpt_uv, |
291 | makeMultiplexedUvChannel); |
292 | |
293 | } // namespace |
294 | |
295 | ////////////////////////// MetricsTracker ///////////////////////////////// |
296 | |
297 | TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker( |
298 | uint64_t currentSum, |
299 | uint64_t currentCount) |
300 | : currentSum_(currentSum), currentCount_(currentCount) {} |
301 | |
302 | void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) { |
303 | currentSum_ += dataPoint; |
304 | ++currentCount_; |
305 | } |
306 | |
307 | float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const { |
308 | return currentCount_ == 0 ? 0 : currentSum_ / (float)currentCount_; |
309 | } |
310 | |
311 | //////////////////////// TensorpipeRpcAgent ///////////////////////////////// |
312 | |
313 | void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) { |
314 | // Remove entry from timeoutMap_. |
315 | { |
316 | std::unique_lock<std::mutex> lock(timeoutMapMutex_); |
317 | auto it = messageIdToTimeout_.find(messageId); |
318 | if (it == messageIdToTimeout_.end()) { |
319 | // Already removed from the map by pollTimeoutRpcs(), no need to |
320 | // process further. |
321 | return; |
322 | } |
323 | |
324 | auto& expirationTime = it->second; |
325 | |
326 | auto& timedOutFuturesVector = timeoutMap_[expirationTime]; |
327 | for (auto it = timedOutFuturesVector.begin(); |
328 | it != timedOutFuturesVector.end(); |
329 | it++) { |
330 | if (it->messageId == messageId) { |
331 | it = timedOutFuturesVector.erase(it); |
332 | break; |
333 | } |
334 | } |
335 | |
336 | if (timedOutFuturesVector.empty()) { |
337 | timeoutMap_.erase(expirationTime); |
338 | } |
339 | |
340 | // Remove from messageId to timeout map as well. |
341 | messageIdToTimeout_.erase(messageId); |
342 | } |
343 | } |
344 | |
345 | void TensorPipeAgent::prepareNames(bool isStaticGroup) { |
346 | std::unordered_map<std::string, worker_id_t> nameToId; |
347 | if (isStaticGroup) { |
348 | nameToId = collectNames( |
349 | rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_); |
350 | } else { |
351 | nameToId = collectCurrentNames( |
352 | rankToNameStore_, workerInfo_.id_, workerInfo_.name_); |
353 | } |
354 | |
355 | for (const auto& entry : nameToId) { |
356 | const auto& workerName = entry.first; |
357 | const auto& workerId = entry.second; |
358 | workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId)); |
359 | workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId)); |
360 | } |
361 | } |
362 | |
363 | void TensorPipeAgent::checkAndSetStaticGroup( |
364 | const c10::intrusive_ptr<::c10d::Store>& store) { |
365 | std::string isStaticGroupKey("rpcIsStaticGroup" ); |
366 | |
367 | std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false" ; |
368 | std::vector<uint8_t> isStaticGroupVec( |
369 | (uint8_t*)isStaticGroupStr.c_str(), |
370 | (uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length()); |
371 | std::vector<uint8_t> returnedVec; |
372 | returnedVec = store->compareSet( |
373 | isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec); |
374 | std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end()); |
375 | // In both cases, the returned value should be the value of isStaticGroupStr, |
376 | // otherwise there is a discrepency with initialization among one of the |
377 | // members |
378 | TORCH_CHECK( |
379 | returnedVal == isStaticGroupStr, |
380 | fmt::format( |
381 | "RPC group mixes statically and dynamically initialized members which is not supported. " , |
382 | "Static group property is initialized as {} and is trying to be set as {} " , |
383 | isStaticGroup_, |
384 | returnedVal)); |
385 | } |
386 | |
387 | TensorPipeAgent::TensorPipeAgent( |
388 | const c10::intrusive_ptr<::c10d::Store>& store, |
389 | std::string selfName, |
390 | worker_id_t selfId, |
391 | optional<int> worldSize, |
392 | TensorPipeRpcBackendOptions opts, |
393 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, |
394 | std::vector<c10::Device> devices, |
395 | std::unique_ptr<RequestCallback> cb) |
396 | : RpcAgent( |
397 | WorkerInfo(std::move(selfName), selfId), |
398 | std::move(cb), |
399 | std::chrono::milliseconds( |
400 | (long)(opts.rpcTimeoutSeconds * kSecToMsConversion))), |
401 | isStaticGroup_(worldSize.has_value()), |
402 | store_(store), |
403 | opts_(std::move(opts)), |
404 | reverseDeviceMaps_(std::move(reverseDeviceMaps)), |
405 | devices_(std::move(devices)), |
406 | threadPool_(opts_.numWorkerThreads), |
407 | context_(std::make_shared<tensorpipe::Context>( |
408 | tensorpipe::ContextOptions().name(workerInfo_.name_))), |
409 | rankToNameStore_("names" , store), |
410 | nameToAddressStore_("addrs" , store), |
411 | shutdownStore_("shutdown" , store) { |
412 | if (isStaticGroup_) { |
413 | worldSize_ = worldSize.value(); |
414 | } |
415 | |
416 | // check the static group attribute against store |
417 | checkAndSetStaticGroup(store); |
418 | |
419 | // collect worker names |
420 | prepareNames(isStaticGroup_); |
421 | |
422 | // Initialize the time-series metrics tracking map |
423 | timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker()); |
424 | } |
425 | |
426 | TensorPipeAgent::~TensorPipeAgent() { |
427 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed" ; |
428 | shutdown(); |
429 | } |
430 | |
431 | void TensorPipeAgent::startImpl() { |
432 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting" ; |
433 | |
434 | std::vector<std::string> addresses; |
435 | int lowestPriority = std::numeric_limits<int>::max(); |
436 | std::string lowestPriorityTransport; |
437 | |
438 | // Register transports |
439 | for (auto& key : TensorPipeTransportRegistry()->Keys()) { |
440 | int64_t priority = -1; |
441 | if (opts_.transports.has_value()) { |
442 | auto iter = |
443 | std::find(opts_.transports->begin(), opts_.transports->end(), key); |
444 | if (iter == opts_.transports->end()) { |
445 | continue; |
446 | } |
447 | // Assign priorities in reverse order of occurrence in the vector, so that |
448 | // a transport that comes before another receives a higher priority. |
449 | priority = |
450 | opts_.transports->size() - 1 - (iter - opts_.transports->begin()); |
451 | } |
452 | std::unique_ptr<TransportRegistration> reg = |
453 | TensorPipeTransportRegistry()->Create(key); |
454 | if (!reg->transport->isViable()) { |
455 | continue; |
456 | } |
457 | if (priority == -1) { |
458 | priority = reg->priority; |
459 | } |
460 | if (priority < lowestPriority) { |
461 | lowestPriority = priority; |
462 | lowestPriorityTransport = key; |
463 | } |
464 | addresses.push_back(c10::str(key, "://" , reg->address)); |
465 | context_->registerTransport( |
466 | priority, std::move(key), std::move(reg->transport)); |
467 | } |
468 | |
469 | // Register channels |
470 | for (auto& key : TensorPipeChannelRegistry()->Keys()) { |
471 | int64_t priority = -1; |
472 | if (opts_.channels.has_value()) { |
473 | auto iter = |
474 | std::find(opts_.channels->begin(), opts_.channels->end(), key); |
475 | if (iter == opts_.channels->end()) { |
476 | continue; |
477 | } |
478 | // Assign priorities in reverse order of occurrence in the vector, so |
479 | // that a channel that comes before another receives a higher priority. |
480 | priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin()); |
481 | } |
482 | std::unique_ptr<ChannelRegistration> reg = |
483 | TensorPipeChannelRegistry()->Create(key); |
484 | if (!reg->channel->isViable()) { |
485 | continue; |
486 | } |
487 | if (priority == -1) { |
488 | priority = reg->priority; |
489 | } |
490 | context_->registerChannel( |
491 | priority, std::move(key), std::move(reg->channel)); |
492 | } |
493 | |
494 | listener_ = context_->listen(addresses); |
495 | |
496 | // Store our own url. |
497 | const auto address = listener_->url(lowestPriorityTransport); |
498 | nameToAddressStore_.set(workerInfo_.name_, address); |
499 | |
500 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address " |
501 | << address; |
502 | |
503 | for (const auto& p : workerNameToInfo_) { |
504 | const auto& name = p.first; |
505 | auto nodeAddrData = nameToAddressStore_.get(name); |
506 | auto nodeAddrStr = |
507 | std::string((const char*)nodeAddrData.data(), nodeAddrData.size()); |
508 | workerNameToURL_.insert({name, nodeAddrStr}); |
509 | } |
510 | |
511 | // Start the Timeout Thread |
512 | timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this); |
513 | |
514 | listener_->accept([this]( |
515 | const tensorpipe::Error& error, |
516 | std::shared_ptr<tensorpipe::Pipe> pipe) { |
517 | onListenerAccepted(error, pipe); |
518 | }); |
519 | } |
520 | |
521 | void TensorPipeAgent::onListenerAccepted( |
522 | const tensorpipe::Error& error, |
523 | std::shared_ptr<tensorpipe::Pipe>& pipe) { |
524 | if (error) { |
525 | if (error.isOfType<tensorpipe::ListenerClosedError>() && |
526 | !rpcAgentRunning_.load()) { |
527 | // This is expected. |
528 | } else { |
529 | LOG(WARNING) << "RPC agent for " << workerInfo_.name_ |
530 | << " encountered error when accepting incoming pipe: " |
531 | << error.what(); |
532 | } |
533 | return; |
534 | } |
535 | |
536 | // Accept the next connection request |
537 | listener_->accept([this]( |
538 | const tensorpipe::Error& error, |
539 | std::shared_ptr<tensorpipe::Pipe> pipe) { |
540 | onListenerAccepted(error, pipe); |
541 | }); |
542 | |
543 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
544 | << " accepted incoming pipe from " << pipe->getRemoteName(); |
545 | |
546 | // Arm for server read |
547 | respond(pipe); |
548 | } |
549 | |
550 | void TensorPipeAgent::pipeRead( |
551 | const std::shared_ptr<tensorpipe::Pipe>& pipe, |
552 | std::function<void( |
553 | const tensorpipe::Error&, |
554 | c10::intrusive_ptr<Message>, |
555 | std::vector<c10::Stream>)> fn) noexcept { |
556 | pipe->readDescriptor([this, fn{std::move(fn)}, pipe]( |
557 | const tensorpipe::Error& error, |
558 | tensorpipe::Descriptor tpDescriptor) mutable { |
559 | if (error) { |
560 | fn(error, c10::intrusive_ptr<Message>(), {}); |
561 | return; |
562 | } |
563 | |
564 | std::vector<c10::Stream> streams; |
565 | { |
566 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
567 | streams = getStreamsFromPoolForDevices(devices_); |
568 | } |
569 | tensorpipe::Allocation tpAllocation; |
570 | TensorpipeReadBuffers tpBuffers; |
571 | std::tie(tpAllocation, tpBuffers) = |
572 | tensorpipeAllocate(tpDescriptor, streams); |
573 | |
574 | pipe->read( |
575 | std::move(tpAllocation), |
576 | [tpDescriptor{std::move(tpDescriptor)}, |
577 | tpBuffers{ |
578 | std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))}, |
579 | fn{std::move(fn)}, |
580 | streams{std::move(streams)}](const tensorpipe::Error& error) mutable { |
581 | if (error) { |
582 | fn(error, c10::intrusive_ptr<Message>(), {}); |
583 | return; |
584 | } |
585 | |
586 | // FIXME This does some unpickling, which could be a bit expensive: |
587 | // perhaps it would be best to perform it inside the worker threads? |
588 | c10::intrusive_ptr<Message> rpcMessage = tensorpipeDeserialize( |
589 | std::move(tpDescriptor), std::move(*tpBuffers)); |
590 | |
591 | fn(error, std::move(rpcMessage), std::move(streams)); |
592 | }); |
593 | }); |
594 | } |
595 | |
596 | void TensorPipeAgent::pipeWrite( |
597 | const std::shared_ptr<tensorpipe::Pipe>& pipe, |
598 | c10::intrusive_ptr<Message> rpcMessage, |
599 | std::vector<c10::Device>&& devices, |
600 | std::vector<c10::Stream> streams, |
601 | std::function<void(const tensorpipe::Error&)> fn) noexcept { |
602 | tensorpipe::Message tpMessage; |
603 | TensorpipeWriteBuffers tpBuffers; |
604 | |
605 | std::tie(tpMessage, tpBuffers) = |
606 | tensorpipeSerialize(std::move(rpcMessage), std::move(devices), streams); |
607 | |
608 | pipe->write( |
609 | std::move(tpMessage), |
610 | [tpBuffers{ |
611 | std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))}, |
612 | fn{std::move(fn)}, |
613 | streams{std::move(streams)}](const tensorpipe::Error& error) { |
614 | fn(error); |
615 | }); |
616 | } |
617 | |
618 | void TensorPipeAgent::sendCompletedResponseMessage( |
619 | std::shared_ptr<tensorpipe::Pipe>& pipe, |
620 | JitFuture& futureResponseMessage, |
621 | uint64_t messageId, |
622 | std::vector<c10::Stream> streams) { |
623 | if (!rpcAgentRunning_.load()) { |
624 | LOG(WARNING) << "RPC agent for " << workerInfo_.name_ |
625 | << " won't send response to request #" << messageId << " to " |
626 | << pipe->getRemoteName() << ", as the agent is shutting down" ; |
627 | return; |
628 | } |
629 | |
630 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
631 | << " is sending response to request #" << messageId << " to " |
632 | << pipe->getRemoteName(); |
633 | |
634 | if (!futureResponseMessage.hasError()) { |
635 | c10::intrusive_ptr<Message> responseMessage = |
636 | futureResponseMessage.value().toCustomClass<Message>(); |
637 | responseMessage->setId(messageId); |
638 | |
639 | std::vector<c10::Device> devices; |
640 | try { |
641 | devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage); |
642 | } catch (const std::exception& e) { |
643 | responseMessage = createExceptionResponse(e.what(), messageId); |
644 | } |
645 | |
646 | for (const auto& tensor : responseMessage->tensors()) { |
647 | const auto device = tensor.device(); |
648 | if (!device.is_cpu()) { |
649 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
650 | if (std::find(devices_.begin(), devices_.end(), device) == |
651 | devices_.end()) { |
652 | std::ostringstream oss; |
653 | std::copy( |
654 | devices_.begin(), |
655 | devices_.end(), |
656 | std::ostream_iterator<c10::Device>(oss, ", " )); |
657 | responseMessage = createExceptionResponse( |
658 | c10::str( |
659 | "RPC detected that a user-function output tensor on device " , |
660 | device, |
661 | ". This device is not one of the input tensor devices: " , |
662 | oss.str(), |
663 | "which is not yet supported. Please file a feature request " |
664 | "issue in PyTorch GitHub repo." ), |
665 | messageId); |
666 | break; |
667 | } |
668 | } |
669 | } |
670 | |
671 | pipeWrite( |
672 | pipe, |
673 | std::move(responseMessage), |
674 | std::move(devices), |
675 | std::move(streams), |
676 | [this, pipe, messageId](const tensorpipe::Error& error) { |
677 | if (error) { |
678 | LOG(WARNING) |
679 | << "RPC agent for " << workerInfo_.name_ |
680 | << " encountered error when sending response to request #" |
681 | << messageId << " to " << pipe->getRemoteName() << ": " |
682 | << error.what(); |
683 | return; |
684 | } |
685 | |
686 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
687 | << " done sending response to request #" << messageId |
688 | << " to " << pipe->getRemoteName(); |
689 | }); |
690 | } else { |
691 | pipeWrite( |
692 | pipe, |
693 | createExceptionResponse( |
694 | futureResponseMessage.tryRetrieveErrorMessage(), messageId), |
695 | /* devices */ {}, |
696 | std::move(streams), |
697 | [this, pipe, messageId](const tensorpipe::Error& error) { |
698 | if (error) { |
699 | LOG(WARNING) |
700 | << "RPC agent for " << workerInfo_.name_ |
701 | << " encountered error when sending response to request #" |
702 | << messageId << " to " << pipe->getRemoteName() << ": " |
703 | << error.what(); |
704 | return; |
705 | } |
706 | |
707 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
708 | << " done sending response to request #" << messageId |
709 | << " to " << pipe->getRemoteName(); |
710 | }); |
711 | } |
712 | } |
713 | |
714 | void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) { |
715 | pipeRead( |
716 | pipe, |
717 | [this, pipe]( |
718 | const tensorpipe::Error& error, |
719 | c10::intrusive_ptr<Message> requestMessage, |
720 | std::vector<c10::Stream> streams) mutable { |
721 | if (error) { |
722 | if (shuttingDown_) { |
723 | // This is expected. |
724 | } else { |
725 | LOG(WARNING) |
726 | << "RPC agent for " << workerInfo_.name_ |
727 | << " encountered error when reading incoming request from " |
728 | << pipe->getRemoteName() << ": " << error.what(); |
729 | } |
730 | return; |
731 | } |
732 | |
733 | // Arm for next read |
734 | respond(pipe); |
735 | |
736 | uint64_t messageId = requestMessage->id(); |
737 | increaseCallCount(serverActiveCalls_); |
738 | |
739 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
740 | << " received request #" << messageId << " from " |
741 | << pipe->getRemoteName(); |
742 | |
743 | // Defer user RPC UDF run to thread pool |
744 | threadPool_.run([this, |
745 | pipe, |
746 | messageId, |
747 | requestMessage{std::move(requestMessage)}, |
748 | streams{std::move(streams)}]() mutable { |
749 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
750 | << " is running request #" << messageId << " from " |
751 | << pipe->getRemoteName() << " in thread pool" ; |
752 | |
753 | c10::intrusive_ptr<JitFuture> futureResponseMessage; |
754 | try { |
755 | // Instead of creating a MultiStreamGuard here, the ctx is passed |
756 | // to the callback and the MultiStreamGuard is created there, |
757 | // because subsequent processing can switch threads due to 1) |
758 | // waiting for RRef arguments to become ready 2) async_execution. |
759 | // Besides, the `ctx` also needs to be propagated to |
760 | // `process***Call` methods to synchronize CUDA streams there |
761 | // to make sure that we fetch the correct value from `to_here()` |
762 | // call. |
763 | futureResponseMessage = |
764 | cb_->operator()(*requestMessage, std::move(streams)); |
765 | } catch (const std::exception& /* unused */) { |
766 | futureResponseMessage = |
767 | c10::make_intrusive<JitFuture>(at::AnyClassType::get()); |
768 | futureResponseMessage->setError(std::current_exception()); |
769 | } |
770 | |
771 | increaseCallCount(serverActiveAsyncCalls_); |
772 | futureResponseMessage->addCallback( |
773 | [this, pipe, messageId]( |
774 | JitFuture& futureResponseMessage) mutable { |
775 | decreaseCallCount(serverActiveCalls_); |
776 | decreaseCallCount(serverActiveAsyncCalls_); |
777 | auto streams = getCurrentStreamsForDevices( |
778 | futureResponseMessage.devices()); |
779 | sendCompletedResponseMessage( |
780 | pipe, futureResponseMessage, messageId, std::move(streams)); |
781 | }); |
782 | |
783 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
784 | << " done running request #" << messageId << " from " |
785 | << pipe->getRemoteName() << " in thread pool" ; |
786 | }); |
787 | }); |
788 | } |
789 | |
790 | c10::intrusive_ptr<JitFuture> TensorPipeAgent::send( |
791 | const WorkerInfo& toWorkerInfo, |
792 | c10::intrusive_ptr<Message> requestMessage, |
793 | const float rpcTimeoutSeconds, |
794 | const DeviceMap& deviceMap) { |
795 | TORCH_CHECK( |
796 | requestMessage->isRequest(), |
797 | "TensorPipeAgent::send(..) is only for sending requests." ); |
798 | |
799 | if (!rpcAgentRunning_.load()) { |
800 | auto err = c10::str( |
801 | "Node " , |
802 | RpcAgent::getWorkerInfo().id_, |
803 | "tried to send() a message of type " , |
804 | requestMessage->type(), |
805 | " but RPC is no longer running on this node." ); |
806 | TORCH_CHECK(false, err); |
807 | } |
808 | |
809 | const auto& url = findWorkerURL(toWorkerInfo); |
810 | |
811 | decltype(connectedPipes_)::iterator it; |
812 | { |
813 | std::unique_lock<std::mutex> lock(connectedPipesMutex_); |
814 | |
815 | // See if we already have a connection to this address or not |
816 | it = connectedPipes_.find(toWorkerInfo.id_); |
817 | if (it == connectedPipes_.end()) { |
818 | // An instance of ClientPipe cannot be copied or moved as it contains a |
819 | // mutex, and to force in-place construction in GCC 5 we need piecewise |
820 | // construction in order to work around an issue. |
821 | std::tie(it, std::ignore) = connectedPipes_.emplace( |
822 | std::piecewise_construct, |
823 | std::forward_as_tuple(toWorkerInfo.id_), |
824 | std::forward_as_tuple(context_->connect( |
825 | url, tensorpipe::PipeOptions().remoteName(toWorkerInfo.name_)))); |
826 | } |
827 | } |
828 | ClientPipe& clientPipe = it->second; |
829 | |
830 | std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture> |
831 | futureResponseMessage; |
832 | { |
833 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
834 | futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_); |
835 | } |
836 | uint64_t messageId = nextMessageID_++; |
837 | requestMessage->setId(messageId); |
838 | |
839 | { |
840 | std::unique_lock<std::mutex> lock(clientPipe.mutex_); |
841 | clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage; |
842 | } |
843 | |
844 | // Get devices for tensors in the request message. This can throw if device |
845 | // maps are not configured properly for this request. |
846 | std::vector<c10::Device> devices; |
847 | if (deviceMap.empty()) { |
848 | devices = |
849 | getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage); |
850 | } else { |
851 | // If deviceMap is specified, use that instead. |
852 | devices = getDevicesForTensors( |
853 | requestMessage->tensors(), |
854 | deviceMap, |
855 | clientPipe.pipe_->getRemoteName()); |
856 | } |
857 | |
858 | futureResponseMessage->jitFuture->addCallback( |
859 | [this](JitFuture& /* unused */) { |
860 | TORCH_INTERNAL_ASSERT( |
861 | this->threadPool_.inThreadPool(), |
862 | "Future marked complete from outside the thread pool" ); |
863 | }); |
864 | |
865 | increaseCallCount(clientActiveCalls_); |
866 | // Use the default RPC timeout if no timeout is specified for this send call |
867 | auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout |
868 | ? getRpcTimeout() |
869 | : std::chrono::milliseconds( |
870 | static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion)); |
871 | |
872 | // We only add to the timeoutMap_ if the timeout is not 0. Per our |
873 | // documentation, a user-provided timeout of 0 indicates the RPC should never |
874 | // expire (infinite timeout), so there is no need to track it in the |
875 | // timeoutMap_. |
876 | steady_clock_time_point expirationTime; |
877 | if (timeout.count() != 0) { |
878 | // Compute the expiration time for this message based on the timeout |
879 | expirationTime = computeRpcMessageExpiryTime(timeout); |
880 | |
881 | // Add the Future to the right vector in the timeoutMap_ |
882 | { |
883 | std::unique_lock<std::mutex> lock(timeoutMapMutex_); |
884 | auto& timeoutFuturesVector = timeoutMap_[expirationTime]; |
885 | messageIdToTimeout_.emplace(messageId, expirationTime); |
886 | timeoutFuturesVector.emplace_back( |
887 | messageId, futureResponseMessage, timeout); |
888 | } |
889 | timeoutThreadCV_.notify_one(); |
890 | } |
891 | |
892 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #" |
893 | << messageId << " to " << clientPipe.pipe_->getRemoteName(); |
894 | |
895 | std::vector<c10::Stream> streams; |
896 | { |
897 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
898 | streams = getStreamsFromPoolForDevices(devices_); |
899 | } |
900 | makeStreamsWaitOnOthers( |
901 | streams, |
902 | getCurrentStreamsForDevices( |
903 | getDevicesOfTensors(requestMessage->tensors()))); |
904 | pipeWrite( |
905 | clientPipe.pipe_, |
906 | std::move(requestMessage), |
907 | std::move(devices), |
908 | std::move(streams), |
909 | [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable { |
910 | if (error) { |
911 | if (error.isOfType<tensorpipe::PipeClosedError>() && |
912 | !rpcAgentRunning_.load()) { |
913 | // This is expected. |
914 | } else { |
915 | LOG(WARNING) << "RPC agent for " << workerInfo_.name_ |
916 | << " encountered error when sending outgoing request #" |
917 | << messageId << " to " |
918 | << clientPipe.pipe_->getRemoteName() << ": " |
919 | << error.what(); |
920 | } |
921 | handleClientError(clientPipe, error); |
922 | return; |
923 | } |
924 | |
925 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #" |
926 | << messageId << " to " << clientPipe.pipe_->getRemoteName(); |
927 | |
928 | pipeRead( |
929 | clientPipe.pipe_, |
930 | [this, &clientPipe]( |
931 | const tensorpipe::Error& error, |
932 | c10::intrusive_ptr<Message> responseMessage, |
933 | std::vector<c10::Stream> streams) { |
934 | if (error) { |
935 | if (error.isOfType<tensorpipe::PipeClosedError>() && |
936 | !rpcAgentRunning_.load()) { |
937 | // This is expected. |
938 | } else { |
939 | LOG(WARNING) |
940 | << "RPC agent for " << workerInfo_.name_ |
941 | << " encountered error when reading incoming response from " |
942 | << clientPipe.pipe_->getRemoteName() << ": " |
943 | << error.what(); |
944 | } |
945 | handleClientError(clientPipe, error); |
946 | return; |
947 | } |
948 | |
949 | // Identify future response message by message ID |
950 | uint64_t messageId = responseMessage->id(); |
951 | |
952 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
953 | << " received response #" << messageId << " from " |
954 | << clientPipe.pipe_->getRemoteName(); |
955 | |
956 | std::shared_ptr<AtomicJitFuture> futureResponseMessage; |
957 | { |
958 | std::lock_guard<std::mutex> lock(clientPipe.mutex_); |
959 | // A read error will lead all following callbacks to be |
960 | // invoked with error, and shouldn't reach here. |
961 | TORCH_INTERNAL_ASSERT( |
962 | !clientPipe.inError_, "Shouldn't be in error state" ); |
963 | auto it = clientPipe.pendingResponseMessage_.find(messageId); |
964 | TORCH_INTERNAL_ASSERT( |
965 | it != clientPipe.pendingResponseMessage_.end(), |
966 | "message ID " , |
967 | messageId, |
968 | " is not recognized" ); |
969 | futureResponseMessage = std::move(it->second); |
970 | clientPipe.pendingResponseMessage_.erase(it); |
971 | } |
972 | |
973 | // Remove entry from timeoutMap_. |
974 | removeFromTimeoutMap(messageId); |
975 | |
976 | if (responseMessage->type() == MessageType::EXCEPTION) { |
977 | markFutureWithError( |
978 | std::move(futureResponseMessage), |
979 | std::string( |
980 | responseMessage->payload().begin(), |
981 | responseMessage->payload().end())); |
982 | } else { |
983 | markFutureAsComplete( |
984 | std::move(futureResponseMessage), |
985 | std::move(responseMessage), |
986 | std::move(streams)); |
987 | } |
988 | }); |
989 | }); |
990 | |
991 | return futureResponseMessage->jitFuture; |
992 | } |
993 | |
994 | void TensorPipeAgent::handleClientError( |
995 | ClientPipe& clientPipe, |
996 | const tensorpipe::Error& error) { |
997 | // When an error occurs on a pipe all pending operations will be aborted and |
998 | // all callbacks invoked with error, hence we immediately flush all future |
999 | // messages belonging to this pipe. |
1000 | decltype(clientPipe.pendingResponseMessage_) pendingMsgs; |
1001 | { |
1002 | std::lock_guard<std::mutex> lock(clientPipe.mutex_); |
1003 | std::swap(clientPipe.pendingResponseMessage_, pendingMsgs); |
1004 | clientPipe.inError_ = true; |
1005 | } |
1006 | std::string errorMsg = error.what(); |
1007 | for (auto& p : pendingMsgs) { |
1008 | markFutureWithError(std::move(p.second), errorMsg); |
1009 | |
1010 | // Remove entry from timeoutMap_. |
1011 | removeFromTimeoutMap(p.first); |
1012 | } |
1013 | } |
1014 | |
1015 | void TensorPipeAgent::pollTimeoutRpcs() { |
1016 | while (rpcAgentRunning_.load()) { |
1017 | std::unique_lock<std::mutex> lock(timeoutMapMutex_); |
1018 | |
1019 | // We sleep until the earliest expiring RPC in the timeoutMap_. We must |
1020 | // also ensure that we sleep while the map is empty, and we exit sleeping |
1021 | // if the RPC Agent has been shutdown. |
1022 | for (;;) { |
1023 | if (!rpcAgentRunning_.load()) { |
1024 | return; |
1025 | } |
1026 | |
1027 | if (!timeoutMap_.empty()) { |
1028 | steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first; |
1029 | if (std::chrono::steady_clock::now() >= earliestTimeout) { |
1030 | break; |
1031 | } |
1032 | timeoutThreadCV_.wait_until(lock, earliestTimeout); |
1033 | } else { |
1034 | timeoutThreadCV_.wait(lock); |
1035 | } |
1036 | } |
1037 | |
1038 | // Move all these futures to a separate vector so we can process them |
1039 | // outside the lock. |
1040 | std::vector<TimeoutMessageMetadata> timedOutFutures = |
1041 | std::move(timeoutMap_.begin()->second); |
1042 | |
1043 | // We can safely remove this key from the timeoutMap_ since all these |
1044 | // futures will be processed. |
1045 | timeoutMap_.erase(timeoutMap_.begin()); |
1046 | |
1047 | for (auto& timeoutMetadata : timedOutFutures) { |
1048 | // Remove from messageIdToTimeout map. |
1049 | messageIdToTimeout_.erase(timeoutMetadata.messageId); |
1050 | } |
1051 | lock.unlock(); |
1052 | |
1053 | // Set an error on futures added to the timedOutFutures vector. We do this |
1054 | // outside the lock to prevent potential lock-order-inversions by callbacks |
1055 | // triggered by the setError call. |
1056 | for (auto& timeoutMetadata : timedOutFutures) { |
1057 | std::string errorMsg = |
1058 | fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count()); |
1059 | auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT); |
1060 | markFutureWithError( |
1061 | std::move(timeoutMetadata.responseFuture), std::move(err)); |
1062 | } |
1063 | } |
1064 | } |
1065 | |
1066 | void TensorPipeAgent::leaveGroup() { |
1067 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1068 | // local worker ActiveCallCount is 0 at this point and we will shutdown |
1069 | // (any future calls will be dropped) |
1070 | callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; }); |
1071 | |
1072 | // Remove this agent's WorkerInfo from store |
1073 | removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_); |
1074 | |
1075 | // Set internal variable to be used during destructor |
1076 | shuttingDown_ = true; |
1077 | } |
1078 | |
1079 | // TODO: Remove join() |
1080 | void TensorPipeAgent::join(bool shutdown, float /* unused */) { |
1081 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining" ; |
1082 | if (!isStaticGroup_) { |
1083 | leaveGroup(); |
1084 | return; |
1085 | } |
1086 | |
1087 | // This method behaves like a barrier, as it can only return once all workers |
1088 | // have no more requests pending, including "nested" requests (triggered from |
1089 | // within the remote code of another call) and "follow-up" requests (triggered |
1090 | // from the callback of a future). |
1091 | while (true) { |
1092 | { |
1093 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1094 | // It is enough to wait for there to be no more active client calls, since |
1095 | // each server call corresponds to a client call for some other worker. |
1096 | callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; }); |
1097 | |
1098 | // We'd like to immediately proceed with the allreduce, but it's a call |
1099 | // that may block for some time, as it waits for other workers to also |
1100 | // complete all their active client calls. While we call allreduce we must |
1101 | // hold the mutex, or else the count we send to other workers may get |
1102 | // stale (e.g., if some nested call happens in the meantime). But we can't |
1103 | // hold the lock for an indeterminately long time, as that would block |
1104 | // other operations (e.g., send). Thus we must release the lock and only |
1105 | // re-acquire it when all workers are ready to proceed with the allreduce. |
1106 | // We perform this synchronization using a barrier. |
1107 | } |
1108 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1109 | << " completed all client calls and is entering a barrier" ; |
1110 | syncCallCount(shutdownStore_, worldSize_); |
1111 | { |
1112 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1113 | // At this point, the count may have become non-zero again. We can't wait |
1114 | // for those calls to complete as other workers are waiting for us in the |
1115 | // allreduce and we would block them. Thus we send our count even if it is |
1116 | // non-zero and if anyone (be it us or another worker) has a non-zero |
1117 | // count we'll just do another round. |
1118 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1119 | << " exited the barrier and found " << clientActiveCalls_ |
1120 | << " active client calls" ; |
1121 | int totalClientActiveCalls = |
1122 | syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_); |
1123 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1124 | << " completed sync call counts and got a total of " |
1125 | << totalClientActiveCalls |
1126 | << " active client calls across all workers" ; |
1127 | if (totalClientActiveCalls == 0) { |
1128 | if (shutdown) { |
1129 | shuttingDown_ = true; |
1130 | syncCallCount(shutdownStore_, worldSize_); |
1131 | } |
1132 | break; |
1133 | } |
1134 | } |
1135 | } |
1136 | VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining" ; |
1137 | } |
1138 | |
1139 | void TensorPipeAgent::shutdownImpl() { |
1140 | // FIXME Isn't it too verbose for a library to print logs in normal operation? |
1141 | LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down" ; |
1142 | |
1143 | // Join the Timeout Thread |
1144 | timeoutThreadCV_.notify_one(); |
1145 | if (timeoutThread_.joinable()) { |
1146 | timeoutThread_.join(); |
1147 | } |
1148 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1149 | << " done waiting for timeout thread to join" ; |
1150 | |
1151 | // This will close all the pipes and listeners, invoke all callbacks with |
1152 | // errors, turn down the I/O event loops and wait for everything to terminate. |
1153 | context_->join(); |
1154 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1155 | << " done waiting for TensorPipe context to join" ; |
1156 | |
1157 | // NOTE: We need to call waitWorkComplete in the end after we have shutdown |
1158 | // all listeners for Tensorpipe. This is to drain any already accepted work |
1159 | // in the ThreadPool. If this is done before we shutdown the listeners, |
1160 | // additional work could be added after this call and before we shutdown |
1161 | // listeners. This work would continue executing in the threadpool and might |
1162 | // cause issues during shutdown of the system. |
1163 | threadPool_.waitWorkComplete(); |
1164 | VLOG(1) << "RPC agent for " << workerInfo_.name_ |
1165 | << " done waiting for thread pool to complete work" ; |
1166 | } |
1167 | |
1168 | const WorkerInfo& TensorPipeAgent::getWorkerInfo( |
1169 | const std::string& workerName) const { |
1170 | std::unordered_map<std::string, WorkerInfo>::const_iterator it; |
1171 | { |
1172 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1173 | it = workerNameToInfo_.find(workerName); |
1174 | } |
1175 | TORCH_CHECK( |
1176 | it != workerNameToInfo_.end(), |
1177 | fmt::format( |
1178 | "name:{},rank:{} could not find destination name {}" , |
1179 | workerInfo_.name_, |
1180 | workerInfo_.id_, |
1181 | workerName)); |
1182 | return it->second; |
1183 | } |
1184 | |
1185 | const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const { |
1186 | std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it; |
1187 | { |
1188 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1189 | it = workerIdToInfo_.find(workerId); |
1190 | } |
1191 | TORCH_CHECK( |
1192 | it != workerIdToInfo_.end(), |
1193 | fmt::format( |
1194 | "name:{},rank:{} could not find destination id {}" , |
1195 | workerInfo_.name_, |
1196 | workerInfo_.id_, |
1197 | workerId)); |
1198 | return it->second; |
1199 | } |
1200 | |
1201 | std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const { |
1202 | std::vector<WorkerInfo> workerInfos; |
1203 | workerInfos.reserve(workerNameToInfo_.size()); |
1204 | for (auto& item : workerNameToInfo_) { |
1205 | workerInfos.emplace_back(item.second); |
1206 | } |
1207 | return workerInfos; |
1208 | } |
1209 | |
1210 | const std::string& TensorPipeAgent::findWorkerURL( |
1211 | const WorkerInfo& worker) const { |
1212 | std::unordered_map<std::string, std::string>::const_iterator it; |
1213 | { |
1214 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1215 | it = workerNameToURL_.find(worker.name_); |
1216 | } |
1217 | TORCH_CHECK( |
1218 | it != workerNameToURL_.end(), |
1219 | fmt::format( |
1220 | "name:{},rank:{} could not find destination url for name {}" , |
1221 | workerInfo_.name_, |
1222 | workerInfo_.id_, |
1223 | worker.name_)); |
1224 | return it->second; |
1225 | } |
1226 | |
1227 | void TensorPipeAgent::updateGroupMembership( |
1228 | const WorkerInfo& workerInfo, |
1229 | const std::vector<c10::Device> devices, |
1230 | const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, |
1231 | bool isJoin) { |
1232 | std::string name = workerInfo.name_; |
1233 | worker_id_t id = workerInfo.id_; |
1234 | // Rank with workerInfo is joining the group, update internal mappings |
1235 | if (isJoin) { |
1236 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1237 | workerIdToInfo_.emplace(id, workerInfo); |
1238 | workerNameToInfo_.emplace(name, workerInfo); |
1239 | |
1240 | // TODO: we should get nodeAddrStr in the joining process, then pass in as |
1241 | // an argument rather than getting from store each time |
1242 | auto nodeAddrData = nameToAddressStore_.get(name); |
1243 | auto nodeAddrStr = |
1244 | std::string((const char*)nodeAddrData.data(), nodeAddrData.size()); |
1245 | workerNameToURL_.insert({name, nodeAddrStr}); |
1246 | |
1247 | for (const auto& it : reverseDeviceMaps) { |
1248 | if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) { |
1249 | reverseDeviceMaps_[it.first] = it.second; |
1250 | } |
1251 | } |
1252 | // TODO: clean up mutex for devices_ usage |
1253 | // Add devices that have not been added yet |
1254 | for (const auto& it : devices) { |
1255 | if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) { |
1256 | devices_.push_back(it); |
1257 | } |
1258 | } |
1259 | } else { |
1260 | workerIdToInfo_.erase(id); |
1261 | workerNameToInfo_.erase(name); |
1262 | workerNameToURL_.erase(name); |
1263 | |
1264 | // remove reverse device maps that are no longer used |
1265 | for (auto it = reverseDeviceMaps_.begin(); |
1266 | it != reverseDeviceMaps_.end();) { |
1267 | if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) { |
1268 | it = reverseDeviceMaps_.erase(it); |
1269 | } else { |
1270 | it++; |
1271 | } |
1272 | } |
1273 | |
1274 | // remove devices that are no longer used |
1275 | for (auto it = devices_.begin(); it != devices_.end();) { |
1276 | if (std::find(devices.begin(), devices.end(), *it) == devices.end()) { |
1277 | it = devices_.erase(it); |
1278 | } else { |
1279 | it++; |
1280 | } |
1281 | } |
1282 | } |
1283 | } |
1284 | std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() { |
1285 | std::unordered_map<std::string, std::string> metrics; |
1286 | metrics[kThreadPoolSize] = c10::to_string(threadPool_.size()); |
1287 | metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable()); |
1288 | { |
1289 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1290 | metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_); |
1291 | metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_); |
1292 | metrics[kServerActiveAsyncCalls] = c10::to_string(serverActiveAsyncCalls_); |
1293 | } |
1294 | if (isGILProfilingEnabled()) { |
1295 | { |
1296 | std::unique_lock<std::mutex> lock(metricsMutex_); |
1297 | // Include the averages for each time series metric. This is just the GIL |
1298 | // Wait Time for now. |
1299 | auto averageGilWaitTime = |
1300 | timeSeriesMetrics_[kGilAverageWaitTime].computeAverage(); |
1301 | lock.unlock(); |
1302 | metrics[kGilAverageWaitTime] = c10::to_string(averageGilWaitTime); |
1303 | } |
1304 | } |
1305 | |
1306 | return metrics; |
1307 | } |
1308 | |
1309 | void TensorPipeAgent::addGilWaitTime( |
1310 | const std::chrono::microseconds gilWaitTime) { |
1311 | std::lock_guard<std::mutex> lock(metricsMutex_); |
1312 | timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count()); |
1313 | } |
1314 | |
1315 | TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() { |
1316 | std::lock_guard<std::mutex> lock(networkDataMutex_); |
1317 | return networkData_; |
1318 | } |
1319 | |
1320 | NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() { |
1321 | NetworkSourceInfo info = { |
1322 | RpcAgent::getWorkerInfo().id_, |
1323 | nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)}; |
1324 | |
1325 | return info; |
1326 | } |
1327 | |
1328 | void TensorPipeAgent::trackNetworkData( |
1329 | uint64_t requestSize, |
1330 | uint64_t responseSize, |
1331 | const std::string& destWorkerName) { |
1332 | std::lock_guard<std::mutex> lock(networkDataMutex_); |
1333 | networkData_[destWorkerName].numCalls++; |
1334 | networkData_[destWorkerName].totalSentBytes += requestSize; |
1335 | networkData_[destWorkerName].totalRecvBytes += responseSize; |
1336 | } |
1337 | |
1338 | void TensorPipeAgent::trackNetworkError( |
1339 | uint64_t requestSize, |
1340 | const std::string& destWorkerName) { |
1341 | std::lock_guard<std::mutex> lock(networkDataMutex_); |
1342 | networkData_[destWorkerName].numCalls++; |
1343 | networkData_[destWorkerName].totalSentBytes += requestSize; |
1344 | networkData_[destWorkerName].totalErrors++; |
1345 | } |
1346 | |
1347 | void TensorPipeAgent::increaseCallCount(int32_t& count) { |
1348 | { |
1349 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1350 | ++count; |
1351 | } |
1352 | callCountCV_.notify_all(); |
1353 | } |
1354 | |
1355 | void TensorPipeAgent::decreaseCallCount(int32_t& count) { |
1356 | { |
1357 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1358 | --count; |
1359 | } |
1360 | callCountCV_.notify_all(); |
1361 | } |
1362 | |
1363 | void TensorPipeAgent::markFutureAsComplete( |
1364 | std::shared_ptr<AtomicJitFuture> atomicFuture, |
1365 | c10::intrusive_ptr<Message> message, |
1366 | std::vector<c10::Stream> streams) { |
1367 | if (!atomicFuture->isComplete.test_and_set()) { |
1368 | // Completing the future will run its callbacks, which could execute |
1369 | // arbitrary user code. To prevent blocking or stalling the TensorPipe event |
1370 | // loops, we defer this to a worker thread. |
1371 | threadPool_.run([this, |
1372 | atomicFuture{std::move(atomicFuture)}, |
1373 | message{std::move(message)}, |
1374 | streams{std::move(streams)}]() mutable { |
1375 | c10::MultiStreamGuard guard(streams); |
1376 | std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages = |
1377 | message->getStorages(); |
1378 | atomicFuture->jitFuture->markCompleted( |
1379 | std::move(message), std::move(storages)); |
1380 | // The future's callbacks may schedule further RPCs, increasing the count. |
1381 | // Thus we must decrease it after completing the future, otherwise it may |
1382 | // briefly dip to zero and trick join into thinking all work is done. |
1383 | decreaseCallCount(clientActiveCalls_); |
1384 | }); |
1385 | } |
1386 | } |
1387 | |
1388 | void TensorPipeAgent::markFutureWithError( |
1389 | std::shared_ptr<AtomicJitFuture> atomicFuture, |
1390 | std::string errorMsg) { |
1391 | if (!atomicFuture->isComplete.test_and_set()) { |
1392 | // Completing the future will run its callbacks, which could execute |
1393 | // arbitrary user code. To prevent blocking or stalling the TensorPipe event |
1394 | // loops, we defer this to a worker thread. |
1395 | threadPool_.run([this, |
1396 | atomicFuture{std::move(atomicFuture)}, |
1397 | errorMsg{std::move(errorMsg)}]() mutable { |
1398 | atomicFuture->jitFuture->setError( |
1399 | std::make_exception_ptr(std::runtime_error(errorMsg))); |
1400 | // The future's callbacks may schedule further RPCs, increasing the count. |
1401 | // Thus we must decrease it after completing the future, otherwise it may |
1402 | // briefly dip to zero and trick join into thinking all work is done. |
1403 | decreaseCallCount(clientActiveCalls_); |
1404 | }); |
1405 | } |
1406 | } |
1407 | |
1408 | std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote( |
1409 | const std::string& remoteName, |
1410 | const Message& message) const { |
1411 | std::unordered_map<std::string, DeviceMap> deviceMaps; |
1412 | { |
1413 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1414 | deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_; |
1415 | } |
1416 | |
1417 | const auto errStr = c10::str( |
1418 | "TensorPipe RPC backend only supports CPU tensors by default, please " |
1419 | "move your tensors to CPU before sending them over RPC, or call " |
1420 | "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly " |
1421 | "configure device mapping. " , |
1422 | message.isRequest() ? "Request" : "Response" , |
1423 | " device mapping is not available for destination " , |
1424 | remoteName); |
1425 | |
1426 | const auto& iter = deviceMaps.find(remoteName); |
1427 | if (iter == deviceMaps.end()) { |
1428 | for (const auto& t : message.tensors()) { |
1429 | TORCH_CHECK( |
1430 | t.device().is_cpu(), |
1431 | errStr, |
1432 | ", but found tensor on device: " , |
1433 | t.device()); |
1434 | } |
1435 | return {}; |
1436 | } else { |
1437 | return getDevicesForTensors(message.tensors(), iter->second, errStr); |
1438 | } |
1439 | } |
1440 | |
1441 | DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const { |
1442 | auto it = opts_.deviceMaps.find(dst.name_); |
1443 | if (it == opts_.deviceMaps.end()) { |
1444 | return {}; |
1445 | } |
1446 | return it->second; |
1447 | } |
1448 | |
1449 | const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const { |
1450 | return store_; |
1451 | } |
1452 | |
1453 | TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const { |
1454 | return opts_; |
1455 | } |
1456 | |
1457 | const std::vector<c10::Device>& TensorPipeAgent::getDevices() const { |
1458 | GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_); |
1459 | return devices_; |
1460 | } |
1461 | |
1462 | size_t TensorPipeAgent::timeoutMapSize() { |
1463 | std::unique_lock<std::mutex> lock(timeoutMapMutex_); |
1464 | return timeoutMap_.size(); |
1465 | } |
1466 | |
1467 | size_t TensorPipeAgent::numPendingResponses() { |
1468 | std::unique_lock<std::mutex> lock(callCountMutex_); |
1469 | return clientActiveCalls_; |
1470 | } |
1471 | |
1472 | size_t TensorPipeAgent::messageIdToTimeoutMapSize() { |
1473 | std::unique_lock<std::mutex> lock(timeoutMapMutex_); |
1474 | return messageIdToTimeout_.size(); |
1475 | } |
1476 | |
1477 | } // namespace rpc |
1478 | } // namespace distributed |
1479 | } // namespace torch |
1480 | |
1481 | #endif // USE_TENSORPIPE |
1482 | |