1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
4 | #include <torch/csrc/distributed/rpc/types.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | |
7 | namespace torch { |
8 | namespace distributed { |
9 | namespace rpc { |
10 | |
11 | // This class converts the content in a PythonCall into py::object. This is a |
12 | // helper class to make sure that all arguments deserialization is done before |
13 | // entering RequestCallbackImpl::processRpc(...), so that the deserialization |
14 | // related logic can be carried out in one spot instead of scattered in multiple |
15 | // places for different message types. |
16 | // NB: The reason for not consolidating class into PythonCall is because |
17 | // PythonCall is a libtorch type which should not depend on Python types. |
18 | class TORCH_API UnpickledPythonCall : public RpcCommandBase { |
19 | public: |
20 | UnpickledPythonCall( |
21 | const SerializedPyObj& serializedPyObj, |
22 | bool isAsyncExecution); |
23 | ~UnpickledPythonCall() override; |
24 | |
25 | // toMessage() method is not implemented, as objects of this class should |
26 | // never be directly converted into a Message object. |
27 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
28 | const py::object& pythonUdf() const; |
29 | |
30 | inline bool isAsyncExecution() const { |
31 | return isAsyncExecution_; |
32 | } |
33 | |
34 | private: |
35 | py::object pythonUdf_; |
36 | const bool isAsyncExecution_; |
37 | }; |
38 | |
39 | } // namespace rpc |
40 | } // namespace distributed |
41 | } // namespace torch |
42 | |