1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/device_api.h
22 * \brief Abstract device memory management API
23 */
24#ifndef TVM_RUNTIME_DEVICE_API_H_
25#define TVM_RUNTIME_DEVICE_API_H_
26
27#include <tvm/runtime/c_runtime_api.h>
28#include <tvm/runtime/ndarray.h>
29#include <tvm/runtime/packed_func.h>
30
31#include <string>
32
33namespace tvm {
34namespace runtime {
35/*!
36 * \brief the query type into GetAttr
37 */
38enum DeviceAttrKind : int {
39 kExist = 0,
40 kMaxThreadsPerBlock = 1,
41 kWarpSize = 2,
42 kMaxSharedMemoryPerBlock = 3,
43 kComputeVersion = 4,
44 kDeviceName = 5,
45 kMaxClockRate = 6,
46 kMultiProcessorCount = 7,
47 kMaxThreadDimensions = 8,
48 kMaxRegistersPerBlock = 9,
49 kGcnArch = 10,
50 kApiVersion = 11,
51 kDriverVersion = 12
52};
53
54#ifdef TVM_KALLOC_ALIGNMENT
55/*! \brief Number of bytes each allocation must align to */
56constexpr int kAllocAlignment = TVM_KALLOC_ALIGNMENT;
57
58/*! \brief Number of bytes each allocation must align to in temporary allocation */
59constexpr int kTempAllocaAlignment = TVM_KALLOC_ALIGNMENT;
60#else
61/*! \brief Number of bytes each allocation must align to */
62constexpr int kAllocAlignment = 64;
63
64/*! \brief Number of bytes each allocation must align to in temporary allocation */
65constexpr int kTempAllocaAlignment = 64;
66#endif // TVM_KALLOC_ALIGNMENT
67
68/*! \brief Maximum size that can be allocated on stack */
69constexpr int kMaxStackAlloca = 1024;
70
71/*! \brief Number of bytes each allocation must align to by default in the workspace buffer to
72 * service intermediate tensors */
73constexpr int kDefaultWorkspaceAlignment = 1;
74
75/*!
76 * \brief TVM Runtime Device API, abstracts the device
77 * specific interface for memory management.
78 */
79class TVM_DLL DeviceAPI {
80 public:
81 /*! \brief virtual destructor */
82 virtual ~DeviceAPI() {}
83 /*!
84 * \brief Set the environment device id to device
85 * \param dev The device to be set.
86 */
87 virtual void SetDevice(Device dev) = 0;
88 /*!
89 * \brief Get attribute of specified device.
90 * \param dev The device device
91 * \param kind The result kind
92 * \param rv The return value.
93 * \sa DeviceAttrKind
94 */
95 virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0;
96
97 /*!
98 * \brief Query the device for specified properties.
99 *
100 * This is used to expand "-from_device=N" in the target string to
101 * all properties that can be determined from that device.
102 */
103 virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {}
104
105 /*!
106 * \brief Allocate a data space on device.
107 * \param dev The device device to perform operation.
108 * \param nbytes The number of bytes in memory.
109 * \param alignment The alignment of the memory.
110 * \param type_hint The type of elements. Only needed by certain backends such
111 * as OpenGL, as nbytes & alignment are sufficient for most backends.
112 * \return The allocated device pointer.
113 */
114 virtual void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment,
115 DLDataType type_hint) = 0;
116 /*!
117 * \brief Allocate a data space on device with memory scope support.
118 * \param dev The device device to perform operation.
119 * \param ndim The number of dimension of allocated tensor.
120 * \param shape The shape of allocated tensor.
121 * \param dtype The type of elements.
122 * \param mem_scope The memory scope of allocated tensor.
123 * \return The allocated device pointer.
124 */
125 virtual void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
126 Optional<String> mem_scope = NullOpt);
127 /*!
128 * \brief Free a data space on device.
129 * \param dev The device device to perform operation.
130 * \param ptr The data space.
131 */
132 virtual void FreeDataSpace(Device dev, void* ptr) = 0;
133 /*!
134 * \brief copy data from one place to another
135 * \note This API is designed to support special memory with shape dependent layout.
136 * We pass in DLTensor* with shape information to support these cases.
137 * \param from The source array.
138 * \param to The target array.
139 * \param stream Optional stream object.
140 */
141 virtual void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream);
142 /*!
143 * \brief Create a new stream of execution.
144 *
145 * \param dev The device of allocation.
146 */
147 virtual TVMStreamHandle CreateStream(Device dev);
148
149 /*!
150 * \brief Free a stream of execution
151 *
152 * \param dev The device of the stream
153 * \param stream The pointer to be freed.
154 */
155 virtual void FreeStream(Device dev, TVMStreamHandle stream);
156
157 /*!
158 * \brief Synchronize the stream
159 * \param dev The device to perform operation.
160 * \param stream The stream to be sync.
161 */
162 virtual void StreamSync(Device dev, TVMStreamHandle stream) = 0;
163 /*!
164 * \brief Set the stream
165 * \param dev The device to set stream.
166 * \param stream The stream to be set.
167 */
168 virtual void SetStream(Device dev, TVMStreamHandle stream) {}
169 /*!
170 * \brief Synchronize 2 streams of execution.
171 *
172 * An event is created in event_src stream that the second then
173 * stream waits on. Neither event_src or event_dst need to be of
174 * the same device ID as the device, but they must be of the same
175 * device type.
176 *
177 * \param dev The device of the streams.
178 * \param event_src The source stream to synchronize.
179 * \param event_dst The destination stream to synchronize.
180 */
181 virtual void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst);
182 /*!
183 * \brief Allocate temporal workspace for backend execution.
184 *
185 * \note We have the following assumption about backend temporal
186 * workspace allocation, and backend will optimize for such assumption:
187 *
188 * - Only a few allocation will happen, and space will be released after use.
189 * - The release order is usually in reverse order of allocate (stack style).
190 * - Repeative pattern of same allocations over different runs.
191 * - Workspace should not overlap between different threads(i.e. be threadlocal)
192 *
193 * \param dev The device of allocation.
194 * \param nbytes The size to be allocated.
195 * \param type_hint The type of elements. Only needed by certain backends such
196 * as OpenGL, as nbytes is sufficient for most backends.
197 */
198 virtual void* AllocWorkspace(Device dev, size_t nbytes, DLDataType type_hint = {});
199 /*!
200 * \brief Free temporal workspace in backend execution.
201 *
202 * \param dev The device of allocation.
203 * \param ptr The pointer to be freed.
204 */
205 virtual void FreeWorkspace(Device dev, void* ptr);
206
207 /*!
208 * \brief Get device API based on device.
209 * \param dev The device
210 * \param allow_missing Whether allow missing
211 * \return The corresponding device API.
212 */
213 static DeviceAPI* Get(Device dev, bool allow_missing = false);
214
215 /*!
216 * \brief Whether a certian device type requires set device device
217 * before launching the kernel function.
218 * \param device_type The device type.
219 */
220 static bool NeedSetDevice(int device_type) {
221 return device_type != kDLCPU && device_type != kDLMicroDev;
222 }
223
224 protected:
225 /*!
226 * \brief copy data from one place to another
227 * \param from The source array.
228 * \param from_offset The byte offeset in the from.
229 * \param to The target array.
230 * \param to_offset The byte offset in the to.
231 * \param num_bytes The size of the memory in bytes
232 * \param dev_from The source device
233 * \param dev_to The target device
234 * \param type_hint The type of elements, only neded by certain backends.
235 * can be useful for cross device endian converison.
236 * \param stream Optional stream object.
237 */
238 virtual void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
239 size_t num_bytes, Device dev_from, Device dev_to,
240 DLDataType type_hint, TVMStreamHandle stream);
241};
242
243/*! \brief The device type bigger than this is RPC device */
244constexpr int kRPCSessMask = 128;
245static_assert(kRPCSessMask >= TVMDeviceExtType_End);
246
247/*!
248 * \brief The name of Device API factory.
249 * \param type The device type.
250 * \return the device name.
251 */
252inline const char* DeviceName(int type) {
253 switch (type) {
254 case kDLCPU:
255 return "cpu";
256 case kDLCUDA:
257 return "cuda";
258 case kDLCUDAHost:
259 return "cuda_host";
260 case kDLCUDAManaged:
261 return "cuda_managed";
262 case kDLOpenCL:
263 return "opencl";
264 case kDLSDAccel:
265 return "sdaccel";
266 case kDLAOCL:
267 return "aocl";
268 case kDLVulkan:
269 return "vulkan";
270 case kDLMetal:
271 return "metal";
272 case kDLVPI:
273 return "vpi";
274 case kDLROCM:
275 return "rocm";
276 case kDLROCMHost:
277 return "rocm_host";
278 case kDLExtDev:
279 return "ext_dev";
280 case kDLOneAPI:
281 return "oneapi";
282 case kDLWebGPU:
283 return "webgpu";
284 case kDLHexagon:
285 return "hexagon";
286 case kOpenGL:
287 return "opengl";
288 case kDLMicroDev:
289 return "microdev";
290 default:
291 LOG(FATAL) << "unknown type =" << type;
292 }
293}
294
295/*!
296 * \brief Return true if a Device is owned by an RPC session.
297 */
298inline bool IsRPCSessionDevice(Device dev) { return (dev.device_type / kRPCSessMask) > 0; }
299
300/*!
301 * \brief Return the RPCSessTable index of the RPC Session that owns this device.
302 * \return the table index.
303 */
304inline int GetRPCSessionIndex(Device dev) {
305 ICHECK(IsRPCSessionDevice(dev)) << "GetRPCSessionIndex: dev has no RPC session";
306 return dev.device_type / kRPCSessMask - 1;
307}
308
309/*!
310 * \brief Remove the RPC session mask from a Device.
311 * RPC clients typically do this when encoding a Device for transmission to an RPC remote.
312 * On the wire, RPCdevice are expected to be valid on the server without interpretation.
313 * \param dev A Device with non-zero RPC Session mask, valid on the RPC client.
314 * \return A Device without any RPC Session mask, valid on the RPC server.
315 */
316inline Device RemoveRPCSessionMask(Device dev) {
317 dev.device_type = static_cast<DLDeviceType>(dev.device_type % kRPCSessMask);
318 return dev;
319}
320
321inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
322 if (tvm::runtime::IsRPCSessionDevice(dev)) {
323 os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
324 dev = tvm::runtime::RemoveRPCSessionMask(dev);
325 }
326 os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
327 return os;
328}
329
330/*!
331 * \brief Add a RPC session mask to a Device.
332 * RPC clients typically do this when decoding a Device received from a RPC remote.
333 * \param dev A Device without any RPC Session mask, valid on the RPC server.
334 * \param session_table_index Numeric index of the RPC session in the session table.
335 * \return A Device with RPC session mask added, valid on the RPC client.
336 */
337inline Device AddRPCSessionMask(Device dev, int session_table_index) {
338 CHECK(!IsRPCSessionDevice(dev)) << "AddRPCSessionMask: dev already non-zero RPCSessionIndex: "
339 << dev;
340 dev.device_type =
341 static_cast<DLDeviceType>(dev.device_type | (kRPCSessMask * (session_table_index + 1)));
342 return dev;
343}
344
345} // namespace runtime
346} // namespace tvm
347
348#endif // TVM_RUNTIME_DEVICE_API_H_
349