1 | #include <torch/csrc/distributed/rpc/script_resp.h> |
---|---|
2 | |
3 | #include <c10/util/C++17.h> |
4 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
5 | #include <torch/csrc/jit/serialization/pickle.h> |
6 | #include <torch/csrc/jit/serialization/unpickler.h> |
7 | |
8 | namespace torch { |
9 | namespace distributed { |
10 | namespace rpc { |
11 | |
12 | ScriptResp::ScriptResp(at::IValue&& value) : value_(value) {} |
13 | |
14 | const at::IValue& ScriptResp::value() { |
15 | return value_; |
16 | } |
17 | |
18 | c10::intrusive_ptr<Message> ScriptResp::toMessageImpl() && { |
19 | std::vector<torch::Tensor> tensor_table; |
20 | auto payload = jit::pickle(value_, &tensor_table); |
21 | return c10::make_intrusive<Message>( |
22 | std::move(payload), std::move(tensor_table), MessageType::SCRIPT_RET); |
23 | } |
24 | |
25 | std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) { |
26 | auto payload = static_cast<const char*>(message.payload().data()); |
27 | auto payload_size = message.payload().size(); |
28 | auto value = jit::unpickle( |
29 | payload, |
30 | payload_size, |
31 | *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
32 | message.tensors()); |
33 | return std::make_unique<ScriptResp>(std::move(value)); |
34 | } |
35 | |
36 | } // namespace rpc |
37 | } // namespace distributed |
38 | } // namespace torch |
39 |