1#include <torch/csrc/distributed/rpc/python_call.h>
2
3#include <c10/util/C++17.h>
4
5namespace torch {
6namespace distributed {
7namespace rpc {
8
9PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution)
10 : serializedPyObj_(std::move(serializedPyObj)),
11 isAsyncExecution_(isAsyncExecution) {}
12
13c10::intrusive_ptr<Message> PythonCall::toMessageImpl() && {
14 std::vector<char> payload;
15 payload.reserve(serializedPyObj_.payload_.length() + 1);
16 payload.push_back(isAsyncExecution_ ? 1 : 0);
17 payload.insert(
18 payload.end(),
19 serializedPyObj_.payload_.begin(),
20 serializedPyObj_.payload_.end());
21
22 return c10::make_intrusive<Message>(
23 std::move(payload),
24 std::move(serializedPyObj_.tensors_),
25 MessageType::PYTHON_CALL);
26}
27
28std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
29 TORCH_INTERNAL_ASSERT(
30 !message.payload().empty(),
31 "Failed to convert an RPC message to PythonCall, the payload should at "
32 "least contain one byte indicating whether this is an async function, "
33 "but got payload of size ",
34 message.payload().size());
35 const char& c = message.payload()[0];
36 TORCH_INTERNAL_ASSERT(c == 0 || c == 1);
37 bool isAsyncExecution = (c == 1);
38 std::string payload(message.payload().begin() + 1, message.payload().end());
39 std::vector<Tensor> tensors = message.tensors();
40 SerializedPyObj serializedPyObj(std::move(payload), std::move(tensors));
41 return std::make_unique<PythonCall>(
42 std::move(serializedPyObj), isAsyncExecution);
43}
44
45const SerializedPyObj& PythonCall::serializedPyObj() const {
46 return serializedPyObj_;
47}
48
49} // namespace rpc
50} // namespace distributed
51} // namespace torch
52