1 | #pragma once |
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/core/Event.h> |
5 | #include <c10/core/Stream.h> |
6 | #include <torch/csrc/autograd/profiler.h> |
7 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
8 | #include <torch/csrc/jit/serialization/pickle.h> |
9 | #include <torch/csrc/utils/byte_order.h> |
10 | |
11 | namespace tensorpipe { |
12 | class Message; |
13 | } // namespace tensorpipe |
14 | |
15 | namespace torch { |
16 | namespace distributed { |
17 | namespace rpc { |
18 | |
19 | // Parse error message and return RPCErrorType based on the message. |
20 | TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture); |
21 | // Create an error string given the error description and error type |
22 | TORCH_API std::string makeRPCError( |
23 | const std::string& rpcErrorStr, |
24 | RPCErrorType errorType); |
25 | |
26 | // Given an RPC message received as a request over the wire, deserialize it into |
27 | // the appropriate 'RpcCommandBase' type. |
28 | TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest( |
29 | const Message& request); |
30 | |
31 | // Given an RPC message received as a response over the wire, deserialize it |
32 | // into the appropriate 'RpcCommandBase' type, if the response is |
33 | // FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions |
34 | // to received tensors and set the wrappedMsgType to its wrapped message type. |
35 | TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse( |
36 | const Message& response, |
37 | MessageType& wrappedMsgType); |
38 | |
39 | // Given an RPC message received as a response over the wire, deserialize it |
40 | // into the valid IValue if the message is for a script rpc result, |
41 | // otherwise deserialize it into dummy none ivalue that will never be used. |
42 | // In this deserialization, we also attach recv rpc backward functions if |
43 | // needed. |
44 | IValue deserializeResptoIValueInternal( |
45 | RpcCommandBase& rpc, |
46 | MessageType messageType); |
47 | TORCH_API IValue deserializeRespToIValue(const Message& message); |
48 | |
49 | // Note: format is subject to change and intended for RPCs. |
50 | // For saving persistently to disk, use torch::save(). |
51 | TORCH_API std::string wireSerialize( |
52 | const std::vector<char>& payload, |
53 | const std::vector<at::Tensor>& tensors); |
54 | |
55 | TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize( |
56 | const void* data, |
57 | size_t data_size); |
58 | |
59 | // We use vector<char> as the type of blobs because it's what rpc::Message uses |
60 | // for its payload, even though it has the disadvantage that it cannot be |
61 | // allocated with uninitialized memory: it is always zeroed out. |
62 | |
63 | // Some Tensors are effectively views of larger Tensors, where only a small |
64 | // subset of the Storage data is referenced. This normally is good and avoids |
65 | // copies when kept locally, but if we naively push the whole Storage over the |
66 | // wire, we'll end up with excess network traffic. This change clones tensors if |
67 | // we'd save at least half the data, and over a minimum hurdle. |
68 | TORCH_API c10::List<at::Tensor> cloneSparseTensors( |
69 | const std::vector<at::Tensor>& tensors); |
70 | |
71 | // Combines an original payload and wrapped payload into the original payload. |
72 | // Used to generate the overall payload for the wrapped RPC. |
73 | TORCH_API void writeWrappedPayload( |
74 | std::vector<char>& originalPayload, |
75 | std::vector<char>& additionalPayload); |
76 | |
77 | // Reads the additional, wrapped payload from a wrapped RPC off of the input |
78 | // payload. After this, payload will contain the payload of the original, |
79 | // un-wrapped RPC. |
80 | TORCH_API std::vector<at::IValue> readWrappedPayload( |
81 | std::vector<char>& payload, |
82 | const rpc::Message& message); |
83 | |
84 | // Takes a list of events from autograd profiler and populates them into |
85 | // profiledEvents to be carried over RPC. |
86 | TORCH_API void populateRemoteProfiledEvents( |
87 | std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents, |
88 | const torch::autograd::profiler::ProfilerConfig& profilerConfig, |
89 | const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>& |
90 | eventLists); |
91 | |
92 | } // namespace rpc |
93 | } // namespace distributed |
94 | } // namespace torch |
95 | |