1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define USE_EIGEN_TENSOR
17#define EIGEN_USE_THREADS
18
19#include <utility>
20
21#include "tensorflow/core/framework/kernel_shape_util.h"
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_slice.h"
28#include "tensorflow/core/kernels/conv_2d.h"
29#include "tensorflow/core/kernels/conv_3d.h"
30#include "tensorflow/core/kernels/conv_ops_gpu.h"
31#include "tensorflow/core/kernels/ops_util.h"
32#include "tensorflow/core/lib/core/errors.h"
33#include "tensorflow/core/profiler/lib/scoped_annotation.h"
34#include "tensorflow/core/util/padding.h"
35#include "tensorflow/core/util/tensor_format.h"
36#include "tensorflow/core/util/use_cudnn.h"
37
38#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39#include "tensorflow/core/platform/stream_executor.h"
40#include "tensorflow/core/protobuf/autotuning.pb.h"
41#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
42#include "tensorflow/core/util/proto/proto_utils.h"
43using stream_executor::dnn::DimIndex;
44#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45#if GOOGLE_CUDA
46#include "third_party/gpus/cudnn/cudnn.h"
47#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h"
48#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
49#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
50#endif // GOOGLE_CUDA
51
52namespace tensorflow {
53
54typedef Eigen::ThreadPoolDevice CPUDevice;
55typedef Eigen::GpuDevice GPUDevice;
56
57template <typename Device, typename T>
58struct LaunchConvOp;
59
60template <typename T>
61struct LaunchConvOp<CPUDevice, T> {
62 static void launch(OpKernelContext* context, bool cudnn_use_autotune,
63 const Tensor& input, const Tensor& filter,
64 const std::array<int64, 3>& dilations,
65 const std::array<int64, 3>& strides, const Padding padding,
66 TensorFormat data_format, Tensor* output) {
67 OP_REQUIRES(context, data_format == FORMAT_NHWC,
68 errors::InvalidArgument("CPU implementation of Conv3D "
69 "currently only supports the NHWC "
70 "tensor format."));
71 OP_REQUIRES(context,
72 dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
73 errors::InvalidArgument("CPU implementation of Conv3D "
74 "currently only supports dilated rates "
75 "of 1."));
76 OP_REQUIRES(context, filter.dim_size(3) == input.dim_size(input.dims() - 1),
77 errors::InvalidArgument(
78 "Number of channels in filter (", filter.dim_size(3),
79 ") must match last dimension of input (",
80 input.dim_size(input.dims() - 1), ")"));
81 functor::CuboidConvolution<CPUDevice, T>()(
82 context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
83 input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
84 strides[0], BrainPadding2EigenPadding(padding));
85 }
86};
87
88template <typename Device, typename T>
89class Conv3DOp : public BinaryOp<T> {
90 public:
91 explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
92 string data_format;
93 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
94 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
95 errors::InvalidArgument("Invalid data format"));
96 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
97 OP_REQUIRES(context, stride_.size() == 5,
98 errors::InvalidArgument("Sliding window strides field must "
99 "specify 5 dimensions"));
100 OP_REQUIRES(
101 context,
102 (GetTensorDim(stride_, data_format_, 'N') == 1 &&
103 GetTensorDim(stride_, data_format_, 'C') == 1),
104 errors::InvalidArgument("Current implementation does not yet support "
105 "strides in the batch and depth dimensions."));
106 OP_REQUIRES(
107 context,
108 (GetTensorDim(stride_, data_format_, '0') > 0 &&
109 GetTensorDim(stride_, data_format_, '1') > 0 &&
110 GetTensorDim(stride_, data_format_, '2') > 0),
111 errors::InvalidArgument("Spatial strides should be larger than 0."));
112 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
113 OP_REQUIRES(context, dilation_.size() == 5,
114 errors::InvalidArgument("Dilation rates field must "
115 "specify 5 dimensions"));
116 OP_REQUIRES(context,
117 (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
118 GetTensorDim(dilation_, data_format_, 'C') == 1),
119 errors::InvalidArgument(
120 "Current implementation does not yet support "
121 "dilation rates in the batch and depth dimensions."));
122 OP_REQUIRES(
123 context,
124 (GetTensorDim(dilation_, data_format_, '0') > 0 &&
125 GetTensorDim(dilation_, data_format_, '1') > 0 &&
126 GetTensorDim(dilation_, data_format_, '2') > 0),
127 errors::InvalidArgument("Dilated rates should be larger than 0."));
128 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
129 cudnn_use_autotune_ = CudnnUseAutotune();
130 }
131
132 void Compute(OpKernelContext* context) override {
133 // Input tensor is of the following dimensions:
134 // [ batch, in_z, in_y, in_x, in_channels ]
135 const Tensor& input = context->input(0);
136
137 // Input filter is of the following dimensions:
138 // [ filter_z, filter_y, filter_x, in_channels, out_channels]
139 const Tensor& filter = context->input(1);
140
141 // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
142 // kept consistent between input/filter/output.
143 OP_REQUIRES(context, input.dims() == 5,
144 errors::InvalidArgument("input must be 5-dimensional"));
145 OP_REQUIRES(context, filter.dims() == 5,
146 errors::InvalidArgument("filter must be 5-dimensional"));
147
148 const int64_t in_depth = GetTensorDim(input, data_format_, 'C');
149 const int64_t in_batch = GetTensorDim(input, data_format_, 'N');
150
151 const int64_t filter_depth = filter.dim_size(3);
152 const int64_t out_depth = filter.dim_size(4);
153
154 OP_REQUIRES(context, filter_depth != 0,
155 errors::InvalidArgument("filter_depth must be non-zero"));
156 OP_REQUIRES(context, in_depth % filter_depth == 0,
157 errors::InvalidArgument(
158 "Input depth must be evenly divisible by filter depth: ",
159 in_depth, " vs ", filter_depth));
160 OP_REQUIRES(
161 context, filter.NumElements() > 0,
162 errors::InvalidArgument("filter must not have zero elements "
163 "(i.e. all dimensions must be non-zero)"));
164
165 // Dimension order for these arrays is: z, y, x.
166 std::array<int64_t, 3> input_size = {
167 {GetTensorDim(input, data_format_, '0'),
168 GetTensorDim(input, data_format_, '1'),
169 GetTensorDim(input, data_format_, '2')}};
170 std::array<int64_t, 3> filter_size = {
171 {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
172 std::array<int64_t, 3> dilations = {
173 {GetTensorDim(dilation_, data_format_, '0'),
174 GetTensorDim(dilation_, data_format_, '1'),
175 GetTensorDim(dilation_, data_format_, '2')}};
176 std::array<int64_t, 3> strides = {
177 {GetTensorDim(stride_, data_format_, '0'),
178 GetTensorDim(stride_, data_format_, '1'),
179 GetTensorDim(stride_, data_format_, '2')}};
180 std::array<int64_t, 3> out, padding;
181
182 OP_REQUIRES_OK(
183 context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
184 padding_, &out, &padding));
185 TensorShape out_shape = ShapeFromFormat(
186 data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
187 Tensor* output;
188 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
189
190 // Return early if nothing to do.
191 if (out_shape.num_elements() == 0) return;
192
193 LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
194 dilations, strides, padding_, data_format_,
195 output);
196 }
197
198 private:
199 std::vector<int32> dilation_;
200 std::vector<int32> stride_;
201 Padding padding_;
202 TensorFormat data_format_;
203 bool cudnn_use_autotune_;
204};
205
206#define REGISTER_CPU_KERNEL(T) \
207 REGISTER_KERNEL_BUILDER( \
208 Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
209 Conv3DOp<CPUDevice, T>);
210TF_CALL_half(REGISTER_CPU_KERNEL);
211TF_CALL_float(REGISTER_CPU_KERNEL);
212TF_CALL_double(REGISTER_CPU_KERNEL);
213TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
214#undef REGISTER_CPU_KERNEL
215
216#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
217
218// A dummy type to group forward convolution autotune results together.
219struct Conv3dAutotuneGroup {
220 static string name() { return "Conv3d"; }
221};
222
223typedef AutotuneSingleton<Conv3dAutotuneGroup, ConvParameters,
224 AutotuneEntry<se::dnn::ConvOp>>
225 AutotuneConv3d;
226
227// TODO(mjanusz): Share logic with 2d implementation as much as possible.
228template <typename T>
229struct LaunchConvOp<GPUDevice, T> {
230 static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
231 const Tensor& input_param, const Tensor& filter,
232 const std::array<int64, 3>& dilations,
233 const std::array<int64, 3>& strides, const Padding padding,
234 TensorFormat data_format, Tensor* output) {
235 auto* stream = ctx->op_device_context()->stream();
236 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
237
238 Tensor input = input_param;
239
240 const int64_t in_batch = GetTensorDim(input, data_format, 'N');
241 int64_t in_planes = GetTensorDim(input, data_format, '0');
242 int64_t in_rows = GetTensorDim(input, data_format, '1');
243 int64_t in_cols = GetTensorDim(input, data_format, '2');
244 const int64_t in_depth = GetTensorDim(input, data_format, 'C');
245
246 const int64_t filter_planes = filter.dim_size(0);
247 const int64_t filter_rows = filter.dim_size(1);
248 const int64_t filter_cols = filter.dim_size(2);
249 const int64_t filter_depth = filter.dim_size(3);
250 const int64_t out_depth = filter.dim_size(4);
251
252 int64_t pad_planes = 0, pad_rows = 0, pad_cols = 0;
253 int64_t out_planes = GetTensorDim(*output, data_format, '0');
254 int64_t out_rows = GetTensorDim(*output, data_format, '1');
255 int64_t out_cols = GetTensorDim(*output, data_format, '2');
256
257 if (padding == Padding::SAME) {
258 pad_planes = std::max<int64_t>(
259 0, (out_planes - 1) * strides[0] + filter_planes - in_planes);
260 pad_rows = std::max<int64_t>(
261 0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
262 pad_cols = std::max<int64_t>(
263 0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
264 }
265
266 bool is_grouped_convolution = filter_depth != in_depth;
267
268 // NOTE: This only works in NHWC.
269 if (!is_grouped_convolution && filter_planes == 1 && filter_rows == 1 &&
270 filter_cols == 1 && dilations[0] == 1 && dilations[1] == 1 &&
271 dilations[2] == 1 && strides[0] == 1 && strides[1] == 1 &&
272 strides[2] == 1 && data_format == FORMAT_NHWC) {
273 // 1x1 filter, so call cublas directly.
274 const uint64 m = in_batch * in_planes * in_rows * in_cols;
275 const uint64 k = in_depth;
276 const uint64 n = out_depth;
277
278 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
279 input.template flat<T>().size());
280 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
281 filter.template flat<T>().size());
282 auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
283 output->template flat<T>().size());
284
285 auto no_transpose = se::blas::Transpose::kNoTranspose;
286 OP_REQUIRES_OK(
287 ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr,
288 n, a_ptr, k, &c_ptr, n,
289 se::blas::kDefaultComputePrecision));
290 return;
291 } else if (!is_grouped_convolution && filter_planes == in_planes &&
292 filter_rows == in_rows && filter_cols == in_cols &&
293 padding == Padding::VALID && data_format == FORMAT_NHWC) {
294 // The input data and filter have the same planes/height/width, so call
295 // cublas directly.
296 const uint64 m = in_batch;
297 const uint64 k = in_planes * in_rows * in_cols * in_depth;
298 const uint64 n = out_depth;
299
300 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
301 input.template flat<T>().size());
302 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
303 filter.template flat<T>().size());
304 auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
305 output->template flat<T>().size());
306
307 auto no_transpose = se::blas::Transpose::kNoTranspose;
308 OP_REQUIRES_OK(
309 ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr,
310 n, a_ptr, k, &c_ptr, n,
311 se::blas::kDefaultComputePrecision));
312 return;
313 }
314
315 if (padding == Padding::SAME) {
316 const bool rows_odd = (pad_rows % 2 != 0);
317 const bool cols_odd = (pad_cols % 2 != 0);
318 const bool planes_odd = (pad_planes % 2 != 0);
319
320 // Necessary because cuDNN only supports symmetric padding.
321 // TODO(mjanusz): Consider making this optional? This would save some
322 // overhead and would work as long as an op trained this way is only
323 // used on GPU.
324 if (rows_odd || cols_odd || planes_odd) {
325 const int64_t new_in_rows = in_rows + rows_odd;
326 const int64_t new_in_cols = in_cols + cols_odd;
327 const int64_t new_in_planes = in_planes + planes_odd;
328
329 Tensor transformed_input;
330 TensorShape transformed_shape = ShapeFromFormat(
331 data_format, in_batch, {{new_in_planes, new_in_rows, new_in_cols}},
332 in_depth);
333 OP_REQUIRES_OK(
334 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape,
335 &transformed_input));
336
337 functor::PadInput<GPUDevice, T, int, 5>()(
338 ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()),
339 {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}},
340 To32Bit(transformed_input.tensor<T, 5>()), data_format, T{});
341 input = transformed_input;
342 in_rows = new_in_rows;
343 in_cols = new_in_cols;
344 in_planes = new_in_planes;
345 }
346 }
347
348#if GOOGLE_CUDA
349 const bool compute_in_nhwc =
350 CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
351#else
352 // fast NHWC implementation is a CUDA only feature
353 const bool compute_in_nhwc = false;
354#endif
355 const TensorFormat compute_data_format =
356 (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
357 : FORMAT_NCHW;
358
359 VLOG(3) << "Compute Conv3D with cuDNN:"
360 << " data_format=" << ToString(data_format)
361 << " compute_data_format=" << ToString(compute_data_format);
362
363 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
364 VLOG(4) << "Convert the input tensor from NDHWC to NCDHW.";
365 const TensorShape nchw_shape = ShapeFromFormat(
366 FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
367 if (in_depth > 1) {
368 Tensor transformed_input;
369 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
370 nchw_shape, &transformed_input));
371 // input: [b, x, y, z, d]
372 // t_input: [b, d, x, y, z]
373 // NCDHW is the only format universally supported by cuDNN.
374 functor::NHWCToNCHW<GPUDevice, T, 5>()(
375 ctx->eigen_device<GPUDevice>(),
376 const_cast<const Tensor&>(input).tensor<T, 5>(),
377 transformed_input.tensor<T, 5>());
378 input = transformed_input;
379 } else {
380 CHECK(input.CopyFrom(input, nchw_shape));
381 }
382 } else {
383 CHECK(data_format == compute_data_format) // Crash OK
384 << "Illegal data and compute format pair:"
385 << " data_format=" << ToString(data_format)
386 << " compute_data_format=" << ToString(compute_data_format);
387 }
388
389 constexpr auto kComputeInNHWC =
390 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
391 se::dnn::FilterLayout::kOutputYXInput);
392 constexpr auto kComputeInNCHW =
393 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
394 se::dnn::FilterLayout::kOutputInputYX);
395
396 se::dnn::DataLayout compute_data_layout;
397 se::dnn::FilterLayout filter_layout;
398
399 std::tie(compute_data_layout, filter_layout) =
400 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
401
402 CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
403 << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
404 << pad_planes << ")";
405 se::dnn::BatchDescriptor input_desc(3);
406 input_desc.set_count(in_batch)
407 .set_feature_map_count(in_depth)
408 .set_spatial_dim(DimIndex::X, in_cols)
409 .set_spatial_dim(DimIndex::Y, in_rows)
410 .set_spatial_dim(DimIndex::Z, in_planes)
411 .set_layout(compute_data_layout);
412 se::dnn::BatchDescriptor output_desc(3);
413 output_desc.set_count(in_batch)
414 .set_spatial_dim(DimIndex::X, out_cols)
415 .set_spatial_dim(DimIndex::Y, out_rows)
416 .set_spatial_dim(DimIndex::Z, out_planes)
417 .set_feature_map_count(out_depth)
418 .set_layout(compute_data_layout);
419 se::dnn::FilterDescriptor filter_desc(3);
420 filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
421 .set_spatial_dim(DimIndex::Y, filter_rows)
422 .set_spatial_dim(DimIndex::Z, filter_planes)
423 .set_input_feature_map_count(filter_depth)
424 .set_output_feature_map_count(out_depth)
425 .set_layout(filter_layout);
426 se::dnn::ConvolutionDescriptor conv_desc(3);
427 conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
428 .set_dilation_rate(DimIndex::Y, dilations[1])
429 .set_dilation_rate(DimIndex::Z, dilations[0])
430 .set_filter_stride(DimIndex::X, strides[2])
431 .set_filter_stride(DimIndex::Y, strides[1])
432 .set_filter_stride(DimIndex::Z, strides[0])
433 .set_zero_padding(DimIndex::X, pad_cols / 2)
434 .set_zero_padding(DimIndex::Y, pad_rows / 2)
435 .set_zero_padding(DimIndex::Z, pad_planes / 2)
436 .set_group_count(in_depth / filter_depth);
437
438 Tensor transformed_filter;
439 auto dst_format =
440 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
441 VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
442 << " to " << ToString(dst_format);
443 TensorShape dst_shape =
444 dst_format == FORMAT_OIHW
445 ? TensorShape({filter.dim_size(4), filter.dim_size(3),
446 filter.dim_size(0), filter.dim_size(1),
447 filter.dim_size(2)})
448 : TensorShape({filter.dim_size(4), filter.dim_size(0),
449 filter.dim_size(1), filter.dim_size(2),
450 filter.dim_size(3)});
451 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
452 &transformed_filter));
453 // filter: [x, y, z, in, out]
454 // t_filter: [out, in, x, y, z] (NCDHW) or
455 // t_filter: [out, x, y, z, in] (NDHWC)
456 functor::TransformFilter<GPUDevice, T, int, 5>()(
457 ctx->eigen_device<GPUDevice>(), dst_format,
458 To32Bit(filter.tensor<T, 5>()),
459 To32Bit(transformed_filter.tensor<T, 5>()));
460
461 Tensor transformed_output;
462 if (data_format != compute_data_format) {
463 VLOG(4) << "Allocate temporary memory for output in compute data format";
464 OP_REQUIRES_OK(
465 ctx,
466 ctx->allocate_temp(
467 DataTypeToEnum<T>::value,
468 ShapeFromFormat(FORMAT_NCHW, in_batch,
469 {{out_planes, out_rows, out_cols}}, out_depth),
470 &transformed_output));
471 } else {
472 transformed_output = *output;
473 }
474
475 auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
476 input.template flat<T>().size());
477 auto filter_ptr =
478 AsDeviceMemory(transformed_filter.template flat<T>().data(),
479 transformed_filter.template flat<T>().size());
480 auto output_ptr =
481 AsDeviceMemory(transformed_output.template flat<T>().data(),
482 transformed_output.template flat<T>().size());
483
484 static int64_t ConvolveScratchSize = GetDnnWorkspaceLimitOrDefault();
485
486 int device_id = stream->parent()->device_ordinal();
487 DataType dtype = input.dtype();
488 ConvParameters conv_parameters = {
489 in_batch,
490 in_depth,
491 {{in_planes, in_rows, in_cols}},
492 compute_data_format,
493 out_depth,
494 {{filter_planes, filter_rows, filter_cols}},
495 {{dilations[0], dilations[1], dilations[2]}},
496 {{strides[0], strides[1], strides[2]}},
497 {{pad_planes, pad_rows, pad_cols}},
498 dtype,
499 device_id,
500 conv_desc.group_count()};
501
502 using se::dnn::AlgorithmConfig;
503 using se::dnn::AlgorithmDesc;
504 using se::dnn::ProfileResult;
505
506 auto config_or = AutotuneUnfusedConv(
507 cudnn_use_autotune, AutotuneConv3d::GetInstance(), conv_parameters, ctx,
508 se::dnn::ConvolutionKind::FORWARD, input_desc, input_ptr, filter_desc,
509 filter_ptr, conv_desc, output_desc, output_ptr, ConvolveScratchSize);
510 OP_REQUIRES_OK(ctx, config_or.status());
511 auto autotune_entry = std::move(config_or).value();
512
513 DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
514 Status cudnn_launch_status = LaunchAutotunedConv(
515 autotune_entry, &scratch_allocator, se::dnn::ConvolutionKind::FORWARD,
516 stream, input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
517 output_desc, output_ptr);
518 if (!cudnn_launch_status.ok()) {
519 ctx->SetStatus(cudnn_launch_status);
520 return;
521 }
522
523 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
524 VLOG(4) << "Convert the output tensor back from NCDHW to NDHWC.";
525 // t_output: [b, out, x, y, z]
526 // output: [b, x, y, z, out]
527 functor::NCHWToNHWC<GPUDevice, T, 5>()(
528 ctx->eigen_device<GPUDevice>(),
529 const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
530 output->tensor<T, 5>());
531 }
532 }
533};
534
535// Forward declarations of the functor specializations for GPU.
536// This ensures that the custom implementation is used instead of the default
537// Eigen one (which is used for CPU).
538namespace functor {
539#define DECLARE_GPU_SPEC(T) \
540 template <> \
541 void TransformFilter<GPUDevice, T, int, 5>::operator()( \
542 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
543 typename TTypes<T, 5, int>::ConstTensor in, \
544 typename TTypes<T, 5, int>::Tensor out); \
545 template <> \
546 void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
547 const GPUDevice& d, FilterTensorFormat src_filter_format, \
548 typename TTypes<T, 5>::ConstTensor in, \
549 typename TTypes<T, 5>::Tensor out); \
550 template <> \
551 void PadInput<GPUDevice, T, int, 5>::operator()( \
552 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
553 const std::array<int, 3>& padding_left, \
554 const std::array<int, 3>& padding_right, \
555 typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
556 const T& padding_value); \
557 template <> \
558 void NHWCToNCHW<GPUDevice, T, 5>::operator()( \
559 const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \
560 typename TTypes<T, 5>::Tensor out); \
561 template <> \
562 void NCHWToNHWC<GPUDevice, T, 5>::operator()( \
563 const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \
564 typename TTypes<T, 5>::Tensor out);
565
566DECLARE_GPU_SPEC(Eigen::half);
567DECLARE_GPU_SPEC(float);
568DECLARE_GPU_SPEC(double);
569#undef DECLARE_GPU_SPEC
570
571} // namespace functor
572
573// Registration of the GPU implementations.
574REGISTER_KERNEL_BUILDER(
575 Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
576 Conv3DOp<GPUDevice, Eigen::half>);
577REGISTER_KERNEL_BUILDER(
578 Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
579 Conv3DOp<GPUDevice, float>);
580REGISTER_KERNEL_BUILDER(
581 Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
582 Conv3DOp<GPUDevice, double>);
583#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
584
585} // namespace tensorflow
586