1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_pipe_impl.cc
22 * \brief Pipe-based RPC channel.
23 */
24// Linux only for now, as linux is the most common usecase.
25#if defined(__linux__) || defined(__ANDROID__)
26
27#include <errno.h>
28#include <signal.h>
29#include <sys/types.h>
30#include <tvm/runtime/registry.h>
31#include <unistd.h>
32
33#include <cstdlib>
34#include <memory>
35
36#include "../../support/pipe.h"
37#include "rpc_endpoint.h"
38#include "rpc_local_session.h"
39
40namespace tvm {
41namespace runtime {
42
43class PipeChannel final : public RPCChannel {
44 public:
45 explicit PipeChannel(int readfd, int writefd, pid_t child_pid)
46 : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {}
47
48 ~PipeChannel() { Close(); }
49
50 size_t Send(const void* data, size_t size) final {
51 ssize_t n = write(writefd_, data, size);
52 if (n == -1) {
53 LOG(FATAL) << "Pipe write error";
54 }
55 return static_cast<size_t>(n);
56 }
57
58 size_t Recv(void* data, size_t size) final {
59 ssize_t n = read(readfd_, data, size);
60 if (n == -1) {
61 LOG(FATAL) << "Pipe read error";
62 }
63 return static_cast<size_t>(n);
64 }
65
66 void Close() {
67 close(readfd_);
68 close(writefd_);
69 kill(child_pid_, SIGKILL);
70 }
71
72 private:
73 int readfd_;
74 int writefd_;
75 pid_t child_pid_;
76};
77
78Module CreatePipeClient(std::vector<std::string> cmd) {
79 int parent2child[2];
80 int child2parent[2];
81 ICHECK_EQ(pipe(parent2child), 0);
82 ICHECK_EQ(pipe(child2parent), 0);
83
84 int parent_read = child2parent[0];
85 int parent_write = parent2child[1];
86 int child_read = parent2child[0];
87 int child_write = child2parent[1];
88
89 pid_t pid = fork();
90 if (pid == 0) {
91 // child process
92 close(parent_read);
93 close(parent_write);
94 std::string sread_pipe = std::to_string(child_read);
95 std::string swrite_pipe = std::to_string(child_write);
96 std::vector<char*> argv;
97 for (auto& str : cmd) {
98 argv.push_back(dmlc::BeginPtr(str));
99 }
100 argv.push_back(dmlc::BeginPtr(sread_pipe));
101 argv.push_back(dmlc::BeginPtr(swrite_pipe));
102 argv.push_back(nullptr);
103 execvp(argv[0], &argv[0]);
104 }
105 // parent process
106 close(child_read);
107 close(child_write);
108
109 auto endpt = RPCEndpoint::Create(std::make_unique<PipeChannel>(parent_read, parent_write, pid),
110 "pipe", "pipe");
111 endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0));
112 return CreateRPCSessionModule(CreateClientSession(endpt));
113}
114
115TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) {
116 std::vector<std::string> cmd;
117 for (int i = 0; i < args.size(); ++i) {
118 cmd.push_back(args[i].operator std::string());
119 }
120 *rv = CreatePipeClient(cmd);
121});
122
123} // namespace runtime
124} // namespace tvm
125#endif
126