1#pragma once
2
3#include <torch/csrc/distributed/rpc/message.h>
4#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
5#include <torch/csrc/distributed/rpc/rpc_command_base.h>
6#include <torch/csrc/jit/python/pybind.h>
7
8namespace torch {
9namespace distributed {
10namespace rpc {
11
12class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
13 public:
14 std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
15 std::unique_ptr<RpcCommandBase> rpc,
16 const MessageType& messageType) const override;
17
18 c10::intrusive_ptr<JitFuture> processPythonCall(
19 RpcCommandBase& rpc,
20 std::vector<c10::Stream> streams) const override;
21
22 c10::intrusive_ptr<JitFuture> processScriptCall(
23 RpcCommandBase& rpc,
24 std::vector<c10::Stream> streams) const override;
25
26 c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
27 RpcCommandBase& rpc,
28 std::vector<c10::Stream> streams) const override;
29
30 c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
31 RpcCommandBase& rpc,
32 std::vector<c10::Stream> streams) const override;
33
34 c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall(
35 RpcCommandBase& rpc) const override;
36
37 void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
38
39 c10::intrusive_ptr<JitFuture> processRpcWithErrors(
40 RpcCommandBase& rpc,
41 const MessageType& messageType,
42 std::vector<c10::Stream> streams) const override;
43
44 bool cudaAvailable() const override;
45
46 c10::intrusive_ptr<JitFuture> processRRefBackward(
47 RpcCommandBase& rpc) const override;
48
49 // Helpers to run user-defined functions, operators and other computations.
50
51 c10::intrusive_ptr<JitFuture> runJitFunction(
52 const c10::QualifiedName& name,
53 std::vector<at::IValue>& stack,
54 std::vector<c10::Stream> streams,
55 bool isAsyncExecution) const;
56
57 c10::intrusive_ptr<JitFuture> runPythonFunction(
58 const py::object& function,
59 std::vector<c10::Stream> streams,
60 bool isAsyncExecution) const;
61};
62
63} // namespace rpc
64} // namespace distributed
65} // namespace torch
66