1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_socket_impl.cc |
22 | * \brief Socket based RPC implementation. |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | |
26 | #include <memory> |
27 | |
28 | #include "../../support/socket.h" |
29 | #include "rpc_endpoint.h" |
30 | #include "rpc_local_session.h" |
31 | #include "rpc_session.h" |
32 | |
33 | namespace tvm { |
34 | namespace runtime { |
35 | |
36 | class SockChannel final : public RPCChannel { |
37 | public: |
38 | explicit SockChannel(support::TCPSocket sock) : sock_(sock) {} |
39 | ~SockChannel() { |
40 | try { |
41 | // BadSocket can throw |
42 | if (!sock_.BadSocket()) { |
43 | sock_.Close(); |
44 | } |
45 | } catch (...) { |
46 | } |
47 | } |
48 | size_t Send(const void* data, size_t size) final { |
49 | ssize_t n = sock_.Send(data, size); |
50 | if (n == -1) { |
51 | support::Socket::Error("SockChannel::Send" ); |
52 | } |
53 | return static_cast<size_t>(n); |
54 | } |
55 | size_t Recv(void* data, size_t size) final { |
56 | ssize_t n = sock_.Recv(data, size); |
57 | if (n == -1) { |
58 | support::Socket::Error("SockChannel::Recv" ); |
59 | } |
60 | return static_cast<size_t>(n); |
61 | } |
62 | |
63 | private: |
64 | support::TCPSocket sock_; |
65 | }; |
66 | |
67 | std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string key, |
68 | bool enable_logging, TVMArgs init_seq) { |
69 | support::TCPSocket sock; |
70 | support::SockAddr addr(url.c_str(), port); |
71 | sock.Create(addr.ss_family()); |
72 | ICHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed" ; |
73 | // hand shake |
74 | std::ostringstream os; |
75 | int code = kRPCMagic; |
76 | int keylen = static_cast<int>(key.length()); |
77 | ICHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code)); |
78 | ICHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); |
79 | if (keylen != 0) { |
80 | ICHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen); |
81 | } |
82 | ICHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); |
83 | if (code == kRPCMagic + 2) { |
84 | sock.Close(); |
85 | LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; |
86 | } else if (code == kRPCMagic + 1) { |
87 | sock.Close(); |
88 | LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; |
89 | } else if (code != kRPCMagic) { |
90 | sock.Close(); |
91 | LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server" ; |
92 | } |
93 | ICHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); |
94 | std::string remote_key; |
95 | if (keylen != 0) { |
96 | remote_key.resize(keylen); |
97 | ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); |
98 | } |
99 | |
100 | std::unique_ptr<RPCChannel> channel = std::make_unique<SockChannel>(sock); |
101 | if (enable_logging) { |
102 | channel.reset(new RPCChannelLogging(std::move(channel))); |
103 | } |
104 | auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key); |
105 | |
106 | endpt->InitRemoteSession(init_seq); |
107 | return endpt; |
108 | } |
109 | |
110 | Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, |
111 | TVMArgs init_seq) { |
112 | auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq); |
113 | return CreateRPCSessionModule(CreateClientSession(endpt)); |
114 | } |
115 | |
116 | // TVM_DLL needed for MSVC |
117 | TVM_DLL void RPCServerLoop(int sockfd) { |
118 | support::TCPSocket sock(static_cast<support::TCPSocket::SockType>(sockfd)); |
119 | RPCEndpoint::Create(std::make_unique<SockChannel>(sock), "SockServerLoop" , "" )->ServerLoop(); |
120 | } |
121 | |
122 | void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { |
123 | RPCEndpoint::Create(std::make_unique<CallbackChannel>(fsend, frecv), "SockServerLoop" , "" ) |
124 | ->ServerLoop(); |
125 | } |
126 | |
127 | TVM_REGISTER_GLOBAL("rpc.Connect" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
128 | std::string url = args[0]; |
129 | int port = args[1]; |
130 | std::string key = args[2]; |
131 | bool enable_logging = args[3]; |
132 | *rv = RPCClientConnect(url, port, key, enable_logging, |
133 | TVMArgs(args.values + 4, args.type_codes + 4, args.size() - 4)); |
134 | }); |
135 | |
136 | TVM_REGISTER_GLOBAL("rpc.ServerLoop" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
137 | if (args[0].type_code() == kDLInt) { |
138 | RPCServerLoop(args[0]); |
139 | } else { |
140 | RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(), |
141 | args[1].operator tvm::runtime::PackedFunc()); |
142 | } |
143 | }); |
144 | |
145 | } // namespace runtime |
146 | } // namespace tvm |
147 | |