1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuDNN kernel calls for the forward algorithm. |
22 | */ |
23 | #include <tvm/runtime/data_type.h> |
24 | #include <tvm/runtime/device_api.h> |
25 | #include <tvm/runtime/registry.h> |
26 | |
27 | #include "cudnn_utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace contrib { |
31 | |
32 | using namespace runtime; |
33 | |
34 | void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], |
35 | const int stride[], const int dilation[], DLTensor* x, DLTensor* w, |
36 | DLTensor* y, const std::string& conv_dtype) { |
37 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
38 | // Set Mode |
39 | entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); |
40 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, |
41 | y->shape, x->dtype, conv_dtype); |
42 | // Set Device |
43 | entry_ptr->conv_entry.device = x->device; |
44 | // Set Algo |
45 | entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo); |
46 | |
47 | // Set workspace |
48 | size_t workspace_size = 0; |
49 | CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( |
50 | entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, |
51 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, |
52 | entry_ptr->conv_entry.fwd_algo, &workspace_size)); |
53 | entry_ptr->conv_entry.UpdateWorkspace(workspace_size); |
54 | CUDNN_CALL(cudnnConvolutionForward( |
55 | entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), |
56 | entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, |
57 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, |
58 | entry_ptr->conv_entry.workspace, workspace_size, |
59 | CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), |
60 | entry_ptr->conv_entry.output_desc, y->data)); |
61 | } |
62 | |
63 | void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, int groups, int act, |
64 | double coef, const int pad[], const int stride[], |
65 | const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y, |
66 | DLTensor* bias, const std::string& conv_dtype) { |
67 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
68 | // Set Mode |
69 | entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); |
70 | CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc, |
71 | static_cast<cudnnActivationMode_t>(act), |
72 | cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef)); |
73 | CUDNN_CALL(cudnnSetTensor4dDescriptor( |
74 | entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format, |
75 | CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, static_cast<int>(w->shape[0]), 1, 1)); |
76 | |
77 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, |
78 | y->shape, x->dtype, conv_dtype); |
79 | // Set Device |
80 | entry_ptr->conv_entry.device = x->device; |
81 | // Set Algo |
82 | entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo); |
83 | |
84 | // Set workspace |
85 | size_t workspace_size = 0; |
86 | CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( |
87 | entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, |
88 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, |
89 | entry_ptr->conv_entry.fwd_algo, &workspace_size)); |
90 | entry_ptr->conv_entry.UpdateWorkspace(workspace_size); |
91 | CUDNN_CALL(cudnnConvolutionBiasActivationForward( |
92 | entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), |
93 | entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, |
94 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, |
95 | entry_ptr->conv_entry.workspace, workspace_size, |
96 | CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), |
97 | entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.bias_desc, bias->data, |
98 | entry_ptr->conv_entry.activation_desc, entry_ptr->conv_entry.output_desc, y->data)); |
99 | } |
100 | |
101 | void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], |
102 | const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], |
103 | const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { |
104 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
105 | const int full_dims = dims + 2; |
106 | std::vector<int64_t> x_dim_int64(full_dims); |
107 | std::vector<int64_t> w_dim_int64(full_dims); |
108 | std::vector<int64_t> y_dim_int64(full_dims); |
109 | for (int i = 0; i < full_dims; ++i) { |
110 | x_dim_int64[i] = x_dim[i]; |
111 | w_dim_int64[i] = w_dim[i]; |
112 | y_dim_int64[i] = y_dim[i]; |
113 | } |
114 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), |
115 | w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype), |
116 | conv_dtype); |
117 | |
118 | int returned_algo_count = 0; |
119 | cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; |
120 | CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( |
121 | entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, |
122 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, |
123 | CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results)); |
124 | |
125 | const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM" , |
126 | "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM" , |
127 | "CUDNN_CONVOLUTION_FWD_ALGO_GEMM" , |
128 | "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT" , |
129 | "CUDNN_CONVOLUTION_FWD_ALGO_FFT" , |
130 | "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING" , |
131 | "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD" , |
132 | "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" }; |
133 | |
134 | auto best_algo = perf_results[0].algo; |
135 | LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing " |
136 | << fwd_algo_names[best_algo]; |
137 | for (int i = 0; i < returned_algo_count; ++i) { |
138 | LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] |
139 | << " - time: " << perf_results[i].time << " ms" |
140 | << ", Memory: " << perf_results[i].memory; |
141 | } |
142 | |
143 | ret[0] = best_algo; |
144 | } |
145 | |
146 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward" ) |
147 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
148 | int mode = args[0]; |
149 | int format = args[1]; |
150 | int algo = args[2]; |
151 | int pad_v[2], stride_v[2], dilation_v[2]; |
152 | for (int i = 0; i < 2; i++) { |
153 | pad_v[i] = args[3 + i]; |
154 | stride_v[i] = args[5 + i]; |
155 | dilation_v[i] = args[7 + i]; |
156 | } |
157 | DLTensor* x = args[9]; |
158 | DLTensor* w = args[10]; |
159 | DLTensor* y = args[11]; |
160 | std::string conv_dtype = args[12]; |
161 | int groups = args[13]; |
162 | |
163 | ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y, |
164 | conv_dtype); |
165 | }); |
166 | |
167 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward" ) |
168 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
169 | int mode = args[0]; |
170 | int format = args[1]; |
171 | int algo = args[2]; |
172 | int pad_v[2], stride_v[2], dilation_v[2]; |
173 | for (int i = 0; i < 2; i++) { |
174 | pad_v[i] = args[3 + i]; |
175 | stride_v[i] = args[5 + i]; |
176 | dilation_v[i] = args[7 + i]; |
177 | } |
178 | int act = args[9]; |
179 | double coef = args[10]; |
180 | DLTensor* x = args[11]; |
181 | DLTensor* w = args[12]; |
182 | DLTensor* bias = args[13]; |
183 | DLTensor* y = args[14]; |
184 | std::string conv_dtype = args[15]; |
185 | int groups = args[16]; |
186 | |
187 | ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v, |
188 | dilation_v, x, w, y, bias, conv_dtype); |
189 | }); |
190 | |
191 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward" ) |
192 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
193 | int mode = args[0]; |
194 | int format = args[1]; |
195 | int algo = args[2]; |
196 | int pad_v[3], stride_v[3], dilation_v[3]; |
197 | for (int i = 0; i < 3; i++) { |
198 | pad_v[i] = args[3 + i]; |
199 | stride_v[i] = args[6 + i]; |
200 | dilation_v[i] = args[9 + i]; |
201 | } |
202 | DLTensor* x = args[12]; |
203 | DLTensor* w = args[13]; |
204 | DLTensor* y = args[14]; |
205 | std::string conv_dtype = args[15]; |
206 | int groups = args[16]; |
207 | |
208 | ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y, |
209 | conv_dtype); |
210 | }); |
211 | |
212 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo" ) |
213 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
214 | int format = args[0]; |
215 | int dims = args[1]; |
216 | int* pad = static_cast<int*>(static_cast<void*>(args[2])); |
217 | int* stride = static_cast<int*>(static_cast<void*>(args[3])); |
218 | int* dilation = static_cast<int*>(static_cast<void*>(args[4])); |
219 | int* x_dim = static_cast<int*>(static_cast<void*>(args[5])); |
220 | int* w_dim = static_cast<int*>(static_cast<void*>(args[6])); |
221 | int* y_dim = static_cast<int*>(static_cast<void*>(args[7])); |
222 | std::string data_dtype = args[8]; |
223 | std::string conv_dtype = args[9]; |
224 | int groups = args[10]; |
225 | |
226 | FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, |
227 | conv_dtype, ret); |
228 | }); |
229 | |
230 | } // namespace contrib |
231 | } // namespace tvm |
232 | |