1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_device_api.cc
22 */
23#include <tvm/runtime/device_api.h>
24#include <tvm/runtime/logging.h>
25#include <tvm/runtime/registry.h>
26
27#include <utility>
28
29#include "rpc_session.h"
30
31namespace tvm {
32namespace runtime {
33
34class RPCDeviceAPI final : public DeviceAPI {
35 public:
36 void SetDevice(Device dev) final {
37 auto remote_dev = RemoveRPCSessionMask(dev);
38 GetSess(dev)->GetDeviceAPI(remote_dev)->SetDevice(remote_dev);
39 }
40
41 void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final {
42 auto remote_dev = RemoveRPCSessionMask(dev);
43 GetSess(dev)->GetDeviceAPI(remote_dev)->GetAttr(remote_dev, kind, rv);
44 }
45
46 void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
47 Optional<String> mem_scope) final {
48 auto sess = GetSess(dev);
49 auto remote_dev = RemoveRPCSessionMask(dev);
50 void* data =
51 sess->GetDeviceAPI(remote_dev)->AllocDataSpace(remote_dev, ndim, shape, dtype, mem_scope);
52 RemoteSpace* space = new RemoteSpace();
53 space->data = data;
54 space->sess = std::move(sess);
55 return space;
56 }
57
58 void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
59 auto sess = GetSess(dev);
60 auto remote_dev = RemoveRPCSessionMask(dev);
61 void* data =
62 sess->GetDeviceAPI(remote_dev)->AllocDataSpace(remote_dev, nbytes, alignment, type_hint);
63
64 RemoteSpace* space = new RemoteSpace();
65 space->data = data;
66 space->sess = std::move(sess);
67 return space;
68 }
69 void FreeDataSpace(Device dev, void* ptr) final {
70 RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
71 auto remote_dev = RemoveRPCSessionMask(dev);
72 try {
73 GetSess(dev)->GetDeviceAPI(remote_dev)->FreeDataSpace(remote_dev, space->data);
74 } catch (const Error& e) {
75 // fault tolerance to remote close.
76 }
77 delete space;
78 }
79
80 void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final {
81 DLDevice dev_from = from->device;
82 DLDevice dev_to = to->device;
83 if (IsRPCSessionDevice(dev_from) && IsRPCSessionDevice(dev_to)) {
84 ICHECK(dev_from.device_type == dev_to.device_type)
85 << "Cannot copy across two different remote session";
86 DLTensor from_tensor = *from;
87 from_tensor.device = RemoveRPCSessionMask(dev_from);
88 from_tensor.data = static_cast<const RemoteSpace*>(from->data)->data;
89 DLTensor to_tensor = *to;
90 to_tensor.device = RemoveRPCSessionMask(dev_to);
91 to_tensor.data = static_cast<const RemoteSpace*>(to->data)->data;
92 auto remote_dev = from_tensor.device;
93 if (remote_dev.device_type == kDLCPU) remote_dev = to_tensor.device;
94 GetSess(dev_from)->GetDeviceAPI(remote_dev)->CopyDataFromTo(&from_tensor, &to_tensor, stream);
95 } else if (IsRPCSessionDevice(dev_from) && dev_to.device_type == kDLCPU) {
96 DLTensor from_tensor = *from;
97 from_tensor.device = RemoveRPCSessionMask(dev_from);
98 from_tensor.data = static_cast<const RemoteSpace*>(from->data)->data;
99 void* to_bytes = static_cast<char*>(to->data) + to->byte_offset;
100 size_t nbytes = GetDataSize(*to);
101 GetSess(dev_from)->CopyFromRemote(&from_tensor, to_bytes, nbytes);
102 } else if (dev_from.device_type == kDLCPU && IsRPCSessionDevice(dev_to)) {
103 DLTensor to_tensor = *to;
104 to_tensor.device = RemoveRPCSessionMask(dev_to);
105 to_tensor.data = static_cast<const RemoteSpace*>(to->data)->data;
106 void* from_bytes = static_cast<char*>(from->data) + from->byte_offset;
107 size_t nbytes = GetDataSize(*from);
108 GetSess(dev_to)->CopyToRemote(from_bytes, &to_tensor, nbytes);
109 } else {
110 LOG(FATAL) << "expect copy from/to remote or between remote";
111 }
112 }
113
114 TVMStreamHandle CreateStream(Device dev) {
115 auto remote_dev = RemoveRPCSessionMask(dev);
116 return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev);
117 }
118
119 void FreeStream(Device dev, TVMStreamHandle stream) {
120 auto remote_dev = RemoveRPCSessionMask(dev);
121 GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream);
122 }
123
124 void StreamSync(Device dev, TVMStreamHandle stream) final {
125 auto remote_dev = RemoveRPCSessionMask(dev);
126 GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream);
127 }
128
129 void SetStream(Device dev, TVMStreamHandle stream) {
130 auto remote_dev = RemoveRPCSessionMask(dev);
131 GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream);
132 }
133
134 protected:
135 void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
136 size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint,
137 TVMStreamHandle stream) final {
138 LOG(FATAL) << "Not implemented.";
139 }
140
141 private:
142 std::shared_ptr<RPCSession> GetSess(Device dev) {
143 int tbl_index = GetRPCSessionIndex(dev);
144 return RPCSession::Get(tbl_index);
145 }
146};
147
148TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) {
149 static RPCDeviceAPI inst;
150 DeviceAPI* ptr = &inst;
151 *rv = static_cast<void*>(ptr);
152});
153} // namespace runtime
154} // namespace tvm
155