1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_device_api.cc |
22 | */ |
23 | #include <tvm/runtime/device_api.h> |
24 | #include <tvm/runtime/logging.h> |
25 | #include <tvm/runtime/registry.h> |
26 | |
27 | #include <utility> |
28 | |
29 | #include "rpc_session.h" |
30 | |
31 | namespace tvm { |
32 | namespace runtime { |
33 | |
34 | class RPCDeviceAPI final : public DeviceAPI { |
35 | public: |
36 | void SetDevice(Device dev) final { |
37 | auto remote_dev = RemoveRPCSessionMask(dev); |
38 | GetSess(dev)->GetDeviceAPI(remote_dev)->SetDevice(remote_dev); |
39 | } |
40 | |
41 | void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { |
42 | auto remote_dev = RemoveRPCSessionMask(dev); |
43 | GetSess(dev)->GetDeviceAPI(remote_dev)->GetAttr(remote_dev, kind, rv); |
44 | } |
45 | |
46 | void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, |
47 | Optional<String> mem_scope) final { |
48 | auto sess = GetSess(dev); |
49 | auto remote_dev = RemoveRPCSessionMask(dev); |
50 | void* data = |
51 | sess->GetDeviceAPI(remote_dev)->AllocDataSpace(remote_dev, ndim, shape, dtype, mem_scope); |
52 | RemoteSpace* space = new RemoteSpace(); |
53 | space->data = data; |
54 | space->sess = std::move(sess); |
55 | return space; |
56 | } |
57 | |
58 | void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { |
59 | auto sess = GetSess(dev); |
60 | auto remote_dev = RemoveRPCSessionMask(dev); |
61 | void* data = |
62 | sess->GetDeviceAPI(remote_dev)->AllocDataSpace(remote_dev, nbytes, alignment, type_hint); |
63 | |
64 | RemoteSpace* space = new RemoteSpace(); |
65 | space->data = data; |
66 | space->sess = std::move(sess); |
67 | return space; |
68 | } |
69 | void FreeDataSpace(Device dev, void* ptr) final { |
70 | RemoteSpace* space = static_cast<RemoteSpace*>(ptr); |
71 | auto remote_dev = RemoveRPCSessionMask(dev); |
72 | try { |
73 | GetSess(dev)->GetDeviceAPI(remote_dev)->FreeDataSpace(remote_dev, space->data); |
74 | } catch (const Error& e) { |
75 | // fault tolerance to remote close. |
76 | } |
77 | delete space; |
78 | } |
79 | |
80 | void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final { |
81 | DLDevice dev_from = from->device; |
82 | DLDevice dev_to = to->device; |
83 | if (IsRPCSessionDevice(dev_from) && IsRPCSessionDevice(dev_to)) { |
84 | ICHECK(dev_from.device_type == dev_to.device_type) |
85 | << "Cannot copy across two different remote session" ; |
86 | DLTensor from_tensor = *from; |
87 | from_tensor.device = RemoveRPCSessionMask(dev_from); |
88 | from_tensor.data = static_cast<const RemoteSpace*>(from->data)->data; |
89 | DLTensor to_tensor = *to; |
90 | to_tensor.device = RemoveRPCSessionMask(dev_to); |
91 | to_tensor.data = static_cast<const RemoteSpace*>(to->data)->data; |
92 | auto remote_dev = from_tensor.device; |
93 | if (remote_dev.device_type == kDLCPU) remote_dev = to_tensor.device; |
94 | GetSess(dev_from)->GetDeviceAPI(remote_dev)->CopyDataFromTo(&from_tensor, &to_tensor, stream); |
95 | } else if (IsRPCSessionDevice(dev_from) && dev_to.device_type == kDLCPU) { |
96 | DLTensor from_tensor = *from; |
97 | from_tensor.device = RemoveRPCSessionMask(dev_from); |
98 | from_tensor.data = static_cast<const RemoteSpace*>(from->data)->data; |
99 | void* to_bytes = static_cast<char*>(to->data) + to->byte_offset; |
100 | size_t nbytes = GetDataSize(*to); |
101 | GetSess(dev_from)->CopyFromRemote(&from_tensor, to_bytes, nbytes); |
102 | } else if (dev_from.device_type == kDLCPU && IsRPCSessionDevice(dev_to)) { |
103 | DLTensor to_tensor = *to; |
104 | to_tensor.device = RemoveRPCSessionMask(dev_to); |
105 | to_tensor.data = static_cast<const RemoteSpace*>(to->data)->data; |
106 | void* from_bytes = static_cast<char*>(from->data) + from->byte_offset; |
107 | size_t nbytes = GetDataSize(*from); |
108 | GetSess(dev_to)->CopyToRemote(from_bytes, &to_tensor, nbytes); |
109 | } else { |
110 | LOG(FATAL) << "expect copy from/to remote or between remote" ; |
111 | } |
112 | } |
113 | |
114 | TVMStreamHandle CreateStream(Device dev) { |
115 | auto remote_dev = RemoveRPCSessionMask(dev); |
116 | return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev); |
117 | } |
118 | |
119 | void FreeStream(Device dev, TVMStreamHandle stream) { |
120 | auto remote_dev = RemoveRPCSessionMask(dev); |
121 | GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream); |
122 | } |
123 | |
124 | void StreamSync(Device dev, TVMStreamHandle stream) final { |
125 | auto remote_dev = RemoveRPCSessionMask(dev); |
126 | GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream); |
127 | } |
128 | |
129 | void SetStream(Device dev, TVMStreamHandle stream) { |
130 | auto remote_dev = RemoveRPCSessionMask(dev); |
131 | GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream); |
132 | } |
133 | |
134 | protected: |
135 | void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, |
136 | size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, |
137 | TVMStreamHandle stream) final { |
138 | LOG(FATAL) << "Not implemented." ; |
139 | } |
140 | |
141 | private: |
142 | std::shared_ptr<RPCSession> GetSess(Device dev) { |
143 | int tbl_index = GetRPCSessionIndex(dev); |
144 | return RPCSession::Get(tbl_index); |
145 | } |
146 | }; |
147 | |
148 | TVM_REGISTER_GLOBAL("device_api.rpc" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
149 | static RPCDeviceAPI inst; |
150 | DeviceAPI* ptr = &inst; |
151 | *rv = static_cast<void*>(ptr); |
152 | }); |
153 | } // namespace runtime |
154 | } // namespace tvm |
155 | |