1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
4 | #include <torch/csrc/distributed/rpc/types.h> |
5 | #include <torch/csrc/distributed/rpc/unpickled_python_call.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | namespace torch { |
9 | namespace distributed { |
10 | namespace rpc { |
11 | |
12 | // This class converts the content in a PythonRemoteCall into py::object. This |
13 | // is a helper class to make sure that all arguments deserialization is done |
14 | // before entering RequestCallbackImpl::processRpc(...), so that the |
15 | // deserialization related logic can be carried out in one spot instead of |
16 | // scattered in multiple places for different message types. |
17 | // NB: The reason for not consolidating class into PythonRemoteCall is because |
18 | // PythonRemoteCall is a libtorch type which should not depend on Python types. |
19 | class TORCH_API UnpickledPythonRemoteCall final : public UnpickledPythonCall { |
20 | public: |
21 | explicit UnpickledPythonRemoteCall( |
22 | const SerializedPyObj& serializedPyObj, |
23 | const at::IValue& retRRefId, |
24 | const at::IValue& retForkId, |
25 | const bool isAsyncExecution); |
26 | |
27 | const RRefId& rrefId() const; |
28 | const ForkId& forkId() const; |
29 | |
30 | private: |
31 | RRefId rrefId_; |
32 | ForkId forkId_; |
33 | }; |
34 | |
35 | } // namespace rpc |
36 | } // namespace distributed |
37 | } // namespace torch |
38 | |