1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define USE_EIGEN_TENSOR
19#define EIGEN_USE_THREADS
20
21#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22#define EIGEN_USE_GPU
23#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24
25#include "tensorflow/core/kernels/conv_ops.h"
26
27#include <string.h>
28
29#include <atomic>
30#include <map>
31#include <utility>
32#include <vector>
33
34#include "absl/synchronization/blocking_counter.h"
35#include "tensorflow/core/framework/allocator.h"
36#include "tensorflow/core/framework/bounds_check.h"
37#include "tensorflow/core/framework/kernel_shape_util.h"
38#include "tensorflow/core/framework/numeric_op.h"
39#include "tensorflow/core/framework/op_kernel.h"
40#include "tensorflow/core/framework/register_types.h"
41#include "tensorflow/core/framework/tensor.h"
42#include "tensorflow/core/framework/tensor_shape.h"
43#include "tensorflow/core/framework/tensor_slice.h"
44#include "tensorflow/core/framework/types.h"
45#include "tensorflow/core/kernels/conv_2d.h"
46#include "tensorflow/core/kernels/deep_conv2d.h"
47#include "tensorflow/core/kernels/fill_functor.h"
48#include "tensorflow/core/kernels/ops_util.h"
49#include "tensorflow/core/lib/core/errors.h"
50#include "tensorflow/core/lib/gtl/array_slice.h"
51#include "tensorflow/core/lib/strings/numbers.h"
52#include "tensorflow/core/lib/strings/str_util.h"
53#include "tensorflow/core/platform/logging.h"
54#include "tensorflow/core/platform/macros.h"
55#include "tensorflow/core/profiler/lib/scoped_annotation.h"
56#include "tensorflow/core/util/padding.h"
57#include "tensorflow/core/util/tensor_format.h"
58#include "tensorflow/core/util/use_cudnn.h"
59
60#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
61#include "tensorflow/core/kernels/xsmm_conv2d.h"
62#endif
63
64#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
65#include "tensorflow/core/kernels/conv_ops_gpu.h"
66#include "tensorflow/core/platform/stream_executor.h"
67#include "tensorflow/core/protobuf/autotuning.pb.h"
68#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
69#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
70#include "tensorflow/core/util/proto/proto_utils.h"
71#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
72#if GOOGLE_CUDA
73#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
74#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
75#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
76#endif // GOOGLE_CUDA
77
78namespace tensorflow {
79
80typedef Eigen::ThreadPoolDevice CPUDevice;
81typedef Eigen::GpuDevice GPUDevice;
82
83namespace {
84template <typename Device, typename T>
85struct LaunchGeneric {
86 void operator()(OpKernelContext* ctx, const Tensor& input,
87 const Tensor& filter, int row_stride, int col_stride,
88 int row_dilation, int col_dilation, const Padding& padding,
89 const std::vector<int64_t>& explicit_paddings, Tensor* output,
90 TensorFormat data_format) {
91 CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
92 "supports NHWC tensor format for now.";
93 if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
94 col_stride == 1 && (padding == SAME || padding == VALID)) {
95 // For 1x1 kernel, the 2D convolution is reduced to matrix
96 // multiplication.
97 //
98 // TODO(vrv): We should be able to call SpatialConvolution
99 // and it will produce the same result, but doing so
100 // led to NaNs during training. Using matmul instead for now.
101 int conv_width = 1; // Width for the convolution step.
102 for (int i = 0; i < 3; ++i) {
103 conv_width *= output->dim_size(i);
104 }
105
106 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
107 dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
108 functor::MatMulConvFunctor<Device, T>()(
109 ctx->eigen_device<Device>(),
110 output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
111 input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
112 filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
113 dim_pair);
114 } else if (filter.dim_size(0) == input.dim_size(1) &&
115 filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
116 col_dilation == 1 && padding == VALID) {
117 // If the input data and filter have the same height/width,
118 // the 2D convolution is reduced to matrix multiplication.
119 const int k = // Length of reduction dimension.
120 filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
121
122 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
123 dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
124 functor::MatMulConvFunctor<Device, T>()(
125 ctx->eigen_device<Device>(),
126 output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
127 input.shaped<T, 2>({input.dim_size(0), k}),
128 filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
129 } else {
130 if (padding == EXPLICIT) {
131 functor::SpatialConvolution<Device, T>()(
132 ctx->eigen_device<Device>(), output->tensor<T, 4>(),
133 input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
134 row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
135 static_cast<int>(explicit_paddings[3]),
136 static_cast<int>(explicit_paddings[4]),
137 static_cast<int>(explicit_paddings[5]));
138 } else {
139 functor::SpatialConvolution<Device, T>()(
140 ctx->eigen_device<Device>(), output->tensor<T, 4>(),
141 input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
142 row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
143 }
144 }
145 }
146};
147
148// Compute grouped 2D convolutions on CPU. Unlike grouped convolution
149// implementation in cuDNN this is faaaaaar from optimal and needs more work
150// to deliver competitive performance. Currently it exists to close the feature
151// parity gap between convolution operations on different devices.
152template <typename T>
153struct LaunchGrouped {
154 void operator()(OpKernelContext* ctx, const Tensor& input,
155 const Tensor& filter, int row_stride, int col_stride,
156 int row_dilation, int col_dilation, const Padding& padding,
157 const std::vector<int64_t>& explicit_paddings, Tensor* output,
158 TensorFormat data_format) {
159 DCHECK(data_format == FORMAT_NHWC)
160 << "Grouped conv implementation only "
161 "supports NHWC tensor format for now.";
162
163 const int64_t in_depth = input.dim_size(3);
164 const int64_t patch_depth = filter.dim_size(2);
165 const int64_t num_groups = in_depth / patch_depth;
166
167 // Shuffle input/filter tensors to have group as a leading dimension.
168 std::array<int64_t, 5> shuffle({3, 0, 1, 2, 4});
169
170 // Compute pre shuffle dimemnsions.
171 auto pre_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
172 return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2),
173 num_groups, tensor.dim_size(3) / num_groups};
174 };
175
176 // Compute post shuffle dimemnsions.
177 auto post_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
178 return {num_groups, tensor.dim_size(0), tensor.dim_size(1),
179 tensor.dim_size(2), tensor.dim_size(3) / num_groups};
180 };
181
182 auto& device = ctx->eigen_device<CPUDevice>();
183
184 absl::BlockingCounter shuffles_completed(2);
185 auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
186
187 // Shuffle input into temporary tensor.
188 Tensor input_shuffled;
189 OP_REQUIRES_OK(
190 ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
191 &input_shuffled));
192 input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
193 input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
194
195 // Shuffle filter into temporary tensor.
196 Tensor filter_shuffled;
197 OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
198 TensorShape(post_shuffle(filter)),
199 &filter_shuffled));
200 filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
201 filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
202
203 // Wait for the completion of input/filter shuffles.
204 shuffles_completed.Wait();
205
206 // Write group convolution results into temporary output tensor.
207 Tensor output_shuffled;
208 OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
209 TensorShape(post_shuffle(*output)),
210 &output_shuffled));
211
212 for (int64_t i = 0; i < num_groups; ++i) {
213 // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
214 // will lead to deadlock, SpatialConvolution has to use async Eigen
215 // assignment). This requires small changes to Eigen to support async
216 // exeuction for tensor chipping operation.
217
218 // TODO(ezhulenev): Grouped convolution should also support 1x1 filter
219 // optimization.
220
221 auto input_slice = input_shuffled.tensor<T, 5>().template chip<0>(i);
222 auto filter_slice = filter_shuffled.tensor<T, 5>().template chip<0>(i);
223 auto output_slice = output_shuffled.tensor<T, 5>().template chip<0>(i);
224
225 if (padding == EXPLICIT) {
226 functor::SpatialConvolution<CPUDevice, T>()(
227 ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
228 filter_slice, row_stride, col_stride, row_dilation, col_dilation,
229 static_cast<int>(explicit_paddings[2]),
230 static_cast<int>(explicit_paddings[3]),
231 static_cast<int>(explicit_paddings[4]),
232 static_cast<int>(explicit_paddings[5]));
233 } else {
234 functor::SpatialConvolution<CPUDevice, T>()(
235 ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
236 filter_slice, row_stride, col_stride, row_dilation, col_dilation,
237 BrainPadding2EigenPadding(padding));
238 }
239 }
240
241 // Shuffle temporary output back into pre-shuffled shape.
242 std::array<int64_t, 5> rev_shuffle({1, 2, 3, 0, 4});
243 output->shaped<T, 5>(pre_shuffle(*output)).device(device) =
244 output_shuffled.tensor<T, 5>().shuffle(rev_shuffle);
245 }
246};
247
248} // namespace
249
250template <typename T>
251struct LaunchConv2DOp<CPUDevice, T> {
252 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
253 const Tensor& input, const Tensor& filter, int row_dilation,
254 int col_dilation, int row_stride, int col_stride,
255 const Padding& padding,
256 const std::vector<int64_t>& explicit_paddings, Tensor* output,
257 TensorFormat data_format) {
258 if (data_format != FORMAT_NHWC) {
259 ctx->SetStatus(errors::Unimplemented(
260 "The Conv2D op currently only supports the NHWC tensor format on the "
261 "CPU. The op was given the format: ",
262 ToString(data_format)));
263 return;
264 }
265
266 for (int64_t explicit_padding : explicit_paddings) {
267 if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
268 ctx->SetStatus(errors::InvalidArgument("filter too large"));
269 return;
270 }
271 }
272
273 const int64_t in_depth = input.dim_size(3);
274 const int64_t out_depth = output->dim_size(3);
275 const int64_t patch_depth = filter.dim_size(2);
276
277 if (patch_depth <= 0) {
278 ctx->SetStatus(errors::InvalidArgument(
279 "filter depth must be stricly positive, got ", patch_depth));
280 return;
281 }
282 if (in_depth % patch_depth != 0) {
283 ctx->SetStatus(errors::InvalidArgument(
284 "input depth must be evenly divisible by filter depth: ", in_depth,
285 " vs ", patch_depth));
286 return;
287 }
288 if (filter.NumElements() <= 0) {
289 ctx->SetStatus(
290 errors::InvalidArgument("filter must not have zero elements "
291 "(i.e. all dimensions must be non-zero)"));
292 return;
293 }
294
295 const int64_t num_groups = in_depth / patch_depth;
296 if (num_groups <= 0) {
297 ctx->SetStatus(errors::InvalidArgument(
298 "number of groups must be stricly positive, got ", num_groups));
299 return;
300 }
301 if (out_depth % num_groups != 0 || out_depth < num_groups) {
302 ctx->SetStatus(errors::InvalidArgument(
303 "output depth must be evenly divisible by number of groups: ",
304 out_depth, " vs ", num_groups));
305 return;
306 }
307
308 if (in_depth != patch_depth) {
309 LaunchGrouped<T>()(ctx, input, filter, row_stride, col_stride,
310 row_dilation, col_dilation, padding, explicit_paddings,
311 output, data_format);
312 } else {
313 LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
314 row_dilation, col_dilation, padding,
315 explicit_paddings, output, data_format);
316 }
317 }
318};
319
320#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
321template <>
322struct LaunchConv2DOp<GPUDevice, int32> {
323 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
324 const Tensor& input, const Tensor& filter, int row_dilation,
325 int col_dilation, int row_stride, int col_stride,
326 const Padding& padding,
327 const std::vector<int64_t>& explicit_paddings, Tensor* output,
328 TensorFormat data_format) {
329 if (data_format != FORMAT_NHWC) {
330 ctx->SetStatus(
331 errors::Unimplemented("The Conv2D op currently only supports the "
332 "NHWC tensor format for integer types. "
333 "The op was given the format: ",
334 ToString(data_format)));
335 return;
336 }
337 const int64_t in_depth = GetTensorDim(input, data_format, 'C');
338 OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
339 errors::Unimplemented(
340 "The Conv2D op currently does not support grouped "
341 "convolutions for integer types. A grouped convolution was "
342 "attempted to be run because the input depth of ",
343 in_depth, " does not match the filter input depth of ",
344 filter.dim_size(2)));
345 OP_REQUIRES(
346 ctx, filter.NumElements() > 0,
347 errors::InvalidArgument("filter must not have zero elements "
348 "(i.e. all dimensions must be non-zero)"));
349
350 for (int64_t explicit_padding : explicit_paddings) {
351 if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
352 ctx->SetStatus(errors::InvalidArgument("filter too large"));
353 return;
354 }
355 }
356 LaunchGeneric<GPUDevice, int32>()(
357 ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation,
358 padding, explicit_paddings, output, data_format);
359 }
360};
361#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
362
363template <typename Device, typename T>
364class LaunchDeepConvOp {
365 public:
366 static bool Run(OpKernelContext* ctx, const Tensor& input,
367 const Tensor& filter, int batch, int input_rows,
368 int input_cols, int in_depth, int filter_rows,
369 int filter_cols, int pad_rows, int pad_cols, int out_rows,
370 int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
371 int /*dilation_cols*/, int /*stride_rows*/,
372 int /*stride_cols*/, Tensor* /*output*/,
373 TensorFormat /*data_format*/) {
374 return false;
375 }
376};
377
378// Conditionally launches DeepConv operation based on convolution parameters.
379template <>
380class LaunchDeepConvOp<CPUDevice, float> {
381 public:
382 static bool Run(OpKernelContext* ctx, const Tensor& input,
383 const Tensor& filter, int batch, int input_rows,
384 int input_cols, int in_depth, int filter_rows,
385 int filter_cols, int pad_rows, int pad_cols, int out_rows,
386 int out_cols, int out_depth, int dilation_rows,
387 int dilation_cols, int stride_rows, int stride_cols,
388 Tensor* output, TensorFormat data_format) {
389 if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
390 dilation_cols != 1 ||
391 !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
392 in_depth, out_depth, out_rows, out_cols)) {
393 return false;
394 }
395
396 Conv2DArgs args;
397 args.batch = batch;
398 args.in_rows = input_rows;
399 args.in_cols = input_cols;
400 args.in_depth = in_depth;
401 args.filter_rows = filter_rows;
402 args.filter_cols = filter_cols;
403 args.pad_rows = pad_rows;
404 args.pad_cols = pad_cols;
405 args.out_rows = out_rows;
406 args.out_cols = out_cols;
407 args.out_depth = out_depth;
408
409 auto input_ptr = input.template flat<float>().data();
410 auto filter_ptr = filter.template flat<float>().data();
411 auto output_ptr = output->template flat<float>().data();
412
413 functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
414 output_ptr);
415 return true;
416 }
417};
418
419#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
420template <typename Device, typename T>
421class LaunchXsmmConvOp {
422 public:
423 static bool Run(OpKernelContext* ctx, const Tensor& input,
424 const Tensor& filter, int batch, int input_rows,
425 int input_cols, int in_depth, int filter_rows,
426 int filter_cols, int pad_rows, int pad_cols, int out_rows,
427 int out_cols, int out_depth, int stride_rows, int stride_cols,
428 int dilation_rows, int dilation_cols, Tensor* output,
429 TensorFormat data_format) {
430 return false;
431 }
432};
433
434template <>
435class LaunchXsmmConvOp<CPUDevice, float> {
436 public:
437 static bool Run(OpKernelContext* ctx, const Tensor& input,
438 const Tensor& filter, int batch, int input_rows,
439 int input_cols, int in_depth, int filter_rows,
440 int filter_cols, int pad_rows, int pad_cols, int out_rows,
441 int out_cols, int out_depth, int dilation_rows,
442 int dilation_cols, int stride_rows, int stride_cols,
443 Tensor* output, TensorFormat data_format) {
444 auto num_threads =
445 ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
446 // See libxsmm_dnn.h for this struct definition.
447 libxsmm_dnn_conv_desc desc;
448 desc.N = batch;
449 desc.C = in_depth;
450 desc.H = input_rows;
451 desc.W = input_cols;
452 desc.K = out_depth;
453 desc.R = filter_rows;
454 desc.S = filter_cols;
455 desc.u = stride_rows;
456 desc.v = stride_cols;
457 desc.pad_h = pad_rows;
458 desc.pad_w = pad_cols;
459 desc.pad_h_in = 0;
460 desc.pad_w_in = 0;
461 desc.pad_h_out = 0;
462 desc.pad_w_out = 0;
463 desc.threads = num_threads;
464 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
465 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
466 desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
467 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
468 desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
469 desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
470 desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
471 if (dilation_rows != 1 || dilation_cols != 1 ||
472 !CanUseXsmmConv2D(desc, data_format)) {
473 return false;
474 }
475
476 auto input_ptr = input.template flat<float>().data();
477 auto filter_ptr = filter.template flat<float>().data();
478 auto output_ptr = output->template flat<float>().data();
479
480 bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
481 ctx, desc, input_ptr, filter_ptr, output_ptr);
482 return success;
483 }
484};
485#endif
486
487#define TF_REQUIRES(EXP, STATUS) \
488 do { \
489 if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
490 } while (false)
491
492Status InitConv2DParameters(const OpKernelConstruction* context,
493 Conv2DParameters* params) {
494 TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
495 TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
496 TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
497 if (context->HasAttr("explicit_paddings")) {
498 TF_RETURN_IF_ERROR(
499 context->GetAttr("explicit_paddings", &params->explicit_paddings));
500 }
501 string data_format_string;
502 TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
503 TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
504 errors::InvalidArgument("Invalid data format"));
505
506 const auto& strides = params->strides;
507 const auto& dilations = params->dilations;
508 const auto& data_format = params->data_format;
509
510 TF_REQUIRES(dilations.size() == 4,
511 errors::InvalidArgument("Sliding window dilations field must "
512 "specify 4 dimensions"));
513 TF_REQUIRES(strides.size() == 4,
514 errors::InvalidArgument("Sliding window strides field must "
515 "specify 4 dimensions"));
516 const int64_t stride_n = GetTensorDim(strides, data_format, 'N');
517 const int64_t stride_c = GetTensorDim(strides, data_format, 'C');
518 const int64_t stride_h = GetTensorDim(strides, data_format, 'H');
519 const int64_t stride_w = GetTensorDim(strides, data_format, 'W');
520 TF_REQUIRES(
521 stride_n == 1 && stride_c == 1,
522 errors::Unimplemented("Current implementation does not yet support "
523 "strides in the batch and depth dimensions."));
524 TF_REQUIRES(stride_h > 0 && stride_w > 0,
525 errors::InvalidArgument(
526 "Row and column strides should be larger than 0."));
527
528 const int64_t dilation_n = GetTensorDim(dilations, data_format, 'N');
529 const int64_t dilation_c = GetTensorDim(dilations, data_format, 'C');
530 const int64_t dilation_h = GetTensorDim(dilations, data_format, 'H');
531 const int64_t dilation_w = GetTensorDim(dilations, data_format, 'W');
532 TF_REQUIRES(
533 dilation_n == 1 && dilation_c == 1,
534 errors::Unimplemented("Current implementation does not yet support "
535 "dilations in the batch and depth dimensions."));
536 TF_REQUIRES(
537 dilation_h > 0 && dilation_w > 0,
538 errors::InvalidArgument("Dilated rates should be larger than 0."));
539
540 int num_dims = data_format == TensorFormat::FORMAT_NCHW_VECT_C ? 5 : 4;
541 TF_RETURN_IF_ERROR(CheckValidPadding(
542 params->padding, params->explicit_paddings, num_dims, data_format));
543
544 return OkStatus();
545}
546
547Status ComputeConv2DDimension(const Conv2DParameters& params,
548 const Tensor& input, const Tensor& filter,
549 Conv2DDimensions* dimensions) {
550 int required_dims =
551 params.data_format == TensorFormat::FORMAT_NCHW_VECT_C ? 5 : 4;
552 // Check that 2D convolution input and filter have exactly required_dims.
553 TF_REQUIRES(
554 input.dims() == required_dims,
555 errors::InvalidArgument("convolution input must be ", required_dims,
556 "-dimensional: ", input.shape().DebugString()));
557 TF_REQUIRES(
558 filter.dims() == required_dims,
559 errors::InvalidArgument("convolution filter must be ", required_dims,
560 "-dimensional: ", filter.shape().DebugString()));
561 for (int i = 0; i < required_dims - 1; i++) {
562 TF_REQUIRES(
563 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
564 errors::InvalidArgument("filter too large"));
565 }
566
567 FilterTensorFormat filter_format =
568 params.data_format == TensorFormat::FORMAT_NCHW_VECT_C
569 ? FilterTensorFormat::FORMAT_OIHW_VECT_I
570 : FilterTensorFormat::FORMAT_HWIO;
571
572 // The last dimension for input is in_depth. Check that it is the same as the
573 // filter's in_depth or it is evenly divisible by filter's in_depth.
574 const int64_t in_depth_raw = GetTensorDim(input, params.data_format, 'C');
575 const int64_t patch_depth_raw = GetFilterDim(filter, filter_format, 'I');
576 TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
577 errors::InvalidArgument("Input depth too large"));
578 TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
579 errors::InvalidArgument("Patch depth too large"));
580 const int in_depth = static_cast<int>(in_depth_raw);
581 const int patch_depth = static_cast<int>(patch_depth_raw);
582 TF_REQUIRES(patch_depth > 0,
583 errors::InvalidArgument(
584 "filter depth must be stricly positive, got ", patch_depth));
585 TF_REQUIRES(in_depth % patch_depth == 0,
586 errors::InvalidArgument(
587 "input depth must be evenly divisible by filter depth: ",
588 in_depth, " vs ", patch_depth));
589
590 // The last dimension for filter is out_depth.
591 const int out_depth =
592 static_cast<int>(GetFilterDim(filter, filter_format, 'O'));
593
594 // The second dimension for input is rows/height.
595 // The first dimension for filter is rows/height.
596 const int64_t input_rows_raw = GetTensorDim(input, params.data_format, 'H');
597 TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
598 errors::InvalidArgument("Input rows too large"));
599 const int input_rows = static_cast<int>(input_rows_raw);
600 const int filter_rows =
601 static_cast<int>(GetFilterDim(filter, filter_format, 'H'));
602
603 // The third dimension for input is columns/width.
604 // The second dimension for filter is columns/width.
605 const int64_t input_cols_raw = GetTensorDim(input, params.data_format, 'W');
606 TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
607 errors::InvalidArgument("Input cols too large"));
608 const int input_cols = static_cast<int>(input_cols_raw);
609 const int filter_cols =
610 static_cast<int>(GetFilterDim(filter, filter_format, 'W'));
611
612 // The first dimension for input is batch.
613 const int64_t batch_raw = GetTensorDim(input, params.data_format, 'N');
614 TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
615 errors::InvalidArgument("batch is too large"));
616 const int batch = static_cast<int>(batch_raw);
617
618 // Take the stride and dilation from the second and third dimensions only (we
619 // do not support striding or dilation on the batch or depth dimension).
620 const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
621 const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
622 const int dilation_rows =
623 GetTensorDim(params.dilations, params.data_format, 'H');
624 const int dilation_cols =
625 GetTensorDim(params.dilations, params.data_format, 'W');
626
627 int64_t pad_rows_before, pad_rows_after, pad_cols_before, pad_cols_after;
628 if (params.padding == Padding::EXPLICIT) {
629 GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'H',
630 &pad_rows_before, &pad_rows_after);
631 GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'W',
632 &pad_cols_before, &pad_cols_after);
633 }
634
635 // Compute windowed output sizes for rows and columns.
636 int64_t out_rows = 0, out_cols = 0;
637 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
638 input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
639 &out_rows, &pad_rows_before, &pad_rows_after));
640 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
641 input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
642 &out_cols, &pad_cols_before, &pad_cols_after));
643
644 dimensions->batch = batch;
645 dimensions->input_rows = input_rows;
646 dimensions->input_cols = input_cols;
647 dimensions->in_depth = in_depth;
648 dimensions->filter_rows = filter_rows;
649 dimensions->filter_cols = filter_cols;
650 dimensions->patch_depth = patch_depth;
651 dimensions->out_depth = out_depth;
652 dimensions->stride_rows = stride_rows;
653 dimensions->stride_cols = stride_cols;
654 dimensions->dilation_rows = dilation_rows;
655 dimensions->dilation_cols = dilation_cols;
656 dimensions->out_rows = out_rows;
657 dimensions->out_cols = out_cols;
658 dimensions->pad_rows_before = pad_rows_before;
659 dimensions->pad_rows_after = pad_rows_after;
660 dimensions->pad_cols_before = pad_cols_before;
661 dimensions->pad_cols_after = pad_cols_after;
662
663 return OkStatus();
664}
665
666#undef TF_REQUIRES
667
668template <typename Device, typename T>
669class Conv2DOp : public BinaryOp<T> {
670 public:
671 explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
672 OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
673
674 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
675 cudnn_use_autotune_ = CudnnUseAutotune();
676 }
677
678 void Compute(OpKernelContext* context) override {
679 // Input tensor is of the following dimensions:
680 // [ batch, in_rows, in_cols, in_depth ]
681 const Tensor& input = context->input(0);
682
683 // Input filter is of the following dimensions:
684 // [ filter_rows, filter_cols, in_depth, out_depth]
685 const Tensor& filter = context->input(1);
686
687 Conv2DDimensions dimensions;
688 OP_REQUIRES_OK(context,
689 ComputeConv2DDimension(params_, input, filter, &dimensions));
690
691 TensorShape out_shape = ShapeFromFormat(
692 params_.data_format, dimensions.batch, dimensions.out_rows,
693 dimensions.out_cols, dimensions.out_depth);
694
695 // Output tensor is of the following dimensions:
696 // [ in_batch, out_rows, out_cols, out_depth ]
697 Tensor* output = nullptr;
698 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
699
700 VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
701 << ", patch_depth = " << dimensions.patch_depth
702 << ", input_cols = " << dimensions.input_cols
703 << ", filter_cols = " << dimensions.filter_cols
704 << ", input_rows = " << dimensions.input_rows
705 << ", filter_rows = " << dimensions.filter_rows
706 << ", stride_rows = " << dimensions.stride_rows
707 << ", stride_cols = " << dimensions.stride_cols
708 << ", dilation_rows = " << dimensions.dilation_rows
709 << ", dilation_cols = " << dimensions.dilation_cols
710 << ", out_depth = " << dimensions.out_depth;
711
712 // If there is nothing to compute, return.
713 if (out_shape.num_elements() == 0) {
714 return;
715 }
716
717 // If the input is empty, result can only be due to padding.
718 if (input.NumElements() == 0) {
719 // Zero-out output and return.
720 functor::SetZeroFunctor<Device, T>()(context->eigen_device<Device>(),
721 output->template flat<T>());
722
723 return;
724 }
725
726#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
727 if (params_.padding != EXPLICIT &&
728 LaunchXsmmConvOp<Device, T>::Run(
729 context, input, filter, dimensions.batch, dimensions.input_rows,
730 dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
731 dimensions.filter_cols, dimensions.pad_rows_before,
732 dimensions.pad_cols_before, dimensions.out_rows,
733 dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
734 dimensions.dilation_cols, dimensions.stride_rows,
735 dimensions.stride_cols, output, params_.data_format)) {
736 return;
737 }
738#endif
739
740 if (params_.padding != EXPLICIT &&
741 LaunchDeepConvOp<Device, T>::Run(
742 context, input, filter, dimensions.batch, dimensions.input_rows,
743 dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
744 dimensions.filter_cols, dimensions.pad_rows_before,
745 dimensions.pad_cols_before, dimensions.out_rows,
746 dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
747 dimensions.dilation_cols, dimensions.stride_rows,
748 dimensions.stride_cols, output, params_.data_format)) {
749 return;
750 }
751
752 launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
753 dimensions.dilation_rows, dimensions.dilation_cols,
754 dimensions.stride_rows, dimensions.stride_cols, params_.padding,
755 params_.explicit_paddings, output, params_.data_format);
756 }
757
758 private:
759 Conv2DParameters params_;
760 bool use_cudnn_;
761 bool cudnn_use_autotune_;
762
763 LaunchConv2DOp<Device, T> launcher_;
764
765 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
766};
767
768#define REGISTER_CPU(T) \
769 REGISTER_KERNEL_BUILDER( \
770 Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
771 Conv2DOp<CPUDevice, T>);
772
773// If we're using the alternative GEMM-based implementation of Conv2D for the
774// CPU implementation, don't register this EigenTensor-based version.
775#if !defined(USE_GEMM_FOR_CONV)
776TF_CALL_bfloat16(REGISTER_CPU);
777TF_CALL_half(REGISTER_CPU);
778TF_CALL_float(REGISTER_CPU);
779TF_CALL_double(REGISTER_CPU);
780TF_CALL_int32(REGISTER_CPU);
781#endif // USE_GEMM_FOR_CONV
782
783// To be used inside depthwise_conv_op.cc.
784template struct LaunchConv2DOp<CPUDevice, Eigen::bfloat16>;
785template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
786template struct LaunchConv2DOp<CPUDevice, float>;
787template struct LaunchConv2DOp<CPUDevice, double>;
788
789#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
790
791int64_t GetDnnWorkspaceLimit(const string& envvar_in_mb,
792 int64_t default_value_in_bytes) {
793 const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
794 if (workspace_limit_in_mb_str != nullptr &&
795 strcmp(workspace_limit_in_mb_str, "") != 0) {
796 int64_t scratch_limit_in_mb = -1;
797 if (strings::safe_strto64(workspace_limit_in_mb_str,
798 &scratch_limit_in_mb)) {
799 return scratch_limit_in_mb * (1 << 20);
800 } else {
801 LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
802 << workspace_limit_in_mb_str;
803 }
804 }
805 return default_value_in_bytes;
806}
807
808int64_t GetDnnWorkspaceLimitOrDefault() {
809 return GetDnnWorkspaceLimit("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
810 1LL << 33); // 8GB by default
811}
812
813template <typename T>
814void LaunchConv2DOp<GPUDevice, T>::operator()(
815 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
816 const Tensor& input_param, const Tensor& filter, int row_dilation,
817 int col_dilation, int row_stride, int col_stride, const Padding& padding,
818 const std::vector<int64_t>& explicit_paddings, Tensor* output,
819 TensorFormat data_format) {
820 using se::dnn::AlgorithmConfig;
821 using se::dnn::AlgorithmDesc;
822 using se::dnn::ProfileResult;
823 auto* stream = ctx->op_device_context()->stream();
824 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
825
826 if (!use_cudnn) {
827 ctx->SetStatus(
828 errors::Unimplemented("Conv2D for GPU is not currently supported "
829 "without cudnn"));
830 return;
831 }
832
833 Tensor input = input_param;
834 const int64_t in_batch = GetTensorDim(input, data_format, 'N');
835 int64_t in_rows = GetTensorDim(input, data_format, 'H');
836 int64_t in_cols = GetTensorDim(input, data_format, 'W');
837 const int64_t in_depths = GetTensorDim(input, data_format, 'C');
838 const int64_t patch_rows = filter.dim_size(0);
839 const int64_t patch_cols = filter.dim_size(1);
840 const int64_t patch_depths = filter.dim_size(2);
841
842 OP_REQUIRES(
843 ctx, filter.NumElements() > 0,
844 errors::InvalidArgument("filter must not have zero elements "
845 "(i.e. all dimensions must be non-zero)"));
846
847 // If the filter in-depth (patch_depths) is 1 and smaller than the input
848 // depth, it's a depthwise convolution. More generally, if the filter in-depth
849 // divides but is smaller than the input depth, it is a grouped convolution.
850 bool is_grouped_convolution = patch_depths != in_depths;
851 if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
852 row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
853 col_stride == 1 && data_format == FORMAT_NHWC &&
854 (padding == VALID || padding == SAME)) {
855 // 1x1 filter, so call cublas directly.
856 const uint64 m = in_batch * in_rows * in_cols;
857 const uint64 k = patch_depths;
858 const uint64 n = filter.dim_size(3);
859
860 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
861 input.template flat<T>().size());
862 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
863 filter.template flat<T>().size());
864 auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
865 output->template flat<T>().size());
866
867 auto no_transpose = se::blas::Transpose::kNoTranspose;
868 OP_REQUIRES_OK(
869 ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
870 a_ptr, k, &c_ptr, n,
871 se::blas::kDefaultComputePrecision));
872 return;
873 } else if (patch_rows == in_rows && patch_cols == in_cols &&
874 !is_grouped_convolution && row_dilation == 1 &&
875 col_dilation == 1 && padding == VALID &&
876 data_format == FORMAT_NHWC) {
877 // The input data and filter have the same height/width, so call cublas
878 // directly.
879 const uint64 m = in_batch;
880 const uint64 k = patch_rows * patch_cols * patch_depths;
881 const uint64 n = filter.dim_size(3);
882
883 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
884 input.template flat<T>().size());
885 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
886 filter.template flat<T>().size());
887 auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
888 output->template flat<T>().size());
889
890 auto no_transpose = se::blas::Transpose::kNoTranspose;
891 OP_REQUIRES_OK(
892 ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
893 a_ptr, k, &c_ptr, n,
894 se::blas::kDefaultComputePrecision));
895 return;
896 }
897
898#if GOOGLE_CUDA
899 // Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
900 // in NHWC data layout. In all other configurations it's more efficient to
901 // run computation in NCHW data format.
902 const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
903 stream->GetCudaComputeCapability().IsAtLeast(
904 se::CudaComputeCapability::VOLTA);
905#else
906 // fast NHWC implementation is a CUDA only feature
907 const bool compute_in_nhwc = false;
908#endif
909
910 // We only do one directional conversion: NHWC->NCHW. We never convert in the
911 // other direction. Grappler layout optimizer selects preferred layout and
912 // adds necessary annotations to the graph.
913 // TODO(ezhulenev): Convert in other direction for fp16?
914 const TensorFormat compute_data_format =
915 (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
916 : FORMAT_NCHW;
917
918 VLOG(3) << "Compute Conv2D with cuDNN:"
919 << " data_format=" << ToString(data_format)
920 << " compute_data_format=" << ToString(compute_data_format);
921
922 const int64_t out_batch = GetTensorDim(*output, data_format, 'N');
923 const int64_t out_rows = GetTensorDim(*output, data_format, 'H');
924 const int64_t out_cols = GetTensorDim(*output, data_format, 'W');
925 const int64_t out_depths = GetTensorDim(*output, data_format, 'C');
926 int64_t padding_top = -1, padding_bottom = -1;
927 int64_t padding_left = -1, padding_right = -1;
928 if (padding == EXPLICIT) {
929 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
930 &padding_bottom);
931 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
932 &padding_right);
933 }
934 int64_t out_rows_check, out_cols_check;
935 Status status = GetWindowedOutputSizeVerboseV2(
936 in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check,
937 &padding_top, &padding_bottom);
938 // The status is guaranteed to be OK because we checked the output and padding
939 // was valid earlier.
940 TF_CHECK_OK(status);
941 DCHECK_EQ(out_rows, out_rows_check);
942 status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation,
943 col_stride, padding, &out_cols_check,
944 &padding_left, &padding_right);
945 TF_CHECK_OK(status);
946 DCHECK_EQ(out_cols, out_cols_check);
947
948 const int64_t common_padding_rows = std::min(padding_top, padding_bottom);
949 const int64_t common_padding_cols = std::min(padding_left, padding_right);
950 if (padding_top != padding_bottom || padding_left != padding_right) {
951 // cuDNN only supports padding the same amount on the left and right sides,
952 // and on the top and bottom sides. So we manually create a new padded
953 // input tensor such that we can pass it to cuDNN.
954 VLOG(4) << "Pad input tensor:"
955 << " padding_top=" << padding_top
956 << " padding_bottom=" << padding_bottom
957 << " padding_left=" << padding_left
958 << " padding_right=" << padding_right;
959
960 // TODO(reedwm): In some cases, we can avoid an allocation even if the two
961 // padding sides are different. For example, if the input is 2x2, the filter
962 // is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the result is
963 // equivalent to as if the padding is (1, 1, 1, 1). Changing the padding in
964 // such a way would allow us to avoid the allocation.
965 Tensor transformed_input;
966 const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top);
967 const int64_t padding_cols_diff = std::abs(padding_right - padding_left);
968 const int64_t new_in_rows = in_rows + padding_rows_diff;
969 const int64_t new_in_cols = in_cols + padding_cols_diff;
970 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
971 DataTypeToEnum<T>::value,
972 ShapeFromFormat(data_format, in_batch, new_in_rows,
973 new_in_cols, in_depths),
974 &transformed_input));
975
976 const int64_t input_pad_top = padding_top - common_padding_rows;
977 const int64_t input_pad_bottom = padding_bottom - common_padding_rows;
978 const int64_t input_pad_left = padding_left - common_padding_cols;
979 const int64_t input_pad_right = padding_right - common_padding_cols;
980 bool in_bounds =
981 FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
982 FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
983 FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
984 FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
985 if (!in_bounds) {
986 ctx->SetStatus(errors::InvalidArgument("Padding is too large."));
987 return;
988 }
989 functor::PadInput<GPUDevice, T, int, 4>()(
990 ctx->eigen_device<GPUDevice>(),
991 To32Bit(static_cast<const Tensor&>(input).tensor<T, 4>()),
992 {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
993 {{static_cast<int>(input_pad_bottom),
994 static_cast<int>(input_pad_right)}},
995 To32Bit(transformed_input.tensor<T, 4>()), data_format, T{});
996
997 input = transformed_input;
998 in_rows = new_in_rows;
999 in_cols = new_in_cols;
1000 }
1001
1002 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1003 VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
1004
1005 TensorShape nchw_shape =
1006 ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
1007 if (in_depths > 1) {
1008 Tensor transformed_input;
1009 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1010 nchw_shape, &transformed_input));
1011 functor::NHWCToNCHW<GPUDevice, T, 4>()(
1012 ctx->eigen_device<GPUDevice>(),
1013 const_cast<const Tensor&>(input).tensor<T, 4>(),
1014 transformed_input.tensor<T, 4>());
1015 input = transformed_input;
1016 } else {
1017 // If depth <= 1, then just reshape.
1018 CHECK(input.CopyFrom(input, nchw_shape));
1019 }
1020 } else {
1021 CHECK(data_format == compute_data_format) // Crash OK
1022 << "Illegal data and compute format pair:"
1023 << " data_format=" << ToString(data_format)
1024 << " compute_data_format=" << ToString(compute_data_format);
1025 }
1026
1027 CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
1028 << "Negative row or col paddings: (" << common_padding_rows << ", "
1029 << common_padding_cols << ")";
1030
1031 constexpr auto kComputeInNHWC =
1032 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1033 se::dnn::FilterLayout::kOutputYXInput);
1034 constexpr auto kComputeInNCHW =
1035 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1036 se::dnn::FilterLayout::kOutputInputYX);
1037
1038 se::dnn::DataLayout compute_data_layout;
1039 se::dnn::FilterLayout filter_layout;
1040
1041 std::tie(compute_data_layout, filter_layout) =
1042 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1043
1044 se::dnn::BatchDescriptor input_desc;
1045 input_desc.set_count(in_batch)
1046 .set_feature_map_count(in_depths)
1047 .set_height(in_rows)
1048 .set_width(in_cols)
1049 .set_layout(compute_data_layout);
1050 se::dnn::BatchDescriptor output_desc;
1051 output_desc.set_count(out_batch)
1052 .set_height(out_rows)
1053 .set_width(out_cols)
1054 .set_feature_map_count(out_depths)
1055 .set_layout(compute_data_layout);
1056 se::dnn::FilterDescriptor filter_desc;
1057 filter_desc.set_input_filter_height(patch_rows)
1058 .set_input_filter_width(patch_cols)
1059 .set_input_feature_map_count(patch_depths)
1060 .set_output_feature_map_count(filter.dim_size(3))
1061 .set_layout(filter_layout);
1062 se::dnn::ConvolutionDescriptor conv_desc;
1063 conv_desc.set_vertical_dilation_rate(row_dilation)
1064 .set_horizontal_dilation_rate(col_dilation)
1065 .set_vertical_filter_stride(row_stride)
1066 .set_horizontal_filter_stride(col_stride)
1067 .set_zero_padding_height(common_padding_rows)
1068 .set_zero_padding_width(common_padding_cols)
1069 .set_group_count(in_depths / patch_depths);
1070
1071 Tensor transformed_filter;
1072
1073 const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
1074 VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
1075 << " to " << ToString(dst_format);
1076
1077 TensorShape dst_shape =
1078 dst_format == FORMAT_OIHW
1079 ? TensorShape({filter.dim_size(3), filter.dim_size(2),
1080 filter.dim_size(0), filter.dim_size(1)})
1081 : TensorShape({filter.dim_size(3), filter.dim_size(0),
1082 filter.dim_size(1), filter.dim_size(2)});
1083
1084 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1085 &transformed_filter));
1086 functor::TransformFilter<GPUDevice, T, int, 4>()(
1087 ctx->eigen_device<GPUDevice>(), dst_format,
1088 To32Bit(filter.tensor<T, 4>()),
1089 To32Bit(transformed_filter.tensor<T, 4>()));
1090
1091 return OkStatus();
1092 };
1093
1094 if (compute_data_format == FORMAT_NCHW) {
1095 OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
1096 } else if (compute_data_format == FORMAT_NHWC) {
1097 OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
1098 } else {
1099 ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
1100 ToString(compute_data_format)));
1101 return;
1102 }
1103
1104 Tensor transformed_output;
1105 if (data_format != compute_data_format) {
1106 VLOG(4) << "Allocate temporary memory for output in compute data format";
1107 OP_REQUIRES_OK(
1108 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1109 ShapeFromFormat(compute_data_format, out_batch,
1110 out_rows, out_cols, out_depths),
1111 &transformed_output));
1112 } else {
1113 transformed_output = *output;
1114 }
1115
1116 auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
1117 input.template flat<T>().size());
1118 auto filter_ptr =
1119 AsDeviceMemory(transformed_filter.template flat<T>().data(),
1120 transformed_filter.template flat<T>().size());
1121 auto output_ptr =
1122 AsDeviceMemory(transformed_output.template flat<T>().data(),
1123 transformed_output.template flat<T>().size());
1124
1125 static int64_t ConvolveScratchSize = GetDnnWorkspaceLimitOrDefault();
1126
1127 int device_id = stream->parent()->device_ordinal();
1128 DataType dtype = input.dtype();
1129 ConvParameters conv_parameters = {in_batch, // batch
1130 in_depths, // in_depths
1131 {{in_rows, // in_rows
1132 in_cols}}, // in_cols
1133 compute_data_format, // compute_data_format
1134 out_depths, // out_depths
1135 {{patch_rows, // filter_rows
1136 patch_cols, // filter_cols
1137 patch_depths}}, // filter_depths
1138 {{row_dilation, // dilation_rows
1139 col_dilation}}, // dilation_cols
1140 {{row_stride, // stride_rows
1141 col_stride}}, // stride_cols
1142 {{common_padding_rows, // padding_rows
1143 common_padding_cols}}, // padding_cols
1144 dtype, // tensor datatype
1145 device_id, // device_id
1146 conv_desc.group_count()};
1147
1148 auto entry_or = AutotuneUnfusedConv(
1149 cudnn_use_autotune, ConvAutotuneMap::GetInstance(), conv_parameters, ctx,
1150 se::dnn::ConvolutionKind::FORWARD, input_desc, input_ptr, filter_desc,
1151 filter_ptr, conv_desc, output_desc, output_ptr, ConvolveScratchSize);
1152 OP_REQUIRES_OK(ctx, entry_or.status());
1153 auto autotune_entry = std::move(entry_or).value();
1154
1155 DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1156 Status cudnn_launch_status = LaunchAutotunedConv(
1157 autotune_entry, &scratch_allocator, se::dnn::ConvolutionKind::FORWARD,
1158 stream, input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1159 output_desc, output_ptr);
1160 if (!cudnn_launch_status.ok()) {
1161 ctx->SetStatus(cudnn_launch_status);
1162 return;
1163 }
1164
1165 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1166 VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
1167 functor::NCHWToNHWC<GPUDevice, T, 4>()(
1168 ctx->eigen_device<GPUDevice>(),
1169 const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
1170 output->tensor<T, 4>());
1171 }
1172}
1173
1174// Forward declarations of the functor specializations for GPU.
1175namespace functor {
1176#define DECLARE_GPU_SPEC(T) \
1177 template <> \
1178 void SpatialConvolution<GPUDevice, T>::operator()( \
1179 const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
1180 typename TTypes<T, 4>::ConstTensor input, \
1181 typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
1182 int col_stride, int row_dilation, int col_dilation, \
1183 const Eigen::PaddingType& padding, \
1184 const Eigen::NoOpOutputKernel& output_kernel); \
1185 template <> \
1186 void SpatialConvolution<GPUDevice, T>::operator()( \
1187 const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
1188 typename TTypes<T, 4>::ConstTensor input, \
1189 typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
1190 int col_stride, int row_dilation, int col_dilation, int padding_top, \
1191 int padding_bottom, int padding_left, int padding_right, \
1192 const Eigen::NoOpOutputKernel& output_kernel); \
1193 extern template struct SpatialConvolution<GPUDevice, T>; \
1194 template <> \
1195 void MatMulConvFunctor<GPUDevice, T>::operator()( \
1196 const GPUDevice& d, typename TTypes<T, 2>::Tensor out, \
1197 typename TTypes<T, 2>::ConstTensor in0, \
1198 typename TTypes<T, 2>::ConstTensor in1, \
1199 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, \
1200 const Eigen::NoOpOutputKernel& output_kernel); \
1201 extern template struct MatMulConvFunctor<GPUDevice, T>; \
1202 template <> \
1203 void TransformFilter<GPUDevice, T, int, 4>::operator()( \
1204 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
1205 typename TTypes<T, 4, int>::ConstTensor in, \
1206 typename TTypes<T, 4, int>::Tensor out); \
1207 extern template struct TransformFilter<GPUDevice, T, int, 4>; \
1208 template <> \
1209 void PadInput<GPUDevice, T, int, 4>::operator()( \
1210 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
1211 const std::array<int, 2>& padding_left, \
1212 const std::array<int, 2>& padding_right, \
1213 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
1214 const T& padding_value); \
1215 extern template struct PadInput<GPUDevice, T, int, 4>
1216
1217DECLARE_GPU_SPEC(float);
1218DECLARE_GPU_SPEC(Eigen::half);
1219DECLARE_GPU_SPEC(double);
1220DECLARE_GPU_SPEC(int32);
1221#undef DECLARE_GPU_SPEC
1222
1223} // namespace functor
1224
1225// Registration of the GPU implementations.
1226REGISTER_KERNEL_BUILDER(
1227 Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
1228 Conv2DOp<GPUDevice, Eigen::half>);
1229REGISTER_KERNEL_BUILDER(
1230 Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1231 Conv2DOp<GPUDevice, float>);
1232REGISTER_KERNEL_BUILDER(
1233 Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
1234 Conv2DOp<GPUDevice, double>);
1235REGISTER_KERNEL_BUILDER(
1236 Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
1237 Conv2DOp<GPUDevice, int32>);
1238
1239// To be used inside depthwise_conv_op.cc.
1240template struct LaunchConv2DOp<GPUDevice, float>;
1241template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
1242template struct LaunchConv2DOp<GPUDevice, double>;
1243
1244#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1245
1246} // namespace tensorflow
1247