1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_session.cc |
22 | * \brief RPC session for remote function call. |
23 | */ |
24 | #include "rpc_endpoint.h" |
25 | |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/device_api.h> |
28 | #include <tvm/runtime/packed_func.h> |
29 | #include <tvm/runtime/registry.h> |
30 | #include <tvm/runtime/serializer.h> |
31 | |
32 | #include <algorithm> |
33 | #include <array> |
34 | #include <chrono> |
35 | #include <cmath> |
36 | #include <memory> |
37 | #include <string> |
38 | #include <utility> |
39 | #include <vector> |
40 | |
41 | #include "../../support/arena.h" |
42 | #include "../../support/ring_buffer.h" |
43 | #include "../object_internal.h" |
44 | #include "rpc_local_session.h" |
45 | |
46 | namespace tvm { |
47 | namespace runtime { |
48 | |
49 | /*! |
50 | * Event-driven state-machine based handlers for RPCEndpoint. |
51 | * |
52 | * Key functions: |
53 | * |
54 | * - SendPackedSeq: send the arguments over to the peer |
55 | * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). |
56 | */ |
57 | class RPCEndpoint::EventHandler : public dmlc::Stream { |
58 | public: |
59 | EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name, |
60 | std::string* remote_key, std::function<void()> flush_writer) |
61 | : reader_(reader), |
62 | writer_(writer), |
63 | name_(name), |
64 | remote_key_(remote_key), |
65 | flush_writer_(flush_writer) { |
66 | this->Clear(); |
67 | |
68 | if (*remote_key == "%toinit" ) { |
69 | state_ = kInitHeader; |
70 | remote_key_->resize(0); |
71 | pending_request_bytes_ = sizeof(int32_t); |
72 | } |
73 | } |
74 | |
75 | /*! |
76 | * \brief Bytes needed to fulfill current request |
77 | */ |
78 | size_t BytesNeeded() const { |
79 | if (reader_->bytes_available() < pending_request_bytes_) { |
80 | return pending_request_bytes_ - reader_->bytes_available(); |
81 | } else { |
82 | return 0; |
83 | } |
84 | } |
85 | |
86 | /*! |
87 | * \brief Request number of bytes from the reader. |
88 | * \param nbytes The number of bytes |
89 | */ |
90 | void RequestBytes(size_t nbytes) { |
91 | pending_request_bytes_ += nbytes; |
92 | reader_->Reserve(pending_request_bytes_); |
93 | } |
94 | |
95 | /*! \return Whether we are ready to handle next request. */ |
96 | bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; } |
97 | |
98 | /*! \return Whether we can perform a clean shutdown */ |
99 | bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; } |
100 | |
101 | /*! \brief Finish the copy ack stage. */ |
102 | void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); } |
103 | |
104 | /*! |
105 | * \brief Enter the io loop until the next event. |
106 | * \param client_mode Whether we are in the client. |
107 | * \param async_server_mode Whether we are in the async server mode. |
108 | * \param setreturn The function to set the return value encoding. |
109 | * \return The function to set return values when there is a return event. |
110 | */ |
111 | RPCCode HandleNextEvent(bool client_mode, bool async_server_mode, |
112 | RPCSession::FEncodeReturn setreturn) { |
113 | std::swap(client_mode_, client_mode); |
114 | std::swap(async_server_mode_, async_server_mode); |
115 | |
116 | RPCCode status = RPCCode::kNone; |
117 | |
118 | while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) { |
119 | switch (state_) { |
120 | case kInitHeader: |
121 | HandleInitHeader(); |
122 | break; |
123 | case kRecvPacketNumBytes: { |
124 | uint64_t packet_nbytes; |
125 | ICHECK(this->Read(&packet_nbytes)); |
126 | if (packet_nbytes != 0) { |
127 | this->SwitchToState(kProcessPacket); |
128 | this->RequestBytes(packet_nbytes); |
129 | } else { |
130 | this->SwitchToState(kRecvPacketNumBytes); |
131 | } |
132 | break; |
133 | } |
134 | case kProcessPacket: { |
135 | this->HandleProcessPacket(setreturn); |
136 | break; |
137 | } |
138 | case kWaitForAsyncCallback: { |
139 | break; |
140 | } |
141 | case kReturnReceived: { |
142 | this->SwitchToState(kRecvPacketNumBytes); |
143 | status = RPCCode::kReturn; |
144 | break; |
145 | } |
146 | case kCopyAckReceived: { |
147 | status = RPCCode::kCopyAck; |
148 | break; |
149 | } |
150 | case kShutdownReceived: { |
151 | status = RPCCode::kShutdown; |
152 | } |
153 | } |
154 | } |
155 | |
156 | std::swap(async_server_mode_, async_server_mode); |
157 | std::swap(client_mode_, client_mode); |
158 | return status; |
159 | } |
160 | |
161 | /*! \brief Clear all the states in the Handler.*/ |
162 | void Clear() { |
163 | state_ = kRecvPacketNumBytes; |
164 | pending_request_bytes_ = sizeof(uint64_t); |
165 | } |
166 | |
167 | /*! |
168 | * \brief Validate that the arguments can be sent through RPC. |
169 | * \param arg_values The argument values. |
170 | * \param type_codes The type codes. |
171 | */ |
172 | void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) { |
173 | TVMArgs args(arg_values, type_codes, num_args); |
174 | for (int i = 0; i < num_args; ++i) { |
175 | int tcode = type_codes[i]; |
176 | if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { |
177 | LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " |
178 | << args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC" ; |
179 | } else if (tcode == kDLDevice) { |
180 | DLDevice dev = args[i]; |
181 | ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel" ; |
182 | } |
183 | } |
184 | } |
185 | |
186 | void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { |
187 | LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); |
188 | } |
189 | |
190 | uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args, |
191 | bool client_mode) { |
192 | return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this); |
193 | } |
194 | |
195 | void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, |
196 | bool client_mode) { |
197 | RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this); |
198 | } |
199 | |
200 | // Endian aware IO handling |
201 | using Stream::Read; |
202 | using Stream::ReadArray; |
203 | using Stream::Write; |
204 | using Stream::WriteArray; |
205 | |
206 | void MessageStart(uint64_t packet_nbytes) { |
207 | // Unused here, implemented for microTVM framing layer. |
208 | } |
209 | |
210 | bool Read(RPCCode* code) { |
211 | int32_t cdata; |
212 | if (!this->Read(&cdata)) return false; |
213 | *code = static_cast<RPCCode>(cdata); |
214 | return true; |
215 | } |
216 | void Write(RPCCode code) { |
217 | int32_t cdata = static_cast<int>(code); |
218 | this->Write(cdata); |
219 | } |
220 | |
221 | void MessageDone() { |
222 | // Unused here, implemented for microTVM framing layer. |
223 | } |
224 | |
225 | template <typename T> |
226 | T* ArenaAlloc(int count) { |
227 | static_assert(std::is_pod<T>::value, "need to be trival" ); |
228 | return arena_.template allocate_<T>(count); |
229 | } |
230 | |
231 | protected: |
232 | enum State { |
233 | kInitHeader, |
234 | kRecvPacketNumBytes, |
235 | kProcessPacket, |
236 | kWaitForAsyncCallback, |
237 | kReturnReceived, |
238 | kCopyAckReceived, |
239 | kShutdownReceived |
240 | }; |
241 | // Current state; |
242 | State state_; |
243 | // Initialize remote header |
244 | int init_header_step_{0}; |
245 | // Whether current handler is client or server mode. |
246 | bool client_mode_{false}; |
247 | // Whether current handler is in the async server mode. |
248 | bool async_server_mode_{false}; |
249 | // Internal arena |
250 | support::Arena arena_; |
251 | |
252 | // State switcher |
253 | void SwitchToState(State state) { |
254 | // invariant |
255 | if (state != kCopyAckReceived) { |
256 | ICHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; |
257 | } |
258 | // need to actively flush the writer |
259 | // so the data get pushed out. |
260 | if (state_ == kWaitForAsyncCallback) { |
261 | flush_writer_(); |
262 | } |
263 | state_ = state; |
264 | ICHECK(state != kInitHeader) << "cannot switch to init header" ; |
265 | if (state == kRecvPacketNumBytes) { |
266 | this->RequestBytes(sizeof(uint64_t)); |
267 | // recycle arena for the next session. |
268 | arena_.RecycleAll(); |
269 | } |
270 | } |
271 | |
272 | // handler for initial header read |
273 | void HandleInitHeader() { |
274 | if (init_header_step_ == 0) { |
275 | int32_t len; |
276 | this->Read(&len); |
277 | remote_key_->resize(len); |
278 | init_header_step_ = 1; |
279 | this->RequestBytes(len); |
280 | return; |
281 | } else { |
282 | ICHECK_EQ(init_header_step_, 1); |
283 | this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); |
284 | this->SwitchToState(kRecvPacketNumBytes); |
285 | } |
286 | } |
287 | |
288 | // Handler for read code. |
289 | void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { |
290 | RPCCode code = RPCCode::kNone; |
291 | this->Read(&code); |
292 | |
293 | if (code >= RPCCode::kSyscallCodeStart) { |
294 | this->HandleSyscall(code); |
295 | } else { |
296 | switch (code) { |
297 | case RPCCode::kInitServer: { |
298 | this->HandleInitServer(); |
299 | break; |
300 | } |
301 | case RPCCode::kCallFunc: { |
302 | this->HandleNormalCallFunc(); |
303 | break; |
304 | } |
305 | case RPCCode::kCopyFromRemote: { |
306 | this->HandleCopyFromRemote(); |
307 | break; |
308 | } |
309 | case RPCCode::kCopyToRemote: { |
310 | this->HandleCopyToRemote(); |
311 | break; |
312 | } |
313 | case RPCCode::kException: |
314 | case RPCCode::kReturn: { |
315 | this->HandleReturn(code, setreturn); |
316 | break; |
317 | } |
318 | case RPCCode::kCopyAck: { |
319 | this->SwitchToState(kCopyAckReceived); |
320 | break; |
321 | } |
322 | case RPCCode::kShutdown: { |
323 | this->SwitchToState(kShutdownReceived); |
324 | break; |
325 | } |
326 | default: |
327 | LOG(FATAL) << "Unknown event " << static_cast<int>(code); |
328 | } |
329 | } |
330 | } |
331 | |
332 | /*! |
333 | * \brief Receive incoming packed seq from the stream. |
334 | * \return The received argments. |
335 | * \note The TVMArgs is available until we switchstate. |
336 | */ |
337 | TVMArgs RecvPackedSeq() { |
338 | TVMValue* values; |
339 | int* tcodes; |
340 | int num_args; |
341 | RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); |
342 | return TVMArgs(values, tcodes, num_args); |
343 | } |
344 | |
345 | /*! |
346 | * \brief Return exception to the remote. |
347 | * \param err_msg The error message. |
348 | */ |
349 | void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); } |
350 | |
351 | /*! |
352 | * \brief Return nullptr to the remote. |
353 | * \param err_msg The error message. |
354 | */ |
355 | void ReturnVoid() { RPCReference::ReturnVoid(this); } |
356 | |
357 | /*! |
358 | * \brief Return a packed sequence to the remote. |
359 | * \param args The arguments. |
360 | */ |
361 | void ReturnPackedSeq(TVMArgs args) { |
362 | RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this); |
363 | } |
364 | |
365 | /*! |
366 | * \brief Handle the case when return/exception value is received. |
367 | * \param code The RPC code. |
368 | * \param setreturn The function to encode return. |
369 | */ |
370 | void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { |
371 | TVMArgs args = RecvPackedSeq(); |
372 | if (code == RPCCode::kException) { |
373 | // switch to the state before sending exception. |
374 | this->SwitchToState(kRecvPacketNumBytes); |
375 | std::string msg = args[0]; |
376 | LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; |
377 | } |
378 | |
379 | ICHECK(setreturn != nullptr) << "fsetreturn not available" ; |
380 | setreturn(args); |
381 | |
382 | this->SwitchToState(kReturnReceived); |
383 | } |
384 | |
385 | void HandleSyscall(RPCCode code); |
386 | |
387 | void HandleCopyFromRemote() { |
388 | DLTensor* arr = RPCReference::ReceiveDLTensor(this); |
389 | uint64_t data_bytes; |
390 | this->Read(&data_bytes); |
391 | size_t elem_bytes = (arr->dtype.bits * arr->dtype.lanes + 7) / 8; |
392 | auto* sess = GetServingSession(); |
393 | // Return Copy Ack with the given data |
394 | auto fcopyack = [this](char* dptr, size_t num_bytes) { |
395 | RPCCode code = RPCCode::kCopyAck; |
396 | uint64_t packet_nbytes = sizeof(code) + num_bytes; |
397 | |
398 | this->Write(packet_nbytes); |
399 | this->Write(code); |
400 | this->WriteArray(dptr, num_bytes); |
401 | this->SwitchToState(kRecvPacketNumBytes); |
402 | }; |
403 | |
404 | // When session is local, we can directly treat handle |
405 | // as the cpu pointer without allocating a temp space. |
406 | if (arr->device.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) { |
407 | char* data_ptr = reinterpret_cast<char*>(arr->data) + arr->byte_offset; |
408 | fcopyack(data_ptr, data_bytes); |
409 | } else { |
410 | char* temp_data = this->ArenaAlloc<char>(data_bytes); |
411 | auto on_copy_complete = [this, elem_bytes, data_bytes, temp_data, fcopyack](RPCCode status, |
412 | TVMArgs args) { |
413 | if (status == RPCCode::kException) { |
414 | this->ReturnException(args.values[0].v_str); |
415 | this->SwitchToState(kRecvPacketNumBytes); |
416 | } else { |
417 | // endian aware handling |
418 | if (!DMLC_IO_NO_ENDIAN_SWAP) { |
419 | dmlc::ByteSwap(temp_data, elem_bytes, data_bytes / elem_bytes); |
420 | } |
421 | fcopyack(temp_data, data_bytes); |
422 | } |
423 | }; |
424 | |
425 | this->SwitchToState(kWaitForAsyncCallback); |
426 | sess->AsyncCopyFromRemote(arr, static_cast<void*>(temp_data), data_bytes, on_copy_complete); |
427 | } |
428 | } |
429 | |
430 | void HandleCopyToRemote() { |
431 | DLTensor* arr = RPCReference::ReceiveDLTensor(this); |
432 | uint64_t data_bytes; |
433 | this->Read(&data_bytes); |
434 | size_t elem_bytes = (arr->dtype.bits * arr->dtype.lanes + 7) / 8; |
435 | auto* sess = GetServingSession(); |
436 | |
437 | // When session is local, we can directly treat handle |
438 | // as the cpu pointer without allocating a temp space. |
439 | if (arr->device.device_type == kDLCPU && sess->IsLocalSession()) { |
440 | char* dptr = reinterpret_cast<char*>(arr->data) + arr->byte_offset; |
441 | this->ReadArray(dptr, data_bytes); |
442 | |
443 | if (!DMLC_IO_NO_ENDIAN_SWAP) { |
444 | dmlc::ByteSwap(dptr, elem_bytes, data_bytes / elem_bytes); |
445 | } |
446 | this->ReturnVoid(); |
447 | this->SwitchToState(kRecvPacketNumBytes); |
448 | } else { |
449 | char* temp_data = this->ArenaAlloc<char>(data_bytes); |
450 | this->ReadArray(temp_data, data_bytes); |
451 | |
452 | if (!DMLC_IO_NO_ENDIAN_SWAP) { |
453 | dmlc::ByteSwap(temp_data, elem_bytes, data_bytes / elem_bytes); |
454 | } |
455 | |
456 | auto on_copy_complete = [this](RPCCode status, TVMArgs args) { |
457 | if (status == RPCCode::kException) { |
458 | this->ReturnException(args.values[0].v_str); |
459 | this->SwitchToState(kRecvPacketNumBytes); |
460 | } else { |
461 | this->ReturnVoid(); |
462 | this->SwitchToState(kRecvPacketNumBytes); |
463 | } |
464 | }; |
465 | |
466 | this->SwitchToState(kWaitForAsyncCallback); |
467 | sess->AsyncCopyToRemote(static_cast<void*>(temp_data), arr, data_bytes, on_copy_complete); |
468 | } |
469 | } |
470 | |
471 | // Handle for packed call. |
472 | void HandleNormalCallFunc() { |
473 | uint64_t call_handle; |
474 | |
475 | this->Read(&call_handle); |
476 | TVMArgs args = RecvPackedSeq(); |
477 | |
478 | this->SwitchToState(kWaitForAsyncCallback); |
479 | GetServingSession()->AsyncCallFunc( |
480 | reinterpret_cast<void*>(call_handle), args.values, args.type_codes, args.size(), |
481 | [this](RPCCode status, TVMArgs args) { |
482 | if (status == RPCCode::kException) { |
483 | this->ReturnException(args.values[0].v_str); |
484 | } else { |
485 | ValidateArguments(args.values, args.type_codes, args.size()); |
486 | this->ReturnPackedSeq(args); |
487 | } |
488 | this->SwitchToState(kRecvPacketNumBytes); |
489 | }); |
490 | } |
491 | |
492 | void HandleInitServer() { |
493 | std::string client_protocol_ver; |
494 | |
495 | uint64_t len; |
496 | this->Read(&len); |
497 | client_protocol_ver.resize(len); |
498 | this->Read(dmlc::BeginPtr(client_protocol_ver), len); |
499 | |
500 | TVMArgs args = RecvPackedSeq(); |
501 | |
502 | try { |
503 | ICHECK(serving_session_ == nullptr) << "Server has already been initialized" ; |
504 | |
505 | std::string server_protocol_ver = kRPCProtocolVer; |
506 | ICHECK_EQ(client_protocol_ver, server_protocol_ver) |
507 | << "Server[" << name_ << "]: Client protocol version mismatch with the server " |
508 | << " server protocol=" << server_protocol_ver |
509 | << ", client protocol=" << client_protocol_ver; |
510 | |
511 | std::string constructor_name; |
512 | TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0); |
513 | |
514 | if (args.size() == 0) { |
515 | constructor_name = "rpc.LocalSession" ; |
516 | serving_session_ = std::make_shared<LocalSession>(); |
517 | } else { |
518 | constructor_name = args[0].operator std::string(); |
519 | constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1); |
520 | } |
521 | |
522 | auto* fconstructor = Registry::Get(constructor_name); |
523 | ICHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name; |
524 | TVMRetValue con_ret; |
525 | |
526 | try { |
527 | fconstructor->CallPacked(constructor_args, &con_ret); |
528 | } catch (const Error& e) { |
529 | LOG(FATAL) << "Server[" << name_ << "]:" |
530 | << " Error caught from session constructor " << constructor_name << ":\n" |
531 | << e.what(); |
532 | } |
533 | |
534 | ICHECK_EQ(con_ret.type_code(), kTVMModuleHandle) |
535 | << "Server[" << name_ << "]:" |
536 | << " Constructor " << constructor_name << " need to return an RPCModule" ; |
537 | Module mod = con_ret; |
538 | std::string tkey = mod->type_key(); |
539 | ICHECK_EQ(tkey, "rpc" ) << "Constructor " << constructor_name << " to return an RPCModule" ; |
540 | serving_session_ = RPCModuleGetSession(mod); |
541 | this->ReturnVoid(); |
542 | } catch (const std::exception& e) { |
543 | this->ReturnException(e.what()); |
544 | } |
545 | |
546 | this->SwitchToState(kRecvPacketNumBytes); |
547 | } |
548 | |
549 | void HandleSyscallStreamSync() { |
550 | TVMArgs args = RecvPackedSeq(); |
551 | try { |
552 | Device dev = args[0]; |
553 | TVMStreamHandle handle = args[1]; |
554 | |
555 | this->SwitchToState(kWaitForAsyncCallback); |
556 | GetServingSession()->AsyncStreamWait(dev, handle, [this](RPCCode status, TVMArgs args) { |
557 | if (status == RPCCode::kException) { |
558 | this->ReturnException(args.values[0].v_str); |
559 | } else { |
560 | this->ReturnVoid(); |
561 | } |
562 | this->SwitchToState(kRecvPacketNumBytes); |
563 | }); |
564 | } catch (const std::exception& e) { |
565 | this->ReturnException(e.what()); |
566 | this->SwitchToState(kRecvPacketNumBytes); |
567 | } |
568 | } |
569 | |
570 | // Handler for special syscalls that have a specific RPCCode. |
571 | template <typename F> |
572 | void SysCallHandler(F f) { |
573 | TVMArgs args = RecvPackedSeq(); |
574 | try { |
575 | TVMRetValue rv; |
576 | f(GetServingSession(), args, &rv); |
577 | TVMValue ret_value; |
578 | int ret_tcode; |
579 | TVMArgsSetter setter(&ret_value, &ret_tcode); |
580 | setter(0, rv); |
581 | |
582 | this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); |
583 | } catch (const std::exception& e) { |
584 | this->ReturnException(e.what()); |
585 | } |
586 | this->SwitchToState(kRecvPacketNumBytes); |
587 | } |
588 | |
589 | private: |
590 | RPCSession* GetServingSession() const { |
591 | ICHECK(serving_session_ != nullptr) |
592 | << "Need to call InitRemoteSession first before any further actions" ; |
593 | ICHECK(!serving_session_->IsAsync() || async_server_mode_) |
594 | << "Cannot host an async session in a non-Event driven server" ; |
595 | |
596 | return serving_session_.get(); |
597 | } |
598 | // Utility functions |
599 | // Internal read function, update pending_request_bytes_ |
600 | size_t Read(void* data, size_t size) final { |
601 | ICHECK_LE(size, pending_request_bytes_); |
602 | reader_->Read(data, size); |
603 | pending_request_bytes_ -= size; |
604 | return size; |
605 | } |
606 | // wriite the data to the channel. |
607 | void Write(const void* data, size_t size) final { writer_->Write(data, size); } |
608 | // Number of pending bytes requests |
609 | size_t pending_request_bytes_{0}; |
610 | // The ring buffer to read data from. |
611 | support::RingBuffer* reader_; |
612 | // The ringr buffer to write reply to. |
613 | support::RingBuffer* writer_; |
614 | // The session used to serve the RPC requests. |
615 | std::shared_ptr<RPCSession> serving_session_; |
616 | // Name of endpoint. |
617 | std::string name_; |
618 | // remote key |
619 | std::string* remote_key_; |
620 | // function to flush the writer. |
621 | std::function<void()> flush_writer_; |
622 | }; |
623 | |
624 | RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { |
625 | RPCCode code = RPCCode::kCallFunc; |
626 | |
627 | CHECK(channel_) << "Expected connection to server " << name_ |
628 | << " to be active, but the connection was previously closed" ; |
629 | while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) { |
630 | while (writer_.bytes_available() != 0) { |
631 | writer_.ReadWithCallback( |
632 | [this](const void* data, size_t size) { return channel_->Send(data, size); }, |
633 | writer_.bytes_available()); |
634 | } |
635 | size_t bytes_needed = handler_->BytesNeeded(); |
636 | if (bytes_needed != 0) { |
637 | size_t n = reader_.WriteWithCallback( |
638 | [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed); |
639 | if (n == 0) { |
640 | if (handler_->CanCleanShutdown()) { |
641 | return RPCCode::kShutdown; |
642 | } else { |
643 | LOG(FATAL) << "Channel closes before we get needed bytes" ; |
644 | } |
645 | } |
646 | } |
647 | code = handler_->HandleNextEvent(client_mode, false, setreturn); |
648 | } |
649 | return code; |
650 | } |
651 | |
652 | void RPCEndpoint::Init() { |
653 | // callback to flush the writer. |
654 | auto flush_writer = [this]() { |
655 | while (writer_.bytes_available() != 0) { |
656 | size_t n = writer_.ReadWithCallback( |
657 | [this](const void* data, size_t size) { return channel_->Send(data, size); }, |
658 | writer_.bytes_available()); |
659 | if (n == 0) break; |
660 | } |
661 | }; |
662 | |
663 | // Event handler |
664 | handler_ = std::make_shared<EventHandler>(&reader_, &writer_, name_, &remote_key_, flush_writer); |
665 | |
666 | // Quick function to for syscall remote. |
667 | syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { |
668 | std::lock_guard<std::mutex> lock(mutex_); |
669 | RPCCode code = static_cast<RPCCode>(all_args[0].operator int()); |
670 | TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1); |
671 | |
672 | uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes( |
673 | args.values, args.type_codes, args.num_args, true); |
674 | |
675 | // All packet begins with packet nbytes |
676 | handler_->Write(packet_nbytes); |
677 | handler_->Write(code); |
678 | handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); |
679 | |
680 | code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { |
681 | ICHECK_EQ(args.size(), 1); |
682 | *rv = args[0]; |
683 | }); |
684 | ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); |
685 | }); |
686 | } |
687 | |
688 | /*! |
689 | * \brief Create a new RPCEndpoint instance. |
690 | * \param channel RPCChannel used to communicate. |
691 | * \param name Name of this session, used to identify log messages from this RPCEndpoint instance. |
692 | * \param remote_key The remote key reported during protocol initialization, or "%toinit" if the |
693 | * RPCEndpoint should handle this phase of the protocol for you. Some servers may prefer to access |
694 | * parts of the key to modify their behavior. |
695 | * \param fcleanup The cleanup Packed function. |
696 | */ |
697 | std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel, |
698 | std::string name, std::string remote_key, |
699 | TypedPackedFunc<void()> fcleanup) { |
700 | std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>(); |
701 | endpt->channel_ = std::move(channel); |
702 | endpt->name_ = std::move(name); |
703 | endpt->remote_key_ = std::move(remote_key); |
704 | endpt->fcleanup_ = fcleanup; |
705 | endpt->Init(); |
706 | return endpt; |
707 | } |
708 | |
709 | RPCEndpoint::~RPCEndpoint() { this->Shutdown(); } |
710 | |
711 | void RPCEndpoint::Shutdown() { |
712 | if (channel_ != nullptr) { |
713 | RPCCode code = RPCCode::kShutdown; |
714 | uint64_t packet_nbytes = sizeof(code); |
715 | |
716 | handler_->Write(packet_nbytes); |
717 | handler_->Write(code); |
718 | |
719 | // flush all writing buffer to output channel. |
720 | try { |
721 | while (writer_.bytes_available() != 0) { |
722 | size_t n = writer_.ReadWithCallback( |
723 | [this](const void* data, size_t size) { return channel_->Send(data, size); }, |
724 | writer_.bytes_available()); |
725 | if (n == 0) break; |
726 | } |
727 | } catch (const Error& e) { |
728 | } |
729 | channel_.reset(nullptr); |
730 | } |
731 | } |
732 | |
733 | void RPCEndpoint::ServerLoop() { |
734 | if (const auto* f = Registry::Get("tvm.rpc.server.start" )) { |
735 | (*f)(); |
736 | } |
737 | TVMRetValue rv; |
738 | ICHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); |
739 | if (const auto* f = Registry::Get("tvm.rpc.server.shutdown" )) { |
740 | (*f)(); |
741 | } |
742 | channel_.reset(nullptr); |
743 | if (fcleanup_ != nullptr) fcleanup_(); |
744 | } |
745 | |
746 | int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { |
747 | RPCCode code = RPCCode::kNone; |
748 | if (in_bytes.length() != 0) { |
749 | reader_.Write(in_bytes.c_str(), in_bytes.length()); |
750 | code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); |
751 | } |
752 | if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { |
753 | writer_.ReadWithCallback( |
754 | [this](const void* data, size_t size) { return channel_->Send(data, size); }, |
755 | writer_.bytes_available()); |
756 | } |
757 | ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); |
758 | if (code == RPCCode::kShutdown) return 0; |
759 | if (writer_.bytes_available() != 0) return 2; |
760 | return 1; |
761 | } |
762 | |
763 | void RPCEndpoint::InitRemoteSession(TVMArgs args) { |
764 | std::lock_guard<std::mutex> lock(mutex_); |
765 | RPCCode code = RPCCode::kInitServer; |
766 | std::string protocol_ver = kRPCProtocolVer; |
767 | uint64_t length = protocol_ver.length(); |
768 | |
769 | uint64_t packet_nbytes = |
770 | sizeof(code) + sizeof(length) + length + |
771 | handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true); |
772 | |
773 | // All packet begins with packet nbytes |
774 | handler_->Write(packet_nbytes); |
775 | handler_->Write(code); |
776 | handler_->Write(length); |
777 | handler_->WriteArray(protocol_ver.data(), length); |
778 | handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); |
779 | |
780 | code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); |
781 | ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); |
782 | } |
783 | |
784 | // Get remote function with name |
785 | void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values, |
786 | const int* arg_type_codes, int num_args, |
787 | RPCSession::FEncodeReturn encode_return) { |
788 | std::lock_guard<std::mutex> lock(mutex_); |
789 | |
790 | handler_->ValidateArguments(arg_values, arg_type_codes, num_args); |
791 | RPCCode code = RPCCode::kCallFunc; |
792 | uint64_t handle = reinterpret_cast<uint64_t>(h); |
793 | |
794 | uint64_t packet_nbytes = |
795 | sizeof(code) + sizeof(handle) + |
796 | handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true); |
797 | |
798 | handler_->Write(packet_nbytes); |
799 | handler_->Write(code); |
800 | handler_->Write(handle); |
801 | handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true); |
802 | |
803 | code = HandleUntilReturnEvent(true, encode_return); |
804 | ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code); |
805 | } |
806 | |
807 | void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { |
808 | std::lock_guard<std::mutex> lock(mutex_); |
809 | RPCCode code = RPCCode::kCopyToRemote; |
810 | |
811 | uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to)); |
812 | ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes) |
813 | << "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset |
814 | << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")" ; |
815 | |
816 | uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes); |
817 | uint64_t packet_nbytes = overhead + nbytes; |
818 | |
819 | handler_->Write(packet_nbytes); |
820 | handler_->Write(code); |
821 | RPCReference::SendDLTensor(handler_, to); |
822 | handler_->Write(nbytes); |
823 | handler_->WriteArray(reinterpret_cast<char*>(from_bytes), nbytes); |
824 | ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn); |
825 | } |
826 | |
827 | void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { |
828 | std::lock_guard<std::mutex> lock(mutex_); |
829 | RPCCode code = RPCCode::kCopyFromRemote; |
830 | |
831 | uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from)); |
832 | ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes) |
833 | << "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset |
834 | << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")" ; |
835 | |
836 | uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes); |
837 | uint64_t packet_nbytes = overhead; |
838 | |
839 | handler_->Write(packet_nbytes); |
840 | handler_->Write(code); |
841 | RPCReference::SendDLTensor(handler_, from); |
842 | handler_->Write(nbytes); |
843 | ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck); |
844 | |
845 | handler_->ReadArray(reinterpret_cast<char*>(to_bytes), nbytes); |
846 | handler_->FinishCopyAck(); |
847 | } |
848 | |
849 | // SysCallEventHandler functions |
850 | void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
851 | std::string name = args[0]; |
852 | *rv = handler->GetFunction(name); |
853 | } |
854 | |
855 | void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
856 | void* handle = args[0]; |
857 | int type_code = args[1]; |
858 | handler->FreeHandle(handle, type_code); |
859 | } |
860 | |
861 | void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
862 | Device dev = args[0]; |
863 | handler->GetDeviceAPI(dev)->SetDevice(dev); |
864 | } |
865 | |
866 | void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
867 | Device dev = args[0]; |
868 | DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int()); |
869 | if (kind == kExist) { |
870 | DeviceAPI* api = handler->GetDeviceAPI(dev, true); |
871 | if (api != nullptr) { |
872 | api->GetAttr(dev, kind, rv); |
873 | } else { |
874 | *rv = 0; |
875 | } |
876 | } else { |
877 | handler->GetDeviceAPI(dev)->GetAttr(dev, static_cast<DeviceAttrKind>(kind), rv); |
878 | } |
879 | } |
880 | |
881 | void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
882 | Device dev = args[0]; |
883 | uint64_t nbytes = args[1]; |
884 | uint64_t alignment = args[2]; |
885 | DLDataType type_hint = args[3]; |
886 | void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint); |
887 | *rv = data; |
888 | } |
889 | |
890 | void RPCDevAllocDataWithScope(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
891 | DLTensor* arr = args[0]; |
892 | Device dev = arr->device; |
893 | int ndim = arr->ndim; |
894 | int64_t* shape = arr->shape; |
895 | DLDataType dtype = arr->dtype; |
896 | int tcode = args[1].type_code(); |
897 | Optional<String> mem_scope = NullOpt; |
898 | if (tcode == kTVMStr) { |
899 | mem_scope = args[1].operator String(); |
900 | } else { |
901 | ICHECK_EQ(tcode, kTVMNullptr); |
902 | } |
903 | void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, ndim, shape, dtype, mem_scope); |
904 | *rv = data; |
905 | } |
906 | |
907 | void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
908 | Device dev = args[0]; |
909 | void* ptr = args[1]; |
910 | handler->GetDeviceAPI(dev)->FreeDataSpace(dev, ptr); |
911 | } |
912 | |
913 | void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
914 | DLTensor* from = args[0]; |
915 | DLTensor* to = args[1]; |
916 | TVMStreamHandle stream = args[2]; |
917 | |
918 | Device dev = from->device; |
919 | if (dev.device_type == kDLCPU) { |
920 | dev = to->device; |
921 | } else { |
922 | ICHECK(to->device.device_type == kDLCPU || to->device.device_type == from->device.device_type) |
923 | << "Can not copy across different dev types directly" ; |
924 | } |
925 | handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); |
926 | } |
927 | |
928 | void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
929 | Device dev = args[0]; |
930 | void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); |
931 | *rv = data; |
932 | } |
933 | |
934 | void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
935 | Device dev = args[0]; |
936 | TVMStreamHandle stream = args[1]; |
937 | handler->GetDeviceAPI(dev)->FreeStream(dev, stream); |
938 | } |
939 | |
940 | void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { |
941 | Device dev = args[0]; |
942 | TVMStreamHandle stream = args[1]; |
943 | handler->GetDeviceAPI(dev)->SetStream(dev, stream); |
944 | } |
945 | |
946 | void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { |
947 | // Event handler sit at clean state at this point. |
948 | switch (code) { |
949 | // system functions |
950 | case RPCCode::kFreeHandle: |
951 | SysCallHandler(RPCFreeHandle); |
952 | break; |
953 | case RPCCode::kGetGlobalFunc: |
954 | SysCallHandler(RPCGetGlobalFunc); |
955 | break; |
956 | case RPCCode::kDevSetDevice: |
957 | SysCallHandler(RPCDevSetDevice); |
958 | break; |
959 | case RPCCode::kDevGetAttr: |
960 | SysCallHandler(RPCDevGetAttr); |
961 | break; |
962 | case RPCCode::kDevAllocData: |
963 | SysCallHandler(RPCDevAllocData); |
964 | break; |
965 | case RPCCode::kDevAllocDataWithScope: |
966 | SysCallHandler(RPCDevAllocDataWithScope); |
967 | break; |
968 | case RPCCode::kDevFreeData: |
969 | SysCallHandler(RPCDevFreeData); |
970 | break; |
971 | case RPCCode::kDevCreateStream: |
972 | SysCallHandler(RPCDevCreateStream); |
973 | break; |
974 | case RPCCode::kDevFreeStream: |
975 | SysCallHandler(RPCDevFreeStream); |
976 | break; |
977 | case RPCCode::kDevStreamSync: |
978 | this->HandleSyscallStreamSync(); |
979 | break; |
980 | case RPCCode::kDevSetStream: |
981 | SysCallHandler(RPCDevSetStream); |
982 | break; |
983 | case RPCCode::kCopyAmongRemote: |
984 | SysCallHandler(RPCCopyAmongRemote); |
985 | break; |
986 | default: |
987 | LOG(FATAL) << "Unknown event " << static_cast<int>(code); |
988 | } |
989 | |
990 | if (state_ != kWaitForAsyncCallback) { |
991 | ICHECK_EQ(state_, kRecvPacketNumBytes); |
992 | } |
993 | } |
994 | |
995 | /*! |
996 | * \brief RPC client session that proxies all calls to an endpoint. |
997 | */ |
998 | class RPCClientSession : public RPCSession, public DeviceAPI { |
999 | public: |
1000 | /*! |
1001 | * \brief param endpoint The client endpoint of the session. |
1002 | */ |
1003 | explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {} |
1004 | |
1005 | // function overrides |
1006 | PackedFuncHandle GetFunction(const std::string& name) final { |
1007 | return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); |
1008 | } |
1009 | |
1010 | void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, |
1011 | int num_args, const FEncodeReturn& fencode_return) final { |
1012 | endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return); |
1013 | } |
1014 | |
1015 | void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final { |
1016 | RPCCode code = RPCCode::kCopyToRemote; |
1017 | uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes); |
1018 | uint64_t rpc_max_size = GetRPCMaxTransferSize(); |
1019 | ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!" ; |
1020 | const uint64_t block_size = rpc_max_size - overhead; |
1021 | uint64_t block_count = 0; |
1022 | const uint64_t num_blocks = nbytes / block_size; |
1023 | void* from_bytes; |
1024 | |
1025 | for (block_count = 0; block_count < num_blocks; block_count++) { |
1026 | remote_to->byte_offset = block_count * block_size; |
1027 | from_bytes = reinterpret_cast<void*>( |
1028 | (reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size)); |
1029 | endpoint_->CopyToRemote(from_bytes, remote_to, block_size); |
1030 | } |
1031 | |
1032 | const uint64_t remainder_bytes = nbytes % block_size; |
1033 | if (remainder_bytes != 0) { |
1034 | remote_to->byte_offset = block_count * block_size; |
1035 | from_bytes = reinterpret_cast<void*>( |
1036 | (reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size)); |
1037 | endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes); |
1038 | } |
1039 | } |
1040 | |
1041 | void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final { |
1042 | RPCCode code = RPCCode::kCopyFromRemote; |
1043 | uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes); |
1044 | uint64_t rpc_max_size = GetRPCMaxTransferSize(); |
1045 | ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!" ; |
1046 | const uint64_t block_size = rpc_max_size - overhead; |
1047 | uint64_t block_count = 0; |
1048 | const uint64_t num_blocks = nbytes / block_size; |
1049 | void* to_bytes; |
1050 | |
1051 | for (block_count = 0; block_count < num_blocks; block_count++) { |
1052 | remote_from->byte_offset = block_count * block_size; |
1053 | to_bytes = reinterpret_cast<void*>( |
1054 | (reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size)); |
1055 | endpoint_->CopyFromRemote(remote_from, to_bytes, block_size); |
1056 | } |
1057 | |
1058 | const uint64_t remainder_bytes = nbytes % block_size; |
1059 | if (remainder_bytes != 0) { |
1060 | remote_from->byte_offset = block_count * block_size; |
1061 | to_bytes = reinterpret_cast<void*>( |
1062 | (reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size)); |
1063 | endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes); |
1064 | } |
1065 | } |
1066 | |
1067 | void FreeHandle(void* handle, int type_code) final { |
1068 | endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); |
1069 | } |
1070 | |
1071 | void SetDevice(Device dev) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, dev); } |
1072 | |
1073 | void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { |
1074 | if (dev.device_type == kDLCPU && kind == kExist) { |
1075 | // cpu always exists. |
1076 | *rv = 1; |
1077 | } else { |
1078 | *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, dev, static_cast<int>(kind)); |
1079 | } |
1080 | } |
1081 | |
1082 | void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { |
1083 | return endpoint_->SysCallRemote(RPCCode::kDevAllocData, dev, nbytes, alignment, type_hint); |
1084 | } |
1085 | |
1086 | void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, |
1087 | Optional<String> mem_scope) final { |
1088 | DLTensor temp; |
1089 | temp.data = nullptr; |
1090 | temp.device = dev; |
1091 | temp.ndim = ndim; |
1092 | temp.dtype = dtype; |
1093 | temp.shape = const_cast<int64_t*>(shape); |
1094 | temp.strides = nullptr; |
1095 | temp.byte_offset = 0; |
1096 | if (mem_scope.defined()) { |
1097 | return endpoint_->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp, |
1098 | static_cast<std::string>(mem_scope.value())); |
1099 | } else { |
1100 | return endpoint_->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp, nullptr); |
1101 | } |
1102 | } |
1103 | |
1104 | void FreeDataSpace(Device dev, void* ptr) final { |
1105 | endpoint_->SysCallRemote(RPCCode::kDevFreeData, dev, ptr); |
1106 | } |
1107 | |
1108 | void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final { |
1109 | endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); |
1110 | } |
1111 | |
1112 | TVMStreamHandle CreateStream(Device dev) final { |
1113 | return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev); |
1114 | } |
1115 | |
1116 | void FreeStream(Device dev, TVMStreamHandle stream) final { |
1117 | endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream); |
1118 | } |
1119 | |
1120 | void StreamSync(Device dev, TVMStreamHandle stream) final { |
1121 | endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream); |
1122 | } |
1123 | |
1124 | void SetStream(Device dev, TVMStreamHandle stream) final { |
1125 | endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream); |
1126 | } |
1127 | |
1128 | DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; } |
1129 | |
1130 | bool IsLocalSession() const final { return false; } |
1131 | |
1132 | void Shutdown() final { endpoint_->Shutdown(); } |
1133 | |
1134 | private: |
1135 | uint64_t GetRPCMaxTransferSize() { |
1136 | if (rpc_chunk_max_size_bytes_ > 0) { |
1137 | return (uint64_t)rpc_chunk_max_size_bytes_; |
1138 | } |
1139 | |
1140 | PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize" ); |
1141 | if (rpc_func == nullptr) { |
1142 | rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault; |
1143 | } else { |
1144 | CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) { |
1145 | // Use args[1] as return value, args[0] is tcode |
1146 | // Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc |
1147 | rpc_chunk_max_size_bytes_ = (int64_t)args[1]; |
1148 | ICHECK_GT(rpc_chunk_max_size_bytes_, 0) |
1149 | << "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_ |
1150 | << ")" ; |
1151 | }); |
1152 | } |
1153 | return (uint64_t)rpc_chunk_max_size_bytes_; |
1154 | } |
1155 | |
1156 | std::shared_ptr<RPCEndpoint> endpoint_; |
1157 | int64_t rpc_chunk_max_size_bytes_ = -1; |
1158 | }; |
1159 | |
1160 | std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) { |
1161 | return std::make_shared<RPCClientSession>(endpoint); |
1162 | } |
1163 | |
1164 | uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) { |
1165 | uint64_t shape_bytes = tensor->ndim * sizeof(int64_t); |
1166 | uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data)); |
1167 | uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) + |
1168 | sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) + |
1169 | shape_bytes + sizeof(nbytes); |
1170 | return overhead; |
1171 | } |
1172 | |
1173 | } // namespace runtime |
1174 | } // namespace tvm |
1175 | |