1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | #include <torch/csrc/distributed/rpc/request_callback.h> |
5 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
6 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
7 | #include <torch/csrc/distributed/rpc/script_call.h> |
8 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
9 | |
10 | namespace torch { |
11 | namespace distributed { |
12 | namespace rpc { |
13 | |
14 | // RequestCallback implementation with no Python dependencies. |
15 | class TORCH_API RequestCallbackNoPython : public RequestCallback { |
16 | public: |
17 | c10::intrusive_ptr<JitFuture> processMessage( |
18 | Message& request, |
19 | std::vector<c10::Stream> streams) const override; |
20 | |
21 | protected: |
22 | virtual std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand( |
23 | std::unique_ptr<RpcCommandBase> rpc, |
24 | const MessageType& messageType) const; |
25 | |
26 | virtual c10::intrusive_ptr<JitFuture> processScriptCall( |
27 | RpcCommandBase& rpc, |
28 | std::vector<c10::Stream> streams) const; |
29 | |
30 | virtual c10::intrusive_ptr<JitFuture> processPythonCall( |
31 | RpcCommandBase& rpc, |
32 | std::vector<c10::Stream> streams) const; |
33 | |
34 | c10::intrusive_ptr<JitFuture> assignOwnerRRef( |
35 | const RRefId& rrefId, |
36 | const RRefId& forkId, |
37 | c10::intrusive_ptr<JitFuture> valueFuture) const; |
38 | |
39 | virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall( |
40 | RpcCommandBase& rpc, |
41 | std::vector<c10::Stream> streams) const; |
42 | |
43 | virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall( |
44 | RpcCommandBase& rpc, |
45 | std::vector<c10::Stream> streams) const; |
46 | |
47 | c10::intrusive_ptr<JitFuture> retrieveOwnerRRef(const RRefId& rrefId) const; |
48 | |
49 | c10::intrusive_ptr<JitFuture> processScriptRRefFetchCall( |
50 | RpcCommandBase& rpc) const; |
51 | |
52 | virtual c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall( |
53 | RpcCommandBase& rpc) const; |
54 | |
55 | c10::intrusive_ptr<JitFuture> processRRefUserDelete( |
56 | RpcCommandBase& rpc) const; |
57 | |
58 | c10::intrusive_ptr<JitFuture> processRRefChildAccept( |
59 | RpcCommandBase& rpc) const; |
60 | |
61 | c10::intrusive_ptr<JitFuture> processRRefForkRequest( |
62 | RpcCommandBase& rpc) const; |
63 | |
64 | c10::intrusive_ptr<JitFuture> processForwardAutogradReq( |
65 | RpcCommandBase& rpc, |
66 | std::vector<c10::Stream> streams) const; |
67 | |
68 | c10::intrusive_ptr<JitFuture> processBackwardAutogradReq( |
69 | RpcCommandBase& rpc, |
70 | std::vector<c10::Stream> streams) const; |
71 | |
72 | c10::intrusive_ptr<JitFuture> processCleanupAutogradContextReq( |
73 | RpcCommandBase& rpc) const; |
74 | |
75 | c10::intrusive_ptr<JitFuture> processRunWithProfilingReq( |
76 | RpcCommandBase& rpc) const; |
77 | |
78 | virtual void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const; |
79 | |
80 | c10::intrusive_ptr<JitFuture> processRpc( |
81 | RpcCommandBase& rpc, |
82 | const MessageType& messageType, |
83 | std::vector<c10::Stream> streams) const; |
84 | |
85 | virtual c10::intrusive_ptr<JitFuture> processRpcWithErrors( |
86 | RpcCommandBase& rpc, |
87 | const MessageType& messageType, |
88 | std::vector<c10::Stream> streams) const; |
89 | |
90 | c10::intrusive_ptr<Message> handleError( |
91 | const std::exception& e, |
92 | const MessageType messageType, |
93 | int64_t messageId) const; |
94 | |
95 | virtual bool cudaAvailable() const; |
96 | |
97 | virtual c10::intrusive_ptr<JitFuture> processRRefBackward( |
98 | RpcCommandBase& rpc) const; |
99 | |
100 | // Helpers to run user-defined functions, operators and other computations. |
101 | |
102 | c10::intrusive_ptr<JitFuture> runJitOperator( |
103 | const jit::Operator& op, |
104 | std::vector<at::IValue>& stack, |
105 | std::vector<c10::Stream> streams) const; |
106 | |
107 | // Helpers to convert various kinds of objects into already-completed futures. |
108 | |
109 | c10::intrusive_ptr<JitFuture> asFuture(IValue value, TypePtr type) const; |
110 | |
111 | c10::intrusive_ptr<JitFuture> asFuture( |
112 | c10::intrusive_ptr<Message> message) const; |
113 | |
114 | c10::intrusive_ptr<JitFuture> asFuture(std::exception_ptr err) const; |
115 | }; |
116 | |
117 | } // namespace rpc |
118 | } // namespace distributed |
119 | } // namespace torch |
120 | |