1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_endpoint.h |
22 | * \brief Communication endpoints to connect local and remote RPC sessions. |
23 | */ |
24 | #ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ |
25 | #define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ |
26 | |
27 | #include <tvm/runtime/packed_func.h> |
28 | |
29 | #include <memory> |
30 | #include <mutex> |
31 | #include <string> |
32 | #include <utility> |
33 | |
34 | #include "../../support/ring_buffer.h" |
35 | #include "../minrpc/rpc_reference.h" |
36 | #include "rpc_channel.h" |
37 | #include "rpc_channel_logger.h" |
38 | #include "rpc_session.h" |
39 | |
40 | namespace tvm { |
41 | namespace runtime { |
42 | |
43 | // Magic header for RPC data plane |
44 | const int kRPCMagic = 0xff271; |
45 | // magic header for RPC tracker(control plane) |
46 | const int kRPCTrackerMagic = 0x2f271; |
47 | // sucess response |
48 | const int kRPCSuccess = kRPCMagic + 0; |
49 | // cannot found matched key in server |
50 | const int kRPCMismatch = kRPCMagic + 2; |
51 | |
52 | /*! \brief Enumeration code for the RPC tracker */ |
53 | enum class TrackerCode : int { |
54 | kFail = -1, |
55 | kSuccess = 0, |
56 | kPing = 1, |
57 | kStop = 2, |
58 | kPut = 3, |
59 | kRequest = 4, |
60 | kUpdateInfo = 5, |
61 | kSummary = 6, |
62 | kGetPendingMatchKeys = 7 |
63 | }; |
64 | |
65 | /*! |
66 | * \brief Communication endpoints to connect local and remote RPC sessions. |
67 | * An endpoint can either be a client or a server. |
68 | */ |
69 | class RPCEndpoint { |
70 | public: |
71 | /*! \brief virtual destructor |
72 | * Closes the connection if the connection hasn't already been closed. |
73 | */ |
74 | ~RPCEndpoint(); |
75 | |
76 | /*! |
77 | * \brief Shutdown RPC connection. |
78 | * |
79 | * Shutdown has no effect if the connection has already been shut down. |
80 | * Shutdown will wait for all output currently queued from the RPC connection (i.e. The user |
81 | * doesn't need to wait for completion before calling Shutdown.) Any further use of objects that |
82 | * depended on the endpoint (e.g. A tvm.nd.array allocated on the remote RPC session) may throw an |
83 | * exception when used. |
84 | */ |
85 | void Shutdown(); |
86 | |
87 | /*! |
88 | * \brief The server loop that server runs to handle RPC calls. |
89 | */ |
90 | void ServerLoop(); |
91 | /*! |
92 | * \brief Message handling function for an async IO event driven server. |
93 | * |
94 | * Called when the server receives a message or an IO event update. |
95 | * Event driven handler will never call recv on the channel |
96 | * and always relies on the ServerIOEventHandler to receive the data. |
97 | * |
98 | * \param in_bytes The incoming bytes. |
99 | * \param event_flag 1: read_available, 2: write_avaiable. |
100 | * \return State flag. |
101 | * 1: continue running, no need to write, |
102 | * 2: need to write |
103 | * 0: shutdown |
104 | */ |
105 | int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); |
106 | |
107 | /*! |
108 | * \brief Initalize the session on the remote that will be used to back all the RPC requests. |
109 | * |
110 | * If no session constructor arguments is passed, LocalSession will be used in the remote. |
111 | * Otherwise the remote serving session will be constructed using the arguments |
112 | * specified in the session_constructor_args. |
113 | * |
114 | * The construction rule can be summarized as follows: |
115 | * |
116 | * \code |
117 | * |
118 | * auto args = session_constructor_args; |
119 | * int n = args.size(); |
120 | * if (n != 0) { |
121 | * std::string constructor = args[0]; |
122 | * server.serving_session_ = GetGlobalFunc(constructor)( |
123 | * args[1], args[2] ... args[n - 1]) |
124 | * } else { |
125 | * server.serving_session_ = LocalSession(); |
126 | * } |
127 | * \endcode |
128 | * |
129 | * \param session_constructor_args Optional sequence of the remote sesssion constructor. |
130 | */ |
131 | void InitRemoteSession(TVMArgs session_constructor_args); |
132 | |
133 | /*! |
134 | * \brief Call into remote function |
135 | * \param handle The function handle |
136 | * \param arg_values The argument values. |
137 | * \param arg_type_codes the type codes of the argument. |
138 | * \param num_args Number of arguments. |
139 | * \param fencode_return The function to receive return value encodings. |
140 | */ |
141 | void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values, |
142 | const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return); |
143 | /*! |
144 | * \brief Copy bytes into remote array content. |
145 | * \param from The source host data. |
146 | * \param from_offset The byte offeset in the from. |
147 | * \param to The target array. |
148 | * \param to_offset The byte offset in the to. |
149 | * \param nbytes The size of the memory in bytes. |
150 | * \param dev_to The target device. |
151 | * \param type_hint Hint of content data type. |
152 | */ |
153 | void CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes); |
154 | /*! |
155 | * \brief Copy bytes from remote array content. |
156 | * \param from The source host data. |
157 | * \param from_offset The byte offeset in the from. |
158 | * \param to The target array. |
159 | * \param to_offset The byte offset in the to. |
160 | * \param nbytes The size of the memory in bytes. |
161 | * \param dev_from The source device. |
162 | * \param type_hint Hint of content data type. |
163 | */ |
164 | void CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes); |
165 | |
166 | /*! |
167 | * \brief Call a remote defined system function with arguments. |
168 | * \param fcode The function code. |
169 | * \param args The arguments |
170 | * \return The returned remote value. |
171 | */ |
172 | template <typename... Args> |
173 | inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args); |
174 | /*! |
175 | * \brief Create a RPC session with given channel. |
176 | * \param channel The communication channel. |
177 | * \param name The local name of the session, used for debug |
178 | * \param remote_key The remote key of the session |
179 | * if remote_key equals "%toinit", we need to re-intialize |
180 | * it by event handler. |
181 | * \param fcleanup The cleanup Packed function. |
182 | */ |
183 | static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name, |
184 | std::string remote_key, |
185 | TypedPackedFunc<void()> fcleanup = nullptr); |
186 | |
187 | private: |
188 | class EventHandler; |
189 | // Handle events until receives a return |
190 | // Also flushes channels so that the function advances. |
191 | RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); |
192 | // Initalization |
193 | void Init(); |
194 | // Internal channel. |
195 | std::unique_ptr<RPCChannel> channel_; |
196 | |
197 | // Internal mutex |
198 | std::mutex mutex_; |
199 | // Internal ring buffer. |
200 | support::RingBuffer reader_, writer_; |
201 | // Event handler. |
202 | std::shared_ptr<EventHandler> handler_; |
203 | // syscall remote with specified function code. |
204 | PackedFunc syscall_remote_; |
205 | // The name of the session. |
206 | std::string name_; |
207 | // The remote key |
208 | std::string remote_key_; |
209 | // Invoked when the RPC session is terminated |
210 | TypedPackedFunc<void()> fcleanup_; |
211 | }; |
212 | |
213 | /*! |
214 | * \brief Create an RPC client session from an RPC client endpoint. |
215 | * \param endpoint The endpoint. |
216 | * \return The created session. |
217 | */ |
218 | std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint); |
219 | |
220 | // implementation of inline functions |
221 | template <typename... Args> |
222 | inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { |
223 | return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...); |
224 | } |
225 | |
226 | /*! |
227 | * \brief Calculates overhead size of a CopyToRemote packet. |
228 | * \param to DLTensor to copy. |
229 | * \param code RPCCode for this transfer. |
230 | * \param nbytes Number of bytes to transfer. |
231 | * \return The remote-copy packet overhead size. |
232 | */ |
233 | uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes); |
234 | |
235 | } // namespace runtime |
236 | } // namespace tvm |
237 | #endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ |
238 | |