1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_socket_impl.cc
22 * \brief Socket based RPC implementation.
23 */
24#include <tvm/runtime/registry.h>
25
26#include <memory>
27
28#include "../../support/socket.h"
29#include "rpc_endpoint.h"
30#include "rpc_local_session.h"
31#include "rpc_session.h"
32
33namespace tvm {
34namespace runtime {
35
36class SockChannel final : public RPCChannel {
37 public:
38 explicit SockChannel(support::TCPSocket sock) : sock_(sock) {}
39 ~SockChannel() {
40 try {
41 // BadSocket can throw
42 if (!sock_.BadSocket()) {
43 sock_.Close();
44 }
45 } catch (...) {
46 }
47 }
48 size_t Send(const void* data, size_t size) final {
49 ssize_t n = sock_.Send(data, size);
50 if (n == -1) {
51 support::Socket::Error("SockChannel::Send");
52 }
53 return static_cast<size_t>(n);
54 }
55 size_t Recv(void* data, size_t size) final {
56 ssize_t n = sock_.Recv(data, size);
57 if (n == -1) {
58 support::Socket::Error("SockChannel::Recv");
59 }
60 return static_cast<size_t>(n);
61 }
62
63 private:
64 support::TCPSocket sock_;
65};
66
67std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string key,
68 bool enable_logging, TVMArgs init_seq) {
69 support::TCPSocket sock;
70 support::SockAddr addr(url.c_str(), port);
71 sock.Create(addr.ss_family());
72 ICHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed";
73 // hand shake
74 std::ostringstream os;
75 int code = kRPCMagic;
76 int keylen = static_cast<int>(key.length());
77 ICHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
78 ICHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
79 if (keylen != 0) {
80 ICHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
81 }
82 ICHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
83 if (code == kRPCMagic + 2) {
84 sock.Close();
85 LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key;
86 } else if (code == kRPCMagic + 1) {
87 sock.Close();
88 LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key;
89 } else if (code != kRPCMagic) {
90 sock.Close();
91 LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
92 }
93 ICHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
94 std::string remote_key;
95 if (keylen != 0) {
96 remote_key.resize(keylen);
97 ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
98 }
99
100 std::unique_ptr<RPCChannel> channel = std::make_unique<SockChannel>(sock);
101 if (enable_logging) {
102 channel.reset(new RPCChannelLogging(std::move(channel)));
103 }
104 auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key);
105
106 endpt->InitRemoteSession(init_seq);
107 return endpt;
108}
109
110Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging,
111 TVMArgs init_seq) {
112 auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq);
113 return CreateRPCSessionModule(CreateClientSession(endpt));
114}
115
116// TVM_DLL needed for MSVC
117TVM_DLL void RPCServerLoop(int sockfd) {
118 support::TCPSocket sock(static_cast<support::TCPSocket::SockType>(sockfd));
119 RPCEndpoint::Create(std::make_unique<SockChannel>(sock), "SockServerLoop", "")->ServerLoop();
120}
121
122void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
123 RPCEndpoint::Create(std::make_unique<CallbackChannel>(fsend, frecv), "SockServerLoop", "")
124 ->ServerLoop();
125}
126
127TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) {
128 std::string url = args[0];
129 int port = args[1];
130 std::string key = args[2];
131 bool enable_logging = args[3];
132 *rv = RPCClientConnect(url, port, key, enable_logging,
133 TVMArgs(args.values + 4, args.type_codes + 4, args.size() - 4));
134});
135
136TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) {
137 if (args[0].type_code() == kDLInt) {
138 RPCServerLoop(args[0]);
139 } else {
140 RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(),
141 args[1].operator tvm::runtime::PackedFunc());
142 }
143});
144
145} // namespace runtime
146} // namespace tvm
147