1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuDNN kernel calls for the forward algorithm.
22 */
23#include <tvm/runtime/data_type.h>
24#include <tvm/runtime/device_api.h>
25#include <tvm/runtime/registry.h>
26
27#include "cudnn_utils.h"
28
29namespace tvm {
30namespace contrib {
31
32using namespace runtime;
33
34void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[],
35 const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
36 DLTensor* y, const std::string& conv_dtype) {
37 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
38 // Set Mode
39 entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
41 y->shape, x->dtype, conv_dtype);
42 // Set Device
43 entry_ptr->conv_entry.device = x->device;
44 // Set Algo
45 entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
46
47 // Set workspace
48 size_t workspace_size = 0;
49 CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
50 entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
51 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
52 entry_ptr->conv_entry.fwd_algo, &workspace_size));
53 entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
54 CUDNN_CALL(cudnnConvolutionForward(
55 entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
56 entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
57 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
58 entry_ptr->conv_entry.workspace, workspace_size,
59 CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
60 entry_ptr->conv_entry.output_desc, y->data));
61}
62
63void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, int groups, int act,
64 double coef, const int pad[], const int stride[],
65 const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y,
66 DLTensor* bias, const std::string& conv_dtype) {
67 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
68 // Set Mode
69 entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
70 CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc,
71 static_cast<cudnnActivationMode_t>(act),
72 cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef));
73 CUDNN_CALL(cudnnSetTensor4dDescriptor(
74 entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format,
75 CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, static_cast<int>(w->shape[0]), 1, 1));
76
77 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
78 y->shape, x->dtype, conv_dtype);
79 // Set Device
80 entry_ptr->conv_entry.device = x->device;
81 // Set Algo
82 entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
83
84 // Set workspace
85 size_t workspace_size = 0;
86 CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
87 entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
88 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
89 entry_ptr->conv_entry.fwd_algo, &workspace_size));
90 entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
91 CUDNN_CALL(cudnnConvolutionBiasActivationForward(
92 entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
93 entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
94 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
95 entry_ptr->conv_entry.workspace, workspace_size,
96 CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
97 entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.bias_desc, bias->data,
98 entry_ptr->conv_entry.activation_desc, entry_ptr->conv_entry.output_desc, y->data));
99}
100
101void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
102 const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
103 const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
104 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
105 const int full_dims = dims + 2;
106 std::vector<int64_t> x_dim_int64(full_dims);
107 std::vector<int64_t> w_dim_int64(full_dims);
108 std::vector<int64_t> y_dim_int64(full_dims);
109 for (int i = 0; i < full_dims; ++i) {
110 x_dim_int64[i] = x_dim[i];
111 w_dim_int64[i] = w_dim[i];
112 y_dim_int64[i] = y_dim[i];
113 }
114 SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
115 w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
116 conv_dtype);
117
118 int returned_algo_count = 0;
119 cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
120 CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
121 entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
122 entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
123 CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));
124
125 const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
126 "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
127 "CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
128 "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
129 "CUDNN_CONVOLUTION_FWD_ALGO_FFT",
130 "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
131 "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
132 "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};
133
134 auto best_algo = perf_results[0].algo;
135 LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
136 << fwd_algo_names[best_algo];
137 for (int i = 0; i < returned_algo_count; ++i) {
138 LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
139 << " - time: " << perf_results[i].time << " ms"
140 << ", Memory: " << perf_results[i].memory;
141 }
142
143 ret[0] = best_algo;
144}
145
146TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
147 .set_body([](TVMArgs args, TVMRetValue* ret) {
148 int mode = args[0];
149 int format = args[1];
150 int algo = args[2];
151 int pad_v[2], stride_v[2], dilation_v[2];
152 for (int i = 0; i < 2; i++) {
153 pad_v[i] = args[3 + i];
154 stride_v[i] = args[5 + i];
155 dilation_v[i] = args[7 + i];
156 }
157 DLTensor* x = args[9];
158 DLTensor* w = args[10];
159 DLTensor* y = args[11];
160 std::string conv_dtype = args[12];
161 int groups = args[13];
162
163 ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y,
164 conv_dtype);
165 });
166
167TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward")
168 .set_body([](TVMArgs args, TVMRetValue* ret) {
169 int mode = args[0];
170 int format = args[1];
171 int algo = args[2];
172 int pad_v[2], stride_v[2], dilation_v[2];
173 for (int i = 0; i < 2; i++) {
174 pad_v[i] = args[3 + i];
175 stride_v[i] = args[5 + i];
176 dilation_v[i] = args[7 + i];
177 }
178 int act = args[9];
179 double coef = args[10];
180 DLTensor* x = args[11];
181 DLTensor* w = args[12];
182 DLTensor* bias = args[13];
183 DLTensor* y = args[14];
184 std::string conv_dtype = args[15];
185 int groups = args[16];
186
187 ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v,
188 dilation_v, x, w, y, bias, conv_dtype);
189 });
190
191TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
192 .set_body([](TVMArgs args, TVMRetValue* ret) {
193 int mode = args[0];
194 int format = args[1];
195 int algo = args[2];
196 int pad_v[3], stride_v[3], dilation_v[3];
197 for (int i = 0; i < 3; i++) {
198 pad_v[i] = args[3 + i];
199 stride_v[i] = args[6 + i];
200 dilation_v[i] = args[9 + i];
201 }
202 DLTensor* x = args[12];
203 DLTensor* w = args[13];
204 DLTensor* y = args[14];
205 std::string conv_dtype = args[15];
206 int groups = args[16];
207
208 ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y,
209 conv_dtype);
210 });
211
212TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo")
213 .set_body([](TVMArgs args, TVMRetValue* ret) {
214 int format = args[0];
215 int dims = args[1];
216 int* pad = static_cast<int*>(static_cast<void*>(args[2]));
217 int* stride = static_cast<int*>(static_cast<void*>(args[3]));
218 int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
219 int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
220 int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
221 int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
222 std::string data_dtype = args[8];
223 std::string conv_dtype = args[9];
224 int groups = args[10];
225
226 FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype,
227 conv_dtype, ret);
228 });
229
230} // namespace contrib
231} // namespace tvm
232