1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/depthwise_conv_op.h"
19
20#include <algorithm>
21#include <cmath>
22#include <type_traits>
23#include <vector>
24
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/kernel_shape_util.h"
27#include "tensorflow/core/framework/numeric_op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/framework/tensor_types.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/kernels/conv_ops.h"
35#include "tensorflow/core/lib/core/status.h"
36#include "tensorflow/core/platform/errors.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/types.h"
39#include "tensorflow/core/util/padding.h"
40#include "tensorflow/core/util/tensor_format.h"
41#include "tensorflow/core/util/use_cudnn.h"
42#include "tensorflow/core/util/work_sharder.h"
43
44#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45
46#if GOOGLE_CUDA
47#include "third_party/gpus/cudnn/cudnn.h"
48#endif
49
50#include "tensorflow/core/platform/stream_executor.h"
51#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52
53namespace tensorflow {
54
55// In depthwise convolution, one input is convolved into depth_multipler
56// outputs and the outputs don't need to be reduced again like what regular
57// convolution does.
58// However, the way to apply filters to inputs is exactly the same as the
59// regular convolution. Please refer to the regular convolution kernels for
60// more details.
61
62typedef Eigen::ThreadPoolDevice CPUDevice;
63typedef Eigen::GpuDevice GPUDevice;
64
65// Computes the vectorized product of 'input_buffer' and 'filter' and stores
66// result in 'output' at location specified by 'out_r' and 'out_c'.
67//
68// EX:
69// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
70// Both 'input_buffer' and 'filter' are padded to register-width boundaries.
71//
72// input_buffer [rows, cols, in_depth, depth_multiplier]
73// [a0, a0, a1, a1] [a2, a2, 0, 0] [b0, b0, b1, b1] [b2, b2, 0, 0]
74// [e0, e0, e1, e1] [e2, e2, 0, 0] [f0, f0, f1, f1] [f2, f2, 0, 0]
75//
76// filter [rows, cols, in_depth, depth_multiplier]
77// [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0]
78// [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0]
79//
80// First output register [in_depth, depth_multiplier]
81// [q0, q1, q2, q3] = ([a0, a0, a1, a1] x [u0, v0, w0, x0]) +
82// ([b0, b0, b1, b1] x [u1, v1, w1, x1]) +
83// ([e0, e0, e1, e1] x [u2, v2, w2, x2]) +
84// ([f0, f0, f1, f1] x [u3, v3, w3, x3])
85//
86// TODO(andydavis) Experiment with processing multiple inputs per input buffer.
87template <typename T>
88struct DepthwiseConv2DKernel {
89 static void Run(const DepthwiseArgs& args,
90 const int64_t padded_filter_inner_dim_size,
91 const int64_t out_r, const int64_t out_c, const T* filter,
92 const T* input_buffer, T* output, TensorFormat data_format) {
93 typedef typename Eigen::internal::packet_traits<T>::type Packet;
94 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
95
96 const int64_t out_depth = args.out_depth;
97 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
98 const int64_t output_scalar_size = out_depth % kPacketSize;
99 const int64_t output_vectorized_size =
100 (out_depth / kPacketSize) * kPacketSize;
101 const int64_t base_output_index =
102 (out_r * args.out_cols + out_c) * out_depth;
103
104 for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
105 // Reset accumulator.
106 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
107 for (int j = 0; j < filter_spatial_size; ++j) {
108 // Calculate index.
109 const int64_t index = i + j * padded_filter_inner_dim_size;
110 // Load filter.
111 // TODO(andydavis) Unroll 'out_c' loop in caller so we can load
112 // multiple inputs here to amortize the cost of each filter block load.
113 const auto filter_block =
114 Eigen::internal::ploadu<Packet>(filter + index);
115 // Load input.
116 const auto data_block =
117 Eigen::internal::ploadu<Packet>(input_buffer + index);
118 // Vector multiply-add.
119 vaccum =
120 Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
121 }
122 // Store vector accumulator to output.
123 Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum);
124 }
125
126 if (output_scalar_size > 0) {
127 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
128 for (int j = 0; j < filter_spatial_size; ++j) {
129 const int64_t index =
130 output_vectorized_size + j * padded_filter_inner_dim_size;
131 const auto filter_block =
132 Eigen::internal::ploadu<Packet>(filter + index);
133 const auto data_block =
134 Eigen::internal::ploadu<Packet>(input_buffer + index);
135 vaccum =
136 Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
137 }
138 // Load accumulator into an array and loop through output.
139 T out_buf[kPacketSize];
140 Eigen::internal::pstoreu<T>(out_buf, vaccum);
141 const int64_t last_output_index =
142 base_output_index + output_vectorized_size;
143 for (int j = 0; j < output_scalar_size; ++j) {
144 output[last_output_index + j] = out_buf[j];
145 }
146 }
147 }
148};
149
150// Computes the depthwise conv2d of 'input' by 'depthwise_filter' and stores
151// the result in 'output'. This implementation trades off copying small patches
152// of the input to achieve better data alignment, which enables vectorized
153// load/store and multiply-add operations (see comments at InputBufferCopyOp and
154// DepthwiseConv2DKernel for details).
155//
156// TODO(andydavis) Evaluate the performance of processing multiple input
157// patches in the inner loop.
158// TODO(andydavis) Consider a zero-copy implementation for the case when
159// 'in_depth' is a multiple of register width, and 'depth_multipler' is one.
160// TODO(andydavis) Evaluate the performance of alternative implementations.
161template <typename T>
162struct LaunchDepthwiseConvOp<CPUDevice, T> {
163 typedef typename Eigen::internal::packet_traits<T>::type Packet;
164
165 void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
166 const T* input, const T* depthwise_filter, T* output,
167 TensorFormat data_format) {
168 OP_REQUIRES(
169 ctx, data_format == FORMAT_NHWC,
170 errors::Unimplemented(
171 "Depthwise convolution on CPU is only supported for NHWC format"));
172 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
173
174 // Pad 'depthwise_filter' to vector register width (if needed).
175 const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true;
176 Tensor padded_filter;
177 if (pad_filter) {
178 // Allocate space for padded filter.
179 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
180 const int64_t padded_filter_inner_dim_size =
181 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
182 OP_REQUIRES_OK(
183 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
184 TensorShape({filter_spatial_size,
185 padded_filter_inner_dim_size}),
186 &padded_filter));
187 // Write out padded filter.
188 functor::DepthwiseFilterPadOp<T>()(
189 args, depthwise_filter, padded_filter.template flat<T>().data());
190 }
191 const T* filter_data =
192 pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter;
193
194 // Computes one shard of depthwise conv2d output.
195 auto shard = [&ctx, &args, &input, &filter_data, &output, data_format](
196 int64_t start, int64_t limit) {
197 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
198 const int64_t input_image_size =
199 args.in_rows * args.in_cols * args.in_depth;
200 const int64_t output_image_size =
201 args.out_rows * args.out_cols * args.out_depth;
202 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
203 const int64_t padded_filter_inner_dim_size =
204 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
205
206 // Allocate buffer for local input regions.
207 Tensor input_buffer;
208 OP_REQUIRES_OK(
209 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
210 TensorShape({filter_spatial_size,
211 padded_filter_inner_dim_size}),
212 &input_buffer));
213 T* input_buffer_data = input_buffer.template flat<T>().data();
214
215 for (int64_t i = start; i < limit; ++i) {
216 const int64_t b = i / args.out_rows;
217 const int64_t in_base = b * input_image_size;
218 const int64_t out_base = b * output_image_size;
219
220 const int64_t out_r = i % args.out_rows;
221
222 for (int64_t out_c = 0; out_c < args.out_cols; ++out_c) {
223 // Populate 'input_buffer_data' with data from local input region.
224 functor::DepthwiseInputCopyOp<T>()(args, padded_filter_inner_dim_size,
225 out_r, out_c, input + in_base,
226 input_buffer_data);
227
228 // Process buffered input across all filters and store to output.
229 DepthwiseConv2DKernel<T>::Run(
230 args, padded_filter_inner_dim_size, out_r, out_c, filter_data,
231 input_buffer_data, output + out_base, data_format);
232 }
233 }
234 };
235
236 const int64_t total_shards = args.batch * args.out_rows;
237
238 // Empirically tested to give reasonable performance boosts at batch size 1
239 // without reducing throughput at batch size 32.
240 const float kCostMultiplier = 2.5f;
241
242 // TODO(andydavis): Estimate shard cost (in cycles) based on the number of
243 // flops/loads/stores required to compute one shard.
244 const int64_t shard_cost = kCostMultiplier * args.out_cols * args.out_depth;
245
246 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
247 Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
248 shard_cost, shard);
249 }
250};
251
252// Extern template instantiated in conv_ops.cc.
253extern template struct LaunchConv2DOp<CPUDevice, bfloat16>;
254extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
255extern template struct LaunchConv2DOp<CPUDevice, float>;
256extern template struct LaunchConv2DOp<CPUDevice, double>;
257
258#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
259
260// Extern template instantiated in conv_ops.cc.
261extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
262extern template struct LaunchConv2DOp<GPUDevice, float>;
263extern template struct LaunchConv2DOp<GPUDevice, double>;
264
265// Extern template instantiated in depthwise_conv_op_gpu.cc.
266extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
267extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
268extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
269
270#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
271
272template <typename Device, typename T>
273class DepthwiseConv2dNativeOp : public BinaryOp<T> {
274 public:
275 explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context)
276 : BinaryOp<T>(context) {
277 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
278 string data_format;
279 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
280 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
281 errors::InvalidArgument("Invalid data format"));
282
283 OP_REQUIRES(context, strides_.size() == 4,
284 errors::InvalidArgument("Sliding window strides field must "
285 "specify 4 dimensions"));
286 stride_ = GetTensorDim(strides_, data_format_, 'H');
287 const int64_t stride_w = GetTensorDim(strides_, data_format_, 'W');
288 const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
289 const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
290
291 OP_REQUIRES(context, stride_ == stride_w,
292 errors::InvalidArgument(
293 "Current implementation only supports equal length "
294 "strides in the row and column dimensions."));
295 OP_REQUIRES(
296 context, (stride_n == 1 && stride_c == 1),
297 errors::InvalidArgument("Current implementation does not yet support "
298 "strides in the batch and depth dimensions."));
299 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
300 OP_REQUIRES_OK(context,
301 context->GetAttr("explicit_paddings", &explicit_paddings_));
302 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
303 /*num_dims=*/4, data_format_));
304
305 // CPU/GPU kernel currently ignores dilations, so all must be 1.
306 std::vector<int32_t> dilations;
307 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
308 bool unit_dilations = true;
309 for (int32_t dilation : dilations) {
310 if (dilation != 1) {
311 unit_dilations = false;
312 }
313 }
314 OP_REQUIRES(context, unit_dilations,
315 errors::Unimplemented(
316 "Current kernel implementation does not support "
317 "dilations, received [",
318 Eigen::Map<Eigen::Matrix<int32_t, 1, Eigen::Dynamic>>(
319 dilations.data(), dilations.size()),
320 "]"));
321
322 cudnn_use_autotune_ = CudnnUseAutotune();
323 dtype_ = DataTypeToEnum<T>::value;
324#if CUDNN_VERSION >= 8000
325 // From the cuDNN release note 8.0: We’ve extended the fprop and dgrad
326 // NHWC depthwise kernels to support more combinations (filter
327 // sizes/strides) such as 5x5/1x1, 5x5/2x2, 7x7/1x1, 7x7/2x2 (in addition
328 // to what we already have, 1x1/1x1, 3x3/1x1, 3x3/2x2), which provides
329 // good performance. (https://docs.nvidia.com/deeplearning/sdk/cudnn-
330 // release-notes/rel_8.html#rel_8)
331 use_cudnn_grouped_conv_ =
332 dtype_ == DT_HALF &&
333 (data_format_ == FORMAT_NCHW ||
334 (data_format_ == FORMAT_NHWC && stride_ == stride_w &&
335 (stride_ == 1 || stride_ == 2)));
336#elif CUDNN_VERSION >= 7603
337 // Use CuDNN grouped conv only when input/output is NCHW and float16(half).
338 // See cudnn release note 7.6.3. (https://docs.nvidia.com/deeplearning/sdk/c
339 // udnn-release-notes/rel_763.html#rel_763)
340 use_cudnn_grouped_conv_ = dtype_ == DT_HALF && data_format_ == FORMAT_NCHW;
341#else
342 use_cudnn_grouped_conv_ = false;
343#endif
344 }
345
346 void Compute(OpKernelContext* context) override {
347 // Input tensor is of the following dimensions:
348 // [ batch, in_rows, in_cols, in_depth ]
349 const Tensor& input = context->input(0);
350
351 // Input filter is of the following dimensions:
352 // [ filter_rows, filter_cols, in_depth, depth_multiplier]
353 const Tensor& filter = context->input(1);
354
355 // For 2D convolution, there should be 4 dimensions.
356 OP_REQUIRES(context, input.dims() == 4,
357 errors::InvalidArgument("input must be 4-dimensional",
358 input.shape().DebugString()));
359 OP_REQUIRES(context, filter.dims() == 4,
360 errors::InvalidArgument("filter must be 4-dimensional: ",
361 filter.shape().DebugString()));
362
363 // in_depth for input and filter must match.
364 const int64_t in_depth = GetTensorDim(input, data_format_, 'C');
365 OP_REQUIRES(context, in_depth == filter.dim_size(2),
366 errors::InvalidArgument(
367 "input and filter must have the same depth: ", in_depth,
368 " vs ", filter.dim_size(2)));
369
370 // The last dimension for filter is depth multiplier.
371 const int32_t depth_multiplier = filter.dim_size(3);
372
373 // The output depth is input depth x depth multiplier
374 const int32_t out_depth = in_depth * depth_multiplier;
375
376 const int64_t input_rows_raw = GetTensorDim(input, data_format_, 'H');
377 OP_REQUIRES(
378 context,
379 FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()),
380 errors::InvalidArgument("Input rows too large"));
381 const int32_t input_rows = static_cast<int32>(input_rows_raw);
382 const int32_t filter_rows = filter.dim_size(0);
383
384 const int64_t input_cols_raw = GetTensorDim(input, data_format_, 'W');
385 OP_REQUIRES(
386 context,
387 FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()),
388 errors::InvalidArgument("Input cols too large"));
389 const int32_t input_cols = static_cast<int32>(input_cols_raw);
390 const int32_t filter_cols = filter.dim_size(1);
391
392 // The first dimension for input is batch.
393 const int32_t batch = input.dim_size(0);
394
395 int64_t out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0,
396 pad_left = 0, pad_right = 0;
397 if (padding_ == Padding::EXPLICIT) {
398 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top,
399 &pad_bottom);
400 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left,
401 &pad_right);
402 }
403 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
404 input_rows, filter_rows, stride_, padding_,
405 &out_rows, &pad_top, &pad_bottom));
406 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
407 input_cols, filter_cols, stride_, padding_,
408 &out_cols, &pad_left, &pad_right));
409 TensorShape out_shape =
410 ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
411 OP_REQUIRES(
412 context,
413 (!std::is_same<Device, GPUDevice>::value ||
414 FastBoundsCheck(out_shape.num_elements(),
415 std::numeric_limits<int32>::max())),
416 errors::InvalidArgument("Output elements too large for GPU kernel"));
417
418 Tensor* output = nullptr;
419 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
420
421 // If there is nothing to compute, return.
422 if (out_shape.num_elements() == 0) {
423 return;
424 }
425
426 // TODO(csigg): Have autotune decide if native is faster than cuDNN.
427 // If in_depth==1, this operation is just a standard convolution.
428 // Depthwise convolution is a special case of cuDNN's grouped convolution.
429 bool use_cudnn =
430 std::is_same<Device, GPUDevice>::value &&
431 (in_depth == 1 || (use_cudnn_grouped_conv_ &&
432 ShouldCudnnGroupedConvolutionBeUsed(
433 filter_rows, filter_cols, in_depth, out_depth)));
434
435 VLOG(2) << "DepthwiseConv2dNative: "
436 << " Input: [" << batch << ", " << input_rows << ", " << input_cols
437 << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
438 << filter_cols << ", " << in_depth << ", " << depth_multiplier
439 << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
440 << ", " << out_depth << "], stride = " << stride_
441 << ", pad_top = " << pad_top << ", pad_left = " << pad_left
442 << ", Use cuDNN: " << use_cudnn;
443
444 if (use_cudnn) {
445 // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
446 //
447 // | TensorFlow | cuDNN
448 // --------------------------------------------------------------------
449 // filter_out_depth | depth_multiplier | depth_multiplier * group_count
450 // filter_in_depth | in_depth | in_depth / group_count
451 //
452 // For depthwise convolution, we have group_count == in_depth.
453 int32_t filter_in_depth = 1;
454 TensorShape shape =
455 TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
456 Tensor reshaped_filter(/*type=*/dtype_);
457 OP_REQUIRES(
458 context, reshaped_filter.CopyFrom(filter, shape),
459 errors::Internal(
460 "Failed to reshape filter tensor for grouped convolution."));
461 // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
462 // conv is supported.
463 launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, input,
464 reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
465 stride_, stride_, padding_, explicit_paddings_, output,
466 data_format_);
467 return;
468 }
469
470 DepthwiseArgs args;
471 args.batch = batch;
472 args.in_rows = input_rows;
473 args.in_cols = input_cols;
474 args.in_depth = in_depth;
475 args.filter_rows = filter_rows;
476 args.filter_cols = filter_cols;
477 args.depth_multiplier = depth_multiplier;
478 args.stride = stride_;
479 args.pad_rows = pad_top;
480 args.pad_cols = pad_left;
481 args.out_rows = out_rows;
482 args.out_cols = out_cols;
483 args.out_depth = out_depth;
484
485 auto input_ptr = input.template flat<T>().data();
486 auto filter_ptr = filter.template flat<T>().data();
487 auto output_ptr = output->template flat<T>().data();
488 LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
489 output_ptr, data_format_);
490 }
491
492 protected:
493 bool use_cudnn_grouped_conv_;
494
495 private:
496 std::vector<int32_t> strides_;
497 Padding padding_;
498 std::vector<int64_t> explicit_paddings_;
499 TensorFormat data_format_;
500
501 int64_t stride_; // in height/width dimension.
502
503 // For in_depth == 1 and grouped convolutions.
504 LaunchConv2DOp<Device, T> launcher_;
505 bool cudnn_use_autotune_;
506 DataType dtype_;
507
508 TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
509};
510
511#define REGISTER_CPU_KERNEL(T) \
512 REGISTER_KERNEL_BUILDER( \
513 Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
514 DepthwiseConv2dNativeOp<CPUDevice, T>)
515
516TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
517TF_CALL_half(REGISTER_CPU_KERNEL);
518TF_CALL_float(REGISTER_CPU_KERNEL);
519#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
520TF_CALL_double(REGISTER_CPU_KERNEL);
521#endif
522
523#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
524
525#define REGISTER_GPU_KERNEL(T) \
526 REGISTER_KERNEL_BUILDER( \
527 Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
528 DepthwiseConv2dNativeOp<GPUDevice, T>)
529
530TF_CALL_half(REGISTER_GPU_KERNEL);
531TF_CALL_float(REGISTER_GPU_KERNEL);
532TF_CALL_double(REGISTER_GPU_KERNEL);
533
534#if CUDNN_VERSION >= 7000
535template <typename T>
536class DepthwiseConv2dGroupedConvOp
537 : public DepthwiseConv2dNativeOp<GPUDevice, T> {
538 public:
539 DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
540 : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
541 this->use_cudnn_grouped_conv_ = true;
542 }
543};
544
545#define REGISTER_GROUPED_CONV_KERNEL(T) \
546 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \
547 .Device(DEVICE_GPU) \
548 .TypeConstraint<T>("T") \
549 .Label("cudnn_grouped_convolution"), \
550 DepthwiseConv2dGroupedConvOp<T>)
551
552TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
553TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
554TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
555#endif // CUDNN_VERSION
556#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
557
558} // namespace tensorflow
559