1 | #pragma once |
2 | |
3 | #ifdef USE_TENSORPIPE |
4 | |
5 | #include <atomic> |
6 | #include <thread> |
7 | |
8 | #include <c10/core/thread_pool.h> |
9 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
10 | #include <torch/csrc/distributed/c10d/Store.hpp> |
11 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
12 | |
13 | // Forward-declare the TensorPipe classes we need, to avoid including its |
14 | // headers in PyTorch's ones and thus have it become a public dependency. |
15 | |
16 | namespace tensorpipe { |
17 | |
18 | class Context; |
19 | class Error; |
20 | class Listener; |
21 | class Message; |
22 | class Pipe; |
23 | |
24 | namespace transport { |
25 | class Context; |
26 | } // namespace transport |
27 | |
28 | namespace channel { |
29 | class Context; |
30 | } // namespace channel |
31 | |
32 | } // namespace tensorpipe |
33 | |
34 | namespace torch { |
35 | namespace distributed { |
36 | namespace rpc { |
37 | |
38 | // These priorities instruct TensorPipe on which transport/channel to pick |
39 | // during handshake. Higher priorities will take precedence over lower ones. |
40 | // The transport with lowest priority will be the one used to bootstrap pipes. |
41 | |
42 | constexpr int64_t kShmTransportPriority = 200; |
43 | constexpr int64_t kIbvTransportPriority = 100; |
44 | // The UV transport just uses TCP and should work everywhere, thus keep it last. |
45 | constexpr int64_t kUvTransportPriority = 0; |
46 | |
47 | constexpr int64_t kCmaChannelPriority = 1200; |
48 | constexpr int64_t kMultiplexedUvChannelPriority = 1100; |
49 | // The basic channel reuses a transport as a channel, and is thus our fallback. |
50 | constexpr int64_t kBasicChannelPriority = 1000; |
51 | |
52 | // CPU channel have higher priority than CUDA channels, since the latter might |
53 | // handle CPU-to-CPU transfers, but will always be less efficient than their |
54 | // CPU-only counterparts. |
55 | constexpr int64_t kCudaIpcChannelPriority = 300; |
56 | constexpr int64_t kCudaGdrChannelPriority = 200; |
57 | constexpr int64_t kCudaXthChannelPriority = 400; |
58 | constexpr int64_t kCudaBasicChannelPriority = 0; |
59 | |
60 | using steady_clock_time_point = |
61 | std::chrono::time_point<std::chrono::steady_clock>; |
62 | |
63 | struct TORCH_API TransportRegistration { |
64 | std::shared_ptr<tensorpipe::transport::Context> transport; |
65 | int64_t priority; |
66 | std::string address; |
67 | }; |
68 | |
69 | C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration); |
70 | |
71 | struct TORCH_API ChannelRegistration { |
72 | std::shared_ptr<tensorpipe::channel::Context> channel; |
73 | int64_t priority; |
74 | }; |
75 | |
76 | C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration); |
77 | |
78 | constexpr auto kDefaultNumWorkerThreads = 16; |
79 | |
80 | struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions { |
81 | TensorPipeRpcBackendOptions( |
82 | int numWorkerThreads, |
83 | optional<std::vector<std::string>> transports, |
84 | optional<std::vector<std::string>> channels, |
85 | float rpc_timeout, |
86 | std::string init_method, |
87 | std::unordered_map<std::string, DeviceMap> device_maps = {}, |
88 | std::vector<c10::Device> devices = {}) |
89 | : RpcBackendOptions(rpc_timeout, init_method), |
90 | numWorkerThreads(numWorkerThreads), |
91 | transports(std::move(transports)), |
92 | channels(std::move(channels)), |
93 | deviceMaps(std::move(device_maps)), |
94 | devices(std::move(devices)) { |
95 | TORCH_CHECK( |
96 | numWorkerThreads > 0, |
97 | "num_worker_threads must be positive, got " , |
98 | numWorkerThreads); |
99 | |
100 | if (this->transports.has_value()) { |
101 | for (const std::string& transportName : this->transports.value()) { |
102 | TORCH_CHECK( |
103 | TensorPipeTransportRegistry()->Has(transportName), |
104 | "Unknown transport: " , |
105 | transportName); |
106 | } |
107 | } |
108 | |
109 | if (this->channels.has_value()) { |
110 | for (const std::string& channelName : this->channels.value()) { |
111 | TORCH_CHECK( |
112 | TensorPipeChannelRegistry()->Has(channelName), |
113 | "Unknown channel: " , |
114 | channelName); |
115 | } |
116 | } |
117 | } |
118 | |
119 | void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) { |
120 | auto iter = deviceMaps.find(workerName); |
121 | if (iter == deviceMaps.end()) { |
122 | deviceMaps[workerName] = deviceMap; |
123 | } else { |
124 | for (auto& entry : deviceMap) { |
125 | // c10::Device has no default constructor, hence map[device] dosn't work |
126 | // In C++-17 we can use insert_or_assign. |
127 | auto entryIter = iter->second.find(entry.first); |
128 | if (entryIter == iter->second.end()) { |
129 | iter->second.emplace(entry.first, entry.second); |
130 | } else { |
131 | entryIter->second = entry.second; |
132 | } |
133 | } |
134 | } |
135 | } |
136 | |
137 | int numWorkerThreads; |
138 | const optional<std::vector<std::string>> transports; |
139 | const optional<std::vector<std::string>> channels; |
140 | std::unordered_map<std::string, DeviceMap> deviceMaps; |
141 | std::vector<c10::Device> devices; |
142 | }; |
143 | |
144 | // Struct to track the network source metrics |
145 | struct TORCH_API NetworkSourceInfo { |
146 | worker_id_t srcRank; |
147 | std::vector<uint8_t> srcMachineAddr; |
148 | }; |
149 | |
150 | // Struct to track aggregated network metrics |
151 | struct TORCH_API AggregatedNetworkData { |
152 | uint64_t numCalls{0}; |
153 | uint64_t totalSentBytes{0}; |
154 | uint64_t totalRecvBytes{0}; |
155 | uint64_t totalErrors{0}; |
156 | }; |
157 | |
158 | // TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe) |
159 | // to transparently move tensors and payloads through the fastest available |
160 | // transport or channel. It acts like a hybrid RPC transport, providing shared |
161 | // memory (linux) and TCP (linux & mac) support. CUDA support is in progress. |
162 | class TORCH_API TensorPipeAgent : public RpcAgent { |
163 | public: |
164 | TensorPipeAgent( |
165 | const c10::intrusive_ptr<::c10d::Store>& store, |
166 | std::string selfName, |
167 | worker_id_t selfId, |
168 | optional<int> worldSize, |
169 | TensorPipeRpcBackendOptions opts, |
170 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, |
171 | std::vector<c10::Device> devices, |
172 | std::unique_ptr<RequestCallback> cb); |
173 | |
174 | TensorPipeAgent(const TensorPipeAgent&) = delete; |
175 | TensorPipeAgent& operator=(const TensorPipeAgent&) = delete; |
176 | |
177 | c10::intrusive_ptr<JitFuture> send( |
178 | const WorkerInfo& to, |
179 | c10::intrusive_ptr<Message> message, |
180 | const float rpcTimeoutSeconds = kUnsetRpcTimeout, |
181 | const DeviceMap& deviceMap = {}) override; |
182 | |
183 | // join() and sync() would be deprecated - |
184 | // https://github.com/pytorch/pytorch/issues/27647 |
185 | void join(bool shutdown = false, float timeout = 0) override; |
186 | void sync() override{}; |
187 | void startImpl() override; |
188 | void shutdownImpl() override; |
189 | |
190 | ~TensorPipeAgent() override; |
191 | |
192 | const WorkerInfo& getWorkerInfo(const std::string& workerName) const override; |
193 | const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override; |
194 | std::vector<WorkerInfo> getWorkerInfos() const override; |
195 | void updateGroupMembership( |
196 | const WorkerInfo& workerInfo, |
197 | const std::vector<c10::Device> devices, |
198 | const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, |
199 | bool isJoin); |
200 | |
201 | std::unordered_map<std::string, std::string> getMetrics() override; |
202 | |
203 | void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override; |
204 | |
205 | TensorPipeRpcBackendOptions getBackendOptions() const; |
206 | |
207 | const c10::intrusive_ptr<::c10d::Store> getStore() const; |
208 | |
209 | DeviceMap getDeviceMap(const WorkerInfo& dest) const override; |
210 | |
211 | const std::vector<c10::Device>& getDevices() const override; |
212 | |
213 | using NetworkDataDict = |
214 | std::unordered_map<std::string, AggregatedNetworkData>; |
215 | |
216 | // Returns metrics tracked by the NetworkDataDict |
217 | NetworkDataDict getNetworkData(); |
218 | // Returns NetworkSourceInfo struct |
219 | NetworkSourceInfo getNetworkSourceInfo(); |
220 | |
221 | static const std::string& guessAddress(); |
222 | |
223 | // For testing purposes. |
224 | size_t timeoutMapSize(); |
225 | size_t numPendingResponses(); |
226 | size_t messageIdToTimeoutMapSize(); |
227 | |
228 | const bool isStaticGroup_; |
229 | |
230 | protected: |
231 | // TensorPipe write function that could be used to write response |
232 | // messages by server, and write request messages by client. This |
233 | // is a protected method since it is overwritten by FaultyTensorPipeAgent |
234 | virtual void pipeWrite( |
235 | const std::shared_ptr<tensorpipe::Pipe>&, |
236 | c10::intrusive_ptr<Message> message, |
237 | std::vector<c10::Device>&& devices, |
238 | std::vector<c10::Stream> streams, |
239 | std::function<void(const tensorpipe::Error&)>) noexcept; |
240 | |
241 | private: |
242 | // Removes the given messageId with the given expirationTime from the |
243 | // timeoutMap_. |
244 | void removeFromTimeoutMap(uint64_t messageId); |
245 | |
246 | // Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_ |
247 | void prepareNames(bool isStaticGroup); |
248 | |
249 | // Check the static group attribute with the value set in store |
250 | void checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store>& store); |
251 | |
252 | const std::string& findWorkerURL(const WorkerInfo& worker) const; |
253 | |
254 | // Only use for Dynamic RPC groups, method to have worker leave group |
255 | void leaveGroup(); |
256 | |
257 | // TensorPipe read function that could be used to read response messages |
258 | // by client, and read request messages by server. |
259 | void pipeRead( |
260 | const std::shared_ptr<tensorpipe::Pipe>&, |
261 | std::function<void( |
262 | const tensorpipe::Error&, |
263 | c10::intrusive_ptr<Message>, |
264 | std::vector<c10::Stream>)>) noexcept; |
265 | |
266 | // Callback of listener accept() |
267 | void onListenerAccepted( |
268 | const tensorpipe::Error& error, |
269 | std::shared_ptr<tensorpipe::Pipe>& pipe); |
270 | |
271 | // Respond to a call from a peer |
272 | void respond(std::shared_ptr<tensorpipe::Pipe>& pipe); |
273 | |
274 | void sendCompletedResponseMessage( |
275 | std::shared_ptr<tensorpipe::Pipe>& pipe, |
276 | JitFuture& futureResponseMessage, |
277 | uint64_t messageId, |
278 | std::vector<c10::Stream> stream); |
279 | |
280 | // Collects metrics from successful RPC calls |
281 | void trackNetworkData( |
282 | uint64_t requestSize, |
283 | uint64_t responseSize, |
284 | const std::string& destWorkerName); |
285 | |
286 | // Collects metrics from failed RPC calls |
287 | void trackNetworkError( |
288 | uint64_t requestSize, |
289 | const std::string& destWorkerName); |
290 | |
291 | inline std::vector<c10::Device> getDevicesForRemote( |
292 | const std::string& remoteName, |
293 | const Message& message) const; |
294 | |
295 | // When a request+response completes, we need to mark the future message as |
296 | // complete. However, if its timeout has already expired, it already has an |
297 | // error set. There is no atomic "test-and-set" way to mark a future complete |
298 | // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even |
299 | // then, it ends up printing a log message, which may worry the user. To solve |
300 | // both issues we use a separate atomic flag to know the status of the future. |
301 | struct AtomicJitFuture { |
302 | explicit AtomicJitFuture(const std::vector<c10::Device>& devices) { |
303 | jitFuture = c10::make_intrusive<at::ivalue::Future>( |
304 | at::AnyClassType::get(), devices); |
305 | } |
306 | |
307 | std::atomic_flag isComplete = ATOMIC_FLAG_INIT; |
308 | c10::intrusive_ptr<JitFuture> jitFuture; |
309 | }; |
310 | |
311 | // Maintains state per client pipe to track pending response messages and |
312 | // error states. pendingResponseMessage_ should be protected by a mutex since |
313 | // it can be raced with user send() call. |
314 | // TODO: To achieve better performance we can have a pipe pool per |
315 | // client that can be configured using RpcBackendOptions. |
316 | struct ClientPipe { |
317 | // NOLINTNEXTLINE(modernize-pass-by-value) |
318 | explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe) : pipe_(pipe) {} |
319 | std::shared_ptr<tensorpipe::Pipe> pipe_; |
320 | mutable std::mutex mutex_; |
321 | bool inError_{false}; |
322 | // Map from Message Request ID's to corresponding futures. |
323 | std::unordered_map<uint64_t, std::shared_ptr<AtomicJitFuture>> |
324 | pendingResponseMessage_; |
325 | }; |
326 | |
327 | const c10::intrusive_ptr<::c10d::Store> store_; |
328 | |
329 | const TensorPipeRpcBackendOptions opts_; |
330 | // For dynamic RPC, the reverse device maps are updated whenever a new rank |
331 | // joins or leaves the group |
332 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_; |
333 | // Local devices used by this agent. If application didn't specify this |
334 | // field, it will be initialized using corresponding local devices in |
335 | // opts_.deviceMaps and reverseDeviceMaps_; |
336 | std::vector<c10::Device> devices_; |
337 | |
338 | ThreadPool threadPool_; |
339 | std::shared_ptr<tensorpipe::Context> context_; |
340 | std::shared_ptr<tensorpipe::Listener> listener_; |
341 | |
342 | mutable std::mutex connectedPipesMutex_; |
343 | std::unordered_map<worker_id_t, ClientPipe> connectedPipes_; |
344 | |
345 | // Maps keyed on name and id for easy WorkerInfo lookup. |
346 | std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_; |
347 | std::unordered_map<std::string, WorkerInfo> workerNameToInfo_; |
348 | std::unordered_map<std::string, std::string> workerNameToURL_; |
349 | |
350 | ::c10d::PrefixStore rankToNameStore_; |
351 | ::c10d::PrefixStore nameToAddressStore_; |
352 | // Store keys that will used to count joined processes and active calls during |
353 | // the shutdown process |
354 | ::c10d::PrefixStore shutdownStore_; |
355 | int worldSize_ = 0; |
356 | std::atomic<uint64_t> nextMessageID_{0}; |
357 | |
358 | // Metadata used for tracking of whether certain RPCs have timed out or not. |
359 | struct TimeoutMessageMetadata { |
360 | TimeoutMessageMetadata( |
361 | uint64_t messageId_, |
362 | // NOLINTNEXTLINE(modernize-pass-by-value) |
363 | std::shared_ptr<AtomicJitFuture> responseFuture_, |
364 | std::chrono::milliseconds timeout_) |
365 | : messageId(messageId_), |
366 | responseFuture(responseFuture_), |
367 | timeout(timeout_) {} |
368 | uint64_t messageId; |
369 | std::shared_ptr<AtomicJitFuture> responseFuture; |
370 | std::chrono::milliseconds timeout; |
371 | }; |
372 | |
373 | // Map to store the expiration times for each message. |
374 | std::map<steady_clock_time_point, std::vector<TimeoutMessageMetadata>> |
375 | timeoutMap_; |
376 | |
377 | // Map to store the messageId to expiry time. |
378 | std::unordered_map<uint64_t, steady_clock_time_point> messageIdToTimeout_; |
379 | |
380 | // Thread that will poll the timeoutMap_ for timed out messages and mark them |
381 | // with an error accordingly |
382 | std::thread timeoutThread_; |
383 | |
384 | // Function run by the timeoutThread_ to check for timed out RPCs |
385 | void pollTimeoutRpcs(); |
386 | |
387 | // Mutex to guard the timeoutMap_ |
388 | std::mutex timeoutMapMutex_; |
389 | |
390 | // Condition Variable to signal population of the timeoutMap_ |
391 | std::condition_variable timeoutThreadCV_; |
392 | |
393 | // Returns the expiration time for an RPC by adding the current time to the |
394 | // passed in timeout. |
395 | inline steady_clock_time_point computeRpcMessageExpiryTime( |
396 | std::chrono::milliseconds timeout) const { |
397 | return std::chrono::time_point_cast<std::chrono::milliseconds>( |
398 | std::chrono::steady_clock::now() + timeout); |
399 | } |
400 | |
401 | // Handle error on an outgoing pipe |
402 | void handleClientError( |
403 | ClientPipe& clientPipe, |
404 | const tensorpipe::Error& error); |
405 | |
406 | // This is a generic struct for capturing Time-Series Metrics. It keeps a |
407 | // running sum and count of data points (observations), and can return an |
408 | // average of the data points seen so far. This is currently only used for |
409 | // tracking the GIL Wait Time in RPC Agents, but can be used for other metrics |
410 | // as well. |
411 | struct TimeSeriesMetricsTracker { |
412 | // Running sum of the data points seen so far |
413 | uint64_t currentSum_; |
414 | // Running count of the data points seen so far |
415 | uint64_t currentCount_; |
416 | |
417 | explicit TimeSeriesMetricsTracker( |
418 | uint64_t currentSum = 0, |
419 | uint64_t currentCount = 0); |
420 | |
421 | // Adds a data point (which is basically one observation for the metric |
422 | // being tracked) to the running sum and count. |
423 | void addData(uint64_t dataPoint); |
424 | // Returns the average of all the data points seen so far. |
425 | float computeAverage() const; |
426 | }; |
427 | |
428 | // Map of Time-Series metrics tracked by the RPC Agent |
429 | std::unordered_map<std::string, TimeSeriesMetricsTracker> timeSeriesMetrics_; |
430 | // Mutex to guard timeSeriesMetrics_ |
431 | std::mutex metricsMutex_; |
432 | |
433 | // Custom lock guard used to check if the RPC group is dynamic and lock the |
434 | // mutex if so |
435 | struct GroupMembershipLockGuard { |
436 | GroupMembershipLockGuard(std::mutex& mutex, bool isStaticGroup) |
437 | : ref_(mutex), isStaticGroup_(isStaticGroup) { |
438 | if (isStaticGroup_) { |
439 | ref_.lock(); |
440 | } |
441 | } |
442 | |
443 | ~GroupMembershipLockGuard() { |
444 | if (isStaticGroup_) { |
445 | ref_.unlock(); |
446 | } |
447 | } |
448 | |
449 | GroupMembershipLockGuard(const GroupMembershipLockGuard&) = delete; |
450 | |
451 | private: |
452 | std::mutex& ref_; |
453 | bool isStaticGroup_; |
454 | }; |
455 | // Mutex to guard access to group membership data |
456 | // e.g. updates to (workerIdToInfo_, workerNameToInfo_, workerNameToURL_) |
457 | mutable std::mutex groupMembershipMutex_; |
458 | |
459 | // Map to Track Network Data |
460 | NetworkDataDict networkData_; |
461 | // Mutex to guard networkData_ |
462 | std::mutex networkDataMutex_; |
463 | |
464 | // A mutex and a cv to guard access to the call counts and watch for changes. |
465 | std::mutex callCountMutex_; |
466 | std::condition_variable callCountCV_; |
467 | // Running total of un-processed, un-errored RPC calls sent |
468 | int32_t clientActiveCalls_{0}; |
469 | // Running total of un-processed RPC requests received |
470 | int32_t serverActiveCalls_{0}; |
471 | // Running total of RPC requests that will be completed asynchronously |
472 | int32_t serverActiveAsyncCalls_{0}; |
473 | |
474 | // Whether a global graceful shutdown has begun, in which case we'll silence |
475 | // error messages due to remote workers closing their pipes. |
476 | std::atomic<bool> shuttingDown_{false}; |
477 | |
478 | // Helpers to modify the counts while correctly dealing with the mutex and cv. |
479 | void increaseCallCount(int32_t& count); |
480 | void decreaseCallCount(int32_t& count); |
481 | |
482 | // Helpers to set the state of the requests. |
483 | void markFutureAsComplete( |
484 | std::shared_ptr<AtomicJitFuture> atomicFuture, |
485 | c10::intrusive_ptr<Message> message, |
486 | std::vector<c10::Stream> streams); |
487 | void markFutureWithError( |
488 | std::shared_ptr<AtomicJitFuture> atomicFuture, |
489 | std::string errorMsg); |
490 | }; |
491 | |
492 | } // namespace rpc |
493 | } // namespace distributed |
494 | } // namespace torch |
495 | |
496 | #endif // USE_TENSORPIPE |
497 | |