1#include <torch/csrc/distributed/rpc/script_resp.h>
2
3#include <c10/util/C++17.h>
4#include <torch/csrc/distributed/rpc/rpc_agent.h>
5#include <torch/csrc/jit/serialization/pickle.h>
6#include <torch/csrc/jit/serialization/unpickler.h>
7
8namespace torch {
9namespace distributed {
10namespace rpc {
11
12ScriptResp::ScriptResp(at::IValue&& value) : value_(value) {}
13
14const at::IValue& ScriptResp::value() {
15 return value_;
16}
17
18c10::intrusive_ptr<Message> ScriptResp::toMessageImpl() && {
19 std::vector<torch::Tensor> tensor_table;
20 auto payload = jit::pickle(value_, &tensor_table);
21 return c10::make_intrusive<Message>(
22 std::move(payload), std::move(tensor_table), MessageType::SCRIPT_RET);
23}
24
25std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
26 auto payload = static_cast<const char*>(message.payload().data());
27 auto payload_size = message.payload().size();
28 auto value = jit::unpickle(
29 payload,
30 payload_size,
31 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
32 message.tensors());
33 return std::make_unique<ScriptResp>(std::move(value));
34}
35
36} // namespace rpc
37} // namespace distributed
38} // namespace torch
39