1#pragma once
2
3#include <c10/core/Device.h>
4#include <c10/core/Event.h>
5#include <c10/core/Stream.h>
6#include <torch/csrc/autograd/profiler.h>
7#include <torch/csrc/distributed/rpc/rpc_command_base.h>
8#include <torch/csrc/jit/serialization/pickle.h>
9#include <torch/csrc/utils/byte_order.h>
10
11namespace tensorpipe {
12class Message;
13} // namespace tensorpipe
14
15namespace torch {
16namespace distributed {
17namespace rpc {
18
19// Parse error message and return RPCErrorType based on the message.
20TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture);
21// Create an error string given the error description and error type
22TORCH_API std::string makeRPCError(
23 const std::string& rpcErrorStr,
24 RPCErrorType errorType);
25
26// Given an RPC message received as a request over the wire, deserialize it into
27// the appropriate 'RpcCommandBase' type.
28TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest(
29 const Message& request);
30
31// Given an RPC message received as a response over the wire, deserialize it
32// into the appropriate 'RpcCommandBase' type, if the response is
33// FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions
34// to received tensors and set the wrappedMsgType to its wrapped message type.
35TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse(
36 const Message& response,
37 MessageType& wrappedMsgType);
38
39// Given an RPC message received as a response over the wire, deserialize it
40// into the valid IValue if the message is for a script rpc result,
41// otherwise deserialize it into dummy none ivalue that will never be used.
42// In this deserialization, we also attach recv rpc backward functions if
43// needed.
44IValue deserializeResptoIValueInternal(
45 RpcCommandBase& rpc,
46 MessageType messageType);
47TORCH_API IValue deserializeRespToIValue(const Message& message);
48
49// Note: format is subject to change and intended for RPCs.
50// For saving persistently to disk, use torch::save().
51TORCH_API std::string wireSerialize(
52 const std::vector<char>& payload,
53 const std::vector<at::Tensor>& tensors);
54
55TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
56 const void* data,
57 size_t data_size);
58
59// We use vector<char> as the type of blobs because it's what rpc::Message uses
60// for its payload, even though it has the disadvantage that it cannot be
61// allocated with uninitialized memory: it is always zeroed out.
62
63// Some Tensors are effectively views of larger Tensors, where only a small
64// subset of the Storage data is referenced. This normally is good and avoids
65// copies when kept locally, but if we naively push the whole Storage over the
66// wire, we'll end up with excess network traffic. This change clones tensors if
67// we'd save at least half the data, and over a minimum hurdle.
68TORCH_API c10::List<at::Tensor> cloneSparseTensors(
69 const std::vector<at::Tensor>& tensors);
70
71// Combines an original payload and wrapped payload into the original payload.
72// Used to generate the overall payload for the wrapped RPC.
73TORCH_API void writeWrappedPayload(
74 std::vector<char>& originalPayload,
75 std::vector<char>& additionalPayload);
76
77// Reads the additional, wrapped payload from a wrapped RPC off of the input
78// payload. After this, payload will contain the payload of the original,
79// un-wrapped RPC.
80TORCH_API std::vector<at::IValue> readWrappedPayload(
81 std::vector<char>& payload,
82 const rpc::Message& message);
83
84// Takes a list of events from autograd profiler and populates them into
85// profiledEvents to be carried over RPC.
86TORCH_API void populateRemoteProfiledEvents(
87 std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
88 const torch::autograd::profiler::ProfilerConfig& profilerConfig,
89 const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
90 eventLists);
91
92} // namespace rpc
93} // namespace distributed
94} // namespace torch
95