1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file minrpc_server.h |
22 | * \brief Minimum RPC server implementation, |
23 | * redirects all the calls to C runtime API. |
24 | * |
25 | * \note This file do not depend on c++ std or c std, |
26 | * and only depends on TVM's C runtime API. |
27 | */ |
28 | #ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ |
29 | #define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ |
30 | |
31 | #ifndef DMLC_LITTLE_ENDIAN |
32 | #define DMLC_LITTLE_ENDIAN 1 |
33 | #endif |
34 | |
35 | #include <string.h> |
36 | #include <tvm/runtime/c_runtime_api.h> |
37 | |
38 | #include <memory> |
39 | #include <utility> |
40 | |
41 | #include "../../support/generic_arena.h" |
42 | #include "minrpc_interfaces.h" |
43 | #include "rpc_reference.h" |
44 | |
45 | #ifndef MINRPC_CHECK |
46 | #define MINRPC_CHECK(cond) \ |
47 | if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); |
48 | #endif |
49 | |
50 | namespace tvm { |
51 | namespace runtime { |
52 | |
53 | namespace detail { |
54 | template <typename TIOHandler> |
55 | class PageAllocator; |
56 | } |
57 | |
58 | /*! |
59 | * \brief Responses to a minimum RPC command. |
60 | * |
61 | * \tparam TIOHandler IO provider to provide io handling. |
62 | */ |
63 | template <typename TIOHandler> |
64 | class MinRPCReturns : public MinRPCReturnInterface { |
65 | public: |
66 | /*! |
67 | * \brief Constructor. |
68 | * \param io The IO handler. |
69 | */ |
70 | explicit MinRPCReturns(TIOHandler* io) : io_(io) {} |
71 | |
72 | void ReturnVoid() { |
73 | int32_t num_args = 1; |
74 | int32_t tcode = kTVMNullptr; |
75 | RPCCode code = RPCCode::kReturn; |
76 | |
77 | uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); |
78 | |
79 | io_->MessageStart(packet_nbytes); |
80 | Write(packet_nbytes); |
81 | Write(code); |
82 | Write(num_args); |
83 | Write(tcode); |
84 | io_->MessageDone(); |
85 | } |
86 | |
87 | void ReturnHandle(void* handle) { |
88 | int32_t num_args = 1; |
89 | int32_t tcode = kTVMOpaqueHandle; |
90 | RPCCode code = RPCCode::kReturn; |
91 | uint64_t encode_handle = reinterpret_cast<uint64_t>(handle); |
92 | uint64_t packet_nbytes = |
93 | sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); |
94 | |
95 | io_->MessageStart(packet_nbytes); |
96 | Write(packet_nbytes); |
97 | Write(code); |
98 | Write(num_args); |
99 | Write(tcode); |
100 | Write(encode_handle); |
101 | io_->MessageDone(); |
102 | } |
103 | |
104 | void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } |
105 | |
106 | void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { |
107 | RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); |
108 | } |
109 | |
110 | void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) { |
111 | RPCCode code = RPCCode::kCopyAck; |
112 | uint64_t packet_nbytes = sizeof(code) + num_bytes; |
113 | |
114 | io_->MessageStart(packet_nbytes); |
115 | Write(packet_nbytes); |
116 | Write(code); |
117 | WriteArray(data_ptr, num_bytes); |
118 | io_->MessageDone(); |
119 | } |
120 | |
121 | void ReturnLastTVMError() { |
122 | const char* err = TVMGetLastError(); |
123 | ReturnException(err); |
124 | } |
125 | |
126 | void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); } |
127 | |
128 | void MessageDone() { io_->MessageDone(); } |
129 | |
130 | void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { |
131 | io_->Exit(static_cast<int>(code)); |
132 | } |
133 | |
134 | template <typename T> |
135 | void Write(const T& data) { |
136 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
137 | "need to be trival" ); |
138 | return WriteRawBytes(&data, sizeof(T)); |
139 | } |
140 | |
141 | template <typename T> |
142 | void WriteArray(T* data, size_t count) { |
143 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
144 | "need to be trival" ); |
145 | return WriteRawBytes(data, sizeof(T) * count); |
146 | } |
147 | |
148 | private: |
149 | void WriteRawBytes(const void* data, size_t size) { |
150 | const uint8_t* buf = static_cast<const uint8_t*>(data); |
151 | size_t ndone = 0; |
152 | while (ndone < size) { |
153 | ssize_t ret = io_->PosixWrite(buf, size - ndone); |
154 | if (ret <= 0) { |
155 | this->ThrowError(RPCServerStatus::kWriteError); |
156 | } |
157 | buf += ret; |
158 | ndone += ret; |
159 | } |
160 | } |
161 | |
162 | TIOHandler* io_; |
163 | }; |
164 | |
165 | /*! |
166 | * \brief Executing a minimum RPC command. |
167 | * |
168 | * \tparam TIOHandler IO provider to provide io handling. |
169 | * \tparam MinRPCReturnInterface* handles response generatation and transmission. |
170 | */ |
171 | template <typename TIOHandler> |
172 | class MinRPCExecute : public MinRPCExecInterface { |
173 | public: |
174 | MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler) |
175 | : io_(io), ret_handler_(ret_handler) {} |
176 | |
177 | void InitServer(int num_args) { |
178 | MINRPC_CHECK(num_args == 0); |
179 | ret_handler_->ReturnVoid(); |
180 | } |
181 | |
182 | void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) { |
183 | TVMValue ret_value[3]; |
184 | int ret_tcode[3]; |
185 | |
186 | int call_ecode = TVMFuncCall(reinterpret_cast<void*>(call_handle), values, tcodes, num_args, |
187 | &(ret_value[1]), &(ret_tcode[1])); |
188 | |
189 | if (call_ecode == 0) { |
190 | // Return value encoding as in LocalSession |
191 | int rv_tcode = ret_tcode[1]; |
192 | ret_tcode[0] = kDLInt; |
193 | ret_value[0].v_int64 = rv_tcode; |
194 | if (rv_tcode == kTVMNDArrayHandle) { |
195 | ret_tcode[1] = kTVMDLTensorHandle; |
196 | ret_value[2].v_handle = ret_value[1].v_handle; |
197 | ret_tcode[2] = kTVMOpaqueHandle; |
198 | ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3); |
199 | } else if (rv_tcode == kTVMBytes) { |
200 | ret_tcode[1] = kTVMBytes; |
201 | ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); |
202 | TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); // NOLINT(*) |
203 | } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { |
204 | ret_tcode[1] = kTVMOpaqueHandle; |
205 | ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); |
206 | } else { |
207 | ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); |
208 | } |
209 | } else { |
210 | ret_handler_->ReturnLastTVMError(); |
211 | } |
212 | } |
213 | |
214 | void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { |
215 | int call_ecode = 0; |
216 | if (arr->device.device_type != kDLCPU) { |
217 | DLTensor temp; |
218 | temp.data = static_cast<void*>(data_ptr); |
219 | temp.device = DLDevice{kDLCPU, 0}; |
220 | temp.ndim = arr->ndim; |
221 | temp.dtype = arr->dtype; |
222 | temp.shape = arr->shape; |
223 | temp.strides = nullptr; |
224 | temp.byte_offset = 0; |
225 | call_ecode = TVMDeviceCopyDataFromTo(arr, &temp, nullptr); |
226 | // need sync to make sure that the copy is completed. |
227 | if (call_ecode == 0) { |
228 | call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); |
229 | } |
230 | } |
231 | |
232 | if (call_ecode == 0) { |
233 | ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes); |
234 | } else { |
235 | ret_handler_->ReturnLastTVMError(); |
236 | } |
237 | } |
238 | |
239 | int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { |
240 | int call_ecode = 0; |
241 | |
242 | int ret = ReadArray(data_ptr, num_bytes); |
243 | if (ret <= 0) return ret; |
244 | |
245 | if (arr->device.device_type != kDLCPU) { |
246 | DLTensor temp; |
247 | temp.data = data_ptr; |
248 | temp.device = DLDevice{kDLCPU, 0}; |
249 | temp.ndim = arr->ndim; |
250 | temp.dtype = arr->dtype; |
251 | temp.shape = arr->shape; |
252 | temp.strides = nullptr; |
253 | temp.byte_offset = 0; |
254 | call_ecode = TVMDeviceCopyDataFromTo(&temp, arr, nullptr); |
255 | // need sync to make sure that the copy is completed. |
256 | if (call_ecode == 0) { |
257 | call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); |
258 | } |
259 | } |
260 | |
261 | if (call_ecode == 0) { |
262 | ret_handler_->ReturnVoid(); |
263 | } else { |
264 | ret_handler_->ReturnLastTVMError(); |
265 | } |
266 | |
267 | return 1; |
268 | } |
269 | |
270 | void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) { |
271 | switch (code) { |
272 | case RPCCode::kFreeHandle: { |
273 | SyscallFreeHandle(values, tcodes, num_args); |
274 | break; |
275 | } |
276 | case RPCCode::kGetGlobalFunc: { |
277 | SyscallGetGlobalFunc(values, tcodes, num_args); |
278 | break; |
279 | } |
280 | case RPCCode::kDevSetDevice: { |
281 | ret_handler_->ReturnException("SetDevice not supported" ); |
282 | break; |
283 | } |
284 | case RPCCode::kDevGetAttr: { |
285 | ret_handler_->ReturnException("GetAttr not supported" ); |
286 | break; |
287 | } |
288 | case RPCCode::kDevAllocData: { |
289 | SyscallDevAllocData(values, tcodes, num_args); |
290 | break; |
291 | } |
292 | case RPCCode::kDevAllocDataWithScope: { |
293 | SyscallDevAllocDataWithScope(values, tcodes, num_args); |
294 | break; |
295 | } |
296 | case RPCCode::kDevFreeData: { |
297 | SyscallDevFreeData(values, tcodes, num_args); |
298 | break; |
299 | } |
300 | case RPCCode::kDevCreateStream: { |
301 | SyscallDevCreateStream(values, tcodes, num_args); |
302 | break; |
303 | } |
304 | case RPCCode::kDevFreeStream: { |
305 | SyscallDevFreeStream(values, tcodes, num_args); |
306 | break; |
307 | } |
308 | case RPCCode::kDevStreamSync: { |
309 | SyscallDevStreamSync(values, tcodes, num_args); |
310 | break; |
311 | } |
312 | case RPCCode::kDevSetStream: { |
313 | SyscallDevSetStream(values, tcodes, num_args); |
314 | break; |
315 | } |
316 | case RPCCode::kCopyAmongRemote: { |
317 | SyscallCopyAmongRemote(values, tcodes, num_args); |
318 | break; |
319 | } |
320 | default: { |
321 | ret_handler_->ReturnException("Syscall not recognized" ); |
322 | break; |
323 | } |
324 | } |
325 | } |
326 | |
327 | void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { |
328 | MINRPC_CHECK(num_args == 2); |
329 | MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); |
330 | MINRPC_CHECK(tcodes[1] == kDLInt); |
331 | |
332 | void* handle = values[0].v_handle; |
333 | int64_t type_code = values[1].v_int64; |
334 | int call_ecode; |
335 | |
336 | if (type_code == kTVMNDArrayHandle) { |
337 | call_ecode = TVMArrayFree(static_cast<TVMArrayHandle>(handle)); |
338 | } else if (type_code == kTVMPackedFuncHandle) { |
339 | call_ecode = TVMFuncFree(handle); |
340 | } else { |
341 | MINRPC_CHECK(type_code == kTVMModuleHandle); |
342 | call_ecode = TVMModFree(handle); |
343 | } |
344 | |
345 | if (call_ecode == 0) { |
346 | ret_handler_->ReturnVoid(); |
347 | } else { |
348 | ret_handler_->ReturnLastTVMError(); |
349 | } |
350 | } |
351 | |
352 | void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { |
353 | MINRPC_CHECK(num_args == 1); |
354 | MINRPC_CHECK(tcodes[0] == kTVMStr); |
355 | void* handle; |
356 | int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); |
357 | |
358 | if (call_ecode == 0) { |
359 | ret_handler_->ReturnHandle(handle); |
360 | } else { |
361 | ret_handler_->ReturnLastTVMError(); |
362 | } |
363 | } |
364 | |
365 | void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { |
366 | MINRPC_CHECK(num_args == 3); |
367 | // from dltensor |
368 | MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle); |
369 | // to dltensor |
370 | MINRPC_CHECK(tcodes[1] == kTVMDLTensorHandle); |
371 | // stream |
372 | MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); |
373 | |
374 | void* from = values[0].v_handle; |
375 | void* to = values[1].v_handle; |
376 | TVMStreamHandle stream = values[2].v_handle; |
377 | |
378 | int call_ecode = TVMDeviceCopyDataFromTo(reinterpret_cast<DLTensor*>(from), |
379 | reinterpret_cast<DLTensor*>(to), stream); |
380 | |
381 | if (call_ecode == 0) { |
382 | ret_handler_->ReturnVoid(); |
383 | } else { |
384 | ret_handler_->ReturnLastTVMError(); |
385 | } |
386 | } |
387 | |
388 | void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { |
389 | MINRPC_CHECK(num_args == 4); |
390 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
391 | MINRPC_CHECK(tcodes[1] == kDLInt); |
392 | MINRPC_CHECK(tcodes[2] == kDLInt); |
393 | MINRPC_CHECK(tcodes[3] == kTVMDataType); |
394 | |
395 | DLDevice dev = values[0].v_device; |
396 | int64_t nbytes = values[1].v_int64; |
397 | int64_t alignment = values[2].v_int64; |
398 | DLDataType type_hint = values[3].v_type; |
399 | |
400 | void* handle; |
401 | int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment, type_hint, &handle); |
402 | |
403 | if (call_ecode == 0) { |
404 | ret_handler_->ReturnHandle(handle); |
405 | } else { |
406 | ret_handler_->ReturnLastTVMError(); |
407 | } |
408 | } |
409 | |
410 | void SyscallDevAllocDataWithScope(TVMValue* values, int* tcodes, int num_args) { |
411 | MINRPC_CHECK(num_args == 2); |
412 | MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle); |
413 | MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr); |
414 | |
415 | DLTensor* arr = static_cast<DLTensor*>(values[0].v_handle); |
416 | const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr : values[1].v_str); |
417 | void* handle; |
418 | int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim, arr->shape, |
419 | arr->dtype, mem_scope, &handle); |
420 | if (call_ecode == 0) { |
421 | ret_handler_->ReturnHandle(handle); |
422 | } else { |
423 | ret_handler_->ReturnLastTVMError(); |
424 | } |
425 | } |
426 | |
427 | void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { |
428 | MINRPC_CHECK(num_args == 2); |
429 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
430 | MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); |
431 | |
432 | DLDevice dev = values[0].v_device; |
433 | void* handle = values[1].v_handle; |
434 | |
435 | int call_ecode = TVMDeviceFreeDataSpace(dev, handle); |
436 | |
437 | if (call_ecode == 0) { |
438 | ret_handler_->ReturnVoid(); |
439 | } else { |
440 | ret_handler_->ReturnLastTVMError(); |
441 | } |
442 | } |
443 | |
444 | void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { |
445 | MINRPC_CHECK(num_args == 1); |
446 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
447 | |
448 | DLDevice dev = values[0].v_device; |
449 | void* handle; |
450 | |
451 | int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); |
452 | |
453 | if (call_ecode == 0) { |
454 | ret_handler_->ReturnHandle(handle); |
455 | } else { |
456 | ret_handler_->ReturnLastTVMError(); |
457 | } |
458 | } |
459 | |
460 | void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { |
461 | MINRPC_CHECK(num_args == 2); |
462 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
463 | MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); |
464 | |
465 | DLDevice dev = values[0].v_device; |
466 | void* handle = values[1].v_handle; |
467 | |
468 | int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); |
469 | |
470 | if (call_ecode == 0) { |
471 | ret_handler_->ReturnVoid(); |
472 | } else { |
473 | ret_handler_->ReturnLastTVMError(); |
474 | } |
475 | } |
476 | |
477 | void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { |
478 | MINRPC_CHECK(num_args == 2); |
479 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
480 | MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); |
481 | |
482 | DLDevice dev = values[0].v_device; |
483 | void* handle = values[1].v_handle; |
484 | |
485 | int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle); |
486 | |
487 | if (call_ecode == 0) { |
488 | ret_handler_->ReturnVoid(); |
489 | } else { |
490 | ret_handler_->ReturnLastTVMError(); |
491 | } |
492 | } |
493 | |
494 | void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { |
495 | MINRPC_CHECK(num_args == 2); |
496 | MINRPC_CHECK(tcodes[0] == kDLDevice); |
497 | MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); |
498 | |
499 | DLDevice dev = values[0].v_device; |
500 | void* handle = values[1].v_handle; |
501 | |
502 | int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); |
503 | |
504 | if (call_ecode == 0) { |
505 | ret_handler_->ReturnVoid(); |
506 | } else { |
507 | ret_handler_->ReturnLastTVMError(); |
508 | } |
509 | } |
510 | |
511 | void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { |
512 | ret_handler_->ThrowError(code, info); |
513 | } |
514 | |
515 | MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; } |
516 | |
517 | private: |
518 | template <typename T> |
519 | int ReadArray(T* data, size_t count) { |
520 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
521 | "need to be trival" ); |
522 | return ReadRawBytes(data, sizeof(T) * count); |
523 | } |
524 | |
525 | int ReadRawBytes(void* data, size_t size) { |
526 | uint8_t* buf = static_cast<uint8_t*>(data); |
527 | size_t ndone = 0; |
528 | while (ndone < size) { |
529 | ssize_t ret = io_->PosixRead(buf, size - ndone); |
530 | if (ret <= 0) return ret; |
531 | ndone += ret; |
532 | buf += ret; |
533 | } |
534 | return 1; |
535 | } |
536 | |
537 | TIOHandler* io_; |
538 | MinRPCReturnInterface* ret_handler_; |
539 | }; |
540 | |
541 | /*! |
542 | * \brief A minimum RPC server that only depends on the tvm C runtime.. |
543 | * |
544 | * All the dependencies are provided by the io arguments. |
545 | * |
546 | * \tparam TIOHandler IO provider to provide io handling. |
547 | * An IOHandler needs to provide the following functions: |
548 | * - PosixWrite, PosixRead, Close: posix style, read, write, close API. |
549 | * - MessageStart(num_bytes), MessageDone(): framing APIs. |
550 | * - Exit: exit with status code. |
551 | */ |
552 | template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator> |
553 | class MinRPCServer { |
554 | public: |
555 | using PageAllocator = Allocator<TIOHandler>; |
556 | |
557 | /*! |
558 | * \brief Constructor. |
559 | * \param io The IO handler. |
560 | */ |
561 | MinRPCServer(TIOHandler* io, std::unique_ptr<MinRPCExecInterface>&& exec_handler) |
562 | : io_(io), arena_(PageAllocator(io_)), exec_handler_(std::move(exec_handler)) {} |
563 | |
564 | explicit MinRPCServer(TIOHandler* io) |
565 | : io_(io), |
566 | arena_(PageAllocator(io)), |
567 | ret_handler_(new MinRPCReturns<TIOHandler>(io_)), |
568 | exec_handler_(std::unique_ptr<MinRPCExecInterface>( |
569 | new MinRPCExecute<TIOHandler>(io_, ret_handler_))) {} |
570 | |
571 | ~MinRPCServer() { |
572 | if (ret_handler_ != nullptr) { |
573 | delete ret_handler_; |
574 | } |
575 | } |
576 | |
577 | /*! \brief Process a single request. |
578 | * |
579 | * \return true when the server should continue processing requests. false when it should be |
580 | * shutdown. |
581 | */ |
582 | bool ProcessOnePacket() { |
583 | RPCCode code; |
584 | uint64_t packet_len; |
585 | |
586 | arena_.RecycleAll(); |
587 | allow_clean_shutdown_ = true; |
588 | |
589 | Read(&packet_len); |
590 | if (packet_len == 0) return true; |
591 | Read(&code); |
592 | allow_clean_shutdown_ = false; |
593 | |
594 | if (code >= RPCCode::kSyscallCodeStart) { |
595 | HandleSyscallFunc(code); |
596 | } else { |
597 | switch (code) { |
598 | case RPCCode::kCallFunc: { |
599 | HandleNormalCallFunc(); |
600 | break; |
601 | } |
602 | case RPCCode::kInitServer: { |
603 | HandleInitServer(); |
604 | break; |
605 | } |
606 | case RPCCode::kCopyFromRemote: { |
607 | HandleCopyFromRemote(); |
608 | break; |
609 | } |
610 | case RPCCode::kCopyToRemote: { |
611 | HandleCopyToRemote(); |
612 | break; |
613 | } |
614 | case RPCCode::kShutdown: { |
615 | Shutdown(); |
616 | return false; |
617 | } |
618 | default: { |
619 | this->ThrowError(RPCServerStatus::kUnknownRPCCode); |
620 | break; |
621 | } |
622 | } |
623 | } |
624 | |
625 | return true; |
626 | } |
627 | |
628 | void HandleInitServer() { |
629 | uint64_t len; |
630 | Read(&len); |
631 | char* proto_ver = ArenaAlloc<char>(len + 1); |
632 | ReadArray(proto_ver, len); |
633 | TVMValue* values; |
634 | int* tcodes; |
635 | int num_args; |
636 | RecvPackedSeq(&values, &tcodes, &num_args); |
637 | exec_handler_->InitServer(num_args); |
638 | } |
639 | |
640 | void Shutdown() { |
641 | arena_.FreeAll(); |
642 | io_->Close(); |
643 | } |
644 | |
645 | void HandleNormalCallFunc() { |
646 | uint64_t call_handle; |
647 | TVMValue* values; |
648 | int* tcodes; |
649 | int num_args; |
650 | |
651 | Read(&call_handle); |
652 | RecvPackedSeq(&values, &tcodes, &num_args); |
653 | exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args); |
654 | } |
655 | |
656 | void HandleCopyFromRemote() { |
657 | DLTensor* arr = ArenaAlloc<DLTensor>(1); |
658 | uint64_t data_handle; |
659 | Read(&data_handle); |
660 | arr->data = reinterpret_cast<void*>(data_handle); |
661 | Read(&(arr->device)); |
662 | Read(&(arr->ndim)); |
663 | Read(&(arr->dtype)); |
664 | arr->shape = ArenaAlloc<int64_t>(arr->ndim); |
665 | ReadArray(arr->shape, arr->ndim); |
666 | arr->strides = nullptr; |
667 | Read(&(arr->byte_offset)); |
668 | |
669 | uint64_t num_bytes; |
670 | Read(&num_bytes); |
671 | |
672 | uint8_t* data_ptr; |
673 | if (arr->device.device_type == kDLCPU) { |
674 | data_ptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset; |
675 | } else { |
676 | data_ptr = ArenaAlloc<uint8_t>(num_bytes); |
677 | } |
678 | |
679 | exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr); |
680 | } |
681 | |
682 | void HandleCopyToRemote() { |
683 | DLTensor* arr = ArenaAlloc<DLTensor>(1); |
684 | uint64_t data_handle; |
685 | Read(&data_handle); |
686 | arr->data = reinterpret_cast<void*>(data_handle); |
687 | Read(&(arr->device)); |
688 | Read(&(arr->ndim)); |
689 | Read(&(arr->dtype)); |
690 | arr->shape = ArenaAlloc<int64_t>(arr->ndim); |
691 | ReadArray(arr->shape, arr->ndim); |
692 | arr->strides = nullptr; |
693 | Read(&(arr->byte_offset)); |
694 | uint64_t num_bytes; |
695 | Read(&num_bytes); |
696 | int ret; |
697 | if (arr->device.device_type == kDLCPU) { |
698 | uint8_t* dptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset; |
699 | ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr); |
700 | } else { |
701 | uint8_t* temp_data = ArenaAlloc<uint8_t>(num_bytes); |
702 | ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data); |
703 | } |
704 | if (ret == 0) { |
705 | if (allow_clean_shutdown_) { |
706 | Shutdown(); |
707 | io_->Exit(0); |
708 | } else { |
709 | this->ThrowError(RPCServerStatus::kReadError); |
710 | } |
711 | } |
712 | if (ret == -1) { |
713 | this->ThrowError(RPCServerStatus::kReadError); |
714 | } |
715 | } |
716 | |
717 | void HandleSyscallFunc(RPCCode code) { |
718 | TVMValue* values; |
719 | int* tcodes; |
720 | int num_args; |
721 | RecvPackedSeq(&values, &tcodes, &num_args); |
722 | |
723 | exec_handler_->SysCallFunc(code, values, tcodes, num_args); |
724 | } |
725 | |
726 | void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { |
727 | io_->Exit(static_cast<int>(code)); |
728 | } |
729 | |
730 | template <typename T> |
731 | T* ArenaAlloc(int count) { |
732 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
733 | "need to be trival" ); |
734 | return arena_.template allocate_<T>(count); |
735 | } |
736 | |
737 | template <typename T> |
738 | void Read(T* data) { |
739 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
740 | "need to be trival" ); |
741 | ReadRawBytes(data, sizeof(T)); |
742 | } |
743 | |
744 | template <typename T> |
745 | void ReadArray(T* data, size_t count) { |
746 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
747 | "need to be trival" ); |
748 | return ReadRawBytes(data, sizeof(T) * count); |
749 | } |
750 | |
751 | private: |
752 | void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { |
753 | RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); |
754 | } |
755 | |
756 | void ReadRawBytes(void* data, size_t size) { |
757 | uint8_t* buf = static_cast<uint8_t*>(data); |
758 | size_t ndone = 0; |
759 | while (ndone < size) { |
760 | ssize_t ret = io_->PosixRead(buf, size - ndone); |
761 | if (ret == 0) { |
762 | if (allow_clean_shutdown_) { |
763 | Shutdown(); |
764 | io_->Exit(0); |
765 | } else { |
766 | this->ThrowError(RPCServerStatus::kReadError); |
767 | } |
768 | } |
769 | if (ret == -1) { |
770 | this->ThrowError(RPCServerStatus::kReadError); |
771 | } |
772 | ndone += ret; |
773 | buf += ret; |
774 | } |
775 | } |
776 | |
777 | /*! \brief IO handler. */ |
778 | TIOHandler* io_; |
779 | /*! \brief internal arena. */ |
780 | support::GenericArena<PageAllocator> arena_; |
781 | MinRPCReturns<TIOHandler>* ret_handler_ = nullptr; |
782 | std::unique_ptr<MinRPCExecInterface> exec_handler_; |
783 | /*! \brief Whether we are in a state that allows clean shutdown. */ |
784 | bool allow_clean_shutdown_{true}; |
785 | static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little endian." ); |
786 | }; |
787 | |
788 | namespace detail { |
789 | // Internal allocator that redirects alloc to TVM's C API. |
790 | template <typename TIOHandler> |
791 | class PageAllocator { |
792 | public: |
793 | using = tvm::support::ArenaPageHeader; |
794 | |
795 | explicit PageAllocator(TIOHandler* io) : io_(io) {} |
796 | |
797 | ArenaPageHeader* allocate(size_t min_size) { |
798 | size_t npages = ((min_size + kPageSize - 1) / kPageSize); |
799 | void* data; |
800 | |
801 | if (TVMDeviceAllocDataSpace(DLDevice{kDLCPU, 0}, npages * kPageSize, kPageAlign, |
802 | DLDataType{kDLInt, 1, 1}, &data) != 0) { |
803 | io_->Exit(static_cast<int>(RPCServerStatus::kAllocError)); |
804 | } |
805 | |
806 | ArenaPageHeader* = static_cast<ArenaPageHeader*>(data); |
807 | header->size = npages * kPageSize; |
808 | header->offset = sizeof(ArenaPageHeader); |
809 | return header; |
810 | } |
811 | |
812 | void (ArenaPageHeader* page) { |
813 | if (TVMDeviceFreeDataSpace(DLDevice{kDLCPU, 0}, page) != 0) { |
814 | io_->Exit(static_cast<int>(RPCServerStatus::kAllocError)); |
815 | } |
816 | } |
817 | |
818 | static const constexpr int kPageSize = 2 << 10; |
819 | static const constexpr int kPageAlign = 8; |
820 | |
821 | private: |
822 | TIOHandler* io_; |
823 | }; |
824 | } // namespace detail |
825 | |
826 | } // namespace runtime |
827 | } // namespace tvm |
828 | #endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ |
829 | |