1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file Use external cudnn utils function
22 */
23
24#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_
25#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_
26
27#include <cudnn.h>
28#include <tvm/runtime/device_api.h>
29#include <tvm/runtime/logging.h>
30
31#include <string>
32
33#include "../../cuda/cuda_common.h"
34
35namespace tvm {
36namespace contrib {
37
38#define CUDNN_CALL(func) \
39 { \
40 cudnnStatus_t e = (func); \
41 ICHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
42 }
43
44/*! breif Convert DLTensor type to CuDNN type */
45struct CuDNNDataType {
46 static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype);
47 template <int v>
48 static const void* GetConst(cudnnDataType_t type);
49}; // struct CuDNNDataType
50
51inline void GetStride(int nbdim, const int* dims, int* strides) {
52 int mul = 1;
53 for (int i = nbdim - 1; i >= 0; --i) {
54 mul *= dims[i];
55 strides[i] = mul;
56 }
57}
58
59inline void GetCudnnStride(int nbdim, const int* dims, int* strides) {
60 int mul = 1;
61 for (int i = nbdim - 1; i >= 0; --i) {
62 strides[i] = mul;
63 mul *= dims[i];
64 }
65}
66
67struct ConvEntry {
68 cudnnConvolutionDescriptor_t conv_desc;
69 cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION};
70 cudnnDataType_t data_type;
71 cudnnTensorFormat_t tensor_format;
72 cudnnTensorDescriptor_t input_desc;
73 cudnnFilterDescriptor_t filter_desc;
74 cudnnTensorDescriptor_t bias_desc;
75 cudnnActivationDescriptor_t activation_desc;
76 cudnnTensorDescriptor_t output_desc;
77 cudnnConvolutionFwdAlgo_t fwd_algo;
78 cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
79 cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
80 // cudnnMathType_t math_type;
81 Device device;
82 runtime::DeviceAPI* cuda_api;
83 void* workspace{nullptr};
84 size_t workspace_size{0};
85 ConvEntry();
86 ~ConvEntry();
87 void UpdateWorkspace(const size_t wsize);
88 void CleanWorkspace();
89}; // ConvThreadEntry
90
91struct SoftmaxEntry {
92 cudnnSoftmaxMode_t mode;
93 cudnnDataType_t data_type;
94 cudnnTensorDescriptor_t shape_desc;
95 SoftmaxEntry();
96 ~SoftmaxEntry();
97}; // SoftmaxEntry
98
99struct CuDNNThreadEntry {
100 CuDNNThreadEntry();
101 ~CuDNNThreadEntry();
102
103 bool exists() const { return handle; }
104
105 cudnnHandle_t handle{nullptr};
106 ConvEntry conv_entry;
107 SoftmaxEntry softmax_entry;
108 runtime::DeviceAPI* cuda_api{nullptr};
109 static CuDNNThreadEntry* ThreadLocal(bool check_exists = true);
110}; // CuDNNThreadEntry
111
112void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups,
113 const int pad[], const int stride[], const int dilation[], int64_t x_dim[],
114 int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype,
115 const std::string& conv_dtype);
116
117} // namespace contrib
118} // namespace tvm
119
120#endif // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_
121