1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_reference.h
22 * \brief Common header defining the communication code used in the RPC protocol.
23 */
24#ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
25#define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
26
27namespace tvm {
28namespace runtime {
29
30/*! \brief The current RPC procotol version. */
31constexpr const char* kRPCProtocolVer = "0.8.0";
32
33// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
34const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;
35
36/*! \brief The RPC code */
37enum class RPCCode : int {
38 kNone,
39 kShutdown,
40 kInitServer,
41 kCallFunc,
42 kReturn,
43 kException,
44 kCopyFromRemote,
45 kCopyToRemote,
46 kCopyAck,
47 // The following are syscall code that can send over CallRemote
48 kSyscallCodeStart,
49 kGetGlobalFunc = kSyscallCodeStart,
50 kFreeHandle,
51 kDevSetDevice,
52 kDevGetAttr,
53 kDevAllocData,
54 kDevFreeData,
55 kDevStreamSync,
56 kCopyAmongRemote,
57 kDevAllocDataWithScope,
58 kDevCreateStream,
59 kDevFreeStream,
60 kDevSetStream,
61};
62
63/*!
64 * \brief List of potential error status during rpc communication.
65 */
66enum class RPCServerStatus : int {
67 kSuccess = 0,
68 kInvalidTypeCodeObject,
69 kInvalidTypeCodeNDArray,
70 kInvalidDLTensorFieldStride,
71 kInvalidDLTensorFieldByteOffset,
72 kUnknownTypeCode,
73 kUnknownRPCCode,
74 kRPCCodeNotSupported,
75 kUnknownRPCSyscall,
76 kCheckError,
77 kReadError,
78 kWriteError,
79 kAllocError
80};
81
82inline const char* RPCCodeToString(RPCCode code) {
83 switch (code) {
84 case RPCCode::kShutdown:
85 return "kShutdown";
86 case RPCCode::kInitServer:
87 return "kInitServer";
88 case RPCCode::kCallFunc:
89 return "kCallFunc";
90 case RPCCode::kReturn:
91 return "kReturn";
92 case RPCCode::kException:
93 return "kException";
94 case RPCCode::kCopyFromRemote:
95 return "kCopyFromRemote";
96 case RPCCode::kCopyToRemote:
97 return "kCopyToRemote";
98 case RPCCode::kCopyAck:
99 return "kCopyAck";
100 // The following are syscall code that can send over CallRemote
101 case RPCCode::kGetGlobalFunc:
102 return "kGetGlobalFunc";
103 case RPCCode::kFreeHandle:
104 return "kFreeHandle";
105 case RPCCode::kDevSetDevice:
106 return "kDevSetDevice";
107 case RPCCode::kDevGetAttr:
108 return "kDevGetAttr";
109 case RPCCode::kDevAllocData:
110 return "kDevAllocData";
111 case RPCCode::kDevFreeData:
112 return "kDevFreeData";
113 case RPCCode::kDevCreateStream:
114 return "kDevCreateStream";
115 case RPCCode::kDevFreeStream:
116 return "kDevFreeStream";
117 case RPCCode::kDevStreamSync:
118 return "kDevStreamSync";
119 case RPCCode::kDevSetStream:
120 return "kDevSetStream";
121 case RPCCode::kCopyAmongRemote:
122 return "kCopyAmongRemote";
123 case RPCCode::kDevAllocDataWithScope:
124 return "kDevAllocDataWithScope";
125 default:
126 return "";
127 }
128}
129
130/*!
131 * \brief Convert RPC server status to string.
132 * \param status The status.
133 * \return The corresponding string.
134 */
135inline const char* RPCServerStatusToString(RPCServerStatus status) {
136 switch (status) {
137 case RPCServerStatus::kSuccess:
138 return "kSuccess";
139 case RPCServerStatus::kInvalidTypeCodeObject:
140 return "kInvalidTypeCodeObject";
141 case RPCServerStatus::kInvalidTypeCodeNDArray:
142 return "kInvalidTypeCodeNDArray";
143 case RPCServerStatus::kInvalidDLTensorFieldStride:
144 return "kInvalidDLTensorFieldStride";
145 case RPCServerStatus::kInvalidDLTensorFieldByteOffset: {
146 return "kInvalidDLTensorFieldByteOffset";
147 }
148 case RPCServerStatus::kUnknownTypeCode:
149 return "kUnknownTypeCode";
150 case RPCServerStatus::kUnknownRPCCode:
151 return "kUnknownRPCCode";
152 case RPCServerStatus::kRPCCodeNotSupported:
153 return "RPCCodeNotSupported";
154 case RPCServerStatus::kUnknownRPCSyscall:
155 return "kUnknownRPCSyscall";
156 case RPCServerStatus::kCheckError:
157 return "kCheckError";
158 case RPCServerStatus::kReadError:
159 return "kReadError";
160 case RPCServerStatus::kWriteError:
161 return "kWriteError";
162 case RPCServerStatus::kAllocError:
163 return "kAllocError";
164 default:
165 return "";
166 }
167}
168
169/*!
170 * \brief Reference implementation of the communication protocol.
171 *
172 * \note The implementation is intentionally written via template
173 * so it can be used in a dependency free setting.
174 *
175 * \sa src/runtime/rpc/device/min_rpc_server.h
176 */
177struct RPCReference {
178 /*!
179 * \brief Auxiliary class to get the packed sequence.
180 * \tparam TChannel The channel to throw errror.
181 */
182 template <typename TChannel>
183 struct PackedSeqNumBytesGetter {
184 public:
185 explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {}
186
187 template <typename T>
188 void Write(const T& value) {
189 num_bytes_ += sizeof(T);
190 }
191
192 template <typename T>
193 void WriteArray(const T* value, size_t num) {
194 num_bytes_ += sizeof(T) * num;
195 }
196
197 void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); }
198
199 uint64_t num_bytes() const { return num_bytes_; }
200
201 private:
202 TChannel* channel_;
203 uint64_t num_bytes_{0};
204 };
205
206 /*!
207 * \return the length of the str.
208 * \param str the string.
209 * \return The length.
210 */
211 static uint64_t StrLength(const char* str) {
212 uint64_t len = 0;
213 while (str[len] != '\0') ++len;
214 return len;
215 }
216
217 /*!
218 * \brief Get the total nbytes to be sent in the packed sequence.
219 *
220 * \param arg_values The values to be sent over.
221 * \param type_codes The type codes to be sent over.
222 * \param num_args Number of argument.
223 * \param client_mode Whether it is a client to server call.
224 * \param channel The communication channel handler.
225 * \tparam TChannel The type of the communication channel.
226 * \return The total number of bytes.
227 */
228 template <typename TChannel>
229 static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes,
230 int num_args, bool client_mode, TChannel* channel) {
231 PackedSeqNumBytesGetter<TChannel> getter(channel);
232 SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter);
233 return getter.num_bytes();
234 }
235
236 template <typename TChannelPtr>
237 static void SendDLTensor(TChannelPtr channel, DLTensor* arr) {
238 DLDevice dev;
239 uint64_t data;
240 // When we return NDArray, we directly return
241 // the space and the context
242 // The client will be further wrapping
243 dev = arr->device;
244 data = reinterpret_cast<uint64_t>(arr->data);
245 channel->Write(data);
246 channel->Write(dev);
247 channel->Write(arr->ndim);
248 channel->Write(arr->dtype);
249 channel->WriteArray(arr->shape, arr->ndim);
250 if (arr->strides != nullptr) {
251 channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride);
252 }
253 channel->Write(arr->byte_offset);
254 return;
255 }
256
257 template <typename TChannelPtr>
258 static DLTensor* ReceiveDLTensor(TChannelPtr channel) {
259 uint64_t handle;
260 channel->Read(&handle);
261 DLTensor* arr = channel->template ArenaAlloc<DLTensor>(1);
262 DLTensor& tensor = *arr;
263 tensor.data = reinterpret_cast<void*>(handle);
264 channel->Read(&(tensor.device));
265 channel->Read(&(tensor.ndim));
266 channel->Read(&(tensor.dtype));
267 tensor.shape = channel->template ArenaAlloc<int64_t>(tensor.ndim);
268 channel->ReadArray(tensor.shape, tensor.ndim);
269 tensor.strides = nullptr;
270 channel->Read(&(tensor.byte_offset));
271 return arr;
272 }
273
274 /*!
275 * \brief Send packed argument sequnce to the other peer.
276 *
277 * This function serves as the foundational communication primitive between peers.
278 *
279 * TVMValue sequence encoding protocol(according to the type):
280 *
281 * - int/float/uint/bytes/str: Serialize all content.
282 * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t)
283 * - OpaqueHandle: send as uint64_t
284 * - ModuleHandle, PackedFuncHandle: send as uint64_t,
285 * The support to Module/PackedFuncHandle are reserved for arguments
286 * in the CallFunc from a client to server only.
287 * Note that we cannot simply take these argument out(as the handle)
288 * refers to a value on the remote(instead of local).
289 *
290 * \param arg_values The values to be sent over.
291 * \param type_codes The type codes to be sent over.
292 * \param num_args Number of argument.
293 * \param client_mode Whether it is a client to server call.
294 * \param channel The communication channel handler.
295 * \tparam TChannel The type of the communication channel.
296 */
297 template <typename TChannel>
298 static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
299 bool client_mode, TChannel* channel) {
300 channel->Write(num_args);
301 channel->WriteArray(type_codes, num_args);
302
303 // Argument packing.
304 for (int i = 0; i < num_args; ++i) {
305 int tcode = type_codes[i];
306 TVMValue value = arg_values[i];
307 switch (tcode) {
308 case kDLInt:
309 case kDLUInt:
310 case kDLFloat: {
311 channel->template Write<int64_t>(value.v_int64);
312 break;
313 }
314 case kTVMDataType: {
315 channel->Write(value.v_type);
316 // padding
317 int32_t padding = 0;
318 channel->template Write<int32_t>(padding);
319 break;
320 }
321 case kDLDevice: {
322 channel->Write(value.v_device);
323 break;
324 }
325
326 case kTVMPackedFuncHandle:
327 case kTVMModuleHandle: {
328 if (!client_mode) {
329 channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject);
330 }
331 // always send handle in 64 bit.
332 uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
333 channel->Write(handle);
334 break;
335 }
336 case kTVMOpaqueHandle: {
337 // always send handle in 64 bit.
338 uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
339 channel->Write(handle);
340 break;
341 }
342 case kTVMNDArrayHandle: {
343 channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray);
344 break;
345 }
346 case kTVMDLTensorHandle: {
347 DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
348 SendDLTensor(channel, arr);
349 break;
350 }
351 case kTVMNullptr:
352 break;
353 case kTVMStr: {
354 const char* s = value.v_str;
355 uint64_t len = StrLength(s);
356 channel->Write(len);
357 channel->WriteArray(s, len);
358 break;
359 }
360 case kTVMBytes: {
361 TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
362 uint64_t len = bytes->size;
363 channel->Write(len);
364 channel->WriteArray(bytes->data, len);
365 break;
366 }
367 default: {
368 channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
369 break;
370 }
371 }
372 }
373 }
374
375 /*!
376 * \brief Receive packed seq from the channel.
377 *
378 * \param out_arg_values The values to be received.
379 * \param out_tcodes The type codes to be received.
380 * \param out_num_args Number of argument.
381 * \param channel The communication channel handler.
382 * \tparam TChannel The type of the communication channel.
383 * \note The temporary space are populated via an arena inside channel.
384 */
385 template <typename TChannel>
386 static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args,
387 TChannel* channel) {
388 // receive number of args
389 int num_args;
390 channel->Read(&num_args);
391 *out_num_args = num_args;
392
393 if (num_args == 0) {
394 *out_values = nullptr;
395 *out_tcodes = nullptr;
396 return;
397 }
398
399 TVMValue* values = channel->template ArenaAlloc<TVMValue>(num_args);
400 int* tcodes = channel->template ArenaAlloc<int>(num_args);
401 *out_values = values;
402 *out_tcodes = tcodes;
403
404 // receive type code.
405 channel->ReadArray(tcodes, num_args);
406
407 // receive arguments
408 for (int i = 0; i < num_args; ++i) {
409 auto& value = values[i];
410 switch (tcodes[i]) {
411 case kDLInt:
412 case kDLUInt:
413 case kDLFloat: {
414 channel->template Read<int64_t>(&(value.v_int64));
415 break;
416 }
417 case kTVMDataType: {
418 channel->Read(&(value.v_type));
419 int32_t padding = 0;
420 channel->template Read<int32_t>(&padding);
421 break;
422 }
423 case kDLDevice: {
424 channel->Read(&(value.v_device));
425 break;
426 }
427 case kTVMPackedFuncHandle:
428 case kTVMModuleHandle:
429 case kTVMOpaqueHandle: {
430 // always send handle in 64 bit.
431 uint64_t handle;
432 channel->Read(&handle);
433 value.v_handle = reinterpret_cast<void*>(handle);
434 break;
435 }
436 case kTVMNullptr: {
437 value.v_handle = nullptr;
438 break;
439 }
440 case kTVMStr: {
441 uint64_t len;
442 channel->Read(&len);
443 char* str = channel->template ArenaAlloc<char>(len + 1);
444 str[len] = '\0';
445 channel->ReadArray(str, len);
446 value.v_str = str;
447 break;
448 }
449 case kTVMBytes: {
450 uint64_t len;
451 channel->Read(&len);
452 TVMByteArray* arr = channel->template ArenaAlloc<TVMByteArray>(1);
453 char* data = channel->template ArenaAlloc<char>(len);
454 arr->size = len;
455 arr->data = data;
456 channel->ReadArray(data, len);
457 value.v_handle = arr;
458 break;
459 }
460 case kTVMDLTensorHandle: {
461 value.v_handle = ReceiveDLTensor(channel);
462 break;
463 }
464 default: {
465 channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
466 break;
467 }
468 }
469 }
470 }
471
472 /*!
473 * \brief Return an exception packet.
474 *
475 * \param msg The error message.
476 * \param channel The communication channel handler.
477 * \tparam TChannel The type of the communication channel.
478 */
479 template <typename TChannel>
480 static void ReturnException(const char* msg, TChannel* channel) {
481 RPCCode code = RPCCode::kException;
482 int32_t num_args = 1;
483 int32_t tcode = kTVMStr;
484 uint64_t len = StrLength(msg);
485
486 uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len;
487
488 channel->MessageStart(packet_nbytes);
489 channel->Write(packet_nbytes);
490 channel->Write(code);
491 channel->Write(num_args);
492 channel->Write(tcode);
493 channel->Write(len);
494 channel->WriteArray(msg, len);
495 channel->MessageDone();
496 }
497
498 /*!
499 * \brief Return a normal packed sequence packet.
500 *
501 * \param msg The error message.
502 * \param channel The communication channel handler.
503 * \tparam TChannel The type of the communication channel.
504 */
505 template <typename TChannel>
506 static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args,
507 TChannel* channel) {
508 RPCCode code = RPCCode::kReturn;
509
510 uint64_t packet_nbytes =
511 sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel);
512
513 channel->MessageStart(packet_nbytes);
514 channel->Write(packet_nbytes);
515 channel->Write(code);
516 SendPackedSeq(arg_values, type_codes, num_args, false, channel);
517 channel->MessageDone();
518 }
519
520 /*!
521 * \brief Return a null(void) packet.
522 *
523 * \param channel The communication channel handler.
524 * \tparam TChannel The type of the communication channel.
525 */
526 template <typename TChannel>
527 static void ReturnVoid(TChannel* channel) {
528 int32_t num_args = 1;
529 int32_t tcode = kTVMNullptr;
530 RPCCode code = RPCCode::kReturn;
531
532 uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
533
534 channel->MessageStart(packet_nbytes);
535 channel->Write(packet_nbytes);
536 channel->Write(code);
537 channel->Write(num_args);
538 channel->Write(tcode);
539 channel->MessageDone();
540 }
541};
542
543} // namespace runtime
544} // namespace tvm
545
546#endif // TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
547