1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // This is the common header for the input and filter backprop kernels. |
17 | // |
18 | // The operation to compute Conv2D gradients. |
19 | // |
20 | // To compute the gradients for Conv2D, we need three input tensors: |
21 | // input, filter, and backprop for output. |
22 | // And we need to compute two backprops: one for input and one for filter. We |
23 | // compute them in two different kernels. |
24 | // |
25 | // Both backprops can be computed as straightforward conv2d. |
26 | // |
27 | // Consider a case where the input is 3x3 and the filter is 2x1: |
28 | // |
29 | // INPUT = [ A B C ] |
30 | // [ D E F ] |
31 | // [ G H I ] |
32 | // |
33 | // where each "A", "B", etc is batch x in_depth |
34 | // |
35 | // FILTER = [ X Y ] |
36 | // |
37 | // where both "X" and "Y" are in_depth x out_depth |
38 | // |
39 | // With VALID padding, the output is 3x2: |
40 | // |
41 | // OUTPUT = [ a b ] |
42 | // [ c d ] |
43 | // [ e f ] |
44 | // |
45 | // where each "a", "b", etc is batch x out_depth |
46 | // |
47 | // So we have: |
48 | // |
49 | // a = A * X + B * Y |
50 | // b = B * X + C * Y |
51 | // c = D * X + E * Y |
52 | // d = E * X + F * Y |
53 | // e = G * X + H * Y |
54 | // f = H * X + I * Y |
55 | // |
56 | // So when we have backprops for the outputs (we denote them by |
57 | // a', b', ... ): |
58 | // |
59 | // The backprops for the input are: |
60 | // |
61 | // A' = a' * X^t |
62 | // B' = a' * Y^t + b' * X^t |
63 | // C' = b' * Y^t |
64 | // ... |
65 | // |
66 | // This is essentially computing a 2d conv of |
67 | // |
68 | // INPUT = [ 0 a' b' 0 ] |
69 | // [ 0 c' d' 0 ] |
70 | // [ 0 e' f' 0 ] |
71 | // and |
72 | // |
73 | // FILTER = [ Y^t X^t ] |
74 | // |
75 | // The backprops for the filter are: |
76 | // |
77 | // X' = A^t * a' + B^t * b' + D^t * c' + E^t * d' + G^t * e' + H^t * f' |
78 | // Y' = B^t * a' + C^t * b' + E^t + c' + F^t * d' + H^t * e' + I^t * f' |
79 | // |
80 | // This is essentially computing a 2d conv of |
81 | // |
82 | // INPUT = [ A^t B^t C^t ] |
83 | // [ D^t E^t F^t ] |
84 | // [ G^t H^t I^t ] |
85 | // |
86 | // and |
87 | // |
88 | // FILTER = [ a' b' ] |
89 | // [ c' d' ] |
90 | // [ e' f' ] |
91 | // |
92 | // |
93 | ////////////////////////////////////////////////////////// |
94 | // |
95 | // With stride more than one, it's a bit more complicated (we will need to |
96 | // create holes to the backprop). |
97 | // |
98 | // Consider the case where |
99 | // |
100 | // INPUT = [ A B C D E ] |
101 | // [ F G H I J ] |
102 | // [ K L M N O ] |
103 | // and |
104 | // |
105 | // FILTER = [ X Y Z ] |
106 | // |
107 | // with stride 2. |
108 | // |
109 | // The output will be |
110 | // |
111 | // OUTPUT = [ a b ] |
112 | // [ c d ] |
113 | // |
114 | // where: |
115 | // |
116 | // a = A * X + B * Y + C * Z |
117 | // b = C * X + D * Y + E * Z |
118 | // c = K * X + L * Y + M * Z |
119 | // d = M * X + N * Y + O * Z |
120 | // |
121 | // |
122 | // To compute the backprop for INPUT, we need to convolve |
123 | // |
124 | // INPUT = [ 0 0 a' 0 b' 0 0 ] |
125 | // [ 0 0 0 0 0 0 0 ] |
126 | // [ 0 0 c' 0 d' 0 0 ] |
127 | // |
128 | // (notice the holes in INPUT) |
129 | // |
130 | // and |
131 | // |
132 | // FILTER = [ Z^t Y^t X^t ] |
133 | // |
134 | // with stride 1. |
135 | // |
136 | // To compute the backprop for FILTER, we need to convolve |
137 | |
138 | // |
139 | // INPUT = [ A^t B^t C^t D^t E^t ] |
140 | // [ F^t G^t H^t I^t J^t ] |
141 | // [ K^t L^t M^t N^t O^t ] |
142 | // and |
143 | // |
144 | // FILTER = [ a' 0 b' ] |
145 | // [ 0 0 0 ] |
146 | // [ c' 0 d' ] |
147 | // |
148 | // (notice the holes in FILTER) |
149 | // |
150 | // |
151 | // with stride 1 |
152 | // |
153 | ////////////////////////////////////////////////////////// |
154 | // |
155 | // |
156 | // The case for SAME padding is in fact very similar to VALID -- we just |
157 | // need to pad the input tensor a bit when computing the filter_backprop. |
158 | |
159 | #ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ |
160 | #define TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ |
161 | |
162 | #include <vector> |
163 | |
164 | #include "tensorflow/core/util/padding.h" |
165 | #include "tensorflow/core/util/tensor_format.h" |
166 | |
167 | namespace tensorflow { |
168 | |
169 | // Forward declaration. |
170 | class OpKernelContext; |
171 | |
172 | template <typename Device, typename T> |
173 | struct LaunchConv2DBackpropInputOp { |
174 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
175 | const Tensor& out_backprop, const Tensor& filter, |
176 | int row_dilation, int col_dilation, int row_stride, |
177 | int col_stride, const Padding& padding, |
178 | const std::vector<int64_t>& explicit_paddings, |
179 | Tensor* in_backprop, TensorFormat data_format); |
180 | }; |
181 | |
182 | template <typename Device, typename T> |
183 | struct LaunchConv2DBackpropFilterOp { |
184 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
185 | const Tensor& out_backprop, const Tensor& input, |
186 | int row_dilation, int col_dilation, int row_stride, |
187 | int col_stride, const Padding& padding, |
188 | const std::vector<int64_t>& explicit_paddings, |
189 | Tensor* filter_backprop, TensorFormat data_format); |
190 | }; |
191 | |
192 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
193 | template <typename T> |
194 | struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> { |
195 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
196 | const Tensor& input, const Tensor& filter, int row_dilation, |
197 | int col_dilation, int row_stride, int col_stride, |
198 | const Padding& padding, |
199 | const std::vector<int64_t>& explicit_paddings, Tensor* output, |
200 | TensorFormat data_format); |
201 | }; |
202 | |
203 | template <typename T> |
204 | struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> { |
205 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
206 | const Tensor& out_backprop, const Tensor& input, |
207 | int row_dilation, int col_dilation, int row_stride, |
208 | int col_stride, const Padding& padding, |
209 | const std::vector<int64_t>& explicit_paddings, |
210 | Tensor* filter_backprop, TensorFormat data_format); |
211 | }; |
212 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
213 | } // namespace tensorflow |
214 | |
215 | #endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ |
216 | |