1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | #include <torch/csrc/distributed/rpc/request_callback_no_python.h> |
5 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
6 | #include <torch/csrc/jit/python/pybind.h> |
7 | |
8 | namespace torch { |
9 | namespace distributed { |
10 | namespace rpc { |
11 | |
12 | class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { |
13 | public: |
14 | std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand( |
15 | std::unique_ptr<RpcCommandBase> rpc, |
16 | const MessageType& messageType) const override; |
17 | |
18 | c10::intrusive_ptr<JitFuture> processPythonCall( |
19 | RpcCommandBase& rpc, |
20 | std::vector<c10::Stream> streams) const override; |
21 | |
22 | c10::intrusive_ptr<JitFuture> processScriptCall( |
23 | RpcCommandBase& rpc, |
24 | std::vector<c10::Stream> streams) const override; |
25 | |
26 | c10::intrusive_ptr<JitFuture> processScriptRemoteCall( |
27 | RpcCommandBase& rpc, |
28 | std::vector<c10::Stream> streams) const override; |
29 | |
30 | c10::intrusive_ptr<JitFuture> processPythonRemoteCall( |
31 | RpcCommandBase& rpc, |
32 | std::vector<c10::Stream> streams) const override; |
33 | |
34 | c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall( |
35 | RpcCommandBase& rpc) const override; |
36 | |
37 | void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override; |
38 | |
39 | c10::intrusive_ptr<JitFuture> processRpcWithErrors( |
40 | RpcCommandBase& rpc, |
41 | const MessageType& messageType, |
42 | std::vector<c10::Stream> streams) const override; |
43 | |
44 | bool cudaAvailable() const override; |
45 | |
46 | c10::intrusive_ptr<JitFuture> processRRefBackward( |
47 | RpcCommandBase& rpc) const override; |
48 | |
49 | // Helpers to run user-defined functions, operators and other computations. |
50 | |
51 | c10::intrusive_ptr<JitFuture> runJitFunction( |
52 | const c10::QualifiedName& name, |
53 | std::vector<at::IValue>& stack, |
54 | std::vector<c10::Stream> streams, |
55 | bool isAsyncExecution) const; |
56 | |
57 | c10::intrusive_ptr<JitFuture> runPythonFunction( |
58 | const py::object& function, |
59 | std::vector<c10::Stream> streams, |
60 | bool isAsyncExecution) const; |
61 | }; |
62 | |
63 | } // namespace rpc |
64 | } // namespace distributed |
65 | } // namespace torch |
66 | |