1 | #include <torch/csrc/distributed/rpc/python_call.h> |
2 | |
3 | #include <c10/util/C++17.h> |
4 | |
5 | namespace torch { |
6 | namespace distributed { |
7 | namespace rpc { |
8 | |
9 | PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution) |
10 | : serializedPyObj_(std::move(serializedPyObj)), |
11 | isAsyncExecution_(isAsyncExecution) {} |
12 | |
13 | c10::intrusive_ptr<Message> PythonCall::toMessageImpl() && { |
14 | std::vector<char> payload; |
15 | payload.reserve(serializedPyObj_.payload_.length() + 1); |
16 | payload.push_back(isAsyncExecution_ ? 1 : 0); |
17 | payload.insert( |
18 | payload.end(), |
19 | serializedPyObj_.payload_.begin(), |
20 | serializedPyObj_.payload_.end()); |
21 | |
22 | return c10::make_intrusive<Message>( |
23 | std::move(payload), |
24 | std::move(serializedPyObj_.tensors_), |
25 | MessageType::PYTHON_CALL); |
26 | } |
27 | |
28 | std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) { |
29 | TORCH_INTERNAL_ASSERT( |
30 | !message.payload().empty(), |
31 | "Failed to convert an RPC message to PythonCall, the payload should at " |
32 | "least contain one byte indicating whether this is an async function, " |
33 | "but got payload of size " , |
34 | message.payload().size()); |
35 | const char& c = message.payload()[0]; |
36 | TORCH_INTERNAL_ASSERT(c == 0 || c == 1); |
37 | bool isAsyncExecution = (c == 1); |
38 | std::string payload(message.payload().begin() + 1, message.payload().end()); |
39 | std::vector<Tensor> tensors = message.tensors(); |
40 | SerializedPyObj serializedPyObj(std::move(payload), std::move(tensors)); |
41 | return std::make_unique<PythonCall>( |
42 | std::move(serializedPyObj), isAsyncExecution); |
43 | } |
44 | |
45 | const SerializedPyObj& PythonCall::serializedPyObj() const { |
46 | return serializedPyObj_; |
47 | } |
48 | |
49 | } // namespace rpc |
50 | } // namespace distributed |
51 | } // namespace torch |
52 | |