1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/kernels/eigen_spatial_convolutions.h" |
21 | |
22 | namespace Eigen { |
23 | |
24 | /** SpatialConvolutionBackwardInput |
25 | * \ingroup CXX11_NeuralNetworks_Module |
26 | * |
27 | * \brief Computes the backprop for the input of a 2D convolution. |
28 | * |
29 | * The output_backward parameter is expected to be a tensor with a rank of 3 or |
30 | * more (channels, height, width, and optionally others) |
31 | * The kernel parameter is expected to be a 4D tensor (filters, channels, |
32 | * kernel_height, kernel_width) |
33 | * The output_backward and the kernel must both be in col-major layout. The |
34 | * result will also be in col-major layout. |
35 | * |
36 | * If row_in_stride, col_in_stride > 1, then applies convolution with holes |
37 | * (aka atrous convolution), sampling every row_in_stride, col_in_stride input |
38 | * pixels. |
39 | * |
40 | * The result can be assigned to a tensor of rank equal to the rank of the |
41 | * output_backward. The dimensions of the result will be filters, height, width |
42 | * (and others if applicable). |
43 | * |
44 | * It is possible to swap the order of the width and height dimensions provided |
45 | * that the same order is used in the input, the kernel, and the output. |
46 | * |
47 | */ |
48 | typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>> |
49 | ReverseColMajor; |
50 | typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>> |
51 | ReverseRowMajor; |
52 | |
53 | template <typename OutputBackward, typename Kernel> |
54 | EIGEN_ALWAYS_INLINE static const std::conditional_t< |
55 | internal::traits<OutputBackward>::Layout == ColMajor, |
56 | TensorReshapingOp< |
57 | const DSizes<typename internal::traits<OutputBackward>::Index, |
58 | internal::traits<OutputBackward>::NumDimensions>, |
59 | const TensorContractionOp< |
60 | const array< |
61 | IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, |
62 | const TensorReshapingOp< |
63 | const DSizes<typename internal::traits<OutputBackward>::Index, |
64 | 2>, |
65 | const Eigen::TensorForcedEvalOp<const TensorShufflingOp< |
66 | const array< |
67 | typename internal::traits<OutputBackward>::Index, 4>, |
68 | const Eigen::TensorForcedEvalOp<const TensorReverseOp< |
69 | const ReverseColMajor, const Kernel>>>>>, |
70 | const TensorReshapingOp< |
71 | const DSizes<typename internal::traits<OutputBackward>::Index, |
72 | 2>, |
73 | const TensorImagePatchOp<Dynamic, Dynamic, |
74 | const OutputBackward>>>>, |
75 | TensorReshapingOp< |
76 | |
77 | const DSizes<typename internal::traits<OutputBackward>::Index, |
78 | internal::traits<OutputBackward>::NumDimensions>, |
79 | const TensorContractionOp< |
80 | const array< |
81 | IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, |
82 | const TensorReshapingOp< |
83 | const DSizes<typename internal::traits<OutputBackward>::Index, |
84 | 2>, |
85 | const TensorImagePatchOp<Dynamic, Dynamic, |
86 | const OutputBackward>>, |
87 | const TensorReshapingOp< |
88 | const DSizes<typename internal::traits<OutputBackward>::Index, |
89 | 2>, |
90 | const Eigen::TensorForcedEvalOp<const TensorShufflingOp< |
91 | const array< |
92 | typename internal::traits<OutputBackward>::Index, 4>, |
93 | const Eigen::TensorForcedEvalOp<const TensorReverseOp< |
94 | const ReverseRowMajor, const Kernel>>>>>>>> |
95 | SpatialConvolutionBackwardInput( |
96 | const Kernel& kernel, const OutputBackward& output_backward, |
97 | typename internal::traits<OutputBackward>::Index inputRows, |
98 | typename internal::traits<OutputBackward>::Index inputCols, |
99 | const DenseIndex row_stride = 1, const DenseIndex col_stride = 1, |
100 | const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) { |
101 | typedef typename internal::traits<OutputBackward>::Index TensorIndex; |
102 | typedef typename internal::traits<OutputBackward>::Scalar OutScalar; |
103 | TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, |
104 | internal::traits<Kernel>::NumDimensions, |
105 | internal::traits<Kernel>::Layout, TensorIndex>> |
106 | kern(kernel); |
107 | TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions, |
108 | internal::traits<OutputBackward>::Layout, TensorIndex>> |
109 | out(output_backward); |
110 | |
111 | EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == |
112 | internal::traits<OutputBackward>::Layout, |
113 | YOU_MADE_A_PROGRAMMING_MISTAKE); |
114 | |
115 | static const bool isColMajor = |
116 | (internal::traits<OutputBackward>::Layout == ColMajor); |
117 | |
118 | static const int NumDims = internal::traits<OutputBackward>::NumDimensions; |
119 | |
120 | // Number of filters to apply. This is the same as the output depth of the |
121 | // result |
122 | const TensorIndex kernelFilters = |
123 | isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; |
124 | // Number of channels. This is the same as the input depth. |
125 | const TensorIndex kernelChannels = |
126 | isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; |
127 | const TensorIndex kernelRows = |
128 | isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; |
129 | const TensorIndex kernelCols = |
130 | isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; |
131 | |
132 | // This is the effective kernel size, taking into account the (*_in_stride - |
133 | // 1) zero-values |
134 | // inserted between consecutive kernel elements in atrous convolution |
135 | const TensorIndex kernelRowsEff = |
136 | kernelRows + (kernelRows - 1) * (row_in_stride - 1); |
137 | const TensorIndex kernelColsEff = |
138 | kernelCols + (kernelCols - 1) * (col_in_stride - 1); |
139 | |
140 | const TensorIndex outputRows = isColMajor |
141 | ? output_backward.dimension(1) |
142 | : output_backward.dimension(NumDims - 2); |
143 | const TensorIndex outputCols = isColMajor |
144 | ? output_backward.dimension(2) |
145 | : output_backward.dimension(NumDims - 3); |
146 | |
147 | // Computing the forward padding |
148 | const TensorIndex forward_pad_top = numext::maxi<Index>( |
149 | 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); |
150 | const TensorIndex forward_pad_left = numext::maxi<Index>( |
151 | 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); |
152 | const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; |
153 | const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; |
154 | |
155 | const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - |
156 | 2 - padding_top + kernelRowsEff; |
157 | const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - |
158 | 2 - padding_left + kernelColsEff; |
159 | |
160 | eigen_assert(padding_top >= 0); |
161 | eigen_assert(padding_left >= 0); |
162 | eigen_assert(padding_bottom >= 0); |
163 | eigen_assert(padding_right >= 0); |
164 | |
165 | // The kernel has dimensions filters X channels X patch_rows X patch_cols |
166 | // We need to reverse the kernel along dimensions corresponding to rows and |
167 | // cols. |
168 | // TODO(yangke): we can make things slightly faster by collapsing the |
169 | // dimensions |
170 | // where we don't reverse. Try that once we have a faster compiler. |
171 | typedef std::conditional_t<isColMajor, ReverseColMajor, ReverseRowMajor> |
172 | Reverse; |
173 | Reverse kernel_reverse; |
174 | // Reorder the dimensions to: |
175 | // filters x patch_rows x patch_cols x channels |
176 | array<TensorIndex, 4> kernel_shuffle; |
177 | if (isColMajor) { |
178 | // From: filters x channels x rows x cols |
179 | // To: filters x rows x cols x channels |
180 | kernel_shuffle[0] = 0; |
181 | kernel_shuffle[1] = 2; |
182 | kernel_shuffle[2] = 3; |
183 | kernel_shuffle[3] = 1; |
184 | } else { |
185 | // From: cols x rows x channels x filters |
186 | // To: channels x cols x rows x filters |
187 | kernel_shuffle[0] = 2; |
188 | kernel_shuffle[1] = 0; |
189 | kernel_shuffle[2] = 1; |
190 | kernel_shuffle[3] = 3; |
191 | } |
192 | |
193 | // Collapse the dims |
194 | DSizes<TensorIndex, 2> kernel_dims; |
195 | if (isColMajor) { |
196 | kernel_dims[0] = kernelFilters * kernelRows * kernelCols; |
197 | kernel_dims[1] = kernelChannels; |
198 | } else { |
199 | kernel_dims[1] = kernelFilters * kernelRows * kernelCols; |
200 | kernel_dims[0] = kernelChannels; |
201 | } |
202 | |
203 | // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS |
204 | // When we extract the image patches from output_backward, it will have |
205 | // dimensions |
206 | // out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * |
207 | // OTHERS) |
208 | DSizes<TensorIndex, 2> pre_contract_dims; |
209 | if (isColMajor) { |
210 | pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols; |
211 | pre_contract_dims[1] = inputRows * inputCols; |
212 | for (int i = 3; i < NumDims; ++i) { |
213 | pre_contract_dims[1] *= out.dimension(i); |
214 | } |
215 | } else { |
216 | pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols; |
217 | pre_contract_dims[0] = inputRows * inputCols; |
218 | for (int i = 0; i < NumDims - 3; ++i) { |
219 | pre_contract_dims[0] *= out.dimension(i); |
220 | } |
221 | } |
222 | |
223 | // We will contract along the collapsed dimension that contains the |
224 | // kernelFilters, the kernelRows and the kernelCols. |
225 | array<IndexPair<TensorIndex>, 1> contract_dims; |
226 | if (isColMajor) { |
227 | // col-major: kernel.contract(output.patches) |
228 | contract_dims[0] = IndexPair<TensorIndex>(0, 0); |
229 | } else { |
230 | // row-major: output.patches.contract(kernel) |
231 | contract_dims[0] = IndexPair<TensorIndex>(1, 1); |
232 | } |
233 | |
234 | // Post contraction, the dimensions of the input_backprop is |
235 | // channels X input_rows X input_cols X OTHERS |
236 | DSizes<TensorIndex, NumDims> post_contract_dims; |
237 | if (isColMajor) { |
238 | post_contract_dims[0] = kernelChannels; |
239 | post_contract_dims[1] = inputRows; |
240 | post_contract_dims[2] = inputCols; |
241 | for (int i = 3; i < NumDims; ++i) { |
242 | post_contract_dims[i] = out.dimension(i); |
243 | } |
244 | } else { |
245 | post_contract_dims[NumDims - 1] = kernelChannels; |
246 | post_contract_dims[NumDims - 2] = inputRows; |
247 | post_contract_dims[NumDims - 3] = inputCols; |
248 | for (int i = 0; i < NumDims - 3; ++i) { |
249 | post_contract_dims[i] = out.dimension(i); |
250 | } |
251 | } |
252 | |
253 | // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled |
254 | // evaluation of these ops does not compose. Doing explicit eval is ~8x |
255 | // faster in micro benchmarks. |
256 | |
257 | return choose( |
258 | Cond<internal::traits<OutputBackward>::Layout == ColMajor>(), |
259 | kernel.reverse(kernel_reverse) |
260 | .eval() |
261 | .shuffle(kernel_shuffle) |
262 | .eval() |
263 | .reshape(kernel_dims) |
264 | .contract( |
265 | output_backward |
266 | .extract_image_patches( |
267 | kernelRows, kernelCols, 1, 1, row_in_stride, |
268 | col_in_stride, row_stride, col_stride, padding_top, |
269 | padding_bottom, padding_left, padding_right, OutScalar(0)) |
270 | .reshape(pre_contract_dims), |
271 | contract_dims) |
272 | .reshape(post_contract_dims), |
273 | output_backward |
274 | .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride, |
275 | col_in_stride, row_stride, col_stride, |
276 | padding_top, padding_bottom, padding_left, |
277 | padding_right, OutScalar(0)) |
278 | .reshape(pre_contract_dims) |
279 | .contract(kernel.reverse(kernel_reverse) |
280 | .eval() |
281 | .shuffle(kernel_shuffle) |
282 | .eval() |
283 | .reshape(kernel_dims), |
284 | contract_dims) |
285 | .reshape(post_contract_dims)); |
286 | } |
287 | |
288 | /** SpatialConvolutionBackwardKernel |
289 | * \ingroup CXX11_NeuralNetworks_Module |
290 | * |
291 | * \brief Computes the backprop for the filter of a 2D convolution. |
292 | * |
293 | * The output_backward parameter is expected to be a tensor with a rank of 3 or |
294 | * more (channels, height, width, and optionally others) |
295 | * The kernel parameter is expected to be a 4D tensor (filters, channels, |
296 | * kernel_height, kernel_width) |
297 | * The output_backward and the kernel must both be in col-major layout. The |
298 | * result will also be in col-major layout. |
299 | * |
300 | * If row_in_stride, col_stride > 1, then applies convolution with holes (aka |
301 | * atrous convolution), sampling every row_in_stride, col_in_stride input |
302 | * pixels. |
303 | * |
304 | * The result can be assigned to a tensor of rank equal to the rank of the |
305 | * output_backward. The dimensions of the result will be filters, height, width |
306 | * (and others if applicable). |
307 | * |
308 | * It is possible to swap the order of the width and height dimensions provided |
309 | * that the same order is used in the input, the kernel, and the output. |
310 | * |
311 | */ |
312 | |
313 | template <typename OutputBackward, typename Input> |
314 | EIGEN_ALWAYS_INLINE static const std::conditional_t< |
315 | internal::traits<Input>::Layout == ColMajor, |
316 | const TensorReverseOp< |
317 | const Eigen::array<typename internal::traits<Input>::Index, |
318 | internal::traits<Input>::NumDimensions>, |
319 | const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp< |
320 | const Eigen::array<typename internal::traits<Input>::Index, |
321 | internal::traits<Input>::NumDimensions>, |
322 | const Eigen::TensorReshapingOp< |
323 | const Eigen::DSizes<typename internal::traits<Input>::Index, |
324 | internal::traits<Input>::NumDimensions>, |
325 | const TensorContractionOp< |
326 | const array< |
327 | IndexPair<typename internal::traits<Input>::Index>, 1>, |
328 | const TensorReshapingOp< |
329 | const DSizes<typename internal::traits<Input>::Index, |
330 | 2>, |
331 | const Eigen::TensorForcedEvalOp< |
332 | const Eigen::TensorShufflingOp< |
333 | const Eigen::array< |
334 | typename internal::traits<Input>::Index, |
335 | internal::traits<Input>::NumDimensions>, |
336 | const Input>>>, |
337 | const TensorReshapingOp< |
338 | const DSizes<typename internal::traits<Input>::Index, |
339 | 2>, |
340 | const TensorImagePatchOp< |
341 | Dynamic, Dynamic, |
342 | const Eigen::TensorForcedEvalOp< |
343 | const Eigen::TensorShufflingOp< |
344 | const Eigen::array< |
345 | typename internal::traits<Input>::Index, |
346 | internal::traits<Input>::NumDimensions>, |
347 | const OutputBackward>>>>>>>>>, |
348 | const TensorReverseOp< |
349 | const Eigen::array<typename internal::traits<Input>::Index, |
350 | internal::traits<Input>::NumDimensions>, |
351 | const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp< |
352 | const Eigen::array<typename internal::traits<Input>::Index, |
353 | internal::traits<Input>::NumDimensions>, |
354 | const Eigen::TensorReshapingOp< |
355 | const Eigen::DSizes<typename internal::traits<Input>::Index, |
356 | internal::traits<Input>::NumDimensions>, |
357 | const TensorContractionOp< |
358 | const array< |
359 | IndexPair<typename internal::traits<Input>::Index>, 1>, |
360 | const TensorReshapingOp< |
361 | const DSizes<typename internal::traits<Input>::Index, |
362 | 2>, |
363 | const TensorImagePatchOp< |
364 | Dynamic, Dynamic, |
365 | const Eigen::TensorForcedEvalOp< |
366 | const Eigen::TensorShufflingOp< |
367 | const Eigen::array< |
368 | typename internal::traits<Input>::Index, |
369 | internal::traits<Input>::NumDimensions>, |
370 | const OutputBackward>>>>, |
371 | const TensorReshapingOp< |
372 | const DSizes<typename internal::traits<Input>::Index, |
373 | 2>, |
374 | const Eigen::TensorForcedEvalOp< |
375 | const Eigen::TensorShufflingOp< |
376 | const Eigen::array< |
377 | typename internal::traits<Input>::Index, |
378 | internal::traits<Input>::NumDimensions>, |
379 | const Input>>>>>>>>> |
380 | SpatialConvolutionBackwardKernel( |
381 | const Input& input, const OutputBackward& output_backward, |
382 | typename internal::traits<Input>::Index kernelRows, |
383 | typename internal::traits<Input>::Index kernelCols, |
384 | const DenseIndex row_stride = 1, const DenseIndex col_stride = 1, |
385 | const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) { |
386 | typedef typename internal::traits<Input>::Index TensorIndex; |
387 | typedef typename internal::traits<OutputBackward>::Scalar OutScalar; |
388 | TensorRef<Tensor<typename internal::traits<Input>::Scalar, |
389 | internal::traits<Input>::NumDimensions, |
390 | internal::traits<Input>::Layout, TensorIndex>> |
391 | in(input); |
392 | TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions, |
393 | internal::traits<OutputBackward>::Layout, TensorIndex>> |
394 | out(output_backward); |
395 | |
396 | EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == |
397 | internal::traits<OutputBackward>::Layout, |
398 | YOU_MADE_A_PROGRAMMING_MISTAKE); |
399 | |
400 | // stride and in_stride cannot both be larger than 1 |
401 | eigen_assert(!(row_stride > 1 && row_in_stride > 1)); |
402 | eigen_assert(!(col_stride > 1 && col_in_stride > 1)); |
403 | |
404 | static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); |
405 | |
406 | static const int NumDims = internal::traits<Input>::NumDimensions; |
407 | EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == |
408 | internal::traits<OutputBackward>::NumDimensions, |
409 | YOU_MADE_A_PROGRAMMING_MISTAKE); |
410 | EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE); |
411 | |
412 | const TensorIndex inputRows = |
413 | isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); |
414 | const TensorIndex inputCols = |
415 | isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); |
416 | |
417 | const TensorIndex outputRows = isColMajor |
418 | ? output_backward.dimension(1) |
419 | : output_backward.dimension(NumDims - 2); |
420 | const TensorIndex outputCols = isColMajor |
421 | ? output_backward.dimension(2) |
422 | : output_backward.dimension(NumDims - 3); |
423 | |
424 | // Number of filters to apply. This is the same as the output depth of the |
425 | // result |
426 | const TensorIndex kernelFilters = |
427 | isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1]; |
428 | |
429 | // Number of channels. This is the same as the input depth. |
430 | const TensorIndex kernelChannels = |
431 | isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1]; |
432 | |
433 | // This is the effective kernel size, taking into account the |
434 | // (*_in_stride - 1) zero-values inserted between consecutive kernel |
435 | // elements in atrous convolution |
436 | const TensorIndex kernelRowsEff = |
437 | kernelRows + (kernelRows - 1) * (row_in_stride - 1); |
438 | const TensorIndex kernelColsEff = |
439 | kernelCols + (kernelCols - 1) * (col_in_stride - 1); |
440 | |
441 | // Number of batches (and other dimensions) in the input tensor. |
442 | TensorIndex batch = 1; |
443 | for (int d = 3; d < NumDims; ++d) { |
444 | batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1); |
445 | } |
446 | |
447 | // Computing the forward padding |
448 | const TensorIndex padRows = numext::maxi<Index>( |
449 | 0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows); |
450 | const TensorIndex padCols = numext::maxi<Index>( |
451 | 0, (outputCols - 1) * col_stride + kernelColsEff - inputCols); |
452 | |
453 | TensorIndex padding_top = padRows / 2; |
454 | TensorIndex padding_left = padCols / 2; |
455 | |
456 | // Compute paddings for output_backward before extracting patches. |
457 | const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1; |
458 | const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1; |
459 | |
460 | const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1; |
461 | const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1; |
462 | |
463 | const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top; |
464 | const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left; |
465 | |
466 | const TensorIndex bottom_pad_rows = |
467 | padded_out_rows - expanded_out_rows - top_pad_rows; |
468 | const TensorIndex right_pad_cols = |
469 | padded_out_cols - expanded_out_cols - left_pad_cols; |
470 | |
471 | // Reorder output_backward dimensions. |
472 | array<TensorIndex, 4> output_backward_shuffle; |
473 | if (isColMajor) { |
474 | // From: [out_depth, out_rows, out_cols, batch] |
475 | // To: [batch, out_rows, out_cols, out_depth] |
476 | output_backward_shuffle = {3, 1, 2, 0}; |
477 | } else { |
478 | // From: [batch, out_cols, out_rows, out_depth] |
479 | // To: [out_depth, out_cols, out_rows, batch] |
480 | output_backward_shuffle = {3, 1, 2, 0}; |
481 | } |
482 | |
483 | // Reorder input dimensions. |
484 | array<TensorIndex, 4> input_shuffle; |
485 | if (isColMajor) { |
486 | // From: [in_depth, in_rows, in_cols, batch] |
487 | // To: [in_depth, batch, in_rows, in_cols] |
488 | input_shuffle = {0, 3, 1, 2}; |
489 | } else { |
490 | // From: [batch, in_cols, in_rows, in_depth] |
491 | // To: [in_cols, in_rows, batch, in_depth] |
492 | input_shuffle = {1, 2, 0, 3}; |
493 | } |
494 | |
495 | // Input is playing the role of a "kernel" in this convolution. |
496 | DSizes<TensorIndex, 2> input_dims; |
497 | if (isColMajor) { |
498 | input_dims[0] = kernelChannels; |
499 | input_dims[1] = batch * inputRows * inputCols; |
500 | } else { |
501 | input_dims[1] = kernelChannels; |
502 | input_dims[0] = inputCols * inputRows * batch; |
503 | } |
504 | |
505 | // Molds the output of the patch extraction result into a 2D tensor: |
506 | // - the first dimension (dims[0]): the patch values to be multiplied with the |
507 | // kernels |
508 | // - the second dimension (dims[1]): everything else |
509 | DSizes<TensorIndex, 2> pre_contract_dims; |
510 | if (isColMajor) { |
511 | pre_contract_dims[0] = batch * inputRows * inputCols; |
512 | pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters; |
513 | } else { |
514 | pre_contract_dims[1] = inputCols * inputRows * batch; |
515 | pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows; |
516 | } |
517 | |
518 | // We will contract along the collapsed dimension that contains the |
519 | // batch, inputRows and inputCols. |
520 | array<IndexPair<TensorIndex>, 1> contract_dims; |
521 | contract_dims[0] = IndexPair<TensorIndex>(1, 0); |
522 | |
523 | // Dimensions after contraction. |
524 | DSizes<TensorIndex, NumDims> post_contract_dims; |
525 | if (isColMajor) { |
526 | post_contract_dims[0] = kernelChannels; |
527 | post_contract_dims[1] = kernelRows; |
528 | post_contract_dims[2] = kernelCols; |
529 | post_contract_dims[3] = kernelFilters; |
530 | } else { |
531 | post_contract_dims[0] = kernelFilters; |
532 | post_contract_dims[1] = kernelCols; |
533 | post_contract_dims[2] = kernelRows; |
534 | post_contract_dims[3] = kernelChannels; |
535 | } |
536 | |
537 | // Reorder output of contraction to a valid filter shape. |
538 | array<TensorIndex, 4> kernel_shuffle; |
539 | if (isColMajor) { |
540 | // From: [in_depth, kernel_rows, kernel_cols, out_depth] |
541 | // To: [out_depth, in_depth, kernel_rows, kernel_cols] |
542 | kernel_shuffle = {3, 0, 1, 2}; |
543 | } else { |
544 | // From: [out_depth, kernel_cols, kernel_rows, in_depth] |
545 | // To: [kernel_cols, kernel_rows, in_depth, out_depth] |
546 | kernel_shuffle = {1, 2, 3, 0}; |
547 | } |
548 | |
549 | // Reverse kernel backprop dimensions. |
550 | array<TensorIndex, 4> kernel_reverse; |
551 | if (isColMajor) { |
552 | kernel_reverse = {false, false, true, true}; |
553 | } else { |
554 | kernel_reverse = {true, true, false, false}; |
555 | } |
556 | |
557 | // Create convolution input (aka source of patches) from output backward |
558 | // tensor by shuffling dimensions. |
559 | const auto output_backward_shuffled = |
560 | output_backward.shuffle(output_backward_shuffle).eval(); |
561 | |
562 | // Create convolution kernel (aka filter) from input by shuffling and |
563 | // reshaping. |
564 | const auto input_shuffled = |
565 | input.shuffle(input_shuffle).eval().reshape(input_dims); |
566 | |
567 | return choose( |
568 | Cond<internal::traits<OutputBackward>::Layout == ColMajor>(), |
569 | input_shuffled.contract( |
570 | output_backward_shuffled |
571 | .extract_image_patches(inputRows, inputCols, row_in_stride, |
572 | col_in_stride, 1, 1, row_stride, |
573 | col_stride, top_pad_rows, |
574 | bottom_pad_rows, left_pad_cols, |
575 | right_pad_cols, OutScalar(0)) |
576 | .reshape(pre_contract_dims), |
577 | contract_dims), |
578 | output_backward_shuffled |
579 | .extract_image_patches( |
580 | inputRows, inputCols, row_in_stride, col_in_stride, 1, 1, |
581 | row_stride, col_stride, top_pad_rows, bottom_pad_rows, |
582 | left_pad_cols, right_pad_cols, OutScalar(0)) |
583 | .reshape(pre_contract_dims) |
584 | .contract(input_shuffled, contract_dims)) |
585 | .reshape(post_contract_dims) |
586 | .shuffle(kernel_shuffle) |
587 | .eval() |
588 | .reverse(kernel_reverse); |
589 | } |
590 | |
591 | } // end namespace Eigen |
592 | |
593 | #endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ |
594 | |