1#pragma once
2
3#include <torch/csrc/distributed/rpc/message.h>
4#include <torch/csrc/distributed/rpc/request_callback.h>
5#include <torch/csrc/distributed/rpc/types.h>
6
7#include <algorithm>
8#include <cctype>
9#include <chrono>
10#include <condition_variable>
11#include <mutex>
12#include <thread>
13
14namespace torch {
15namespace distributed {
16namespace rpc {
17
18using DeviceMap = std::unordered_map<c10::Device, c10::Device>;
19
20// Default RPC timeout
21constexpr float kDefaultRpcTimeoutSeconds = 60;
22// Unset RPC timeout. This is the value agent::send() will have if user does not
23// pass in a specific timeout, and indicates that we must use the default
24// timeout for RPCs.
25constexpr float kUnsetRpcTimeout = -1;
26constexpr auto kDefaultInitMethod = "env://";
27constexpr float kSecToMsConversion = 1000;
28constexpr auto kRpcTimeoutErrorStr =
29 "RPC ran for more than set timeout ({} ms) and will now be marked with an error";
30
31using steady_clock_time_point =
32 std::chrono::time_point<std::chrono::steady_clock>;
33// Input is qualified name string, output is JIT StrongTypePtr
34// Same as jit::TypeResolver, did not import jit::TypeResolver to here
35// because it could instroduce cyclic dependencies.
36using TypeResolver =
37 std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
38
39struct RpcBackendOptions {
40 RpcBackendOptions()
41 : RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {}
42
43 RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod)
44 : rpcTimeoutSeconds(rpcTimeoutSeconds),
45 initMethod(std::move(initMethod)) {
46 TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative");
47 }
48
49 float rpcTimeoutSeconds;
50 std::string initMethod;
51};
52
53// A globally unique ID to identify an RpcAgent
54struct TORCH_API WorkerInfo : torch::CustomClassHolder {
55 WorkerInfo(std::string name, int64_t id);
56
57 WorkerInfo(std::string name, worker_id_t id);
58
59 bool operator==(const WorkerInfo& rhs) {
60 return (id_ == rhs.id_) && (name_ == rhs.name_);
61 }
62
63 static constexpr size_t MAX_NAME_LEN = 128;
64
65 const std::string name_;
66 const worker_id_t id_;
67};
68
69struct TORCH_API RegisterWorkerInfoOnce {
70 RegisterWorkerInfoOnce();
71};
72
73TORCH_API std::ostream& operator<<(
74 std::ostream& os,
75 const WorkerInfo& workerInfo);
76
77// Struct for options to configure the RPC Retry protocol.
78struct TORCH_API RpcRetryOptions {
79 // Using a default constructor like all other Options structs in the RPC
80 // codebase. TORCH_CHECKs for input validation are done in the
81 // sendWithRetries function.
82 RpcRetryOptions() = default;
83 // Maximum number of times we will retry the RPC
84 int maxRetries{5};
85 // Initial duration between consecutive RPC send attempts
86 std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)};
87 // Constant for exponential backoff used while calculating future wait
88 // durations
89 float retryBackoff{1.5};
90};
91
92// Struct that stores all the metadata needed to retry a given RPC.
93struct TORCH_API RpcRetryInfo {
94 RpcRetryInfo(
95 const WorkerInfo& to,
96 c10::intrusive_ptr<Message> message,
97 c10::intrusive_ptr<JitFuture> originalFuture,
98 int retryCount,
99 RpcRetryOptions options)
100 : to_(to),
101 message_(std::move(message)),
102 originalFuture_(std::move(originalFuture)),
103 retryCount_(retryCount),
104 options_(options) {}
105
106 const WorkerInfo& to_;
107 c10::intrusive_ptr<Message> message_;
108 // Future that is returned to the caller of sendWithRetries().
109 c10::intrusive_ptr<JitFuture> originalFuture_;
110 // Number of send attempts completed so far.
111 int retryCount_;
112 RpcRetryOptions options_;
113};
114
115// ``RpcAgent`` is the base class for sending and receiving RPC messages. It
116// provides a unified ``send`` API for both request and response messages, and
117// will invoke the given ``RequestCallback`` to process received requests. It
118// should immediately become ready to serve request and accept response after
119// construction.
120class TORCH_API RpcAgent {
121 public:
122 // `WorkerInfo` is the globally unique identifier for this RpcAgent instance.
123 // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the
124 // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent``
125 // implementation to determine how to resolve names. ``id_`` is the globally
126 // unique ID for this ``RpcAgent``. This should be determined by the
127 // ``RpcAgent`` implementation.
128 // The ``RequestCallback`` will be invoked to handle received requests. This
129 // ``RpcAgent`` base class makes no assumption on the thread-safeness of the
130 // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that
131 // its threading model conform to ``RequestCallback``'s requirement.
132 // NB: RpcAgent implementations should not start serving requests until
133 // ``start()`` is called, as there could be other contexts that have not been
134 // initialized yet at this time.
135 RpcAgent(
136 WorkerInfo id,
137 std::unique_ptr<RequestCallback> cb,
138 std::chrono::milliseconds rpcTimeout);
139
140 virtual ~RpcAgent();
141
142 // Send a message to the ``RpcAgent`` of id ``to`` and returns a
143 // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it
144 // cannot block until it receives the response.
145 //
146 // If ``message.isRequest()`` is true, the ``JitFuture`` will be
147 // completed when the response arrives. For other message types, the Future
148 // should be ignored by the caller.
149 virtual c10::intrusive_ptr<JitFuture> send(
150 const WorkerInfo& to,
151 c10::intrusive_ptr<Message> message,
152 const float rpcTimeoutSeconds = kUnsetRpcTimeout,
153 const DeviceMap& deviceMap = {}) = 0;
154
155 // Retries sending the message up to maxRetries times until an ACK is
156 // receieved. The duration between consecutive sends is increased over
157 // time using an exponential backoff algorithm.
158 //
159 // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a
160 // ``JitFuture`` ptr, just like send(). Caller can specify the maximum
161 // number of retries for this RPC (default is 5), initial duration between
162 // sends (default is 1000ms), and backoff constant (default is 1.5) by
163 // passing in the RpcRetryOptions struct. This API might end up
164 // executing a method twice on the remote end (it does not guarantee
165 // exactly-once semantics). Therefore, the user must ensure their requests
166 // are idempotent.
167 c10::intrusive_ptr<JitFuture> sendWithRetries(
168 const WorkerInfo& to,
169 c10::intrusive_ptr<Message> message,
170 RpcRetryOptions retryOptions = RpcRetryOptions());
171
172 // Return a reference to the ``WorkerInfo`` of this RpcAgent.
173 // NB: not using ``c10::optional<const std::string&>`` here because we might
174 // need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
175 // implementations to depend on libtorch.
176 const WorkerInfo& getWorkerInfo() const;
177
178 // Return a reference to the ``WorkerInfo`` of the given ``workerName``.
179 virtual const WorkerInfo& getWorkerInfo(
180 const std::string& workerName) const = 0;
181
182 virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
183
184 virtual std::vector<WorkerInfo> getWorkerInfos() const = 0;
185
186 // Retrieve the timeout for all RPCs.
187 inline std::chrono::milliseconds getRpcTimeout() const {
188 return rpcTimeout_.load();
189 }
190
191 // Set the timeout for all RPCs
192 inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) {
193 rpcTimeout_.store(rpcTimeout);
194 }
195
196 // Call sync and join all internal threads. This method should be called
197 // before every RPC process exits.
198 virtual void join(bool shutdown = false, float timeout = 0) = 0;
199
200 // Synchronize the this process with other ``RpcAgent`` processes. Block until
201 // all ``RpcAgent``s reach this method and send all pending messages.
202 virtual void sync() = 0;
203
204 // Sets up backend-agnostic state for accepting requests. Currently, this
205 // entails setting rpcAgentRunning_ to true, creating the retry thread, and
206 // calling the backend's startImpl.
207 void start();
208
209 // Derived classes must override this function to start accepting requests.
210 // This is used to initialize any backend-specific state. Users must call
211 // start, not startImpl, to initialize the RPC Agent.
212 virtual void startImpl() = 0;
213
214 // Stop accepting requests and shutdown the RPC framework as soon as possible
215 // by terminating all RPC threads.
216 void shutdown();
217
218 // Derived classes must override this function to start accepting requests.
219 // THis is used to clean up any backend-specific state. Users must call
220 // shutdown, not shutdownImpl, to shutdown the RPC Agent.
221 virtual void shutdownImpl() = 0;
222
223 // Check if current RPC agent is set.
224 static bool isCurrentRpcAgentSet();
225
226 // Retrieve the valid current RPC agent.
227 static std::shared_ptr<RpcAgent> getCurrentRpcAgent();
228
229 // Set the current RPC agent.
230 static void setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent);
231
232 // Retrieve metrics as KV map
233 virtual std::unordered_map<std::string, std::string> getMetrics() = 0;
234
235 // Retrive debug info in addition to metrics as KV map
236 virtual std::unordered_map<std::string, std::string> getDebugInfo();
237
238 // Flag to control whether GIL wait times
239 // should be profiled or not.
240 void enableGILProfiling(bool flag);
241
242 // Retrieve wheher we should profile GIL wait times or not.
243 bool isGILProfilingEnabled();
244
245 // Set type resolver that will be passed to JIT pickler to resolver type Ptr
246 // based on type str.
247 void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver);
248
249 // Get the type resolver
250 std::shared_ptr<TypeResolver> getTypeResolver();
251
252 // Retrieves the device map for the provided destination worker.
253 virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const;
254
255 // Retrieve the (non-CPU) devices that are supported by the agent.
256 virtual const std::vector<c10::Device>& getDevices() const;
257
258 protected:
259 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
260 const WorkerInfo workerInfo_;
261 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
262 const std::unique_ptr<RequestCallback> cb_;
263 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
264 std::atomic<std::chrono::milliseconds> rpcTimeout_;
265 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
266 std::atomic<bool> profilingEnabled_;
267 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
268 std::shared_ptr<TypeResolver> typeResolver_;
269 // Atomic boolean indicating whether this agent is running. It controls
270 // whether several background threads should be running. It is set in
271 // RpcAgent::start() and unset in the derived class shutdown().
272 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
273 std::atomic<bool> rpcAgentRunning_;
274
275 private:
276 static std::shared_ptr<RpcAgent> currentRpcAgent_;
277 // Add GIL wait time data point to metrics
278 virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
279 friend class PythonRpcHandler;
280
281 // Map that stores metadata for RPC's that may need to be re-tried as well as
282 // the timepoint at which we should re-try them.
283 std::map<
284 steady_clock_time_point,
285 std::unordered_set<std::shared_ptr<RpcRetryInfo>>>
286 rpcRetryMap_;
287
288 // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until
289 // the next unACKed RPC's timeout has expired.
290 std::thread rpcRetryThread_;
291
292 // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is
293 // running.
294 void retryExpiredRpcs();
295
296 // This is the callback attached to futures corresponding to send retries.
297 // This handles 3 cases: 1). send was completed, 2). send failed with an
298 // error and we've done maxRetries failed send attempts, and 3). send
299 // failed with an error and we have more retries to go. In case 1, we mark
300 // the original future as complete. In case 2, we mark the future with an
301 // error and do not retry again. In case 3, we move the RpcRetryInfo struct
302 // to another time point in the map to schedule the RPC for a future send.
303 void rpcRetryCallback(
304 JitFuture& message,
305 steady_clock_time_point newTime,
306 std::shared_ptr<RpcRetryInfo> earliestRpc);
307
308 // Function that uses the exponential backoff algorithm to compute the next
309 // time point to retry a given RPC.
310 inline steady_clock_time_point computeNewRpcRetryTime(
311 RpcRetryOptions& options,
312 int retryCount) {
313 // The exponential backoff algorithm being used here is:
314 // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)).
315 std::chrono::milliseconds timedelta =
316 std::chrono::duration_cast<std::chrono::milliseconds>(
317 options.rpcRetryDuration * pow(options.retryBackoff, retryCount));
318 return std::chrono::time_point_cast<std::chrono::milliseconds>(
319 std::chrono::steady_clock::now() + timedelta);
320 }
321
322 // Condition Variable to signal when the rpcRetryMap_ has been populated.
323 std::condition_variable rpcRetryMapCV_;
324
325 // Mutex to protect RpcRetryMap_.
326 std::mutex rpcRetryMutex_;
327};
328
329} // namespace rpc
330} // namespace distributed
331} // namespace torch
332
333namespace std {
334template <>
335struct hash<torch::distributed::rpc::WorkerInfo> {
336 std::size_t operator()(
337 const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept {
338 return worker_info.id_;
339 }
340};
341} // namespace std
342