1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_session.h
22 * \brief Base RPC session interface.
23 */
24#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
25#define TVM_RUNTIME_RPC_RPC_SESSION_H_
26
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/packed_func.h>
29
30#include <functional>
31#include <memory>
32#include <string>
33
34#include "../minrpc/rpc_reference.h"
35
36namespace tvm {
37namespace runtime {
38
39/*!
40 * \brief The interface of all remote RPC sessions.
41 *
42 * It contains all the necessary interface to implement
43 * remote call and resource management.
44 *
45 * The interface is designed to allow easy proxy-chaining
46 * by forward requests to another RPCSession.
47 */
48class RPCSession {
49 public:
50 /*! \brief PackedFunc Handle in the remote. */
51 using PackedFuncHandle = void*;
52
53 /*! \brief Module handle in the remote. */
54 using ModuleHandle = void*;
55
56 /*! \brief NDArray handle in the remote. */
57 using NDArrayHandle = void*;
58
59 /*!
60 * \brief Callback to send an encoded return values via encode_args.
61 *
62 * \param encode_args The arguments that we can encode the return values into.
63 *
64 * Encoding convention (as list of arguments):
65 * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention.
66 * - PackedFunc/Module: [tcode: int, handle: void*]
67 * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*]
68 * DLTensor* contains the meta-data as well as handle into the remote data.
69 * nd_handle can be used for deletion.
70 */
71 using FEncodeReturn = std::function<void(TVMArgs encoded_args)>;
72
73 /*!
74 * \brief Callback to send an encoded return values via encode_args.
75 *
76 * \param status The return status, can be RPCCode::kReturn or RPCCode::kException.
77 * \param encode_args The arguments that we can encode the return values into.
78 */
79 using FAsyncCallback = std::function<void(RPCCode status, TVMArgs encoded_args)>;
80
81 /*! \brief Destructor.*/
82 virtual ~RPCSession() {}
83
84 /*!
85 * \brief Get function in the session.
86 * \param name The name of the function.
87 * \return The function handle.
88 */
89 virtual PackedFuncHandle GetFunction(const std::string& name) = 0;
90
91 /*!
92 * \brief Call into a remote Packed function.
93 *
94 * Calling convention:
95 *
96 * - type_code is follows the PackedFunc convention.
97 * - int/float/string/bytes follows the PackedFunc convention, all data are local.
98 * - PackedFunc/Module and future remote objects: pass remote handle instead.
99 * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor
100 * points to a remote data handle returned by the Device API.
101 * The meta-data of the DLTensor sits on local.
102 *
103 * The caller populates the arguments and manages these arguments.
104 *
105 * The callee can change the content of arg_values and arg_type_codes
106 * if they want to do inplace modify and forward.
107 *
108 * The callee need to store the return value into ret_value.
109 * - PackedFunc/Module are stored as void*
110 * - NDArray is stored as local NDArray, whose data field is a remote handle.
111 * Notably the NDArray's deleter won't delete remote handle.
112 * It is up to the user of the RPCSession to such wrapping.
113 * - In short, remote handles are "moved" as return values
114 * and the callee needs to explicitly manage them by calling
115 * the deleter functions when they are no longer needed.
116 *
117 * \param func The function handle.
118 * \param arg_values The argument values.
119 * \param arg_type_codes the type codes of the argument.
120 * \param num_args Number of arguments.
121 * \param fencode_return The function to set the return value,
122 * if not called, return value is null.
123 */
124 virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values,
125 const int* arg_type_codes, int num_args,
126 const FEncodeReturn& fencode_return) = 0;
127
128 /*!
129 * \brief Copy bytes into remote array content.
130 * \param local_from_bytes The source host data.
131 * \param remote_to The target array.
132 * \param nbytes The size of the memory in bytes.
133 */
134 virtual void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) = 0;
135 /*!
136 * \brief Copy bytes from remote array content.
137 * \param remote_from The source host data.
138 * \param local_to_bytes The target array.
139 * \param nbytes The size of the memory in bytes.
140 */
141 virtual void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) = 0;
142
143 /*!
144 * \brief Free a remote function.
145 * \param handle The remote handle, can be NDArray/PackedFunc/Module
146 * \param type_code The type code of the underlying type.
147 */
148 virtual void FreeHandle(void* handle, int type_code) = 0;
149
150 /*!
151 * \brief Get device API that represents the remote
152 * actions that can be taken on the remote.
153 *
154 * The caller can then call into the Alloc/Free functions
155 * to allocate free spaces and taking the pointer as the handle.
156 *
157 * The device API is guaranteed to be alive during the
158 * lifetime of the Session.
159 *
160 * \param dev The remote device.
161 * \param allow_missing Whether can we return nullptr if it is not available.
162 *
163 * \return The device API.
164 */
165 virtual DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing = false) = 0;
166
167 /*!
168 * \brief Whether the session is a local session and we can directly
169 * the data handle returned by the session and treat it as pointer
170 * to the local memory.
171 *
172 * This information is useful for RPC server to directly copy into the
173 * local memory without creating a temporary buffer.
174 *
175 * \return Whether it is a local session.
176 */
177 virtual bool IsLocalSession() const = 0;
178
179 // Asynchrous variant of API
180 // These APIs are used by the RPC server to allow sessions that
181 // have special implementations for the async functions.
182 //
183 // In the async APIs, an exception is returned by the passing
184 // async_error=true, encode_args=[error_msg].
185
186 /*!
187 * \brief Whether the session is async.
188 *
189 * If the session is not async, its Aync implementations
190 * simply calls into the their synchronize counterparts,
191 * and the callback is guaranteed to be called before the async function finishes.
192 *
193 * \return the async state.
194 *
195 * \note We can only use async session in an Event driven RPC server.
196 */
197 virtual bool IsAsync() const;
198
199 /*!
200 * \brief Asynchrously call func.
201 * \param func The function handle.
202 * \param arg_values The argument values.
203 * \param arg_type_codes the type codes of the argument.
204 * \param num_args Number of arguments.
205 *
206 * \param callback The callback to pass the return value or exception.
207 */
208 virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values,
209 const int* arg_type_codes, int num_args, FAsyncCallback callback);
210
211 /*!
212 * \brief Asynchrous version of CopyToRemote.
213 *
214 * \param local_from_bytes The source host data.
215 * \param remote_to The target array.
216 * \param nbytes The size of the memory in bytes.
217 * \param on_complete The callback to signal copy complete.
218 * \note All the allocated memory in local_from, and remote_to
219 * must stay alive until on_compelete is called.
220 */
221 virtual void AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes,
222 FAsyncCallback on_complete);
223
224 /*!
225 * \brief Asynchrous version of CopyFromRemote.
226 *
227 * \param remote_from The source host data.
228 * \param local_to_bytes The target array.
229 * \param nbytes The size of the memory in bytes.
230 * \param on_complete The callback to signal copy complete.
231 * \note All the allocated memory in remote_from, and local_to
232 * must stay alive until on_compelete is called.
233 */
234 virtual void AsyncCopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes,
235 FAsyncCallback on_complete);
236 /*!
237 * \brief Asynchrously wait for all events in dev, stream compeletes.
238 * \param dev The device.
239 * \param stream The stream to wait on.
240 * \param on_complete The callback to signal copy complete.
241 */
242 virtual void AsyncStreamWait(Device dev, TVMStreamHandle stream, FAsyncCallback on_compelte);
243
244 /*!
245 * \return The session table index of the session.
246 */
247 int table_index() const { return table_index_; }
248
249 /*!
250 * \brief Try get session from the global session table by table index.
251 * \param table_index The table index of the session.
252 * \return The shared_ptr to the session, can be nullptr.
253 */
254 static std::shared_ptr<RPCSession> Get(int table_index);
255
256 /*!
257 * \brief Shutdown RPC connection.
258 */
259 virtual void Shutdown() {}
260
261 protected:
262 /*!
263 * \brief Send an exception to the callback.
264 * \param msg The exception message.
265 */
266 void SendException(FAsyncCallback callback, const char* msg);
267
268 private:
269 /*! \brief index of this session in RPC session table */
270 int table_index_{0};
271 /*! \brief Insert the current session to the session table.*/
272 static void InsertToSessionTable(std::shared_ptr<RPCSession> sess);
273 // friend declaration
274 friend Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess);
275};
276
277/*!
278 * \brief Remote space handle cell used by the RPC runtime API.
279 *
280 * When we allocate space using a rpc device, the data pointer
281 * points to an allocated RemoteSpace.
282 */
283struct RemoteSpace {
284 /*! \brief The remote data handle. */
285 void* data;
286 /*! \brief Reference to the underlying RPC session. */
287 std::shared_ptr<RPCSession> sess;
288};
289
290/*!
291 * \brief Create a Global RPC module that refers to the session.
292 * \param sess The RPC session of the global module.
293 * \return The created module.
294 */
295Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess);
296
297/*!
298 * \brief Get the session module from a RPC session Module.
299 * \param mod The input module(must be an RPCModule).
300 * \return The internal RPCSession.
301 */
302std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod);
303
304} // namespace runtime
305} // namespace tvm
306#endif // TVM_RUNTIME_RPC_RPC_SESSION_H_
307