1#pragma once
2
3#ifdef USE_TENSORPIPE
4
5#include <atomic>
6#include <thread>
7
8#include <c10/core/thread_pool.h>
9#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
10#include <torch/csrc/distributed/c10d/Store.hpp>
11#include <torch/csrc/distributed/rpc/rpc_agent.h>
12
13// Forward-declare the TensorPipe classes we need, to avoid including its
14// headers in PyTorch's ones and thus have it become a public dependency.
15
16namespace tensorpipe {
17
18class Context;
19class Error;
20class Listener;
21class Message;
22class Pipe;
23
24namespace transport {
25class Context;
26} // namespace transport
27
28namespace channel {
29class Context;
30} // namespace channel
31
32} // namespace tensorpipe
33
34namespace torch {
35namespace distributed {
36namespace rpc {
37
38// These priorities instruct TensorPipe on which transport/channel to pick
39// during handshake. Higher priorities will take precedence over lower ones.
40// The transport with lowest priority will be the one used to bootstrap pipes.
41
42constexpr int64_t kShmTransportPriority = 200;
43constexpr int64_t kIbvTransportPriority = 100;
44// The UV transport just uses TCP and should work everywhere, thus keep it last.
45constexpr int64_t kUvTransportPriority = 0;
46
47constexpr int64_t kCmaChannelPriority = 1200;
48constexpr int64_t kMultiplexedUvChannelPriority = 1100;
49// The basic channel reuses a transport as a channel, and is thus our fallback.
50constexpr int64_t kBasicChannelPriority = 1000;
51
52// CPU channel have higher priority than CUDA channels, since the latter might
53// handle CPU-to-CPU transfers, but will always be less efficient than their
54// CPU-only counterparts.
55constexpr int64_t kCudaIpcChannelPriority = 300;
56constexpr int64_t kCudaGdrChannelPriority = 200;
57constexpr int64_t kCudaXthChannelPriority = 400;
58constexpr int64_t kCudaBasicChannelPriority = 0;
59
60using steady_clock_time_point =
61 std::chrono::time_point<std::chrono::steady_clock>;
62
63struct TORCH_API TransportRegistration {
64 std::shared_ptr<tensorpipe::transport::Context> transport;
65 int64_t priority;
66 std::string address;
67};
68
69C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
70
71struct TORCH_API ChannelRegistration {
72 std::shared_ptr<tensorpipe::channel::Context> channel;
73 int64_t priority;
74};
75
76C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration);
77
78constexpr auto kDefaultNumWorkerThreads = 16;
79
80struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions {
81 TensorPipeRpcBackendOptions(
82 int numWorkerThreads,
83 optional<std::vector<std::string>> transports,
84 optional<std::vector<std::string>> channels,
85 float rpc_timeout,
86 std::string init_method,
87 std::unordered_map<std::string, DeviceMap> device_maps = {},
88 std::vector<c10::Device> devices = {})
89 : RpcBackendOptions(rpc_timeout, init_method),
90 numWorkerThreads(numWorkerThreads),
91 transports(std::move(transports)),
92 channels(std::move(channels)),
93 deviceMaps(std::move(device_maps)),
94 devices(std::move(devices)) {
95 TORCH_CHECK(
96 numWorkerThreads > 0,
97 "num_worker_threads must be positive, got ",
98 numWorkerThreads);
99
100 if (this->transports.has_value()) {
101 for (const std::string& transportName : this->transports.value()) {
102 TORCH_CHECK(
103 TensorPipeTransportRegistry()->Has(transportName),
104 "Unknown transport: ",
105 transportName);
106 }
107 }
108
109 if (this->channels.has_value()) {
110 for (const std::string& channelName : this->channels.value()) {
111 TORCH_CHECK(
112 TensorPipeChannelRegistry()->Has(channelName),
113 "Unknown channel: ",
114 channelName);
115 }
116 }
117 }
118
119 void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) {
120 auto iter = deviceMaps.find(workerName);
121 if (iter == deviceMaps.end()) {
122 deviceMaps[workerName] = deviceMap;
123 } else {
124 for (auto& entry : deviceMap) {
125 // c10::Device has no default constructor, hence map[device] dosn't work
126 // In C++-17 we can use insert_or_assign.
127 auto entryIter = iter->second.find(entry.first);
128 if (entryIter == iter->second.end()) {
129 iter->second.emplace(entry.first, entry.second);
130 } else {
131 entryIter->second = entry.second;
132 }
133 }
134 }
135 }
136
137 int numWorkerThreads;
138 const optional<std::vector<std::string>> transports;
139 const optional<std::vector<std::string>> channels;
140 std::unordered_map<std::string, DeviceMap> deviceMaps;
141 std::vector<c10::Device> devices;
142};
143
144// Struct to track the network source metrics
145struct TORCH_API NetworkSourceInfo {
146 worker_id_t srcRank;
147 std::vector<uint8_t> srcMachineAddr;
148};
149
150// Struct to track aggregated network metrics
151struct TORCH_API AggregatedNetworkData {
152 uint64_t numCalls{0};
153 uint64_t totalSentBytes{0};
154 uint64_t totalRecvBytes{0};
155 uint64_t totalErrors{0};
156};
157
158// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
159// to transparently move tensors and payloads through the fastest available
160// transport or channel. It acts like a hybrid RPC transport, providing shared
161// memory (linux) and TCP (linux & mac) support. CUDA support is in progress.
162class TORCH_API TensorPipeAgent : public RpcAgent {
163 public:
164 TensorPipeAgent(
165 const c10::intrusive_ptr<::c10d::Store>& store,
166 std::string selfName,
167 worker_id_t selfId,
168 optional<int> worldSize,
169 TensorPipeRpcBackendOptions opts,
170 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
171 std::vector<c10::Device> devices,
172 std::unique_ptr<RequestCallback> cb);
173
174 TensorPipeAgent(const TensorPipeAgent&) = delete;
175 TensorPipeAgent& operator=(const TensorPipeAgent&) = delete;
176
177 c10::intrusive_ptr<JitFuture> send(
178 const WorkerInfo& to,
179 c10::intrusive_ptr<Message> message,
180 const float rpcTimeoutSeconds = kUnsetRpcTimeout,
181 const DeviceMap& deviceMap = {}) override;
182
183 // join() and sync() would be deprecated -
184 // https://github.com/pytorch/pytorch/issues/27647
185 void join(bool shutdown = false, float timeout = 0) override;
186 void sync() override{};
187 void startImpl() override;
188 void shutdownImpl() override;
189
190 ~TensorPipeAgent() override;
191
192 const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
193 const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override;
194 std::vector<WorkerInfo> getWorkerInfos() const override;
195 void updateGroupMembership(
196 const WorkerInfo& workerInfo,
197 const std::vector<c10::Device> devices,
198 const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
199 bool isJoin);
200
201 std::unordered_map<std::string, std::string> getMetrics() override;
202
203 void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
204
205 TensorPipeRpcBackendOptions getBackendOptions() const;
206
207 const c10::intrusive_ptr<::c10d::Store> getStore() const;
208
209 DeviceMap getDeviceMap(const WorkerInfo& dest) const override;
210
211 const std::vector<c10::Device>& getDevices() const override;
212
213 using NetworkDataDict =
214 std::unordered_map<std::string, AggregatedNetworkData>;
215
216 // Returns metrics tracked by the NetworkDataDict
217 NetworkDataDict getNetworkData();
218 // Returns NetworkSourceInfo struct
219 NetworkSourceInfo getNetworkSourceInfo();
220
221 static const std::string& guessAddress();
222
223 // For testing purposes.
224 size_t timeoutMapSize();
225 size_t numPendingResponses();
226 size_t messageIdToTimeoutMapSize();
227
228 const bool isStaticGroup_;
229
230 protected:
231 // TensorPipe write function that could be used to write response
232 // messages by server, and write request messages by client. This
233 // is a protected method since it is overwritten by FaultyTensorPipeAgent
234 virtual void pipeWrite(
235 const std::shared_ptr<tensorpipe::Pipe>&,
236 c10::intrusive_ptr<Message> message,
237 std::vector<c10::Device>&& devices,
238 std::vector<c10::Stream> streams,
239 std::function<void(const tensorpipe::Error&)>) noexcept;
240
241 private:
242 // Removes the given messageId with the given expirationTime from the
243 // timeoutMap_.
244 void removeFromTimeoutMap(uint64_t messageId);
245
246 // Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_
247 void prepareNames(bool isStaticGroup);
248
249 // Check the static group attribute with the value set in store
250 void checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store>& store);
251
252 const std::string& findWorkerURL(const WorkerInfo& worker) const;
253
254 // Only use for Dynamic RPC groups, method to have worker leave group
255 void leaveGroup();
256
257 // TensorPipe read function that could be used to read response messages
258 // by client, and read request messages by server.
259 void pipeRead(
260 const std::shared_ptr<tensorpipe::Pipe>&,
261 std::function<void(
262 const tensorpipe::Error&,
263 c10::intrusive_ptr<Message>,
264 std::vector<c10::Stream>)>) noexcept;
265
266 // Callback of listener accept()
267 void onListenerAccepted(
268 const tensorpipe::Error& error,
269 std::shared_ptr<tensorpipe::Pipe>& pipe);
270
271 // Respond to a call from a peer
272 void respond(std::shared_ptr<tensorpipe::Pipe>& pipe);
273
274 void sendCompletedResponseMessage(
275 std::shared_ptr<tensorpipe::Pipe>& pipe,
276 JitFuture& futureResponseMessage,
277 uint64_t messageId,
278 std::vector<c10::Stream> stream);
279
280 // Collects metrics from successful RPC calls
281 void trackNetworkData(
282 uint64_t requestSize,
283 uint64_t responseSize,
284 const std::string& destWorkerName);
285
286 // Collects metrics from failed RPC calls
287 void trackNetworkError(
288 uint64_t requestSize,
289 const std::string& destWorkerName);
290
291 inline std::vector<c10::Device> getDevicesForRemote(
292 const std::string& remoteName,
293 const Message& message) const;
294
295 // When a request+response completes, we need to mark the future message as
296 // complete. However, if its timeout has already expired, it already has an
297 // error set. There is no atomic "test-and-set" way to mark a future complete
298 // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even
299 // then, it ends up printing a log message, which may worry the user. To solve
300 // both issues we use a separate atomic flag to know the status of the future.
301 struct AtomicJitFuture {
302 explicit AtomicJitFuture(const std::vector<c10::Device>& devices) {
303 jitFuture = c10::make_intrusive<at::ivalue::Future>(
304 at::AnyClassType::get(), devices);
305 }
306
307 std::atomic_flag isComplete = ATOMIC_FLAG_INIT;
308 c10::intrusive_ptr<JitFuture> jitFuture;
309 };
310
311 // Maintains state per client pipe to track pending response messages and
312 // error states. pendingResponseMessage_ should be protected by a mutex since
313 // it can be raced with user send() call.
314 // TODO: To achieve better performance we can have a pipe pool per
315 // client that can be configured using RpcBackendOptions.
316 struct ClientPipe {
317 // NOLINTNEXTLINE(modernize-pass-by-value)
318 explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe) : pipe_(pipe) {}
319 std::shared_ptr<tensorpipe::Pipe> pipe_;
320 mutable std::mutex mutex_;
321 bool inError_{false};
322 // Map from Message Request ID's to corresponding futures.
323 std::unordered_map<uint64_t, std::shared_ptr<AtomicJitFuture>>
324 pendingResponseMessage_;
325 };
326
327 const c10::intrusive_ptr<::c10d::Store> store_;
328
329 const TensorPipeRpcBackendOptions opts_;
330 // For dynamic RPC, the reverse device maps are updated whenever a new rank
331 // joins or leaves the group
332 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
333 // Local devices used by this agent. If application didn't specify this
334 // field, it will be initialized using corresponding local devices in
335 // opts_.deviceMaps and reverseDeviceMaps_;
336 std::vector<c10::Device> devices_;
337
338 ThreadPool threadPool_;
339 std::shared_ptr<tensorpipe::Context> context_;
340 std::shared_ptr<tensorpipe::Listener> listener_;
341
342 mutable std::mutex connectedPipesMutex_;
343 std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;
344
345 // Maps keyed on name and id for easy WorkerInfo lookup.
346 std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
347 std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
348 std::unordered_map<std::string, std::string> workerNameToURL_;
349
350 ::c10d::PrefixStore rankToNameStore_;
351 ::c10d::PrefixStore nameToAddressStore_;
352 // Store keys that will used to count joined processes and active calls during
353 // the shutdown process
354 ::c10d::PrefixStore shutdownStore_;
355 int worldSize_ = 0;
356 std::atomic<uint64_t> nextMessageID_{0};
357
358 // Metadata used for tracking of whether certain RPCs have timed out or not.
359 struct TimeoutMessageMetadata {
360 TimeoutMessageMetadata(
361 uint64_t messageId_,
362 // NOLINTNEXTLINE(modernize-pass-by-value)
363 std::shared_ptr<AtomicJitFuture> responseFuture_,
364 std::chrono::milliseconds timeout_)
365 : messageId(messageId_),
366 responseFuture(responseFuture_),
367 timeout(timeout_) {}
368 uint64_t messageId;
369 std::shared_ptr<AtomicJitFuture> responseFuture;
370 std::chrono::milliseconds timeout;
371 };
372
373 // Map to store the expiration times for each message.
374 std::map<steady_clock_time_point, std::vector<TimeoutMessageMetadata>>
375 timeoutMap_;
376
377 // Map to store the messageId to expiry time.
378 std::unordered_map<uint64_t, steady_clock_time_point> messageIdToTimeout_;
379
380 // Thread that will poll the timeoutMap_ for timed out messages and mark them
381 // with an error accordingly
382 std::thread timeoutThread_;
383
384 // Function run by the timeoutThread_ to check for timed out RPCs
385 void pollTimeoutRpcs();
386
387 // Mutex to guard the timeoutMap_
388 std::mutex timeoutMapMutex_;
389
390 // Condition Variable to signal population of the timeoutMap_
391 std::condition_variable timeoutThreadCV_;
392
393 // Returns the expiration time for an RPC by adding the current time to the
394 // passed in timeout.
395 inline steady_clock_time_point computeRpcMessageExpiryTime(
396 std::chrono::milliseconds timeout) const {
397 return std::chrono::time_point_cast<std::chrono::milliseconds>(
398 std::chrono::steady_clock::now() + timeout);
399 }
400
401 // Handle error on an outgoing pipe
402 void handleClientError(
403 ClientPipe& clientPipe,
404 const tensorpipe::Error& error);
405
406 // This is a generic struct for capturing Time-Series Metrics. It keeps a
407 // running sum and count of data points (observations), and can return an
408 // average of the data points seen so far. This is currently only used for
409 // tracking the GIL Wait Time in RPC Agents, but can be used for other metrics
410 // as well.
411 struct TimeSeriesMetricsTracker {
412 // Running sum of the data points seen so far
413 uint64_t currentSum_;
414 // Running count of the data points seen so far
415 uint64_t currentCount_;
416
417 explicit TimeSeriesMetricsTracker(
418 uint64_t currentSum = 0,
419 uint64_t currentCount = 0);
420
421 // Adds a data point (which is basically one observation for the metric
422 // being tracked) to the running sum and count.
423 void addData(uint64_t dataPoint);
424 // Returns the average of all the data points seen so far.
425 float computeAverage() const;
426 };
427
428 // Map of Time-Series metrics tracked by the RPC Agent
429 std::unordered_map<std::string, TimeSeriesMetricsTracker> timeSeriesMetrics_;
430 // Mutex to guard timeSeriesMetrics_
431 std::mutex metricsMutex_;
432
433 // Custom lock guard used to check if the RPC group is dynamic and lock the
434 // mutex if so
435 struct GroupMembershipLockGuard {
436 GroupMembershipLockGuard(std::mutex& mutex, bool isStaticGroup)
437 : ref_(mutex), isStaticGroup_(isStaticGroup) {
438 if (isStaticGroup_) {
439 ref_.lock();
440 }
441 }
442
443 ~GroupMembershipLockGuard() {
444 if (isStaticGroup_) {
445 ref_.unlock();
446 }
447 }
448
449 GroupMembershipLockGuard(const GroupMembershipLockGuard&) = delete;
450
451 private:
452 std::mutex& ref_;
453 bool isStaticGroup_;
454 };
455 // Mutex to guard access to group membership data
456 // e.g. updates to (workerIdToInfo_, workerNameToInfo_, workerNameToURL_)
457 mutable std::mutex groupMembershipMutex_;
458
459 // Map to Track Network Data
460 NetworkDataDict networkData_;
461 // Mutex to guard networkData_
462 std::mutex networkDataMutex_;
463
464 // A mutex and a cv to guard access to the call counts and watch for changes.
465 std::mutex callCountMutex_;
466 std::condition_variable callCountCV_;
467 // Running total of un-processed, un-errored RPC calls sent
468 int32_t clientActiveCalls_{0};
469 // Running total of un-processed RPC requests received
470 int32_t serverActiveCalls_{0};
471 // Running total of RPC requests that will be completed asynchronously
472 int32_t serverActiveAsyncCalls_{0};
473
474 // Whether a global graceful shutdown has begun, in which case we'll silence
475 // error messages due to remote workers closing their pipes.
476 std::atomic<bool> shuttingDown_{false};
477
478 // Helpers to modify the counts while correctly dealing with the mutex and cv.
479 void increaseCallCount(int32_t& count);
480 void decreaseCallCount(int32_t& count);
481
482 // Helpers to set the state of the requests.
483 void markFutureAsComplete(
484 std::shared_ptr<AtomicJitFuture> atomicFuture,
485 c10::intrusive_ptr<Message> message,
486 std::vector<c10::Stream> streams);
487 void markFutureWithError(
488 std::shared_ptr<AtomicJitFuture> atomicFuture,
489 std::string errorMsg);
490};
491
492} // namespace rpc
493} // namespace distributed
494} // namespace torch
495
496#endif // USE_TENSORPIPE
497