1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file minrpc_server.h
22 * \brief Minimum RPC server implementation,
23 * redirects all the calls to C runtime API.
24 *
25 * \note This file do not depend on c++ std or c std,
26 * and only depends on TVM's C runtime API.
27 */
28#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
29#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
30
31#ifndef DMLC_LITTLE_ENDIAN
32#define DMLC_LITTLE_ENDIAN 1
33#endif
34
35#include <string.h>
36#include <tvm/runtime/c_runtime_api.h>
37
38#include <memory>
39#include <utility>
40
41#include "../../support/generic_arena.h"
42#include "minrpc_interfaces.h"
43#include "rpc_reference.h"
44
45#ifndef MINRPC_CHECK
46#define MINRPC_CHECK(cond) \
47 if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError);
48#endif
49
50namespace tvm {
51namespace runtime {
52
53namespace detail {
54template <typename TIOHandler>
55class PageAllocator;
56}
57
58/*!
59 * \brief Responses to a minimum RPC command.
60 *
61 * \tparam TIOHandler IO provider to provide io handling.
62 */
63template <typename TIOHandler>
64class MinRPCReturns : public MinRPCReturnInterface {
65 public:
66 /*!
67 * \brief Constructor.
68 * \param io The IO handler.
69 */
70 explicit MinRPCReturns(TIOHandler* io) : io_(io) {}
71
72 void ReturnVoid() {
73 int32_t num_args = 1;
74 int32_t tcode = kTVMNullptr;
75 RPCCode code = RPCCode::kReturn;
76
77 uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
78
79 io_->MessageStart(packet_nbytes);
80 Write(packet_nbytes);
81 Write(code);
82 Write(num_args);
83 Write(tcode);
84 io_->MessageDone();
85 }
86
87 void ReturnHandle(void* handle) {
88 int32_t num_args = 1;
89 int32_t tcode = kTVMOpaqueHandle;
90 RPCCode code = RPCCode::kReturn;
91 uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
92 uint64_t packet_nbytes =
93 sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle);
94
95 io_->MessageStart(packet_nbytes);
96 Write(packet_nbytes);
97 Write(code);
98 Write(num_args);
99 Write(tcode);
100 Write(encode_handle);
101 io_->MessageDone();
102 }
103
104 void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); }
105
106 void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {
107 RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
108 }
109
110 void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {
111 RPCCode code = RPCCode::kCopyAck;
112 uint64_t packet_nbytes = sizeof(code) + num_bytes;
113
114 io_->MessageStart(packet_nbytes);
115 Write(packet_nbytes);
116 Write(code);
117 WriteArray(data_ptr, num_bytes);
118 io_->MessageDone();
119 }
120
121 void ReturnLastTVMError() {
122 const char* err = TVMGetLastError();
123 ReturnException(err);
124 }
125
126 void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); }
127
128 void MessageDone() { io_->MessageDone(); }
129
130 void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
131 io_->Exit(static_cast<int>(code));
132 }
133
134 template <typename T>
135 void Write(const T& data) {
136 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
137 "need to be trival");
138 return WriteRawBytes(&data, sizeof(T));
139 }
140
141 template <typename T>
142 void WriteArray(T* data, size_t count) {
143 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
144 "need to be trival");
145 return WriteRawBytes(data, sizeof(T) * count);
146 }
147
148 private:
149 void WriteRawBytes(const void* data, size_t size) {
150 const uint8_t* buf = static_cast<const uint8_t*>(data);
151 size_t ndone = 0;
152 while (ndone < size) {
153 ssize_t ret = io_->PosixWrite(buf, size - ndone);
154 if (ret <= 0) {
155 this->ThrowError(RPCServerStatus::kWriteError);
156 }
157 buf += ret;
158 ndone += ret;
159 }
160 }
161
162 TIOHandler* io_;
163};
164
165/*!
166 * \brief Executing a minimum RPC command.
167 *
168 * \tparam TIOHandler IO provider to provide io handling.
169 * \tparam MinRPCReturnInterface* handles response generatation and transmission.
170 */
171template <typename TIOHandler>
172class MinRPCExecute : public MinRPCExecInterface {
173 public:
174 MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler)
175 : io_(io), ret_handler_(ret_handler) {}
176
177 void InitServer(int num_args) {
178 MINRPC_CHECK(num_args == 0);
179 ret_handler_->ReturnVoid();
180 }
181
182 void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) {
183 TVMValue ret_value[3];
184 int ret_tcode[3];
185
186 int call_ecode = TVMFuncCall(reinterpret_cast<void*>(call_handle), values, tcodes, num_args,
187 &(ret_value[1]), &(ret_tcode[1]));
188
189 if (call_ecode == 0) {
190 // Return value encoding as in LocalSession
191 int rv_tcode = ret_tcode[1];
192 ret_tcode[0] = kDLInt;
193 ret_value[0].v_int64 = rv_tcode;
194 if (rv_tcode == kTVMNDArrayHandle) {
195 ret_tcode[1] = kTVMDLTensorHandle;
196 ret_value[2].v_handle = ret_value[1].v_handle;
197 ret_tcode[2] = kTVMOpaqueHandle;
198 ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3);
199 } else if (rv_tcode == kTVMBytes) {
200 ret_tcode[1] = kTVMBytes;
201 ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
202 TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); // NOLINT(*)
203 } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
204 ret_tcode[1] = kTVMOpaqueHandle;
205 ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
206 } else {
207 ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
208 }
209 } else {
210 ret_handler_->ReturnLastTVMError();
211 }
212 }
213
214 void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
215 int call_ecode = 0;
216 if (arr->device.device_type != kDLCPU) {
217 DLTensor temp;
218 temp.data = static_cast<void*>(data_ptr);
219 temp.device = DLDevice{kDLCPU, 0};
220 temp.ndim = arr->ndim;
221 temp.dtype = arr->dtype;
222 temp.shape = arr->shape;
223 temp.strides = nullptr;
224 temp.byte_offset = 0;
225 call_ecode = TVMDeviceCopyDataFromTo(arr, &temp, nullptr);
226 // need sync to make sure that the copy is completed.
227 if (call_ecode == 0) {
228 call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
229 }
230 }
231
232 if (call_ecode == 0) {
233 ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes);
234 } else {
235 ret_handler_->ReturnLastTVMError();
236 }
237 }
238
239 int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
240 int call_ecode = 0;
241
242 int ret = ReadArray(data_ptr, num_bytes);
243 if (ret <= 0) return ret;
244
245 if (arr->device.device_type != kDLCPU) {
246 DLTensor temp;
247 temp.data = data_ptr;
248 temp.device = DLDevice{kDLCPU, 0};
249 temp.ndim = arr->ndim;
250 temp.dtype = arr->dtype;
251 temp.shape = arr->shape;
252 temp.strides = nullptr;
253 temp.byte_offset = 0;
254 call_ecode = TVMDeviceCopyDataFromTo(&temp, arr, nullptr);
255 // need sync to make sure that the copy is completed.
256 if (call_ecode == 0) {
257 call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
258 }
259 }
260
261 if (call_ecode == 0) {
262 ret_handler_->ReturnVoid();
263 } else {
264 ret_handler_->ReturnLastTVMError();
265 }
266
267 return 1;
268 }
269
270 void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {
271 switch (code) {
272 case RPCCode::kFreeHandle: {
273 SyscallFreeHandle(values, tcodes, num_args);
274 break;
275 }
276 case RPCCode::kGetGlobalFunc: {
277 SyscallGetGlobalFunc(values, tcodes, num_args);
278 break;
279 }
280 case RPCCode::kDevSetDevice: {
281 ret_handler_->ReturnException("SetDevice not supported");
282 break;
283 }
284 case RPCCode::kDevGetAttr: {
285 ret_handler_->ReturnException("GetAttr not supported");
286 break;
287 }
288 case RPCCode::kDevAllocData: {
289 SyscallDevAllocData(values, tcodes, num_args);
290 break;
291 }
292 case RPCCode::kDevAllocDataWithScope: {
293 SyscallDevAllocDataWithScope(values, tcodes, num_args);
294 break;
295 }
296 case RPCCode::kDevFreeData: {
297 SyscallDevFreeData(values, tcodes, num_args);
298 break;
299 }
300 case RPCCode::kDevCreateStream: {
301 SyscallDevCreateStream(values, tcodes, num_args);
302 break;
303 }
304 case RPCCode::kDevFreeStream: {
305 SyscallDevFreeStream(values, tcodes, num_args);
306 break;
307 }
308 case RPCCode::kDevStreamSync: {
309 SyscallDevStreamSync(values, tcodes, num_args);
310 break;
311 }
312 case RPCCode::kDevSetStream: {
313 SyscallDevSetStream(values, tcodes, num_args);
314 break;
315 }
316 case RPCCode::kCopyAmongRemote: {
317 SyscallCopyAmongRemote(values, tcodes, num_args);
318 break;
319 }
320 default: {
321 ret_handler_->ReturnException("Syscall not recognized");
322 break;
323 }
324 }
325 }
326
327 void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) {
328 MINRPC_CHECK(num_args == 2);
329 MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle);
330 MINRPC_CHECK(tcodes[1] == kDLInt);
331
332 void* handle = values[0].v_handle;
333 int64_t type_code = values[1].v_int64;
334 int call_ecode;
335
336 if (type_code == kTVMNDArrayHandle) {
337 call_ecode = TVMArrayFree(static_cast<TVMArrayHandle>(handle));
338 } else if (type_code == kTVMPackedFuncHandle) {
339 call_ecode = TVMFuncFree(handle);
340 } else {
341 MINRPC_CHECK(type_code == kTVMModuleHandle);
342 call_ecode = TVMModFree(handle);
343 }
344
345 if (call_ecode == 0) {
346 ret_handler_->ReturnVoid();
347 } else {
348 ret_handler_->ReturnLastTVMError();
349 }
350 }
351
352 void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) {
353 MINRPC_CHECK(num_args == 1);
354 MINRPC_CHECK(tcodes[0] == kTVMStr);
355 void* handle;
356 int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle);
357
358 if (call_ecode == 0) {
359 ret_handler_->ReturnHandle(handle);
360 } else {
361 ret_handler_->ReturnLastTVMError();
362 }
363 }
364
365 void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) {
366 MINRPC_CHECK(num_args == 3);
367 // from dltensor
368 MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle);
369 // to dltensor
370 MINRPC_CHECK(tcodes[1] == kTVMDLTensorHandle);
371 // stream
372 MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle);
373
374 void* from = values[0].v_handle;
375 void* to = values[1].v_handle;
376 TVMStreamHandle stream = values[2].v_handle;
377
378 int call_ecode = TVMDeviceCopyDataFromTo(reinterpret_cast<DLTensor*>(from),
379 reinterpret_cast<DLTensor*>(to), stream);
380
381 if (call_ecode == 0) {
382 ret_handler_->ReturnVoid();
383 } else {
384 ret_handler_->ReturnLastTVMError();
385 }
386 }
387
388 void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) {
389 MINRPC_CHECK(num_args == 4);
390 MINRPC_CHECK(tcodes[0] == kDLDevice);
391 MINRPC_CHECK(tcodes[1] == kDLInt);
392 MINRPC_CHECK(tcodes[2] == kDLInt);
393 MINRPC_CHECK(tcodes[3] == kTVMDataType);
394
395 DLDevice dev = values[0].v_device;
396 int64_t nbytes = values[1].v_int64;
397 int64_t alignment = values[2].v_int64;
398 DLDataType type_hint = values[3].v_type;
399
400 void* handle;
401 int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment, type_hint, &handle);
402
403 if (call_ecode == 0) {
404 ret_handler_->ReturnHandle(handle);
405 } else {
406 ret_handler_->ReturnLastTVMError();
407 }
408 }
409
410 void SyscallDevAllocDataWithScope(TVMValue* values, int* tcodes, int num_args) {
411 MINRPC_CHECK(num_args == 2);
412 MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle);
413 MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr);
414
415 DLTensor* arr = static_cast<DLTensor*>(values[0].v_handle);
416 const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr : values[1].v_str);
417 void* handle;
418 int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim, arr->shape,
419 arr->dtype, mem_scope, &handle);
420 if (call_ecode == 0) {
421 ret_handler_->ReturnHandle(handle);
422 } else {
423 ret_handler_->ReturnLastTVMError();
424 }
425 }
426
427 void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) {
428 MINRPC_CHECK(num_args == 2);
429 MINRPC_CHECK(tcodes[0] == kDLDevice);
430 MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
431
432 DLDevice dev = values[0].v_device;
433 void* handle = values[1].v_handle;
434
435 int call_ecode = TVMDeviceFreeDataSpace(dev, handle);
436
437 if (call_ecode == 0) {
438 ret_handler_->ReturnVoid();
439 } else {
440 ret_handler_->ReturnLastTVMError();
441 }
442 }
443
444 void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) {
445 MINRPC_CHECK(num_args == 1);
446 MINRPC_CHECK(tcodes[0] == kDLDevice);
447
448 DLDevice dev = values[0].v_device;
449 void* handle;
450
451 int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle);
452
453 if (call_ecode == 0) {
454 ret_handler_->ReturnHandle(handle);
455 } else {
456 ret_handler_->ReturnLastTVMError();
457 }
458 }
459
460 void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) {
461 MINRPC_CHECK(num_args == 2);
462 MINRPC_CHECK(tcodes[0] == kDLDevice);
463 MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
464
465 DLDevice dev = values[0].v_device;
466 void* handle = values[1].v_handle;
467
468 int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle);
469
470 if (call_ecode == 0) {
471 ret_handler_->ReturnVoid();
472 } else {
473 ret_handler_->ReturnLastTVMError();
474 }
475 }
476
477 void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) {
478 MINRPC_CHECK(num_args == 2);
479 MINRPC_CHECK(tcodes[0] == kDLDevice);
480 MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
481
482 DLDevice dev = values[0].v_device;
483 void* handle = values[1].v_handle;
484
485 int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle);
486
487 if (call_ecode == 0) {
488 ret_handler_->ReturnVoid();
489 } else {
490 ret_handler_->ReturnLastTVMError();
491 }
492 }
493
494 void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) {
495 MINRPC_CHECK(num_args == 2);
496 MINRPC_CHECK(tcodes[0] == kDLDevice);
497 MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
498
499 DLDevice dev = values[0].v_device;
500 void* handle = values[1].v_handle;
501
502 int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle);
503
504 if (call_ecode == 0) {
505 ret_handler_->ReturnVoid();
506 } else {
507 ret_handler_->ReturnLastTVMError();
508 }
509 }
510
511 void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
512 ret_handler_->ThrowError(code, info);
513 }
514
515 MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; }
516
517 private:
518 template <typename T>
519 int ReadArray(T* data, size_t count) {
520 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
521 "need to be trival");
522 return ReadRawBytes(data, sizeof(T) * count);
523 }
524
525 int ReadRawBytes(void* data, size_t size) {
526 uint8_t* buf = static_cast<uint8_t*>(data);
527 size_t ndone = 0;
528 while (ndone < size) {
529 ssize_t ret = io_->PosixRead(buf, size - ndone);
530 if (ret <= 0) return ret;
531 ndone += ret;
532 buf += ret;
533 }
534 return 1;
535 }
536
537 TIOHandler* io_;
538 MinRPCReturnInterface* ret_handler_;
539};
540
541/*!
542 * \brief A minimum RPC server that only depends on the tvm C runtime..
543 *
544 * All the dependencies are provided by the io arguments.
545 *
546 * \tparam TIOHandler IO provider to provide io handling.
547 * An IOHandler needs to provide the following functions:
548 * - PosixWrite, PosixRead, Close: posix style, read, write, close API.
549 * - MessageStart(num_bytes), MessageDone(): framing APIs.
550 * - Exit: exit with status code.
551 */
552template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator>
553class MinRPCServer {
554 public:
555 using PageAllocator = Allocator<TIOHandler>;
556
557 /*!
558 * \brief Constructor.
559 * \param io The IO handler.
560 */
561 MinRPCServer(TIOHandler* io, std::unique_ptr<MinRPCExecInterface>&& exec_handler)
562 : io_(io), arena_(PageAllocator(io_)), exec_handler_(std::move(exec_handler)) {}
563
564 explicit MinRPCServer(TIOHandler* io)
565 : io_(io),
566 arena_(PageAllocator(io)),
567 ret_handler_(new MinRPCReturns<TIOHandler>(io_)),
568 exec_handler_(std::unique_ptr<MinRPCExecInterface>(
569 new MinRPCExecute<TIOHandler>(io_, ret_handler_))) {}
570
571 ~MinRPCServer() {
572 if (ret_handler_ != nullptr) {
573 delete ret_handler_;
574 }
575 }
576
577 /*! \brief Process a single request.
578 *
579 * \return true when the server should continue processing requests. false when it should be
580 * shutdown.
581 */
582 bool ProcessOnePacket() {
583 RPCCode code;
584 uint64_t packet_len;
585
586 arena_.RecycleAll();
587 allow_clean_shutdown_ = true;
588
589 Read(&packet_len);
590 if (packet_len == 0) return true;
591 Read(&code);
592 allow_clean_shutdown_ = false;
593
594 if (code >= RPCCode::kSyscallCodeStart) {
595 HandleSyscallFunc(code);
596 } else {
597 switch (code) {
598 case RPCCode::kCallFunc: {
599 HandleNormalCallFunc();
600 break;
601 }
602 case RPCCode::kInitServer: {
603 HandleInitServer();
604 break;
605 }
606 case RPCCode::kCopyFromRemote: {
607 HandleCopyFromRemote();
608 break;
609 }
610 case RPCCode::kCopyToRemote: {
611 HandleCopyToRemote();
612 break;
613 }
614 case RPCCode::kShutdown: {
615 Shutdown();
616 return false;
617 }
618 default: {
619 this->ThrowError(RPCServerStatus::kUnknownRPCCode);
620 break;
621 }
622 }
623 }
624
625 return true;
626 }
627
628 void HandleInitServer() {
629 uint64_t len;
630 Read(&len);
631 char* proto_ver = ArenaAlloc<char>(len + 1);
632 ReadArray(proto_ver, len);
633 TVMValue* values;
634 int* tcodes;
635 int num_args;
636 RecvPackedSeq(&values, &tcodes, &num_args);
637 exec_handler_->InitServer(num_args);
638 }
639
640 void Shutdown() {
641 arena_.FreeAll();
642 io_->Close();
643 }
644
645 void HandleNormalCallFunc() {
646 uint64_t call_handle;
647 TVMValue* values;
648 int* tcodes;
649 int num_args;
650
651 Read(&call_handle);
652 RecvPackedSeq(&values, &tcodes, &num_args);
653 exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args);
654 }
655
656 void HandleCopyFromRemote() {
657 DLTensor* arr = ArenaAlloc<DLTensor>(1);
658 uint64_t data_handle;
659 Read(&data_handle);
660 arr->data = reinterpret_cast<void*>(data_handle);
661 Read(&(arr->device));
662 Read(&(arr->ndim));
663 Read(&(arr->dtype));
664 arr->shape = ArenaAlloc<int64_t>(arr->ndim);
665 ReadArray(arr->shape, arr->ndim);
666 arr->strides = nullptr;
667 Read(&(arr->byte_offset));
668
669 uint64_t num_bytes;
670 Read(&num_bytes);
671
672 uint8_t* data_ptr;
673 if (arr->device.device_type == kDLCPU) {
674 data_ptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
675 } else {
676 data_ptr = ArenaAlloc<uint8_t>(num_bytes);
677 }
678
679 exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr);
680 }
681
682 void HandleCopyToRemote() {
683 DLTensor* arr = ArenaAlloc<DLTensor>(1);
684 uint64_t data_handle;
685 Read(&data_handle);
686 arr->data = reinterpret_cast<void*>(data_handle);
687 Read(&(arr->device));
688 Read(&(arr->ndim));
689 Read(&(arr->dtype));
690 arr->shape = ArenaAlloc<int64_t>(arr->ndim);
691 ReadArray(arr->shape, arr->ndim);
692 arr->strides = nullptr;
693 Read(&(arr->byte_offset));
694 uint64_t num_bytes;
695 Read(&num_bytes);
696 int ret;
697 if (arr->device.device_type == kDLCPU) {
698 uint8_t* dptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
699 ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr);
700 } else {
701 uint8_t* temp_data = ArenaAlloc<uint8_t>(num_bytes);
702 ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data);
703 }
704 if (ret == 0) {
705 if (allow_clean_shutdown_) {
706 Shutdown();
707 io_->Exit(0);
708 } else {
709 this->ThrowError(RPCServerStatus::kReadError);
710 }
711 }
712 if (ret == -1) {
713 this->ThrowError(RPCServerStatus::kReadError);
714 }
715 }
716
717 void HandleSyscallFunc(RPCCode code) {
718 TVMValue* values;
719 int* tcodes;
720 int num_args;
721 RecvPackedSeq(&values, &tcodes, &num_args);
722
723 exec_handler_->SysCallFunc(code, values, tcodes, num_args);
724 }
725
726 void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
727 io_->Exit(static_cast<int>(code));
728 }
729
730 template <typename T>
731 T* ArenaAlloc(int count) {
732 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
733 "need to be trival");
734 return arena_.template allocate_<T>(count);
735 }
736
737 template <typename T>
738 void Read(T* data) {
739 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
740 "need to be trival");
741 ReadRawBytes(data, sizeof(T));
742 }
743
744 template <typename T>
745 void ReadArray(T* data, size_t count) {
746 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
747 "need to be trival");
748 return ReadRawBytes(data, sizeof(T) * count);
749 }
750
751 private:
752 void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
753 RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
754 }
755
756 void ReadRawBytes(void* data, size_t size) {
757 uint8_t* buf = static_cast<uint8_t*>(data);
758 size_t ndone = 0;
759 while (ndone < size) {
760 ssize_t ret = io_->PosixRead(buf, size - ndone);
761 if (ret == 0) {
762 if (allow_clean_shutdown_) {
763 Shutdown();
764 io_->Exit(0);
765 } else {
766 this->ThrowError(RPCServerStatus::kReadError);
767 }
768 }
769 if (ret == -1) {
770 this->ThrowError(RPCServerStatus::kReadError);
771 }
772 ndone += ret;
773 buf += ret;
774 }
775 }
776
777 /*! \brief IO handler. */
778 TIOHandler* io_;
779 /*! \brief internal arena. */
780 support::GenericArena<PageAllocator> arena_;
781 MinRPCReturns<TIOHandler>* ret_handler_ = nullptr;
782 std::unique_ptr<MinRPCExecInterface> exec_handler_;
783 /*! \brief Whether we are in a state that allows clean shutdown. */
784 bool allow_clean_shutdown_{true};
785 static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little endian.");
786};
787
788namespace detail {
789// Internal allocator that redirects alloc to TVM's C API.
790template <typename TIOHandler>
791class PageAllocator {
792 public:
793 using ArenaPageHeader = tvm::support::ArenaPageHeader;
794
795 explicit PageAllocator(TIOHandler* io) : io_(io) {}
796
797 ArenaPageHeader* allocate(size_t min_size) {
798 size_t npages = ((min_size + kPageSize - 1) / kPageSize);
799 void* data;
800
801 if (TVMDeviceAllocDataSpace(DLDevice{kDLCPU, 0}, npages * kPageSize, kPageAlign,
802 DLDataType{kDLInt, 1, 1}, &data) != 0) {
803 io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
804 }
805
806 ArenaPageHeader* header = static_cast<ArenaPageHeader*>(data);
807 header->size = npages * kPageSize;
808 header->offset = sizeof(ArenaPageHeader);
809 return header;
810 }
811
812 void deallocate(ArenaPageHeader* page) {
813 if (TVMDeviceFreeDataSpace(DLDevice{kDLCPU, 0}, page) != 0) {
814 io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
815 }
816 }
817
818 static const constexpr int kPageSize = 2 << 10;
819 static const constexpr int kPageAlign = 8;
820
821 private:
822 TIOHandler* io_;
823};
824} // namespace detail
825
826} // namespace runtime
827} // namespace tvm
828#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
829