1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_session.h |
22 | * \brief Base RPC session interface. |
23 | */ |
24 | #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ |
25 | #define TVM_RUNTIME_RPC_RPC_SESSION_H_ |
26 | |
27 | #include <tvm/runtime/device_api.h> |
28 | #include <tvm/runtime/packed_func.h> |
29 | |
30 | #include <functional> |
31 | #include <memory> |
32 | #include <string> |
33 | |
34 | #include "../minrpc/rpc_reference.h" |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | |
39 | /*! |
40 | * \brief The interface of all remote RPC sessions. |
41 | * |
42 | * It contains all the necessary interface to implement |
43 | * remote call and resource management. |
44 | * |
45 | * The interface is designed to allow easy proxy-chaining |
46 | * by forward requests to another RPCSession. |
47 | */ |
48 | class RPCSession { |
49 | public: |
50 | /*! \brief PackedFunc Handle in the remote. */ |
51 | using PackedFuncHandle = void*; |
52 | |
53 | /*! \brief Module handle in the remote. */ |
54 | using ModuleHandle = void*; |
55 | |
56 | /*! \brief NDArray handle in the remote. */ |
57 | using NDArrayHandle = void*; |
58 | |
59 | /*! |
60 | * \brief Callback to send an encoded return values via encode_args. |
61 | * |
62 | * \param encode_args The arguments that we can encode the return values into. |
63 | * |
64 | * Encoding convention (as list of arguments): |
65 | * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. |
66 | * - PackedFunc/Module: [tcode: int, handle: void*] |
67 | * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] |
68 | * DLTensor* contains the meta-data as well as handle into the remote data. |
69 | * nd_handle can be used for deletion. |
70 | */ |
71 | using FEncodeReturn = std::function<void(TVMArgs encoded_args)>; |
72 | |
73 | /*! |
74 | * \brief Callback to send an encoded return values via encode_args. |
75 | * |
76 | * \param status The return status, can be RPCCode::kReturn or RPCCode::kException. |
77 | * \param encode_args The arguments that we can encode the return values into. |
78 | */ |
79 | using FAsyncCallback = std::function<void(RPCCode status, TVMArgs encoded_args)>; |
80 | |
81 | /*! \brief Destructor.*/ |
82 | virtual ~RPCSession() {} |
83 | |
84 | /*! |
85 | * \brief Get function in the session. |
86 | * \param name The name of the function. |
87 | * \return The function handle. |
88 | */ |
89 | virtual PackedFuncHandle GetFunction(const std::string& name) = 0; |
90 | |
91 | /*! |
92 | * \brief Call into a remote Packed function. |
93 | * |
94 | * Calling convention: |
95 | * |
96 | * - type_code is follows the PackedFunc convention. |
97 | * - int/float/string/bytes follows the PackedFunc convention, all data are local. |
98 | * - PackedFunc/Module and future remote objects: pass remote handle instead. |
99 | * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor |
100 | * points to a remote data handle returned by the Device API. |
101 | * The meta-data of the DLTensor sits on local. |
102 | * |
103 | * The caller populates the arguments and manages these arguments. |
104 | * |
105 | * The callee can change the content of arg_values and arg_type_codes |
106 | * if they want to do inplace modify and forward. |
107 | * |
108 | * The callee need to store the return value into ret_value. |
109 | * - PackedFunc/Module are stored as void* |
110 | * - NDArray is stored as local NDArray, whose data field is a remote handle. |
111 | * Notably the NDArray's deleter won't delete remote handle. |
112 | * It is up to the user of the RPCSession to such wrapping. |
113 | * - In short, remote handles are "moved" as return values |
114 | * and the callee needs to explicitly manage them by calling |
115 | * the deleter functions when they are no longer needed. |
116 | * |
117 | * \param func The function handle. |
118 | * \param arg_values The argument values. |
119 | * \param arg_type_codes the type codes of the argument. |
120 | * \param num_args Number of arguments. |
121 | * \param fencode_return The function to set the return value, |
122 | * if not called, return value is null. |
123 | */ |
124 | virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, |
125 | const int* arg_type_codes, int num_args, |
126 | const FEncodeReturn& fencode_return) = 0; |
127 | |
128 | /*! |
129 | * \brief Copy bytes into remote array content. |
130 | * \param local_from_bytes The source host data. |
131 | * \param remote_to The target array. |
132 | * \param nbytes The size of the memory in bytes. |
133 | */ |
134 | virtual void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) = 0; |
135 | /*! |
136 | * \brief Copy bytes from remote array content. |
137 | * \param remote_from The source host data. |
138 | * \param local_to_bytes The target array. |
139 | * \param nbytes The size of the memory in bytes. |
140 | */ |
141 | virtual void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) = 0; |
142 | |
143 | /*! |
144 | * \brief Free a remote function. |
145 | * \param handle The remote handle, can be NDArray/PackedFunc/Module |
146 | * \param type_code The type code of the underlying type. |
147 | */ |
148 | virtual void FreeHandle(void* handle, int type_code) = 0; |
149 | |
150 | /*! |
151 | * \brief Get device API that represents the remote |
152 | * actions that can be taken on the remote. |
153 | * |
154 | * The caller can then call into the Alloc/Free functions |
155 | * to allocate free spaces and taking the pointer as the handle. |
156 | * |
157 | * The device API is guaranteed to be alive during the |
158 | * lifetime of the Session. |
159 | * |
160 | * \param dev The remote device. |
161 | * \param allow_missing Whether can we return nullptr if it is not available. |
162 | * |
163 | * \return The device API. |
164 | */ |
165 | virtual DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing = false) = 0; |
166 | |
167 | /*! |
168 | * \brief Whether the session is a local session and we can directly |
169 | * the data handle returned by the session and treat it as pointer |
170 | * to the local memory. |
171 | * |
172 | * This information is useful for RPC server to directly copy into the |
173 | * local memory without creating a temporary buffer. |
174 | * |
175 | * \return Whether it is a local session. |
176 | */ |
177 | virtual bool IsLocalSession() const = 0; |
178 | |
179 | // Asynchrous variant of API |
180 | // These APIs are used by the RPC server to allow sessions that |
181 | // have special implementations for the async functions. |
182 | // |
183 | // In the async APIs, an exception is returned by the passing |
184 | // async_error=true, encode_args=[error_msg]. |
185 | |
186 | /*! |
187 | * \brief Whether the session is async. |
188 | * |
189 | * If the session is not async, its Aync implementations |
190 | * simply calls into the their synchronize counterparts, |
191 | * and the callback is guaranteed to be called before the async function finishes. |
192 | * |
193 | * \return the async state. |
194 | * |
195 | * \note We can only use async session in an Event driven RPC server. |
196 | */ |
197 | virtual bool IsAsync() const; |
198 | |
199 | /*! |
200 | * \brief Asynchrously call func. |
201 | * \param func The function handle. |
202 | * \param arg_values The argument values. |
203 | * \param arg_type_codes the type codes of the argument. |
204 | * \param num_args Number of arguments. |
205 | * |
206 | * \param callback The callback to pass the return value or exception. |
207 | */ |
208 | virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, |
209 | const int* arg_type_codes, int num_args, FAsyncCallback callback); |
210 | |
211 | /*! |
212 | * \brief Asynchrous version of CopyToRemote. |
213 | * |
214 | * \param local_from_bytes The source host data. |
215 | * \param remote_to The target array. |
216 | * \param nbytes The size of the memory in bytes. |
217 | * \param on_complete The callback to signal copy complete. |
218 | * \note All the allocated memory in local_from, and remote_to |
219 | * must stay alive until on_compelete is called. |
220 | */ |
221 | virtual void AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes, |
222 | FAsyncCallback on_complete); |
223 | |
224 | /*! |
225 | * \brief Asynchrous version of CopyFromRemote. |
226 | * |
227 | * \param remote_from The source host data. |
228 | * \param local_to_bytes The target array. |
229 | * \param nbytes The size of the memory in bytes. |
230 | * \param on_complete The callback to signal copy complete. |
231 | * \note All the allocated memory in remote_from, and local_to |
232 | * must stay alive until on_compelete is called. |
233 | */ |
234 | virtual void AsyncCopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes, |
235 | FAsyncCallback on_complete); |
236 | /*! |
237 | * \brief Asynchrously wait for all events in dev, stream compeletes. |
238 | * \param dev The device. |
239 | * \param stream The stream to wait on. |
240 | * \param on_complete The callback to signal copy complete. |
241 | */ |
242 | virtual void AsyncStreamWait(Device dev, TVMStreamHandle stream, FAsyncCallback on_compelte); |
243 | |
244 | /*! |
245 | * \return The session table index of the session. |
246 | */ |
247 | int table_index() const { return table_index_; } |
248 | |
249 | /*! |
250 | * \brief Try get session from the global session table by table index. |
251 | * \param table_index The table index of the session. |
252 | * \return The shared_ptr to the session, can be nullptr. |
253 | */ |
254 | static std::shared_ptr<RPCSession> Get(int table_index); |
255 | |
256 | /*! |
257 | * \brief Shutdown RPC connection. |
258 | */ |
259 | virtual void Shutdown() {} |
260 | |
261 | protected: |
262 | /*! |
263 | * \brief Send an exception to the callback. |
264 | * \param msg The exception message. |
265 | */ |
266 | void SendException(FAsyncCallback callback, const char* msg); |
267 | |
268 | private: |
269 | /*! \brief index of this session in RPC session table */ |
270 | int table_index_{0}; |
271 | /*! \brief Insert the current session to the session table.*/ |
272 | static void InsertToSessionTable(std::shared_ptr<RPCSession> sess); |
273 | // friend declaration |
274 | friend Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess); |
275 | }; |
276 | |
277 | /*! |
278 | * \brief Remote space handle cell used by the RPC runtime API. |
279 | * |
280 | * When we allocate space using a rpc device, the data pointer |
281 | * points to an allocated RemoteSpace. |
282 | */ |
283 | struct RemoteSpace { |
284 | /*! \brief The remote data handle. */ |
285 | void* data; |
286 | /*! \brief Reference to the underlying RPC session. */ |
287 | std::shared_ptr<RPCSession> sess; |
288 | }; |
289 | |
290 | /*! |
291 | * \brief Create a Global RPC module that refers to the session. |
292 | * \param sess The RPC session of the global module. |
293 | * \return The created module. |
294 | */ |
295 | Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess); |
296 | |
297 | /*! |
298 | * \brief Get the session module from a RPC session Module. |
299 | * \param mod The input module(must be an RPCModule). |
300 | * \return The internal RPCSession. |
301 | */ |
302 | std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod); |
303 | |
304 | } // namespace runtime |
305 | } // namespace tvm |
306 | #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ |
307 | |