1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ |
21 | #define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ |
22 | |
23 | #include <memory> |
24 | #include <utility> |
25 | |
26 | #include "minrpc_logger.h" |
27 | #include "minrpc_server.h" |
28 | |
29 | namespace tvm { |
30 | namespace runtime { |
31 | |
32 | /*! |
33 | * \brief A minimum RPC server that logs the received commands. |
34 | * |
35 | * \tparam TIOHandler IO provider to provide io handling. |
36 | */ |
37 | template <typename TIOHandler> |
38 | class MinRPCServerWithLog { |
39 | public: |
40 | explicit MinRPCServerWithLog(TIOHandler* io) |
41 | : ret_handler_(io), |
42 | ret_handler_wlog_(&ret_handler_, &logger_), |
43 | exec_handler_(io, &ret_handler_wlog_), |
44 | exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), |
45 | next_(io, std::move(exec_handler_ptr_)) {} |
46 | |
47 | bool ProcessOnePacket() { return next_.ProcessOnePacket(); } |
48 | |
49 | private: |
50 | Logger logger_; |
51 | MinRPCReturns<TIOHandler> ret_handler_; |
52 | MinRPCExecute<TIOHandler> exec_handler_; |
53 | MinRPCReturnsWithLog ret_handler_wlog_; |
54 | std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_; |
55 | MinRPCServer<TIOHandler> next_; |
56 | }; |
57 | |
58 | /*! |
59 | * \brief A minimum RPC server that only logs the outgoing commands and received responses. |
60 | * (Does not process the packets or respond to them.) |
61 | * |
62 | * \tparam TIOHandler IO provider to provide io handling. |
63 | */ |
64 | template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator> |
65 | class MinRPCSniffer { |
66 | public: |
67 | using PageAllocator = Allocator<TIOHandler>; |
68 | explicit MinRPCSniffer(TIOHandler* io) |
69 | : io_(io), |
70 | arena_(PageAllocator(io_)), |
71 | ret_handler_(io_), |
72 | ret_handler_wlog_(&ret_handler_, &logger_), |
73 | exec_handler_(&ret_handler_wlog_), |
74 | exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), |
75 | next_(io_, std::move(exec_handler_ptr_)) {} |
76 | |
77 | bool ProcessOnePacket() { return next_.ProcessOnePacket(); } |
78 | |
79 | void ProcessOneResponse() { |
80 | RPCCode code; |
81 | uint64_t packet_len = 0; |
82 | |
83 | if (!Read(&packet_len)) return; |
84 | if (packet_len == 0) { |
85 | OutputLog(); |
86 | return; |
87 | } |
88 | if (!Read(&code)) return; |
89 | switch (code) { |
90 | case RPCCode::kReturn: { |
91 | int32_t num_args; |
92 | int* type_codes; |
93 | TVMValue* values; |
94 | RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); |
95 | ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args); |
96 | break; |
97 | } |
98 | case RPCCode::kException: { |
99 | ret_handler_wlog_.ReturnException("" ); |
100 | break; |
101 | } |
102 | default: { |
103 | OutputLog(); |
104 | break; |
105 | } |
106 | } |
107 | } |
108 | |
109 | void OutputLog() { logger_.OutputLog(); } |
110 | |
111 | void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { |
112 | logger_.Log("-> " ); |
113 | logger_.Log(RPCServerStatusToString(code)); |
114 | OutputLog(); |
115 | } |
116 | |
117 | template <typename T> |
118 | T* ArenaAlloc(int count) { |
119 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
120 | "need to be trival" ); |
121 | return arena_.template allocate_<T>(count); |
122 | } |
123 | |
124 | template <typename T> |
125 | bool Read(T* data) { |
126 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
127 | "need to be trival" ); |
128 | return ReadRawBytes(data, sizeof(T)); |
129 | } |
130 | |
131 | template <typename T> |
132 | bool ReadArray(T* data, size_t count) { |
133 | static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value, |
134 | "need to be trival" ); |
135 | return ReadRawBytes(data, sizeof(T) * count); |
136 | } |
137 | |
138 | private: |
139 | bool ReadRawBytes(void* data, size_t size) { |
140 | uint8_t* buf = reinterpret_cast<uint8_t*>(data); |
141 | size_t ndone = 0; |
142 | while (ndone < size) { |
143 | ssize_t ret = io_->PosixRead(buf, size - ndone); |
144 | if (ret <= 0) { |
145 | this->ThrowError(RPCServerStatus::kReadError); |
146 | return false; |
147 | } |
148 | ndone += ret; |
149 | buf += ret; |
150 | } |
151 | return true; |
152 | } |
153 | |
154 | Logger logger_; |
155 | TIOHandler* io_; |
156 | support::GenericArena<PageAllocator> arena_; |
157 | MinRPCReturnsNoOp<TIOHandler> ret_handler_; |
158 | MinRPCReturnsWithLog ret_handler_wlog_; |
159 | MinRPCExecuteNoOp exec_handler_; |
160 | std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_; |
161 | MinRPCServer<TIOHandler> next_; |
162 | }; |
163 | |
164 | } // namespace runtime |
165 | } // namespace tvm |
166 | #endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ |
167 | |