1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | |
5 | namespace torch { |
6 | namespace distributed { |
7 | namespace rpc { |
8 | |
9 | // Functor which is invoked to process an RPC message. This is an abstract class |
10 | // with some common functionality across all request handlers. Users need to |
11 | // implement this interface to perform the actual business logic. |
12 | class TORCH_API RequestCallback { |
13 | public: |
14 | // Invoke the callback. |
15 | c10::intrusive_ptr<JitFuture> operator()( |
16 | Message& request, |
17 | std::vector<c10::Stream> streams) const; |
18 | |
19 | virtual ~RequestCallback() = default; |
20 | |
21 | protected: |
22 | // RpcAgent implementation should invoke ``RequestCallback`` to process |
23 | // received requests. There is no restriction on the implementation's |
24 | // threading model. This function takes an rvalue reference of the Message |
25 | // object. It is expected to return the future to a response message or |
26 | // message containing an exception. Different rpc agent implementations are |
27 | // expected to ensure delivery of the response/exception based on their |
28 | // implementation specific mechanisms. |
29 | virtual c10::intrusive_ptr<JitFuture> processMessage( |
30 | Message& request, |
31 | std::vector<c10::Stream> streams) const = 0; |
32 | }; |
33 | |
34 | } // namespace rpc |
35 | } // namespace distributed |
36 | } // namespace torch |
37 | |