1#pragma once
2
3#include <torch/types.h>
4#include <vector>
5
6namespace torch {
7namespace distributed {
8namespace rpc {
9
10// An enum denoting common RPC errors to allow specific error handling for them.
11enum RPCErrorType {
12 UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
13 TIMEOUT = 1, /* Indicates that the RPC has timed out */
14 INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
15 FaultyAgent for testing */
16};
17
18// The enum values are bitwise ORed with MessageType
19// They are bit flags starting from 0x100 and should have
20// value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc.
21enum MessageTypeFlags {
22 REQUEST_TYPE = 0x100,
23 RESPONSE_TYPE = 0x200,
24};
25
26// Message types must have values between 0x00 to 0xff
27enum MessageType {
28 // messages for dist.rpc on builtin operators
29 SCRIPT_CALL = 0x00 | MessageTypeFlags::REQUEST_TYPE,
30 SCRIPT_RET = 0x01 | MessageTypeFlags::RESPONSE_TYPE,
31
32 // messages for dist.rpc on Python UDF
33 PYTHON_CALL = 0x02 | MessageTypeFlags::REQUEST_TYPE,
34 PYTHON_RET = 0x03 | MessageTypeFlags::RESPONSE_TYPE,
35
36 // messages for dist.remote on builtin operators and Python UDF
37 SCRIPT_REMOTE_CALL = 0x04 |
38 MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator
39 PYTHON_REMOTE_CALL =
40 0x05 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF
41 REMOTE_RET =
42 0x06 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for
43 // UDF, builtin, or script
44
45 // RRef related internal messages
46 SCRIPT_RREF_FETCH_CALL =
47 0x07 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<IValue> fetches value
48 // from owner
49 PYTHON_RREF_FETCH_CALL =
50 0x08 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<py::object> fetches
51 // value from owner
52 SCRIPT_RREF_FETCH_RET = 0x09 |
53 MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user
54 PYTHON_RREF_FETCH_RET = 0x0a |
55 MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user
56 RREF_USER_DELETE = 0x0b |
57 MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref
58 RREF_FORK_REQUEST =
59 0x0c | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner
60 // about itself
61 RREF_CHILD_ACCEPT =
62 0x0d | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent
63 // that owner knows it
64 RREF_ACK =
65 0x0e | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages
66
67 // Messages with autograd info
68 FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
69 FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
70
71 // Messages to propagate gradients on the backward pass.
72 BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
73 BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
74
75 // Messages to tell workers to clean up their autograd context.
76 CLEANUP_AUTOGRAD_CONTEXT_REQ = 0x13 | MessageTypeFlags::REQUEST_TYPE,
77 CLEANUP_AUTOGRAD_CONTEXT_RESP = 0x14 | MessageTypeFlags::RESPONSE_TYPE,
78
79 // Messages that tell workers to run requests with profiling enabled.
80 RUN_WITH_PROFILING_REQ = 0x15 | MessageTypeFlags::REQUEST_TYPE,
81 RUN_WITH_PROFILING_RESP = 0x16 | MessageTypeFlags::RESPONSE_TYPE,
82
83 // Messages to support RRef.backward().
84 RREF_BACKWARD_REQ = 0x17 | MessageTypeFlags::REQUEST_TYPE,
85 RREF_BACKWARD_RESP = 0x18 | MessageTypeFlags::RESPONSE_TYPE,
86
87 // Other internal message types
88 EXCEPTION = 0x37 | MessageTypeFlags::RESPONSE_TYPE,
89 UNKNOWN = 0x3c
90};
91
92// A message to be sent/received by an RpcAgent.
93//
94// A Message object contains 4 fields:
95// payload (std::vector<char>): a binary chunk of data.
96// tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
97// included in the payload, and it is up to the RpcAgent implementation
98// to determine how to serialize them. This design is helpful for
99// communicating super large tensors where serializing all the data at
100// once leads to excessively large memory footprint. An implementation
101// can then serialize and send tensors chunck-by-chunk, in the streaming
102// fashion.
103// type (MessageType): type of the message.
104// id (int64_t): message id, this is used to match request and response.
105// Other implementation can ignore it if they have their own
106// ways to do matching.
107//
108// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
109// and PythonResp into a Message, and it is up to the RpcAgent
110// implementation to determine how to serialize a message.
111class TORCH_API Message final : public torch::CustomClassHolder {
112 private:
113 // Keep these private in order to force users to go through make_intrusive and
114 // thus prevent creating a Message that's not held by an intrusive_ptr.
115 Message();
116
117 Message(
118 std::vector<char>&& payload,
119 std::vector<torch::Tensor>&& tensors,
120 MessageType type);
121
122 Message(
123 std::vector<char>&& payload,
124 std::vector<torch::Tensor>&& tensors,
125 MessageType type,
126 int64_t id);
127
128 friend c10::intrusive_ptr<Message>;
129
130 public:
131 Message(const Message& other) = delete;
132 Message(Message&& other) = delete;
133 Message& operator=(Message const& rhs) = delete;
134 Message& operator=(Message&& rhs) = delete;
135
136 // Destructively retrieves the payload.
137 std::vector<char>&& movePayload() &&;
138 std::vector<torch::Tensor>&& moveTensors() &&;
139
140 std::vector<char>& payload();
141 const std::vector<char>& payload() const;
142 std::vector<torch::Tensor>& tensors();
143 const std::vector<torch::Tensor>& tensors() const;
144 MessageType type() const;
145
146 bool isRequest() const;
147 bool isResponse() const;
148 bool isShutdown() const;
149
150 // id is an optional field to match request/response. If an RpcAgent
151 // implementation is able to do the matching without using this id, it can be
152 // dropped during message serialization.
153 int64_t id() const;
154 void setId(int64_t id);
155
156 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> getStorages() const;
157
158 private:
159 std::vector<char> payload_;
160 std::vector<torch::Tensor> tensors_;
161 MessageType type_ = MessageType::UNKNOWN;
162 int64_t id_ = -1;
163};
164
165// Create a response Message of type Exception.
166// The exception string representation will be used as the message's payload.
167// A message ID corresponding to the request that resulted in this response can
168// be provided for matching requests/responses.
169TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
170 const std::exception& e,
171 int64_t id);
172
173// Create a response Message of type Exception.
174// The passed in string representation will be used as the message's payload.
175// A message ID corresponding to the request that resulted in this response can
176// be provided for matching requests/responses.
177TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
178 const std::string& exceptionStr,
179 int64_t id);
180
181inline std::tuple<
182 c10::intrusive_ptr<Message>,
183 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>>>
184withStorages(c10::intrusive_ptr<Message> message) {
185 auto storages = message->getStorages();
186 return std::make_tuple(std::move(message), std::move(storages));
187}
188
189using JitFuture = c10::ivalue::Future;
190
191} // namespace rpc
192} // namespace distributed
193} // namespace torch
194