1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuDNN kernel calls for backward algorithms.
22 */
23#include <tvm/runtime/data_type.h>
24#include <tvm/runtime/device_api.h>
25#include <tvm/runtime/registry.h>
26
27#include "cudnn_utils.h"
28
29namespace tvm {
30namespace contrib {
31
32using namespace runtime;
33
34void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[],
35 const int stride[], const int dilation[], DLTensor* dy, DLTensor* w,
36 DLTensor* dx, const std::string& conv_dtype) {
37 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
38 // Set Mode
39 entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape,
41 dy->shape, dy->dtype, conv_dtype);
42 // Set Device
43 entry_ptr->conv_entry.device = dy->device;
44 // Set Algo
45 entry_ptr->conv_entry.bwd_data_algo = static_cast<cudnnConvolutionBwdDataAlgo_t>(algo);
46
47 // Set workspace
48 size_t workspace_size = 0;
49 CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
50 entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
51 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
52 entry_ptr->conv_entry.bwd_data_algo, &workspace_size));
53 entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
54 CUDNN_CALL(cudnnConvolutionBackwardData(
55 entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
56 entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data,
57 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo,
58 entry_ptr->conv_entry.workspace, workspace_size,
59 CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc,
60 dx->data));
61}
62
63void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
64 const int dilation[], const int dy_dim[], const int w_dim[],
65 const int dx_dim[], const std::string& data_dtype,
66 const std::string& conv_dtype, TVMRetValue* ret) {
67 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
68 const int full_dims = dims + 2;
69 std::vector<int64_t> dy_dim_int64(full_dims);
70 std::vector<int64_t> w_dim_int64(full_dims);
71 std::vector<int64_t> dx_dim_int64(full_dims);
72 for (int i = 0; i < full_dims; ++i) {
73 dy_dim_int64[i] = dy_dim[i];
74 w_dim_int64[i] = w_dim[i];
75 dx_dim_int64[i] = dx_dim[i];
76 }
77 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(),
78 w_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype),
79 conv_dtype);
80
81 int returned_algo_count = 0;
82
83 cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT];
84 CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
85 entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
86 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
87 CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));
88
89 const std::vector<std::string> bwd_data_algo_names{
90 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", // non-deterministic
91 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
92 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
93 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
94 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
95 "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"};
96
97 auto best_algo = perf_results[0].algo;
98 LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
99 << bwd_data_algo_names[best_algo];
100 for (int i = 0; i < returned_algo_count; ++i) {
101 LOG(INFO) << "\t\t" << i << ") " << bwd_data_algo_names[perf_results[i].algo]
102 << " - time: " << perf_results[i].time << " ms"
103 << ", Memory: " << perf_results[i].memory;
104 }
105
106 ret[0] = best_algo;
107}
108
109void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups,
110 const int pad[], const int stride[], const int dilation[],
111 DLTensor* dy, DLTensor* x, DLTensor* dw,
112 const std::string& conv_dtype) {
113 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
114 // Set Mode
115 entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
116 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape,
117 dy->shape, x->dtype, conv_dtype);
118 // Set Device
119 entry_ptr->conv_entry.device = x->device;
120 // Set Algo
121 entry_ptr->conv_entry.bwd_filter_algo = static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo);
122
123 // Set workspace
124 size_t workspace_size = 0;
125 CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(
126 entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc,
127 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
128 entry_ptr->conv_entry.bwd_filter_algo, &workspace_size));
129 entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
130 CUDNN_CALL(cudnnConvolutionBackwardFilter(
131 entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
132 entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.output_desc, dy->data,
133 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo,
134 entry_ptr->conv_entry.workspace, workspace_size,
135 CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
136 entry_ptr->conv_entry.filter_desc, dw->data));
137}
138
139void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
140 const int dilation[], const int dy_dim[], const int x_dim[],
141 const int dw_dim[], const std::string& data_dtype,
142 const std::string& conv_dtype, TVMRetValue* ret) {
143 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
144 const int full_dims = dims + 2;
145 std::vector<int64_t> x_dim_int64(full_dims);
146 std::vector<int64_t> dy_dim_int64(full_dims);
147 std::vector<int64_t> dw_dim_int64(full_dims);
148 for (int i = 0; i < full_dims; ++i) {
149 x_dim_int64[i] = x_dim[i];
150 dy_dim_int64[i] = dy_dim[i];
151 dw_dim_int64[i] = dw_dim[i];
152 }
153 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
154 dw_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype),
155 conv_dtype);
156
157 int returned_algo_count = 0;
158
159 cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT];
160 CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(
161 entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc,
162 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
163 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results));
164
165 const std::vector<std::string> bwd_filter_algo_names{
166 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", // non-deterministic
167 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
168 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
169 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
170 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
171 "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
172 };
173
174 auto best_algo = perf_results[0].algo;
175 LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing "
176 << bwd_filter_algo_names[best_algo];
177 for (int i = 0; i < returned_algo_count; ++i) {
178 LOG(INFO) << "\t\t" << i << ") " << bwd_filter_algo_names[perf_results[i].algo]
179 << " - time: " << perf_results[i].time << " ms"
180 << ", Memory: " << perf_results[i].memory;
181 }
182
183 ret[0] = best_algo;
184}
185
186TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
187 .set_body([](TVMArgs args, TVMRetValue* ret) {
188 int mode = args[0];
189 int format = args[1];
190 int algo = args[2];
191 int pad_v[2], stride_v[2], dilation_v[2];
192 for (int i = 0; i < 2; i++) {
193 pad_v[i] = args[3 + i];
194 stride_v[i] = args[5 + i];
195 dilation_v[i] = args[7 + i];
196 }
197 DLTensor* dy = args[9];
198 DLTensor* w = args[10];
199 DLTensor* dx = args[11];
200 std::string conv_dtype = args[12];
201 int groups = args[13];
202
203 ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx,
204 conv_dtype);
205 });
206
207TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
208 .set_body([](TVMArgs args, TVMRetValue* ret) {
209 int format = args[0];
210 int dims = args[1];
211 int* pad = static_cast<int*>(static_cast<void*>(args[2]));
212 int* stride = static_cast<int*>(static_cast<void*>(args[3]));
213 int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
214 int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
215 int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
216 int* dx_dim = static_cast<int*>(static_cast<void*>(args[7]));
217 std::string data_dtype = args[8];
218 std::string conv_dtype = args[9];
219 int groups = args[10];
220
221 BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim,
222 data_dtype, conv_dtype, ret);
223 });
224
225TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter")
226 .set_body([](TVMArgs args, TVMRetValue* ret) {
227 int mode = args[0];
228 int format = args[1];
229 int algo = args[2];
230 int pad_v[2], stride_v[2], dilation_v[2];
231 for (int i = 0; i < 2; i++) {
232 pad_v[i] = args[3 + i];
233 stride_v[i] = args[5 + i];
234 dilation_v[i] = args[7 + i];
235 }
236 DLTensor* dy = args[9];
237 DLTensor* x = args[10];
238 DLTensor* dw = args[11];
239 std::string conv_dtype = args[12];
240 int groups = args[13];
241
242 ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, x,
243 dw, conv_dtype);
244 });
245
246TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo")
247 .set_body([](TVMArgs args, TVMRetValue* ret) {
248 int format = args[0];
249 int dims = args[1];
250 int* pad = static_cast<int*>(static_cast<void*>(args[2]));
251 int* stride = static_cast<int*>(static_cast<void*>(args[3]));
252 int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
253 int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
254 int* x_dim = static_cast<int*>(static_cast<void*>(args[6]));
255 int* dw_dim = static_cast<int*>(static_cast<void*>(args[7]));
256 std::string data_dtype = args[8];
257 std::string conv_dtype = args[9];
258 int groups = args[10];
259
260 BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim,
261 data_dtype, conv_dtype, ret);
262 });
263
264} // namespace contrib
265} // namespace tvm
266