1/*
2 * Copyright 1993-2020 NVIDIA Corporation. All rights reserved.
3 *
4 * NOTICE TO LICENSEE:
5 *
6 * This source code and/or documentation ("Licensed Deliverables") are
7 * subject to NVIDIA intellectual property rights under U.S. and
8 * international Copyright laws.
9 *
10 * These Licensed Deliverables contained herein is PROPRIETARY and
11 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12 * conditions of a form of NVIDIA software license agreement by and
13 * between NVIDIA and Licensee ("License Agreement") or electronically
14 * accepted by Licensee. Notwithstanding any terms or conditions to
15 * the contrary in the License Agreement, reproduction or disclosure
16 * of the Licensed Deliverables to any third party without the express
17 * written consent of NVIDIA is prohibited.
18 *
19 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32 * OF THESE LICENSED DELIVERABLES.
33 *
34 * U.S. Government End Users. These Licensed Deliverables are a
35 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36 * 1995), consisting of "commercial computer software" and "commercial
37 * computer software documentation" as such terms are used in 48
38 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39 * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41 * U.S. Government End Users acquire the Licensed Deliverables with
42 * only those rights set forth herein.
43 *
44 * Any use of the Licensed Deliverables in individual and commercial
45 * software must include, in the user documentation and internal
46 * comments to the code, the above Disclaimer and U.S. Government End
47 * Users Notice.
48 */
49
50/*
51 * cudnn_cnn_train : cuDNN's basic definitions and inference CNN functions.
52 */
53
54#pragma once
55#include <cuda_runtime.h>
56#include <stdint.h>
57
58#include "cudnn_version.h"
59#include "cudnn_ops_infer.h"
60#include "cudnn_ops_train.h"
61#include "cudnn_cnn_infer.h"
62
63/* These version numbers are autogenerated, do not edit manually. */
64#define CUDNN_CNN_TRAIN_MAJOR 8
65#define CUDNN_CNN_TRAIN_MINOR 2
66#define CUDNN_CNN_TRAIN_PATCH 4
67
68#if (CUDNN_CNN_TRAIN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_TRAIN_MINOR != CUDNN_MINOR) || \
69 (CUDNN_CNN_TRAIN_PATCH != CUDNN_PATCHLEVEL)
70#error Version mismatch in cuDNN CNN INFER!!!
71#endif
72
73#if defined(__cplusplus)
74extern "C" {
75#endif
76
77/* helper function to provide the convolution backward filter algo that fit best the requirement */
78
79typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
80 cudnnConvolutionBwdFilterAlgo_t algo;
81 cudnnStatus_t status;
82 float time;
83 size_t memory;
84 cudnnDeterminism_t determinism;
85 cudnnMathType_t mathType;
86 int reserved[3];
87} cudnnConvolutionBwdFilterAlgoPerf_t;
88
89cudnnStatus_t CUDNNWINAPI
90cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
91
92cudnnStatus_t CUDNNWINAPI
93cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
94 const cudnnTensorDescriptor_t xDesc,
95 const cudnnTensorDescriptor_t dyDesc,
96 const cudnnConvolutionDescriptor_t convDesc,
97 const cudnnFilterDescriptor_t dwDesc,
98 const int requestedAlgoCount,
99 int *returnedAlgoCount,
100 cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
101
102cudnnStatus_t CUDNNWINAPI
103cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
104 const cudnnTensorDescriptor_t xDesc,
105 const void *x,
106 const cudnnTensorDescriptor_t dyDesc,
107 const void *y,
108 const cudnnConvolutionDescriptor_t convDesc,
109 const cudnnFilterDescriptor_t dwDesc,
110 void *dw,
111 const int requestedAlgoCount,
112 int *returnedAlgoCount,
113 cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
114 void *workSpace,
115 size_t workSpaceSizeInBytes);
116
117cudnnStatus_t CUDNNWINAPI
118cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
119 const cudnnTensorDescriptor_t srcDesc,
120 const cudnnTensorDescriptor_t diffDesc,
121 const cudnnConvolutionDescriptor_t convDesc,
122 const cudnnFilterDescriptor_t gradDesc,
123 const int requestedAlgoCount,
124 int *returnedAlgoCount,
125 cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
126
127/*
128 * convolution algorithm (which requires potentially some workspace)
129 */
130
131/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
132cudnnStatus_t CUDNNWINAPI
133cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
134 const cudnnTensorDescriptor_t xDesc,
135 const cudnnTensorDescriptor_t dyDesc,
136 const cudnnConvolutionDescriptor_t convDesc,
137 const cudnnFilterDescriptor_t gradDesc,
138 cudnnConvolutionBwdFilterAlgo_t algo,
139 size_t *sizeInBytes);
140
141cudnnStatus_t CUDNNWINAPI
142cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
143 const void *alpha,
144 const cudnnTensorDescriptor_t xDesc,
145 const void *x,
146 const cudnnTensorDescriptor_t dyDesc,
147 const void *dy,
148 const cudnnConvolutionDescriptor_t convDesc,
149 cudnnConvolutionBwdFilterAlgo_t algo,
150 void *workSpace,
151 size_t workSpaceSizeInBytes,
152 const void *beta,
153 const cudnnFilterDescriptor_t dwDesc,
154 void *dw);
155
156/* Function to compute the bias gradient for batch convolution */
157cudnnStatus_t CUDNNWINAPI
158cudnnConvolutionBackwardBias(cudnnHandle_t handle,
159 const void *alpha,
160 const cudnnTensorDescriptor_t dyDesc,
161 const void *dy,
162 const void *beta,
163 const cudnnTensorDescriptor_t dbDesc,
164 void *db);
165
166cudnnStatus_t CUDNNWINAPI
167cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
168
169cudnnStatus_t CUDNNWINAPI
170cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
171
172cudnnStatus_t CUDNNWINAPI
173cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
174 cudnnFusedOpsConstParamLabel_t paramLabel,
175 const void *param);
176
177cudnnStatus_t CUDNNWINAPI
178cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
179 cudnnFusedOpsConstParamLabel_t paramLabel,
180 void *param,
181 int *isNULL);
182
183cudnnStatus_t CUDNNWINAPI
184cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
185
186cudnnStatus_t CUDNNWINAPI
187cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
188
189cudnnStatus_t CUDNNWINAPI
190cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
191 cudnnFusedOpsVariantParamLabel_t paramLabel,
192 void *ptr);
193
194cudnnStatus_t CUDNNWINAPI
195cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
196 cudnnFusedOpsVariantParamLabel_t paramLabel,
197 void *ptr);
198
199cudnnStatus_t CUDNNWINAPI
200cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
201
202cudnnStatus_t CUDNNWINAPI
203cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
204
205cudnnStatus_t CUDNNWINAPI
206cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
207 cudnnFusedOpsPlan_t plan,
208 const cudnnFusedOpsConstParamPack_t constPack,
209 size_t *workspaceSizeInBytes);
210
211cudnnStatus_t CUDNNWINAPI
212cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
213
214cudnnStatus_t CUDNNWINAPI
215cudnnCnnTrainVersionCheck(void);
216
217#if defined(__cplusplus)
218}
219#endif
220