1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_channel_logger.h
22 * \brief A wrapper for RPCChannel with a NanoRPCListener for logging the commands.
23 */
24#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
25#define TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
26
27#include <tvm/runtime/c_runtime_api.h>
28
29#include <memory>
30#include <utility>
31
32#include "../../support/ssize.h"
33#include "../minrpc/minrpc_server_logging.h"
34#include "rpc_channel.h"
35
36#define RX_BUFFER_SIZE 65536
37
38namespace tvm {
39namespace runtime {
40
41class Buffer {
42 public:
43 Buffer(uint8_t* data, size_t data_size_bytes)
44 : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {}
45
46 size_t Write(const uint8_t* data, size_t data_size_bytes) {
47 size_t num_bytes_available = capacity_ - num_valid_bytes_;
48 size_t num_bytes_to_copy = data_size_bytes;
49 if (num_bytes_available < num_bytes_to_copy) {
50 num_bytes_to_copy = num_bytes_available;
51 }
52
53 memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy);
54 num_valid_bytes_ += num_bytes_to_copy;
55 return num_bytes_to_copy;
56 }
57
58 size_t Read(uint8_t* data, size_t data_size_bytes) {
59 size_t num_bytes_to_copy = data_size_bytes;
60 size_t num_bytes_available = num_valid_bytes_ - read_cursor_;
61 if (num_bytes_available < num_bytes_to_copy) {
62 num_bytes_to_copy = num_bytes_available;
63 }
64
65 memcpy(data, &data_[read_cursor_], num_bytes_to_copy);
66 read_cursor_ += num_bytes_to_copy;
67 return num_bytes_to_copy;
68 }
69
70 void Clear() {
71 num_valid_bytes_ = 0;
72 read_cursor_ = 0;
73 }
74
75 size_t Size() const { return num_valid_bytes_; }
76
77 private:
78 /*! \brief pointer to data buffer. */
79 uint8_t* data_;
80
81 /*! \brief The total number of bytes available in data_.*/
82 size_t capacity_;
83
84 /*! \brief number of valid bytes in the buffer. */
85 size_t num_valid_bytes_;
86
87 /*! \brief Read cursor position. */
88 size_t read_cursor_;
89};
90
91/*!
92 * \brief A simple IO handler for MinRPCSniffer.
93 *
94 * \tparam Buffer* buffer to store received data.
95 */
96class SnifferIOHandler {
97 public:
98 explicit SnifferIOHandler(Buffer* receive_buffer) : receive_buffer_(receive_buffer) {}
99
100 void MessageStart(size_t message_size_bytes) {}
101
102 ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { return 0; }
103
104 void MessageDone() {}
105
106 ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) {
107 return receive_buffer_->Read(buf, buf_size_bytes);
108 }
109
110 void Close() {}
111
112 void Exit(int code) {}
113
114 private:
115 Buffer* receive_buffer_;
116};
117
118/*!
119 * \brief A simple rpc session that logs the received commands.
120 */
121class NanoRPCListener {
122 public:
123 NanoRPCListener()
124 : receive_buffer_(receive_storage_, receive_storage_size_bytes_),
125 io_(&receive_buffer_),
126 rpc_server_(&io_) {}
127
128 void Listen(const uint8_t* data, size_t size) { receive_buffer_.Write(data, size); }
129
130 void ProcessTxPacket() {
131 rpc_server_.ProcessOnePacket();
132 ClearBuffer();
133 }
134
135 void ProcessRxPacket() {
136 rpc_server_.ProcessOneResponse();
137 ClearBuffer();
138 }
139
140 private:
141 void ClearBuffer() { receive_buffer_.Clear(); }
142
143 private:
144 size_t receive_storage_size_bytes_ = RX_BUFFER_SIZE;
145 uint8_t receive_storage_[RX_BUFFER_SIZE];
146 Buffer receive_buffer_;
147 SnifferIOHandler io_;
148 MinRPCSniffer<SnifferIOHandler> rpc_server_;
149
150 void HandleCompleteMessage() { rpc_server_.ProcessOnePacket(); }
151
152 static void HandleCompleteMessageCb(void* context) {
153 static_cast<NanoRPCListener*>(context)->HandleCompleteMessage();
154 }
155};
156
157/*!
158 * \brief A wrapper for RPCChannel, that also logs the commands sent.
159 *
160 * \tparam std::unique_ptr<RPCChannel>&& underlying RPCChannel unique_ptr.
161 */
162class RPCChannelLogging : public RPCChannel {
163 public:
164 explicit RPCChannelLogging(std::unique_ptr<RPCChannel>&& next) { next_ = std::move(next); }
165
166 size_t Send(const void* data, size_t size) {
167 listener_.ProcessRxPacket();
168 listener_.Listen((const uint8_t*)data, size);
169 listener_.ProcessTxPacket();
170 return next_->Send(data, size);
171 }
172
173 size_t Recv(void* data, size_t size) {
174 size_t ret = next_->Recv(data, size);
175 listener_.Listen((const uint8_t*)data, size);
176 return ret;
177 }
178
179 private:
180 std::unique_ptr<RPCChannel> next_;
181 NanoRPCListener listener_;
182};
183
184} // namespace runtime
185} // namespace tvm
186#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
187