1#include <torch/csrc/distributed/rpc/rpc_agent.h>
2#include <torch/csrc/distributed/rpc/rref_proto.h>
3#include <torch/csrc/jit/serialization/pickle.h>
4
5#include <limits>
6
7namespace torch {
8namespace distributed {
9namespace rpc {
10
11namespace {
12
13c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
14 TORCH_INTERNAL_ASSERT(
15 type == message.type(),
16 "Expecting message of type ",
17 type,
18 ", but got ",
19 message.type());
20 auto payload = static_cast<const char*>(message.payload().data());
21 auto payload_size = message.payload().size();
22
23 auto value = jit::unpickle(
24 payload,
25 payload_size,
26 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
27 message.tensors());
28 return std::move(*std::move(value).toTuple()).elements();
29}
30
31c10::intrusive_ptr<Message> fromIValues(
32 std::vector<IValue> ivalues,
33 MessageType type) {
34 std::vector<torch::Tensor> tensor_table;
35 auto payload = jit::pickle(
36 c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
37 return c10::make_intrusive<Message>(
38 std::move(payload), std::move(tensor_table), type);
39}
40
41} // namespace
42
43/////////////////////////// RRefMessageBase //////////////////////////////////
44
45const RRefId& RRefMessageBase::rrefId() {
46 return rrefId_;
47}
48
49/////////////////////////// ForkMessageBase //////////////////////////////////
50
51const ForkId& ForkMessageBase::forkId() {
52 return forkId_;
53}
54
55c10::intrusive_ptr<Message> ForkMessageBase::toMessageImpl() && {
56 return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_);
57}
58
59std::pair<RRefId, ForkId> ForkMessageBase::fromMessage(
60 const Message& message,
61 MessageType type) {
62 auto ivalues = toIValues(message, type);
63
64 TORCH_INTERNAL_ASSERT(
65 ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message.");
66
67 return std::make_pair(
68 RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1]));
69}
70
71/////////////////////////// RRef Protocol //////////////////////////////////
72
73c10::intrusive_ptr<Message> ScriptRRefFetchCall::toMessageImpl() && {
74 std::vector<at::IValue> ivalues;
75 ivalues.reserve(2);
76 ivalues.emplace_back(rrefId_.toIValue());
77 ivalues.emplace_back(fromWorkerId_);
78 return fromIValues(std::move(ivalues), MessageType::SCRIPT_RREF_FETCH_CALL);
79}
80
81std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
82 const Message& message) {
83 auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_CALL);
84 TORCH_INTERNAL_ASSERT(
85 values.size() == 2, "ScriptRRefFetchCall expects 2 IValues from message");
86 auto id = values[1].toInt();
87 TORCH_INTERNAL_ASSERT(
88 id >= std::numeric_limits<worker_id_t>::min() &&
89 id <= std::numeric_limits<worker_id_t>::max(),
90 "ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
91 return std::make_unique<ScriptRRefFetchCall>(
92 worker_id_t(id), RRefId::fromIValue(values[0]));
93}
94
95c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
96 std::vector<at::IValue> ivalues;
97 ivalues.reserve(2);
98 ivalues.emplace_back(rrefId_.toIValue());
99 ivalues.emplace_back(fromWorkerId_);
100 return fromIValues(std::move(ivalues), MessageType::PYTHON_RREF_FETCH_CALL);
101}
102
103std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
104 const Message& message) {
105 auto values = toIValues(message, MessageType::PYTHON_RREF_FETCH_CALL);
106 TORCH_INTERNAL_ASSERT(
107 values.size() == 2, "PythonRRefFetchCall expects 2 IValues from message");
108 auto id = values[1].toInt();
109 TORCH_INTERNAL_ASSERT(
110 id >= std::numeric_limits<worker_id_t>::min() &&
111 id <= std::numeric_limits<worker_id_t>::max(),
112 "PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
113 return std::make_unique<PythonRRefFetchCall>(
114 worker_id_t(id), RRefId::fromIValue(values[0]));
115}
116
117const std::vector<at::IValue>& RRefFetchRet::values() {
118 return values_;
119}
120
121c10::intrusive_ptr<Message> RRefFetchRet::toMessageImpl() && {
122 return fromIValues(values_, type_);
123}
124
125std::unique_ptr<ScriptRRefFetchRet> ScriptRRefFetchRet::fromMessage(
126 const Message& message) {
127 auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_RET);
128 TORCH_INTERNAL_ASSERT(
129 values.size() == 1,
130 "RRef of IValue should contain a single IValue, but got ",
131 values.size());
132 return std::make_unique<ScriptRRefFetchRet>(std::move(values).vec());
133}
134
135std::unique_ptr<PythonRRefFetchRet> PythonRRefFetchRet::fromMessage(
136 const Message& message) {
137 return std::make_unique<PythonRRefFetchRet>(
138 toIValues(message, MessageType::PYTHON_RREF_FETCH_RET).vec());
139}
140
141std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
142 const Message& message) {
143 auto pair =
144 ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
145 return std::make_unique<RRefUserDelete>(
146 RRefUserDelete(pair.first, pair.second));
147}
148
149std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
150 auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET);
151 return std::make_unique<RemoteRet>(pair.first, pair.second);
152}
153
154const ForkId& RRefChildAccept::forkId() const {
155 return forkId_;
156}
157
158c10::intrusive_ptr<Message> RRefChildAccept::toMessageImpl() && {
159 return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT);
160}
161
162std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage(
163 const Message& message) {
164 auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT);
165 TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message.");
166
167 return std::make_unique<RRefChildAccept>(ForkId::fromIValue(values.back()));
168}
169
170std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage(
171 const Message& message) {
172 auto pair =
173 ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST);
174 return std::make_unique<RRefForkRequest>(pair.first, pair.second);
175}
176
177c10::intrusive_ptr<Message> RRefAck::toMessageImpl() && {
178 return c10::make_intrusive<Message>(
179 std::vector<char>{}, std::vector<torch::Tensor>{}, MessageType::RREF_ACK);
180}
181
182std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) {
183 TORCH_INTERNAL_ASSERT(
184 message.type() == MessageType::RREF_ACK,
185 "Message type miss match, expect ",
186 MessageType::RREF_ACK,
187 ", but got ",
188 message.type());
189 return std::make_unique<RRefAck>();
190}
191
192} // namespace rpc
193} // namespace distributed
194} // namespace torch
195