1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda_common.h
22 * \brief Common utilities for CUDA
23 */
24#ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_
25#define TVM_RUNTIME_CUDA_CUDA_COMMON_H_
26
27#include <cuda_runtime.h>
28#include <tvm/runtime/packed_func.h>
29
30#include <string>
31
32#include "../workspace_pool.h"
33
34namespace tvm {
35namespace runtime {
36
37#define CUDA_DRIVER_CALL(x) \
38 { \
39 CUresult result = x; \
40 if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
41 const char* msg; \
42 cuGetErrorName(result, &msg); \
43 LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
44 } \
45 }
46
47#define CUDA_CALL(func) \
48 { \
49 cudaError_t e = (func); \
50 ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
51 << "CUDA: " << cudaGetErrorString(e); \
52 }
53
54/*! \brief Thread local workspace */
55class CUDAThreadEntry {
56 public:
57 /*! \brief The cuda stream */
58 cudaStream_t stream{nullptr};
59 /*! \brief thread local pool*/
60 WorkspacePool pool;
61 /*! \brief constructor */
62 CUDAThreadEntry();
63 // get the threadlocal workspace
64 static CUDAThreadEntry* ThreadLocal();
65};
66} // namespace runtime
67} // namespace tvm
68#endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_
69