1 | #pragma once |
2 | |
3 | #include <torch/types.h> |
4 | #include <vector> |
5 | |
6 | namespace torch { |
7 | namespace distributed { |
8 | namespace rpc { |
9 | |
10 | // An enum denoting common RPC errors to allow specific error handling for them. |
11 | enum RPCErrorType { |
12 | UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */ |
13 | TIMEOUT = 1, /* Indicates that the RPC has timed out */ |
14 | INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by |
15 | FaultyAgent for testing */ |
16 | }; |
17 | |
18 | // The enum values are bitwise ORed with MessageType |
19 | // They are bit flags starting from 0x100 and should have |
20 | // value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc. |
21 | enum MessageTypeFlags { |
22 | REQUEST_TYPE = 0x100, |
23 | RESPONSE_TYPE = 0x200, |
24 | }; |
25 | |
26 | // Message types must have values between 0x00 to 0xff |
27 | enum MessageType { |
28 | // messages for dist.rpc on builtin operators |
29 | SCRIPT_CALL = 0x00 | MessageTypeFlags::REQUEST_TYPE, |
30 | SCRIPT_RET = 0x01 | MessageTypeFlags::RESPONSE_TYPE, |
31 | |
32 | // messages for dist.rpc on Python UDF |
33 | PYTHON_CALL = 0x02 | MessageTypeFlags::REQUEST_TYPE, |
34 | PYTHON_RET = 0x03 | MessageTypeFlags::RESPONSE_TYPE, |
35 | |
36 | // messages for dist.remote on builtin operators and Python UDF |
37 | SCRIPT_REMOTE_CALL = 0x04 | |
38 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator |
39 | PYTHON_REMOTE_CALL = |
40 | 0x05 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF |
41 | REMOTE_RET = |
42 | 0x06 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for |
43 | // UDF, builtin, or script |
44 | |
45 | // RRef related internal messages |
46 | SCRIPT_RREF_FETCH_CALL = |
47 | 0x07 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<IValue> fetches value |
48 | // from owner |
49 | PYTHON_RREF_FETCH_CALL = |
50 | 0x08 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<py::object> fetches |
51 | // value from owner |
52 | SCRIPT_RREF_FETCH_RET = 0x09 | |
53 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user |
54 | PYTHON_RREF_FETCH_RET = 0x0a | |
55 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user |
56 | RREF_USER_DELETE = 0x0b | |
57 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref |
58 | RREF_FORK_REQUEST = |
59 | 0x0c | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner |
60 | // about itself |
61 | RREF_CHILD_ACCEPT = |
62 | 0x0d | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent |
63 | // that owner knows it |
64 | RREF_ACK = |
65 | 0x0e | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages |
66 | |
67 | // Messages with autograd info |
68 | FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE, |
69 | FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE, |
70 | |
71 | // Messages to propagate gradients on the backward pass. |
72 | BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE, |
73 | BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE, |
74 | |
75 | // Messages to tell workers to clean up their autograd context. |
76 | CLEANUP_AUTOGRAD_CONTEXT_REQ = 0x13 | MessageTypeFlags::REQUEST_TYPE, |
77 | CLEANUP_AUTOGRAD_CONTEXT_RESP = 0x14 | MessageTypeFlags::RESPONSE_TYPE, |
78 | |
79 | // Messages that tell workers to run requests with profiling enabled. |
80 | RUN_WITH_PROFILING_REQ = 0x15 | MessageTypeFlags::REQUEST_TYPE, |
81 | RUN_WITH_PROFILING_RESP = 0x16 | MessageTypeFlags::RESPONSE_TYPE, |
82 | |
83 | // Messages to support RRef.backward(). |
84 | RREF_BACKWARD_REQ = 0x17 | MessageTypeFlags::REQUEST_TYPE, |
85 | RREF_BACKWARD_RESP = 0x18 | MessageTypeFlags::RESPONSE_TYPE, |
86 | |
87 | // Other internal message types |
88 | EXCEPTION = 0x37 | MessageTypeFlags::RESPONSE_TYPE, |
89 | UNKNOWN = 0x3c |
90 | }; |
91 | |
92 | // A message to be sent/received by an RpcAgent. |
93 | // |
94 | // A Message object contains 4 fields: |
95 | // payload (std::vector<char>): a binary chunk of data. |
96 | // tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not |
97 | // included in the payload, and it is up to the RpcAgent implementation |
98 | // to determine how to serialize them. This design is helpful for |
99 | // communicating super large tensors where serializing all the data at |
100 | // once leads to excessively large memory footprint. An implementation |
101 | // can then serialize and send tensors chunck-by-chunk, in the streaming |
102 | // fashion. |
103 | // type (MessageType): type of the message. |
104 | // id (int64_t): message id, this is used to match request and response. |
105 | // Other implementation can ignore it if they have their own |
106 | // ways to do matching. |
107 | // |
108 | // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, |
109 | // and PythonResp into a Message, and it is up to the RpcAgent |
110 | // implementation to determine how to serialize a message. |
111 | class TORCH_API Message final : public torch::CustomClassHolder { |
112 | private: |
113 | // Keep these private in order to force users to go through make_intrusive and |
114 | // thus prevent creating a Message that's not held by an intrusive_ptr. |
115 | Message(); |
116 | |
117 | Message( |
118 | std::vector<char>&& payload, |
119 | std::vector<torch::Tensor>&& tensors, |
120 | MessageType type); |
121 | |
122 | Message( |
123 | std::vector<char>&& payload, |
124 | std::vector<torch::Tensor>&& tensors, |
125 | MessageType type, |
126 | int64_t id); |
127 | |
128 | friend c10::intrusive_ptr<Message>; |
129 | |
130 | public: |
131 | Message(const Message& other) = delete; |
132 | Message(Message&& other) = delete; |
133 | Message& operator=(Message const& rhs) = delete; |
134 | Message& operator=(Message&& rhs) = delete; |
135 | |
136 | // Destructively retrieves the payload. |
137 | std::vector<char>&& movePayload() &&; |
138 | std::vector<torch::Tensor>&& moveTensors() &&; |
139 | |
140 | std::vector<char>& payload(); |
141 | const std::vector<char>& payload() const; |
142 | std::vector<torch::Tensor>& tensors(); |
143 | const std::vector<torch::Tensor>& tensors() const; |
144 | MessageType type() const; |
145 | |
146 | bool isRequest() const; |
147 | bool isResponse() const; |
148 | bool isShutdown() const; |
149 | |
150 | // id is an optional field to match request/response. If an RpcAgent |
151 | // implementation is able to do the matching without using this id, it can be |
152 | // dropped during message serialization. |
153 | int64_t id() const; |
154 | void setId(int64_t id); |
155 | |
156 | std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> getStorages() const; |
157 | |
158 | private: |
159 | std::vector<char> payload_; |
160 | std::vector<torch::Tensor> tensors_; |
161 | MessageType type_ = MessageType::UNKNOWN; |
162 | int64_t id_ = -1; |
163 | }; |
164 | |
165 | // Create a response Message of type Exception. |
166 | // The exception string representation will be used as the message's payload. |
167 | // A message ID corresponding to the request that resulted in this response can |
168 | // be provided for matching requests/responses. |
169 | TORCH_API c10::intrusive_ptr<Message> createExceptionResponse( |
170 | const std::exception& e, |
171 | int64_t id); |
172 | |
173 | // Create a response Message of type Exception. |
174 | // The passed in string representation will be used as the message's payload. |
175 | // A message ID corresponding to the request that resulted in this response can |
176 | // be provided for matching requests/responses. |
177 | TORCH_API c10::intrusive_ptr<Message> createExceptionResponse( |
178 | const std::string& exceptionStr, |
179 | int64_t id); |
180 | |
181 | inline std::tuple< |
182 | c10::intrusive_ptr<Message>, |
183 | std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>>> |
184 | withStorages(c10::intrusive_ptr<Message> message) { |
185 | auto storages = message->getStorages(); |
186 | return std::make_tuple(std::move(message), std::move(storages)); |
187 | } |
188 | |
189 | using JitFuture = c10::ivalue::Future; |
190 | |
191 | } // namespace rpc |
192 | } // namespace distributed |
193 | } // namespace torch |
194 | |