1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
19#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
20
21#define USE_EIGEN_TENSOR
22#define EIGEN_USE_THREADS
23
24#include <algorithm>
25#include <limits>
26#include <vector>
27
28#include "absl/base/dynamic_annotations.h"
29#include "tensorflow/core/framework/bounds_check.h"
30#include "tensorflow/core/framework/kernel_shape_util.h"
31#include "tensorflow/core/framework/numeric_op.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/register_types.h"
34#include "tensorflow/core/framework/tensor.h"
35#include "tensorflow/core/framework/tensor_shape.h"
36#include "tensorflow/core/framework/tensor_slice.h"
37#include "tensorflow/core/kernels/conv_2d.h"
38#include "tensorflow/core/kernels/conv_grad_ops.h"
39#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
40#include "tensorflow/core/kernels/fill_functor.h"
41#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
42#include "tensorflow/core/kernels/xsmm_conv2d.h"
43#endif
44#include "tensorflow/core/lib/core/errors.h"
45#include "tensorflow/core/lib/gtl/array_slice.h"
46#include "tensorflow/core/platform/logging.h"
47#include "tensorflow/core/platform/macros.h"
48#include "tensorflow/core/util/padding.h"
49#include "tensorflow/core/util/tensor_format.h"
50#include "tensorflow/core/util/use_cudnn.h"
51#include "tensorflow/core/util/work_sharder.h"
52
53#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
54#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
55#endif
56
57#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
58#include "tensorflow/core/kernels/conv_ops_gpu.h"
59#include "tensorflow/core/platform/stream_executor.h"
60#include "tensorflow/core/util/proto/proto_utils.h"
61#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
62#if GOOGLE_CUDA
63#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
64#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
65#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
66#endif // GOOGLE_CUDA
67
68namespace tensorflow {
69
70typedef Eigen::ThreadPoolDevice CPUDevice;
71typedef Eigen::GpuDevice GPUDevice;
72
73// Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
74// order (height, width, depth), constructed from patches in 'col_data', which
75// is required to be in storage order (out_height * out_width, filter_height,
76// filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
77template <typename T>
78void Col2im(const T* col_data, const int depth, const int height,
79 const int width, const int filter_h, const int filter_w,
80 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
81 const int stride_h, const int stride_w, T* __restrict im_data) {
82 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
83 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
84 int h_pad = -pad_t;
85 for (int h = 0; h < height_col; ++h) {
86 int w_pad = -pad_l;
87 for (int w = 0; w < width_col; ++w) {
88 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
89 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
90 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
91 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
92 for (int i = 0; i < depth; ++i) {
93 im_patch_data[i] += col_data[i];
94 }
95 }
96 im_patch_data += depth;
97 col_data += depth;
98 }
99 // Jump over remaining number of depth.
100 im_patch_data += depth * (width - filter_w);
101 }
102 w_pad += stride_w;
103 }
104 h_pad += stride_h;
105 }
106}
107
108// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU
109// and GPU (for int32 only).
110template <typename Device, typename T>
111struct LaunchConv2DBackpropInputOpImpl {
112 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
113 const Tensor& out_backprop, const Tensor& filter,
114 int row_dilation, int col_dilation, int row_stride,
115 int col_stride, const Padding& padding,
116 const std::vector<int64_t>& explicit_paddings,
117 Tensor* in_backprop, TensorFormat data_format) {
118 std::vector<int32> strides(4, 1);
119 std::vector<int32> dilations(4, 1);
120
121 auto input_h = GetTensorDimIndex(data_format, 'H');
122 auto input_w = GetTensorDimIndex(data_format, 'W');
123 strides[input_h] = row_stride;
124 strides[input_w] = col_stride;
125 dilations[input_h] = row_dilation;
126 dilations[input_w] = col_dilation;
127
128 const TensorShape& input_shape = in_backprop->shape();
129 const TensorShape& filter_shape = filter.shape();
130
131 ConvBackpropDimensions dims;
132 OP_REQUIRES_OK(
133 ctx, ConvBackpropComputeDimensionsV2(
134 "Conv2DBackpropInput", /*num_spatial_dims=*/2, input_shape,
135 filter_shape, out_backprop.shape(), dilations, strides,
136 padding, explicit_paddings, data_format, &dims));
137
138 int64_t padding_top = -1, padding_bottom = -1;
139 int64_t padding_left = -1, padding_right = -1;
140 if (padding == EXPLICIT) {
141 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
142 &padding_top, &padding_bottom);
143 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
144 &padding_left, &padding_right);
145 }
146
147 int64_t expected_out_rows, expected_out_cols;
148 // The function is guaranteed to succeed because we checked the output and
149 // padding was valid earlier.
150 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
151 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
152 row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
153 &padding_bottom));
154 DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
155
156 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
157 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
158 col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
159 &padding_right));
160 DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
161
162 if (std::is_same<Device, GPUDevice>::value) {
163 int64_t size = 1;
164#define REQUIRES_32BIT(x) \
165 size *= x; \
166 OP_REQUIRES(ctx, \
167 FastBoundsCheck(x, std::numeric_limits<int32>::max()) && \
168 FastBoundsCheck(size, std::numeric_limits<int32>::max()), \
169 errors::InvalidArgument("Tensor too large"))
170
171 REQUIRES_32BIT(in_backprop->dim_size(0));
172 REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom);
173 REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right);
174 REQUIRES_32BIT(in_backprop->dim_size(3));
175#undef REQUIRES_32BIT
176 }
177
178 auto in_backprop_t = in_backprop->tensor<T, 4>();
179 auto out_backprop_t = out_backprop.tensor<T, 4>();
180 auto filter_t = filter.tensor<T, 4>();
181
182 // WARNING: Need to swap row/col, padding_top/padding_left, and
183 // padding_bottom/padding_right when calling Eigen. Eigen expects tensors
184 // in NWHC format, but Tensorflow uses NHWC.
185
186 if (padding != EXPLICIT) {
187 // If padding was not explicitly defined, Eigen spatial convolution
188 // backward input will infer correct forward paddings from input tensors.
189 functor::SpatialConvolutionBackwardInputFunc<Device, T>()(
190 ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
191 col_stride, row_stride, col_dilation, row_dilation);
192 } else {
193 functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Device,
194 T>()(
195 ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
196 in_backprop_t.dimension(2) + (padding_left + padding_right),
197 in_backprop_t.dimension(1) + (padding_top + padding_bottom),
198 col_stride, row_stride, col_dilation, row_dilation, padding_top,
199 padding_left);
200 }
201 }
202};
203
204// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU.
205template <typename T>
206struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
207 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
208 const Tensor& out_backprop, const Tensor& filter,
209 int row_dilation, int col_dilation, int row_stride,
210 int col_stride, const Padding& padding,
211 const std::vector<int64_t>& explicit_paddings,
212 Tensor* in_backprop, TensorFormat data_format) {
213 LaunchConv2DBackpropInputOpImpl<CPUDevice, T> launcher;
214 launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
215 row_dilation, col_dilation, row_stride, col_stride, padding,
216 explicit_paddings, in_backprop, data_format);
217 }
218};
219
220#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
221template <typename Device, class T>
222struct LaunchXsmmBackwardInputConvolution {
223 bool operator()(OpKernelContext* context, const Device& d,
224 typename TTypes<T, 4>::Tensor input_backward,
225 typename TTypes<T, 4>::ConstTensor kernel,
226 typename TTypes<T, 4>::ConstTensor output_backward,
227 int input_rows, int input_cols, int row_stride,
228 int col_stride, int pad_h, int pad_w,
229 TensorFormat data_format) const {
230 return false;
231 }
232};
233
234template <>
235struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
236 bool operator()(OpKernelContext* context, const CPUDevice& d,
237 typename TTypes<float, 4>::Tensor input_backward,
238 typename TTypes<float, 4>::ConstTensor kernel,
239 typename TTypes<float, 4>::ConstTensor output_backward,
240 int input_rows, int input_cols, int row_stride,
241 int col_stride, int pad_h, int pad_w,
242 TensorFormat data_format) const {
243 auto batch = input_backward.dimension(0);
244 auto in_depth = input_backward.dimension(3);
245 auto out_depth = output_backward.dimension(3);
246 auto filter_rows = kernel.dimension(0);
247 auto filter_cols = kernel.dimension(1);
248 auto num_threads =
249 context->device()->tensorflow_cpu_worker_threads()->num_threads;
250 // See libxsmm_dnn.h for this struct definition.
251 libxsmm_dnn_conv_desc desc;
252 desc.N = batch;
253 desc.C = in_depth;
254 desc.H = input_rows;
255 desc.W = input_cols;
256 desc.K = out_depth;
257 desc.R = filter_rows;
258 desc.S = filter_cols;
259 desc.u = row_stride;
260 desc.v = col_stride;
261 desc.pad_h = pad_h;
262 desc.pad_w = pad_w;
263 desc.pad_h_in = 0;
264 desc.pad_w_in = 0;
265 desc.pad_h_out = 0;
266 desc.pad_w_out = 0;
267 desc.threads = num_threads;
268 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
269 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
270 desc.filter_format =
271 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
272 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
273 desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
274 desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
275 desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
276 auto input_ptr = input_backward.data();
277 auto filter_ptr = kernel.data();
278 auto output_ptr = output_backward.data();
279
280 bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
281 context, desc, input_ptr, filter_ptr, output_ptr);
282 return success;
283 }
284};
285#endif
286
287template <typename T>
288struct Conv2DCustomBackpropInputMatMulFunctor {
289 using MatrixMap = Eigen::Map<
290 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
291 using ConstMatrixMap = Eigen::Map<
292 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
293
294 void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
295 const int filter_total_size, const int output_image_size,
296 const int dims_out_depth, T* im2col_buf) {
297 // Compute gradient into 'im2col_buf'.
298 MatrixMap C(im2col_buf, output_image_size, filter_total_size);
299
300 ConstMatrixMap A(out_data, output_image_size, dims_out_depth);
301 ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth);
302
303 C.noalias() = A * B.transpose();
304 }
305};
306
307#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
308template <>
309struct Conv2DCustomBackpropInputMatMulFunctor<float> {
310 using T = float;
311
312 void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
313 const int filter_total_size, const int output_image_size,
314 const int dims_out_depth, T* im2col_buf) {
315 // Inputs are in RowMajor order.
316 // im2col = out_data * filter_data^T
317 // [ois x fts] = [ois x dod] * [fts x dod]^T
318 //
319 // Dimension names:
320 // out_image_size -> ois
321 // filter_total_size -> fts
322 // dims_out_depth -> dod
323
324 const int m = output_image_size;
325 const int n = filter_total_size;
326 const int k = dims_out_depth; // contraction dim
327
328 const char transposeA = 'N'; // sgemm(A) == filter_data
329 const char transposeB = 'T'; // sgemm(B) == out_data
330
331 const int ldA = dims_out_depth;
332 const int ldB = dims_out_depth;
333 const int ldC = filter_total_size;
334
335 const float alpha = 1.0;
336 const float beta = 0.0;
337
338 // dnnl_sgemm code can't be instrumented with msan.
339 ANNOTATE_MEMORY_IS_INITIALIZED(
340 im2col_buf, filter_total_size * output_image_size * sizeof(T));
341
342 dnnl_status_t st =
343 dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA,
344 filter_data, ldB, beta, im2col_buf, ldC);
345
346 OP_REQUIRES(
347 ctx, st == 0,
348 errors::Internal("Failed to call dnnl_sgemm. Error code: ", st));
349 }
350};
351#endif
352
353template <typename Device, class T>
354class Conv2DBackpropInputOp : public OpKernel {
355 public:
356 explicit Conv2DBackpropInputOp(OpKernelConstruction* context)
357 : OpKernel(context) {
358 string data_format;
359 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
360 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
361 errors::InvalidArgument("Invalid data format"));
362
363 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
364 OP_REQUIRES(context, strides_.size() == 4,
365 errors::InvalidArgument("Sliding window strides field must "
366 "specify 4 dimensions"));
367 int stride_n = GetTensorDim(strides_, data_format_, 'N');
368 int stride_c = GetTensorDim(strides_, data_format_, 'C');
369 int stride_h = GetTensorDim(strides_, data_format_, 'H');
370 int stride_w = GetTensorDim(strides_, data_format_, 'W');
371 OP_REQUIRES(
372 context, (stride_n == 1 && stride_c == 1),
373 errors::Unimplemented("Current implementation does not yet support "
374 "strides in the batch and depth dimensions."));
375 OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
376 errors::InvalidArgument(
377 "Row and column strides should be larger than 0."));
378
379 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
380 OP_REQUIRES(context, dilations_.size() == 4,
381 errors::InvalidArgument("Sliding window dilations field must "
382 "specify 4 dimensions"));
383 int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
384 int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
385 int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
386 int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
387 OP_REQUIRES(
388 context, (dilation_n == 1 && dilation_c == 1),
389 errors::Unimplemented("Current implementation does not yet support "
390 "dilations in the batch and depth dimensions."));
391 OP_REQUIRES(
392 context, dilation_h > 0 && dilation_w > 0,
393 errors::InvalidArgument("Dilated rates should be larger than 0."));
394
395 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
396 OP_REQUIRES_OK(context,
397 context->GetAttr("explicit_paddings", &explicit_paddings_));
398 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
399 /*num_dims=*/4, data_format_));
400
401 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
402 cudnn_use_autotune_ = CudnnUseAutotune();
403
404 if (std::is_same<Device, CPUDevice>::value ||
405 std::is_same<T, int32>::value) {
406 OP_REQUIRES(
407 context, data_format_ == FORMAT_NHWC,
408 errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] "
409 "only supports NHWC data format."));
410
411 // TODO(yangzihao): Add a CPU implementation for dilated convolution.
412 OP_REQUIRES(
413 context, (dilation_h == 1 && dilation_w == 1),
414 errors::InvalidArgument(
415 "Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support "
416 "dilation rates larger than 1."));
417 }
418 }
419
420 void Compute(OpKernelContext* context) override {
421 const Tensor& input_sizes = context->input(0);
422 const Tensor& filter = context->input(1);
423 const Tensor& out_backprop = context->input(2);
424
425 OP_REQUIRES(
426 context, out_backprop.dims() == 4,
427 errors::InvalidArgument("input_sizes must be 4-dimensional, got: ",
428 out_backprop.dims()));
429
430 TensorShape input_shape;
431 OP_REQUIRES_OK(context,
432 Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
433 out_backprop.shape(),
434 data_format_, &input_shape));
435
436 Tensor* in_backprop = nullptr;
437 OP_REQUIRES_OK(context,
438 context->allocate_output(0, input_shape, &in_backprop));
439
440 // If there is nothing to compute, return.
441 if (input_shape.num_elements() == 0) {
442 return;
443 }
444
445 // If shapes are valid but `out_backprop` is empty, in_backprop should be
446 // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input.
447 if (out_backprop.NumElements() == 0) {
448 functor::SetZeroFunctor<Device, T> set_zero;
449 set_zero(context->eigen_device<Device>(),
450 in_backprop->template flat<T>());
451 return;
452 }
453
454 // For now we take the stride from the second and third dimensions only (we
455 // do not support striding on the batch or depth dimension).
456 const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
457 const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
458 const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
459 const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
460
461 VLOG(2) << "Conv2DBackpropInput:"
462 << " input: " << input_shape.DebugString()
463 << " filter:" << filter.shape().DebugString()
464 << " out_backprop: " << out_backprop.shape().DebugString()
465 << " strides: [" << stride_rows << ", " << stride_cols << "]"
466 << " dilations: [" << dilation_rows << ", " << dilation_cols << "]";
467
468 LaunchConv2DBackpropInputOp<Device, T> launch;
469 launch(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
470 dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
471 explicit_paddings_, in_backprop, data_format_);
472 }
473
474 private:
475 std::vector<int32> dilations_;
476 std::vector<int32> strides_;
477 TensorFormat data_format_;
478 Padding padding_;
479 std::vector<int64_t> explicit_paddings_;
480
481 bool use_cudnn_ = false;
482 bool cudnn_use_autotune_ = false;
483
484 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp);
485};
486
487// Based on implementation written by Yangqing Jia (jiayq).
488template <typename Device, class T>
489class Conv2DCustomBackpropInputOp : public OpKernel {
490 public:
491 explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
492 : OpKernel(context) {
493 string data_format;
494 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
495 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
496 errors::InvalidArgument("Invalid data format"));
497 OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
498 errors::InvalidArgument(
499 "Conv2DCustomBackpropInputOp only supports NHWC."));
500 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
501 OP_REQUIRES(context, strides_.size() == 4,
502 errors::InvalidArgument("Sliding window strides field must "
503 "specify 4 dimensions"));
504 OP_REQUIRES(
505 context, (strides_[0] == 1 && strides_[3] == 1),
506 errors::Unimplemented("Current implementation does not yet support "
507 "strides in the batch and depth dimensions."));
508 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
509 errors::InvalidArgument(
510 "Row and column strides should be larger than 0."));
511 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
512 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
513 OP_REQUIRES(context, dilations_.size() == 4,
514 errors::InvalidArgument("Sliding window dilations field must "
515 "specify 4 dimensions"));
516 OP_REQUIRES(
517 context, (dilations_[0] == 1 && dilations_[3] == 1),
518 errors::Unimplemented("Current implementation does not yet support "
519 "dilations in the batch and depth dimensions."));
520 // TODO(yangzihao): Add a CPU implementation for dilated convolution.
521 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
522 errors::InvalidArgument(
523 "Current libxsmm and customized CPU implementations do "
524 "not yet support dilation rates larger than 1."));
525 OP_REQUIRES_OK(context,
526 context->GetAttr("explicit_paddings", &explicit_paddings_));
527 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
528 /*num_dims=*/4, data_format_));
529 }
530
531 void Compute(OpKernelContext* context) override {
532 const Tensor& input_sizes = context->input(0);
533 const Tensor& filter = context->input(1);
534 const Tensor& out_backprop = context->input(2);
535 OP_REQUIRES(
536 context, out_backprop.dims() == 4,
537 errors::InvalidArgument("input_sizes must be 4-dimensional, got: ",
538 out_backprop.dims()));
539
540 TensorShape input_shape;
541 OP_REQUIRES_OK(context,
542 Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
543 out_backprop.shape(),
544 data_format_, &input_shape));
545
546 ConvBackpropDimensions dims;
547 OP_REQUIRES_OK(context,
548 ConvBackpropComputeDimensionsV2(
549 "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
550 input_shape, filter.shape(), out_backprop.shape(),
551 /*dilations=*/{1, 1, 1, 1}, strides_, padding_,
552 explicit_paddings_, data_format_, &dims));
553
554 OP_REQUIRES(context, dims.in_depth == filter.shape().dim_size(2),
555 errors::InvalidArgument(
556 "Gradients for grouped convolutions are not "
557 "supported on CPU. Please file a feature request if you "
558 "run into this issue. Computed input depth ",
559 dims.in_depth, " doesn't match filter input depth ",
560 filter.shape().dim_size(2)));
561 OP_REQUIRES(
562 context, dims.out_depth == filter.shape().dim_size(3),
563 errors::InvalidArgument("Computed output depth ", dims.out_depth,
564 " doesn't match filter output depth ",
565 filter.shape().dim_size(3)));
566
567 Tensor* in_backprop = nullptr;
568 OP_REQUIRES_OK(context,
569 context->allocate_output(0, input_shape, &in_backprop));
570
571 // If there is nothing to compute, return.
572 if (input_shape.num_elements() == 0) {
573 return;
574 }
575
576 // If shapes are valid but `out_backprop` is empty, in_backprop should be
577 // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input.
578 if (out_backprop.NumElements() == 0) {
579 functor::SetZeroFunctor<Device, T> set_zero;
580 set_zero(context->eigen_device<Device>(),
581 in_backprop->template flat<T>());
582 return;
583 }
584
585// TODO(ezhulenev): Remove custom kernel and move XSMM support to
586// LaunchConv2DBackpropInputOp functor.
587#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
588 defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
589 int64 pad_top, pad_bottom;
590 int64 pad_left, pad_right;
591 OP_REQUIRES_OK(
592 context,
593 GetWindowedOutputSizeVerbose(
594 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
595 dims.spatial_dims[0].stride, padding_,
596 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
597 OP_REQUIRES_OK(
598 context,
599 GetWindowedOutputSizeVerbose(
600 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
601 dims.spatial_dims[1].stride, padding_,
602 &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
603
604 if (pad_left == pad_right && pad_top == pad_bottom) {
605 if (LaunchXsmmBackwardInputConvolution<Device, T>()(
606 context, context->eigen_device<Device>(),
607 in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
608 out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
609 dims.spatial_dims[1].input_size,
610 static_cast<int>(dims.spatial_dims[0].stride),
611 static_cast<int>(dims.spatial_dims[1].stride),
612 static_cast<int>(pad_top), static_cast<int>(pad_left),
613 data_format_)) {
614 return;
615 }
616 }
617#else
618 int64_t pad_top, pad_bottom;
619 int64_t pad_left, pad_right;
620#endif
621 if (padding_ == Padding::EXPLICIT) {
622 pad_top = explicit_paddings_[2];
623 pad_bottom = explicit_paddings_[3];
624 pad_left = explicit_paddings_[4];
625 pad_right = explicit_paddings_[5];
626 }
627 OP_REQUIRES_OK(
628 context,
629 GetWindowedOutputSizeVerbose(
630 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
631 dims.spatial_dims[0].stride, padding_,
632 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
633 OP_REQUIRES_OK(
634 context,
635 GetWindowedOutputSizeVerbose(
636 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
637 dims.spatial_dims[1].stride, padding_,
638 &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
639
640 // The total dimension size of each kernel.
641 const int filter_total_size = dims.spatial_dims[0].filter_size *
642 dims.spatial_dims[1].filter_size *
643 dims.in_depth;
644 // The output image size is the spatial size of the output.
645 const int output_image_size =
646 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
647
648 // TODO(andydavis) Get L2/L3 cache sizes from device.
649 const size_t l2_cache_size = 256LL << 10;
650 const size_t l3_cache_size = 30LL << 20;
651
652 // Use L3 cache size as target working set size.
653 const size_t target_working_set_size = l3_cache_size / sizeof(T);
654
655 // Calculate size of matrices involved in MatMul: C = A x B.
656 const size_t size_A = output_image_size * dims.out_depth;
657
658 const size_t size_B = filter_total_size * dims.out_depth;
659
660 const size_t size_C = output_image_size * filter_total_size;
661
662 const size_t work_unit_size = size_A + size_B + size_C;
663
664 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
665
666 // Calculate per-thread work unit size.
667 const size_t thread_work_unit_size =
668 work_unit_size / worker_threads.num_threads;
669
670 // Set minimum per-thread work unit size to size of L2 cache.
671 const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
672
673 // Use parallel tensor contractions if there is no batching, or if the
674 // minimum per-thread work unit size threshold has been exceeded.
675 // Otherwise, revert to multiple single-threaded matmul ops running in
676 // parallel to keep all threads busy.
677 // TODO(andydavis) Explore alternatives to branching the code in this way
678 // (i.e. run multiple, parallel tensor contractions in another thread pool).
679 const bool use_parallel_contraction =
680 dims.batch_size == 1 ||
681 thread_work_unit_size >= min_thread_work_unit_size;
682
683 OP_REQUIRES(
684 context, work_unit_size > 0,
685 errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
686 "must all have at least 1 element"));
687
688 const size_t shard_size =
689 use_parallel_contraction
690 ? 1
691 : (target_working_set_size + work_unit_size - 1) / work_unit_size;
692
693 Tensor col_buffer;
694 OP_REQUIRES_OK(context,
695 context->allocate_temp(
696 DataTypeToEnum<T>::value,
697 TensorShape({static_cast<int64_t>(shard_size),
698 static_cast<int64_t>(output_image_size),
699 static_cast<int64_t>(filter_total_size)}),
700 &col_buffer));
701
702 // The input offset corresponding to a single input image.
703 const int input_offset = dims.spatial_dims[0].input_size *
704 dims.spatial_dims[1].input_size * dims.in_depth;
705 // The output offset corresponding to a single output image.
706 const int output_offset = dims.spatial_dims[0].output_size *
707 dims.spatial_dims[1].output_size * dims.out_depth;
708
709 const T* filter_data = filter.template flat<T>().data();
710 T* col_buffer_data = col_buffer.template flat<T>().data();
711 const T* out_backprop_data = out_backprop.template flat<T>().data();
712
713 auto in_backprop_flat = in_backprop->template flat<T>();
714 T* input_backprop_data = in_backprop_flat.data();
715 in_backprop_flat.device(context->eigen_device<Device>()) =
716 in_backprop_flat.constant(T(0));
717
718 if (use_parallel_contraction) {
719 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
720 Eigen::Unaligned>
721 TensorMap;
722 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
723 Eigen::Unaligned>
724 ConstTensorMap;
725
726 // Initialize contraction dims (we need to transpose 'B' below).
727 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
728 contract_dims[0].first = 1;
729 contract_dims[0].second = 1;
730
731 for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
732 // Compute gradient into col_buffer.
733 TensorMap C(col_buffer_data, output_image_size, filter_total_size);
734
735 ConstTensorMap A(out_backprop_data + output_offset * image_id,
736 output_image_size, dims.out_depth);
737 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
738
739 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
740
741 Col2im<T>(
742 col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
743 dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
744 dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
745 pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
746 input_backprop_data);
747
748 input_backprop_data += input_offset;
749 }
750 } else {
751 for (int image_id = 0; image_id < dims.batch_size;
752 image_id += shard_size) {
753 const int shard_limit =
754 std::min(static_cast<int>(shard_size),
755 static_cast<int>(dims.batch_size) - image_id);
756
757 auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom,
758 &pad_right, &output_image_size, &filter_total_size,
759 &input_backprop_data, &col_buffer_data,
760 &out_backprop_data, &filter_data, &input_offset,
761 &output_offset, &size_C](int64_t start, int64_t limit) {
762 for (int shard_id = start; shard_id < limit; ++shard_id) {
763 T* im2col_buf = col_buffer_data + shard_id * size_C;
764 T* input_data = input_backprop_data + shard_id * input_offset;
765 const T* out_data = out_backprop_data + shard_id * output_offset;
766
767 Conv2DCustomBackpropInputMatMulFunctor<T>()(
768 context, out_data, filter_data, filter_total_size,
769 output_image_size, dims.out_depth, im2col_buf);
770
771 Col2im<T>(im2col_buf, dims.in_depth,
772 dims.spatial_dims[0].input_size,
773 dims.spatial_dims[1].input_size,
774 dims.spatial_dims[0].filter_size,
775 dims.spatial_dims[1].filter_size, pad_top, pad_left,
776 pad_bottom, pad_right, dims.spatial_dims[0].stride,
777 dims.spatial_dims[1].stride, input_data);
778 }
779 };
780 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
781 work_unit_size, shard);
782
783 input_backprop_data += input_offset * shard_limit;
784 out_backprop_data += output_offset * shard_limit;
785 }
786 }
787 }
788
789 private:
790 std::vector<int32> dilations_;
791 std::vector<int32> strides_;
792 Padding padding_;
793 std::vector<int64_t> explicit_paddings_;
794 TensorFormat data_format_;
795
796 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
797};
798
799// TODO(ezhulenev): Add a cost model to switch between custom/Eigen ops.
800#define DEFAULT_CONV_2D_BACKPROP_CPU_OP Conv2DCustomBackpropInputOp
801
802#define REGISTER_CONV_2D_BACKPROP_CPU_KERNELS(T) \
803 REGISTER_KERNEL_BUILDER( \
804 Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
805 DEFAULT_CONV_2D_BACKPROP_CPU_OP<CPUDevice, T>); \
806 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \
807 .Device(DEVICE_CPU) \
808 .Label("custom") \
809 .TypeConstraint<T>("T"), \
810 Conv2DCustomBackpropInputOp<CPUDevice, T>); \
811 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \
812 .Device(DEVICE_CPU) \
813 .Label("eigen_tensor") \
814 .TypeConstraint<T>("T"), \
815 Conv2DBackpropInputOp<CPUDevice, T>);
816
817} // namespace tensorflow
818
819#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
820