1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuDNN kernel calls for backward algorithms. |
22 | */ |
23 | #include <tvm/runtime/data_type.h> |
24 | #include <tvm/runtime/device_api.h> |
25 | #include <tvm/runtime/registry.h> |
26 | |
27 | #include "cudnn_utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace contrib { |
31 | |
32 | using namespace runtime; |
33 | |
34 | void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[], |
35 | const int stride[], const int dilation[], DLTensor* dy, DLTensor* w, |
36 | DLTensor* dx, const std::string& conv_dtype) { |
37 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
38 | // Set Mode |
39 | entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); |
40 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape, |
41 | dy->shape, dy->dtype, conv_dtype); |
42 | // Set Device |
43 | entry_ptr->conv_entry.device = dy->device; |
44 | // Set Algo |
45 | entry_ptr->conv_entry.bwd_data_algo = static_cast<cudnnConvolutionBwdDataAlgo_t>(algo); |
46 | |
47 | // Set workspace |
48 | size_t workspace_size = 0; |
49 | CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( |
50 | entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, |
51 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, |
52 | entry_ptr->conv_entry.bwd_data_algo, &workspace_size)); |
53 | entry_ptr->conv_entry.UpdateWorkspace(workspace_size); |
54 | CUDNN_CALL(cudnnConvolutionBackwardData( |
55 | entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), |
56 | entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data, |
57 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo, |
58 | entry_ptr->conv_entry.workspace, workspace_size, |
59 | CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc, |
60 | dx->data)); |
61 | } |
62 | |
63 | void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], |
64 | const int dilation[], const int dy_dim[], const int w_dim[], |
65 | const int dx_dim[], const std::string& data_dtype, |
66 | const std::string& conv_dtype, TVMRetValue* ret) { |
67 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
68 | const int full_dims = dims + 2; |
69 | std::vector<int64_t> dy_dim_int64(full_dims); |
70 | std::vector<int64_t> w_dim_int64(full_dims); |
71 | std::vector<int64_t> dx_dim_int64(full_dims); |
72 | for (int i = 0; i < full_dims; ++i) { |
73 | dy_dim_int64[i] = dy_dim[i]; |
74 | w_dim_int64[i] = w_dim[i]; |
75 | dx_dim_int64[i] = dx_dim[i]; |
76 | } |
77 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(), |
78 | w_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), |
79 | conv_dtype); |
80 | |
81 | int returned_algo_count = 0; |
82 | |
83 | cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT]; |
84 | CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm( |
85 | entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, |
86 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, |
87 | CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results)); |
88 | |
89 | const std::vector<std::string> bwd_data_algo_names{ |
90 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0" , // non-deterministic |
91 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1" , |
92 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT" , |
93 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING" , |
94 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD" , |
95 | "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED" }; |
96 | |
97 | auto best_algo = perf_results[0].algo; |
98 | LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing " |
99 | << bwd_data_algo_names[best_algo]; |
100 | for (int i = 0; i < returned_algo_count; ++i) { |
101 | LOG(INFO) << "\t\t" << i << ") " << bwd_data_algo_names[perf_results[i].algo] |
102 | << " - time: " << perf_results[i].time << " ms" |
103 | << ", Memory: " << perf_results[i].memory; |
104 | } |
105 | |
106 | ret[0] = best_algo; |
107 | } |
108 | |
109 | void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups, |
110 | const int pad[], const int stride[], const int dilation[], |
111 | DLTensor* dy, DLTensor* x, DLTensor* dw, |
112 | const std::string& conv_dtype) { |
113 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
114 | // Set Mode |
115 | entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); |
116 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape, |
117 | dy->shape, x->dtype, conv_dtype); |
118 | // Set Device |
119 | entry_ptr->conv_entry.device = x->device; |
120 | // Set Algo |
121 | entry_ptr->conv_entry.bwd_filter_algo = static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo); |
122 | |
123 | // Set workspace |
124 | size_t workspace_size = 0; |
125 | CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( |
126 | entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, |
127 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, |
128 | entry_ptr->conv_entry.bwd_filter_algo, &workspace_size)); |
129 | entry_ptr->conv_entry.UpdateWorkspace(workspace_size); |
130 | CUDNN_CALL(cudnnConvolutionBackwardFilter( |
131 | entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), |
132 | entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.output_desc, dy->data, |
133 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo, |
134 | entry_ptr->conv_entry.workspace, workspace_size, |
135 | CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), |
136 | entry_ptr->conv_entry.filter_desc, dw->data)); |
137 | } |
138 | |
139 | void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], |
140 | const int dilation[], const int dy_dim[], const int x_dim[], |
141 | const int dw_dim[], const std::string& data_dtype, |
142 | const std::string& conv_dtype, TVMRetValue* ret) { |
143 | CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); |
144 | const int full_dims = dims + 2; |
145 | std::vector<int64_t> x_dim_int64(full_dims); |
146 | std::vector<int64_t> dy_dim_int64(full_dims); |
147 | std::vector<int64_t> dw_dim_int64(full_dims); |
148 | for (int i = 0; i < full_dims; ++i) { |
149 | x_dim_int64[i] = x_dim[i]; |
150 | dy_dim_int64[i] = dy_dim[i]; |
151 | dw_dim_int64[i] = dw_dim[i]; |
152 | } |
153 | SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), |
154 | dw_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), |
155 | conv_dtype); |
156 | |
157 | int returned_algo_count = 0; |
158 | |
159 | cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT]; |
160 | CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( |
161 | entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, |
162 | entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, |
163 | CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results)); |
164 | |
165 | const std::vector<std::string> bwd_filter_algo_names{ |
166 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0" , // non-deterministic |
167 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1" , |
168 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT" , |
169 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3" , |
170 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED" , |
171 | "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING" , |
172 | }; |
173 | |
174 | auto best_algo = perf_results[0].algo; |
175 | LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing " |
176 | << bwd_filter_algo_names[best_algo]; |
177 | for (int i = 0; i < returned_algo_count; ++i) { |
178 | LOG(INFO) << "\t\t" << i << ") " << bwd_filter_algo_names[perf_results[i].algo] |
179 | << " - time: " << perf_results[i].time << " ms" |
180 | << ", Memory: " << perf_results[i].memory; |
181 | } |
182 | |
183 | ret[0] = best_algo; |
184 | } |
185 | |
186 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data" ) |
187 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
188 | int mode = args[0]; |
189 | int format = args[1]; |
190 | int algo = args[2]; |
191 | int pad_v[2], stride_v[2], dilation_v[2]; |
192 | for (int i = 0; i < 2; i++) { |
193 | pad_v[i] = args[3 + i]; |
194 | stride_v[i] = args[5 + i]; |
195 | dilation_v[i] = args[7 + i]; |
196 | } |
197 | DLTensor* dy = args[9]; |
198 | DLTensor* w = args[10]; |
199 | DLTensor* dx = args[11]; |
200 | std::string conv_dtype = args[12]; |
201 | int groups = args[13]; |
202 | |
203 | ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx, |
204 | conv_dtype); |
205 | }); |
206 | |
207 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo" ) |
208 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
209 | int format = args[0]; |
210 | int dims = args[1]; |
211 | int* pad = static_cast<int*>(static_cast<void*>(args[2])); |
212 | int* stride = static_cast<int*>(static_cast<void*>(args[3])); |
213 | int* dilation = static_cast<int*>(static_cast<void*>(args[4])); |
214 | int* dy_dim = static_cast<int*>(static_cast<void*>(args[5])); |
215 | int* w_dim = static_cast<int*>(static_cast<void*>(args[6])); |
216 | int* dx_dim = static_cast<int*>(static_cast<void*>(args[7])); |
217 | std::string data_dtype = args[8]; |
218 | std::string conv_dtype = args[9]; |
219 | int groups = args[10]; |
220 | |
221 | BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim, |
222 | data_dtype, conv_dtype, ret); |
223 | }); |
224 | |
225 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter" ) |
226 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
227 | int mode = args[0]; |
228 | int format = args[1]; |
229 | int algo = args[2]; |
230 | int pad_v[2], stride_v[2], dilation_v[2]; |
231 | for (int i = 0; i < 2; i++) { |
232 | pad_v[i] = args[3 + i]; |
233 | stride_v[i] = args[5 + i]; |
234 | dilation_v[i] = args[7 + i]; |
235 | } |
236 | DLTensor* dy = args[9]; |
237 | DLTensor* x = args[10]; |
238 | DLTensor* dw = args[11]; |
239 | std::string conv_dtype = args[12]; |
240 | int groups = args[13]; |
241 | |
242 | ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, x, |
243 | dw, conv_dtype); |
244 | }); |
245 | |
246 | TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo" ) |
247 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
248 | int format = args[0]; |
249 | int dims = args[1]; |
250 | int* pad = static_cast<int*>(static_cast<void*>(args[2])); |
251 | int* stride = static_cast<int*>(static_cast<void*>(args[3])); |
252 | int* dilation = static_cast<int*>(static_cast<void*>(args[4])); |
253 | int* dy_dim = static_cast<int*>(static_cast<void*>(args[5])); |
254 | int* x_dim = static_cast<int*>(static_cast<void*>(args[6])); |
255 | int* dw_dim = static_cast<int*>(static_cast<void*>(args[7])); |
256 | std::string data_dtype = args[8]; |
257 | std::string conv_dtype = args[9]; |
258 | int groups = args[10]; |
259 | |
260 | BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim, |
261 | data_dtype, conv_dtype, ret); |
262 | }); |
263 | |
264 | } // namespace contrib |
265 | } // namespace tvm |
266 | |