1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file Use external cudnn utils function
22 */
23
24#include "cudnn_utils.h"
25
26#include <dmlc/thread_local.h>
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/registry.h>
29
30#include <string>
31#include <vector>
32
33namespace tvm {
34namespace contrib {
35
36// CuDNN Data Type
37cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) {
38 switch (dtype.code) {
39 case kDLInt:
40 if (dtype.bits == 8 && dtype.lanes == 1)
41 return CUDNN_DATA_INT8;
42 else if (dtype.bits == 32 && dtype.lanes == 1)
43 return CUDNN_DATA_INT32;
44 else if (dtype.bits == 8 && dtype.lanes == 4)
45 return CUDNN_DATA_INT8x4;
46 else
47 LOG(FATAL) << "Unsupported type";
48 break;
49 case kDLUInt:
50 LOG(FATAL) << "Unsupported type";
51 break;
52 case kDLFloat:
53 if (dtype.bits == 32 && dtype.lanes == 1)
54 return CUDNN_DATA_FLOAT;
55 else if (dtype.bits == 64 && dtype.lanes == 1)
56 return CUDNN_DATA_DOUBLE;
57 else if (dtype.bits == 16 && dtype.lanes == 1)
58 return CUDNN_DATA_HALF;
59 else
60 LOG(FATAL) << "Unsupported type";
61 break;
62 }
63 return CUDNN_DATA_FLOAT;
64}
65
66template <>
67const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) {
68 static const int int_v = 0;
69 static const float float_v = 0;
70 static const double double_v = 0;
71 if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
72 return static_cast<const void*>(&float_v);
73 }
74 if (type == CUDNN_DATA_DOUBLE) {
75 return static_cast<const void*>(&double_v);
76 }
77 if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
78 return static_cast<const void*>(&int_v);
79 }
80 return nullptr;
81}
82
83template <>
84const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
85 static const int int_v = 1;
86 static const float float_v = 1.f;
87 static const double double_v = 1.f;
88 if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
89 return static_cast<const void*>(&float_v);
90 }
91 if (type == CUDNN_DATA_DOUBLE) {
92 return static_cast<const void*>(&double_v);
93 }
94 if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
95 return static_cast<const void*>(&int_v);
96 }
97 return nullptr;
98}
99
100// CuDNNThreadEntry
101
102CuDNNThreadEntry::CuDNNThreadEntry() {
103 auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
104 auto func = runtime::Registry::Get("device_api.cuda");
105 void* ret = (*func)();
106 cuda_api = static_cast<runtime::DeviceAPI*>(ret);
107
108 // If no CuDNN-capable device is present, allow the CuDNNThreadEntry
109 // object to be created. This is needed for
110 // CuDNNThreadEntry::exists.
111 {
112 cudnnStatus_t create_res = cudnnCreate(&handle);
113 if (create_res == CUDNN_STATUS_NOT_INITIALIZED) {
114 return;
115 }
116 CUDNN_CALL(create_res);
117 }
118
119 CUDNN_CALL(cudnnSetStream(handle, stream));
120 conv_entry.cuda_api = cuda_api;
121}
122
123CuDNNThreadEntry::~CuDNNThreadEntry() {}
124
125typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;
126
127CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) {
128 auto* res = CuDNNThreadStore::Get();
129 if (check_exists) {
130 ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED";
131 }
132
133 return res;
134}
135
136// ConvEntry
137
138ConvEntry::ConvEntry() {
139 CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
140 CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
141 CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
142 CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
143 CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc));
144 CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc));
145}
146
147ConvEntry::~ConvEntry() {
148 CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc));
149 CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
150 CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
151 CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
152 CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc));
153 CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc));
154 CleanWorkspace();
155}
156
157void ConvEntry::UpdateWorkspace(const size_t wsize) {
158 if (workspace_size < wsize) {
159 if (workspace != nullptr) {
160 CleanWorkspace();
161 }
162 workspace_size = wsize;
163 workspace = cuda_api->AllocWorkspace(device, workspace_size);
164 }
165}
166
167void ConvEntry::CleanWorkspace() {
168 if (workspace) cuda_api->FreeWorkspace(device, workspace);
169 workspace_size = 0;
170}
171
172void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups,
173 const int pad[], const int stride[], const int dilation[], int64_t x_dim[],
174 int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype,
175 const std::string& conv_dtype) {
176 // Set Format
177 entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
178 // Set Data Type
179 entry_ptr->conv_entry.data_type =
180 CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype));
181
182 cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype);
183
184 // Dims includes N and C
185 int full_dims = dims + 2;
186
187 std::vector<int> dim(full_dims);
188 std::vector<int> tensor_stride(full_dims);
189
190 // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
191 // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
192
193 CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
194 if (dims == 2) {
195 // Set Desc
196 CUDNN_CALL(cudnnSetConvolution2dDescriptor(
197 entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
198 dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
199 int ni, ci, hi, wi;
200 if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
201 ni = 0;
202 ci = 3;
203 hi = 1;
204 wi = 2;
205 } else {
206 ni = 0;
207 ci = 1;
208 hi = 2;
209 wi = 3;
210 }
211
212 // Set Input
213 CUDNN_CALL(cudnnSetTensor4dDescriptor(
214 entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type,
215 static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
216 static_cast<int>(x_dim[wi])));
217 // Set Filter
218 CUDNN_CALL(cudnnSetFilter4dDescriptor(
219 entry_ptr->conv_entry.filter_desc, cudnn_data_type, entry_ptr->conv_entry.tensor_format,
220 static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
221 static_cast<int>(w_dim[wi])));
222 // Set Output
223 CUDNN_CALL(cudnnSetTensor4dDescriptor(
224 entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type,
225 static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
226 static_cast<int>(y_dim[wi])));
227 } else {
228 ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors.";
229
230 CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
231 dilation, entry_ptr->conv_entry.mode,
232 entry_ptr->conv_entry.data_type));
233
234 // Set Filter
235 for (int i = 0; i < full_dims; i++) {
236 dim[i] = static_cast<int>(w_dim[i]);
237 }
238 CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, cudnn_data_type,
239 entry_ptr->conv_entry.tensor_format, full_dims,
240 dim.data()));
241 // Set Input
242 for (int i = 0; i < full_dims; i++) {
243 dim[i] = static_cast<int>(x_dim[i]);
244 }
245 GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
246 CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, cudnn_data_type,
247 full_dims, dim.data(), tensor_stride.data()));
248 // Set Output
249 for (int i = 0; i < full_dims; i++) {
250 dim[i] = static_cast<int>(y_dim[i]);
251 }
252 GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
253 CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, cudnn_data_type,
254 full_dims, dim.data(), tensor_stride.data()));
255 }
256
257 if (cudnnGetVersion() > 7000) {
258 CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
259 }
260}
261
262// SoftmaxEntry
263
264SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); }
265
266SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); }
267
268TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool {
269 return CuDNNThreadEntry::ThreadLocal(false)->exists();
270});
271
272} // namespace contrib
273} // namespace tvm
274