1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
21#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
22
23#include <memory>
24#include <utility>
25
26#include "minrpc_logger.h"
27#include "minrpc_server.h"
28
29namespace tvm {
30namespace runtime {
31
32/*!
33 * \brief A minimum RPC server that logs the received commands.
34 *
35 * \tparam TIOHandler IO provider to provide io handling.
36 */
37template <typename TIOHandler>
38class MinRPCServerWithLog {
39 public:
40 explicit MinRPCServerWithLog(TIOHandler* io)
41 : ret_handler_(io),
42 ret_handler_wlog_(&ret_handler_, &logger_),
43 exec_handler_(io, &ret_handler_wlog_),
44 exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
45 next_(io, std::move(exec_handler_ptr_)) {}
46
47 bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
48
49 private:
50 Logger logger_;
51 MinRPCReturns<TIOHandler> ret_handler_;
52 MinRPCExecute<TIOHandler> exec_handler_;
53 MinRPCReturnsWithLog ret_handler_wlog_;
54 std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
55 MinRPCServer<TIOHandler> next_;
56};
57
58/*!
59 * \brief A minimum RPC server that only logs the outgoing commands and received responses.
60 * (Does not process the packets or respond to them.)
61 *
62 * \tparam TIOHandler IO provider to provide io handling.
63 */
64template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator>
65class MinRPCSniffer {
66 public:
67 using PageAllocator = Allocator<TIOHandler>;
68 explicit MinRPCSniffer(TIOHandler* io)
69 : io_(io),
70 arena_(PageAllocator(io_)),
71 ret_handler_(io_),
72 ret_handler_wlog_(&ret_handler_, &logger_),
73 exec_handler_(&ret_handler_wlog_),
74 exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
75 next_(io_, std::move(exec_handler_ptr_)) {}
76
77 bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
78
79 void ProcessOneResponse() {
80 RPCCode code;
81 uint64_t packet_len = 0;
82
83 if (!Read(&packet_len)) return;
84 if (packet_len == 0) {
85 OutputLog();
86 return;
87 }
88 if (!Read(&code)) return;
89 switch (code) {
90 case RPCCode::kReturn: {
91 int32_t num_args;
92 int* type_codes;
93 TVMValue* values;
94 RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
95 ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args);
96 break;
97 }
98 case RPCCode::kException: {
99 ret_handler_wlog_.ReturnException("");
100 break;
101 }
102 default: {
103 OutputLog();
104 break;
105 }
106 }
107 }
108
109 void OutputLog() { logger_.OutputLog(); }
110
111 void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
112 logger_.Log("-> ");
113 logger_.Log(RPCServerStatusToString(code));
114 OutputLog();
115 }
116
117 template <typename T>
118 T* ArenaAlloc(int count) {
119 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
120 "need to be trival");
121 return arena_.template allocate_<T>(count);
122 }
123
124 template <typename T>
125 bool Read(T* data) {
126 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
127 "need to be trival");
128 return ReadRawBytes(data, sizeof(T));
129 }
130
131 template <typename T>
132 bool ReadArray(T* data, size_t count) {
133 static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
134 "need to be trival");
135 return ReadRawBytes(data, sizeof(T) * count);
136 }
137
138 private:
139 bool ReadRawBytes(void* data, size_t size) {
140 uint8_t* buf = reinterpret_cast<uint8_t*>(data);
141 size_t ndone = 0;
142 while (ndone < size) {
143 ssize_t ret = io_->PosixRead(buf, size - ndone);
144 if (ret <= 0) {
145 this->ThrowError(RPCServerStatus::kReadError);
146 return false;
147 }
148 ndone += ret;
149 buf += ret;
150 }
151 return true;
152 }
153
154 Logger logger_;
155 TIOHandler* io_;
156 support::GenericArena<PageAllocator> arena_;
157 MinRPCReturnsNoOp<TIOHandler> ret_handler_;
158 MinRPCReturnsWithLog ret_handler_wlog_;
159 MinRPCExecuteNoOp exec_handler_;
160 std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
161 MinRPCServer<TIOHandler> next_;
162};
163
164} // namespace runtime
165} // namespace tvm
166#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
167