1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_session.cc
22 * \brief RPC session for remote function call.
23 */
24#include "rpc_endpoint.h"
25
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/packed_func.h>
29#include <tvm/runtime/registry.h>
30#include <tvm/runtime/serializer.h>
31
32#include <algorithm>
33#include <array>
34#include <chrono>
35#include <cmath>
36#include <memory>
37#include <string>
38#include <utility>
39#include <vector>
40
41#include "../../support/arena.h"
42#include "../../support/ring_buffer.h"
43#include "../object_internal.h"
44#include "rpc_local_session.h"
45
46namespace tvm {
47namespace runtime {
48
49/*!
50 * Event-driven state-machine based handlers for RPCEndpoint.
51 *
52 * Key functions:
53 *
54 * - SendPackedSeq: send the arguments over to the peer
55 * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol).
56 */
57class RPCEndpoint::EventHandler : public dmlc::Stream {
58 public:
59 EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name,
60 std::string* remote_key, std::function<void()> flush_writer)
61 : reader_(reader),
62 writer_(writer),
63 name_(name),
64 remote_key_(remote_key),
65 flush_writer_(flush_writer) {
66 this->Clear();
67
68 if (*remote_key == "%toinit") {
69 state_ = kInitHeader;
70 remote_key_->resize(0);
71 pending_request_bytes_ = sizeof(int32_t);
72 }
73 }
74
75 /*!
76 * \brief Bytes needed to fulfill current request
77 */
78 size_t BytesNeeded() const {
79 if (reader_->bytes_available() < pending_request_bytes_) {
80 return pending_request_bytes_ - reader_->bytes_available();
81 } else {
82 return 0;
83 }
84 }
85
86 /*!
87 * \brief Request number of bytes from the reader.
88 * \param nbytes The number of bytes
89 */
90 void RequestBytes(size_t nbytes) {
91 pending_request_bytes_ += nbytes;
92 reader_->Reserve(pending_request_bytes_);
93 }
94
95 /*! \return Whether we are ready to handle next request. */
96 bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; }
97
98 /*! \return Whether we can perform a clean shutdown */
99 bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; }
100
101 /*! \brief Finish the copy ack stage. */
102 void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); }
103
104 /*!
105 * \brief Enter the io loop until the next event.
106 * \param client_mode Whether we are in the client.
107 * \param async_server_mode Whether we are in the async server mode.
108 * \param setreturn The function to set the return value encoding.
109 * \return The function to set return values when there is a return event.
110 */
111 RPCCode HandleNextEvent(bool client_mode, bool async_server_mode,
112 RPCSession::FEncodeReturn setreturn) {
113 std::swap(client_mode_, client_mode);
114 std::swap(async_server_mode_, async_server_mode);
115
116 RPCCode status = RPCCode::kNone;
117
118 while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) {
119 switch (state_) {
120 case kInitHeader:
121 HandleInitHeader();
122 break;
123 case kRecvPacketNumBytes: {
124 uint64_t packet_nbytes;
125 ICHECK(this->Read(&packet_nbytes));
126 if (packet_nbytes != 0) {
127 this->SwitchToState(kProcessPacket);
128 this->RequestBytes(packet_nbytes);
129 } else {
130 this->SwitchToState(kRecvPacketNumBytes);
131 }
132 break;
133 }
134 case kProcessPacket: {
135 this->HandleProcessPacket(setreturn);
136 break;
137 }
138 case kWaitForAsyncCallback: {
139 break;
140 }
141 case kReturnReceived: {
142 this->SwitchToState(kRecvPacketNumBytes);
143 status = RPCCode::kReturn;
144 break;
145 }
146 case kCopyAckReceived: {
147 status = RPCCode::kCopyAck;
148 break;
149 }
150 case kShutdownReceived: {
151 status = RPCCode::kShutdown;
152 }
153 }
154 }
155
156 std::swap(async_server_mode_, async_server_mode);
157 std::swap(client_mode_, client_mode);
158 return status;
159 }
160
161 /*! \brief Clear all the states in the Handler.*/
162 void Clear() {
163 state_ = kRecvPacketNumBytes;
164 pending_request_bytes_ = sizeof(uint64_t);
165 }
166
167 /*!
168 * \brief Validate that the arguments can be sent through RPC.
169 * \param arg_values The argument values.
170 * \param type_codes The type codes.
171 */
172 void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) {
173 TVMArgs args(arg_values, type_codes, num_args);
174 for (int i = 0; i < num_args; ++i) {
175 int tcode = type_codes[i];
176 if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
177 LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
178 << args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC";
179 } else if (tcode == kDLDevice) {
180 DLDevice dev = args[i];
181 ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel";
182 }
183 }
184 }
185
186 void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
187 LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code);
188 }
189
190 uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args,
191 bool client_mode) {
192 return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this);
193 }
194
195 void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
196 bool client_mode) {
197 RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this);
198 }
199
200 // Endian aware IO handling
201 using Stream::Read;
202 using Stream::ReadArray;
203 using Stream::Write;
204 using Stream::WriteArray;
205
206 void MessageStart(uint64_t packet_nbytes) {
207 // Unused here, implemented for microTVM framing layer.
208 }
209
210 bool Read(RPCCode* code) {
211 int32_t cdata;
212 if (!this->Read(&cdata)) return false;
213 *code = static_cast<RPCCode>(cdata);
214 return true;
215 }
216 void Write(RPCCode code) {
217 int32_t cdata = static_cast<int>(code);
218 this->Write(cdata);
219 }
220
221 void MessageDone() {
222 // Unused here, implemented for microTVM framing layer.
223 }
224
225 template <typename T>
226 T* ArenaAlloc(int count) {
227 static_assert(std::is_pod<T>::value, "need to be trival");
228 return arena_.template allocate_<T>(count);
229 }
230
231 protected:
232 enum State {
233 kInitHeader,
234 kRecvPacketNumBytes,
235 kProcessPacket,
236 kWaitForAsyncCallback,
237 kReturnReceived,
238 kCopyAckReceived,
239 kShutdownReceived
240 };
241 // Current state;
242 State state_;
243 // Initialize remote header
244 int init_header_step_{0};
245 // Whether current handler is client or server mode.
246 bool client_mode_{false};
247 // Whether current handler is in the async server mode.
248 bool async_server_mode_{false};
249 // Internal arena
250 support::Arena arena_;
251
252 // State switcher
253 void SwitchToState(State state) {
254 // invariant
255 if (state != kCopyAckReceived) {
256 ICHECK_EQ(pending_request_bytes_, 0U) << "state=" << state;
257 }
258 // need to actively flush the writer
259 // so the data get pushed out.
260 if (state_ == kWaitForAsyncCallback) {
261 flush_writer_();
262 }
263 state_ = state;
264 ICHECK(state != kInitHeader) << "cannot switch to init header";
265 if (state == kRecvPacketNumBytes) {
266 this->RequestBytes(sizeof(uint64_t));
267 // recycle arena for the next session.
268 arena_.RecycleAll();
269 }
270 }
271
272 // handler for initial header read
273 void HandleInitHeader() {
274 if (init_header_step_ == 0) {
275 int32_t len;
276 this->Read(&len);
277 remote_key_->resize(len);
278 init_header_step_ = 1;
279 this->RequestBytes(len);
280 return;
281 } else {
282 ICHECK_EQ(init_header_step_, 1);
283 this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
284 this->SwitchToState(kRecvPacketNumBytes);
285 }
286 }
287
288 // Handler for read code.
289 void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) {
290 RPCCode code = RPCCode::kNone;
291 this->Read(&code);
292
293 if (code >= RPCCode::kSyscallCodeStart) {
294 this->HandleSyscall(code);
295 } else {
296 switch (code) {
297 case RPCCode::kInitServer: {
298 this->HandleInitServer();
299 break;
300 }
301 case RPCCode::kCallFunc: {
302 this->HandleNormalCallFunc();
303 break;
304 }
305 case RPCCode::kCopyFromRemote: {
306 this->HandleCopyFromRemote();
307 break;
308 }
309 case RPCCode::kCopyToRemote: {
310 this->HandleCopyToRemote();
311 break;
312 }
313 case RPCCode::kException:
314 case RPCCode::kReturn: {
315 this->HandleReturn(code, setreturn);
316 break;
317 }
318 case RPCCode::kCopyAck: {
319 this->SwitchToState(kCopyAckReceived);
320 break;
321 }
322 case RPCCode::kShutdown: {
323 this->SwitchToState(kShutdownReceived);
324 break;
325 }
326 default:
327 LOG(FATAL) << "Unknown event " << static_cast<int>(code);
328 }
329 }
330 }
331
332 /*!
333 * \brief Receive incoming packed seq from the stream.
334 * \return The received argments.
335 * \note The TVMArgs is available until we switchstate.
336 */
337 TVMArgs RecvPackedSeq() {
338 TVMValue* values;
339 int* tcodes;
340 int num_args;
341 RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this);
342 return TVMArgs(values, tcodes, num_args);
343 }
344
345 /*!
346 * \brief Return exception to the remote.
347 * \param err_msg The error message.
348 */
349 void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); }
350
351 /*!
352 * \brief Return nullptr to the remote.
353 * \param err_msg The error message.
354 */
355 void ReturnVoid() { RPCReference::ReturnVoid(this); }
356
357 /*!
358 * \brief Return a packed sequence to the remote.
359 * \param args The arguments.
360 */
361 void ReturnPackedSeq(TVMArgs args) {
362 RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this);
363 }
364
365 /*!
366 * \brief Handle the case when return/exception value is received.
367 * \param code The RPC code.
368 * \param setreturn The function to encode return.
369 */
370 void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
371 TVMArgs args = RecvPackedSeq();
372 if (code == RPCCode::kException) {
373 // switch to the state before sending exception.
374 this->SwitchToState(kRecvPacketNumBytes);
375 std::string msg = args[0];
376 LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
377 }
378
379 ICHECK(setreturn != nullptr) << "fsetreturn not available";
380 setreturn(args);
381
382 this->SwitchToState(kReturnReceived);
383 }
384
385 void HandleSyscall(RPCCode code);
386
387 void HandleCopyFromRemote() {
388 DLTensor* arr = RPCReference::ReceiveDLTensor(this);
389 uint64_t data_bytes;
390 this->Read(&data_bytes);
391 size_t elem_bytes = (arr->dtype.bits * arr->dtype.lanes + 7) / 8;
392 auto* sess = GetServingSession();
393 // Return Copy Ack with the given data
394 auto fcopyack = [this](char* dptr, size_t num_bytes) {
395 RPCCode code = RPCCode::kCopyAck;
396 uint64_t packet_nbytes = sizeof(code) + num_bytes;
397
398 this->Write(packet_nbytes);
399 this->Write(code);
400 this->WriteArray(dptr, num_bytes);
401 this->SwitchToState(kRecvPacketNumBytes);
402 };
403
404 // When session is local, we can directly treat handle
405 // as the cpu pointer without allocating a temp space.
406 if (arr->device.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) {
407 char* data_ptr = reinterpret_cast<char*>(arr->data) + arr->byte_offset;
408 fcopyack(data_ptr, data_bytes);
409 } else {
410 char* temp_data = this->ArenaAlloc<char>(data_bytes);
411 auto on_copy_complete = [this, elem_bytes, data_bytes, temp_data, fcopyack](RPCCode status,
412 TVMArgs args) {
413 if (status == RPCCode::kException) {
414 this->ReturnException(args.values[0].v_str);
415 this->SwitchToState(kRecvPacketNumBytes);
416 } else {
417 // endian aware handling
418 if (!DMLC_IO_NO_ENDIAN_SWAP) {
419 dmlc::ByteSwap(temp_data, elem_bytes, data_bytes / elem_bytes);
420 }
421 fcopyack(temp_data, data_bytes);
422 }
423 };
424
425 this->SwitchToState(kWaitForAsyncCallback);
426 sess->AsyncCopyFromRemote(arr, static_cast<void*>(temp_data), data_bytes, on_copy_complete);
427 }
428 }
429
430 void HandleCopyToRemote() {
431 DLTensor* arr = RPCReference::ReceiveDLTensor(this);
432 uint64_t data_bytes;
433 this->Read(&data_bytes);
434 size_t elem_bytes = (arr->dtype.bits * arr->dtype.lanes + 7) / 8;
435 auto* sess = GetServingSession();
436
437 // When session is local, we can directly treat handle
438 // as the cpu pointer without allocating a temp space.
439 if (arr->device.device_type == kDLCPU && sess->IsLocalSession()) {
440 char* dptr = reinterpret_cast<char*>(arr->data) + arr->byte_offset;
441 this->ReadArray(dptr, data_bytes);
442
443 if (!DMLC_IO_NO_ENDIAN_SWAP) {
444 dmlc::ByteSwap(dptr, elem_bytes, data_bytes / elem_bytes);
445 }
446 this->ReturnVoid();
447 this->SwitchToState(kRecvPacketNumBytes);
448 } else {
449 char* temp_data = this->ArenaAlloc<char>(data_bytes);
450 this->ReadArray(temp_data, data_bytes);
451
452 if (!DMLC_IO_NO_ENDIAN_SWAP) {
453 dmlc::ByteSwap(temp_data, elem_bytes, data_bytes / elem_bytes);
454 }
455
456 auto on_copy_complete = [this](RPCCode status, TVMArgs args) {
457 if (status == RPCCode::kException) {
458 this->ReturnException(args.values[0].v_str);
459 this->SwitchToState(kRecvPacketNumBytes);
460 } else {
461 this->ReturnVoid();
462 this->SwitchToState(kRecvPacketNumBytes);
463 }
464 };
465
466 this->SwitchToState(kWaitForAsyncCallback);
467 sess->AsyncCopyToRemote(static_cast<void*>(temp_data), arr, data_bytes, on_copy_complete);
468 }
469 }
470
471 // Handle for packed call.
472 void HandleNormalCallFunc() {
473 uint64_t call_handle;
474
475 this->Read(&call_handle);
476 TVMArgs args = RecvPackedSeq();
477
478 this->SwitchToState(kWaitForAsyncCallback);
479 GetServingSession()->AsyncCallFunc(
480 reinterpret_cast<void*>(call_handle), args.values, args.type_codes, args.size(),
481 [this](RPCCode status, TVMArgs args) {
482 if (status == RPCCode::kException) {
483 this->ReturnException(args.values[0].v_str);
484 } else {
485 ValidateArguments(args.values, args.type_codes, args.size());
486 this->ReturnPackedSeq(args);
487 }
488 this->SwitchToState(kRecvPacketNumBytes);
489 });
490 }
491
492 void HandleInitServer() {
493 std::string client_protocol_ver;
494
495 uint64_t len;
496 this->Read(&len);
497 client_protocol_ver.resize(len);
498 this->Read(dmlc::BeginPtr(client_protocol_ver), len);
499
500 TVMArgs args = RecvPackedSeq();
501
502 try {
503 ICHECK(serving_session_ == nullptr) << "Server has already been initialized";
504
505 std::string server_protocol_ver = kRPCProtocolVer;
506 ICHECK_EQ(client_protocol_ver, server_protocol_ver)
507 << "Server[" << name_ << "]: Client protocol version mismatch with the server "
508 << " server protocol=" << server_protocol_ver
509 << ", client protocol=" << client_protocol_ver;
510
511 std::string constructor_name;
512 TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0);
513
514 if (args.size() == 0) {
515 constructor_name = "rpc.LocalSession";
516 serving_session_ = std::make_shared<LocalSession>();
517 } else {
518 constructor_name = args[0].operator std::string();
519 constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1);
520 }
521
522 auto* fconstructor = Registry::Get(constructor_name);
523 ICHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name;
524 TVMRetValue con_ret;
525
526 try {
527 fconstructor->CallPacked(constructor_args, &con_ret);
528 } catch (const Error& e) {
529 LOG(FATAL) << "Server[" << name_ << "]:"
530 << " Error caught from session constructor " << constructor_name << ":\n"
531 << e.what();
532 }
533
534 ICHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
535 << "Server[" << name_ << "]:"
536 << " Constructor " << constructor_name << " need to return an RPCModule";
537 Module mod = con_ret;
538 std::string tkey = mod->type_key();
539 ICHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule";
540 serving_session_ = RPCModuleGetSession(mod);
541 this->ReturnVoid();
542 } catch (const std::exception& e) {
543 this->ReturnException(e.what());
544 }
545
546 this->SwitchToState(kRecvPacketNumBytes);
547 }
548
549 void HandleSyscallStreamSync() {
550 TVMArgs args = RecvPackedSeq();
551 try {
552 Device dev = args[0];
553 TVMStreamHandle handle = args[1];
554
555 this->SwitchToState(kWaitForAsyncCallback);
556 GetServingSession()->AsyncStreamWait(dev, handle, [this](RPCCode status, TVMArgs args) {
557 if (status == RPCCode::kException) {
558 this->ReturnException(args.values[0].v_str);
559 } else {
560 this->ReturnVoid();
561 }
562 this->SwitchToState(kRecvPacketNumBytes);
563 });
564 } catch (const std::exception& e) {
565 this->ReturnException(e.what());
566 this->SwitchToState(kRecvPacketNumBytes);
567 }
568 }
569
570 // Handler for special syscalls that have a specific RPCCode.
571 template <typename F>
572 void SysCallHandler(F f) {
573 TVMArgs args = RecvPackedSeq();
574 try {
575 TVMRetValue rv;
576 f(GetServingSession(), args, &rv);
577 TVMValue ret_value;
578 int ret_tcode;
579 TVMArgsSetter setter(&ret_value, &ret_tcode);
580 setter(0, rv);
581
582 this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1));
583 } catch (const std::exception& e) {
584 this->ReturnException(e.what());
585 }
586 this->SwitchToState(kRecvPacketNumBytes);
587 }
588
589 private:
590 RPCSession* GetServingSession() const {
591 ICHECK(serving_session_ != nullptr)
592 << "Need to call InitRemoteSession first before any further actions";
593 ICHECK(!serving_session_->IsAsync() || async_server_mode_)
594 << "Cannot host an async session in a non-Event driven server";
595
596 return serving_session_.get();
597 }
598 // Utility functions
599 // Internal read function, update pending_request_bytes_
600 size_t Read(void* data, size_t size) final {
601 ICHECK_LE(size, pending_request_bytes_);
602 reader_->Read(data, size);
603 pending_request_bytes_ -= size;
604 return size;
605 }
606 // wriite the data to the channel.
607 void Write(const void* data, size_t size) final { writer_->Write(data, size); }
608 // Number of pending bytes requests
609 size_t pending_request_bytes_{0};
610 // The ring buffer to read data from.
611 support::RingBuffer* reader_;
612 // The ringr buffer to write reply to.
613 support::RingBuffer* writer_;
614 // The session used to serve the RPC requests.
615 std::shared_ptr<RPCSession> serving_session_;
616 // Name of endpoint.
617 std::string name_;
618 // remote key
619 std::string* remote_key_;
620 // function to flush the writer.
621 std::function<void()> flush_writer_;
622};
623
624RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) {
625 RPCCode code = RPCCode::kCallFunc;
626
627 CHECK(channel_) << "Expected connection to server " << name_
628 << " to be active, but the connection was previously closed";
629 while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) {
630 while (writer_.bytes_available() != 0) {
631 writer_.ReadWithCallback(
632 [this](const void* data, size_t size) { return channel_->Send(data, size); },
633 writer_.bytes_available());
634 }
635 size_t bytes_needed = handler_->BytesNeeded();
636 if (bytes_needed != 0) {
637 size_t n = reader_.WriteWithCallback(
638 [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed);
639 if (n == 0) {
640 if (handler_->CanCleanShutdown()) {
641 return RPCCode::kShutdown;
642 } else {
643 LOG(FATAL) << "Channel closes before we get needed bytes";
644 }
645 }
646 }
647 code = handler_->HandleNextEvent(client_mode, false, setreturn);
648 }
649 return code;
650}
651
652void RPCEndpoint::Init() {
653 // callback to flush the writer.
654 auto flush_writer = [this]() {
655 while (writer_.bytes_available() != 0) {
656 size_t n = writer_.ReadWithCallback(
657 [this](const void* data, size_t size) { return channel_->Send(data, size); },
658 writer_.bytes_available());
659 if (n == 0) break;
660 }
661 };
662
663 // Event handler
664 handler_ = std::make_shared<EventHandler>(&reader_, &writer_, name_, &remote_key_, flush_writer);
665
666 // Quick function to for syscall remote.
667 syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) {
668 std::lock_guard<std::mutex> lock(mutex_);
669 RPCCode code = static_cast<RPCCode>(all_args[0].operator int());
670 TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1);
671
672 uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes(
673 args.values, args.type_codes, args.num_args, true);
674
675 // All packet begins with packet nbytes
676 handler_->Write(packet_nbytes);
677 handler_->Write(code);
678 handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
679
680 code = HandleUntilReturnEvent(true, [rv](TVMArgs args) {
681 ICHECK_EQ(args.size(), 1);
682 *rv = args[0];
683 });
684 ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
685 });
686}
687
688/*!
689 * \brief Create a new RPCEndpoint instance.
690 * \param channel RPCChannel used to communicate.
691 * \param name Name of this session, used to identify log messages from this RPCEndpoint instance.
692 * \param remote_key The remote key reported during protocol initialization, or "%toinit" if the
693 * RPCEndpoint should handle this phase of the protocol for you. Some servers may prefer to access
694 * parts of the key to modify their behavior.
695 * \param fcleanup The cleanup Packed function.
696 */
697std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel,
698 std::string name, std::string remote_key,
699 TypedPackedFunc<void()> fcleanup) {
700 std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
701 endpt->channel_ = std::move(channel);
702 endpt->name_ = std::move(name);
703 endpt->remote_key_ = std::move(remote_key);
704 endpt->fcleanup_ = fcleanup;
705 endpt->Init();
706 return endpt;
707}
708
709RPCEndpoint::~RPCEndpoint() { this->Shutdown(); }
710
711void RPCEndpoint::Shutdown() {
712 if (channel_ != nullptr) {
713 RPCCode code = RPCCode::kShutdown;
714 uint64_t packet_nbytes = sizeof(code);
715
716 handler_->Write(packet_nbytes);
717 handler_->Write(code);
718
719 // flush all writing buffer to output channel.
720 try {
721 while (writer_.bytes_available() != 0) {
722 size_t n = writer_.ReadWithCallback(
723 [this](const void* data, size_t size) { return channel_->Send(data, size); },
724 writer_.bytes_available());
725 if (n == 0) break;
726 }
727 } catch (const Error& e) {
728 }
729 channel_.reset(nullptr);
730 }
731}
732
733void RPCEndpoint::ServerLoop() {
734 if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
735 (*f)();
736 }
737 TVMRetValue rv;
738 ICHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown);
739 if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
740 (*f)();
741 }
742 channel_.reset(nullptr);
743 if (fcleanup_ != nullptr) fcleanup_();
744}
745
746int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) {
747 RPCCode code = RPCCode::kNone;
748 if (in_bytes.length() != 0) {
749 reader_.Write(in_bytes.c_str(), in_bytes.length());
750 code = handler_->HandleNextEvent(false, true, [](TVMArgs) {});
751 }
752 if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
753 writer_.ReadWithCallback(
754 [this](const void* data, size_t size) { return channel_->Send(data, size); },
755 writer_.bytes_available());
756 }
757 ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
758 if (code == RPCCode::kShutdown) return 0;
759 if (writer_.bytes_available() != 0) return 2;
760 return 1;
761}
762
763void RPCEndpoint::InitRemoteSession(TVMArgs args) {
764 std::lock_guard<std::mutex> lock(mutex_);
765 RPCCode code = RPCCode::kInitServer;
766 std::string protocol_ver = kRPCProtocolVer;
767 uint64_t length = protocol_ver.length();
768
769 uint64_t packet_nbytes =
770 sizeof(code) + sizeof(length) + length +
771 handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true);
772
773 // All packet begins with packet nbytes
774 handler_->Write(packet_nbytes);
775 handler_->Write(code);
776 handler_->Write(length);
777 handler_->WriteArray(protocol_ver.data(), length);
778 handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
779
780 code = HandleUntilReturnEvent(true, [](TVMArgs args) {});
781 ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
782}
783
784// Get remote function with name
785void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values,
786 const int* arg_type_codes, int num_args,
787 RPCSession::FEncodeReturn encode_return) {
788 std::lock_guard<std::mutex> lock(mutex_);
789
790 handler_->ValidateArguments(arg_values, arg_type_codes, num_args);
791 RPCCode code = RPCCode::kCallFunc;
792 uint64_t handle = reinterpret_cast<uint64_t>(h);
793
794 uint64_t packet_nbytes =
795 sizeof(code) + sizeof(handle) +
796 handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true);
797
798 handler_->Write(packet_nbytes);
799 handler_->Write(code);
800 handler_->Write(handle);
801 handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true);
802
803 code = HandleUntilReturnEvent(true, encode_return);
804 ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code);
805}
806
807void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) {
808 std::lock_guard<std::mutex> lock(mutex_);
809 RPCCode code = RPCCode::kCopyToRemote;
810
811 uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
812 ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
813 << "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset
814 << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";
815
816 uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
817 uint64_t packet_nbytes = overhead + nbytes;
818
819 handler_->Write(packet_nbytes);
820 handler_->Write(code);
821 RPCReference::SendDLTensor(handler_, to);
822 handler_->Write(nbytes);
823 handler_->WriteArray(reinterpret_cast<char*>(from_bytes), nbytes);
824 ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn);
825}
826
827void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) {
828 std::lock_guard<std::mutex> lock(mutex_);
829 RPCCode code = RPCCode::kCopyFromRemote;
830
831 uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from));
832 ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes)
833 << "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset
834 << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";
835
836 uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
837 uint64_t packet_nbytes = overhead;
838
839 handler_->Write(packet_nbytes);
840 handler_->Write(code);
841 RPCReference::SendDLTensor(handler_, from);
842 handler_->Write(nbytes);
843 ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck);
844
845 handler_->ReadArray(reinterpret_cast<char*>(to_bytes), nbytes);
846 handler_->FinishCopyAck();
847}
848
849// SysCallEventHandler functions
850void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
851 std::string name = args[0];
852 *rv = handler->GetFunction(name);
853}
854
855void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
856 void* handle = args[0];
857 int type_code = args[1];
858 handler->FreeHandle(handle, type_code);
859}
860
861void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
862 Device dev = args[0];
863 handler->GetDeviceAPI(dev)->SetDevice(dev);
864}
865
866void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
867 Device dev = args[0];
868 DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
869 if (kind == kExist) {
870 DeviceAPI* api = handler->GetDeviceAPI(dev, true);
871 if (api != nullptr) {
872 api->GetAttr(dev, kind, rv);
873 } else {
874 *rv = 0;
875 }
876 } else {
877 handler->GetDeviceAPI(dev)->GetAttr(dev, static_cast<DeviceAttrKind>(kind), rv);
878 }
879}
880
881void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
882 Device dev = args[0];
883 uint64_t nbytes = args[1];
884 uint64_t alignment = args[2];
885 DLDataType type_hint = args[3];
886 void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint);
887 *rv = data;
888}
889
890void RPCDevAllocDataWithScope(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
891 DLTensor* arr = args[0];
892 Device dev = arr->device;
893 int ndim = arr->ndim;
894 int64_t* shape = arr->shape;
895 DLDataType dtype = arr->dtype;
896 int tcode = args[1].type_code();
897 Optional<String> mem_scope = NullOpt;
898 if (tcode == kTVMStr) {
899 mem_scope = args[1].operator String();
900 } else {
901 ICHECK_EQ(tcode, kTVMNullptr);
902 }
903 void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, ndim, shape, dtype, mem_scope);
904 *rv = data;
905}
906
907void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
908 Device dev = args[0];
909 void* ptr = args[1];
910 handler->GetDeviceAPI(dev)->FreeDataSpace(dev, ptr);
911}
912
913void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
914 DLTensor* from = args[0];
915 DLTensor* to = args[1];
916 TVMStreamHandle stream = args[2];
917
918 Device dev = from->device;
919 if (dev.device_type == kDLCPU) {
920 dev = to->device;
921 } else {
922 ICHECK(to->device.device_type == kDLCPU || to->device.device_type == from->device.device_type)
923 << "Can not copy across different dev types directly";
924 }
925 handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream);
926}
927
928void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
929 Device dev = args[0];
930 void* data = handler->GetDeviceAPI(dev)->CreateStream(dev);
931 *rv = data;
932}
933
934void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
935 Device dev = args[0];
936 TVMStreamHandle stream = args[1];
937 handler->GetDeviceAPI(dev)->FreeStream(dev, stream);
938}
939
940void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
941 Device dev = args[0];
942 TVMStreamHandle stream = args[1];
943 handler->GetDeviceAPI(dev)->SetStream(dev, stream);
944}
945
946void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
947 // Event handler sit at clean state at this point.
948 switch (code) {
949 // system functions
950 case RPCCode::kFreeHandle:
951 SysCallHandler(RPCFreeHandle);
952 break;
953 case RPCCode::kGetGlobalFunc:
954 SysCallHandler(RPCGetGlobalFunc);
955 break;
956 case RPCCode::kDevSetDevice:
957 SysCallHandler(RPCDevSetDevice);
958 break;
959 case RPCCode::kDevGetAttr:
960 SysCallHandler(RPCDevGetAttr);
961 break;
962 case RPCCode::kDevAllocData:
963 SysCallHandler(RPCDevAllocData);
964 break;
965 case RPCCode::kDevAllocDataWithScope:
966 SysCallHandler(RPCDevAllocDataWithScope);
967 break;
968 case RPCCode::kDevFreeData:
969 SysCallHandler(RPCDevFreeData);
970 break;
971 case RPCCode::kDevCreateStream:
972 SysCallHandler(RPCDevCreateStream);
973 break;
974 case RPCCode::kDevFreeStream:
975 SysCallHandler(RPCDevFreeStream);
976 break;
977 case RPCCode::kDevStreamSync:
978 this->HandleSyscallStreamSync();
979 break;
980 case RPCCode::kDevSetStream:
981 SysCallHandler(RPCDevSetStream);
982 break;
983 case RPCCode::kCopyAmongRemote:
984 SysCallHandler(RPCCopyAmongRemote);
985 break;
986 default:
987 LOG(FATAL) << "Unknown event " << static_cast<int>(code);
988 }
989
990 if (state_ != kWaitForAsyncCallback) {
991 ICHECK_EQ(state_, kRecvPacketNumBytes);
992 }
993}
994
995/*!
996 * \brief RPC client session that proxies all calls to an endpoint.
997 */
998class RPCClientSession : public RPCSession, public DeviceAPI {
999 public:
1000 /*!
1001 * \brief param endpoint The client endpoint of the session.
1002 */
1003 explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {}
1004
1005 // function overrides
1006 PackedFuncHandle GetFunction(const std::string& name) final {
1007 return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name);
1008 }
1009
1010 void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
1011 int num_args, const FEncodeReturn& fencode_return) final {
1012 endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return);
1013 }
1014
1015 void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
1016 RPCCode code = RPCCode::kCopyToRemote;
1017 uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
1018 uint64_t rpc_max_size = GetRPCMaxTransferSize();
1019 ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!";
1020 const uint64_t block_size = rpc_max_size - overhead;
1021 uint64_t block_count = 0;
1022 const uint64_t num_blocks = nbytes / block_size;
1023 void* from_bytes;
1024
1025 for (block_count = 0; block_count < num_blocks; block_count++) {
1026 remote_to->byte_offset = block_count * block_size;
1027 from_bytes = reinterpret_cast<void*>(
1028 (reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
1029 endpoint_->CopyToRemote(from_bytes, remote_to, block_size);
1030 }
1031
1032 const uint64_t remainder_bytes = nbytes % block_size;
1033 if (remainder_bytes != 0) {
1034 remote_to->byte_offset = block_count * block_size;
1035 from_bytes = reinterpret_cast<void*>(
1036 (reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
1037 endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes);
1038 }
1039 }
1040
1041 void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
1042 RPCCode code = RPCCode::kCopyFromRemote;
1043 uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes);
1044 uint64_t rpc_max_size = GetRPCMaxTransferSize();
1045 ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!";
1046 const uint64_t block_size = rpc_max_size - overhead;
1047 uint64_t block_count = 0;
1048 const uint64_t num_blocks = nbytes / block_size;
1049 void* to_bytes;
1050
1051 for (block_count = 0; block_count < num_blocks; block_count++) {
1052 remote_from->byte_offset = block_count * block_size;
1053 to_bytes = reinterpret_cast<void*>(
1054 (reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
1055 endpoint_->CopyFromRemote(remote_from, to_bytes, block_size);
1056 }
1057
1058 const uint64_t remainder_bytes = nbytes % block_size;
1059 if (remainder_bytes != 0) {
1060 remote_from->byte_offset = block_count * block_size;
1061 to_bytes = reinterpret_cast<void*>(
1062 (reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
1063 endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes);
1064 }
1065 }
1066
1067 void FreeHandle(void* handle, int type_code) final {
1068 endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code);
1069 }
1070
1071 void SetDevice(Device dev) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, dev); }
1072
1073 void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final {
1074 if (dev.device_type == kDLCPU && kind == kExist) {
1075 // cpu always exists.
1076 *rv = 1;
1077 } else {
1078 *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, dev, static_cast<int>(kind));
1079 }
1080 }
1081
1082 void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
1083 return endpoint_->SysCallRemote(RPCCode::kDevAllocData, dev, nbytes, alignment, type_hint);
1084 }
1085
1086 void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
1087 Optional<String> mem_scope) final {
1088 DLTensor temp;
1089 temp.data = nullptr;
1090 temp.device = dev;
1091 temp.ndim = ndim;
1092 temp.dtype = dtype;
1093 temp.shape = const_cast<int64_t*>(shape);
1094 temp.strides = nullptr;
1095 temp.byte_offset = 0;
1096 if (mem_scope.defined()) {
1097 return endpoint_->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp,
1098 static_cast<std::string>(mem_scope.value()));
1099 } else {
1100 return endpoint_->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp, nullptr);
1101 }
1102 }
1103
1104 void FreeDataSpace(Device dev, void* ptr) final {
1105 endpoint_->SysCallRemote(RPCCode::kDevFreeData, dev, ptr);
1106 }
1107
1108 void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final {
1109 endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream);
1110 }
1111
1112 TVMStreamHandle CreateStream(Device dev) final {
1113 return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev);
1114 }
1115
1116 void FreeStream(Device dev, TVMStreamHandle stream) final {
1117 endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream);
1118 }
1119
1120 void StreamSync(Device dev, TVMStreamHandle stream) final {
1121 endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream);
1122 }
1123
1124 void SetStream(Device dev, TVMStreamHandle stream) final {
1125 endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream);
1126 }
1127
1128 DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; }
1129
1130 bool IsLocalSession() const final { return false; }
1131
1132 void Shutdown() final { endpoint_->Shutdown(); }
1133
1134 private:
1135 uint64_t GetRPCMaxTransferSize() {
1136 if (rpc_chunk_max_size_bytes_ > 0) {
1137 return (uint64_t)rpc_chunk_max_size_bytes_;
1138 }
1139
1140 PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
1141 if (rpc_func == nullptr) {
1142 rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
1143 } else {
1144 CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) {
1145 // Use args[1] as return value, args[0] is tcode
1146 // Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
1147 rpc_chunk_max_size_bytes_ = (int64_t)args[1];
1148 ICHECK_GT(rpc_chunk_max_size_bytes_, 0)
1149 << "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_
1150 << ")";
1151 });
1152 }
1153 return (uint64_t)rpc_chunk_max_size_bytes_;
1154 }
1155
1156 std::shared_ptr<RPCEndpoint> endpoint_;
1157 int64_t rpc_chunk_max_size_bytes_ = -1;
1158};
1159
1160std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
1161 return std::make_shared<RPCClientSession>(endpoint);
1162}
1163
1164uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) {
1165 uint64_t shape_bytes = tensor->ndim * sizeof(int64_t);
1166 uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data));
1167 uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) +
1168 sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) +
1169 shape_bytes + sizeof(nbytes);
1170 return overhead;
1171}
1172
1173} // namespace runtime
1174} // namespace tvm
1175