1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements convolution operations with other kernels baked into the
17// processing, to optimize latency and memory usage:
18// - Conv2D + BiasAdd + <Activation>
19// - Conv2D + FusedBatchNorm + <Activation>
20//
21// Activation: Relu, Relu6, Elu, etc...
22//
23// Kernels for convolutions fused with image transformations (resize and mirror
24// padding) defined in `conv_ops_fused_image_transform.cc`.
25//
26// For the CPU device we implement fusion with an Eigen tensor contraction
27// output kernel. For the GPU device we rely on CuDNN primitives.
28//
29// NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
30
31#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
32#define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
33
34#define USE_EIGEN_TENSOR
35#define EIGEN_USE_THREADS
36
37#if GOOGLE_CUDA
38#define EIGEN_USE_GPU
39#endif // GOOGLE_CUDA
40
41#include <string>
42#include <type_traits>
43#include <utility>
44#include <vector>
45
46#include "tensorflow/core/framework/bounds_check.h"
47#include "tensorflow/core/framework/op_kernel.h"
48#include "tensorflow/core/framework/register_types.h"
49#include "tensorflow/core/framework/tensor.h"
50#include "tensorflow/core/framework/tensor_shape.h"
51#include "tensorflow/core/kernels/conv_2d.h"
52#include "tensorflow/core/kernels/conv_ops.h"
53#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
54#include "tensorflow/core/kernels/ops_util.h"
55#include "tensorflow/core/profiler/lib/scoped_annotation.h"
56#include "tensorflow/core/util/tensor_format.h"
57#include "tensorflow/core/util/use_cudnn.h"
58
59#if GOOGLE_CUDA
60#include "third_party/gpus/cudnn/cudnn.h"
61#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
62#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
63#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
64#include "tensorflow/core/kernels/conv_ops_gpu.h"
65#include "tensorflow/core/platform/stream_executor.h"
66#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
67#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
68#include "tensorflow/core/util/proto/proto_utils.h"
69#endif // GOOGLE_CUDA
70
71namespace tensorflow {
72
73typedef Eigen::ThreadPoolDevice CPUDevice;
74typedef Eigen::GpuDevice GPUDevice;
75
76template <typename Device, typename T>
77struct LaunchFusedConv2DOp {
78 void operator()(OpKernelContext* context, bool use_cudnn,
79 bool cudnn_use_autotune, const Tensor& input,
80 const Tensor& filter, FusedComputationType fusion,
81 const FusedComputationArgs& fusion_args,
82 const Conv2DParameters& params,
83 const Conv2DDimensions& dimensions, Tensor* output);
84};
85
86// This is CPU-only implementation that uses Eigen contraction output kernels.
87//
88// Dispatch 2D convolution to the appropriate primitive operation:
89// (1) MatMul for the case of 1x1 convolution.
90// (2) MatMul for the case when filter size equals to the input size.
91// (3) General spatial 2D convolution for all other cases.
92template <typename T>
93class LaunchFusedConv2DWithOutputKernel {
94 public:
95 LaunchFusedConv2DWithOutputKernel(
96 int row_stride, int col_stride, //
97 int row_dilation, int col_dilation, //
98 Padding padding, const std::vector<int64_t>& explicit_paddings)
99 : row_stride_(row_stride),
100 col_stride_(col_stride),
101 row_dilation_(row_dilation),
102 col_dilation_(col_dilation),
103 padding_(padding),
104 explicit_paddings_(explicit_paddings) {}
105
106 template <typename OutputKernel>
107 void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
108 const Tensor& input, const Tensor& filter, Tensor* output) {
109 // Wrap output_kernel into type erased wrapper to reduce the number of
110 // unique template instantiations for Eigen Tensor contraction expressions.
111 OutputKernelWrapper output_kernel_wrapper(
112 [&output_kernel](
113 const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
114 const Eigen::TensorContractionParams& params, Eigen::Index i,
115 Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
116 output_kernel(output_mapper, params, i, j, num_rows, num_cols);
117 });
118
119 if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
120 row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
121 int conv_width = 1; // Width for the convolution step.
122 for (int i = 0; i < 3; ++i) {
123 conv_width *= output->dim_size(i);
124 }
125
126 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
127 dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
128 functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
129 ctx->eigen_device<CPUDevice>(),
130 output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
131 input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
132 filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
133 dim_pair, std::move(output_kernel_wrapper));
134
135 } else if (filter.dim_size(0) == input.dim_size(1) &&
136 filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
137 col_dilation_ == 1 && padding_ == VALID) {
138 // If the input data and filter have the same height/width,
139 // reduce the 2D convolution to matrix multiplication.
140 const auto k = // Length of reduction dimension.
141 filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
142
143 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
144 dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
145 functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
146 ctx->eigen_device<CPUDevice>(),
147 output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
148 input.shaped<T, 2>({input.dim_size(0), k}),
149 filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
150 std::move(output_kernel_wrapper));
151
152 } else {
153 if (padding_ == EXPLICIT) {
154 functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
155 ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
156 input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
157 col_stride_, row_dilation_, col_dilation_,
158 static_cast<int>(explicit_paddings_[2]),
159 static_cast<int>(explicit_paddings_[3]),
160 static_cast<int>(explicit_paddings_[4]),
161 static_cast<int>(explicit_paddings_[5]),
162 std::move(output_kernel_wrapper));
163 } else {
164 functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
165 ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
166 input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
167 col_stride_, row_dilation_, col_dilation_,
168 BrainPadding2EigenPadding(padding_),
169 std::move(output_kernel_wrapper));
170 }
171 }
172 }
173
174 private:
175 // Wrap output_kernel into type erased struct to reduce the number of unique
176 // template instantiations for Eigen Tensor contraction expressions.
177 //
178 // We do not pass std::function directly as an output kernel because it blows
179 // up the binary size in debug mode with super long symbol names.
180 struct OutputKernelWrapper {
181 using OutputKernelFn =
182 std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
183 const Eigen::TensorContractionParams&, Eigen::Index,
184 Eigen::Index, Eigen::Index, Eigen::Index)>;
185
186 explicit OutputKernelWrapper(OutputKernelFn fn)
187 : output_kernel_fn(std::move(fn)) {}
188
189 void operator()(
190 const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
191 const Eigen::TensorContractionParams& params, Eigen::Index i,
192 Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
193 output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
194 }
195
196 OutputKernelFn output_kernel_fn;
197 };
198
199 int row_stride_;
200 int col_stride_;
201 int row_dilation_;
202 int col_dilation_;
203 const Padding padding_;
204 const std::vector<int64_t>& explicit_paddings_;
205};
206
207template <typename T>
208struct LaunchFusedConv2DOp<CPUDevice, T> {
209 void operator()(OpKernelContext* context, bool use_cudnn,
210 bool cudnn_use_autotune, const Tensor& input,
211 const Tensor& filter, const FusedComputationType fusion,
212 const FusedComputationArgs& fusion_args,
213 const Conv2DParameters& params,
214 const Conv2DDimensions& dimensions, Tensor* output) {
215 OP_REQUIRES(context, dimensions.in_depth == filter.dim_size(2),
216 errors::Unimplemented("Fused conv implementation does not "
217 "support grouped convolutions for now."));
218 OP_REQUIRES(context, params.data_format == FORMAT_NHWC,
219 errors::Unimplemented("Fused conv implementation only supports "
220 "NHWC tensor format for now."));
221 OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
222 errors::Unimplemented("Fused conv implementation with half "
223 "precision is not supported on CPU."));
224
225 BiasAddArgs<T> bias_add_args;
226 if (BiasAddArgs<T>::IsSupported(fusion)) {
227 if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
228 OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
229 &fusion_args.leakyrelu_alpha));
230 } else {
231 OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
232 }
233 }
234
235 FusedBatchNormArgs<T> fused_batch_norm_args;
236 if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
237 if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) {
238 OP_REQUIRES_OK(context,
239 InitFusedBatchNormArgs(context, fusion_args.epsilon,
240 &fused_batch_norm_args,
241 &fusion_args.leakyrelu_alpha));
242 } else {
243 OP_REQUIRES_OK(context,
244 InitFusedBatchNormArgs(context, fusion_args.epsilon,
245 &fused_batch_norm_args));
246 }
247 }
248
249 LaunchFusedConv2DWithOutputKernel<T> conv2d(
250 dimensions.stride_rows, dimensions.stride_cols,
251 dimensions.dilation_rows, dimensions.dilation_cols, params.padding,
252 params.explicit_paddings);
253
254 switch (fusion) {
255 case FusedComputationType::kUndefined:
256 OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
257 break;
258 case FusedComputationType::kBiasAdd:
259 conv2d(WithBiasAdd<T>(bias_add_args), context, input, filter, output);
260 break;
261 case FusedComputationType::kBiasAddWithRelu:
262 conv2d(WithBiasAddAndRelu<T>(bias_add_args), context, input, filter,
263 output);
264 break;
265 case FusedComputationType::kBiasAddWithRelu6:
266 conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
267 output);
268 break;
269 case FusedComputationType::kBiasAddWithLeakyRelu:
270 conv2d(WithBiasAddAndLeakyRelu<T>(bias_add_args), context, input,
271 filter, output);
272 break;
273 case FusedComputationType::kBiasAddWithElu:
274 conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
275 output);
276 break;
277 case FusedComputationType::kFusedBatchNorm:
278 conv2d(
279 WithFusedBatchNorm<T>(fusion_args.epsilon, fused_batch_norm_args),
280 context, input, filter, output);
281 break;
282 case FusedComputationType::kFusedBatchNormWithRelu:
283 conv2d(WithFusedBatchNormAndRelu<T>(fusion_args.epsilon,
284 fused_batch_norm_args),
285 context, input, filter, output);
286 break;
287 case FusedComputationType::kFusedBatchNormWithRelu6:
288 conv2d(WithFusedBatchNormAndRelu6<T>(fusion_args.epsilon,
289 fused_batch_norm_args),
290 context, input, filter, output);
291 break;
292 case FusedComputationType::kFusedBatchNormWithLeakyRelu:
293 conv2d(WithFusedBatchNormAndLeakyRelu<T>(fusion_args.epsilon,
294 fused_batch_norm_args),
295 context, input, filter, output);
296 break;
297 case FusedComputationType::kFusedBatchNormWithElu:
298 conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
299 fused_batch_norm_args),
300 context, input, filter, output);
301 break;
302 default:
303 OP_REQUIRES_OK(context, errors::Internal("Fusion type is unsupported"));
304 break;
305 }
306 }
307};
308
309template <>
310struct LaunchFusedConv2DOp<CPUDevice, int8>;
311
312template <>
313struct LaunchFusedConv2DOp<CPUDevice, qint8>;
314
315#if GOOGLE_CUDA
316
317inline int64_t ConvolveScratchSize() {
318 static int64_t convolve_scratch_size = GetDnnWorkspaceLimit(
319 // default value is in bytes despite the name of the environment variable
320 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
321 );
322 return convolve_scratch_size;
323}
324
325template <typename T>
326struct LaunchFusedConv2DOp<GPUDevice, T> {
327 void operator()(OpKernelContext* context, bool use_cudnn,
328 bool cudnn_use_autotune, const Tensor& input_param,
329 const Tensor& filter, FusedComputationType fusion,
330 const FusedComputationArgs& fusion_args,
331 const Conv2DParameters& params,
332 const Conv2DDimensions& dimensions, Tensor* output) {
333 OP_REQUIRES(
334 context,
335 params.data_format == FORMAT_NHWC || params.data_format == FORMAT_NCHW,
336 errors::Unimplemented("Fused conv implementation only supports "
337 "NHWC and HCHW tensor formats for now."));
338
339 auto* stream = context->op_device_context()->stream();
340 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
341 OP_REQUIRES(
342 context, use_cudnn,
343 errors::Unimplemented("FusedConv2D for GPU is not currently supported "
344 "without cudnn"));
345
346 bool is_supported_activation =
347 fusion == FusedComputationType::kBiasAddWithRelu ||
348 fusion == FusedComputationType::kBiasAddWithRelu6 ||
349 fusion == FusedComputationType::kBiasAddWithElu ||
350 fusion == FusedComputationType::kBiasAddWithLeakyRelu;
351 OP_REQUIRES(
352 context, is_supported_activation,
353 errors::Unimplemented("FusedConv2D implementation only supports "
354 "fusing with `BiasAdd + Relu|Relu6|Elu|LeakyRlue`"
355 " for now."));
356
357 Tensor input = input_param;
358
359 const int64_t in_batch = GetTensorDim(input, params.data_format, 'N');
360 int64_t in_rows = GetTensorDim(input, params.data_format, 'H');
361 int64_t in_cols = GetTensorDim(input, params.data_format, 'W');
362 const int64_t in_depths = GetTensorDim(input, params.data_format, 'C');
363
364 const int64_t patch_rows = filter.dim_size(0);
365 const int64_t patch_cols = filter.dim_size(1);
366 const int64_t patch_depths = filter.dim_size(2);
367
368 const int64_t out_batch = GetTensorDim(*output, params.data_format, 'N');
369 const int64_t out_rows = GetTensorDim(*output, params.data_format, 'H');
370 const int64_t out_cols = GetTensorDim(*output, params.data_format, 'W');
371 const int64_t out_depths = GetTensorDim(*output, params.data_format, 'C');
372
373 // Bias of the following dimensions: [ output_depth ]
374 const Tensor& bias = context->input(2);
375 OP_REQUIRES(context, bias.dims() == 1,
376 errors::InvalidArgument("bias must be 1-dimensional",
377 bias.shape().DebugString()));
378 OP_REQUIRES(context, bias.dim_size(0) == out_depths,
379 errors::InvalidArgument("bias depth must be equal to out depth",
380 bias.shape().DebugString()));
381
382 const int64_t common_padding_rows =
383 std::min(dimensions.pad_rows_before, dimensions.pad_rows_after);
384 const int64_t common_padding_cols =
385 std::min(dimensions.pad_cols_before, dimensions.pad_cols_after);
386 if (dimensions.pad_rows_before != dimensions.pad_rows_after ||
387 dimensions.pad_cols_before != dimensions.pad_cols_after) {
388 // cuDNN only supports padding the same amount on the left and right
389 // sides, and on the top and bottom sides. So we manually create a new
390 // padded input tensor such that we can pass it to cuDNN.
391
392 // TODO(reedwm): In some cases, we can avoid an allocation even if the two
393 // padding sides are different. For example, if the input is 2x2, the
394 // filter is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the
395 // result is equivalent to as if the padding is (1, 1, 1, 1). Changing the
396 // padding in such a way would allow us to avoid the allocation.
397 Tensor transformed_input;
398 const int64_t padding_rows_diff =
399 std::abs(dimensions.pad_rows_after - dimensions.pad_rows_before);
400 const int64_t padding_cols_diff =
401 std::abs(dimensions.pad_cols_after - dimensions.pad_cols_before);
402 const int64_t new_in_rows = in_rows + padding_rows_diff;
403 const int64_t new_in_cols = in_cols + padding_cols_diff;
404 OP_REQUIRES_OK(context,
405 context->allocate_temp(
406 DataTypeToEnum<T>::value,
407 ShapeFromFormat(params.data_format, in_batch,
408 new_in_rows, new_in_cols, in_depths),
409 &transformed_input));
410 const int64_t input_pad_top =
411 dimensions.pad_rows_before - common_padding_rows;
412 const int64_t input_pad_bottom =
413 dimensions.pad_rows_after - common_padding_rows;
414 const int64_t input_pad_left =
415 dimensions.pad_cols_before - common_padding_cols;
416 const int64_t input_pad_right =
417 dimensions.pad_cols_after - common_padding_cols;
418 bool in_bounds =
419 FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
420 FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
421 FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
422 FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
423 if (!in_bounds) {
424 context->SetStatus(errors::InvalidArgument("Padding is too large."));
425 return;
426 }
427 functor::PadInput<GPUDevice, T, int, 4>()(
428 context->eigen_device<GPUDevice>(),
429 To32Bit(input_param.tensor<T, 4>()),
430 {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
431 {{static_cast<int>(input_pad_bottom),
432 static_cast<int>(input_pad_right)}},
433 To32Bit(transformed_input.tensor<T, 4>()), params.data_format, T{});
434 input = transformed_input;
435 in_rows = new_in_rows;
436 in_cols = new_in_cols;
437 }
438
439 const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
440 stream->GetCudaComputeCapability().IsAtLeast(
441 se::CudaComputeCapability::VOLTA);
442 if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
443 // Convert the input tensor from NHWC to NCHW.
444 TensorShape nchw_shape =
445 ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
446 if (in_depths > 1) {
447 Tensor transformed_input;
448 OP_REQUIRES_OK(context,
449 context->allocate_temp(DataTypeToEnum<T>::value,
450 nchw_shape, &transformed_input));
451 functor::NHWCToNCHW<GPUDevice, T, 4>()(
452 context->eigen_device<GPUDevice>(),
453 const_cast<const Tensor&>(input).tensor<T, 4>(),
454 transformed_input.tensor<T, 4>());
455 input = transformed_input;
456 } else {
457 // If depth <= 1, then just reshape.
458 CHECK(input.CopyFrom(input, nchw_shape)); // Crash OK
459 }
460 }
461
462 CHECK(common_padding_rows >= 0) << "Negative padding rows"; // Crash OK
463 CHECK(common_padding_rows >= 0) << "Negative padding cols"; // Crash OK
464
465 se::dnn::ActivationMode dnn_activation_mode;
466 switch (fusion) {
467 case FusedComputationType::kBiasAddWithRelu:
468 dnn_activation_mode = se::dnn::ActivationMode::kRelu;
469 break;
470 case FusedComputationType::kBiasAddWithRelu6:
471 dnn_activation_mode = se::dnn::ActivationMode::kRelu6;
472 break;
473 case FusedComputationType::kBiasAddWithElu:
474 dnn_activation_mode = se::dnn::ActivationMode::kElu;
475 break;
476 case FusedComputationType::kBiasAddWithLeakyRelu:
477 dnn_activation_mode = se::dnn::ActivationMode::kLeakyRelu;
478 break;
479 default:
480 LOG(FATAL) << "Unsupported fusion type"; // Crash OK
481 }
482
483 const TensorFormat compute_data_format =
484 compute_in_nhwc ? FORMAT_NHWC : FORMAT_NCHW;
485 constexpr auto kComputeInNHWC =
486 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
487 se::dnn::FilterLayout::kOutputYXInput);
488 constexpr auto kComputeInNCHW =
489 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
490 se::dnn::FilterLayout::kOutputInputYX);
491 se::dnn::DataLayout compute_data_layout;
492 se::dnn::FilterLayout filter_layout;
493 std::tie(compute_data_layout, filter_layout) =
494 compute_in_nhwc ? kComputeInNHWC : kComputeInNCHW;
495
496 se::dnn::BatchDescriptor input_desc;
497 input_desc.set_count(in_batch)
498 .set_feature_map_count(in_depths)
499 .set_height(in_rows)
500 .set_width(in_cols)
501 .set_layout(compute_data_layout);
502 se::dnn::FilterDescriptor filter_desc;
503 filter_desc.set_input_filter_height(patch_rows)
504 .set_input_filter_width(patch_cols)
505 .set_input_feature_map_count(patch_depths)
506 .set_output_feature_map_count(filter.dim_size(3))
507 .set_layout(filter_layout);
508 se::dnn::BatchDescriptor bias_desc;
509 bias_desc.set_count(1)
510 .set_height(1)
511 .set_width(1)
512 .set_feature_map_count(out_depths)
513 .set_layout(compute_data_layout);
514 se::dnn::ConvolutionDescriptor conv_desc;
515 conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows)
516 .set_horizontal_dilation_rate(dimensions.dilation_cols)
517 .set_vertical_filter_stride(dimensions.stride_rows)
518 .set_horizontal_filter_stride(dimensions.stride_cols)
519 .set_zero_padding_height(common_padding_rows)
520 .set_zero_padding_width(common_padding_cols)
521 .set_group_count(in_depths / patch_depths);
522 se::dnn::BatchDescriptor output_desc;
523 output_desc.set_count(out_batch)
524 .set_height(out_rows)
525 .set_width(out_cols)
526 .set_feature_map_count(out_depths)
527 .set_layout(compute_data_layout);
528
529 Tensor transformed_filter;
530 const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
531 VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
532 << " to " << ToString(dst_format);
533
534 TensorShape dst_shape =
535 dst_format == FORMAT_OIHW
536 ? TensorShape({filter.dim_size(3), filter.dim_size(2),
537 filter.dim_size(0), filter.dim_size(1)})
538 : TensorShape({filter.dim_size(3), filter.dim_size(0),
539 filter.dim_size(1), filter.dim_size(2)});
540
541 TF_RETURN_IF_ERROR(context->allocate_temp(
542 DataTypeToEnum<T>::value, dst_shape, &transformed_filter));
543 functor::TransformFilter<GPUDevice, T, int, 4>()(
544 context->eigen_device<GPUDevice>(), dst_format,
545 To32Bit(filter.tensor<T, 4>()),
546 To32Bit(transformed_filter.tensor<T, 4>()));
547
548 return OkStatus();
549 };
550
551 if (compute_in_nhwc) {
552 OP_REQUIRES_OK(context, transform_filter(FORMAT_OHWI));
553 } else {
554 OP_REQUIRES_OK(context, transform_filter(FORMAT_OIHW));
555 }
556
557 Tensor transformed_output;
558 if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
559 // Only allocate temporary memory when a layout transformation is needed.
560 OP_REQUIRES_OK(context,
561 context->allocate_temp(
562 DataTypeToEnum<T>::value,
563 ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
564 out_cols, out_depths),
565 &transformed_output));
566 } else {
567 transformed_output = *output;
568 }
569
570 const auto tensor_on_device = [](const Tensor& t) -> se::DeviceMemory<T> {
571 return AsDeviceMemory(t.template flat<T>().data(),
572 t.template flat<T>().size());
573 };
574
575 se::DeviceMemory<T> input_ptr = tensor_on_device(input);
576 se::DeviceMemory<T> filter_ptr = tensor_on_device(transformed_filter);
577 se::DeviceMemory<T> bias_ptr = tensor_on_device(bias);
578 se::DeviceMemory<T> output_ptr = tensor_on_device(transformed_output);
579
580 // We do not use side inputs, so we can safely pass nullptr.
581 se::DeviceMemory<T> side_input_ptr =
582 AsDeviceMemory(static_cast<T*>(nullptr), 0);
583
584 constexpr double kConvScale = 1.0;
585 constexpr double kSideInputScale = 0.0;
586 double leakyrelu_alpha = fusion_args.leakyrelu_alpha;
587
588 int device_id = stream->parent()->device_ordinal();
589 DataType dtype = input.dtype();
590 ConvParameters conv_parameters = {
591 in_batch, // batch
592 in_depths, // in_depths
593 {{in_rows, // in_rows
594 in_cols}}, // in_cols
595 compute_data_format, // compute_data_format
596 out_depths, // out_depths
597 {{patch_rows, // filter_rows
598 patch_cols, // filter_cols
599 patch_depths}}, // filter_depths
600 {{dimensions.dilation_rows, // dilation_rows
601 dimensions.dilation_cols}}, // dilation_cols
602 {{dimensions.stride_rows, // stride_rows
603 dimensions.stride_cols}}, // stride_cols
604 {{common_padding_rows, // padding_rows
605 common_padding_cols}}, // padding_cols
606 dtype, // tensor datatype
607 device_id, // device_id
608 conv_desc.group_count(),
609 ConvParameters::FusionInfo{kConvScale, kSideInputScale, leakyrelu_alpha,
610 dnn_activation_mode, // activation_mode
611 /*is_contrib=*/false}};
612
613 se::dnn::DataType element_type = se::dnn::ToDataType<T>::value;
614
615 auto entry_or = AutotuneFusedConv<T>(
616 cudnn_use_autotune, FusedConvAutotuneMap::GetInstance(),
617 conv_parameters, context, input_desc, filter_desc, bias_desc,
618 output_desc, conv_desc, dnn_activation_mode, kConvScale,
619 kSideInputScale, leakyrelu_alpha, input_ptr, filter_ptr, output_ptr,
620 bias_ptr, side_input_ptr, ConvolveScratchSize());
621 OP_REQUIRES_OK(context, entry_or.status());
622 auto autotune_entry = std::move(entry_or).value();
623
624 DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
625 Status cudnn_launch_status;
626 if (!autotune_entry.is_algorithm_config()) {
627 auto& runners = autotune_entry.GetOpRunners();
628 se::dnn::FusedConvOp::Config config{se::dnn::ConvolutionKind::FORWARD,
629 element_type,
630 element_type,
631 element_type,
632 kConvScale,
633 kSideInputScale,
634 leakyrelu_alpha,
635 input_desc,
636 filter_desc,
637 bias_desc,
638 output_desc,
639 conv_desc,
640 dnn_activation_mode};
641 auto primary_or = runners.primary->GetOrCreateRunner(config, stream);
642 OP_REQUIRES_OK(context, primary_or.status());
643 auto* primary = primary_or.value();
644
645 const se::dnn::FusedConvRunner* no_scratch_fallback = nullptr;
646 if (runners.no_scratch_fallback) {
647 auto no_scratch_fallback_or =
648 runners.no_scratch_fallback->GetOrCreateRunner(config, stream);
649 OP_REQUIRES_OK(context, no_scratch_fallback_or.status());
650 no_scratch_fallback = no_scratch_fallback_or.value();
651 }
652
653 auto runner_and_scratch_or =
654 AllocateScratchOrFallback<se::dnn::FusedConvOp::Signature>(
655 &scratch_allocator, primary, no_scratch_fallback);
656 OP_REQUIRES_OK(context, runner_and_scratch_or.status());
657 auto runner_and_scratch = std::move(runner_and_scratch_or).value();
658 auto& runner =
659 *std::get<const se::dnn::FusedConvRunner*>(runner_and_scratch);
660 cudnn_launch_status = runner(
661 stream, nullptr, std::get<se::DeviceMemoryBase>(runner_and_scratch),
662 input_ptr, filter_ptr, side_input_ptr, bias_ptr, output_ptr);
663 } else {
664 cudnn_launch_status = stream->FusedConvolveWithAlgorithm(
665 input_desc, input_ptr, // input
666 kConvScale, // input_scale
667 filter_desc, filter_ptr, // filter
668 conv_desc, // conv
669 side_input_ptr, kSideInputScale, // side_input
670 bias_desc, bias_ptr, // bias
671 dnn_activation_mode, // activation
672 output_desc, &output_ptr, // output
673 &scratch_allocator, autotune_entry.GetAlgorithmConfig(), nullptr);
674 }
675
676 OP_REQUIRES_OK(context, cudnn_launch_status);
677
678 // Convert the output tensor back from NCHW to NHWC.
679 if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
680 functor::NCHWToNHWC<GPUDevice, T, 4>()(
681 context->eigen_device<GPUDevice>(),
682 const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
683 output->tensor<T, 4>());
684 }
685 }
686};
687
688template <>
689struct LaunchFusedConv2DOp<GPUDevice, int8>;
690
691template <>
692struct LaunchFusedConv2DOp<GPUDevice, qint8>;
693
694#endif // GOOGLE_CUDA
695
696template <typename Device, typename T>
697class FusedConv2DOp : public OpKernel {
698 public:
699 explicit FusedConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
700 OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
701
702 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
703 cudnn_use_autotune_ = CudnnUseAutotune();
704
705 using FCT = FusedComputationType;
706
707 std::vector<FusedComputationPattern> patterns;
708 if (std::is_same<Device, CPUDevice>::value) {
709 patterns = {
710 {FCT::kBiasAdd, {"BiasAdd"}},
711 {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
712 {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
713 {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
714 {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
715 {FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
716 {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
717 {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
718 {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
719 {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}},
720 };
721 }
722
723 // NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
724 // identity activation function, it in theory should allow to fuse
725 // convolution with BiasAdd, but in practice it doesn't work, cuDNN ignores
726 // this parameter and always does Relu activation.
727 if (std::is_same<Device, GPUDevice>::value) {
728 if (std::is_same<T, int8>::value || std::is_same<T, qint8>::value) {
729 patterns = {{FCT::kBiasAdd, {"BiasAdd"}},
730 {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}};
731 } else {
732 patterns = {
733 {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
734 {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
735 {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
736 {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
737 };
738 }
739 }
740
741 OP_REQUIRES_OK(context, InitializeFusedComputation(
742 context, "Conv2D", patterns,
743 &fused_computation_, &fused_computation_args_));
744 }
745
746 void Compute(OpKernelContext* context) override {
747 // Input tensor is of the following dimensions:
748 // [ batch, in_rows, in_cols, in_depth ]
749 const Tensor& input = context->input(0);
750
751 // Input filter is of the following dimensions:
752 // [ filter_rows, filter_cols, in_depth, out_depth]
753 const Tensor& filter = context->input(1);
754
755 Conv2DDimensions dimensions;
756 OP_REQUIRES_OK(context,
757 ComputeConv2DDimension(params_, input, filter, &dimensions));
758
759 TensorShape out_shape = ShapeFromFormat(
760 params_.data_format, dimensions.batch, dimensions.out_rows,
761 dimensions.out_cols, dimensions.out_depth);
762
763 // Output tensor is of the following dimensions:
764 // [ in_batch, out_rows, out_cols, out_depth ]
765 Tensor* output = nullptr;
766 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
767
768 VLOG(2) << "FusedConv2D: in_depth = " << dimensions.in_depth
769 << ", patch_depth = " << dimensions.patch_depth
770 << ", input_cols = " << dimensions.input_cols
771 << ", filter_cols = " << dimensions.filter_cols
772 << ", input_rows = " << dimensions.input_rows
773 << ", filter_rows = " << dimensions.filter_rows
774 << ", stride_rows = " << dimensions.stride_rows
775 << ", stride_cols = " << dimensions.stride_cols
776 << ", dilation_rows = " << dimensions.dilation_rows
777 << ", dilation_cols = " << dimensions.dilation_cols
778 << ", out_depth = " << dimensions.out_depth;
779
780 // If there is nothing to compute, return.
781 if (out_shape.num_elements() == 0) {
782 return;
783 }
784
785 LaunchFusedConv2DOp<Device, T>()(context, use_cudnn_, cudnn_use_autotune_,
786 input, filter, fused_computation_,
787 fused_computation_args_, params_,
788 dimensions, output);
789 }
790
791 private:
792 Conv2DParameters params_;
793 bool use_cudnn_;
794 bool cudnn_use_autotune_;
795
796 FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
797 FusedComputationArgs fused_computation_args_;
798
799 TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DOp);
800};
801
802// Registration of the CPU implementations.
803#define REGISTER_FUSED_CPU_CONV2D(T) \
804 REGISTER_KERNEL_BUILDER( \
805 Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
806 FusedConv2DOp<CPUDevice, T>);
807
808#if GOOGLE_CUDA
809
810#define DECLARE_FUNCTOR_GPU_SPEC(T) \
811 template <> \
812 void TransformFilter<GPUDevice, T, int, 4>::operator()( \
813 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
814 typename TTypes<T, 4, int>::ConstTensor in, \
815 typename TTypes<T, 4, int>::Tensor out); \
816 extern template struct TransformFilter<GPUDevice, T, int, 4>; \
817 template <> \
818 void PadInput<GPUDevice, T, int, 4>::operator()( \
819 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
820 const std::array<int, 2>& padding_left, \
821 const std::array<int, 2>& padding_right, \
822 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
823 const T& padding_value); \
824 extern template struct PadInput<GPUDevice, T, int, 4>
825
826// Registration of the GPU implementations.
827#define REGISTER_FUSED_GPU_CONV2D(T) \
828 REGISTER_KERNEL_BUILDER(Name("_FusedConv2D") \
829 .Device(DEVICE_GPU) \
830 .TypeConstraint<T>("T") \
831 .HostMemory("host_args"), \
832 FusedConv2DOp<GPUDevice, T>);
833
834#endif // GOOGLE_CUDA
835
836} // namespace tensorflow
837
838#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
839