1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_session.cc |
22 | * \brief RPC session for remote function call. |
23 | */ |
24 | #include "rpc_session.h" |
25 | |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/runtime/packed_func.h> |
28 | |
29 | #include <array> |
30 | #include <mutex> |
31 | |
32 | namespace tvm { |
33 | namespace runtime { |
34 | |
35 | bool RPCSession::IsAsync() const { return false; } |
36 | |
37 | void RPCSession::SendException(FAsyncCallback callback, const char* msg) { |
38 | TVMValue value; |
39 | value.v_str = msg; |
40 | int32_t tcode = kTVMStr; |
41 | callback(RPCCode::kException, TVMArgs(&value, &tcode, 1)); |
42 | } |
43 | |
44 | void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, |
45 | const int* arg_type_codes, int num_args, FAsyncCallback callback) { |
46 | try { |
47 | this->CallFunc(func, arg_values, arg_type_codes, num_args, |
48 | [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); |
49 | } catch (const std::exception& e) { |
50 | this->SendException(callback, e.what()); |
51 | } |
52 | } |
53 | |
54 | void RPCSession::AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes, |
55 | RPCSession::FAsyncCallback callback) { |
56 | TVMValue value; |
57 | int32_t tcode = kTVMNullptr; |
58 | value.v_handle = nullptr; |
59 | |
60 | try { |
61 | this->CopyToRemote(local_from_bytes, remote_to, nbytes); |
62 | callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); |
63 | } catch (const std::exception& e) { |
64 | this->SendException(callback, e.what()); |
65 | } |
66 | } |
67 | |
68 | void RPCSession::AsyncCopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes, |
69 | RPCSession::FAsyncCallback callback) { |
70 | TVMValue value; |
71 | int32_t tcode = kTVMNullptr; |
72 | value.v_handle = nullptr; |
73 | |
74 | try { |
75 | this->CopyFromRemote(remote_from, local_to_bytes, nbytes); |
76 | callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); |
77 | } catch (const std::exception& e) { |
78 | this->SendException(callback, e.what()); |
79 | } |
80 | } |
81 | |
82 | void RPCSession::AsyncStreamWait(Device dev, TVMStreamHandle stream, |
83 | RPCSession::FAsyncCallback callback) { |
84 | TVMValue value; |
85 | int32_t tcode = kTVMNullptr; |
86 | value.v_handle = nullptr; |
87 | |
88 | try { |
89 | this->GetDeviceAPI(dev)->StreamSync(dev, stream); |
90 | callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); |
91 | } catch (const std::exception& e) { |
92 | this->SendException(callback, e.what()); |
93 | } |
94 | } |
95 | |
96 | class RPCSessTable { |
97 | public: |
98 | static constexpr int kMaxRPCSession = 32; |
99 | // Get global singleton |
100 | static RPCSessTable* Global() { |
101 | static RPCSessTable inst; |
102 | return &inst; |
103 | } |
104 | // Get session from table |
105 | std::shared_ptr<RPCSession> Get(int index) { |
106 | ICHECK(index >= 0 && index < kMaxRPCSession); |
107 | return tbl_[index].lock(); |
108 | } |
109 | // Insert session into table. |
110 | int Insert(std::shared_ptr<RPCSession> ptr) { |
111 | std::lock_guard<std::mutex> lock(mutex_); |
112 | for (int i = 0; i < kMaxRPCSession; ++i) { |
113 | if (tbl_[i].lock() == nullptr) { |
114 | tbl_[i] = ptr; |
115 | return i; |
116 | } |
117 | } |
118 | LOG(FATAL) << "maximum number of RPC session reached" ; |
119 | } |
120 | |
121 | private: |
122 | // The mutex |
123 | std::mutex mutex_; |
124 | // Use weak_ptr intentionally |
125 | // If the RPCSession get released, the pointer session will be released |
126 | std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_; |
127 | }; |
128 | |
129 | std::shared_ptr<RPCSession> RPCSession::Get(int table_index) { |
130 | return RPCSessTable::Global()->Get(table_index); |
131 | } |
132 | |
133 | void RPCSession::InsertToSessionTable(std::shared_ptr<RPCSession> sess) { |
134 | ICHECK_EQ(sess->table_index_, 0); |
135 | sess->table_index_ = RPCSessTable::Global()->Insert(sess); |
136 | } |
137 | |
138 | } // namespace runtime |
139 | } // namespace tvm |
140 | |