1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | #include <torch/csrc/distributed/rpc/request_callback.h> |
5 | #include <torch/csrc/distributed/rpc/types.h> |
6 | |
7 | #include <algorithm> |
8 | #include <cctype> |
9 | #include <chrono> |
10 | #include <condition_variable> |
11 | #include <mutex> |
12 | #include <thread> |
13 | |
14 | namespace torch { |
15 | namespace distributed { |
16 | namespace rpc { |
17 | |
18 | using DeviceMap = std::unordered_map<c10::Device, c10::Device>; |
19 | |
20 | // Default RPC timeout |
21 | constexpr float kDefaultRpcTimeoutSeconds = 60; |
22 | // Unset RPC timeout. This is the value agent::send() will have if user does not |
23 | // pass in a specific timeout, and indicates that we must use the default |
24 | // timeout for RPCs. |
25 | constexpr float kUnsetRpcTimeout = -1; |
26 | constexpr auto kDefaultInitMethod = "env://" ; |
27 | constexpr float kSecToMsConversion = 1000; |
28 | constexpr auto kRpcTimeoutErrorStr = |
29 | "RPC ran for more than set timeout ({} ms) and will now be marked with an error" ; |
30 | |
31 | using steady_clock_time_point = |
32 | std::chrono::time_point<std::chrono::steady_clock>; |
33 | // Input is qualified name string, output is JIT StrongTypePtr |
34 | // Same as jit::TypeResolver, did not import jit::TypeResolver to here |
35 | // because it could instroduce cyclic dependencies. |
36 | using TypeResolver = |
37 | std::function<c10::StrongTypePtr(const c10::QualifiedName&)>; |
38 | |
39 | struct RpcBackendOptions { |
40 | RpcBackendOptions() |
41 | : RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {} |
42 | |
43 | RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod) |
44 | : rpcTimeoutSeconds(rpcTimeoutSeconds), |
45 | initMethod(std::move(initMethod)) { |
46 | TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative" ); |
47 | } |
48 | |
49 | float rpcTimeoutSeconds; |
50 | std::string initMethod; |
51 | }; |
52 | |
53 | // A globally unique ID to identify an RpcAgent |
54 | struct TORCH_API WorkerInfo : torch::CustomClassHolder { |
55 | WorkerInfo(std::string name, int64_t id); |
56 | |
57 | WorkerInfo(std::string name, worker_id_t id); |
58 | |
59 | bool operator==(const WorkerInfo& rhs) { |
60 | return (id_ == rhs.id_) && (name_ == rhs.name_); |
61 | } |
62 | |
63 | static constexpr size_t MAX_NAME_LEN = 128; |
64 | |
65 | const std::string name_; |
66 | const worker_id_t id_; |
67 | }; |
68 | |
69 | struct TORCH_API RegisterWorkerInfoOnce { |
70 | RegisterWorkerInfoOnce(); |
71 | }; |
72 | |
73 | TORCH_API std::ostream& operator<<( |
74 | std::ostream& os, |
75 | const WorkerInfo& workerInfo); |
76 | |
77 | // Struct for options to configure the RPC Retry protocol. |
78 | struct TORCH_API RpcRetryOptions { |
79 | // Using a default constructor like all other Options structs in the RPC |
80 | // codebase. TORCH_CHECKs for input validation are done in the |
81 | // sendWithRetries function. |
82 | RpcRetryOptions() = default; |
83 | // Maximum number of times we will retry the RPC |
84 | int maxRetries{5}; |
85 | // Initial duration between consecutive RPC send attempts |
86 | std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)}; |
87 | // Constant for exponential backoff used while calculating future wait |
88 | // durations |
89 | float retryBackoff{1.5}; |
90 | }; |
91 | |
92 | // Struct that stores all the metadata needed to retry a given RPC. |
93 | struct TORCH_API RpcRetryInfo { |
94 | RpcRetryInfo( |
95 | const WorkerInfo& to, |
96 | c10::intrusive_ptr<Message> message, |
97 | c10::intrusive_ptr<JitFuture> originalFuture, |
98 | int retryCount, |
99 | RpcRetryOptions options) |
100 | : to_(to), |
101 | message_(std::move(message)), |
102 | originalFuture_(std::move(originalFuture)), |
103 | retryCount_(retryCount), |
104 | options_(options) {} |
105 | |
106 | const WorkerInfo& to_; |
107 | c10::intrusive_ptr<Message> message_; |
108 | // Future that is returned to the caller of sendWithRetries(). |
109 | c10::intrusive_ptr<JitFuture> originalFuture_; |
110 | // Number of send attempts completed so far. |
111 | int retryCount_; |
112 | RpcRetryOptions options_; |
113 | }; |
114 | |
115 | // ``RpcAgent`` is the base class for sending and receiving RPC messages. It |
116 | // provides a unified ``send`` API for both request and response messages, and |
117 | // will invoke the given ``RequestCallback`` to process received requests. It |
118 | // should immediately become ready to serve request and accept response after |
119 | // construction. |
120 | class TORCH_API RpcAgent { |
121 | public: |
122 | // `WorkerInfo` is the globally unique identifier for this RpcAgent instance. |
123 | // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the |
124 | // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent`` |
125 | // implementation to determine how to resolve names. ``id_`` is the globally |
126 | // unique ID for this ``RpcAgent``. This should be determined by the |
127 | // ``RpcAgent`` implementation. |
128 | // The ``RequestCallback`` will be invoked to handle received requests. This |
129 | // ``RpcAgent`` base class makes no assumption on the thread-safeness of the |
130 | // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that |
131 | // its threading model conform to ``RequestCallback``'s requirement. |
132 | // NB: RpcAgent implementations should not start serving requests until |
133 | // ``start()`` is called, as there could be other contexts that have not been |
134 | // initialized yet at this time. |
135 | RpcAgent( |
136 | WorkerInfo id, |
137 | std::unique_ptr<RequestCallback> cb, |
138 | std::chrono::milliseconds rpcTimeout); |
139 | |
140 | virtual ~RpcAgent(); |
141 | |
142 | // Send a message to the ``RpcAgent`` of id ``to`` and returns a |
143 | // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it |
144 | // cannot block until it receives the response. |
145 | // |
146 | // If ``message.isRequest()`` is true, the ``JitFuture`` will be |
147 | // completed when the response arrives. For other message types, the Future |
148 | // should be ignored by the caller. |
149 | virtual c10::intrusive_ptr<JitFuture> send( |
150 | const WorkerInfo& to, |
151 | c10::intrusive_ptr<Message> message, |
152 | const float rpcTimeoutSeconds = kUnsetRpcTimeout, |
153 | const DeviceMap& deviceMap = {}) = 0; |
154 | |
155 | // Retries sending the message up to maxRetries times until an ACK is |
156 | // receieved. The duration between consecutive sends is increased over |
157 | // time using an exponential backoff algorithm. |
158 | // |
159 | // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a |
160 | // ``JitFuture`` ptr, just like send(). Caller can specify the maximum |
161 | // number of retries for this RPC (default is 5), initial duration between |
162 | // sends (default is 1000ms), and backoff constant (default is 1.5) by |
163 | // passing in the RpcRetryOptions struct. This API might end up |
164 | // executing a method twice on the remote end (it does not guarantee |
165 | // exactly-once semantics). Therefore, the user must ensure their requests |
166 | // are idempotent. |
167 | c10::intrusive_ptr<JitFuture> sendWithRetries( |
168 | const WorkerInfo& to, |
169 | c10::intrusive_ptr<Message> message, |
170 | RpcRetryOptions retryOptions = RpcRetryOptions()); |
171 | |
172 | // Return a reference to the ``WorkerInfo`` of this RpcAgent. |
173 | // NB: not using ``c10::optional<const std::string&>`` here because we might |
174 | // need to create a separate RPC API lib and avoid forcing all ``RpcAgent`` |
175 | // implementations to depend on libtorch. |
176 | const WorkerInfo& getWorkerInfo() const; |
177 | |
178 | // Return a reference to the ``WorkerInfo`` of the given ``workerName``. |
179 | virtual const WorkerInfo& getWorkerInfo( |
180 | const std::string& workerName) const = 0; |
181 | |
182 | virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0; |
183 | |
184 | virtual std::vector<WorkerInfo> getWorkerInfos() const = 0; |
185 | |
186 | // Retrieve the timeout for all RPCs. |
187 | inline std::chrono::milliseconds getRpcTimeout() const { |
188 | return rpcTimeout_.load(); |
189 | } |
190 | |
191 | // Set the timeout for all RPCs |
192 | inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) { |
193 | rpcTimeout_.store(rpcTimeout); |
194 | } |
195 | |
196 | // Call sync and join all internal threads. This method should be called |
197 | // before every RPC process exits. |
198 | virtual void join(bool shutdown = false, float timeout = 0) = 0; |
199 | |
200 | // Synchronize the this process with other ``RpcAgent`` processes. Block until |
201 | // all ``RpcAgent``s reach this method and send all pending messages. |
202 | virtual void sync() = 0; |
203 | |
204 | // Sets up backend-agnostic state for accepting requests. Currently, this |
205 | // entails setting rpcAgentRunning_ to true, creating the retry thread, and |
206 | // calling the backend's startImpl. |
207 | void start(); |
208 | |
209 | // Derived classes must override this function to start accepting requests. |
210 | // This is used to initialize any backend-specific state. Users must call |
211 | // start, not startImpl, to initialize the RPC Agent. |
212 | virtual void startImpl() = 0; |
213 | |
214 | // Stop accepting requests and shutdown the RPC framework as soon as possible |
215 | // by terminating all RPC threads. |
216 | void shutdown(); |
217 | |
218 | // Derived classes must override this function to start accepting requests. |
219 | // THis is used to clean up any backend-specific state. Users must call |
220 | // shutdown, not shutdownImpl, to shutdown the RPC Agent. |
221 | virtual void shutdownImpl() = 0; |
222 | |
223 | // Check if current RPC agent is set. |
224 | static bool isCurrentRpcAgentSet(); |
225 | |
226 | // Retrieve the valid current RPC agent. |
227 | static std::shared_ptr<RpcAgent> getCurrentRpcAgent(); |
228 | |
229 | // Set the current RPC agent. |
230 | static void setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent); |
231 | |
232 | // Retrieve metrics as KV map |
233 | virtual std::unordered_map<std::string, std::string> getMetrics() = 0; |
234 | |
235 | // Retrive debug info in addition to metrics as KV map |
236 | virtual std::unordered_map<std::string, std::string> getDebugInfo(); |
237 | |
238 | // Flag to control whether GIL wait times |
239 | // should be profiled or not. |
240 | void enableGILProfiling(bool flag); |
241 | |
242 | // Retrieve wheher we should profile GIL wait times or not. |
243 | bool isGILProfilingEnabled(); |
244 | |
245 | // Set type resolver that will be passed to JIT pickler to resolver type Ptr |
246 | // based on type str. |
247 | void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver); |
248 | |
249 | // Get the type resolver |
250 | std::shared_ptr<TypeResolver> getTypeResolver(); |
251 | |
252 | // Retrieves the device map for the provided destination worker. |
253 | virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const; |
254 | |
255 | // Retrieve the (non-CPU) devices that are supported by the agent. |
256 | virtual const std::vector<c10::Device>& getDevices() const; |
257 | |
258 | protected: |
259 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
260 | const WorkerInfo workerInfo_; |
261 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
262 | const std::unique_ptr<RequestCallback> cb_; |
263 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
264 | std::atomic<std::chrono::milliseconds> rpcTimeout_; |
265 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
266 | std::atomic<bool> profilingEnabled_; |
267 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
268 | std::shared_ptr<TypeResolver> typeResolver_; |
269 | // Atomic boolean indicating whether this agent is running. It controls |
270 | // whether several background threads should be running. It is set in |
271 | // RpcAgent::start() and unset in the derived class shutdown(). |
272 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
273 | std::atomic<bool> rpcAgentRunning_; |
274 | |
275 | private: |
276 | static std::shared_ptr<RpcAgent> currentRpcAgent_; |
277 | // Add GIL wait time data point to metrics |
278 | virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0; |
279 | friend class PythonRpcHandler; |
280 | |
281 | // Map that stores metadata for RPC's that may need to be re-tried as well as |
282 | // the timepoint at which we should re-try them. |
283 | std::map< |
284 | steady_clock_time_point, |
285 | std::unordered_set<std::shared_ptr<RpcRetryInfo>>> |
286 | rpcRetryMap_; |
287 | |
288 | // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until |
289 | // the next unACKed RPC's timeout has expired. |
290 | std::thread rpcRetryThread_; |
291 | |
292 | // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is |
293 | // running. |
294 | void retryExpiredRpcs(); |
295 | |
296 | // This is the callback attached to futures corresponding to send retries. |
297 | // This handles 3 cases: 1). send was completed, 2). send failed with an |
298 | // error and we've done maxRetries failed send attempts, and 3). send |
299 | // failed with an error and we have more retries to go. In case 1, we mark |
300 | // the original future as complete. In case 2, we mark the future with an |
301 | // error and do not retry again. In case 3, we move the RpcRetryInfo struct |
302 | // to another time point in the map to schedule the RPC for a future send. |
303 | void rpcRetryCallback( |
304 | JitFuture& message, |
305 | steady_clock_time_point newTime, |
306 | std::shared_ptr<RpcRetryInfo> earliestRpc); |
307 | |
308 | // Function that uses the exponential backoff algorithm to compute the next |
309 | // time point to retry a given RPC. |
310 | inline steady_clock_time_point computeNewRpcRetryTime( |
311 | RpcRetryOptions& options, |
312 | int retryCount) { |
313 | // The exponential backoff algorithm being used here is: |
314 | // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)). |
315 | std::chrono::milliseconds timedelta = |
316 | std::chrono::duration_cast<std::chrono::milliseconds>( |
317 | options.rpcRetryDuration * pow(options.retryBackoff, retryCount)); |
318 | return std::chrono::time_point_cast<std::chrono::milliseconds>( |
319 | std::chrono::steady_clock::now() + timedelta); |
320 | } |
321 | |
322 | // Condition Variable to signal when the rpcRetryMap_ has been populated. |
323 | std::condition_variable rpcRetryMapCV_; |
324 | |
325 | // Mutex to protect RpcRetryMap_. |
326 | std::mutex rpcRetryMutex_; |
327 | }; |
328 | |
329 | } // namespace rpc |
330 | } // namespace distributed |
331 | } // namespace torch |
332 | |
333 | namespace std { |
334 | template <> |
335 | struct hash<torch::distributed::rpc::WorkerInfo> { |
336 | std::size_t operator()( |
337 | const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept { |
338 | return worker_info.id_; |
339 | } |
340 | }; |
341 | } // namespace std |
342 | |