1#include <torch/csrc/distributed/rpc/python_remote_call.h>
2#include <torch/csrc/distributed/rpc/rpc_agent.h>
3#include <torch/csrc/jit/serialization/pickle.h>
4
5namespace torch {
6namespace distributed {
7namespace rpc {
8
9PythonRemoteCall::PythonRemoteCall(
10 SerializedPyObj&& serializedPyObj,
11 at::IValue retRRefId,
12 at::IValue retForkId,
13 const bool isAsyncExecution)
14 : serializedPyObj_(std::move(serializedPyObj)),
15 retRRefId_(std::move(retRRefId)),
16 retForkId_(std::move(retForkId)),
17 isAsyncExecution_(isAsyncExecution) {}
18
19c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
20 std::vector<IValue> ivalues = std::move(serializedPyObj_).toIValues();
21 ivalues.emplace_back(retRRefId_);
22 ivalues.emplace_back(retForkId_);
23 ivalues.emplace_back(isAsyncExecution_);
24
25 std::vector<torch::Tensor> tensor_table;
26 auto payload =
27 jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
28
29 return c10::make_intrusive<Message>(
30 std::move(payload),
31 std::move(tensor_table),
32 MessageType::PYTHON_REMOTE_CALL);
33}
34
35std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
36 const Message& message) {
37 auto payload = static_cast<const char*>(message.payload().data());
38 auto payload_size = message.payload().size();
39
40 auto value = jit::unpickle(
41 payload,
42 payload_size,
43 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
44 message.tensors());
45 auto values = value.toTupleRef().elements().vec();
46
47 // remove the last elements from values and convert it back to an RRef
48 TORCH_INTERNAL_ASSERT(
49 values.size() >= 3,
50 "Expect at least 3 elements in the unpickled values, but got ",
51 values.size());
52 bool isAsyncExecution = values.back().toBool();
53 values.pop_back();
54 auto retForkId = std::move(values.back());
55 values.pop_back();
56 auto retRRefId = std::move(values.back());
57 values.pop_back();
58 auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
59
60 return std::make_unique<PythonRemoteCall>(
61 std::move(serializedPyObj),
62 std::move(retRRefId),
63 std::move(retForkId),
64 isAsyncExecution);
65}
66
67} // namespace rpc
68} // namespace distributed
69} // namespace torch
70