1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file Use external cudnn utils function |
22 | */ |
23 | |
24 | #ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ |
25 | #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ |
26 | |
27 | #include <cudnn.h> |
28 | #include <tvm/runtime/device_api.h> |
29 | #include <tvm/runtime/logging.h> |
30 | |
31 | #include <string> |
32 | |
33 | #include "../../cuda/cuda_common.h" |
34 | |
35 | namespace tvm { |
36 | namespace contrib { |
37 | |
38 | #define CUDNN_CALL(func) \ |
39 | { \ |
40 | cudnnStatus_t e = (func); \ |
41 | ICHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ |
42 | } |
43 | |
44 | /*! breif Convert DLTensor type to CuDNN type */ |
45 | struct CuDNNDataType { |
46 | static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype); |
47 | template <int v> |
48 | static const void* GetConst(cudnnDataType_t type); |
49 | }; // struct CuDNNDataType |
50 | |
51 | inline void GetStride(int nbdim, const int* dims, int* strides) { |
52 | int mul = 1; |
53 | for (int i = nbdim - 1; i >= 0; --i) { |
54 | mul *= dims[i]; |
55 | strides[i] = mul; |
56 | } |
57 | } |
58 | |
59 | inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { |
60 | int mul = 1; |
61 | for (int i = nbdim - 1; i >= 0; --i) { |
62 | strides[i] = mul; |
63 | mul *= dims[i]; |
64 | } |
65 | } |
66 | |
67 | struct ConvEntry { |
68 | cudnnConvolutionDescriptor_t conv_desc; |
69 | cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION}; |
70 | cudnnDataType_t data_type; |
71 | cudnnTensorFormat_t tensor_format; |
72 | cudnnTensorDescriptor_t input_desc; |
73 | cudnnFilterDescriptor_t filter_desc; |
74 | cudnnTensorDescriptor_t bias_desc; |
75 | cudnnActivationDescriptor_t activation_desc; |
76 | cudnnTensorDescriptor_t output_desc; |
77 | cudnnConvolutionFwdAlgo_t fwd_algo; |
78 | cudnnConvolutionBwdDataAlgo_t bwd_data_algo; |
79 | cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; |
80 | // cudnnMathType_t math_type; |
81 | Device device; |
82 | runtime::DeviceAPI* cuda_api; |
83 | void* workspace{nullptr}; |
84 | size_t workspace_size{0}; |
85 | ConvEntry(); |
86 | ~ConvEntry(); |
87 | void UpdateWorkspace(const size_t wsize); |
88 | void CleanWorkspace(); |
89 | }; // ConvThreadEntry |
90 | |
91 | struct SoftmaxEntry { |
92 | cudnnSoftmaxMode_t mode; |
93 | cudnnDataType_t data_type; |
94 | cudnnTensorDescriptor_t shape_desc; |
95 | SoftmaxEntry(); |
96 | ~SoftmaxEntry(); |
97 | }; // SoftmaxEntry |
98 | |
99 | struct CuDNNThreadEntry { |
100 | CuDNNThreadEntry(); |
101 | ~CuDNNThreadEntry(); |
102 | |
103 | bool exists() const { return handle; } |
104 | |
105 | cudnnHandle_t handle{nullptr}; |
106 | ConvEntry conv_entry; |
107 | SoftmaxEntry softmax_entry; |
108 | runtime::DeviceAPI* cuda_api{nullptr}; |
109 | static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); |
110 | }; // CuDNNThreadEntry |
111 | |
112 | void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, |
113 | const int pad[], const int stride[], const int dilation[], int64_t x_dim[], |
114 | int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, |
115 | const std::string& conv_dtype); |
116 | |
117 | } // namespace contrib |
118 | } // namespace tvm |
119 | |
120 | #endif // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ |
121 | |