1#pragma once
2
3#include <torch/csrc/distributed/rpc/message.h>
4#include <torch/csrc/distributed/rpc/request_callback.h>
5#include <torch/csrc/distributed/rpc/rpc_command_base.h>
6#include <torch/csrc/distributed/rpc/rref_impl.h>
7#include <torch/csrc/distributed/rpc/script_call.h>
8#include <torch/csrc/distributed/rpc/script_remote_call.h>
9
10namespace torch {
11namespace distributed {
12namespace rpc {
13
14// RequestCallback implementation with no Python dependencies.
15class TORCH_API RequestCallbackNoPython : public RequestCallback {
16 public:
17 c10::intrusive_ptr<JitFuture> processMessage(
18 Message& request,
19 std::vector<c10::Stream> streams) const override;
20
21 protected:
22 virtual std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
23 std::unique_ptr<RpcCommandBase> rpc,
24 const MessageType& messageType) const;
25
26 virtual c10::intrusive_ptr<JitFuture> processScriptCall(
27 RpcCommandBase& rpc,
28 std::vector<c10::Stream> streams) const;
29
30 virtual c10::intrusive_ptr<JitFuture> processPythonCall(
31 RpcCommandBase& rpc,
32 std::vector<c10::Stream> streams) const;
33
34 c10::intrusive_ptr<JitFuture> assignOwnerRRef(
35 const RRefId& rrefId,
36 const RRefId& forkId,
37 c10::intrusive_ptr<JitFuture> valueFuture) const;
38
39 virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
40 RpcCommandBase& rpc,
41 std::vector<c10::Stream> streams) const;
42
43 virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
44 RpcCommandBase& rpc,
45 std::vector<c10::Stream> streams) const;
46
47 c10::intrusive_ptr<JitFuture> retrieveOwnerRRef(const RRefId& rrefId) const;
48
49 c10::intrusive_ptr<JitFuture> processScriptRRefFetchCall(
50 RpcCommandBase& rpc) const;
51
52 virtual c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall(
53 RpcCommandBase& rpc) const;
54
55 c10::intrusive_ptr<JitFuture> processRRefUserDelete(
56 RpcCommandBase& rpc) const;
57
58 c10::intrusive_ptr<JitFuture> processRRefChildAccept(
59 RpcCommandBase& rpc) const;
60
61 c10::intrusive_ptr<JitFuture> processRRefForkRequest(
62 RpcCommandBase& rpc) const;
63
64 c10::intrusive_ptr<JitFuture> processForwardAutogradReq(
65 RpcCommandBase& rpc,
66 std::vector<c10::Stream> streams) const;
67
68 c10::intrusive_ptr<JitFuture> processBackwardAutogradReq(
69 RpcCommandBase& rpc,
70 std::vector<c10::Stream> streams) const;
71
72 c10::intrusive_ptr<JitFuture> processCleanupAutogradContextReq(
73 RpcCommandBase& rpc) const;
74
75 c10::intrusive_ptr<JitFuture> processRunWithProfilingReq(
76 RpcCommandBase& rpc) const;
77
78 virtual void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const;
79
80 c10::intrusive_ptr<JitFuture> processRpc(
81 RpcCommandBase& rpc,
82 const MessageType& messageType,
83 std::vector<c10::Stream> streams) const;
84
85 virtual c10::intrusive_ptr<JitFuture> processRpcWithErrors(
86 RpcCommandBase& rpc,
87 const MessageType& messageType,
88 std::vector<c10::Stream> streams) const;
89
90 c10::intrusive_ptr<Message> handleError(
91 const std::exception& e,
92 const MessageType messageType,
93 int64_t messageId) const;
94
95 virtual bool cudaAvailable() const;
96
97 virtual c10::intrusive_ptr<JitFuture> processRRefBackward(
98 RpcCommandBase& rpc) const;
99
100 // Helpers to run user-defined functions, operators and other computations.
101
102 c10::intrusive_ptr<JitFuture> runJitOperator(
103 const jit::Operator& op,
104 std::vector<at::IValue>& stack,
105 std::vector<c10::Stream> streams) const;
106
107 // Helpers to convert various kinds of objects into already-completed futures.
108
109 c10::intrusive_ptr<JitFuture> asFuture(IValue value, TypePtr type) const;
110
111 c10::intrusive_ptr<JitFuture> asFuture(
112 c10::intrusive_ptr<Message> message) const;
113
114 c10::intrusive_ptr<JitFuture> asFuture(std::exception_ptr err) const;
115};
116
117} // namespace rpc
118} // namespace distributed
119} // namespace torch
120