1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_pipe_impl.cc |
22 | * \brief Pipe-based RPC channel. |
23 | */ |
24 | // Linux only for now, as linux is the most common usecase. |
25 | #if defined(__linux__) || defined(__ANDROID__) |
26 | |
27 | #include <errno.h> |
28 | #include <signal.h> |
29 | #include <sys/types.h> |
30 | #include <tvm/runtime/registry.h> |
31 | #include <unistd.h> |
32 | |
33 | #include <cstdlib> |
34 | #include <memory> |
35 | |
36 | #include "../../support/pipe.h" |
37 | #include "rpc_endpoint.h" |
38 | #include "rpc_local_session.h" |
39 | |
40 | namespace tvm { |
41 | namespace runtime { |
42 | |
43 | class PipeChannel final : public RPCChannel { |
44 | public: |
45 | explicit PipeChannel(int readfd, int writefd, pid_t child_pid) |
46 | : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {} |
47 | |
48 | ~PipeChannel() { Close(); } |
49 | |
50 | size_t Send(const void* data, size_t size) final { |
51 | ssize_t n = write(writefd_, data, size); |
52 | if (n == -1) { |
53 | LOG(FATAL) << "Pipe write error" ; |
54 | } |
55 | return static_cast<size_t>(n); |
56 | } |
57 | |
58 | size_t Recv(void* data, size_t size) final { |
59 | ssize_t n = read(readfd_, data, size); |
60 | if (n == -1) { |
61 | LOG(FATAL) << "Pipe read error" ; |
62 | } |
63 | return static_cast<size_t>(n); |
64 | } |
65 | |
66 | void Close() { |
67 | close(readfd_); |
68 | close(writefd_); |
69 | kill(child_pid_, SIGKILL); |
70 | } |
71 | |
72 | private: |
73 | int readfd_; |
74 | int writefd_; |
75 | pid_t child_pid_; |
76 | }; |
77 | |
78 | Module CreatePipeClient(std::vector<std::string> cmd) { |
79 | int parent2child[2]; |
80 | int child2parent[2]; |
81 | ICHECK_EQ(pipe(parent2child), 0); |
82 | ICHECK_EQ(pipe(child2parent), 0); |
83 | |
84 | int parent_read = child2parent[0]; |
85 | int parent_write = parent2child[1]; |
86 | int child_read = parent2child[0]; |
87 | int child_write = child2parent[1]; |
88 | |
89 | pid_t pid = fork(); |
90 | if (pid == 0) { |
91 | // child process |
92 | close(parent_read); |
93 | close(parent_write); |
94 | std::string sread_pipe = std::to_string(child_read); |
95 | std::string swrite_pipe = std::to_string(child_write); |
96 | std::vector<char*> argv; |
97 | for (auto& str : cmd) { |
98 | argv.push_back(dmlc::BeginPtr(str)); |
99 | } |
100 | argv.push_back(dmlc::BeginPtr(sread_pipe)); |
101 | argv.push_back(dmlc::BeginPtr(swrite_pipe)); |
102 | argv.push_back(nullptr); |
103 | execvp(argv[0], &argv[0]); |
104 | } |
105 | // parent process |
106 | close(child_read); |
107 | close(child_write); |
108 | |
109 | auto endpt = RPCEndpoint::Create(std::make_unique<PipeChannel>(parent_read, parent_write, pid), |
110 | "pipe" , "pipe" ); |
111 | endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); |
112 | return CreateRPCSessionModule(CreateClientSession(endpt)); |
113 | } |
114 | |
115 | TVM_REGISTER_GLOBAL("rpc.CreatePipeClient" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
116 | std::vector<std::string> cmd; |
117 | for (int i = 0; i < args.size(); ++i) { |
118 | cmd.push_back(args[i].operator std::string()); |
119 | } |
120 | *rv = CreatePipeClient(cmd); |
121 | }); |
122 | |
123 | } // namespace runtime |
124 | } // namespace tvm |
125 | #endif |
126 | |