1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file local_session.cc
22 * \brief Local session that directs requests to local API.
23 */
24#include "rpc_local_session.h"
25
26#include <tvm/runtime/device_api.h>
27#include <tvm/runtime/registry.h>
28
29#include <memory>
30
31namespace tvm {
32namespace runtime {
33
34RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) {
35 if (auto* fp = tvm::runtime::Registry::Get(name)) {
36 // return raw handle because the remote need to explicitly manage it.
37 tvm::runtime::TVMRetValue ret;
38 ret = *fp;
39 TVMValue val;
40 int type_code;
41 ret.MoveToCHost(&val, &type_code);
42 return val.v_handle;
43 } else {
44 return nullptr;
45 }
46}
47
48void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) {
49 int rv_tcode = rv.type_code();
50
51 // return value encoding.
52 TVMValue ret_value_pack[3];
53 int ret_tcode_pack[3];
54 TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack);
55 // first location always encode type code.
56 set_arg(0, rv_tcode);
57
58 if (rv_tcode == kTVMNDArrayHandle) {
59 // We follow a special protocol to return NDArray to client side
60 // The first pack value is the NDArray handle as DLTensor
61 // The second pack value is a customized deleter that deletes the NDArray.
62 rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
63 ret_tcode_pack[1] = kTVMDLTensorHandle;
64 ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
65 ret_tcode_pack[2] = kTVMOpaqueHandle;
66 encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
67 } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
68 // MoveToCHost means rv no longer manages the object.
69 // return handle instead.
70 rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
71 ret_tcode_pack[1] = kTVMOpaqueHandle;
72 encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
73 } else if (rv_tcode == kTVMBytes) {
74 TVMByteArray byte_arr;
75 auto* sptr = rv.ptr<std::string>();
76 byte_arr.data = sptr->data();
77 byte_arr.size = sptr->length();
78 set_arg(1, byte_arr);
79 encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
80 } else {
81 set_arg(1, rv);
82 encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
83 }
84}
85
86void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values,
87 const int* arg_type_codes, int num_args,
88 const FEncodeReturn& encode_return) {
89 PackedFuncObj* pf = static_cast<PackedFuncObj*>(func);
90 TVMRetValue rv;
91 pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
92 this->EncodeReturn(std::move(rv), encode_return);
93}
94
95void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) {
96 ICHECK_EQ(nbytes, GetDataSize(*to));
97 DLTensor from;
98 from.data = from_bytes;
99 from.device = {kDLCPU, 0};
100 from.ndim = to->ndim;
101 from.shape = to->shape;
102 from.dtype = to->dtype;
103 from.strides = nullptr;
104 from.byte_offset = 0;
105 Device dev_to = to->device;
106 this->GetDeviceAPI(dev_to)->CopyDataFromTo(&from, to, nullptr);
107 // Copy can happen asynchrously
108 // synchronize to make sure that copy is completed
109 this->GetDeviceAPI(dev_to)->StreamSync(dev_to, nullptr);
110}
111
112void LocalSession::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) {
113 ICHECK_EQ(nbytes, GetDataSize(*from));
114 DLTensor to;
115 to.data = to_bytes;
116 to.device = {kDLCPU, 0};
117 to.ndim = from->ndim;
118 to.shape = from->shape;
119 to.dtype = from->dtype;
120 to.strides = nullptr;
121 to.byte_offset = 0;
122
123 Device dev_from = from->device;
124 this->GetDeviceAPI(dev_from)->CopyDataFromTo(from, &to, nullptr);
125 // Copy can happen asynchrously
126 // synchronize to make sure that copy is completed
127 this->GetDeviceAPI(dev_from)->StreamSync(dev_from, nullptr);
128}
129
130void LocalSession::FreeHandle(void* handle, int type_code) {
131 TVMValue value;
132 value.v_handle = handle;
133 // will trigger deleter once the rv goes out of the scope.
134 TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code);
135}
136
137DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) {
138 return DeviceAPI::Get(dev, allow_missing);
139}
140
141TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() {
142 return CreateRPCSessionModule(std::make_shared<LocalSession>());
143});
144
145} // namespace runtime
146} // namespace tvm
147