1 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
2 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
3 | |
4 | #include <c10/util/C++17.h> |
5 | #include <torch/csrc/jit/serialization/pickle.h> |
6 | |
7 | namespace torch { |
8 | namespace distributed { |
9 | namespace rpc { |
10 | |
11 | ScriptRemoteCall::ScriptRemoteCall( |
12 | std::shared_ptr<Operator> op, |
13 | std::vector<at::IValue>&& stack, |
14 | const RRefId& retRRefId, |
15 | const ForkId& retForkId) |
16 | : ScriptCall(std::move(op), std::move(stack)), |
17 | retRRefId_(retRRefId), |
18 | retForkId_(retForkId) {} |
19 | |
20 | ScriptRemoteCall::ScriptRemoteCall( |
21 | const c10::QualifiedName& qualifiedName, |
22 | std::vector<at::IValue>&& stack, |
23 | const RRefId& retRRefId, |
24 | const ForkId& retForkId, |
25 | const bool isAsyncExecution) |
26 | : ScriptCall(qualifiedName, std::move(stack), isAsyncExecution), |
27 | retRRefId_(retRRefId), |
28 | retForkId_(retForkId) {} |
29 | |
30 | std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues( |
31 | std::vector<at::IValue>& ivalues) { |
32 | // remove the last element from values and convert it back to an RRef |
33 | auto retForkId = RRefId::fromIValue(ivalues.back()); |
34 | ivalues.pop_back(); |
35 | auto retRRefId = ForkId::fromIValue(ivalues.back()); |
36 | ivalues.pop_back(); |
37 | |
38 | auto scriptCallPtr = ScriptCall::fromIValues(ivalues); |
39 | |
40 | if (scriptCallPtr->hasOp()) { |
41 | return std::make_unique<ScriptRemoteCall>( |
42 | scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId); |
43 | } else { |
44 | return std::make_unique<ScriptRemoteCall>( |
45 | scriptCallPtr->qualifiedName(), |
46 | std::move(ivalues), |
47 | retRRefId, |
48 | retForkId, |
49 | scriptCallPtr->isAsyncExecution()); |
50 | } |
51 | } |
52 | |
53 | c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && { |
54 | std::vector<IValue> ivalues; |
55 | ScriptCall::toIValues(ivalues); |
56 | ivalues.emplace_back(retRRefId_.toIValue()); |
57 | ivalues.emplace_back(retForkId_.toIValue()); |
58 | |
59 | std::vector<torch::Tensor> tensor_table; |
60 | auto payload = jit::pickle( |
61 | c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); |
62 | |
63 | return c10::make_intrusive<Message>( |
64 | std::move(payload), |
65 | std::move(tensor_table), |
66 | MessageType::SCRIPT_REMOTE_CALL); |
67 | } |
68 | |
69 | std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage( |
70 | const Message& message) { |
71 | auto payload = static_cast<const char*>(message.payload().data()); |
72 | auto payload_size = message.payload().size(); |
73 | |
74 | auto value = jit::unpickle( |
75 | payload, |
76 | payload_size, |
77 | *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
78 | message.tensors()); |
79 | auto values = value.toTupleRef().elements().vec(); |
80 | return fromIValues(values); |
81 | } |
82 | |
83 | } // namespace rpc |
84 | } // namespace distributed |
85 | } // namespace torch |
86 | |