1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_reference.h |
22 | * \brief Common header defining the communication code used in the RPC protocol. |
23 | */ |
24 | #ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ |
25 | #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ |
26 | |
27 | namespace tvm { |
28 | namespace runtime { |
29 | |
30 | /*! \brief The current RPC procotol version. */ |
31 | constexpr const char* kRPCProtocolVer = "0.8.0" ; |
32 | |
33 | // When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. |
34 | const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX; |
35 | |
36 | /*! \brief The RPC code */ |
37 | enum class RPCCode : int { |
38 | kNone, |
39 | kShutdown, |
40 | kInitServer, |
41 | kCallFunc, |
42 | kReturn, |
43 | kException, |
44 | kCopyFromRemote, |
45 | kCopyToRemote, |
46 | kCopyAck, |
47 | // The following are syscall code that can send over CallRemote |
48 | kSyscallCodeStart, |
49 | kGetGlobalFunc = kSyscallCodeStart, |
50 | kFreeHandle, |
51 | kDevSetDevice, |
52 | kDevGetAttr, |
53 | kDevAllocData, |
54 | kDevFreeData, |
55 | kDevStreamSync, |
56 | kCopyAmongRemote, |
57 | kDevAllocDataWithScope, |
58 | kDevCreateStream, |
59 | kDevFreeStream, |
60 | kDevSetStream, |
61 | }; |
62 | |
63 | /*! |
64 | * \brief List of potential error status during rpc communication. |
65 | */ |
66 | enum class RPCServerStatus : int { |
67 | kSuccess = 0, |
68 | kInvalidTypeCodeObject, |
69 | kInvalidTypeCodeNDArray, |
70 | kInvalidDLTensorFieldStride, |
71 | kInvalidDLTensorFieldByteOffset, |
72 | kUnknownTypeCode, |
73 | kUnknownRPCCode, |
74 | kRPCCodeNotSupported, |
75 | kUnknownRPCSyscall, |
76 | kCheckError, |
77 | kReadError, |
78 | kWriteError, |
79 | kAllocError |
80 | }; |
81 | |
82 | inline const char* RPCCodeToString(RPCCode code) { |
83 | switch (code) { |
84 | case RPCCode::kShutdown: |
85 | return "kShutdown" ; |
86 | case RPCCode::kInitServer: |
87 | return "kInitServer" ; |
88 | case RPCCode::kCallFunc: |
89 | return "kCallFunc" ; |
90 | case RPCCode::kReturn: |
91 | return "kReturn" ; |
92 | case RPCCode::kException: |
93 | return "kException" ; |
94 | case RPCCode::kCopyFromRemote: |
95 | return "kCopyFromRemote" ; |
96 | case RPCCode::kCopyToRemote: |
97 | return "kCopyToRemote" ; |
98 | case RPCCode::kCopyAck: |
99 | return "kCopyAck" ; |
100 | // The following are syscall code that can send over CallRemote |
101 | case RPCCode::kGetGlobalFunc: |
102 | return "kGetGlobalFunc" ; |
103 | case RPCCode::kFreeHandle: |
104 | return "kFreeHandle" ; |
105 | case RPCCode::kDevSetDevice: |
106 | return "kDevSetDevice" ; |
107 | case RPCCode::kDevGetAttr: |
108 | return "kDevGetAttr" ; |
109 | case RPCCode::kDevAllocData: |
110 | return "kDevAllocData" ; |
111 | case RPCCode::kDevFreeData: |
112 | return "kDevFreeData" ; |
113 | case RPCCode::kDevCreateStream: |
114 | return "kDevCreateStream" ; |
115 | case RPCCode::kDevFreeStream: |
116 | return "kDevFreeStream" ; |
117 | case RPCCode::kDevStreamSync: |
118 | return "kDevStreamSync" ; |
119 | case RPCCode::kDevSetStream: |
120 | return "kDevSetStream" ; |
121 | case RPCCode::kCopyAmongRemote: |
122 | return "kCopyAmongRemote" ; |
123 | case RPCCode::kDevAllocDataWithScope: |
124 | return "kDevAllocDataWithScope" ; |
125 | default: |
126 | return "" ; |
127 | } |
128 | } |
129 | |
130 | /*! |
131 | * \brief Convert RPC server status to string. |
132 | * \param status The status. |
133 | * \return The corresponding string. |
134 | */ |
135 | inline const char* RPCServerStatusToString(RPCServerStatus status) { |
136 | switch (status) { |
137 | case RPCServerStatus::kSuccess: |
138 | return "kSuccess" ; |
139 | case RPCServerStatus::kInvalidTypeCodeObject: |
140 | return "kInvalidTypeCodeObject" ; |
141 | case RPCServerStatus::kInvalidTypeCodeNDArray: |
142 | return "kInvalidTypeCodeNDArray" ; |
143 | case RPCServerStatus::kInvalidDLTensorFieldStride: |
144 | return "kInvalidDLTensorFieldStride" ; |
145 | case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { |
146 | return "kInvalidDLTensorFieldByteOffset" ; |
147 | } |
148 | case RPCServerStatus::kUnknownTypeCode: |
149 | return "kUnknownTypeCode" ; |
150 | case RPCServerStatus::kUnknownRPCCode: |
151 | return "kUnknownRPCCode" ; |
152 | case RPCServerStatus::kRPCCodeNotSupported: |
153 | return "RPCCodeNotSupported" ; |
154 | case RPCServerStatus::kUnknownRPCSyscall: |
155 | return "kUnknownRPCSyscall" ; |
156 | case RPCServerStatus::kCheckError: |
157 | return "kCheckError" ; |
158 | case RPCServerStatus::kReadError: |
159 | return "kReadError" ; |
160 | case RPCServerStatus::kWriteError: |
161 | return "kWriteError" ; |
162 | case RPCServerStatus::kAllocError: |
163 | return "kAllocError" ; |
164 | default: |
165 | return "" ; |
166 | } |
167 | } |
168 | |
169 | /*! |
170 | * \brief Reference implementation of the communication protocol. |
171 | * |
172 | * \note The implementation is intentionally written via template |
173 | * so it can be used in a dependency free setting. |
174 | * |
175 | * \sa src/runtime/rpc/device/min_rpc_server.h |
176 | */ |
177 | struct RPCReference { |
178 | /*! |
179 | * \brief Auxiliary class to get the packed sequence. |
180 | * \tparam TChannel The channel to throw errror. |
181 | */ |
182 | template <typename TChannel> |
183 | struct PackedSeqNumBytesGetter { |
184 | public: |
185 | explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {} |
186 | |
187 | template <typename T> |
188 | void Write(const T& value) { |
189 | num_bytes_ += sizeof(T); |
190 | } |
191 | |
192 | template <typename T> |
193 | void WriteArray(const T* value, size_t num) { |
194 | num_bytes_ += sizeof(T) * num; |
195 | } |
196 | |
197 | void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } |
198 | |
199 | uint64_t num_bytes() const { return num_bytes_; } |
200 | |
201 | private: |
202 | TChannel* channel_; |
203 | uint64_t num_bytes_{0}; |
204 | }; |
205 | |
206 | /*! |
207 | * \return the length of the str. |
208 | * \param str the string. |
209 | * \return The length. |
210 | */ |
211 | static uint64_t StrLength(const char* str) { |
212 | uint64_t len = 0; |
213 | while (str[len] != '\0') ++len; |
214 | return len; |
215 | } |
216 | |
217 | /*! |
218 | * \brief Get the total nbytes to be sent in the packed sequence. |
219 | * |
220 | * \param arg_values The values to be sent over. |
221 | * \param type_codes The type codes to be sent over. |
222 | * \param num_args Number of argument. |
223 | * \param client_mode Whether it is a client to server call. |
224 | * \param channel The communication channel handler. |
225 | * \tparam TChannel The type of the communication channel. |
226 | * \return The total number of bytes. |
227 | */ |
228 | template <typename TChannel> |
229 | static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, |
230 | int num_args, bool client_mode, TChannel* channel) { |
231 | PackedSeqNumBytesGetter<TChannel> getter(channel); |
232 | SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); |
233 | return getter.num_bytes(); |
234 | } |
235 | |
236 | template <typename TChannelPtr> |
237 | static void SendDLTensor(TChannelPtr channel, DLTensor* arr) { |
238 | DLDevice dev; |
239 | uint64_t data; |
240 | // When we return NDArray, we directly return |
241 | // the space and the context |
242 | // The client will be further wrapping |
243 | dev = arr->device; |
244 | data = reinterpret_cast<uint64_t>(arr->data); |
245 | channel->Write(data); |
246 | channel->Write(dev); |
247 | channel->Write(arr->ndim); |
248 | channel->Write(arr->dtype); |
249 | channel->WriteArray(arr->shape, arr->ndim); |
250 | if (arr->strides != nullptr) { |
251 | channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); |
252 | } |
253 | channel->Write(arr->byte_offset); |
254 | return; |
255 | } |
256 | |
257 | template <typename TChannelPtr> |
258 | static DLTensor* ReceiveDLTensor(TChannelPtr channel) { |
259 | uint64_t handle; |
260 | channel->Read(&handle); |
261 | DLTensor* arr = channel->template ArenaAlloc<DLTensor>(1); |
262 | DLTensor& tensor = *arr; |
263 | tensor.data = reinterpret_cast<void*>(handle); |
264 | channel->Read(&(tensor.device)); |
265 | channel->Read(&(tensor.ndim)); |
266 | channel->Read(&(tensor.dtype)); |
267 | tensor.shape = channel->template ArenaAlloc<int64_t>(tensor.ndim); |
268 | channel->ReadArray(tensor.shape, tensor.ndim); |
269 | tensor.strides = nullptr; |
270 | channel->Read(&(tensor.byte_offset)); |
271 | return arr; |
272 | } |
273 | |
274 | /*! |
275 | * \brief Send packed argument sequnce to the other peer. |
276 | * |
277 | * This function serves as the foundational communication primitive between peers. |
278 | * |
279 | * TVMValue sequence encoding protocol(according to the type): |
280 | * |
281 | * - int/float/uint/bytes/str: Serialize all content. |
282 | * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) |
283 | * - OpaqueHandle: send as uint64_t |
284 | * - ModuleHandle, PackedFuncHandle: send as uint64_t, |
285 | * The support to Module/PackedFuncHandle are reserved for arguments |
286 | * in the CallFunc from a client to server only. |
287 | * Note that we cannot simply take these argument out(as the handle) |
288 | * refers to a value on the remote(instead of local). |
289 | * |
290 | * \param arg_values The values to be sent over. |
291 | * \param type_codes The type codes to be sent over. |
292 | * \param num_args Number of argument. |
293 | * \param client_mode Whether it is a client to server call. |
294 | * \param channel The communication channel handler. |
295 | * \tparam TChannel The type of the communication channel. |
296 | */ |
297 | template <typename TChannel> |
298 | static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, |
299 | bool client_mode, TChannel* channel) { |
300 | channel->Write(num_args); |
301 | channel->WriteArray(type_codes, num_args); |
302 | |
303 | // Argument packing. |
304 | for (int i = 0; i < num_args; ++i) { |
305 | int tcode = type_codes[i]; |
306 | TVMValue value = arg_values[i]; |
307 | switch (tcode) { |
308 | case kDLInt: |
309 | case kDLUInt: |
310 | case kDLFloat: { |
311 | channel->template Write<int64_t>(value.v_int64); |
312 | break; |
313 | } |
314 | case kTVMDataType: { |
315 | channel->Write(value.v_type); |
316 | // padding |
317 | int32_t padding = 0; |
318 | channel->template Write<int32_t>(padding); |
319 | break; |
320 | } |
321 | case kDLDevice: { |
322 | channel->Write(value.v_device); |
323 | break; |
324 | } |
325 | |
326 | case kTVMPackedFuncHandle: |
327 | case kTVMModuleHandle: { |
328 | if (!client_mode) { |
329 | channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); |
330 | } |
331 | // always send handle in 64 bit. |
332 | uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); |
333 | channel->Write(handle); |
334 | break; |
335 | } |
336 | case kTVMOpaqueHandle: { |
337 | // always send handle in 64 bit. |
338 | uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); |
339 | channel->Write(handle); |
340 | break; |
341 | } |
342 | case kTVMNDArrayHandle: { |
343 | channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); |
344 | break; |
345 | } |
346 | case kTVMDLTensorHandle: { |
347 | DLTensor* arr = static_cast<DLTensor*>(value.v_handle); |
348 | SendDLTensor(channel, arr); |
349 | break; |
350 | } |
351 | case kTVMNullptr: |
352 | break; |
353 | case kTVMStr: { |
354 | const char* s = value.v_str; |
355 | uint64_t len = StrLength(s); |
356 | channel->Write(len); |
357 | channel->WriteArray(s, len); |
358 | break; |
359 | } |
360 | case kTVMBytes: { |
361 | TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle); |
362 | uint64_t len = bytes->size; |
363 | channel->Write(len); |
364 | channel->WriteArray(bytes->data, len); |
365 | break; |
366 | } |
367 | default: { |
368 | channel->ThrowError(RPCServerStatus::kUnknownTypeCode); |
369 | break; |
370 | } |
371 | } |
372 | } |
373 | } |
374 | |
375 | /*! |
376 | * \brief Receive packed seq from the channel. |
377 | * |
378 | * \param out_arg_values The values to be received. |
379 | * \param out_tcodes The type codes to be received. |
380 | * \param out_num_args Number of argument. |
381 | * \param channel The communication channel handler. |
382 | * \tparam TChannel The type of the communication channel. |
383 | * \note The temporary space are populated via an arena inside channel. |
384 | */ |
385 | template <typename TChannel> |
386 | static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args, |
387 | TChannel* channel) { |
388 | // receive number of args |
389 | int num_args; |
390 | channel->Read(&num_args); |
391 | *out_num_args = num_args; |
392 | |
393 | if (num_args == 0) { |
394 | *out_values = nullptr; |
395 | *out_tcodes = nullptr; |
396 | return; |
397 | } |
398 | |
399 | TVMValue* values = channel->template ArenaAlloc<TVMValue>(num_args); |
400 | int* tcodes = channel->template ArenaAlloc<int>(num_args); |
401 | *out_values = values; |
402 | *out_tcodes = tcodes; |
403 | |
404 | // receive type code. |
405 | channel->ReadArray(tcodes, num_args); |
406 | |
407 | // receive arguments |
408 | for (int i = 0; i < num_args; ++i) { |
409 | auto& value = values[i]; |
410 | switch (tcodes[i]) { |
411 | case kDLInt: |
412 | case kDLUInt: |
413 | case kDLFloat: { |
414 | channel->template Read<int64_t>(&(value.v_int64)); |
415 | break; |
416 | } |
417 | case kTVMDataType: { |
418 | channel->Read(&(value.v_type)); |
419 | int32_t padding = 0; |
420 | channel->template Read<int32_t>(&padding); |
421 | break; |
422 | } |
423 | case kDLDevice: { |
424 | channel->Read(&(value.v_device)); |
425 | break; |
426 | } |
427 | case kTVMPackedFuncHandle: |
428 | case kTVMModuleHandle: |
429 | case kTVMOpaqueHandle: { |
430 | // always send handle in 64 bit. |
431 | uint64_t handle; |
432 | channel->Read(&handle); |
433 | value.v_handle = reinterpret_cast<void*>(handle); |
434 | break; |
435 | } |
436 | case kTVMNullptr: { |
437 | value.v_handle = nullptr; |
438 | break; |
439 | } |
440 | case kTVMStr: { |
441 | uint64_t len; |
442 | channel->Read(&len); |
443 | char* str = channel->template ArenaAlloc<char>(len + 1); |
444 | str[len] = '\0'; |
445 | channel->ReadArray(str, len); |
446 | value.v_str = str; |
447 | break; |
448 | } |
449 | case kTVMBytes: { |
450 | uint64_t len; |
451 | channel->Read(&len); |
452 | TVMByteArray* arr = channel->template ArenaAlloc<TVMByteArray>(1); |
453 | char* data = channel->template ArenaAlloc<char>(len); |
454 | arr->size = len; |
455 | arr->data = data; |
456 | channel->ReadArray(data, len); |
457 | value.v_handle = arr; |
458 | break; |
459 | } |
460 | case kTVMDLTensorHandle: { |
461 | value.v_handle = ReceiveDLTensor(channel); |
462 | break; |
463 | } |
464 | default: { |
465 | channel->ThrowError(RPCServerStatus::kUnknownTypeCode); |
466 | break; |
467 | } |
468 | } |
469 | } |
470 | } |
471 | |
472 | /*! |
473 | * \brief Return an exception packet. |
474 | * |
475 | * \param msg The error message. |
476 | * \param channel The communication channel handler. |
477 | * \tparam TChannel The type of the communication channel. |
478 | */ |
479 | template <typename TChannel> |
480 | static void ReturnException(const char* msg, TChannel* channel) { |
481 | RPCCode code = RPCCode::kException; |
482 | int32_t num_args = 1; |
483 | int32_t tcode = kTVMStr; |
484 | uint64_t len = StrLength(msg); |
485 | |
486 | uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; |
487 | |
488 | channel->MessageStart(packet_nbytes); |
489 | channel->Write(packet_nbytes); |
490 | channel->Write(code); |
491 | channel->Write(num_args); |
492 | channel->Write(tcode); |
493 | channel->Write(len); |
494 | channel->WriteArray(msg, len); |
495 | channel->MessageDone(); |
496 | } |
497 | |
498 | /*! |
499 | * \brief Return a normal packed sequence packet. |
500 | * |
501 | * \param msg The error message. |
502 | * \param channel The communication channel handler. |
503 | * \tparam TChannel The type of the communication channel. |
504 | */ |
505 | template <typename TChannel> |
506 | static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, |
507 | TChannel* channel) { |
508 | RPCCode code = RPCCode::kReturn; |
509 | |
510 | uint64_t packet_nbytes = |
511 | sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); |
512 | |
513 | channel->MessageStart(packet_nbytes); |
514 | channel->Write(packet_nbytes); |
515 | channel->Write(code); |
516 | SendPackedSeq(arg_values, type_codes, num_args, false, channel); |
517 | channel->MessageDone(); |
518 | } |
519 | |
520 | /*! |
521 | * \brief Return a null(void) packet. |
522 | * |
523 | * \param channel The communication channel handler. |
524 | * \tparam TChannel The type of the communication channel. |
525 | */ |
526 | template <typename TChannel> |
527 | static void ReturnVoid(TChannel* channel) { |
528 | int32_t num_args = 1; |
529 | int32_t tcode = kTVMNullptr; |
530 | RPCCode code = RPCCode::kReturn; |
531 | |
532 | uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); |
533 | |
534 | channel->MessageStart(packet_nbytes); |
535 | channel->Write(packet_nbytes); |
536 | channel->Write(code); |
537 | channel->Write(num_args); |
538 | channel->Write(tcode); |
539 | channel->MessageDone(); |
540 | } |
541 | }; |
542 | |
543 | } // namespace runtime |
544 | } // namespace tvm |
545 | |
546 | #endif // TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ |
547 | |