1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_endpoint.h
22 * \brief Communication endpoints to connect local and remote RPC sessions.
23 */
24#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
25#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
26
27#include <tvm/runtime/packed_func.h>
28
29#include <memory>
30#include <mutex>
31#include <string>
32#include <utility>
33
34#include "../../support/ring_buffer.h"
35#include "../minrpc/rpc_reference.h"
36#include "rpc_channel.h"
37#include "rpc_channel_logger.h"
38#include "rpc_session.h"
39
40namespace tvm {
41namespace runtime {
42
43// Magic header for RPC data plane
44const int kRPCMagic = 0xff271;
45// magic header for RPC tracker(control plane)
46const int kRPCTrackerMagic = 0x2f271;
47// sucess response
48const int kRPCSuccess = kRPCMagic + 0;
49// cannot found matched key in server
50const int kRPCMismatch = kRPCMagic + 2;
51
52/*! \brief Enumeration code for the RPC tracker */
53enum class TrackerCode : int {
54 kFail = -1,
55 kSuccess = 0,
56 kPing = 1,
57 kStop = 2,
58 kPut = 3,
59 kRequest = 4,
60 kUpdateInfo = 5,
61 kSummary = 6,
62 kGetPendingMatchKeys = 7
63};
64
65/*!
66 * \brief Communication endpoints to connect local and remote RPC sessions.
67 * An endpoint can either be a client or a server.
68 */
69class RPCEndpoint {
70 public:
71 /*! \brief virtual destructor
72 * Closes the connection if the connection hasn't already been closed.
73 */
74 ~RPCEndpoint();
75
76 /*!
77 * \brief Shutdown RPC connection.
78 *
79 * Shutdown has no effect if the connection has already been shut down.
80 * Shutdown will wait for all output currently queued from the RPC connection (i.e. The user
81 * doesn't need to wait for completion before calling Shutdown.) Any further use of objects that
82 * depended on the endpoint (e.g. A tvm.nd.array allocated on the remote RPC session) may throw an
83 * exception when used.
84 */
85 void Shutdown();
86
87 /*!
88 * \brief The server loop that server runs to handle RPC calls.
89 */
90 void ServerLoop();
91 /*!
92 * \brief Message handling function for an async IO event driven server.
93 *
94 * Called when the server receives a message or an IO event update.
95 * Event driven handler will never call recv on the channel
96 * and always relies on the ServerIOEventHandler to receive the data.
97 *
98 * \param in_bytes The incoming bytes.
99 * \param event_flag 1: read_available, 2: write_avaiable.
100 * \return State flag.
101 * 1: continue running, no need to write,
102 * 2: need to write
103 * 0: shutdown
104 */
105 int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag);
106
107 /*!
108 * \brief Initalize the session on the remote that will be used to back all the RPC requests.
109 *
110 * If no session constructor arguments is passed, LocalSession will be used in the remote.
111 * Otherwise the remote serving session will be constructed using the arguments
112 * specified in the session_constructor_args.
113 *
114 * The construction rule can be summarized as follows:
115 *
116 * \code
117 *
118 * auto args = session_constructor_args;
119 * int n = args.size();
120 * if (n != 0) {
121 * std::string constructor = args[0];
122 * server.serving_session_ = GetGlobalFunc(constructor)(
123 * args[1], args[2] ... args[n - 1])
124 * } else {
125 * server.serving_session_ = LocalSession();
126 * }
127 * \endcode
128 *
129 * \param session_constructor_args Optional sequence of the remote sesssion constructor.
130 */
131 void InitRemoteSession(TVMArgs session_constructor_args);
132
133 /*!
134 * \brief Call into remote function
135 * \param handle The function handle
136 * \param arg_values The argument values.
137 * \param arg_type_codes the type codes of the argument.
138 * \param num_args Number of arguments.
139 * \param fencode_return The function to receive return value encodings.
140 */
141 void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values,
142 const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return);
143 /*!
144 * \brief Copy bytes into remote array content.
145 * \param from The source host data.
146 * \param from_offset The byte offeset in the from.
147 * \param to The target array.
148 * \param to_offset The byte offset in the to.
149 * \param nbytes The size of the memory in bytes.
150 * \param dev_to The target device.
151 * \param type_hint Hint of content data type.
152 */
153 void CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes);
154 /*!
155 * \brief Copy bytes from remote array content.
156 * \param from The source host data.
157 * \param from_offset The byte offeset in the from.
158 * \param to The target array.
159 * \param to_offset The byte offset in the to.
160 * \param nbytes The size of the memory in bytes.
161 * \param dev_from The source device.
162 * \param type_hint Hint of content data type.
163 */
164 void CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes);
165
166 /*!
167 * \brief Call a remote defined system function with arguments.
168 * \param fcode The function code.
169 * \param args The arguments
170 * \return The returned remote value.
171 */
172 template <typename... Args>
173 inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args);
174 /*!
175 * \brief Create a RPC session with given channel.
176 * \param channel The communication channel.
177 * \param name The local name of the session, used for debug
178 * \param remote_key The remote key of the session
179 * if remote_key equals "%toinit", we need to re-intialize
180 * it by event handler.
181 * \param fcleanup The cleanup Packed function.
182 */
183 static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name,
184 std::string remote_key,
185 TypedPackedFunc<void()> fcleanup = nullptr);
186
187 private:
188 class EventHandler;
189 // Handle events until receives a return
190 // Also flushes channels so that the function advances.
191 RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn);
192 // Initalization
193 void Init();
194 // Internal channel.
195 std::unique_ptr<RPCChannel> channel_;
196
197 // Internal mutex
198 std::mutex mutex_;
199 // Internal ring buffer.
200 support::RingBuffer reader_, writer_;
201 // Event handler.
202 std::shared_ptr<EventHandler> handler_;
203 // syscall remote with specified function code.
204 PackedFunc syscall_remote_;
205 // The name of the session.
206 std::string name_;
207 // The remote key
208 std::string remote_key_;
209 // Invoked when the RPC session is terminated
210 TypedPackedFunc<void()> fcleanup_;
211};
212
213/*!
214 * \brief Create an RPC client session from an RPC client endpoint.
215 * \param endpoint The endpoint.
216 * \return The created session.
217 */
218std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint);
219
220// implementation of inline functions
221template <typename... Args>
222inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
223 return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
224}
225
226/*!
227 * \brief Calculates overhead size of a CopyToRemote packet.
228 * \param to DLTensor to copy.
229 * \param code RPCCode for this transfer.
230 * \param nbytes Number of bytes to transfer.
231 * \return The remote-copy packet overhead size.
232 */
233uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);
234
235} // namespace runtime
236} // namespace tvm
237#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
238