1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define USE_EIGEN_TENSOR
19#define EIGEN_USE_THREADS
20
21#include <algorithm>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/core/framework/kernel_shape_util.h"
26#include "tensorflow/core/framework/numeric_op.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/framework/tensor_slice.h"
32#include "tensorflow/core/kernels/conv_2d.h"
33#include "tensorflow/core/kernels/conv_grad_ops.h"
34#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
35#include "tensorflow/core/kernels/fill_functor.h"
36#include "tensorflow/core/profiler/lib/scoped_annotation.h"
37
38#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
39#include "tensorflow/core/kernels/xsmm_conv2d.h"
40#endif
41#include "tensorflow/core/lib/core/errors.h"
42#include "tensorflow/core/lib/gtl/array_slice.h"
43#include "tensorflow/core/platform/logging.h"
44#include "tensorflow/core/platform/macros.h"
45#include "tensorflow/core/util/padding.h"
46#include "tensorflow/core/util/tensor_format.h"
47#include "tensorflow/core/util/use_cudnn.h"
48#include "tensorflow/core/util/work_sharder.h"
49
50#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
51#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
52#endif
53
54#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
55#include "tensorflow/core/kernels/conv_ops_gpu.h"
56#include "tensorflow/core/platform/stream_executor.h"
57#include "tensorflow/core/protobuf/autotuning.pb.h"
58#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
59#include "tensorflow/core/util/proto/proto_utils.h"
60#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61#if GOOGLE_CUDA
62#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
63#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
64#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
65#endif // GOOGLE_CUDA
66
67namespace {
68
69// Returns in 'col_data', image patches in storage order (height, width, depth)
70// extracted from image at 'input_data', which is required to be in storage
71// order (batch, height, width, depth).
72// Implementation written by Yangqing Jia (jiayq).
73template <typename T>
74void Im2col(const T* input_data, const int depth, const int height,
75 const int width, const int filter_h, const int filter_w,
76 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
77 const int stride_h, const int stride_w, T* col_data) {
78 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
79 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
80
81 int h_pad = -pad_t;
82 for (int h = 0; h < height_col; ++h) {
83 int w_pad = -pad_l;
84 for (int w = 0; w < width_col; ++w) {
85 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
86 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
87 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
88 memcpy(col_data, input_data + (ih * width + iw) * depth,
89 sizeof(T) * depth);
90 } else {
91 // This should be simply padded with zero.
92 memset(col_data, 0, sizeof(T) * depth);
93 }
94 col_data += depth;
95 }
96 }
97 w_pad += stride_w;
98 }
99 h_pad += stride_h;
100 }
101}
102
103} // namespace
104
105namespace tensorflow {
106
107typedef Eigen::ThreadPoolDevice CPUDevice;
108typedef Eigen::GpuDevice GPUDevice;
109
110template <typename T>
111struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
112 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
113 const Tensor& out_backprop, const Tensor& input,
114 int row_dilation, int col_dilation, int row_stride,
115 int col_stride, const Padding& padding,
116 const std::vector<int64_t>& explicit_paddings,
117 Tensor* filter_backprop, TensorFormat data_format) {
118 std::vector<int32> dilations(4, 1);
119 dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
120 dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation;
121
122 std::vector<int32> strides(4, 1);
123 strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
124 strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
125 TensorShape filter_shape = filter_backprop->shape();
126
127 ConvBackpropDimensions dims;
128 OP_REQUIRES_OK(
129 ctx, ConvBackpropComputeDimensionsV2(
130 "Conv2DBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
131 filter_shape, out_backprop.shape(), dilations, strides,
132 padding, explicit_paddings, data_format, &dims));
133
134 int64_t padding_top = -1, padding_bottom = -1;
135 int64_t padding_left = -1, padding_right = -1;
136 if (padding == EXPLICIT) {
137 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
138 &padding_top, &padding_bottom);
139 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
140 &padding_left, &padding_right);
141 }
142 int64_t expected_out_rows, expected_out_cols;
143 // The function is guaranteed to succeed because we checked the output and
144 // padding was valid earlier.
145 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
146 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
147 row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
148 &padding_bottom));
149 DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
150 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
151 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
152 col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
153 &padding_right));
154 DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
155
156 const CPUDevice& d = ctx->eigen_device<CPUDevice>();
157
158 // WARNING: Need to swap row/col, padding_top/padding_left, and
159 // padding_bottom/padding_right when calling Eigen. Eigen expects tensors
160 // in NWHC format, but Tensorflow uses NHWC.
161
162 auto filter_backprop_t = filter_backprop->tensor<T, 4>();
163 auto input_t = input.tensor<T, 4>();
164 auto out_backprop_t = out_backprop.tensor<T, 4>();
165
166 if (padding != EXPLICIT) {
167 // If padding was not explicitly defined, Eigen spatial convolution
168 // backward filter will infer correct forward paddings from input tensors.
169 filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel(
170 input_t, out_backprop_t, filter_backprop_t.dimension(1),
171 filter_backprop_t.dimension(0), col_stride, row_stride, col_dilation,
172 row_dilation);
173
174 } else {
175 // Otherwise we have to explicitly pad the input, before passing it to
176 // spatial convolution backward filter.
177 Eigen::array<std::pair<int, int>, 4> paddings;
178 paddings[0] = {0, 0};
179 paddings[1] = {padding_top, padding_bottom};
180 paddings[2] = {padding_left, padding_right};
181 paddings[3] = {0, 0};
182
183 auto padded_t = input_t.pad(paddings, T(0));
184
185 // TODO(ezhulenev): Pass explicit paddings to Eigen spatial backward
186 // convolution and do not rely on tensor padding expression.
187 filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel(
188 padded_t, out_backprop_t, filter_backprop_t.dimension(1),
189 filter_backprop_t.dimension(0), col_stride, row_stride, col_dilation,
190 row_dilation);
191 }
192 }
193};
194
195#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
196template <typename Device, class T>
197struct LaunchXsmmBackwardFilter {
198 bool operator()(OpKernelContext* context, const Device& d,
199 typename TTypes<T, 4>::ConstTensor input_backward,
200 typename TTypes<T, 4>::Tensor kernel,
201 typename TTypes<T, 4>::ConstTensor output_backward,
202 int input_rows, int input_cols, int row_stride,
203 int col_stride, int pad_h, int pad_w,
204 TensorFormat data_format) const {
205 return false;
206 }
207};
208
209template <>
210struct LaunchXsmmBackwardFilter<CPUDevice, float> {
211 bool operator()(OpKernelContext* context, const CPUDevice& d,
212 typename TTypes<float, 4>::ConstTensor input,
213 typename TTypes<float, 4>::Tensor filter,
214 typename TTypes<float, 4>::ConstTensor output, int input_rows,
215 int input_cols, int row_stride, int col_stride, int pad_h,
216 int pad_w, TensorFormat data_format) const {
217 auto batch = input.dimension(0);
218 auto in_depth = input.dimension(3);
219 auto out_depth = output.dimension(3);
220 auto filter_rows = filter.dimension(0);
221 auto filter_cols = filter.dimension(1);
222
223 auto num_threads =
224 context->device()->tensorflow_cpu_worker_threads()->num_threads;
225 // See libxsmm_dnn.h for this struct definition.
226 libxsmm_dnn_conv_desc desc;
227 desc.N = batch;
228 desc.C = in_depth;
229 desc.H = input_rows;
230 desc.W = input_cols;
231 desc.K = out_depth;
232 desc.R = filter_rows;
233 desc.S = filter_cols;
234 desc.u = row_stride;
235 desc.v = col_stride;
236 desc.pad_h = pad_h;
237 desc.pad_w = pad_w;
238 desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now.
239 desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now.
240 desc.pad_h_out = 0;
241 desc.pad_w_out = 0;
242 desc.threads = num_threads;
243 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
244 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
245 desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
246 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
247 desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
248 desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
249 desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
250 if (!CanUseXsmmConv2D(desc, data_format)) {
251 return false;
252 }
253
254 auto input_ptr = input.data();
255 auto filter_ptr = filter.data();
256 auto output_ptr = output.data();
257 bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()(
258 context, desc, input_ptr, filter_ptr, output_ptr);
259 return success;
260 }
261};
262#endif
263
264template <typename Device, class T>
265class Conv2DBackpropFilterOp : public OpKernel {
266 public:
267 explicit Conv2DBackpropFilterOp(OpKernelConstruction* context)
268 : OpKernel(context) {
269 string data_format;
270 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
271 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
272 errors::InvalidArgument("Invalid data format"));
273 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
274 int stride_n = GetTensorDim(strides_, data_format_, 'N');
275 int stride_c = GetTensorDim(strides_, data_format_, 'C');
276 int stride_h = GetTensorDim(strides_, data_format_, 'H');
277 int stride_w = GetTensorDim(strides_, data_format_, 'W');
278 OP_REQUIRES(
279 context, (stride_n == 1 && stride_c == 1),
280 errors::InvalidArgument("Current implementation does not yet support "
281 "strides in the batch and depth dimensions."));
282 OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
283 errors::InvalidArgument(
284 "Row and column strides should be larger than 0."));
285 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
286 OP_REQUIRES(context, dilations_.size() == 4,
287 errors::InvalidArgument("Sliding window dilations field must "
288 "specify 4 dimensions"));
289 int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
290 int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
291 int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
292 int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
293 OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
294 errors::InvalidArgument(
295 "Current implementation does not yet support "
296 "dilations in the batch and depth dimensions."));
297 OP_REQUIRES(
298 context, dilation_h > 0 && dilation_w > 0,
299 errors::InvalidArgument("Dilated rates should be larger than 0."));
300
301 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
302 OP_REQUIRES_OK(context,
303 context->GetAttr("explicit_paddings", &explicit_paddings_));
304 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
305 /*num_dims=*/4, data_format_));
306
307 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
308 cudnn_use_autotune_ = CudnnUseAutotune();
309
310 if (std::is_same<Device, CPUDevice>::value) {
311 OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
312 errors::InvalidArgument("Conv2DBackpropFilterOp [CPU] "
313 "only supports NHWC data format."));
314
315 // TODO(yangzihao): Add a CPU implementation for dilated convolution.
316 OP_REQUIRES(
317 context, (dilation_h == 1 && dilation_w == 1),
318 errors::InvalidArgument("Conv2DBackpropFilterOp [CPU] not yet "
319 "support dilation rates larger than 1."));
320 }
321 }
322
323 void Compute(OpKernelContext* context) override {
324 const Tensor& input = context->input(0);
325 const Tensor& filter_sizes = context->input(1);
326 const Tensor& out_backprop = context->input(2);
327 OP_REQUIRES(
328 context, TensorShapeUtils::IsVector(filter_sizes.shape()),
329 errors::InvalidArgument(
330 "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
331 filter_sizes.dims()));
332 TensorShape filter_shape;
333 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
334 filter_sizes.vec<int32>(), &filter_shape));
335
336 Tensor* filter_backprop = nullptr;
337 OP_REQUIRES_OK(context,
338 context->allocate_output(0, filter_shape, &filter_backprop));
339
340 // If there is nothing to compute, return.
341 if (filter_shape.num_elements() == 0) {
342 return;
343 }
344 // If input is empty, set gradients to zero.
345 if (input.shape().num_elements() == 0) {
346 functor::SetZeroFunctor<Device, T> f;
347 f(context->eigen_device<Device>(), filter_backprop->flat<T>());
348 return;
349 }
350
351 // For now we take the stride from the second and third dimensions only (we
352 // do not support striding on the batch or depth dimension).
353 const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
354 const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
355 const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
356 const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
357
358 VLOG(2) << "Conv2DBackpropFilter:"
359 << " input: " << input.shape().DebugString()
360 << " filter:" << filter_shape.DebugString()
361 << " out_backprop: " << out_backprop.shape().DebugString()
362 << " strides: [" << stride_rows << ", " << stride_cols << "]"
363 << " dilations: [" << dilation_rows << ", " << dilation_cols << "]";
364
365 launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
366 dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
367 explicit_paddings_, filter_backprop, data_format_);
368 }
369
370 private:
371 std::vector<int32> dilations_;
372 std::vector<int32> strides_;
373 Padding padding_;
374 std::vector<int64_t> explicit_paddings_;
375 bool use_cudnn_;
376 TensorFormat data_format_;
377 LaunchConv2DBackpropFilterOp<Device, T> launcher_;
378 bool cudnn_use_autotune_;
379
380 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropFilterOp);
381};
382
383// Based on implementation written by Yangqing Jia (jiayq).
384template <typename Device, class T>
385class Conv2DCustomBackpropFilterOp : public OpKernel {
386 public:
387 explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context)
388 : OpKernel(context) {
389 string data_format;
390 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
391 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
392 errors::InvalidArgument("Invalid data format"));
393 OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
394 errors::InvalidArgument(
395 "Conv2DCustomBackpropFilterOp only supports NHWC."));
396 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
397 OP_REQUIRES(context, strides_.size() == 4,
398 errors::InvalidArgument("Sliding window strides field must "
399 "specify 4 dimensions"));
400 OP_REQUIRES(
401 context, (strides_[0] == 1 && strides_[3] == 1),
402 errors::InvalidArgument("Current implementation does not yet support "
403 "strides in the batch and depth dimensions."));
404 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
405 errors::InvalidArgument(
406 "Row and column strides should be larger than 0."));
407 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
408 OP_REQUIRES_OK(context,
409 context->GetAttr("explicit_paddings", &explicit_paddings_));
410 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
411 /*num_dims=*/4, data_format_));
412 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
413 OP_REQUIRES(context, dilations_.size() == 4,
414 errors::InvalidArgument("Sliding window dilations field must "
415 "specify 4 dimensions"));
416 OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
417 errors::InvalidArgument(
418 "Current implementation does not yet support "
419 "dilations in the batch and depth dimensions."));
420 if (std::is_same<Device, CPUDevice>::value ||
421 std::is_same<Device, GPUDevice>::value) {
422 // TODO(yangzihao): Add a CPU implementation for dilated convolution.
423 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
424 errors::InvalidArgument(
425 "Current libxsmm and customized CPU implementations do "
426 "not yet support dilation rates larger than 1."));
427 dilations_ = {1, 1, 1, 1};
428 }
429 }
430
431 void Compute(OpKernelContext* context) override {
432 const Tensor& input = context->input(0);
433 const Tensor& filter_sizes = context->input(1);
434 const Tensor& out_backprop = context->input(2);
435 OP_REQUIRES(
436 context, TensorShapeUtils::IsVector(filter_sizes.shape()),
437 errors::InvalidArgument(
438 "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, "
439 "not ",
440 filter_sizes.dims()));
441 TensorShape filter_shape;
442 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
443 filter_sizes.vec<int32>(), &filter_shape));
444
445 ConvBackpropDimensions dims;
446 OP_REQUIRES_OK(
447 context,
448 ConvBackpropComputeDimensionsV2(
449 "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
450 filter_shape, out_backprop.shape(), dilations_, strides_, padding_,
451 explicit_paddings_, data_format_, &dims));
452
453 Tensor* filter_backprop;
454 OP_REQUIRES_OK(context,
455 context->allocate_output(0, filter_shape, &filter_backprop));
456
457 // If there is nothing to compute, return.
458 if (filter_shape.num_elements() == 0) {
459 return;
460 }
461
462 int64_t pad_top, pad_bottom;
463 int64_t pad_left, pad_right;
464 if (padding_ == Padding::EXPLICIT) {
465 pad_top = explicit_paddings_[2];
466 pad_bottom = explicit_paddings_[3];
467 pad_left = explicit_paddings_[4];
468 pad_right = explicit_paddings_[5];
469 }
470 OP_REQUIRES_OK(
471 context,
472 GetWindowedOutputSizeVerbose(
473 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
474 dims.spatial_dims[0].stride, padding_,
475 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
476 OP_REQUIRES_OK(
477 context,
478 GetWindowedOutputSizeVerbose(
479 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
480 dims.spatial_dims[1].stride, padding_,
481 &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
482#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
483 defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
484 if (pad_left == pad_right && pad_top == pad_bottom) {
485 if (LaunchXsmmBackwardFilter<Device, T>()(
486 context, context->eigen_device<Device>(), input.tensor<T, 4>(),
487 filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
488 dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
489 static_cast<int>(dims.spatial_dims[0].stride),
490 static_cast<int>(dims.spatial_dims[1].stride),
491 static_cast<int>(pad_top), static_cast<int>(pad_left),
492 data_format_)) {
493 return;
494 }
495 }
496#endif
497
498 // The total dimension size of each kernel.
499 const int filter_total_size = dims.spatial_dims[0].filter_size *
500 dims.spatial_dims[1].filter_size *
501 dims.in_depth;
502 OP_REQUIRES(
503 context,
504 filter_total_size * dims.out_depth == filter_backprop->NumElements(),
505 errors::InvalidArgument(
506 "filter_size does not have enough elements, requested ",
507 filter_total_size * dims.out_depth, ", got ",
508 filter_backprop->NumElements()));
509
510 // The output image size is the spatial size of the output.
511 const int output_image_size =
512 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
513
514 // Shard 'batch' images into 'shard_size' groups of images to be fed
515 // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
516 // size ('target_working_set_size') by the matmul size of an individual
517 // image ('work_unit_size').
518
519 // TODO(andydavis)
520 // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
521 // *) Consider reducing 'target_working_set_size' if L3 is shared by
522 // other concurrently running tensorflow ops.
523 const size_t target_working_set_size = (30LL << 20) / sizeof(T);
524
525 const size_t size_A = output_image_size * filter_total_size;
526
527 const size_t size_B = output_image_size * dims.out_depth;
528
529 const size_t size_C = filter_total_size * dims.out_depth;
530
531 const size_t work_unit_size = size_A + size_B + size_C;
532
533 OP_REQUIRES(
534 context, work_unit_size != 0,
535 errors::InvalidArgument(
536 "Work size for convolution would be 0, which is not acceptable"));
537
538 const size_t shard_size =
539 (target_working_set_size + work_unit_size - 1) / work_unit_size;
540
541 Tensor col_buffer;
542 OP_REQUIRES_OK(context,
543 context->allocate_temp(
544 DataTypeToEnum<T>::value,
545 TensorShape({static_cast<int64_t>(shard_size),
546 static_cast<int64_t>(output_image_size),
547 static_cast<int64_t>(filter_total_size)}),
548 &col_buffer));
549
550 // The input offset corresponding to a single input image.
551 const int input_offset = dims.spatial_dims[0].input_size *
552 dims.spatial_dims[1].input_size * dims.in_depth;
553 // The output offset corresponding to a single output image.
554 const int output_offset = dims.spatial_dims[0].output_size *
555 dims.spatial_dims[1].output_size * dims.out_depth;
556
557 const T* input_data = input.template flat<T>().data();
558 T* col_buffer_data = col_buffer.template flat<T>().data();
559 const T* out_backprop_data = out_backprop.template flat<T>().data();
560 T* filter_backprop_data = filter_backprop->template flat<T>().data();
561
562 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
563 Eigen::Unaligned>
564 TensorMap;
565 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
566 Eigen::Unaligned>
567 ConstTensorMap;
568
569 TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
570 C.setZero();
571
572 // Initialize contraction dims (we need to transpose 'A' below).
573 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
574 contract_dims[0].first = 0;
575 contract_dims[0].second = 0;
576
577 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
578
579 for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
580 const int shard_limit =
581 std::min(static_cast<int>(shard_size),
582 static_cast<int>(dims.batch_size) - image_id);
583
584 auto shard = [&input_data, &col_buffer_data, &dims, &pad_top, &pad_left,
585 &pad_bottom, &pad_right, &input_offset,
586 &size_A](int64_t start, int64_t limit) {
587 for (int shard_id = start; shard_id < limit; ++shard_id) {
588 const T* input_data_shard = input_data + shard_id * input_offset;
589 T* col_data_shard = col_buffer_data + shard_id * size_A;
590
591 // When we compute the gradient with respect to the filters, we need
592 // to do im2col to allow gemm-type computation.
593 Im2col<T>(
594 input_data_shard, dims.in_depth, dims.spatial_dims[0].input_size,
595 dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
596 dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
597 pad_right, dims.spatial_dims[0].stride,
598 dims.spatial_dims[1].stride, col_data_shard);
599 }
600 };
601 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
602 size_A, shard);
603
604 ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
605 filter_total_size);
606 ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
607 dims.out_depth);
608
609 // Gradient with respect to filter.
610 C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
611
612 input_data += input_offset * shard_limit;
613 out_backprop_data += output_offset * shard_limit;
614 }
615 }
616
617 private:
618 std::vector<int32> dilations_;
619 std::vector<int32> strides_;
620 Padding padding_;
621 std::vector<int64_t> explicit_paddings_;
622 TensorFormat data_format_;
623
624 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp);
625};
626
627#define REGISTER_CPU_KERNELS(T) \
628 REGISTER_KERNEL_BUILDER( \
629 Name("Conv2DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
630 Conv2DCustomBackpropFilterOp<CPUDevice, T>); \
631 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
632 .Device(DEVICE_CPU) \
633 .Label("custom") \
634 .TypeConstraint<T>("T") \
635 .AttrConstraint("data_format", "NHWC"), \
636 Conv2DCustomBackpropFilterOp<CPUDevice, T>); \
637 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
638 .Device(DEVICE_CPU) \
639 .Label("eigen_tensor") \
640 .TypeConstraint<T>("T") \
641 .AttrConstraint("data_format", "NHWC"), \
642 Conv2DBackpropFilterOp<CPUDevice, T>);
643
644TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
645TF_CALL_half(REGISTER_CPU_KERNELS);
646TF_CALL_float(REGISTER_CPU_KERNELS);
647TF_CALL_double(REGISTER_CPU_KERNELS);
648#undef REGISTER_CPU_KERNELS
649
650// To be used inside depthwise_conv_grad_op.cc.
651template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::bfloat16>;
652template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
653template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
654template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
655
656// GPU definitions.
657#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
658// The slow version (but compiles for GPU)
659
660// A dummy type to group forward backward filter autotune results together.
661struct ConvBackwardFilterAutotuneGroup {
662 static string name() { return "ConvBwdFilter"; }
663};
664
665typedef AutotuneSingleton<ConvBackwardFilterAutotuneGroup, ConvParameters,
666 AutotuneEntry<se::dnn::ConvOp>>
667 AutotuneConvBwdFilter;
668
669template <typename T>
670void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
671 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
672 const Tensor& out_backprop, const Tensor& input, int row_dilation,
673 int col_dilation, int row_stride, int col_stride, const Padding& padding,
674 const std::vector<int64_t>& explicit_paddings, Tensor* filter_backprop,
675 TensorFormat data_format) {
676 using se::dnn::AlgorithmConfig;
677 using se::dnn::AlgorithmDesc;
678 using se::dnn::ProfileResult;
679
680 std::vector<int32> dilations(4, 1);
681 dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
682 dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation;
683
684 std::vector<int32> strides(4, 1);
685 strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
686 strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
687 TensorShape filter_shape = filter_backprop->shape();
688
689 ConvBackpropDimensions dims;
690 OP_REQUIRES_OK(
691 ctx, ConvBackpropComputeDimensionsV2(
692 "Conv2DBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
693 filter_shape, out_backprop.shape(), dilations, strides, padding,
694 explicit_paddings, data_format, &dims));
695
696 int64_t padding_top = -1, padding_bottom = -1;
697 int64_t padding_left = -1, padding_right = -1;
698 if (padding == EXPLICIT) {
699 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
700 &padding_bottom);
701 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
702 &padding_right);
703 }
704 int64_t expected_out_rows, expected_out_cols;
705 // The function is guaranteed to succeed because we checked the output and
706 // padding was valid earlier.
707 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
708 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
709 row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
710 &padding_bottom));
711 DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
712 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
713 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
714 col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
715 &padding_right));
716 DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
717
718 auto* stream = ctx->op_device_context()->stream();
719 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
720
721 if (!use_cudnn) {
722 ctx->SetStatus(errors::Unimplemented(
723 "Conv2DBackprop for GPU is not currently supported "
724 "without cudnn"));
725 return;
726 }
727
728 // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
729 // input depth, it's a depthwise convolution. More generally, if the filter
730 // in-depth divides but is smaller than the input depth, it is a grouped
731 // convolution.
732 bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
733 bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
734 if (!cudnn_disable_conv_1x1_optimization_ &&
735 dims.spatial_dims[0].filter_size == 1 &&
736 dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
737 dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
738 data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
739 const uint64 m = dims.in_depth;
740 const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
741 dims.spatial_dims[1].input_size;
742 const uint64 n = dims.out_depth;
743
744 // The shape of output backprop is
745 // [batch, out_rows, out_cols, out_depth]
746 // From cublas's perspective, it is: n x k
747 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
748 out_backprop.template flat<T>().size());
749
750 // The shape of input is
751 // [batch, in_rows, in_cols, in_depth],
752 // From cublas's perspective, it is: m x k
753 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
754 input.template flat<T>().size());
755
756 // the shape of the filter backprop from the conv_2d should be
757 // [1, 1, in_depth, out_depth]
758 // From cublas's perspective, it is: n x m
759 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
760 filter_backprop->template flat<T>().size());
761
762 OP_REQUIRES_OK(
763 ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
764 se::blas::Transpose::kTranspose, n, m, k,
765 a_ptr, n, b_ptr, m, &c_ptr, n,
766 se::blas::kDefaultComputePrecision));
767 return;
768 } else if (dims.spatial_dims[0].filter_size ==
769 dims.spatial_dims[0].input_size &&
770 dims.spatial_dims[1].filter_size ==
771 dims.spatial_dims[1].input_size &&
772 !is_grouped_convolution && padding == VALID &&
773 data_format == FORMAT_NHWC) {
774 // The input data and filter have the same height/width, and we are not
775 // using grouped convolution, so call cublas directly.
776 const uint64 m = dims.spatial_dims[0].input_size *
777 dims.spatial_dims[1].input_size * dims.in_depth;
778 const uint64 k = dims.batch_size;
779 const uint64 n = dims.out_depth;
780
781 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
782 input.template flat<T>().size());
783 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
784 out_backprop.template flat<T>().size());
785 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
786 filter_backprop->template flat<T>().size());
787
788 OP_REQUIRES_OK(
789 ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
790 se::blas::Transpose::kTranspose, n, m, k,
791 b_ptr, n, a_ptr, m, &c_ptr, n,
792 se::blas::kDefaultComputePrecision));
793 return;
794 }
795
796 const int64_t common_padding_rows = std::min(padding_top, padding_bottom);
797 const int64_t common_padding_cols = std::min(padding_left, padding_right);
798 Tensor compatible_input;
799 if (padding_top != padding_bottom || padding_left != padding_right) {
800 // Pad the input in the same way we did during the forward pass, so that
801 // cuDNN or MIOpen receives the same input during the backward pass function
802 // as it did during the forward pass function.
803 const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top);
804 const int64_t padding_cols_diff = std::abs(padding_right - padding_left);
805 const int64_t new_in_rows =
806 dims.spatial_dims[0].input_size + padding_rows_diff;
807 const int64_t new_in_cols =
808 dims.spatial_dims[1].input_size + padding_cols_diff;
809 const int64_t input_pad_top = padding_top - common_padding_rows;
810 const int64_t input_pad_bottom = padding_bottom - common_padding_rows;
811 const int64_t input_pad_left = padding_left - common_padding_cols;
812 const int64_t input_pad_right = padding_right - common_padding_cols;
813 OP_REQUIRES_OK(
814 ctx, ctx->allocate_temp(
815 DataTypeToEnum<T>::value,
816 ShapeFromFormat(data_format, dims.batch_size, new_in_rows,
817 new_in_cols, dims.in_depth),
818 &compatible_input));
819
820 functor::PadInput<GPUDevice, T, int, 4>()(
821 ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()),
822 {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
823 {{static_cast<int>(input_pad_bottom),
824 static_cast<int>(input_pad_right)}},
825 To32Bit(compatible_input.tensor<T, 4>()), data_format, T{});
826 } else {
827 compatible_input = input;
828 }
829
830 CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
831 << "Negative row or col paddings: (" << common_padding_rows << ", "
832 << common_padding_cols << ")";
833
834 // The Tensor Core in NVIDIA Volta+ GPUs supports efficient convolution with
835 // fp16 in NHWC data layout. In all other configurations it's more efficient
836 // to run computation in NCHW data format.
837 const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
838 stream->GetCudaComputeCapability().IsAtLeast(
839 se::CudaComputeCapability::VOLTA);
840
841 // We only do one directional conversion: NHWC->NCHW. We never convert in the
842 // other direction. Grappler layout optimizer selects the preferred layout and
843 // adds necessary annotations to the graph.
844 const TensorFormat compute_data_format =
845 (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
846 : FORMAT_NCHW;
847
848 VLOG(3) << "Compute Conv2DBackpropFilter with cuDNN:"
849 << " data_format=" << ToString(data_format)
850 << " compute_data_format=" << ToString(compute_data_format);
851
852 constexpr auto kComputeInNHWC =
853 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
854 se::dnn::FilterLayout::kOutputYXInput);
855 constexpr auto kComputeInNCHW =
856 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
857 se::dnn::FilterLayout::kOutputInputYX);
858
859 se::dnn::DataLayout compute_data_layout;
860 se::dnn::FilterLayout filter_layout;
861
862 std::tie(compute_data_layout, filter_layout) =
863 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
864
865 se::dnn::BatchDescriptor input_desc;
866 input_desc.set_count(dims.batch_size)
867 .set_height(GetTensorDim(compatible_input, data_format, 'H'))
868 .set_width(GetTensorDim(compatible_input, data_format, 'W'))
869 .set_feature_map_count(dims.in_depth)
870 .set_layout(compute_data_layout);
871 se::dnn::BatchDescriptor output_desc;
872 output_desc.set_count(dims.batch_size)
873 .set_height(dims.spatial_dims[0].output_size)
874 .set_width(dims.spatial_dims[1].output_size)
875 .set_feature_map_count(dims.out_depth)
876 .set_layout(compute_data_layout);
877 se::dnn::FilterDescriptor filter_desc;
878 filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
879 .set_input_filter_width(dims.spatial_dims[1].filter_size)
880 .set_input_feature_map_count(filter_shape.dim_size(2))
881 .set_output_feature_map_count(filter_shape.dim_size(3))
882 .set_layout(filter_layout);
883 se::dnn::ConvolutionDescriptor conv_desc;
884 conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
885 .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
886 .set_vertical_filter_stride(dims.spatial_dims[0].stride)
887 .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
888 .set_zero_padding_height(common_padding_rows)
889 .set_zero_padding_width(common_padding_cols)
890 .set_group_count(dims.in_depth / filter_shape.dim_size(2));
891
892 // Tensorflow filter format: HWIO
893 // cuDNN filter formats: (data format) -> (filter format)
894 // (1) NCHW -> OIHW
895 // (2) NHWC -> OHWI
896 //
897 // We compute filter backprop into temporary tensor, and then convert it to
898 // the HWIO data format at the end.
899
900 Tensor pre_transformed_filter_backprop;
901 OP_REQUIRES_OK(
902 ctx,
903 ctx->allocate_temp(
904 DataTypeToEnum<T>::value,
905 TensorShape({filter_shape.dim_size(3), filter_shape.dim_size(2),
906 filter_shape.dim_size(0), filter_shape.dim_size(1)}),
907 &pre_transformed_filter_backprop));
908
909 Tensor transformed_out_backprop;
910 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
911 VLOG(4) << "Convert the `out_backprop` tensor from NHWC to NCHW.";
912 TensorShape compute_shape = ShapeFromFormat(
913 compute_data_format, dims.batch_size, dims.spatial_dims[0].output_size,
914 dims.spatial_dims[1].output_size, dims.out_depth);
915 if (dims.out_depth > 1) {
916 OP_REQUIRES_OK(ctx,
917 ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape,
918 &transformed_out_backprop));
919 functor::NHWCToNCHW<GPUDevice, T, 4>()(
920 ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
921 transformed_out_backprop.tensor<T, 4>());
922 } else {
923 // If depth <= 1, just reshape.
924 CHECK(transformed_out_backprop.CopyFrom(out_backprop, compute_shape));
925 }
926 } else {
927 transformed_out_backprop = out_backprop;
928 }
929
930 Tensor transformed_input;
931 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
932 VLOG(4) << "Convert the `input` tensor from NHWC to NCHW.";
933 TensorShape compute_shape = ShapeFromFormat(
934 compute_data_format, GetTensorDim(compatible_input, data_format, 'N'),
935 GetTensorDim(compatible_input, data_format, 'H'),
936 GetTensorDim(compatible_input, data_format, 'W'),
937 GetTensorDim(compatible_input, data_format, 'C'));
938 if (compute_shape.dim_size(1) > 1) {
939 OP_REQUIRES_OK(ctx,
940 ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape,
941 &transformed_input));
942 functor::NHWCToNCHW<GPUDevice, T, 4>()(
943 ctx->eigen_device<GPUDevice>(),
944 const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
945 transformed_input.tensor<T, 4>());
946 } else {
947 // If depth <= 1, just reshape.
948 CHECK(transformed_input.CopyFrom(compatible_input, compute_shape));
949 }
950 } else {
951 transformed_input = compatible_input;
952 }
953
954 se::DeviceMemory<T> out_backprop_ptr =
955 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
956 transformed_out_backprop.template flat<T>().size());
957 se::DeviceMemory<T> filter_backprop_ptr =
958 AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(),
959 pre_transformed_filter_backprop.template flat<T>().size());
960 auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(),
961 transformed_input.template flat<T>().size());
962
963 static int64_t ConvolveBackwardFilterScratchSize =
964 GetDnnWorkspaceLimitOrDefault();
965 int device_id = stream->parent()->device_ordinal();
966 DataType dtype = input.dtype();
967 ConvParameters conv_parameters = {
968 dims.batch_size, // batch
969 dims.in_depth, // in_depths
970 {{input_desc.height(), // in_rows
971 input_desc.width()}}, // in_cols
972 compute_data_format, // compute_data_format
973 dims.out_depth, // out_depths
974 {{dims.spatial_dims[0].filter_size, // filter_rows
975 dims.spatial_dims[1].filter_size, // filter_cols
976 filter_shape.dim_size(2)}}, // filter_depth
977 {{dims.spatial_dims[0].dilation, // dilation_rows
978 dims.spatial_dims[1].dilation}}, // dilation_cols
979 {{dims.spatial_dims[0].stride, // stride_rows
980 dims.spatial_dims[1].stride}}, // stride_cols
981 {{common_padding_rows, // padding_rows
982 common_padding_cols}}, // padding_cols
983 dtype, // tensor datatype
984 device_id, // device_id
985 conv_desc.group_count() // group_count
986 };
987
988 auto entry_or = AutotuneUnfusedConv(
989 cudnn_use_autotune, AutotuneConvBwdFilter::GetInstance(), conv_parameters,
990 ctx, se::dnn::ConvolutionKind::BACKWARD_FILTER, input_desc, input_ptr,
991 filter_desc, filter_backprop_ptr, conv_desc, output_desc,
992 out_backprop_ptr, ConvolveBackwardFilterScratchSize);
993 OP_REQUIRES_OK(ctx, entry_or.status());
994 auto autotune_entry = std::move(entry_or).value();
995
996 DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, ctx);
997 Status cudnn_launch_status = LaunchAutotunedConv(
998 autotune_entry, &scratch_allocator,
999 se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, input_ptr,
1000 filter_desc, filter_backprop_ptr, conv_desc, output_desc,
1001 out_backprop_ptr);
1002 if (!cudnn_launch_status.ok()) {
1003 ctx->SetStatus(cudnn_launch_status);
1004 return;
1005 }
1006
1007 FilterTensorFormat src_filter_format =
1008 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1009
1010 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1011 functor::ReverseTransformFilter<GPUDevice, T, 4>()(
1012 ctx->eigen_device<GPUDevice>(), src_filter_format,
1013 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
1014 filter_backprop->tensor<T, 4>());
1015}
1016
1017// Forward declarations of the functor specializations for GPU.
1018namespace functor {
1019#define DECLARE_GPU_SPEC(T) \
1020 template <> \
1021 void TransformFilter<GPUDevice, T, int, 4>::operator()( \
1022 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
1023 typename TTypes<T, 4, int>::ConstTensor in, \
1024 typename TTypes<T, 4, int>::Tensor out); \
1025 extern template struct TransformFilter<GPUDevice, T, int, 4>; \
1026 template <> \
1027 void PadInput<GPUDevice, T, int, 4>::operator()( \
1028 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
1029 const std::array<int, 2>& padding_left, \
1030 const std::array<int, 2>& padding_right, \
1031 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
1032 const T& padding_value); \
1033 extern template struct PadInput<GPUDevice, T, int, 4>;
1034
1035DECLARE_GPU_SPEC(float);
1036DECLARE_GPU_SPEC(Eigen::half);
1037DECLARE_GPU_SPEC(double);
1038#undef DECLARE_GPU_SPEC
1039} // namespace functor
1040
1041REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
1042 .Device(DEVICE_GPU)
1043 .TypeConstraint<double>("T")
1044 .HostMemory("filter_sizes"),
1045 Conv2DBackpropFilterOp<GPUDevice, double>);
1046REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
1047 .Device(DEVICE_GPU)
1048 .TypeConstraint<float>("T")
1049 .HostMemory("filter_sizes"),
1050 Conv2DBackpropFilterOp<GPUDevice, float>);
1051REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
1052 .Device(DEVICE_GPU)
1053 .TypeConstraint<Eigen::half>("T")
1054 .HostMemory("filter_sizes"),
1055 Conv2DBackpropFilterOp<GPUDevice, Eigen::half>);
1056
1057// To be used inside depthwise_conv_grad_op.cc.
1058// TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
1059template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
1060template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
1061template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
1062
1063#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1064
1065} // namespace tensorflow
1066