1#include <torch/csrc/distributed/rpc/rpc_agent.h>
2#include <torch/csrc/distributed/rpc/script_remote_call.h>
3
4#include <c10/util/C++17.h>
5#include <torch/csrc/jit/serialization/pickle.h>
6
7namespace torch {
8namespace distributed {
9namespace rpc {
10
11ScriptRemoteCall::ScriptRemoteCall(
12 std::shared_ptr<Operator> op,
13 std::vector<at::IValue>&& stack,
14 const RRefId& retRRefId,
15 const ForkId& retForkId)
16 : ScriptCall(std::move(op), std::move(stack)),
17 retRRefId_(retRRefId),
18 retForkId_(retForkId) {}
19
20ScriptRemoteCall::ScriptRemoteCall(
21 const c10::QualifiedName& qualifiedName,
22 std::vector<at::IValue>&& stack,
23 const RRefId& retRRefId,
24 const ForkId& retForkId,
25 const bool isAsyncExecution)
26 : ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
27 retRRefId_(retRRefId),
28 retForkId_(retForkId) {}
29
30std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
31 std::vector<at::IValue>& ivalues) {
32 // remove the last element from values and convert it back to an RRef
33 auto retForkId = RRefId::fromIValue(ivalues.back());
34 ivalues.pop_back();
35 auto retRRefId = ForkId::fromIValue(ivalues.back());
36 ivalues.pop_back();
37
38 auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
39
40 if (scriptCallPtr->hasOp()) {
41 return std::make_unique<ScriptRemoteCall>(
42 scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
43 } else {
44 return std::make_unique<ScriptRemoteCall>(
45 scriptCallPtr->qualifiedName(),
46 std::move(ivalues),
47 retRRefId,
48 retForkId,
49 scriptCallPtr->isAsyncExecution());
50 }
51}
52
53c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
54 std::vector<IValue> ivalues;
55 ScriptCall::toIValues(ivalues);
56 ivalues.emplace_back(retRRefId_.toIValue());
57 ivalues.emplace_back(retForkId_.toIValue());
58
59 std::vector<torch::Tensor> tensor_table;
60 auto payload = jit::pickle(
61 c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
62
63 return c10::make_intrusive<Message>(
64 std::move(payload),
65 std::move(tensor_table),
66 MessageType::SCRIPT_REMOTE_CALL);
67}
68
69std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
70 const Message& message) {
71 auto payload = static_cast<const char*>(message.payload().data());
72 auto payload_size = message.payload().size();
73
74 auto value = jit::unpickle(
75 payload,
76 payload_size,
77 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
78 message.tensors());
79 auto values = value.toTupleRef().elements().vec();
80 return fromIValues(values);
81}
82
83} // namespace rpc
84} // namespace distributed
85} // namespace torch
86