1 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
2 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
3 | #include <torch/csrc/jit/serialization/pickle.h> |
4 | |
5 | #include <limits> |
6 | |
7 | namespace torch { |
8 | namespace distributed { |
9 | namespace rpc { |
10 | |
11 | namespace { |
12 | |
13 | c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) { |
14 | TORCH_INTERNAL_ASSERT( |
15 | type == message.type(), |
16 | "Expecting message of type " , |
17 | type, |
18 | ", but got " , |
19 | message.type()); |
20 | auto payload = static_cast<const char*>(message.payload().data()); |
21 | auto payload_size = message.payload().size(); |
22 | |
23 | auto value = jit::unpickle( |
24 | payload, |
25 | payload_size, |
26 | *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
27 | message.tensors()); |
28 | return std::move(*std::move(value).toTuple()).elements(); |
29 | } |
30 | |
31 | c10::intrusive_ptr<Message> fromIValues( |
32 | std::vector<IValue> ivalues, |
33 | MessageType type) { |
34 | std::vector<torch::Tensor> tensor_table; |
35 | auto payload = jit::pickle( |
36 | c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); |
37 | return c10::make_intrusive<Message>( |
38 | std::move(payload), std::move(tensor_table), type); |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | /////////////////////////// RRefMessageBase ////////////////////////////////// |
44 | |
45 | const RRefId& RRefMessageBase::rrefId() { |
46 | return rrefId_; |
47 | } |
48 | |
49 | /////////////////////////// ForkMessageBase ////////////////////////////////// |
50 | |
51 | const ForkId& ForkMessageBase::forkId() { |
52 | return forkId_; |
53 | } |
54 | |
55 | c10::intrusive_ptr<Message> ForkMessageBase::toMessageImpl() && { |
56 | return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_); |
57 | } |
58 | |
59 | std::pair<RRefId, ForkId> ForkMessageBase::fromMessage( |
60 | const Message& message, |
61 | MessageType type) { |
62 | auto ivalues = toIValues(message, type); |
63 | |
64 | TORCH_INTERNAL_ASSERT( |
65 | ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message." ); |
66 | |
67 | return std::make_pair( |
68 | RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1])); |
69 | } |
70 | |
71 | /////////////////////////// RRef Protocol ////////////////////////////////// |
72 | |
73 | c10::intrusive_ptr<Message> ScriptRRefFetchCall::toMessageImpl() && { |
74 | std::vector<at::IValue> ivalues; |
75 | ivalues.reserve(2); |
76 | ivalues.emplace_back(rrefId_.toIValue()); |
77 | ivalues.emplace_back(fromWorkerId_); |
78 | return fromIValues(std::move(ivalues), MessageType::SCRIPT_RREF_FETCH_CALL); |
79 | } |
80 | |
81 | std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage( |
82 | const Message& message) { |
83 | auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_CALL); |
84 | TORCH_INTERNAL_ASSERT( |
85 | values.size() == 2, "ScriptRRefFetchCall expects 2 IValues from message" ); |
86 | auto id = values[1].toInt(); |
87 | TORCH_INTERNAL_ASSERT( |
88 | id >= std::numeric_limits<worker_id_t>::min() && |
89 | id <= std::numeric_limits<worker_id_t>::max(), |
90 | "ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit." ) |
91 | return std::make_unique<ScriptRRefFetchCall>( |
92 | worker_id_t(id), RRefId::fromIValue(values[0])); |
93 | } |
94 | |
95 | c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && { |
96 | std::vector<at::IValue> ivalues; |
97 | ivalues.reserve(2); |
98 | ivalues.emplace_back(rrefId_.toIValue()); |
99 | ivalues.emplace_back(fromWorkerId_); |
100 | return fromIValues(std::move(ivalues), MessageType::PYTHON_RREF_FETCH_CALL); |
101 | } |
102 | |
103 | std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage( |
104 | const Message& message) { |
105 | auto values = toIValues(message, MessageType::PYTHON_RREF_FETCH_CALL); |
106 | TORCH_INTERNAL_ASSERT( |
107 | values.size() == 2, "PythonRRefFetchCall expects 2 IValues from message" ); |
108 | auto id = values[1].toInt(); |
109 | TORCH_INTERNAL_ASSERT( |
110 | id >= std::numeric_limits<worker_id_t>::min() && |
111 | id <= std::numeric_limits<worker_id_t>::max(), |
112 | "PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit." ) |
113 | return std::make_unique<PythonRRefFetchCall>( |
114 | worker_id_t(id), RRefId::fromIValue(values[0])); |
115 | } |
116 | |
117 | const std::vector<at::IValue>& RRefFetchRet::values() { |
118 | return values_; |
119 | } |
120 | |
121 | c10::intrusive_ptr<Message> RRefFetchRet::toMessageImpl() && { |
122 | return fromIValues(values_, type_); |
123 | } |
124 | |
125 | std::unique_ptr<ScriptRRefFetchRet> ScriptRRefFetchRet::fromMessage( |
126 | const Message& message) { |
127 | auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_RET); |
128 | TORCH_INTERNAL_ASSERT( |
129 | values.size() == 1, |
130 | "RRef of IValue should contain a single IValue, but got " , |
131 | values.size()); |
132 | return std::make_unique<ScriptRRefFetchRet>(std::move(values).vec()); |
133 | } |
134 | |
135 | std::unique_ptr<PythonRRefFetchRet> PythonRRefFetchRet::fromMessage( |
136 | const Message& message) { |
137 | return std::make_unique<PythonRRefFetchRet>( |
138 | toIValues(message, MessageType::PYTHON_RREF_FETCH_RET).vec()); |
139 | } |
140 | |
141 | std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage( |
142 | const Message& message) { |
143 | auto pair = |
144 | ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE); |
145 | return std::make_unique<RRefUserDelete>( |
146 | RRefUserDelete(pair.first, pair.second)); |
147 | } |
148 | |
149 | std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) { |
150 | auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET); |
151 | return std::make_unique<RemoteRet>(pair.first, pair.second); |
152 | } |
153 | |
154 | const ForkId& RRefChildAccept::forkId() const { |
155 | return forkId_; |
156 | } |
157 | |
158 | c10::intrusive_ptr<Message> RRefChildAccept::toMessageImpl() && { |
159 | return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT); |
160 | } |
161 | |
162 | std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage( |
163 | const Message& message) { |
164 | auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT); |
165 | TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message." ); |
166 | |
167 | return std::make_unique<RRefChildAccept>(ForkId::fromIValue(values.back())); |
168 | } |
169 | |
170 | std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage( |
171 | const Message& message) { |
172 | auto pair = |
173 | ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST); |
174 | return std::make_unique<RRefForkRequest>(pair.first, pair.second); |
175 | } |
176 | |
177 | c10::intrusive_ptr<Message> RRefAck::toMessageImpl() && { |
178 | return c10::make_intrusive<Message>( |
179 | std::vector<char>{}, std::vector<torch::Tensor>{}, MessageType::RREF_ACK); |
180 | } |
181 | |
182 | std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) { |
183 | TORCH_INTERNAL_ASSERT( |
184 | message.type() == MessageType::RREF_ACK, |
185 | "Message type miss match, expect " , |
186 | MessageType::RREF_ACK, |
187 | ", but got " , |
188 | message.type()); |
189 | return std::make_unique<RRefAck>(); |
190 | } |
191 | |
192 | } // namespace rpc |
193 | } // namespace distributed |
194 | } // namespace torch |
195 | |