1 | #include <torch/csrc/distributed/rpc/python_remote_call.h> |
2 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
3 | #include <torch/csrc/jit/serialization/pickle.h> |
4 | |
5 | namespace torch { |
6 | namespace distributed { |
7 | namespace rpc { |
8 | |
9 | PythonRemoteCall::PythonRemoteCall( |
10 | SerializedPyObj&& serializedPyObj, |
11 | at::IValue retRRefId, |
12 | at::IValue retForkId, |
13 | const bool isAsyncExecution) |
14 | : serializedPyObj_(std::move(serializedPyObj)), |
15 | retRRefId_(std::move(retRRefId)), |
16 | retForkId_(std::move(retForkId)), |
17 | isAsyncExecution_(isAsyncExecution) {} |
18 | |
19 | c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && { |
20 | std::vector<IValue> ivalues = std::move(serializedPyObj_).toIValues(); |
21 | ivalues.emplace_back(retRRefId_); |
22 | ivalues.emplace_back(retForkId_); |
23 | ivalues.emplace_back(isAsyncExecution_); |
24 | |
25 | std::vector<torch::Tensor> tensor_table; |
26 | auto payload = |
27 | jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); |
28 | |
29 | return c10::make_intrusive<Message>( |
30 | std::move(payload), |
31 | std::move(tensor_table), |
32 | MessageType::PYTHON_REMOTE_CALL); |
33 | } |
34 | |
35 | std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage( |
36 | const Message& message) { |
37 | auto payload = static_cast<const char*>(message.payload().data()); |
38 | auto payload_size = message.payload().size(); |
39 | |
40 | auto value = jit::unpickle( |
41 | payload, |
42 | payload_size, |
43 | *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
44 | message.tensors()); |
45 | auto values = value.toTupleRef().elements().vec(); |
46 | |
47 | // remove the last elements from values and convert it back to an RRef |
48 | TORCH_INTERNAL_ASSERT( |
49 | values.size() >= 3, |
50 | "Expect at least 3 elements in the unpickled values, but got " , |
51 | values.size()); |
52 | bool isAsyncExecution = values.back().toBool(); |
53 | values.pop_back(); |
54 | auto retForkId = std::move(values.back()); |
55 | values.pop_back(); |
56 | auto retRRefId = std::move(values.back()); |
57 | values.pop_back(); |
58 | auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values)); |
59 | |
60 | return std::make_unique<PythonRemoteCall>( |
61 | std::move(serializedPyObj), |
62 | std::move(retRRefId), |
63 | std::move(retForkId), |
64 | isAsyncExecution); |
65 | } |
66 | |
67 | } // namespace rpc |
68 | } // namespace distributed |
69 | } // namespace torch |
70 | |