1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define USE_EIGEN_TENSOR
17#define EIGEN_USE_THREADS
18
19#include <utility>
20
21#include "tensorflow/core/framework/kernel_shape_util.h"
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_slice.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/kernels/conv_2d.h"
30#include "tensorflow/core/kernels/conv_3d.h"
31#include "tensorflow/core/kernels/conv_grad_ops.h"
32#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
33#include "tensorflow/core/kernels/conv_ops_gpu.h"
34#include "tensorflow/core/lib/core/errors.h"
35#include "tensorflow/core/lib/gtl/inlined_vector.h"
36#include "tensorflow/core/profiler/lib/scoped_annotation.h"
37#include "tensorflow/core/util/padding.h"
38#include "tensorflow/core/util/tensor_format.h"
39#include "tensorflow/core/util/use_cudnn.h"
40#include "tensorflow/core/util/work_sharder.h"
41
42#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
43#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
44#endif
45
46#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47#include "tensorflow/core/platform/stream_executor.h"
48using stream_executor::dnn::DimIndex;
49#include "tensorflow/core/protobuf/autotuning.pb.h"
50#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
51#include "tensorflow/core/util/proto/proto_utils.h"
52#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53#if GOOGLE_CUDA
54#include "third_party/gpus/cudnn/cudnn.h"
55#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
56#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
57#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
58#endif // GOOGLE_CUDA
59
60namespace {
61
62// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
63// conv_grad_input_ops_3d.cc.
64
65// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
66
67// "Depth" is already used for the channel dimension, so for the third spatial
68// dimension in this file we use "plane", although in NDHWC layout it's
69// indicated with a "D".
70
71// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
72// order (planes, height, width, depth), constructed from patches in 'col_data',
73// which is required to be in storage order (out_planes * out_height *
74// out_width, filter_planes, filter_height, filter_width, in_depth).
75//
76// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
77template <typename T>
78void Col2im(const T* col_data, const int depth, const int planes,
79 const int height, const int width, const int filter_p,
80 const int filter_h, const int filter_w, const int pad_pt,
81 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
82 const int pad_r, const int stride_p, const int stride_h,
83 const int stride_w, T* im_data) {
84 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
85 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
86 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
87 int p_pad = -pad_pt;
88 for (int p = 0; p < planes_col; ++p) {
89 int h_pad = -pad_t;
90 for (int h = 0; h < height_col; ++h) {
91 int w_pad = -pad_l;
92 for (int w = 0; w < width_col; ++w) {
93 T* im_patch_data =
94 im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
95 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
96 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
97 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
98 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
99 iw < width) {
100 for (int i = 0; i < depth; ++i) {
101 im_patch_data[i] += col_data[i];
102 }
103 }
104 im_patch_data += depth;
105 col_data += depth;
106 }
107 // Jump over remaining number of depth.
108 im_patch_data += depth * (width - filter_w);
109 }
110 // Jump over remaining number of (depth * width).
111 im_patch_data += (depth * width) * (height - filter_h);
112 }
113 w_pad += stride_w;
114 }
115 h_pad += stride_h;
116 }
117 p_pad += stride_p;
118 }
119}
120
121// Returns in 'col_data', image patches in storage order (planes, height, width,
122// depth) extracted from image at 'input_data', which is required to be in
123// storage order (batch, planes, height, width, depth).
124//
125// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
126template <typename T>
127void Im2col(const T* input_data, const int depth, const int planes,
128 const int height, const int width, const int filter_p,
129 const int filter_h, const int filter_w, const int pad_pt,
130 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
131 const int pad_r, const int stride_p, const int stride_h,
132 const int stride_w, T* col_data) {
133 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
134 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
135 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
136
137 int p_pad = -pad_pt;
138 for (int p = 0; p < planes_col; ++p) {
139 int h_pad = -pad_t;
140 for (int h = 0; h < height_col; ++h) {
141 int w_pad = -pad_l;
142 for (int w = 0; w < width_col; ++w) {
143 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
144 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
145 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
146 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
147 iw < width) {
148 memcpy(col_data,
149 input_data +
150 (ip * height * width + ih * width + iw) * depth,
151 sizeof(T) * depth);
152 } else {
153 // This should be simply padded with zero.
154 memset(col_data, 0, sizeof(T) * depth);
155 }
156 col_data += depth;
157 }
158 }
159 }
160 w_pad += stride_w;
161 }
162 h_pad += stride_h;
163 }
164 p_pad += stride_p;
165 }
166}
167
168} // namespace
169
170namespace tensorflow {
171
172typedef Eigen::ThreadPoolDevice CPUDevice;
173typedef Eigen::GpuDevice GPUDevice;
174
175// Backprop for input that offloads computation to
176// Eigen::CuboidConvolutionBackwardInput.
177template <typename Device, class T>
178class Conv3DBackpropInputOp : public OpKernel {
179 public:
180 explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
181 : OpKernel(context),
182 data_format_(FORMAT_NHWC),
183 takes_shape_(type_string().find("V2") != std::string::npos) {
184 // data_format is only available in V2.
185 if (takes_shape_) {
186 string data_format;
187 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
188 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
189 errors::InvalidArgument("Invalid data format"));
190 OP_REQUIRES(
191 context, data_format_ == FORMAT_NHWC,
192 errors::InvalidArgument(
193 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
194 }
195
196 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
197 OP_REQUIRES(context, dilation_.size() == 5,
198 errors::InvalidArgument("Dilation rates field must "
199 "specify 5 dimensions"));
200 OP_REQUIRES(context,
201 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
202 GetTensorDim(dilation_, data_format_, 'N') == 1),
203 errors::InvalidArgument(
204 "Current implementation does not yet support "
205 "dilation rates in the batch and depth dimensions."));
206
207 // TODO(yangzihao): Add CPU version of dilated conv 3D.
208 OP_REQUIRES(context,
209 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
210 GetTensorDim(dilation_, data_format_, '1') == 1 &&
211 GetTensorDim(dilation_, data_format_, '2') == 1),
212 errors::InvalidArgument(
213 "Current CPU implementation does not yet support "
214 "dilation rates larger than 1."));
215
216 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
217 OP_REQUIRES(context, stride_.size() == 5,
218 errors::InvalidArgument("Sliding window strides field must "
219 "specify 5 dimensions"));
220 OP_REQUIRES(
221 context,
222 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
223 GetTensorDim(stride_, data_format_, 'N') == 1),
224 errors::InvalidArgument("Current implementation does not yet support "
225 "strides in the batch and depth dimensions."));
226 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
227 }
228
229 void Compute(OpKernelContext* context) override {
230 const Tensor& filter = context->input(1);
231 const TensorShape& filter_shape = filter.shape();
232
233 const Tensor& out_backprop = context->input(2);
234 const TensorShape& out_backprop_shape = out_backprop.shape();
235
236 TensorShape input_shape;
237 if (takes_shape_) {
238 const Tensor& input_sizes = context->input(0);
239 // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
240 // input_sizes.
241 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
242 } else {
243 input_shape = context->input(0).shape();
244 }
245
246 OP_REQUIRES(context, input_shape.dims() == 5,
247 errors::InvalidArgument("input tensor must have 5 dimensions"));
248 OP_REQUIRES(
249 context, filter_shape.dims() == 5,
250 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
251 OP_REQUIRES(
252 context, out_backprop_shape.dims() == 5,
253 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
254 OP_REQUIRES(
255 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
256 errors::InvalidArgument("input and filter_sizes must have the same "
257 "number of channels. Got ",
258 input_shape.dim_size(4), " for input and ",
259 filter_shape.dim_size(3), " for filter_sizes"));
260 OP_REQUIRES(
261 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
262 errors::InvalidArgument("out_backprop and filter_sizes must have the "
263 "same number of channels. Got ",
264 out_backprop_shape.dim_size(4),
265 " for out_backprop and ",
266 filter_shape.dim_size(4), " for filter_sizes"));
267
268 ConvBackpropDimensions dims;
269 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
270 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
271 input_shape, filter_shape, out_backprop_shape,
272 stride_, padding_, data_format_, &dims));
273
274 Tensor* in_backprop;
275 OP_REQUIRES_OK(context,
276 context->allocate_output(0, input_shape, &in_backprop));
277
278 functor::CuboidConvolutionBackwardInput<Device, T>()(
279 context->eigen_device<Device>(),
280 in_backprop->tensor<T, 5>(), // input_backward
281 filter.tensor<T, 5>(), // filter
282 out_backprop.tensor<T, 5>(), // output_backward
283 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
284 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
285 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
286 }
287
288 private:
289 std::vector<int32> dilation_;
290 std::vector<int32> stride_;
291 Padding padding_;
292 TensorFormat data_format_;
293 bool takes_shape_;
294
295 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
296};
297
298// Custom backprop for input that explicitly does the work sharding and calls
299// Eigen only to multiply matrices.
300template <typename Device, class T>
301class Conv3DCustomBackpropInputOp : public OpKernel {
302 // Limit the maximum size of allocated temporary buffer to
303 // kMaxTempAllocationOverhead times the size of the input tensors (input,
304 // filter, out_backprop). If the size of the temporary buffer exceeds this
305 // limit, fallback on Eigen implementation.
306 static constexpr int kMaxTempAllocationOverhead = 25;
307
308 public:
309 explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
310 : OpKernel(context),
311 data_format_(FORMAT_NHWC),
312 takes_shape_(type_string().find("V2") != std::string::npos) {
313 // data_format is only available in V2.
314 if (takes_shape_) {
315 string data_format;
316 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
317 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
318 errors::InvalidArgument("Invalid data format"));
319 OP_REQUIRES(
320 context, data_format_ == FORMAT_NHWC,
321 errors::InvalidArgument(
322 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
323 }
324
325 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
326 OP_REQUIRES(context, dilation_.size() == 5,
327 errors::InvalidArgument("Dilation rates field must "
328 "specify 5 dimensions"));
329 OP_REQUIRES(context,
330 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
331 GetTensorDim(dilation_, data_format_, 'N') == 1),
332 errors::InvalidArgument(
333 "Current implementation does not yet support "
334 "dilation rates in the batch and depth dimensions."));
335
336 // TODO(yangzihao): Add CPU version of dilated conv 3D.
337 OP_REQUIRES(context,
338 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
339 GetTensorDim(dilation_, data_format_, '1') == 1 &&
340 GetTensorDim(dilation_, data_format_, '2') == 1),
341 errors::InvalidArgument(
342 "Current CPU implementation does not yet support "
343 "dilation rates larger than 1."));
344
345 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
346 OP_REQUIRES(context, stride_.size() == 5,
347 errors::InvalidArgument("Sliding window strides field must "
348 "specify 5 dimensions"));
349 OP_REQUIRES(
350 context,
351 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
352 GetTensorDim(stride_, data_format_, 'N') == 1),
353 errors::InvalidArgument("Current implementation does not yet support "
354 "strides in the batch and depth dimensions."));
355 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
356 }
357
358 void Compute(OpKernelContext* context) override {
359 const Tensor& filter = context->input(1);
360 const TensorShape& filter_shape = filter.shape();
361
362 const Tensor& out_backprop = context->input(2);
363 const TensorShape& out_backprop_shape = out_backprop.shape();
364
365 TensorShape input_shape;
366 if (takes_shape_) {
367 const Tensor& input_sizes = context->input(0);
368 // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
369 // input_sizes.
370 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
371 } else {
372 input_shape = context->input(0).shape();
373 }
374
375 OP_REQUIRES(context, input_shape.dims() == 5,
376 errors::InvalidArgument("input tensor must have 5 dimensions"));
377 OP_REQUIRES(
378 context, filter_shape.dims() == 5,
379 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
380 OP_REQUIRES(
381 context, out_backprop_shape.dims() == 5,
382 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
383 OP_REQUIRES(
384 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
385 errors::InvalidArgument("input and filter_sizes must have the same "
386 "number of channels. Got ",
387 input_shape.dim_size(4), " for input and ",
388 filter_shape.dim_size(3), " for filter_sizes"));
389 OP_REQUIRES(
390 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
391 errors::InvalidArgument("out_backprop and filter_sizes must have the "
392 "same number of channels. Got ",
393 out_backprop_shape.dim_size(4),
394 " for out_backprop and ",
395 filter_shape.dim_size(4), " for filter_sizes"));
396
397 ConvBackpropDimensions dims;
398 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
399 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
400 input_shape, filter_shape, out_backprop_shape,
401 stride_, padding_, data_format_, &dims));
402
403 Tensor* in_backprop;
404 OP_REQUIRES_OK(context,
405 context->allocate_output(0, input_shape, &in_backprop));
406
407 int64_t top_pad_planes, bottom_pad_planes;
408 int64_t top_pad_rows, bottom_pad_rows;
409 int64_t left_pad_cols, right_pad_cols;
410
411 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
412 dims.spatial_dims[0].input_size,
413 dims.spatial_dims[0].filter_size,
414 dims.spatial_dims[0].stride, padding_,
415 &dims.spatial_dims[0].output_size,
416 &top_pad_planes, &bottom_pad_planes));
417 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
418 dims.spatial_dims[1].input_size,
419 dims.spatial_dims[1].filter_size,
420 dims.spatial_dims[1].stride, padding_,
421 &dims.spatial_dims[1].output_size,
422 &top_pad_rows, &bottom_pad_rows));
423 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
424 dims.spatial_dims[2].input_size,
425 dims.spatial_dims[2].filter_size,
426 dims.spatial_dims[2].stride, padding_,
427 &dims.spatial_dims[2].output_size,
428 &left_pad_cols, &right_pad_cols));
429
430 // TODO(ezhulenev): Extract work size and shard estimation to shared
431 // functions in conv_grad_ops, and update 2d convolution backprop.
432
433 // The total dimension size of each kernel.
434 const int64_t filter_total_size =
435 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
436 dims.spatial_dims[2].filter_size * dims.in_depth;
437
438 // The output image size is the spatial size of the output.
439 const int64_t output_image_size = dims.spatial_dims[0].output_size *
440 dims.spatial_dims[1].output_size *
441 dims.spatial_dims[2].output_size;
442
443 const auto cache_sizes = Eigen::internal::CacheSizes();
444 const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
445
446 // Use L3 cache size as target working set size.
447 const size_t target_working_set_size = l3_cache_size / sizeof(T);
448
449 // Calculate size of matrices involved in MatMul: C = A x B.
450 const int64_t size_A = output_image_size * dims.out_depth;
451
452 const int64_t size_B = filter_total_size * dims.out_depth;
453
454 const int64_t size_C = output_image_size * filter_total_size;
455
456 const int64_t work_unit_size = size_A + size_B + size_C;
457
458 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
459
460 // Use parallel tensor contractions if there is no batching.
461 //
462 // Compared to Conv2D code, this version is missing work size estimation. In
463 // benchmarks I didn't find a case when it's beneficial to run parallel
464 // contraction compared to sharding and matmuls.
465 const bool use_parallel_contraction = dims.batch_size == 1;
466
467 OP_REQUIRES(
468 context, work_unit_size > 0,
469 errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
470 "must all have at least 1 element"));
471
472 const size_t shard_size =
473 use_parallel_contraction
474 ? 1
475 : (target_working_set_size + work_unit_size - 1) / work_unit_size;
476
477 // Total number of elements in all the tensors used by this kernel.
478 int64_t total_tensor_elements = input_shape.num_elements() +
479 filter_shape.num_elements() +
480 out_backprop_shape.num_elements();
481
482 // Shape of the temporary workspace buffer.
483 TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
484 static_cast<int64_t>(output_image_size),
485 static_cast<int64_t>(filter_total_size)};
486 int64_t col_buffer_elements = col_buffer_shape.num_elements();
487
488 // If the temporary allocation overhead is too large, fallback on Eigen
489 // implementation which requires much less memory.
490 int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
491 if (col_buffer_overhead > kMaxTempAllocationOverhead) {
492 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
493 "col_buffer_overhead="
494 << col_buffer_overhead;
495
496 functor::CuboidConvolutionBackwardInput<Device, T>()(
497 context->eigen_device<Device>(),
498 in_backprop->tensor<T, 5>(), // input_backward
499 filter.tensor<T, 5>(), // filter
500 out_backprop.tensor<T, 5>(), // output_backward
501 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
502 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
503 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
504
505 return;
506 }
507
508 Tensor col_buffer;
509 OP_REQUIRES_OK(context,
510 context->allocate_temp(DataTypeToEnum<T>::value,
511 col_buffer_shape, &col_buffer));
512
513 // The input offset corresponding to a single input image.
514 const int64_t input_offset =
515 dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
516 dims.spatial_dims[2].input_size * dims.in_depth;
517
518 // The output offset corresponding to a single output image.
519 const int64_t output_offset =
520 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
521 dims.spatial_dims[2].output_size * dims.out_depth;
522
523 const T* filter_data = filter.template flat<T>().data();
524 T* col_buffer_data = col_buffer.template flat<T>().data();
525 const T* out_backprop_data = out_backprop.template flat<T>().data();
526
527 auto in_backprop_flat = in_backprop->template flat<T>();
528 T* input_backprop_data = in_backprop_flat.data();
529 in_backprop_flat.device(context->eigen_device<Device>()) =
530 in_backprop_flat.constant(T(0));
531
532 if (use_parallel_contraction) {
533 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
534 Eigen::Unaligned>
535 TensorMap;
536 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
537 Eigen::Unaligned>
538 ConstTensorMap;
539
540 // Initialize contraction dims (we need to transpose 'B' below).
541 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
542 contract_dims[0].first = 1;
543 contract_dims[0].second = 1;
544
545 for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
546 // Compute gradient into col_buffer.
547 TensorMap C(col_buffer_data, output_image_size, filter_total_size);
548
549 ConstTensorMap A(out_backprop_data + output_offset * image_id,
550 output_image_size, dims.out_depth);
551 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
552
553 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
554
555 Col2im<T>(col_buffer_data, dims.in_depth,
556 // Input spatial dimensions.
557 dims.spatial_dims[0].input_size, // input planes
558 dims.spatial_dims[1].input_size, // input rows
559 dims.spatial_dims[2].input_size, // input cols
560 // Filter spatial dimensions.
561 dims.spatial_dims[0].filter_size, // filter planes
562 dims.spatial_dims[1].filter_size, // filter rows
563 dims.spatial_dims[2].filter_size, // filter cols
564 // Spatial padding.
565 top_pad_planes, top_pad_rows, left_pad_cols,
566 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
567 // Spatial striding.
568 dims.spatial_dims[0].stride, // stride planes
569 dims.spatial_dims[1].stride, // stride rows
570 dims.spatial_dims[2].stride, // stride cols
571 input_backprop_data);
572
573 input_backprop_data += input_offset;
574 }
575 } else {
576 typedef Eigen::Map<
577 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
578 MatrixMap;
579 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
580 Eigen::RowMajor>>
581 ConstMatrixMap;
582
583 for (int image_id = 0; image_id < dims.batch_size;
584 image_id += shard_size) {
585 const int shard_limit =
586 std::min(static_cast<int>(shard_size),
587 static_cast<int>(dims.batch_size) - image_id);
588
589 auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
590 &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
591 &output_image_size, &filter_total_size,
592 &input_backprop_data, &col_buffer_data,
593 &out_backprop_data, &filter_data, &input_offset,
594 &output_offset, &size_C](int64_t start, int64_t limit) {
595 for (int shard_id = start; shard_id < limit; ++shard_id) {
596 T* im2col_buf = col_buffer_data + shard_id * size_C;
597 T* input_data = input_backprop_data + shard_id * input_offset;
598 const T* out_data = out_backprop_data + shard_id * output_offset;
599
600 // Compute gradient into 'im2col_buf'.
601 MatrixMap C(im2col_buf, output_image_size, filter_total_size);
602
603 ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
604 ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
605
606 C.noalias() = A * B.transpose();
607
608 Col2im<T>(im2col_buf, dims.in_depth,
609 // Input spatial dimensions.
610 dims.spatial_dims[0].input_size, // input planes
611 dims.spatial_dims[1].input_size, // input rows
612 dims.spatial_dims[2].input_size, // input cols
613 // Filter spatial dimensions.
614 dims.spatial_dims[0].filter_size, // filter planes
615 dims.spatial_dims[1].filter_size, // filter rows
616 dims.spatial_dims[2].filter_size, // filter cols
617 // Spatial padding.
618 top_pad_planes, top_pad_rows, left_pad_cols,
619 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
620 // Spatial striding.
621 dims.spatial_dims[0].stride, // stride planes
622 dims.spatial_dims[1].stride, // stride rows
623 dims.spatial_dims[2].stride, // stride cols
624 input_data);
625 }
626 };
627 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
628 work_unit_size, shard);
629
630 input_backprop_data += input_offset * shard_limit;
631 out_backprop_data += output_offset * shard_limit;
632 }
633 }
634 }
635
636 private:
637 std::vector<int32> dilation_;
638 std::vector<int32> stride_;
639 Padding padding_;
640 TensorFormat data_format_;
641 bool takes_shape_;
642
643 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
644};
645
646// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
647// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
648
649#define REGISTER_CPU_KERNEL(T) \
650 REGISTER_KERNEL_BUILDER( \
651 Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
652 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
653 REGISTER_KERNEL_BUILDER( \
654 Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
655 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
656 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
657 .Device(DEVICE_CPU) \
658 .Label("custom") \
659 .TypeConstraint<T>("T"), \
660 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
661 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
662 .Device(DEVICE_CPU) \
663 .Label("custom") \
664 .TypeConstraint<T>("T"), \
665 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
666 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
667 .Device(DEVICE_CPU) \
668 .Label("eigen_tensor") \
669 .TypeConstraint<T>("T"), \
670 Conv3DBackpropInputOp<CPUDevice, T>); \
671 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
672 .Device(DEVICE_CPU) \
673 .Label("eigen_tensor") \
674 .TypeConstraint<T>("T"), \
675 Conv3DBackpropInputOp<CPUDevice, T>);
676
677TF_CALL_half(REGISTER_CPU_KERNEL);
678TF_CALL_float(REGISTER_CPU_KERNEL);
679TF_CALL_double(REGISTER_CPU_KERNEL);
680#undef REGISTER_CPU_KERNEL
681
682#define REGISTER_CPU_KERNEL(T) \
683 REGISTER_KERNEL_BUILDER( \
684 Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
685 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
686 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
687 .Device(DEVICE_CPU) \
688 .Label("custom") \
689 .TypeConstraint<T>("T"), \
690 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
691 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
692 .Device(DEVICE_CPU) \
693 .Label("eigen_tensor") \
694 .TypeConstraint<T>("T"), \
695 Conv3DBackpropInputOp<CPUDevice, T>);
696
697TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
698#undef REGISTER_CPU_KERNEL
699
700// Backprop for filter that offloads computation to
701// Eigen::CuboidConvolutionBackwardFilter.
702template <typename Device, class T>
703class Conv3DBackpropFilterOp : public OpKernel {
704 public:
705 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
706 : OpKernel(context),
707 data_format_(FORMAT_NHWC),
708 takes_shape_(type_string().find("V2") != std::string::npos) {
709 // data_format is only available in V2.
710 if (takes_shape_) {
711 string data_format;
712 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
713 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
714 errors::InvalidArgument("Invalid data format"));
715 OP_REQUIRES(
716 context, data_format_ == FORMAT_NHWC,
717 errors::InvalidArgument(
718 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
719 }
720
721 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
722 OP_REQUIRES(context, dilation_.size() == 5,
723 errors::InvalidArgument("Dilation rates field must "
724 "specify 5 dimensions"));
725 OP_REQUIRES(context,
726 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
727 GetTensorDim(dilation_, data_format_, 'N') == 1),
728 errors::InvalidArgument(
729 "Current implementation does not yet support "
730 "dilation rates in the batch and depth dimensions."));
731
732 // TODO(yangzihao): Add CPU version of dilated conv 3D.
733 OP_REQUIRES(context,
734 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
735 GetTensorDim(dilation_, data_format_, '1') == 1 &&
736 GetTensorDim(dilation_, data_format_, '2') == 1),
737 errors::InvalidArgument(
738 "Current CPU implementation does not yet support "
739 "dilation rates larger than 1."));
740
741 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
742 OP_REQUIRES(context, stride_.size() == 5,
743 errors::InvalidArgument("Sliding window strides field must "
744 "specify 5 dimensions"));
745 OP_REQUIRES(
746 context,
747 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
748 GetTensorDim(stride_, data_format_, 'N') == 1),
749 errors::InvalidArgument("Current implementation does not yet support "
750 "strides in the batch and depth dimensions."));
751 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
752 }
753
754 void Compute(OpKernelContext* context) override {
755 const Tensor& input = context->input(0);
756 const TensorShape& input_shape = input.shape();
757
758 const Tensor& out_backprop = context->input(2);
759 const TensorShape& out_backprop_shape = out_backprop.shape();
760
761 TensorShape filter_shape;
762 if (takes_shape_) {
763 const Tensor& filter_sizes = context->input(1);
764 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
765 errors::InvalidArgument(
766 "filter_sizes shape must be rank 1 but is rank ",
767 filter_sizes.shape().dims()));
768 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
769 filter_sizes.vec<int32>(), &filter_shape));
770 } else {
771 filter_shape = context->input(1).shape();
772 }
773
774 OP_REQUIRES(context, input_shape.dims() == 5,
775 errors::InvalidArgument("input tensor must have 5 dimensions"));
776 OP_REQUIRES(
777 context, filter_shape.dims() == 5,
778 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
779 OP_REQUIRES(
780 context, out_backprop_shape.dims() == 5,
781 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
782 OP_REQUIRES(
783 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
784 errors::InvalidArgument("input and filter_sizes must have the same "
785 "number of channels. Got ",
786 input_shape.dim_size(4), " for input and ",
787 filter_shape.dim_size(3), " for filter_sizes"));
788 OP_REQUIRES(
789 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
790 errors::InvalidArgument("out_backprop and filter_sizes must have the "
791 "same number of channels. Got ",
792 out_backprop_shape.dim_size(4),
793 " for out_backprop and ",
794 filter_shape.dim_size(4), " for filter_sizes"));
795
796 ConvBackpropDimensions dims;
797 OP_REQUIRES_OK(context,
798 ConvBackpropComputeDimensions(
799 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
800 input_shape, filter_shape, out_backprop_shape, stride_,
801 padding_, data_format_, &dims));
802
803 Tensor* filter_backprop;
804 OP_REQUIRES_OK(context,
805 context->allocate_output(0, filter_shape, &filter_backprop));
806
807 if (input_shape.num_elements() == 0) {
808 filter_backprop->template flat<T>().setZero();
809 return;
810 }
811
812 functor::CuboidConvolutionBackwardFilter<Device, T>()(
813 context->eigen_device<Device>(),
814 filter_backprop->tensor<T, 5>(), // filter_backward
815 input.tensor<T, 5>(), // input
816 out_backprop.tensor<T, 5>(), // output_backward
817 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
818 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
819 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
820 }
821
822 private:
823 std::vector<int32> dilation_;
824 std::vector<int32> stride_;
825 Padding padding_;
826 TensorFormat data_format_;
827 bool takes_shape_;
828
829 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
830};
831
832// Custom backprop for filter that explicitly does the work sharding and calls
833// Eigen only to multiply matrices.
834template <typename Device, class T>
835class Conv3DCustomBackpropFilterOp : public OpKernel {
836 // Limit the maximum size of allocated temporary buffer to
837 // kMaxTempAllocationOverhead times the size of the input tensors (input,
838 // filter, out_backprop). If the size of the temporary buffer exceeds this
839 // limit, fallback on Eigen implementation.
840 static constexpr int kMaxTempAllocationOverhead = 25;
841
842 public:
843 explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
844 : OpKernel(context),
845 data_format_(FORMAT_NHWC),
846 takes_shape_(type_string().find("V2") != std::string::npos) {
847 // data_format is only available in V2.
848 if (takes_shape_) {
849 string data_format;
850 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
851 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
852 errors::InvalidArgument("Invalid data format"));
853 OP_REQUIRES(
854 context, data_format_ == FORMAT_NHWC,
855 errors::InvalidArgument(
856 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
857 }
858
859 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
860 OP_REQUIRES(context, dilation_.size() == 5,
861 errors::InvalidArgument("Dilation rates field must "
862 "specify 5 dimensions"));
863 OP_REQUIRES(context,
864 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
865 GetTensorDim(dilation_, data_format_, 'N') == 1),
866 errors::InvalidArgument(
867 "Current implementation does not yet support "
868 "dilation rates in the batch and depth dimensions."));
869
870 // TODO(yangzihao): Add CPU version of dilated conv 3D.
871 OP_REQUIRES(context,
872 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
873 GetTensorDim(dilation_, data_format_, '1') == 1 &&
874 GetTensorDim(dilation_, data_format_, '2') == 1),
875 errors::InvalidArgument(
876 "Current CPU implementation does not yet support "
877 "dilation rates larger than 1."));
878
879 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
880 OP_REQUIRES(context, stride_.size() == 5,
881 errors::InvalidArgument("Sliding window strides field must "
882 "specify 5 dimensions"));
883 OP_REQUIRES(
884 context,
885 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
886 GetTensorDim(stride_, data_format_, 'N') == 1),
887 errors::InvalidArgument("Current implementation does not yet support "
888 "strides in the batch and depth dimensions."));
889 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
890 }
891
892 void Compute(OpKernelContext* context) override {
893 const Tensor& input = context->input(0);
894 const TensorShape& input_shape = input.shape();
895
896 const Tensor& out_backprop = context->input(2);
897 const TensorShape& out_backprop_shape = out_backprop.shape();
898
899 TensorShape filter_shape;
900 if (takes_shape_) {
901 const Tensor& filter_sizes = context->input(1);
902 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
903 errors::InvalidArgument(
904 "filter_sizes shape must be rank 1 but is rank ",
905 filter_sizes.shape().dims()));
906 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
907 filter_sizes.vec<int32>(), &filter_shape));
908 } else {
909 filter_shape = context->input(1).shape();
910 }
911
912 OP_REQUIRES(context, input_shape.dims() == 5,
913 errors::InvalidArgument("input tensor must have 5 dimensions"));
914 OP_REQUIRES(
915 context, filter_shape.dims() == 5,
916 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
917 OP_REQUIRES(
918 context, out_backprop_shape.dims() == 5,
919 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
920 OP_REQUIRES(
921 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
922 errors::InvalidArgument("input and filter_sizes must have the same "
923 "number of channels. Got ",
924 input_shape.dim_size(4), " for input and ",
925 filter_shape.dim_size(3), " for filter_sizes"));
926 OP_REQUIRES(
927 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
928 errors::InvalidArgument("out_backprop and filter_sizes must have the "
929 "same number of channels. Got ",
930 out_backprop_shape.dim_size(4),
931 " for out_backprop and ",
932 filter_shape.dim_size(4), " for filter_sizes"));
933
934 ConvBackpropDimensions dims;
935 OP_REQUIRES_OK(context,
936 ConvBackpropComputeDimensions(
937 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
938 input_shape, filter_shape, out_backprop_shape, stride_,
939 padding_, data_format_, &dims));
940
941 Tensor* filter_backprop;
942 OP_REQUIRES_OK(context,
943 context->allocate_output(0, filter_shape, &filter_backprop));
944
945 if (input_shape.num_elements() == 0) {
946 filter_backprop->template flat<T>().setZero();
947 return;
948 }
949
950 int64_t top_pad_planes, bottom_pad_planes;
951 int64_t top_pad_rows, bottom_pad_rows;
952 int64_t left_pad_cols, right_pad_cols;
953
954 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
955 dims.spatial_dims[0].input_size,
956 dims.spatial_dims[0].filter_size,
957 dims.spatial_dims[0].stride, padding_,
958 &dims.spatial_dims[0].output_size,
959 &top_pad_planes, &bottom_pad_planes));
960 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
961 dims.spatial_dims[1].input_size,
962 dims.spatial_dims[1].filter_size,
963 dims.spatial_dims[1].stride, padding_,
964 &dims.spatial_dims[1].output_size,
965 &top_pad_rows, &bottom_pad_rows));
966 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
967 dims.spatial_dims[2].input_size,
968 dims.spatial_dims[2].filter_size,
969 dims.spatial_dims[2].stride, padding_,
970 &dims.spatial_dims[2].output_size,
971 &left_pad_cols, &right_pad_cols));
972
973 // TODO(ezhulenev): Extract work size and shard estimation to shared
974 // functions in conv_grad_ops, and update 2d convolution backprop.
975
976 // The total dimension size of each kernel.
977 const int64_t filter_total_size =
978 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
979 dims.spatial_dims[2].filter_size * dims.in_depth;
980 // The output image size is the spatial size of the output.
981 const int64_t output_image_size = dims.spatial_dims[0].output_size *
982 dims.spatial_dims[1].output_size *
983 dims.spatial_dims[2].output_size;
984
985 // Shard 'batch' images (volumes) into 'shard_size' groups of images
986 // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
987 // dividing the L3 cache size ('target_working_set_size') by the matmul size
988 // of an individual image ('work_unit_size').
989
990 const auto cache_sizes = Eigen::internal::CacheSizes();
991 const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
992
993 // TODO(andydavis)
994 // *) Consider reducing 'target_working_set_size' if L3 is shared by
995 // other concurrently running tensorflow ops.
996 const size_t target_working_set_size = l3_cache_size / sizeof(T);
997
998 const int64_t size_A = output_image_size * filter_total_size;
999
1000 const int64_t size_B = output_image_size * dims.out_depth;
1001
1002 const int64_t size_C = filter_total_size * dims.out_depth;
1003
1004 const int64_t work_unit_size = size_A + size_B + size_C;
1005
1006 OP_REQUIRES(
1007 context, work_unit_size > 0,
1008 errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
1009 "must all have at least 1 element"));
1010
1011 const size_t shard_size =
1012 (target_working_set_size + work_unit_size - 1) / work_unit_size;
1013
1014 // Total number of elements in all the tensors used by this kernel.
1015 int64_t total_tensor_elements = input_shape.num_elements() +
1016 filter_shape.num_elements() +
1017 out_backprop_shape.num_elements();
1018
1019 // Shape of the temporary workspace buffer.
1020 TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
1021 static_cast<int64_t>(output_image_size),
1022 static_cast<int64_t>(filter_total_size)};
1023 int64_t col_buffer_elements = col_buffer_shape.num_elements();
1024
1025 // If the temporary allocation overhead is too large, fallback on Eigen
1026 // implementation which requires much less memory.
1027 int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
1028 if (col_buffer_overhead > kMaxTempAllocationOverhead) {
1029 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
1030 "col_buffer_overhead="
1031 << col_buffer_overhead;
1032
1033 functor::CuboidConvolutionBackwardFilter<Device, T>()(
1034 context->eigen_device<Device>(),
1035 filter_backprop->tensor<T, 5>(), // filter_backward
1036 input.tensor<T, 5>(), // input
1037 out_backprop.tensor<T, 5>(), // output_backward
1038 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
1039 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
1040 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
1041
1042 return;
1043 }
1044
1045 Tensor col_buffer;
1046 OP_REQUIRES_OK(context,
1047 context->allocate_temp(DataTypeToEnum<T>::value,
1048 col_buffer_shape, &col_buffer));
1049
1050 // The input offset corresponding to a single input image.
1051 const int64_t input_offset =
1052 dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
1053 dims.spatial_dims[2].input_size * dims.in_depth;
1054 // The output offset corresponding to a single output image.
1055 const int64_t output_offset =
1056 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
1057 dims.spatial_dims[2].output_size * dims.out_depth;
1058
1059 const T* input_data = input.template flat<T>().data();
1060 T* col_buffer_data = col_buffer.template flat<T>().data();
1061 const T* out_backprop_data = out_backprop.template flat<T>().data();
1062 T* filter_backprop_data = filter_backprop->template flat<T>().data();
1063
1064 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
1065 Eigen::Unaligned>
1066 TensorMap;
1067 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
1068 Eigen::Unaligned>
1069 ConstTensorMap;
1070
1071 TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
1072 C.setZero();
1073
1074 // Initialize contraction dims (we need to transpose 'A' below).
1075 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
1076 contract_dims[0].first = 0;
1077 contract_dims[0].second = 0;
1078
1079 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
1080
1081 for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
1082 const int shard_limit =
1083 std::min(static_cast<int>(shard_size),
1084 static_cast<int>(dims.batch_size) - image_id);
1085
1086 auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
1087 &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
1088 &bottom_pad_rows, &right_pad_cols, &input_offset,
1089 &size_A](int64_t start, int64_t limit) {
1090 for (int shard_id = start; shard_id < limit; ++shard_id) {
1091 const T* input_data_shard = input_data + shard_id * input_offset;
1092 T* col_data_shard = col_buffer_data + shard_id * size_A;
1093
1094 // When we compute the gradient with respect to the filters, we need
1095 // to do im2col to allow gemm-type computation.
1096 Im2col<T>(input_data_shard, dims.in_depth,
1097 // Input spatial dimensions.
1098 dims.spatial_dims[0].input_size, // input planes
1099 dims.spatial_dims[1].input_size, // input rows
1100 dims.spatial_dims[2].input_size, // input cols
1101 // Filter spatial dimensions.
1102 dims.spatial_dims[0].filter_size, // filter planes
1103 dims.spatial_dims[1].filter_size, // filter rows
1104 dims.spatial_dims[2].filter_size, // filter cols
1105 // Spatial padding.
1106 top_pad_planes, top_pad_rows, left_pad_cols,
1107 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
1108 // Spatial striding.
1109 dims.spatial_dims[0].stride, // stride planes
1110 dims.spatial_dims[1].stride, // stride rows
1111 dims.spatial_dims[2].stride, // stride cols
1112 col_data_shard);
1113 }
1114 };
1115 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
1116 size_A, shard);
1117
1118 ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
1119 filter_total_size);
1120 ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
1121 dims.out_depth);
1122
1123 // Gradient with respect to filter.
1124 C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
1125
1126 input_data += input_offset * shard_limit;
1127 out_backprop_data += output_offset * shard_limit;
1128 }
1129 }
1130
1131 private:
1132 std::vector<int32> dilation_;
1133 std::vector<int32> stride_;
1134 Padding padding_;
1135 TensorFormat data_format_;
1136 bool takes_shape_;
1137
1138 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
1139};
1140
1141// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
1142// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
1143
1144#define REGISTER_CPU_KERNEL(T) \
1145 REGISTER_KERNEL_BUILDER( \
1146 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1147 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1148 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1149 .Device(DEVICE_CPU) \
1150 .TypeConstraint<T>("T"), \
1151 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1152 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
1153 .Device(DEVICE_CPU) \
1154 .Label("custom") \
1155 .TypeConstraint<T>("T"), \
1156 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1157 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1158 .Device(DEVICE_CPU) \
1159 .Label("custom") \
1160 .TypeConstraint<T>("T"), \
1161 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1162 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
1163 .Device(DEVICE_CPU) \
1164 .Label("eigen_tensor") \
1165 .TypeConstraint<T>("T"), \
1166 Conv3DBackpropFilterOp<CPUDevice, T>); \
1167 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1168 .Device(DEVICE_CPU) \
1169 .Label("eigen_tensor") \
1170 .TypeConstraint<T>("T"), \
1171 Conv3DBackpropFilterOp<CPUDevice, T>);
1172
1173TF_CALL_float(REGISTER_CPU_KERNEL);
1174TF_CALL_double(REGISTER_CPU_KERNEL);
1175#undef REGISTER_CPU_KERNEL
1176
1177#define REGISTER_CPU_KERNEL(T) \
1178 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1179 .Device(DEVICE_CPU) \
1180 .TypeConstraint<T>("T"), \
1181 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1182 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1183 .Device(DEVICE_CPU) \
1184 .Label("custom") \
1185 .TypeConstraint<T>("T"), \
1186 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1187 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1188 .Device(DEVICE_CPU) \
1189 .Label("eigen_tensor") \
1190 .TypeConstraint<T>("T"), \
1191 Conv3DBackpropFilterOp<CPUDevice, T>);
1192
1193TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
1194#undef REGISTER_CPU_KERNEL
1195
1196// WARNING: Eigen::half is not trivially copyable and can't be used in
1197// custom backprop filter kernel because of memcpy and memset in Im2col.
1198#define REGISTER_CPU_KERNEL(T) \
1199 REGISTER_KERNEL_BUILDER( \
1200 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1201 Conv3DBackpropFilterOp<CPUDevice, T>); \
1202 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1203 .Device(DEVICE_CPU) \
1204 .TypeConstraint<T>("T"), \
1205 Conv3DBackpropFilterOp<CPUDevice, T>);
1206
1207TF_CALL_half(REGISTER_CPU_KERNEL);
1208#undef REGISTER_CPU_KERNEL
1209
1210// GPU definitions of both ops.
1211#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1212// Forward declarations of the functor specializations for GPU.
1213// This ensures that the custom implementation is used instead of the default
1214// Eigen one (which is used for CPU).
1215namespace functor {
1216#define DECLARE_GPU_SPEC(T) \
1217 template <> \
1218 void TransformFilter<GPUDevice, T, int, 5>::operator()( \
1219 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
1220 typename TTypes<T, 5, int>::ConstTensor in, \
1221 typename TTypes<T, 5, int>::Tensor out); \
1222 template <> \
1223 void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
1224 const GPUDevice& d, FilterTensorFormat src_filter_format, \
1225 typename TTypes<T, 5>::ConstTensor in, \
1226 typename TTypes<T, 5>::Tensor out); \
1227 template <> \
1228 void PadInput<GPUDevice, T, int, 5>::operator()( \
1229 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
1230 const std::array<int, 3>& padding_left, \
1231 const std::array<int, 3>& padding_right, \
1232 typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
1233 const T& padding_value);
1234
1235DECLARE_GPU_SPEC(Eigen::half);
1236DECLARE_GPU_SPEC(float);
1237DECLARE_GPU_SPEC(double);
1238#undef DECLARE_GPU_SPEC
1239} // namespace functor
1240
1241// A dummy type to group backward data autotune results together.
1242struct Conv3dBackwardDataAutotuneGroup {
1243 static string name() { return "Conv3dBwdData"; }
1244};
1245
1246typedef AutotuneSingleton<Conv3dBackwardDataAutotuneGroup, ConvParameters,
1247 AutotuneEntry<se::dnn::ConvOp>>
1248
1249 AutotuneConv3dBwdData;
1250
1251template <typename T>
1252class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
1253 public:
1254 explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
1255 : OpKernel(context),
1256 data_format_(FORMAT_NHWC),
1257 takes_shape_(type_string().find("V2") != std::string::npos) {
1258 // data_format is only available in V2.
1259 if (takes_shape_) {
1260 string data_format;
1261 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1262 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1263 errors::InvalidArgument("Invalid data format"));
1264 }
1265 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1266 OP_REQUIRES(context, dilation_.size() == 5,
1267 errors::InvalidArgument("Dilation rates field must "
1268 "specify 5 dimensions"));
1269 OP_REQUIRES(context,
1270 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1271 GetTensorDim(dilation_, data_format_, 'N') == 1),
1272 errors::InvalidArgument(
1273 "Current implementation does not yet support "
1274 "dilation rates in the batch and depth dimensions."));
1275 OP_REQUIRES(
1276 context,
1277 (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1278 GetTensorDim(dilation_, data_format_, '1') > 0 &&
1279 GetTensorDim(dilation_, data_format_, '2') > 0),
1280 errors::InvalidArgument("Dilated rates should be larger than 0."));
1281 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1282 OP_REQUIRES(context, stride_.size() == 5,
1283 errors::InvalidArgument("Sliding window strides field must "
1284 "specify 5 dimensions"));
1285 OP_REQUIRES(
1286 context,
1287 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1288 GetTensorDim(stride_, data_format_, 'N') == 1),
1289 errors::InvalidArgument("Current implementation does not yet support "
1290 "strides in the batch and depth dimensions."));
1291 OP_REQUIRES(
1292 context,
1293 (GetTensorDim(stride_, data_format_, '0') > 0 &&
1294 GetTensorDim(stride_, data_format_, '1') > 0 &&
1295 GetTensorDim(stride_, data_format_, '2') > 0),
1296 errors::InvalidArgument("Spatial strides should be larger than 0."));
1297 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1298 cudnn_use_autotune_ = CudnnUseAutotune();
1299 }
1300 void Compute(OpKernelContext* context) override {
1301 const Tensor& filter = context->input(1);
1302 const TensorShape& filter_shape = filter.shape();
1303
1304 const Tensor& out_backprop = context->input(2);
1305 const TensorShape& out_backprop_shape = out_backprop.shape();
1306
1307 TensorShape input_shape;
1308 if (takes_shape_) {
1309 const Tensor& input_sizes = context->input(0);
1310 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
1311 } else {
1312 input_shape = context->input(0).shape();
1313 }
1314
1315 ConvBackpropDimensions dims;
1316 OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
1317 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
1318 input_shape, filter_shape, out_backprop_shape,
1319 dilation_, stride_, padding_,
1320 /*explicit_paddings=*/{}, data_format_, &dims));
1321
1322 Tensor* in_backprop;
1323 OP_REQUIRES_OK(context,
1324 context->allocate_output(0, input_shape, &in_backprop));
1325
1326 auto* stream = context->op_device_context()->stream();
1327 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1328
1329 bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1330 if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
1331 dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
1332 dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
1333 dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
1334 dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
1335 const uint64 m = dims.batch_size * dims.input_size(0) *
1336 dims.input_size(1) * dims.input_size(2);
1337 const uint64 k = dims.out_depth;
1338 const uint64 n = dims.in_depth;
1339
1340 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1341 out_backprop.template flat<T>().size());
1342 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1343 filter.template flat<T>().size());
1344 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1345 in_backprop->template flat<T>().size());
1346
1347 auto transpose = se::blas::Transpose::kTranspose;
1348 auto no_transpose = se::blas::Transpose::kNoTranspose;
1349
1350 OP_REQUIRES_OK(
1351 context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1352 k, a_ptr, k, &c_ptr, n,
1353 se::blas::kDefaultComputePrecision));
1354 return;
1355 } else if (!is_grouped_convolution &&
1356 dims.filter_size(0) == dims.input_size(0) &&
1357 dims.filter_size(1) == dims.input_size(1) &&
1358 dims.filter_size(2) == dims.input_size(2) &&
1359 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1360 const uint64 m = dims.batch_size;
1361 const uint64 k = dims.out_depth;
1362 const uint64 n = dims.input_size(0) * dims.input_size(1) *
1363 dims.input_size(2) * dims.in_depth;
1364
1365 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1366 out_backprop.template flat<T>().size());
1367 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1368 filter.template flat<T>().size());
1369 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1370 in_backprop->template flat<T>().size());
1371
1372 auto transpose = se::blas::Transpose::kTranspose;
1373 auto no_transpose = se::blas::Transpose::kNoTranspose;
1374
1375 OP_REQUIRES_OK(
1376 context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1377 k, a_ptr, k, &c_ptr, n,
1378 se::blas::kDefaultComputePrecision));
1379 return;
1380 }
1381
1382 int padding_planes = dims.SpatialPadding(padding_, 0);
1383 int padding_rows = dims.SpatialPadding(padding_, 1);
1384 int padding_cols = dims.SpatialPadding(padding_, 2);
1385 const bool planes_odd = (padding_planes % 2 != 0);
1386 const bool rows_odd = (padding_rows % 2 != 0);
1387 const bool cols_odd = (padding_cols % 2 != 0);
1388
1389 TensorShape compatible_input_shape;
1390 if (rows_odd || cols_odd || planes_odd) {
1391 // cuDNN only supports the same amount of padding on both sides.
1392 compatible_input_shape = {
1393 dims.batch_size,
1394 dims.in_depth,
1395 dims.input_size(0) + planes_odd,
1396 dims.input_size(1) + rows_odd,
1397 dims.input_size(2) + cols_odd,
1398 };
1399 } else {
1400 compatible_input_shape = {dims.batch_size, dims.in_depth,
1401 dims.input_size(0), dims.input_size(1),
1402 dims.input_size(2)};
1403 }
1404
1405 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1406 << "Negative paddings: (" << padding_rows << ", " << padding_cols
1407 << ", " << padding_planes << ")";
1408
1409#if GOOGLE_CUDA
1410 const bool compute_in_nhwc =
1411 CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1412#else
1413 // fast NDHWC implementation is a CUDA only feature
1414 const bool compute_in_nhwc = false;
1415#endif
1416 const TensorFormat compute_data_format =
1417 (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1418 : FORMAT_NCHW;
1419
1420 VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
1421 << " data_format=" << ToString(data_format_)
1422 << " compute_data_format=" << ToString(compute_data_format);
1423
1424 constexpr auto kComputeInNHWC =
1425 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1426 se::dnn::FilterLayout::kOutputYXInput);
1427 constexpr auto kComputeInNCHW =
1428 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1429 se::dnn::FilterLayout::kOutputInputYX);
1430
1431 se::dnn::DataLayout compute_data_layout;
1432 se::dnn::FilterLayout filter_layout;
1433
1434 std::tie(compute_data_layout, filter_layout) =
1435 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1436
1437 se::dnn::BatchDescriptor input_desc(3);
1438 input_desc.set_count(dims.batch_size)
1439 .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
1440 .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
1441 .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
1442 .set_feature_map_count(dims.in_depth)
1443 .set_layout(compute_data_layout);
1444 se::dnn::BatchDescriptor output_desc(3);
1445 output_desc.set_count(dims.batch_size)
1446 .set_spatial_dim(DimIndex::X, dims.output_size(2))
1447 .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1448 .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1449 .set_feature_map_count(dims.out_depth)
1450 .set_layout(compute_data_layout);
1451 se::dnn::FilterDescriptor filter_desc(3);
1452 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1453 .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1454 .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1455 .set_input_feature_map_count(filter_shape.dim_size(3))
1456 .set_output_feature_map_count(filter_shape.dim_size(4))
1457 .set_layout(filter_layout);
1458 se::dnn::ConvolutionDescriptor conv_desc(3);
1459 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1460 .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1461 .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1462 .set_filter_stride(DimIndex::X, dims.stride(2))
1463 .set_filter_stride(DimIndex::Y, dims.stride(1))
1464 .set_filter_stride(DimIndex::Z, dims.stride(0))
1465 .set_zero_padding(DimIndex::X, padding_cols / 2)
1466 .set_zero_padding(DimIndex::Y, padding_rows / 2)
1467 .set_zero_padding(DimIndex::Z, padding_planes / 2)
1468 .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1469
1470 // Shape: out, in, z, y, x.
1471 Tensor transformed_filter;
1472 auto dst_format =
1473 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1474 TensorShape dst_shape =
1475 dst_format == FORMAT_OIHW
1476 ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1477 dims.filter_size(0), dims.filter_size(1),
1478 dims.filter_size(2)})
1479 : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1480 dims.filter_size(1), dims.filter_size(2),
1481 filter_shape.dim_size(3)});
1482 OP_REQUIRES_OK(context,
1483 context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1484 &transformed_filter));
1485
1486 functor::TransformFilter<GPUDevice, T, int, 5>()(
1487 context->eigen_device<GPUDevice>(), dst_format,
1488 To32Bit(filter.tensor<T, 5>()),
1489 To32Bit(transformed_filter.tensor<T, 5>()));
1490
1491 // Shape: batch, filters, z, y, x.
1492 Tensor transformed_out_backprop;
1493 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1494 TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1495 dims.output_size(0), dims.output_size(1),
1496 dims.output_size(2)};
1497 if (dims.out_depth > 1) {
1498 OP_REQUIRES_OK(context, context->allocate_temp(
1499 DataTypeToEnum<T>::value, nchw_shape,
1500 &transformed_out_backprop));
1501 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1502 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1503 transformed_out_backprop.tensor<T, 5>());
1504 } else {
1505 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1506 }
1507 } else {
1508 transformed_out_backprop = out_backprop;
1509 }
1510 // Shape: batch, filters, z, y, x.
1511 Tensor pre_transformed_in_backprop;
1512 OP_REQUIRES_OK(context,
1513 context->allocate_temp(
1514 DataTypeToEnum<T>::value,
1515 ShapeFromFormat(compute_data_format,
1516 compatible_input_shape.dim_size(0),
1517 {{compatible_input_shape.dim_size(2),
1518 compatible_input_shape.dim_size(3),
1519 compatible_input_shape.dim_size(4)}},
1520 compatible_input_shape.dim_size(1)),
1521 &pre_transformed_in_backprop));
1522
1523 auto out_backprop_ptr =
1524 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1525 transformed_out_backprop.template flat<T>().size());
1526 auto filter_ptr =
1527 AsDeviceMemory(transformed_filter.template flat<T>().data(),
1528 transformed_filter.template flat<T>().size());
1529 auto in_backprop_ptr =
1530 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
1531 pre_transformed_in_backprop.template flat<T>().size());
1532
1533 static int64_t ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
1534 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 33); // 8GB by default
1535
1536 const int device_id = stream->parent()->device_ordinal();
1537 // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
1538 // the dtype from 2nd input, i.e., out_backprop.
1539 DataType dtype = context->input(2).dtype();
1540 const ConvParameters conv_parameters = {
1541 dims.batch_size,
1542 dims.in_depth,
1543 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1544 compute_data_format,
1545 dims.out_depth,
1546 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1547 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1548 {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1549 {{padding_planes, padding_rows, padding_cols}},
1550 dtype,
1551 device_id,
1552 conv_desc.group_count()};
1553
1554 using se::dnn::AlgorithmConfig;
1555 using se::dnn::AlgorithmDesc;
1556 using se::dnn::ProfileResult;
1557
1558 auto entry_or = AutotuneUnfusedConv(
1559 cudnn_use_autotune_, AutotuneConv3dBwdData::GetInstance(),
1560 conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_DATA,
1561 input_desc, in_backprop_ptr, filter_desc, filter_ptr, conv_desc,
1562 output_desc, out_backprop_ptr, ConvolveBackwardDataScratchSize);
1563 OP_REQUIRES_OK(context, entry_or.status());
1564 auto autotune_entry = std::move(entry_or).value();
1565
1566 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1567 context);
1568 Status cudnn_launch_status = LaunchAutotunedConv(
1569 autotune_entry, &scratch_allocator,
1570 se::dnn::ConvolutionKind::BACKWARD_DATA, stream, input_desc,
1571 in_backprop_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1572 out_backprop_ptr);
1573 if (!cudnn_launch_status.ok()) {
1574 context->SetStatus(cudnn_launch_status);
1575 return;
1576 }
1577
1578 if (rows_odd || cols_odd || planes_odd) {
1579 Tensor in_backprop_remove_padding;
1580 OP_REQUIRES_OK(
1581 context, context->allocate_temp(
1582 DataTypeToEnum<T>::value,
1583 ShapeFromFormat(compute_data_format, dims.batch_size,
1584 {{dims.input_size(0), dims.input_size(1),
1585 dims.input_size(2)}},
1586 dims.in_depth),
1587 &in_backprop_remove_padding));
1588
1589 // Remove the padding for odd spatial dimensions.
1590 functor::PadInput<GPUDevice, T, int, 5>()(
1591 context->eigen_device<GPUDevice>(),
1592 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1593 .tensor<T, 5>()),
1594 {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
1595 To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
1596 compute_data_format, T{});
1597
1598 pre_transformed_in_backprop = in_backprop_remove_padding;
1599 }
1600
1601 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1602 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1603 functor::NCHWToNHWC<GPUDevice, T, 5>()(
1604 context->eigen_device<GPUDevice>(),
1605 toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
1606 in_backprop->tensor<T, 5>());
1607 } else {
1608 *in_backprop = pre_transformed_in_backprop;
1609 }
1610 }
1611
1612 private:
1613 std::vector<int32> dilation_;
1614 std::vector<int32> stride_;
1615 Padding padding_;
1616 TensorFormat data_format_;
1617 bool takes_shape_;
1618 bool cudnn_use_autotune_;
1619};
1620
1621// A dummy type to group backward filter autotune results together.
1622struct Conv3dBackwardFilterAutotuneGroup {
1623 static string name() { return "Conv3dBwdFilter"; }
1624};
1625
1626typedef AutotuneSingleton<Conv3dBackwardFilterAutotuneGroup, ConvParameters,
1627 AutotuneEntry<se::dnn::ConvOp>>
1628 AutotuneConv3dBwdFilter;
1629
1630template <typename T>
1631class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
1632 public:
1633 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
1634 : OpKernel(context),
1635 data_format_(FORMAT_NHWC),
1636 takes_shape_(type_string().find("V2") != std::string::npos) {
1637 // data_format is only available in V2.
1638 if (takes_shape_) {
1639 string data_format;
1640 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1641 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1642 errors::InvalidArgument("Invalid data format"));
1643 }
1644 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1645 OP_REQUIRES(context, dilation_.size() == 5,
1646 errors::InvalidArgument("Dilation rates field must "
1647 "specify 5 dimensions"));
1648 OP_REQUIRES(context,
1649 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1650 GetTensorDim(dilation_, data_format_, 'N') == 1),
1651 errors::InvalidArgument(
1652 "Current implementation does not yet support "
1653 "dilation rates in the batch and depth dimensions."));
1654 OP_REQUIRES(
1655 context,
1656 (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1657 GetTensorDim(dilation_, data_format_, '1') > 0 &&
1658 GetTensorDim(dilation_, data_format_, '2') > 0),
1659 errors::InvalidArgument("Dilated rates should be larger than 0."));
1660 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1661 OP_REQUIRES(context, stride_.size() == 5,
1662 errors::InvalidArgument("Sliding window strides field must "
1663 "specify 5 dimensions"));
1664 OP_REQUIRES(
1665 context,
1666 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1667 GetTensorDim(stride_, data_format_, 'N') == 1),
1668 errors::InvalidArgument("Current implementation does not yet support "
1669 "strides in the batch and depth dimensions."));
1670 OP_REQUIRES(
1671 context,
1672 (GetTensorDim(stride_, data_format_, '0') > 0 &&
1673 GetTensorDim(stride_, data_format_, '1') > 0 &&
1674 GetTensorDim(stride_, data_format_, '2') > 0),
1675 errors::InvalidArgument("Spatial strides should be larger than 0."));
1676 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1677 cudnn_use_autotune_ = CudnnUseAutotune();
1678 }
1679
1680 void Compute(OpKernelContext* context) override {
1681 const Tensor& input = context->input(0);
1682 const TensorShape& input_shape = input.shape();
1683
1684 const Tensor& out_backprop = context->input(2);
1685 const TensorShape& out_backprop_shape = out_backprop.shape();
1686
1687 TensorShape filter_shape;
1688 if (takes_shape_) {
1689 const Tensor& filter_sizes = context->input(1);
1690 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
1691 errors::InvalidArgument(
1692 "filter_sizes shape must be rank 1 but is rank ",
1693 filter_sizes.shape().dims()));
1694 OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
1695 } else {
1696 filter_shape = context->input(1).shape();
1697 }
1698
1699 ConvBackpropDimensions dims;
1700 OP_REQUIRES_OK(
1701 context,
1702 ConvBackpropComputeDimensionsV2(
1703 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
1704 filter_shape, out_backprop_shape, dilation_, stride_, padding_,
1705 /*explicit_paddings=*/{}, data_format_, &dims));
1706
1707 Tensor* filter_backprop;
1708 OP_REQUIRES_OK(context,
1709 context->allocate_output(0, filter_shape, &filter_backprop));
1710
1711 auto* stream = context->op_device_context()->stream();
1712 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1713
1714 bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1715 if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
1716 dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
1717 dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
1718 dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
1719 dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
1720 const uint64 m = dims.in_depth;
1721 const uint64 k = dims.batch_size * dims.input_size(1) *
1722 dims.input_size(2) * dims.input_size(0);
1723 const uint64 n = dims.out_depth;
1724
1725 // The shape of output backprop is
1726 // [batch, out_z, out_y, out_x, out_depth]
1727 // From cublas's perspective, it is: n x k
1728 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1729 out_backprop.template flat<T>().size());
1730
1731 // The shape of input is:
1732 // [batch, in_z, in_y, in_x, in_depth],
1733 // From cublas's perspective, it is: m x k
1734 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
1735 input.template flat<T>().size());
1736
1737 // The shape of the filter backprop is:
1738 // [1, 1, 1, in_depth, out_depth]
1739 // From cublas's perspective, it is: n x m
1740 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1741 filter_backprop->template flat<T>().size());
1742
1743 OP_REQUIRES_OK(context,
1744 stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1745 se::blas::Transpose::kTranspose, n, m,
1746 k, a_ptr, n, b_ptr, m, &c_ptr, n,
1747 se::blas::kDefaultComputePrecision));
1748 return;
1749 } else if (!is_grouped_convolution &&
1750 dims.filter_size(0) == dims.input_size(0) &&
1751 dims.filter_size(1) == dims.input_size(1) &&
1752 dims.filter_size(2) == dims.input_size(2) &&
1753 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1754 const uint64 m = dims.input_size(0) * dims.input_size(1) *
1755 dims.input_size(2) * dims.in_depth;
1756 const uint64 k = dims.batch_size;
1757 const uint64 n = dims.out_depth;
1758
1759 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
1760 input.template flat<T>().size());
1761 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1762 out_backprop.template flat<T>().size());
1763 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1764 filter_backprop->template flat<T>().size());
1765
1766 OP_REQUIRES_OK(context,
1767 stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1768 se::blas::Transpose::kTranspose, n, m,
1769 k, b_ptr, n, a_ptr, m, &c_ptr, n,
1770 se::blas::kDefaultComputePrecision));
1771 return;
1772 }
1773
1774 int padding_planes = dims.SpatialPadding(padding_, 0);
1775 int padding_rows = dims.SpatialPadding(padding_, 1);
1776 int padding_cols = dims.SpatialPadding(padding_, 2);
1777 const bool planes_odd = (padding_planes % 2 != 0);
1778 const bool rows_odd = (padding_rows % 2 != 0);
1779 const bool cols_odd = (padding_cols % 2 != 0);
1780
1781 Tensor compatible_input;
1782 if (rows_odd || cols_odd || planes_odd) {
1783 OP_REQUIRES_OK(context,
1784 context->allocate_temp(
1785 DataTypeToEnum<T>::value,
1786 ShapeFromFormat(data_format_, dims.batch_size,
1787 {{dims.input_size(0) + planes_odd,
1788 dims.input_size(1) + rows_odd,
1789 dims.input_size(2) + cols_odd}},
1790 dims.in_depth),
1791 &compatible_input));
1792 functor::PadInput<GPUDevice, T, int, 5>()(
1793 context->template eigen_device<GPUDevice>(),
1794 To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
1795 {{planes_odd, rows_odd, cols_odd}},
1796 To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{});
1797 } else {
1798 compatible_input = input;
1799 }
1800
1801 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1802 << "Negative paddings: (" << padding_rows << ", " << padding_cols
1803 << ", " << padding_planes << ")";
1804
1805#if GOOGLE_CUDA
1806 const bool compute_in_nhwc =
1807 CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1808#else
1809 // fast NDHWC implementation is a CUDA only feature
1810 const bool compute_in_nhwc = false;
1811#endif
1812 const TensorFormat compute_data_format =
1813 (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1814 : FORMAT_NCHW;
1815
1816 VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
1817 << " data_format=" << ToString(data_format_)
1818 << " compute_data_format=" << ToString(compute_data_format);
1819
1820 constexpr auto kComputeInNHWC =
1821 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1822 se::dnn::FilterLayout::kOutputYXInput);
1823 constexpr auto kComputeInNCHW =
1824 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1825 se::dnn::FilterLayout::kOutputInputYX);
1826
1827 se::dnn::DataLayout compute_data_layout;
1828 se::dnn::FilterLayout filter_layout;
1829
1830 std::tie(compute_data_layout, filter_layout) =
1831 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1832
1833 se::dnn::BatchDescriptor input_desc(3);
1834 input_desc.set_count(dims.batch_size)
1835 .set_spatial_dim(DimIndex::X,
1836 GetTensorDim(compatible_input, data_format_, '2'))
1837 .set_spatial_dim(DimIndex::Y,
1838 GetTensorDim(compatible_input, data_format_, '1'))
1839 .set_spatial_dim(DimIndex::Z,
1840 GetTensorDim(compatible_input, data_format_, '0'))
1841 .set_feature_map_count(dims.in_depth)
1842 .set_layout(compute_data_layout);
1843 se::dnn::BatchDescriptor output_desc(3);
1844 output_desc.set_count(dims.batch_size)
1845 .set_spatial_dim(DimIndex::X, dims.output_size(2))
1846 .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1847 .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1848 .set_feature_map_count(dims.out_depth)
1849 .set_layout(compute_data_layout);
1850 se::dnn::FilterDescriptor filter_desc(3);
1851 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1852 .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1853 .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1854 .set_input_feature_map_count(filter_shape.dim_size(3))
1855 .set_output_feature_map_count(filter_shape.dim_size(4))
1856 .set_layout(filter_layout);
1857 se::dnn::ConvolutionDescriptor conv_desc(3);
1858 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1859 .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1860 .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1861 .set_filter_stride(DimIndex::X, dims.stride(2))
1862 .set_filter_stride(DimIndex::Y, dims.stride(1))
1863 .set_filter_stride(DimIndex::Z, dims.stride(0))
1864 .set_zero_padding(DimIndex::X, padding_cols / 2)
1865 .set_zero_padding(DimIndex::Y, padding_rows / 2)
1866 .set_zero_padding(DimIndex::Z, padding_planes / 2)
1867 .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1868
1869 Tensor pre_transformed_filter_backprop;
1870 auto dst_format =
1871 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1872 TensorShape dst_shape =
1873 dst_format == FORMAT_OIHW
1874 ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1875 dims.filter_size(0), dims.filter_size(1),
1876 dims.filter_size(2)})
1877 : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1878 dims.filter_size(1), dims.filter_size(2),
1879 filter_shape.dim_size(3)});
1880 OP_REQUIRES_OK(context,
1881 context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1882 &pre_transformed_filter_backprop));
1883
1884 Tensor transformed_out_backprop;
1885 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1886 VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
1887 TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1888 dims.output_size(0), dims.output_size(1),
1889 dims.output_size(2)};
1890 OP_REQUIRES_OK(
1891 context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
1892 &transformed_out_backprop));
1893 if (dims.out_depth > 1) {
1894 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1895 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1896 transformed_out_backprop.tensor<T, 5>());
1897 } else {
1898 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1899 }
1900 } else {
1901 transformed_out_backprop = out_backprop;
1902 }
1903 Tensor transformed_input;
1904 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1905 VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
1906 TensorShape nchw_shape = {
1907 dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
1908 compatible_input.dim_size(2), compatible_input.dim_size(3)};
1909 if (dims.in_depth > 1) {
1910 OP_REQUIRES_OK(context,
1911 context->allocate_temp(DataTypeToEnum<T>::value,
1912 nchw_shape, &transformed_input));
1913 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1914 context->eigen_device<GPUDevice>(),
1915 const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
1916 transformed_input.tensor<T, 5>());
1917 } else {
1918 CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
1919 }
1920 } else {
1921 transformed_input = compatible_input;
1922 }
1923
1924 auto out_backprop_ptr =
1925 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1926 transformed_out_backprop.template flat<T>().size());
1927 auto filter_backprop_ptr = AsDeviceMemory(
1928 pre_transformed_filter_backprop.template flat<T>().data(),
1929 pre_transformed_filter_backprop.template flat<T>().size());
1930 auto input_ptr =
1931 AsDeviceMemory(transformed_input.template flat<T>().data(),
1932 transformed_input.template flat<T>().size());
1933
1934 static int64_t ConvolveBackwardFilterScratchSize =
1935 GetDnnWorkspaceLimitOrDefault();
1936
1937 const int device_id = stream->parent()->device_ordinal();
1938 DataType dtype = input.dtype();
1939 const ConvParameters conv_parameters = {
1940 dims.batch_size,
1941 dims.in_depth,
1942 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1943 compute_data_format,
1944 dims.out_depth,
1945 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1946 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1947 {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1948 {{padding_planes, padding_rows, padding_cols}},
1949 dtype,
1950 device_id,
1951 conv_desc.group_count()};
1952
1953 using se::dnn::AlgorithmConfig;
1954 using se::dnn::AlgorithmDesc;
1955 using se::dnn::ProfileResult;
1956
1957 auto entry_or = AutotuneUnfusedConv(
1958 cudnn_use_autotune_, AutotuneConv3dBwdFilter::GetInstance(),
1959 conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER,
1960 input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc,
1961 output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize);
1962 OP_REQUIRES_OK(context, entry_or.status());
1963 auto autotune_entry = std::move(entry_or).value();
1964
1965 DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1966 context);
1967 Status cudnn_launch_status = LaunchAutotunedConv(
1968 autotune_entry, &scratch_allocator,
1969 se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc,
1970 input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc,
1971 out_backprop_ptr);
1972 if (!cudnn_launch_status.ok()) {
1973 context->SetStatus(cudnn_launch_status);
1974 return;
1975 }
1976
1977 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1978 functor::ReverseTransformFilter<GPUDevice, T, 5>()(
1979 context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
1980 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
1981 filter_backprop->tensor<T, 5>());
1982 }
1983
1984 private:
1985 std::vector<int32> dilation_;
1986 std::vector<int32> stride_;
1987 Padding padding_;
1988 TensorFormat data_format_;
1989 bool takes_shape_;
1990 bool cudnn_use_autotune_;
1991};
1992
1993#define REGISTER_GPU_KERNEL(T) \
1994 REGISTER_KERNEL_BUILDER( \
1995 Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1996 Conv3DBackpropInputOp<GPUDevice, T>); \
1997 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
1998 .Device(DEVICE_GPU) \
1999 .TypeConstraint<T>("T") \
2000 .HostMemory("input_sizes"), \
2001 Conv3DBackpropInputOp<GPUDevice, T>); \
2002 REGISTER_KERNEL_BUILDER( \
2003 Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2004 Conv3DBackpropFilterOp<GPUDevice, T>); \
2005 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
2006 .Device(DEVICE_GPU) \
2007 .TypeConstraint<T>("T") \
2008 .HostMemory("filter_sizes"), \
2009 Conv3DBackpropFilterOp<GPUDevice, T>);
2010TF_CALL_half(REGISTER_GPU_KERNEL);
2011TF_CALL_float(REGISTER_GPU_KERNEL);
2012TF_CALL_double(REGISTER_GPU_KERNEL);
2013#undef REGISTER_GPU_KERNEL
2014
2015#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2016
2017} // namespace tensorflow
2018