1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda_device_api.cc
22 * \brief GPU specific API
23 */
24#include <cuda.h>
25#include <cuda_runtime.h>
26#include <dmlc/thread_local.h>
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/profiling.h>
29#include <tvm/runtime/registry.h>
30
31#include <cstring>
32
33#include "cuda_common.h"
34
35namespace tvm {
36namespace runtime {
37
38class CUDADeviceAPI final : public DeviceAPI {
39 public:
40 void SetDevice(Device dev) final { CUDA_CALL(cudaSetDevice(dev.device_id)); }
41 void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final {
42 int value = 0;
43 switch (kind) {
44 case kExist:
45 value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id) ==
46 cudaSuccess);
47 break;
48 case kMaxThreadsPerBlock: {
49 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id));
50 break;
51 }
52 case kWarpSize: {
53 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, dev.device_id));
54 break;
55 }
56 case kMaxSharedMemoryPerBlock: {
57 CUDA_CALL(
58 cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, dev.device_id));
59 break;
60 }
61 case kComputeVersion: {
62 std::ostringstream os;
63 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, dev.device_id));
64 os << value << ".";
65 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, dev.device_id));
66 os << value;
67 *rv = os.str();
68 return;
69 }
70 case kDeviceName: {
71 std::string name(256, 0);
72 CUDA_DRIVER_CALL(cuDeviceGetName(&name[0], name.size(), dev.device_id));
73 name.resize(strlen(name.c_str()));
74 *rv = std::move(name);
75 return;
76 }
77 case kMaxClockRate: {
78 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, dev.device_id));
79 break;
80 }
81 case kMultiProcessorCount: {
82 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, dev.device_id));
83 break;
84 }
85 case kMaxThreadDimensions: {
86 int dims[3];
87 CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, dev.device_id));
88 CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, dev.device_id));
89 CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, dev.device_id));
90
91 std::stringstream ss; // use json string to return multiple int values;
92 ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
93 *rv = ss.str();
94 return;
95 }
96 case kMaxRegistersPerBlock: {
97 CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, dev.device_id));
98 break;
99 }
100 case kGcnArch:
101 return;
102 case kApiVersion: {
103 *rv = CUDA_VERSION;
104 return;
105 }
106 case kDriverVersion:
107 return;
108 }
109 *rv = value;
110 }
111 void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
112 ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
113 void* ret;
114 if (dev.device_type == kDLCUDAHost) {
115 VLOG(1) << "allocating " << nbytes << "bytes on host";
116 CUDA_CALL(cudaMallocHost(&ret, nbytes));
117 } else {
118 CUDA_CALL(cudaSetDevice(dev.device_id));
119 size_t free_mem, total_mem;
120 CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
121 VLOG(1) << "allocating " << nbytes << " bytes on device, with " << free_mem
122 << " bytes currently free out of " << total_mem << " bytes available";
123 CUDA_CALL(cudaMalloc(&ret, nbytes));
124 }
125 return ret;
126 }
127
128 void FreeDataSpace(Device dev, void* ptr) final {
129 if (dev.device_type == kDLCUDAHost) {
130 VLOG(1) << "freeing host memory";
131 CUDA_CALL(cudaFreeHost(ptr));
132 } else {
133 CUDA_CALL(cudaSetDevice(dev.device_id));
134 VLOG(1) << "freeing device memory";
135 CUDA_CALL(cudaFree(ptr));
136 }
137 }
138
139 protected:
140 void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
141 Device dev_from, Device dev_to, DLDataType type_hint,
142 TVMStreamHandle stream) final {
143 cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
144 from = static_cast<const char*>(from) + from_offset;
145 to = static_cast<char*>(to) + to_offset;
146
147 if (dev_from.device_type == kDLCUDAHost) {
148 dev_from.device_type = kDLCPU;
149 }
150
151 if (dev_to.device_type == kDLCUDAHost) {
152 dev_to.device_type = kDLCPU;
153 }
154
155 // In case there is a copy from host mem to host mem */
156 if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) {
157 memcpy(to, from, size);
158 return;
159 }
160
161 if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCUDA) {
162 CUDA_CALL(cudaSetDevice(dev_from.device_id));
163 if (dev_from.device_id == dev_to.device_id) {
164 GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
165 } else {
166 cudaMemcpyPeerAsync(to, dev_to.device_id, from, dev_from.device_id, size, cu_stream);
167 }
168 } else if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCPU) {
169 CUDA_CALL(cudaSetDevice(dev_from.device_id));
170 GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
171 } else if (dev_from.device_type == kDLCPU && dev_to.device_type == kDLCUDA) {
172 CUDA_CALL(cudaSetDevice(dev_to.device_id));
173 GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
174 } else {
175 LOG(FATAL) << "expect copy from/to GPU or between GPU";
176 }
177 }
178
179 public:
180 TVMStreamHandle CreateStream(Device dev) {
181 CUDA_CALL(cudaSetDevice(dev.device_id));
182 cudaStream_t retval;
183 CUDA_CALL(cudaStreamCreate(&retval));
184 return static_cast<TVMStreamHandle>(retval);
185 }
186
187 void FreeStream(Device dev, TVMStreamHandle stream) {
188 CUDA_CALL(cudaSetDevice(dev.device_id));
189 cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
190 CUDA_CALL(cudaStreamDestroy(cu_stream));
191 }
192
193 void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
194 CUDA_CALL(cudaSetDevice(dev.device_id));
195 cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
196 cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
197 cudaEvent_t evt;
198 CUDA_CALL(cudaEventCreate(&evt));
199 CUDA_CALL(cudaEventRecord(evt, src_stream));
200 CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
201 CUDA_CALL(cudaEventDestroy(evt));
202 }
203
204 void StreamSync(Device dev, TVMStreamHandle stream) final {
205 CUDA_CALL(cudaSetDevice(dev.device_id));
206 CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
207 }
208
209 void SetStream(Device dev, TVMStreamHandle stream) final {
210 CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
211 }
212
213 void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
214 return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
215 }
216
217 void FreeWorkspace(Device dev, void* data) final {
218 CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
219 }
220
221 static CUDADeviceAPI* Global() {
222 // NOTE: explicitly use new to avoid exit-time destruction of global state
223 // Global state will be recycled by OS as the process exits.
224 static auto* inst = new CUDADeviceAPI();
225 return inst;
226 }
227
228 private:
229 static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind,
230 cudaStream_t stream) {
231 if (stream != nullptr) {
232 CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
233 } else {
234 CUDA_CALL(cudaMemcpy(to, from, size, kind));
235 }
236 }
237};
238
239typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
240
241CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {}
242
243CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }
244
245TVM_REGISTER_GLOBAL("device_api.cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
246 DeviceAPI* ptr = CUDADeviceAPI::Global();
247 *rv = static_cast<void*>(ptr);
248});
249
250TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body([](TVMArgs args, TVMRetValue* rv) {
251 DeviceAPI* ptr = CUDADeviceAPI::Global();
252 *rv = static_cast<void*>(ptr);
253});
254
255class CUDATimerNode : public TimerNode {
256 public:
257 virtual void Start() {
258 // This initial cudaEventRecord is sometimes pretty slow (~100us). Does
259 // cudaEventRecord do some stream synchronization?
260 CUDA_CALL(cudaEventRecord(start_, CUDAThreadEntry::ThreadLocal()->stream));
261 }
262 virtual void Stop() { CUDA_CALL(cudaEventRecord(stop_, CUDAThreadEntry::ThreadLocal()->stream)); }
263 virtual int64_t SyncAndGetElapsedNanos() {
264 CUDA_CALL(cudaEventSynchronize(stop_));
265 float milliseconds = 0;
266 CUDA_CALL(cudaEventElapsedTime(&milliseconds, start_, stop_));
267 return milliseconds * 1e6;
268 }
269 virtual ~CUDATimerNode() {
270 CUDA_CALL(cudaEventDestroy(start_));
271 CUDA_CALL(cudaEventDestroy(stop_));
272 }
273 CUDATimerNode() {
274 CUDA_CALL(cudaEventCreate(&start_));
275 CUDA_CALL(cudaEventCreate(&stop_));
276 }
277
278 static constexpr const char* _type_key = "CUDATimerNode";
279 TVM_DECLARE_FINAL_OBJECT_INFO(CUDATimerNode, TimerNode);
280
281 private:
282 cudaEvent_t start_;
283 cudaEvent_t stop_;
284};
285
286TVM_REGISTER_OBJECT_TYPE(CUDATimerNode);
287
288TVM_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) {
289 return Timer(make_object<CUDATimerNode>());
290});
291
292TVM_DLL String GetCudaFreeMemory() {
293 size_t free_mem, total_mem;
294 CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
295 std::stringstream ss;
296 ss << "Current CUDA memory is " << free_mem << " bytes free out of " << total_mem
297 << " bytes on device";
298 return ss.str();
299}
300
301TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory);
302
303} // namespace runtime
304} // namespace tvm
305