1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file Use external cudnn utils function |
22 | */ |
23 | |
24 | #include "cudnn_utils.h" |
25 | |
26 | #include <dmlc/thread_local.h> |
27 | #include <tvm/runtime/data_type.h> |
28 | #include <tvm/runtime/registry.h> |
29 | |
30 | #include <string> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | namespace contrib { |
35 | |
36 | // CuDNN Data Type |
37 | cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { |
38 | switch (dtype.code) { |
39 | case kDLInt: |
40 | if (dtype.bits == 8 && dtype.lanes == 1) |
41 | return CUDNN_DATA_INT8; |
42 | else if (dtype.bits == 32 && dtype.lanes == 1) |
43 | return CUDNN_DATA_INT32; |
44 | else if (dtype.bits == 8 && dtype.lanes == 4) |
45 | return CUDNN_DATA_INT8x4; |
46 | else |
47 | LOG(FATAL) << "Unsupported type" ; |
48 | break; |
49 | case kDLUInt: |
50 | LOG(FATAL) << "Unsupported type" ; |
51 | break; |
52 | case kDLFloat: |
53 | if (dtype.bits == 32 && dtype.lanes == 1) |
54 | return CUDNN_DATA_FLOAT; |
55 | else if (dtype.bits == 64 && dtype.lanes == 1) |
56 | return CUDNN_DATA_DOUBLE; |
57 | else if (dtype.bits == 16 && dtype.lanes == 1) |
58 | return CUDNN_DATA_HALF; |
59 | else |
60 | LOG(FATAL) << "Unsupported type" ; |
61 | break; |
62 | } |
63 | return CUDNN_DATA_FLOAT; |
64 | } |
65 | |
66 | template <> |
67 | const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { |
68 | static const int int_v = 0; |
69 | static const float float_v = 0; |
70 | static const double double_v = 0; |
71 | if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) { |
72 | return static_cast<const void*>(&float_v); |
73 | } |
74 | if (type == CUDNN_DATA_DOUBLE) { |
75 | return static_cast<const void*>(&double_v); |
76 | } |
77 | if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) { |
78 | return static_cast<const void*>(&int_v); |
79 | } |
80 | return nullptr; |
81 | } |
82 | |
83 | template <> |
84 | const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { |
85 | static const int int_v = 1; |
86 | static const float float_v = 1.f; |
87 | static const double double_v = 1.f; |
88 | if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) { |
89 | return static_cast<const void*>(&float_v); |
90 | } |
91 | if (type == CUDNN_DATA_DOUBLE) { |
92 | return static_cast<const void*>(&double_v); |
93 | } |
94 | if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) { |
95 | return static_cast<const void*>(&int_v); |
96 | } |
97 | return nullptr; |
98 | } |
99 | |
100 | // CuDNNThreadEntry |
101 | |
102 | CuDNNThreadEntry::CuDNNThreadEntry() { |
103 | auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; |
104 | auto func = runtime::Registry::Get("device_api.cuda" ); |
105 | void* ret = (*func)(); |
106 | cuda_api = static_cast<runtime::DeviceAPI*>(ret); |
107 | |
108 | // If no CuDNN-capable device is present, allow the CuDNNThreadEntry |
109 | // object to be created. This is needed for |
110 | // CuDNNThreadEntry::exists. |
111 | { |
112 | cudnnStatus_t create_res = cudnnCreate(&handle); |
113 | if (create_res == CUDNN_STATUS_NOT_INITIALIZED) { |
114 | return; |
115 | } |
116 | CUDNN_CALL(create_res); |
117 | } |
118 | |
119 | CUDNN_CALL(cudnnSetStream(handle, stream)); |
120 | conv_entry.cuda_api = cuda_api; |
121 | } |
122 | |
123 | CuDNNThreadEntry::~CuDNNThreadEntry() {} |
124 | |
125 | typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore; |
126 | |
127 | CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) { |
128 | auto* res = CuDNNThreadStore::Get(); |
129 | if (check_exists) { |
130 | ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED" ; |
131 | } |
132 | |
133 | return res; |
134 | } |
135 | |
136 | // ConvEntry |
137 | |
138 | ConvEntry::ConvEntry() { |
139 | CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc)); |
140 | CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc)); |
141 | CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc)); |
142 | CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc)); |
143 | CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc)); |
144 | CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc)); |
145 | } |
146 | |
147 | ConvEntry::~ConvEntry() { |
148 | CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc)); |
149 | CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc)); |
150 | CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc)); |
151 | CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc)); |
152 | CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc)); |
153 | CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc)); |
154 | CleanWorkspace(); |
155 | } |
156 | |
157 | void ConvEntry::UpdateWorkspace(const size_t wsize) { |
158 | if (workspace_size < wsize) { |
159 | if (workspace != nullptr) { |
160 | CleanWorkspace(); |
161 | } |
162 | workspace_size = wsize; |
163 | workspace = cuda_api->AllocWorkspace(device, workspace_size); |
164 | } |
165 | } |
166 | |
167 | void ConvEntry::CleanWorkspace() { |
168 | if (workspace) cuda_api->FreeWorkspace(device, workspace); |
169 | workspace_size = 0; |
170 | } |
171 | |
172 | void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, |
173 | const int pad[], const int stride[], const int dilation[], int64_t x_dim[], |
174 | int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, |
175 | const std::string& conv_dtype) { |
176 | // Set Format |
177 | entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); |
178 | // Set Data Type |
179 | entry_ptr->conv_entry.data_type = |
180 | CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype)); |
181 | |
182 | cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype); |
183 | |
184 | // Dims includes N and C |
185 | int full_dims = dims + 2; |
186 | |
187 | std::vector<int> dim(full_dims); |
188 | std::vector<int> tensor_stride(full_dims); |
189 | |
190 | // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error |
191 | // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int |
192 | |
193 | CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); |
194 | if (dims == 2) { |
195 | // Set Desc |
196 | CUDNN_CALL(cudnnSetConvolution2dDescriptor( |
197 | entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], |
198 | dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); |
199 | int ni, ci, hi, wi; |
200 | if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { |
201 | ni = 0; |
202 | ci = 3; |
203 | hi = 1; |
204 | wi = 2; |
205 | } else { |
206 | ni = 0; |
207 | ci = 1; |
208 | hi = 2; |
209 | wi = 3; |
210 | } |
211 | |
212 | // Set Input |
213 | CUDNN_CALL(cudnnSetTensor4dDescriptor( |
214 | entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, |
215 | static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]), |
216 | static_cast<int>(x_dim[wi]))); |
217 | // Set Filter |
218 | CUDNN_CALL(cudnnSetFilter4dDescriptor( |
219 | entry_ptr->conv_entry.filter_desc, cudnn_data_type, entry_ptr->conv_entry.tensor_format, |
220 | static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]), |
221 | static_cast<int>(w_dim[wi]))); |
222 | // Set Output |
223 | CUDNN_CALL(cudnnSetTensor4dDescriptor( |
224 | entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, |
225 | static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]), |
226 | static_cast<int>(y_dim[wi]))); |
227 | } else { |
228 | ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors." ; |
229 | |
230 | CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, |
231 | dilation, entry_ptr->conv_entry.mode, |
232 | entry_ptr->conv_entry.data_type)); |
233 | |
234 | // Set Filter |
235 | for (int i = 0; i < full_dims; i++) { |
236 | dim[i] = static_cast<int>(w_dim[i]); |
237 | } |
238 | CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, cudnn_data_type, |
239 | entry_ptr->conv_entry.tensor_format, full_dims, |
240 | dim.data())); |
241 | // Set Input |
242 | for (int i = 0; i < full_dims; i++) { |
243 | dim[i] = static_cast<int>(x_dim[i]); |
244 | } |
245 | GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); |
246 | CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, cudnn_data_type, |
247 | full_dims, dim.data(), tensor_stride.data())); |
248 | // Set Output |
249 | for (int i = 0; i < full_dims; i++) { |
250 | dim[i] = static_cast<int>(y_dim[i]); |
251 | } |
252 | GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); |
253 | CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, cudnn_data_type, |
254 | full_dims, dim.data(), tensor_stride.data())); |
255 | } |
256 | |
257 | if (cudnnGetVersion() > 7000) { |
258 | CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) |
259 | } |
260 | } |
261 | |
262 | // SoftmaxEntry |
263 | |
264 | SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } |
265 | |
266 | SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } |
267 | |
268 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists" ).set_body_typed([]() -> bool { |
269 | return CuDNNThreadEntry::ThreadLocal(false)->exists(); |
270 | }); |
271 | |
272 | } // namespace contrib |
273 | } // namespace tvm |
274 | |