1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
17#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
21
22namespace Eigen {
23
24/** SpatialConvolutionBackwardInput
25 * \ingroup CXX11_NeuralNetworks_Module
26 *
27 * \brief Computes the backprop for the input of a 2D convolution.
28 *
29 * The output_backward parameter is expected to be a tensor with a rank of 3 or
30 * more (channels, height, width, and optionally others)
31 * The kernel parameter is expected to be a 4D tensor (filters, channels,
32 * kernel_height, kernel_width)
33 * The output_backward and the kernel must both be in col-major layout. The
34 * result will also be in col-major layout.
35 *
36 * If row_in_stride, col_in_stride > 1, then applies convolution with holes
37 * (aka atrous convolution), sampling every row_in_stride, col_in_stride input
38 * pixels.
39 *
40 * The result can be assigned to a tensor of rank equal to the rank of the
41 * output_backward. The dimensions of the result will be filters, height, width
42 * (and others if applicable).
43 *
44 * It is possible to swap the order of the width and height dimensions provided
45 * that the same order is used in the input, the kernel, and the output.
46 *
47 */
48typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>>
49 ReverseColMajor;
50typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>>
51 ReverseRowMajor;
52
53template <typename OutputBackward, typename Kernel>
54EIGEN_ALWAYS_INLINE static const std::conditional_t<
55 internal::traits<OutputBackward>::Layout == ColMajor,
56 TensorReshapingOp<
57 const DSizes<typename internal::traits<OutputBackward>::Index,
58 internal::traits<OutputBackward>::NumDimensions>,
59 const TensorContractionOp<
60 const array<
61 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
62 const TensorReshapingOp<
63 const DSizes<typename internal::traits<OutputBackward>::Index,
64 2>,
65 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
66 const array<
67 typename internal::traits<OutputBackward>::Index, 4>,
68 const Eigen::TensorForcedEvalOp<const TensorReverseOp<
69 const ReverseColMajor, const Kernel>>>>>,
70 const TensorReshapingOp<
71 const DSizes<typename internal::traits<OutputBackward>::Index,
72 2>,
73 const TensorImagePatchOp<Dynamic, Dynamic,
74 const OutputBackward>>>>,
75 TensorReshapingOp<
76
77 const DSizes<typename internal::traits<OutputBackward>::Index,
78 internal::traits<OutputBackward>::NumDimensions>,
79 const TensorContractionOp<
80 const array<
81 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
82 const TensorReshapingOp<
83 const DSizes<typename internal::traits<OutputBackward>::Index,
84 2>,
85 const TensorImagePatchOp<Dynamic, Dynamic,
86 const OutputBackward>>,
87 const TensorReshapingOp<
88 const DSizes<typename internal::traits<OutputBackward>::Index,
89 2>,
90 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
91 const array<
92 typename internal::traits<OutputBackward>::Index, 4>,
93 const Eigen::TensorForcedEvalOp<const TensorReverseOp<
94 const ReverseRowMajor, const Kernel>>>>>>>>
95SpatialConvolutionBackwardInput(
96 const Kernel& kernel, const OutputBackward& output_backward,
97 typename internal::traits<OutputBackward>::Index inputRows,
98 typename internal::traits<OutputBackward>::Index inputCols,
99 const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
100 const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
101 typedef typename internal::traits<OutputBackward>::Index TensorIndex;
102 typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
103 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
104 internal::traits<Kernel>::NumDimensions,
105 internal::traits<Kernel>::Layout, TensorIndex>>
106 kern(kernel);
107 TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
108 internal::traits<OutputBackward>::Layout, TensorIndex>>
109 out(output_backward);
110
111 EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
112 internal::traits<OutputBackward>::Layout,
113 YOU_MADE_A_PROGRAMMING_MISTAKE);
114
115 static const bool isColMajor =
116 (internal::traits<OutputBackward>::Layout == ColMajor);
117
118 static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
119
120 // Number of filters to apply. This is the same as the output depth of the
121 // result
122 const TensorIndex kernelFilters =
123 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
124 // Number of channels. This is the same as the input depth.
125 const TensorIndex kernelChannels =
126 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
127 const TensorIndex kernelRows =
128 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
129 const TensorIndex kernelCols =
130 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
131
132 // This is the effective kernel size, taking into account the (*_in_stride -
133 // 1) zero-values
134 // inserted between consecutive kernel elements in atrous convolution
135 const TensorIndex kernelRowsEff =
136 kernelRows + (kernelRows - 1) * (row_in_stride - 1);
137 const TensorIndex kernelColsEff =
138 kernelCols + (kernelCols - 1) * (col_in_stride - 1);
139
140 const TensorIndex outputRows = isColMajor
141 ? output_backward.dimension(1)
142 : output_backward.dimension(NumDims - 2);
143 const TensorIndex outputCols = isColMajor
144 ? output_backward.dimension(2)
145 : output_backward.dimension(NumDims - 3);
146
147 // Computing the forward padding
148 const TensorIndex forward_pad_top = numext::maxi<Index>(
149 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
150 const TensorIndex forward_pad_left = numext::maxi<Index>(
151 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
152 const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
153 const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
154
155 const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
156 2 - padding_top + kernelRowsEff;
157 const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
158 2 - padding_left + kernelColsEff;
159
160 eigen_assert(padding_top >= 0);
161 eigen_assert(padding_left >= 0);
162 eigen_assert(padding_bottom >= 0);
163 eigen_assert(padding_right >= 0);
164
165 // The kernel has dimensions filters X channels X patch_rows X patch_cols
166 // We need to reverse the kernel along dimensions corresponding to rows and
167 // cols.
168 // TODO(yangke): we can make things slightly faster by collapsing the
169 // dimensions
170 // where we don't reverse. Try that once we have a faster compiler.
171 typedef std::conditional_t<isColMajor, ReverseColMajor, ReverseRowMajor>
172 Reverse;
173 Reverse kernel_reverse;
174 // Reorder the dimensions to:
175 // filters x patch_rows x patch_cols x channels
176 array<TensorIndex, 4> kernel_shuffle;
177 if (isColMajor) {
178 // From: filters x channels x rows x cols
179 // To: filters x rows x cols x channels
180 kernel_shuffle[0] = 0;
181 kernel_shuffle[1] = 2;
182 kernel_shuffle[2] = 3;
183 kernel_shuffle[3] = 1;
184 } else {
185 // From: cols x rows x channels x filters
186 // To: channels x cols x rows x filters
187 kernel_shuffle[0] = 2;
188 kernel_shuffle[1] = 0;
189 kernel_shuffle[2] = 1;
190 kernel_shuffle[3] = 3;
191 }
192
193 // Collapse the dims
194 DSizes<TensorIndex, 2> kernel_dims;
195 if (isColMajor) {
196 kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
197 kernel_dims[1] = kernelChannels;
198 } else {
199 kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
200 kernel_dims[0] = kernelChannels;
201 }
202
203 // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
204 // When we extract the image patches from output_backward, it will have
205 // dimensions
206 // out_depth X (patch_rows * patch_cols) X (input_rows * input_cols *
207 // OTHERS)
208 DSizes<TensorIndex, 2> pre_contract_dims;
209 if (isColMajor) {
210 pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
211 pre_contract_dims[1] = inputRows * inputCols;
212 for (int i = 3; i < NumDims; ++i) {
213 pre_contract_dims[1] *= out.dimension(i);
214 }
215 } else {
216 pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
217 pre_contract_dims[0] = inputRows * inputCols;
218 for (int i = 0; i < NumDims - 3; ++i) {
219 pre_contract_dims[0] *= out.dimension(i);
220 }
221 }
222
223 // We will contract along the collapsed dimension that contains the
224 // kernelFilters, the kernelRows and the kernelCols.
225 array<IndexPair<TensorIndex>, 1> contract_dims;
226 if (isColMajor) {
227 // col-major: kernel.contract(output.patches)
228 contract_dims[0] = IndexPair<TensorIndex>(0, 0);
229 } else {
230 // row-major: output.patches.contract(kernel)
231 contract_dims[0] = IndexPair<TensorIndex>(1, 1);
232 }
233
234 // Post contraction, the dimensions of the input_backprop is
235 // channels X input_rows X input_cols X OTHERS
236 DSizes<TensorIndex, NumDims> post_contract_dims;
237 if (isColMajor) {
238 post_contract_dims[0] = kernelChannels;
239 post_contract_dims[1] = inputRows;
240 post_contract_dims[2] = inputCols;
241 for (int i = 3; i < NumDims; ++i) {
242 post_contract_dims[i] = out.dimension(i);
243 }
244 } else {
245 post_contract_dims[NumDims - 1] = kernelChannels;
246 post_contract_dims[NumDims - 2] = inputRows;
247 post_contract_dims[NumDims - 3] = inputCols;
248 for (int i = 0; i < NumDims - 3; ++i) {
249 post_contract_dims[i] = out.dimension(i);
250 }
251 }
252
253 // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled
254 // evaluation of these ops does not compose. Doing explicit eval is ~8x
255 // faster in micro benchmarks.
256
257 return choose(
258 Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
259 kernel.reverse(kernel_reverse)
260 .eval()
261 .shuffle(kernel_shuffle)
262 .eval()
263 .reshape(kernel_dims)
264 .contract(
265 output_backward
266 .extract_image_patches(
267 kernelRows, kernelCols, 1, 1, row_in_stride,
268 col_in_stride, row_stride, col_stride, padding_top,
269 padding_bottom, padding_left, padding_right, OutScalar(0))
270 .reshape(pre_contract_dims),
271 contract_dims)
272 .reshape(post_contract_dims),
273 output_backward
274 .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride,
275 col_in_stride, row_stride, col_stride,
276 padding_top, padding_bottom, padding_left,
277 padding_right, OutScalar(0))
278 .reshape(pre_contract_dims)
279 .contract(kernel.reverse(kernel_reverse)
280 .eval()
281 .shuffle(kernel_shuffle)
282 .eval()
283 .reshape(kernel_dims),
284 contract_dims)
285 .reshape(post_contract_dims));
286}
287
288/** SpatialConvolutionBackwardKernel
289 * \ingroup CXX11_NeuralNetworks_Module
290 *
291 * \brief Computes the backprop for the filter of a 2D convolution.
292 *
293 * The output_backward parameter is expected to be a tensor with a rank of 3 or
294 * more (channels, height, width, and optionally others)
295 * The kernel parameter is expected to be a 4D tensor (filters, channels,
296 * kernel_height, kernel_width)
297 * The output_backward and the kernel must both be in col-major layout. The
298 * result will also be in col-major layout.
299 *
300 * If row_in_stride, col_stride > 1, then applies convolution with holes (aka
301 * atrous convolution), sampling every row_in_stride, col_in_stride input
302 * pixels.
303 *
304 * The result can be assigned to a tensor of rank equal to the rank of the
305 * output_backward. The dimensions of the result will be filters, height, width
306 * (and others if applicable).
307 *
308 * It is possible to swap the order of the width and height dimensions provided
309 * that the same order is used in the input, the kernel, and the output.
310 *
311 */
312
313template <typename OutputBackward, typename Input>
314EIGEN_ALWAYS_INLINE static const std::conditional_t<
315 internal::traits<Input>::Layout == ColMajor,
316 const TensorReverseOp<
317 const Eigen::array<typename internal::traits<Input>::Index,
318 internal::traits<Input>::NumDimensions>,
319 const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
320 const Eigen::array<typename internal::traits<Input>::Index,
321 internal::traits<Input>::NumDimensions>,
322 const Eigen::TensorReshapingOp<
323 const Eigen::DSizes<typename internal::traits<Input>::Index,
324 internal::traits<Input>::NumDimensions>,
325 const TensorContractionOp<
326 const array<
327 IndexPair<typename internal::traits<Input>::Index>, 1>,
328 const TensorReshapingOp<
329 const DSizes<typename internal::traits<Input>::Index,
330 2>,
331 const Eigen::TensorForcedEvalOp<
332 const Eigen::TensorShufflingOp<
333 const Eigen::array<
334 typename internal::traits<Input>::Index,
335 internal::traits<Input>::NumDimensions>,
336 const Input>>>,
337 const TensorReshapingOp<
338 const DSizes<typename internal::traits<Input>::Index,
339 2>,
340 const TensorImagePatchOp<
341 Dynamic, Dynamic,
342 const Eigen::TensorForcedEvalOp<
343 const Eigen::TensorShufflingOp<
344 const Eigen::array<
345 typename internal::traits<Input>::Index,
346 internal::traits<Input>::NumDimensions>,
347 const OutputBackward>>>>>>>>>,
348 const TensorReverseOp<
349 const Eigen::array<typename internal::traits<Input>::Index,
350 internal::traits<Input>::NumDimensions>,
351 const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
352 const Eigen::array<typename internal::traits<Input>::Index,
353 internal::traits<Input>::NumDimensions>,
354 const Eigen::TensorReshapingOp<
355 const Eigen::DSizes<typename internal::traits<Input>::Index,
356 internal::traits<Input>::NumDimensions>,
357 const TensorContractionOp<
358 const array<
359 IndexPair<typename internal::traits<Input>::Index>, 1>,
360 const TensorReshapingOp<
361 const DSizes<typename internal::traits<Input>::Index,
362 2>,
363 const TensorImagePatchOp<
364 Dynamic, Dynamic,
365 const Eigen::TensorForcedEvalOp<
366 const Eigen::TensorShufflingOp<
367 const Eigen::array<
368 typename internal::traits<Input>::Index,
369 internal::traits<Input>::NumDimensions>,
370 const OutputBackward>>>>,
371 const TensorReshapingOp<
372 const DSizes<typename internal::traits<Input>::Index,
373 2>,
374 const Eigen::TensorForcedEvalOp<
375 const Eigen::TensorShufflingOp<
376 const Eigen::array<
377 typename internal::traits<Input>::Index,
378 internal::traits<Input>::NumDimensions>,
379 const Input>>>>>>>>>
380SpatialConvolutionBackwardKernel(
381 const Input& input, const OutputBackward& output_backward,
382 typename internal::traits<Input>::Index kernelRows,
383 typename internal::traits<Input>::Index kernelCols,
384 const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
385 const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
386 typedef typename internal::traits<Input>::Index TensorIndex;
387 typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
388 TensorRef<Tensor<typename internal::traits<Input>::Scalar,
389 internal::traits<Input>::NumDimensions,
390 internal::traits<Input>::Layout, TensorIndex>>
391 in(input);
392 TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
393 internal::traits<OutputBackward>::Layout, TensorIndex>>
394 out(output_backward);
395
396 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
397 internal::traits<OutputBackward>::Layout,
398 YOU_MADE_A_PROGRAMMING_MISTAKE);
399
400 // stride and in_stride cannot both be larger than 1
401 eigen_assert(!(row_stride > 1 && row_in_stride > 1));
402 eigen_assert(!(col_stride > 1 && col_in_stride > 1));
403
404 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
405
406 static const int NumDims = internal::traits<Input>::NumDimensions;
407 EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
408 internal::traits<OutputBackward>::NumDimensions,
409 YOU_MADE_A_PROGRAMMING_MISTAKE);
410 EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
411
412 const TensorIndex inputRows =
413 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
414 const TensorIndex inputCols =
415 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
416
417 const TensorIndex outputRows = isColMajor
418 ? output_backward.dimension(1)
419 : output_backward.dimension(NumDims - 2);
420 const TensorIndex outputCols = isColMajor
421 ? output_backward.dimension(2)
422 : output_backward.dimension(NumDims - 3);
423
424 // Number of filters to apply. This is the same as the output depth of the
425 // result
426 const TensorIndex kernelFilters =
427 isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
428
429 // Number of channels. This is the same as the input depth.
430 const TensorIndex kernelChannels =
431 isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
432
433 // This is the effective kernel size, taking into account the
434 // (*_in_stride - 1) zero-values inserted between consecutive kernel
435 // elements in atrous convolution
436 const TensorIndex kernelRowsEff =
437 kernelRows + (kernelRows - 1) * (row_in_stride - 1);
438 const TensorIndex kernelColsEff =
439 kernelCols + (kernelCols - 1) * (col_in_stride - 1);
440
441 // Number of batches (and other dimensions) in the input tensor.
442 TensorIndex batch = 1;
443 for (int d = 3; d < NumDims; ++d) {
444 batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1);
445 }
446
447 // Computing the forward padding
448 const TensorIndex padRows = numext::maxi<Index>(
449 0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows);
450 const TensorIndex padCols = numext::maxi<Index>(
451 0, (outputCols - 1) * col_stride + kernelColsEff - inputCols);
452
453 TensorIndex padding_top = padRows / 2;
454 TensorIndex padding_left = padCols / 2;
455
456 // Compute paddings for output_backward before extracting patches.
457 const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1;
458 const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1;
459
460 const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1;
461 const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1;
462
463 const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top;
464 const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left;
465
466 const TensorIndex bottom_pad_rows =
467 padded_out_rows - expanded_out_rows - top_pad_rows;
468 const TensorIndex right_pad_cols =
469 padded_out_cols - expanded_out_cols - left_pad_cols;
470
471 // Reorder output_backward dimensions.
472 array<TensorIndex, 4> output_backward_shuffle;
473 if (isColMajor) {
474 // From: [out_depth, out_rows, out_cols, batch]
475 // To: [batch, out_rows, out_cols, out_depth]
476 output_backward_shuffle = {3, 1, 2, 0};
477 } else {
478 // From: [batch, out_cols, out_rows, out_depth]
479 // To: [out_depth, out_cols, out_rows, batch]
480 output_backward_shuffle = {3, 1, 2, 0};
481 }
482
483 // Reorder input dimensions.
484 array<TensorIndex, 4> input_shuffle;
485 if (isColMajor) {
486 // From: [in_depth, in_rows, in_cols, batch]
487 // To: [in_depth, batch, in_rows, in_cols]
488 input_shuffle = {0, 3, 1, 2};
489 } else {
490 // From: [batch, in_cols, in_rows, in_depth]
491 // To: [in_cols, in_rows, batch, in_depth]
492 input_shuffle = {1, 2, 0, 3};
493 }
494
495 // Input is playing the role of a "kernel" in this convolution.
496 DSizes<TensorIndex, 2> input_dims;
497 if (isColMajor) {
498 input_dims[0] = kernelChannels;
499 input_dims[1] = batch * inputRows * inputCols;
500 } else {
501 input_dims[1] = kernelChannels;
502 input_dims[0] = inputCols * inputRows * batch;
503 }
504
505 // Molds the output of the patch extraction result into a 2D tensor:
506 // - the first dimension (dims[0]): the patch values to be multiplied with the
507 // kernels
508 // - the second dimension (dims[1]): everything else
509 DSizes<TensorIndex, 2> pre_contract_dims;
510 if (isColMajor) {
511 pre_contract_dims[0] = batch * inputRows * inputCols;
512 pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters;
513 } else {
514 pre_contract_dims[1] = inputCols * inputRows * batch;
515 pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows;
516 }
517
518 // We will contract along the collapsed dimension that contains the
519 // batch, inputRows and inputCols.
520 array<IndexPair<TensorIndex>, 1> contract_dims;
521 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
522
523 // Dimensions after contraction.
524 DSizes<TensorIndex, NumDims> post_contract_dims;
525 if (isColMajor) {
526 post_contract_dims[0] = kernelChannels;
527 post_contract_dims[1] = kernelRows;
528 post_contract_dims[2] = kernelCols;
529 post_contract_dims[3] = kernelFilters;
530 } else {
531 post_contract_dims[0] = kernelFilters;
532 post_contract_dims[1] = kernelCols;
533 post_contract_dims[2] = kernelRows;
534 post_contract_dims[3] = kernelChannels;
535 }
536
537 // Reorder output of contraction to a valid filter shape.
538 array<TensorIndex, 4> kernel_shuffle;
539 if (isColMajor) {
540 // From: [in_depth, kernel_rows, kernel_cols, out_depth]
541 // To: [out_depth, in_depth, kernel_rows, kernel_cols]
542 kernel_shuffle = {3, 0, 1, 2};
543 } else {
544 // From: [out_depth, kernel_cols, kernel_rows, in_depth]
545 // To: [kernel_cols, kernel_rows, in_depth, out_depth]
546 kernel_shuffle = {1, 2, 3, 0};
547 }
548
549 // Reverse kernel backprop dimensions.
550 array<TensorIndex, 4> kernel_reverse;
551 if (isColMajor) {
552 kernel_reverse = {false, false, true, true};
553 } else {
554 kernel_reverse = {true, true, false, false};
555 }
556
557 // Create convolution input (aka source of patches) from output backward
558 // tensor by shuffling dimensions.
559 const auto output_backward_shuffled =
560 output_backward.shuffle(output_backward_shuffle).eval();
561
562 // Create convolution kernel (aka filter) from input by shuffling and
563 // reshaping.
564 const auto input_shuffled =
565 input.shuffle(input_shuffle).eval().reshape(input_dims);
566
567 return choose(
568 Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
569 input_shuffled.contract(
570 output_backward_shuffled
571 .extract_image_patches(inputRows, inputCols, row_in_stride,
572 col_in_stride, 1, 1, row_stride,
573 col_stride, top_pad_rows,
574 bottom_pad_rows, left_pad_cols,
575 right_pad_cols, OutScalar(0))
576 .reshape(pre_contract_dims),
577 contract_dims),
578 output_backward_shuffled
579 .extract_image_patches(
580 inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
581 row_stride, col_stride, top_pad_rows, bottom_pad_rows,
582 left_pad_cols, right_pad_cols, OutScalar(0))
583 .reshape(pre_contract_dims)
584 .contract(input_shuffled, contract_dims))
585 .reshape(post_contract_dims)
586 .shuffle(kernel_shuffle)
587 .eval()
588 .reverse(kernel_reverse);
589}
590
591} // end namespace Eigen
592
593#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
594