1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file local_session.cc |
22 | * \brief Local session that directs requests to local API. |
23 | */ |
24 | #include "rpc_local_session.h" |
25 | |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <memory> |
30 | |
31 | namespace tvm { |
32 | namespace runtime { |
33 | |
34 | RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { |
35 | if (auto* fp = tvm::runtime::Registry::Get(name)) { |
36 | // return raw handle because the remote need to explicitly manage it. |
37 | tvm::runtime::TVMRetValue ret; |
38 | ret = *fp; |
39 | TVMValue val; |
40 | int type_code; |
41 | ret.MoveToCHost(&val, &type_code); |
42 | return val.v_handle; |
43 | } else { |
44 | return nullptr; |
45 | } |
46 | } |
47 | |
48 | void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { |
49 | int rv_tcode = rv.type_code(); |
50 | |
51 | // return value encoding. |
52 | TVMValue ret_value_pack[3]; |
53 | int ret_tcode_pack[3]; |
54 | TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); |
55 | // first location always encode type code. |
56 | set_arg(0, rv_tcode); |
57 | |
58 | if (rv_tcode == kTVMNDArrayHandle) { |
59 | // We follow a special protocol to return NDArray to client side |
60 | // The first pack value is the NDArray handle as DLTensor |
61 | // The second pack value is a customized deleter that deletes the NDArray. |
62 | rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); |
63 | ret_tcode_pack[1] = kTVMDLTensorHandle; |
64 | ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; |
65 | ret_tcode_pack[2] = kTVMOpaqueHandle; |
66 | encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); |
67 | } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { |
68 | // MoveToCHost means rv no longer manages the object. |
69 | // return handle instead. |
70 | rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); |
71 | ret_tcode_pack[1] = kTVMOpaqueHandle; |
72 | encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); |
73 | } else if (rv_tcode == kTVMBytes) { |
74 | TVMByteArray byte_arr; |
75 | auto* sptr = rv.ptr<std::string>(); |
76 | byte_arr.data = sptr->data(); |
77 | byte_arr.size = sptr->length(); |
78 | set_arg(1, byte_arr); |
79 | encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); |
80 | } else { |
81 | set_arg(1, rv); |
82 | encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); |
83 | } |
84 | } |
85 | |
86 | void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, |
87 | const int* arg_type_codes, int num_args, |
88 | const FEncodeReturn& encode_return) { |
89 | PackedFuncObj* pf = static_cast<PackedFuncObj*>(func); |
90 | TVMRetValue rv; |
91 | pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); |
92 | this->EncodeReturn(std::move(rv), encode_return); |
93 | } |
94 | |
95 | void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { |
96 | ICHECK_EQ(nbytes, GetDataSize(*to)); |
97 | DLTensor from; |
98 | from.data = from_bytes; |
99 | from.device = {kDLCPU, 0}; |
100 | from.ndim = to->ndim; |
101 | from.shape = to->shape; |
102 | from.dtype = to->dtype; |
103 | from.strides = nullptr; |
104 | from.byte_offset = 0; |
105 | Device dev_to = to->device; |
106 | this->GetDeviceAPI(dev_to)->CopyDataFromTo(&from, to, nullptr); |
107 | // Copy can happen asynchrously |
108 | // synchronize to make sure that copy is completed |
109 | this->GetDeviceAPI(dev_to)->StreamSync(dev_to, nullptr); |
110 | } |
111 | |
112 | void LocalSession::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { |
113 | ICHECK_EQ(nbytes, GetDataSize(*from)); |
114 | DLTensor to; |
115 | to.data = to_bytes; |
116 | to.device = {kDLCPU, 0}; |
117 | to.ndim = from->ndim; |
118 | to.shape = from->shape; |
119 | to.dtype = from->dtype; |
120 | to.strides = nullptr; |
121 | to.byte_offset = 0; |
122 | |
123 | Device dev_from = from->device; |
124 | this->GetDeviceAPI(dev_from)->CopyDataFromTo(from, &to, nullptr); |
125 | // Copy can happen asynchrously |
126 | // synchronize to make sure that copy is completed |
127 | this->GetDeviceAPI(dev_from)->StreamSync(dev_from, nullptr); |
128 | } |
129 | |
130 | void LocalSession::FreeHandle(void* handle, int type_code) { |
131 | TVMValue value; |
132 | value.v_handle = handle; |
133 | // will trigger deleter once the rv goes out of the scope. |
134 | TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); |
135 | } |
136 | |
137 | DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) { |
138 | return DeviceAPI::Get(dev, allow_missing); |
139 | } |
140 | |
141 | TVM_REGISTER_GLOBAL("rpc.LocalSession" ).set_body_typed([]() { |
142 | return CreateRPCSessionModule(std::make_shared<LocalSession>()); |
143 | }); |
144 | |
145 | } // namespace runtime |
146 | } // namespace tvm |
147 | |