1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/maxpooling_op.h"
21
22#include <type_traits>
23#include <vector>
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/framework/bounds_check.h"
28#include "tensorflow/core/framework/numeric_op.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_slice.h"
34#include "tensorflow/core/kernels/conv_2d.h"
35#include "tensorflow/core/kernels/eigen_pooling.h"
36#include "tensorflow/core/kernels/ops_util.h"
37#include "tensorflow/core/kernels/pooling_ops_common.h"
38#include "tensorflow/core/lib/core/errors.h"
39#include "tensorflow/core/lib/gtl/array_slice.h"
40#include "tensorflow/core/util/determinism.h"
41#include "tensorflow/core/util/env_var.h"
42#include "tensorflow/core/util/padding.h"
43#include "tensorflow/core/util/tensor_format.h"
44#include "tensorflow/core/util/use_cudnn.h"
45
46#if GOOGLE_CUDA
47#include "third_party/gpus/cudnn/cudnn.h"
48#endif // GOOGLE_CUDA
49#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
51#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
52#include "tensorflow/core/platform/stream_executor.h"
53#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54
55namespace tensorflow {
56
57typedef Eigen::ThreadPoolDevice CPUDevice;
58typedef Eigen::GpuDevice GPUDevice;
59
60const int kInvalidMaxPoolingIndex = -1;
61
62template <typename Device, typename T, typename Targmax>
63static void SpatialMaxPoolWithArgMaxHelper(
64 OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
65 Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
66 const PoolParameters& params, const bool include_batch_in_index) {
67 if (input_backprop != nullptr) {
68 OP_REQUIRES(
69 context, include_batch_in_index,
70 errors::Internal(
71 "SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
72 "to be True when input_backprop != nullptr"));
73 OP_REQUIRES(
74 context, (std::is_same<Targmax, int64_t>::value),
75 errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax "
76 "to be int64 when input_backprop != nullptr"));
77 }
78 if (tensor_in.NumElements() == 0 || output->NumElements() == 0) return;
79
80 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
81 ConstEigenMatrixMap;
82 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
83 EigenMatrixMap;
84 typedef Eigen::Map<Eigen::Matrix<Targmax, Eigen::Dynamic, Eigen::Dynamic>>
85 EigenIndexMatrixMap;
86
87 ConstEigenMatrixMap in_mat(
88 tensor_in.flat<T>().data(), params.depth,
89 params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
90 EigenMatrixMap out_mat(
91 output->flat<T>().data(), params.depth,
92 params.out_width * params.out_height * params.tensor_in_batch);
93 EigenIndexMatrixMap out_arg_max_mat(
94 output_arg_max->flat<Targmax>().data(), params.depth,
95 params.out_width * params.out_height * params.tensor_in_batch);
96
97 const DeviceBase::CpuWorkerThreads& worker_threads =
98 *(context->device()->tensorflow_cpu_worker_threads());
99
100 // The following code basically does the following:
101 // 1. Flattens the input and output tensors into two dimensional arrays.
102 // tensor_in_as_matrix:
103 // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
104 // output_as_matrix:
105 // depth by (out_width * out_height * tensor_in_batch)
106 //
107 // 2. Walks through the set of columns in the flattened tensor_in_as_matrix,
108 // and updates the corresponding column(s) in output_as_matrix with the
109 // max value.
110 auto shard = [&params, &in_mat, &out_mat, &out_arg_max_mat, &input_backprop,
111 &output_arg_max, &out_backprop,
112 include_batch_in_index](int64_t start, int64_t limit) {
113 const int32_t depth = params.depth;
114 const int32_t in_rows = params.tensor_in_rows;
115 const int32_t in_cols = params.tensor_in_cols;
116 const int32_t pad_top = params.pad_top;
117 const int32_t pad_left = params.pad_left;
118 const int32_t window_rows = params.window_rows;
119 const int32_t window_cols = params.window_cols;
120 const int32_t row_stride = params.row_stride;
121 const int32_t col_stride = params.col_stride;
122 const int32_t out_height = params.out_height;
123 const int32_t out_width = params.out_width;
124
125 {
126 // Initializes the output tensor with MIN<T>.
127 const int32_t output_image_size = out_height * out_width * depth;
128 EigenMatrixMap out_shard(out_mat.data() + start * output_image_size, 1,
129 (limit - start) * output_image_size);
130 out_shard.setConstant(Eigen::NumTraits<T>::lowest());
131 EigenIndexMatrixMap out_arg_max_shard(
132 out_arg_max_mat.data() + start * output_image_size, 1,
133 (limit - start) * output_image_size);
134 out_arg_max_shard.setConstant(kInvalidMaxPoolingIndex);
135 }
136
137 for (int64_t b = start; b < limit; ++b) {
138 for (int h = 0; h < in_rows; ++h) {
139 for (int w = 0; w < in_cols; ++w) {
140 // (h_start, h_end) * (w_start, w_end) is the range that the input
141 // vector projects to.
142 const int hpad = h + pad_top;
143 const int wpad = w + pad_left;
144 const int h_start =
145 (hpad < window_rows) ? 0 : (hpad - window_rows) / row_stride + 1;
146 const int h_end = std::min(hpad / row_stride + 1, out_height);
147 const int w_start =
148 (wpad < window_cols) ? 0 : (wpad - window_cols) / col_stride + 1;
149 const int w_end = std::min(wpad / col_stride + 1, out_width);
150 // compute elementwise max
151 const int64_t in_index = (b * in_rows + h) * in_cols + w;
152 for (int ph = h_start; ph < h_end; ++ph) {
153 const int64_t out_index_base = (b * out_height + ph) * out_width;
154 for (int pw = w_start; pw < w_end; ++pw) {
155 const int64_t out_index = out_index_base + pw;
156 /// NOTES(zhengxq): not using the eigen matrix operation for
157 /// now.
158 for (int d = 0; d < depth; ++d) {
159 const T& input_ref = in_mat.coeffRef(d, in_index);
160 T& output_ref = out_mat.coeffRef(d, out_index);
161 Targmax& out_arg_max_ref =
162 out_arg_max_mat.coeffRef(d, out_index);
163 if (output_ref < input_ref ||
164 out_arg_max_ref == kInvalidMaxPoolingIndex) {
165 output_ref = input_ref;
166 if (include_batch_in_index) {
167 out_arg_max_ref = in_index * depth + d;
168 } else {
169 out_arg_max_ref = (h * in_cols + w) * depth + d;
170 }
171 }
172 }
173 }
174 }
175 }
176 }
177 }
178
179 if (input_backprop != nullptr) {
180 auto input_backprop_flat = input_backprop->flat<T>();
181 auto out_arg_max_flat = output_arg_max->flat<int64_t>();
182 auto out_backprop_flat = out_backprop.flat<T>();
183
184 // Initialize output to 0.
185 const int64_t in_size = in_rows * in_cols * depth;
186 const int64_t in_start = start * in_size;
187 const int64_t in_end = limit * in_size;
188 EigenMatrixMap in_shard(input_backprop_flat.data() + in_start, 1,
189 in_end - in_start);
190 in_shard.setConstant(T(0));
191
192 // Backpropagate.
193 const int out_size = out_height * out_width * depth;
194 const int out_start = start * out_size;
195 const int out_end = limit * out_size;
196 for (int index = out_start; index < out_end; ++index) {
197 int input_backprop_index = out_arg_max_flat(index);
198 // Although this check is in the inner loop, it is worth its value
199 // so we don't end up with memory corruptions. Our benchmark shows that
200 // the performance impact is quite small
201 // CHECK(input_backprop_index >= in_start && input_backprop_index <
202 // in_end)
203 FastBoundsCheck(input_backprop_index - in_start, in_end - in_start);
204 if (index < out_backprop.NumElements()) {
205 input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
206 }
207 }
208 }
209 };
210
211 const int64_t shard_cost = params.tensor_in_rows * params.tensor_in_cols *
212 params.depth * params.window_rows *
213 params.window_cols;
214 Shard(worker_threads.num_threads, worker_threads.workers,
215 params.tensor_in_batch, shard_cost, shard);
216}
217
218// The operation to compute MaxPool gradients.
219// It takes three inputs:
220// - The original input tensor
221// - The original output tensor
222// - Backprop tensor for output
223// It produces one output: backprop tensor for input.
224template <class Device, class T>
225class MaxPoolingGradOp : public OpKernel {
226 public:
227 explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
228 string data_format;
229 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
230 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
231 errors::InvalidArgument("Invalid data format"));
232 OP_REQUIRES(
233 context, data_format_ == FORMAT_NHWC,
234 errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ",
235 "on device type ",
236 DeviceTypeString(context->device_type())));
237
238 if (context->num_inputs() == 3) {
239 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
240 OP_REQUIRES(context, ksize_.size() == 4,
241 errors::InvalidArgument("Sliding window ksize field must "
242 "specify 4 dimensions"));
243 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
244 OP_REQUIRES(context, stride_.size() == 4,
245 errors::InvalidArgument("Sliding window strides field must "
246 "specify 4 dimensions"));
247 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
248 errors::Unimplemented(
249 "Pooling is not yet supported on the batch dimension."));
250 OP_REQUIRES(
251 context, ksize_[3] == 1 && stride_[3] == 1,
252 errors::Unimplemented(
253 "MaxPoolingGrad is not yet supported on the depth dimension."));
254 }
255
256 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
257
258 if (padding_ == Padding::EXPLICIT) {
259 OP_REQUIRES_OK(
260 context, context->GetAttr("explicit_paddings", &explicit_paddings_));
261 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
262 /*num_dims=*/4, data_format_));
263 }
264 }
265
266 void Compute(OpKernelContext* context) override {
267 const Tensor& tensor_in = context->input(0);
268 const Tensor& tensor_out = context->input(1);
269 const Tensor& out_backprop = context->input(2);
270
271 // For maxpooling, tensor_in should have 4 dimensions.
272 OP_REQUIRES(context, tensor_in.dims() == 4,
273 errors::InvalidArgument("tensor_in must be 4-dimensional"));
274 OP_REQUIRES(context, tensor_out.dims() == 4,
275 errors::InvalidArgument("tensor_out must be 4-dimensional"));
276 // For maxpooling, out_backprop should have 4 dimensions.
277 OP_REQUIRES(context, out_backprop.dims() == 4,
278 errors::InvalidArgument("out_backprop must be 4-dimensional"));
279
280 const TensorShape& output_shape = tensor_in.shape();
281
282 Tensor tensor_out_dup;
283 OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
284 {1}, DataTypeToEnum<T>::v(), tensor_out.shape(),
285 &tensor_out_dup));
286 Tensor tensor_out_arg_max;
287 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64_t>::v(),
288 tensor_out.shape(),
289 &tensor_out_arg_max));
290 std::vector<int32> ksize = ksize_;
291 std::vector<int32> stride = stride_;
292 if (context->num_inputs() == 5) {
293 const Tensor& tensor_ksize = context->input(3);
294 auto value_ksize = tensor_ksize.flat<int32>();
295 ksize.resize(tensor_ksize.shape().num_elements());
296 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
297
298 const Tensor& tensor_stride = context->input(4);
299 auto value_stride = tensor_stride.flat<int32>();
300 stride.resize(tensor_stride.shape().num_elements());
301 std::copy_n(&value_stride(0), stride.size(), stride.begin());
302 }
303
304 OP_REQUIRES(context, ksize.size() == 4,
305 errors::InvalidArgument("Sliding window ksize field must "
306 "specify 4 dimensions"));
307 OP_REQUIRES(context, stride.size() == 4,
308 errors::InvalidArgument("Sliding window strides field must "
309 "specify 4 dimensions"));
310 OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
311 errors::Unimplemented(
312 "Pooling is not yet supported on the batch dimension."));
313 OP_REQUIRES(
314 context, ksize[3] == 1 && stride[3] == 1,
315 errors::Unimplemented(
316 "MaxPoolingGrad is not yet supported on the depth dimension."));
317
318 PoolParameters params{context,
319 ksize,
320 stride,
321 padding_,
322 explicit_paddings_,
323 FORMAT_NHWC,
324 tensor_in.shape()};
325 if (!context->status().ok()) {
326 return;
327 }
328 OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
329 errors::InvalidArgument("Expected orig_output shape to be ",
330 params.forward_output_shape(),
331 ", but got ", tensor_out.shape()));
332 OP_REQUIRES(context, out_backprop.shape() == params.forward_output_shape(),
333 errors::InvalidArgument("Expected grad shape to be ",
334 params.forward_output_shape(),
335 ", but got ", out_backprop.shape()));
336
337 Tensor* output = nullptr;
338 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
339 {0}, 0, output_shape, &output));
340
341 SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, int64_t>(
342 context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
343 out_backprop, params, true);
344 }
345
346 private:
347 std::vector<int32> ksize_;
348 std::vector<int32> stride_;
349 Padding padding_;
350 std::vector<int64_t> explicit_paddings_;
351 TensorFormat data_format_;
352};
353
354#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
355
356template <class T>
357class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
358 public:
359 typedef Eigen::GpuDevice Device;
360
361 explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
362 string data_format;
363 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
364 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
365 errors::InvalidArgument("Invalid data format"));
366 if (context->num_inputs() == 3) {
367 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
368 OP_REQUIRES(context, ksize_.size() == 4,
369 errors::InvalidArgument("Sliding window ksize field must "
370 "specify 4 dimensions"));
371 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
372 OP_REQUIRES(context, stride_.size() == 4,
373 errors::InvalidArgument("Sliding window strides field must "
374 "specify 4 dimensions"));
375 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
376 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
377 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
378 errors::Unimplemented(
379 "Pooling is not yet supported on the batch dimension."));
380 }
381 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
382 if (padding_ == Padding::EXPLICIT) {
383 OP_REQUIRES_OK(
384 context, context->GetAttr("explicit_paddings", &explicit_paddings_));
385 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
386 /*num_dims=*/4, data_format_));
387 }
388 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
389 &propagate_nans_));
390 }
391
392 void Compute(OpKernelContext* context) override {
393 const Tensor& tensor_in = context->input(0);
394 const Tensor& tensor_out = context->input(1);
395 const Tensor& out_backprop = context->input(2);
396
397 // For maxpooling, tensor_in should have 4 dimensions.
398 OP_REQUIRES(context, tensor_in.dims() == 4,
399 errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
400 OP_REQUIRES(context, tensor_out.dims() == 4,
401 errors::InvalidArgument("tensor_out must be 4-dimensional"));
402 // For maxpooling, out_backprop should have 4 dimensions.
403 OP_REQUIRES(context, out_backprop.dims() == 4,
404 errors::InvalidArgument("out_backprop must be 4-dimensional"));
405
406 TensorShape output_shape = tensor_in.shape();
407
408 std::vector<int32> ksize = ksize_;
409 std::vector<int32> stride = stride_;
410 if (context->num_inputs() == 5) {
411 const Tensor& tensor_ksize = context->input(3);
412 auto value_ksize = tensor_ksize.flat<int32>();
413 ksize.resize(tensor_ksize.shape().num_elements());
414 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
415
416 const Tensor& tensor_stride = context->input(4);
417 auto value_stride = tensor_stride.flat<int32>();
418 stride.resize(tensor_stride.shape().num_elements());
419 std::copy_n(&value_stride(0), stride.size(), stride.begin());
420 }
421 OP_REQUIRES(context, ksize.size() == 4,
422 errors::InvalidArgument("Sliding window ksize field must "
423 "specify 4 dimensions"));
424 OP_REQUIRES(context, stride.size() == 4,
425 errors::InvalidArgument("Sliding window strides field must "
426 "specify 4 dimensions"));
427 const int32_t ksize_n = GetTensorDim(ksize, data_format_, 'N');
428 const int32_t stride_n = GetTensorDim(stride, data_format_, 'N');
429 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
430 errors::Unimplemented(
431 "Pooling is not yet supported on the batch dimension."));
432 int64_t pad_top, pad_bottom, pad_left, pad_right;
433 if (padding_ == Padding::EXPLICIT) {
434 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H',
435 /*pad_top=*/&pad_top,
436 /*pad_bottom=*/&pad_bottom);
437 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W',
438 /*pad_left=*/&pad_left,
439 /*pad_right=*/&pad_right);
440 }
441 DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
442 stride, padding_, explicit_paddings_,
443 data_format_, &tensor_in, &tensor_out,
444 out_backprop, output_shape, propagate_nans_);
445 }
446
447 private:
448 std::vector<int32> ksize_;
449 std::vector<int32> stride_;
450 Padding padding_;
451 std::vector<int64_t> explicit_paddings_;
452 TensorFormat data_format_;
453 bool propagate_nans_;
454};
455
456#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
457
458// The operation to compute gradient of MaxPool gradients.
459// It takes three inputs:
460// - The original input tensor
461// - The original output tensor
462// - Backprop tensor for output gradients
463// It produces one output: backprop tensor for output gradient.
464template <class Device, class T>
465class MaxPoolingGradGradOp : public OpKernel {
466 public:
467 explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
468 : OpKernel(context) {
469 string data_format;
470 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
471 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
472 errors::InvalidArgument("Invalid data format"));
473 OP_REQUIRES(
474 context, data_format_ == FORMAT_NHWC,
475 errors::InvalidArgument(
476 "Default MaxPoolingGradGradOp only supports NHWC ",
477 "on device type ", DeviceTypeString(context->device_type())));
478
479 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
480
481 if (context->num_inputs() == 3) {
482 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
483 OP_REQUIRES(context, ksize_.size() == 4,
484 errors::InvalidArgument("Sliding window ksize field must "
485 "specify 4 dimensions"));
486 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
487 OP_REQUIRES(context, stride_.size() == 4,
488 errors::InvalidArgument("Sliding window strides field must "
489 "specify 4 dimensions"));
490 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
491 errors::Unimplemented(
492 "Pooling is not yet supported on the batch dimension."));
493 OP_REQUIRES(context, ksize_[3] == 1 && stride_[3] == 1,
494 errors::Unimplemented("MaxPoolingGradGrad is not yet "
495 "supported on the depth dimension."));
496 }
497 }
498
499 void Compute(OpKernelContext* context) override {
500 const Tensor& tensor_in = context->input(0);
501 const Tensor& tensor_out = context->input(1);
502 const Tensor& out_grad_backprop = context->input(2);
503
504 // For maxpooling, tensor_in should have 4 dimensions.
505 OP_REQUIRES(context, tensor_in.dims() == 4,
506 errors::InvalidArgument("tensor_in must be 4-dimensional"));
507 OP_REQUIRES(context, tensor_out.dims() == 4,
508 errors::InvalidArgument("tensor_out must be 4-dimensional"));
509 // For maxpooling, out_grad_backprop should have 4 dimensions.
510 OP_REQUIRES(
511 context, out_grad_backprop.dims() == 4,
512 errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
513
514 std::vector<int32> ksize = ksize_;
515 std::vector<int32> stride = stride_;
516 if (context->num_inputs() == 5) {
517 const Tensor& tensor_ksize = context->input(3);
518 auto value_ksize = tensor_ksize.flat<int32>();
519 ksize.resize(tensor_ksize.shape().num_elements());
520 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
521
522 const Tensor& tensor_stride = context->input(4);
523 auto value_stride = tensor_stride.flat<int32>();
524 stride.resize(tensor_stride.shape().num_elements());
525 std::copy_n(&value_stride(0), stride.size(), stride.begin());
526 }
527
528 OP_REQUIRES(context, ksize.size() == 4,
529 errors::InvalidArgument("Sliding window ksize field must "
530 "specify 4 dimensions"));
531 OP_REQUIRES(context, stride.size() == 4,
532 errors::InvalidArgument("Sliding window strides field must "
533 "specify 4 dimensions"));
534 OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
535 errors::Unimplemented(
536 "Pooling is not yet supported on the batch dimension."));
537 OP_REQUIRES(
538 context, ksize[3] == 1 && stride[3] == 1,
539 errors::Unimplemented(
540 "MaxPoolingGrad is not yet supported on the depth dimension."));
541
542 PoolParameters params{context,
543 ksize,
544 stride,
545 padding_,
546 /*explicit_paddings=*/{},
547 FORMAT_NHWC,
548 tensor_in.shape()};
549 if (!context->status().ok()) {
550 return;
551 }
552 OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
553 errors::InvalidArgument("Expected orig_output shape to be ",
554 params.forward_output_shape(),
555 ", but got ", tensor_out.shape()));
556 OP_REQUIRES(
557 context, out_grad_backprop.shape() == tensor_in.shape(),
558 errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
559 ", but got ", out_grad_backprop.shape()));
560
561 Tensor* output = nullptr;
562 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
563 {2}, 0, tensor_out.shape(), &output));
564
565 SpatialMaxPoolGradGrad(context, output, tensor_in, tensor_out,
566 out_grad_backprop, params, padding_);
567 }
568
569 private:
570 void SpatialMaxPoolGradGrad(OpKernelContext* context, Tensor* bottom_diff,
571 const Tensor& tensor_in, const Tensor& tensor_out,
572 const Tensor& top_diff,
573 const PoolParameters& params,
574 const Padding& padding) {
575 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
576 ConstEigenMatrixMap;
577 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
578 EigenMatrixMap;
579
580 ConstEigenMatrixMap in_mat(
581 tensor_in.flat<T>().data(), params.depth,
582 params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
583 ConstEigenMatrixMap out_mat(
584 tensor_out.flat<T>().data(), params.depth,
585 params.out_width * params.out_height * params.tensor_in_batch);
586 ConstEigenMatrixMap top_diff_mat(
587 top_diff.flat<T>().data(), params.depth,
588 params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
589 EigenMatrixMap bottom_diff_mat(
590 bottom_diff->flat<T>().data(), params.depth,
591 params.out_width * params.out_height * params.tensor_in_batch);
592
593 const DeviceBase::CpuWorkerThreads& worker_threads =
594 *(context->device()->tensorflow_cpu_worker_threads());
595
596 // The following code basically does the following:
597 // 1. Flattens the input, output, top_diff and bottom_diff tensors into
598 // two dimensional arrays.
599 // tensor_in_as_matrix:
600 // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
601 // tensor_out_as_matrix:
602 // depth by (out_width * out_height * tensor_in_batch)
603 // top_diff_as_matrix:
604 // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
605 // bottom_diff_as_matrix:
606 // depth by (out_width * out_height * tensor_in_batch)
607 //
608 // 2. Walks through the set of columns in the flattened
609 // tensor_in_as_matrix, tensor_out_as_matrix, top_diff_as_matrix
610 // and updates the column(s) corresponding to the maximum values in
611 // tensor_out_as_matrix with the corresponding values in
612 // top_diff_as_matrix.
613 auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
614 int64_t start, int64_t limit) {
615 const int32_t depth = params.depth;
616 const int32_t in_rows = params.tensor_in_rows;
617 const int32_t in_cols = params.tensor_in_cols;
618 const int32_t pad_top = params.pad_top;
619 const int32_t pad_left = params.pad_left;
620 const int32_t window_rows = params.window_rows;
621 const int32_t window_cols = params.window_cols;
622 const int32_t row_stride = params.row_stride;
623 const int32_t col_stride = params.col_stride;
624 const int32_t out_height = params.out_height;
625 const int32_t out_width = params.out_width;
626
627 {
628 // Initializes the output grad backprop tensor with 0.
629 const int32_t output_image_size = out_height * out_width * params.depth;
630 EigenMatrixMap bottom_diff_shard(
631 bottom_diff_mat.data() + start * output_image_size, 1,
632 (limit - start) * output_image_size);
633 bottom_diff_shard.setZero();
634 }
635
636 for (int b = start; b < limit; ++b) {
637 for (int ph = 0; ph < out_height; ++ph) {
638 for (int pw = 0; pw < out_width; ++pw) {
639 // (h_start, h_end) * (w_start, w_end) is the range that the input
640 // vector projects to.
641 int h_start = ph * row_stride - pad_top;
642 const int h_end = std::min(h_start + window_rows, in_rows);
643 int w_start = pw * col_stride - pad_left;
644 const int w_end = std::min(w_start + window_cols, in_cols);
645 h_start = std::max(h_start, 0);
646 w_start = std::max(w_start, 0);
647 const int out_index = (b * out_height + ph) * out_width + pw;
648 // Find value corresponding to the input maximum in top_diff.
649 for (int d = 0; d < depth; ++d) {
650 const T& output_ref = out_mat.coeffRef(d, out_index);
651 bool should_stop = false;
652 for (int h = h_start; h < h_end && !should_stop; ++h) {
653 for (int w = w_start; w < w_end && !should_stop; ++w) {
654 const int in_index = (b * in_rows + h) * in_cols + w;
655 const T& input_ref = in_mat.coeffRef(d, in_index);
656 if (output_ref == input_ref) {
657 T& bottom_diff_ref = bottom_diff_mat.coeffRef(d, out_index);
658 bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
659 should_stop = true;
660 }
661 }
662 }
663 }
664 }
665 }
666 }
667 };
668
669 const int64_t shard_cost = params.out_width * params.out_height *
670 params.depth * params.window_rows *
671 params.window_cols;
672 Shard(worker_threads.num_threads, worker_threads.workers,
673 params.tensor_in_batch, shard_cost, shard);
674 }
675
676 std::vector<int32> ksize_;
677 std::vector<int32> stride_;
678 Padding padding_;
679 TensorFormat data_format_;
680};
681
682#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
683
684template <class T>
685class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
686 public:
687 typedef Eigen::GpuDevice Device;
688
689 explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
690 : OpKernel(context) {
691 string data_format;
692 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
693 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
694 errors::InvalidArgument("Invalid data format"));
695 if (context->num_inputs() == 3) {
696 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
697 OP_REQUIRES(context, ksize_.size() == 4,
698 errors::InvalidArgument("Sliding window ksize field must "
699 "specify 4 dimensions"));
700 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
701 OP_REQUIRES(context, stride_.size() == 4,
702 errors::InvalidArgument("Sliding window strides field must "
703 "specify 4 dimensions"));
704 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
705 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
706 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
707 errors::Unimplemented(
708 "Pooling is not yet supported on the batch dimension."));
709 }
710 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
711 }
712
713 void Compute(OpKernelContext* context) override {
714 const Tensor& tensor_in = context->input(0);
715 const Tensor& tensor_out = context->input(1);
716 const Tensor& out_grad_backprop = context->input(2);
717
718 // For maxpooling, tensor_in should have 4 dimensions.
719 OP_REQUIRES(context, tensor_in.dims() == 4,
720 errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
721 OP_REQUIRES(context, tensor_out.dims() == 4,
722 errors::InvalidArgument("tensor_out must be 4-dimensional"));
723 // For maxpooling, out_grad_backprop should have 4 dimensions.
724 OP_REQUIRES(
725 context, out_grad_backprop.dims() == 4,
726 errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
727
728 Tensor* output = nullptr;
729 OP_REQUIRES_OK(context,
730 context->allocate_output(0, tensor_out.shape(), &output));
731
732 std::vector<int32> ksize = ksize_;
733 std::vector<int32> stride = stride_;
734 if (context->num_inputs() == 5) {
735 const Tensor& tensor_ksize = context->input(3);
736 auto value_ksize = tensor_ksize.flat<int32>();
737 ksize.resize(tensor_ksize.shape().num_elements());
738 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
739
740 const Tensor& tensor_stride = context->input(4);
741 auto value_stride = tensor_stride.flat<int32>();
742 stride.resize(tensor_stride.shape().num_elements());
743 std::copy_n(&value_stride(0), stride.size(), stride.begin());
744 }
745
746 OP_REQUIRES(context, ksize.size() == 4,
747 errors::InvalidArgument("Sliding window ksize field must "
748 "specify 4 dimensions"));
749 OP_REQUIRES(context, stride.size() == 4,
750 errors::InvalidArgument("Sliding window strides field must "
751 "specify 4 dimensions"));
752 const int32_t ksize_n = GetTensorDim(ksize, data_format_, 'N');
753 const int32_t stride_n = GetTensorDim(stride, data_format_, 'N');
754 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
755 errors::Unimplemented(
756 "Pooling is not yet supported on the batch dimension."));
757
758 PoolParameters params{context,
759 ksize,
760 stride,
761 padding_,
762 /*explicit_paddings=*/{},
763 data_format_,
764 tensor_in.shape()};
765 if (!context->status().ok()) {
766 return;
767 }
768 OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
769 errors::InvalidArgument("Expected orig_output shape to be ",
770 params.forward_output_shape(),
771 ", but got ", tensor_out.shape()));
772 OP_REQUIRES(
773 context, out_grad_backprop.shape() == tensor_in.shape(),
774 errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
775 ", but got ", out_grad_backprop.shape()));
776
777 functor::MaxPoolGradBackwardNoMask<T>()(
778 data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
779 params.tensor_in_batch, params.out_height, params.out_width,
780 params.depth, params.tensor_in_rows, params.tensor_in_cols,
781 params.window_rows, params.window_cols, params.row_stride,
782 params.col_stride, params.pad_top, params.pad_left,
783 out_grad_backprop.flat<T>().data(), output->flat<T>().data(),
784 context->eigen_device<Eigen::GpuDevice>());
785 }
786
787 private:
788 std::vector<int32> ksize_;
789 std::vector<int32> stride_;
790 Padding padding_;
791 TensorFormat data_format_;
792 bool use_dnn_;
793};
794
795#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
796
797template <typename Device, typename T>
798struct LaunchMaxPoolingNoMask;
799
800template <typename Device, typename T>
801class MaxPoolingNoMaskOp : public OpKernel {
802 public:
803 explicit MaxPoolingNoMaskOp(OpKernelConstruction* context)
804 : OpKernel(context) {
805 string data_format;
806 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
807 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
808 errors::InvalidArgument("Invalid data format"));
809 OP_REQUIRES(
810 context, data_format_ == FORMAT_NHWC,
811 errors::InvalidArgument(
812 "Default MaxPoolingNoMaskOp only supports NHWC on device type ",
813 DeviceTypeString(context->device_type())));
814 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
815 OP_REQUIRES(context, ksize_.size() == 4,
816 errors::InvalidArgument("Sliding window ksize field must "
817 "specify 4 dimensions"));
818 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
819 OP_REQUIRES(context, stride_.size() == 4,
820 errors::InvalidArgument("Sliding window stride field must "
821 "specify 4 dimensions"));
822 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
823 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
824 errors::Unimplemented(
825 "Pooling is not yet supported on the batch dimension."));
826 OP_REQUIRES(
827 context, padding_ != EXPLICIT,
828 errors::Unimplemented(
829 "Explicit padding is not supported for MaxPoolingNoMaskOp."));
830 }
831
832 void Compute(OpKernelContext* context) override {
833 const Tensor& tensor_in = context->input(0);
834
835 PoolParameters params{context,
836 ksize_,
837 stride_,
838 padding_,
839 /*explicit_paddings=*/{},
840 data_format_,
841 tensor_in.shape()};
842 if (!context->status().ok()) {
843 return;
844 }
845
846 TensorShape out_shape({params.tensor_in_batch, params.out_height,
847 params.out_width, params.depth});
848 Tensor* output = nullptr;
849 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
850
851 LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
852 output);
853 }
854
855 private:
856 std::vector<int32> ksize_;
857 std::vector<int32> stride_;
858 Padding padding_;
859 TensorFormat data_format_;
860};
861
862template <typename Device, typename T>
863class MaxPoolingNoMaskV2Op : public OpKernel {
864 public:
865 explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
866 : OpKernel(context) {
867 string data_format;
868 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
869 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
870 errors::InvalidArgument("Invalid data format"));
871 OP_REQUIRES(
872 context, data_format_ == FORMAT_NHWC,
873 errors::InvalidArgument(
874 "Default MaxPoolingNoMaskOp only supports NHWC on device type ",
875 DeviceTypeString(context->device_type())));
876 if (context->num_inputs() == 1) {
877 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
878 OP_REQUIRES(context, ksize_.size() == 4,
879 errors::InvalidArgument("Sliding window ksize field must "
880 "specify 4 dimensions"));
881 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
882 OP_REQUIRES(context, stride_.size() == 4,
883 errors::InvalidArgument("Sliding window stride field must "
884 "specify 4 dimensions"));
885 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
886 errors::Unimplemented(
887 "Pooling is not yet supported on the batch dimension."));
888 }
889 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
890 }
891
892 void Compute(OpKernelContext* context) override {
893 const Tensor& tensor_in = context->input(0);
894
895 std::vector<int32> ksize = ksize_;
896 std::vector<int32> stride = stride_;
897
898 if (context->num_inputs() != 1) {
899 const Tensor& tensor_ksize = context->input(1);
900 auto value_ksize = tensor_ksize.flat<int32>();
901 ksize.resize(tensor_ksize.shape().num_elements());
902 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
903
904 const Tensor& tensor_stride = context->input(2);
905 auto value_stride = tensor_stride.flat<int32>();
906 stride.resize(tensor_stride.shape().num_elements());
907 std::copy_n(&value_stride(0), stride.size(), stride.begin());
908 }
909 OP_REQUIRES(context, ksize.size() == 4,
910 errors::InvalidArgument("Sliding window ksize field must "
911 "specify 4 dimensions"));
912 OP_REQUIRES(context, stride.size() == 4,
913 errors::InvalidArgument("Sliding window stride field must "
914 "specify 4 dimensions"));
915 OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
916 errors::Unimplemented(
917 "Pooling is not yet supported on the batch dimension."));
918 PoolParameters params{context,
919 ksize,
920 stride,
921 padding_,
922 /*explicit_paddings=*/{},
923 data_format_,
924 tensor_in.shape()};
925 if (!context->status().ok()) {
926 return;
927 }
928
929 TensorShape out_shape({params.tensor_in_batch, params.out_height,
930 params.out_width, params.depth});
931 Tensor* output = nullptr;
932 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
933
934 LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
935 output);
936 }
937
938 private:
939 std::vector<int32> ksize_;
940 std::vector<int32> stride_;
941 Padding padding_;
942 TensorFormat data_format_;
943};
944
945template <typename Device, typename T, typename Targmax>
946struct LaunchMaxPoolingWithArgmax;
947
948template <typename T, typename Targmax>
949struct LaunchMaxPoolingWithArgmax<CPUDevice, T, Targmax> {
950 static void launch(OpKernelContext* context, const PoolParameters& params,
951 const Tensor& input, Tensor* output, Tensor* argmax,
952 bool propagate_nans, bool include_batch_in_index) {
953 Tensor unused;
954 SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, Targmax>(
955 context, output, argmax, /*input_backprop=*/nullptr, input, unused,
956 params, include_batch_in_index);
957 }
958};
959
960template <typename Device, typename T, typename Targmax>
961class MaxPoolingWithArgmaxOp : public OpKernel {
962 public:
963 explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
964 : OpKernel(context) {
965 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
966 OP_REQUIRES(context, ksize_.size() == 4,
967 errors::InvalidArgument("Sliding window ksize field must "
968 "specify 4 dimensions"));
969 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
970 OP_REQUIRES(context, stride_.size() == 4,
971 errors::InvalidArgument("Sliding window stride field must "
972 "specify 4 dimensions"));
973 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
974 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
975 errors::Unimplemented(
976 "Pooling is not yet supported on the batch dimension."));
977 OP_REQUIRES_OK(context, context->GetAttr("include_batch_in_index",
978 &include_batch_in_index_));
979 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
980 &propagate_nans_));
981 }
982
983 void Compute(OpKernelContext* context) override {
984 const Tensor& tensor_in = context->input(0);
985 OP_REQUIRES(context, tensor_in.dims() == 4,
986 errors::InvalidArgument("tensor_in must be 4-dimensional (2)"));
987 OP_REQUIRES(context, tensor_in.NumElements() > 0,
988 errors::InvalidArgument("tensor_in must not be empty (2)"));
989
990 PoolParameters params{context,
991 ksize_,
992 stride_,
993 padding_,
994 /*explicit_paddings=*/{},
995 FORMAT_NHWC,
996 tensor_in.shape()};
997 if (!context->status().ok()) {
998 return;
999 }
1000
1001 TensorShape out_shape({params.tensor_in_batch, params.out_height,
1002 params.out_width, params.depth});
1003 Tensor* output = nullptr;
1004 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1005 Tensor* argmax = nullptr;
1006 OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
1007
1008 LaunchMaxPoolingWithArgmax<Device, T, Targmax>::launch(
1009 context, params, tensor_in, output, argmax, propagate_nans_,
1010 include_batch_in_index_);
1011 }
1012
1013 private:
1014 std::vector<int32> ksize_;
1015 std::vector<int32> stride_;
1016 Padding padding_;
1017 bool propagate_nans_;
1018 bool include_batch_in_index_;
1019};
1020
1021template <typename Device, typename T>
1022struct LaunchMaxPoolingGradWithArgmax;
1023
1024template <typename T>
1025struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
1026 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
1027 EigenMatrixMap;
1028
1029 static void launch(OpKernelContext* context, const PoolParameters& params,
1030 const Tensor& grad_in, const Tensor& argmax,
1031 Tensor* grad_out, const bool include_batch_in_index) {
1032 const DeviceBase::CpuWorkerThreads& worker_threads =
1033 *(context->device()->tensorflow_cpu_worker_threads());
1034
1035 auto shard = [&grad_in, &argmax, &grad_out, include_batch_in_index](
1036 int64_t start, int64_t limit) {
1037 const int64_t batch_size =
1038 GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
1039 const int64_t output_size_per_batch =
1040 grad_out->NumElements() / batch_size;
1041 const int64_t input_size_per_batch = grad_in.NumElements() / batch_size;
1042
1043 {
1044 auto grad_out_flat = grad_out->flat<T>();
1045 auto argmax_flat = argmax.flat<int64_t>();
1046 auto grad_in_flat = grad_in.flat<T>();
1047
1048 const int64_t output_start = start * output_size_per_batch;
1049 const int64_t output_end = limit * output_size_per_batch;
1050 EigenMatrixMap inputShard(grad_out_flat.data() + output_start, 1,
1051 output_end - output_start);
1052 inputShard.setConstant(T(0));
1053
1054 const int input_start = start * input_size_per_batch;
1055 const int input_end = limit * input_size_per_batch;
1056 for (int64_t index = input_start; index < input_end; index++) {
1057 if (index >= argmax.NumElements()) {
1058 break;
1059 }
1060 int64_t grad_out_index = argmax_flat(index);
1061 if (!include_batch_in_index) {
1062 const int64_t cur_batch = index / input_size_per_batch;
1063 grad_out_index += cur_batch * output_size_per_batch;
1064 }
1065 CHECK(grad_out_index >= output_start && grad_out_index < output_end)
1066 << "Invalid output gradient index: " << grad_out_index << ", "
1067 << output_start << ", " << output_end;
1068 grad_out_flat(grad_out_index) += grad_in_flat(index);
1069 }
1070 }
1071 };
1072
1073 const int64_t batch_size =
1074 GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
1075 const int64_t shard_cost = grad_out->NumElements() / batch_size;
1076 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
1077 shard_cost, shard);
1078 }
1079};
1080
1081// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op.
1082template <typename Device, typename T>
1083class MaxPoolingGradWithArgmaxOp : public OpKernel {
1084 public:
1085 explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
1086 : OpKernel(context) {
1087 string data_format_str;
1088 if (std::is_same<Device, GPUDevice>::value) {
1089 OP_REQUIRES(context, !tensorflow::OpDeterminismRequired(),
1090 errors::Unimplemented("Determinism is not yet supported "
1091 "for MaxPoolGradWithArgmax."));
1092 }
1093 auto status = context->GetAttr("data_format", &data_format_str);
1094 if (status.ok()) {
1095 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
1096 errors::InvalidArgument("Invalid data format"));
1097 }
1098
1099 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1100 OP_REQUIRES(context, ksize_.size() == 4,
1101 errors::InvalidArgument("Sliding window ksize field must "
1102 "specify 4 dimensions"));
1103 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1104 OP_REQUIRES(context, stride_.size() == 4,
1105 errors::InvalidArgument("Sliding window stride field must "
1106 "specify 4 dimensions"));
1107 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1108 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
1109 errors::Unimplemented(
1110 "Pooling is not yet supported on the batch dimension."));
1111 OP_REQUIRES_OK(context, context->GetAttr("include_batch_in_index",
1112 &include_batch_in_index_));
1113 }
1114
1115 void Compute(OpKernelContext* context) override {
1116 const Tensor& tensor_in = context->input(0);
1117 const Tensor& grad_in = context->input(1);
1118 const Tensor& argmax = context->input(2);
1119
1120 PoolParameters params{context,
1121 ksize_,
1122 stride_,
1123 padding_,
1124 /*explicit_paddings=*/{},
1125 FORMAT_NHWC,
1126 tensor_in.shape()};
1127 if (!context->status().ok()) {
1128 return;
1129 }
1130 OP_REQUIRES(context, grad_in.shape() == params.forward_output_shape(),
1131 errors::InvalidArgument("Expected grad shape to be ",
1132 params.forward_output_shape(),
1133 ", but got ", grad_in.shape()));
1134 OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
1135 errors::InvalidArgument("Expected argmax shape to be ",
1136 params.forward_output_shape(),
1137 ", but got ", argmax.shape()));
1138
1139 TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows,
1140 params.tensor_in_cols, params.depth});
1141 Tensor* grad_out = nullptr;
1142 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1143 {0}, 0, out_shape, &grad_out));
1144
1145 if (out_shape.num_elements() == 0) return; // nothing to be done
1146
1147 LaunchMaxPoolingGradWithArgmax<Device, T>::launch(
1148 context, params, grad_in, argmax, grad_out, include_batch_in_index_);
1149 }
1150
1151 private:
1152 std::vector<int32> ksize_;
1153 std::vector<int32> stride_;
1154 Padding padding_;
1155 TensorFormat data_format_;
1156 bool include_batch_in_index_;
1157};
1158
1159template <typename Device, typename T>
1160struct LaunchMaxPoolingGradGradWithArgmax;
1161
1162template <typename Device, typename T>
1163class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
1164 public:
1165 explicit MaxPoolingGradGradWithArgmaxOp(OpKernelConstruction* context)
1166 : OpKernel(context) {
1167 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1168 OP_REQUIRES(context, ksize_.size() == 4,
1169 errors::InvalidArgument("Sliding window ksize field must "
1170 "specify 4 dimensions"));
1171 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1172 OP_REQUIRES(context, stride_.size() == 4,
1173 errors::InvalidArgument("Sliding window stride field must "
1174 "specify 4 dimensions"));
1175 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1176 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
1177 errors::Unimplemented(
1178 "Pooling is not yet supported on the batch dimension."));
1179 OP_REQUIRES_OK(context, context->GetAttr("include_batch_in_index",
1180 &include_batch_in_index_));
1181 }
1182
1183 void Compute(OpKernelContext* context) override {
1184 const Tensor& tensor_in = context->input(0);
1185 const Tensor& grad_in = context->input(1);
1186 const Tensor& argmax = context->input(2);
1187
1188 PoolParameters params{context,
1189 ksize_,
1190 stride_,
1191 padding_,
1192 /*explicit_paddings=*/{},
1193 FORMAT_NHWC,
1194 tensor_in.shape()};
1195 if (!context->status().ok()) {
1196 return;
1197 }
1198 OP_REQUIRES(
1199 context, grad_in.shape() == tensor_in.shape(),
1200 errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
1201 ", but got ", grad_in.shape()));
1202 OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
1203 errors::InvalidArgument("Expected argmax shape to be ",
1204 params.forward_output_shape(),
1205 ", but got ", argmax.shape()));
1206
1207 TensorShape out_shape({params.tensor_in_batch, params.out_height,
1208 params.out_width, params.depth});
1209
1210 Tensor* grad_out = nullptr;
1211 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1212 {0}, 0, out_shape, &grad_out));
1213
1214 LaunchMaxPoolingGradGradWithArgmax<Device, T>::launch(
1215 context, params, grad_in, argmax, grad_out, include_batch_in_index_);
1216 }
1217
1218 private:
1219 std::vector<int32> ksize_;
1220 std::vector<int32> stride_;
1221 Padding padding_;
1222 bool include_batch_in_index_;
1223};
1224
1225#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1226template <typename T>
1227class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
1228 public:
1229 typedef GPUDevice Device;
1230 explicit MaxPoolingNoMaskOp(OpKernelConstruction* context)
1231 : OpKernel(context) {
1232 string data_format;
1233 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1234 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1235 errors::InvalidArgument("Invalid data format"));
1236 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1237 OP_REQUIRES(context, ksize_.size() == 4,
1238 errors::InvalidArgument("Sliding window ksize field must "
1239 "specify 4 dimensions"));
1240 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1241 OP_REQUIRES(context, stride_.size() == 4,
1242 errors::InvalidArgument("Sliding window stride field must "
1243 "specify 4 dimensions"));
1244 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1245 OP_REQUIRES_OK(context,
1246 context->GetAttr("explicit_paddings", &explicit_paddings_));
1247 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
1248 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
1249 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1250 errors::Unimplemented(
1251 "Pooling is not yet supported on the batch dimension."));
1252
1253 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
1254 &propagate_nans_));
1255 }
1256
1257 void Compute(OpKernelContext* context) override {
1258 const Tensor& tensor_in = context->input(0);
1259
1260 PoolParameters params{
1261 context, ksize_, stride_, padding_, explicit_paddings_,
1262 data_format_, tensor_in.shape()};
1263 if (!context->status().ok()) {
1264 return;
1265 }
1266
1267 TensorShape out_shape =
1268 ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
1269 params.out_width, params.depth);
1270
1271 // Degenerate pooling output should return an empty tensor.
1272 if (out_shape.num_elements() == 0) {
1273 Tensor* output = nullptr;
1274 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1275 return;
1276 }
1277
1278 // Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
1279 constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
1280 OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
1281 errors::InvalidArgument(
1282 "qint8 should be used with data_format NCHW_VECT_C."));
1283
1284#if CUDNN_VERSION >= 7300
1285 DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
1286 stride_, padding_, explicit_paddings_,
1287 data_format_, tensor_in, out_shape,
1288 propagate_nans_);
1289#else
1290 // These is_int8x4 checks avoid linker errors for missing qint8 kernels.
1291 if (!is_int8x4 && data_format_ == FORMAT_NCHW) {
1292 DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
1293 stride_, padding_, explicit_paddings_,
1294 data_format_, tensor_in, out_shape,
1295 propagate_nans_);
1296 } else {
1297#if !defined(TENSORFLOW_USE_ROCM)
1298 OP_REQUIRES(context, padding_ != EXPLICIT,
1299 errors::Unimplemented("Explicit padding is not supported ",
1300 "when CUDNN is not enabled."));
1301#endif
1302 Tensor* output = nullptr;
1303 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1304 if (is_int8x4) {
1305 LaunchMaxPoolingNoMask_NCHW_VECT_C<Device>::launch(context, params,
1306 tensor_in, output);
1307 } else if (data_format_ == FORMAT_NHWC) {
1308 LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
1309 output, propagate_nans_);
1310 } else {
1311 LOG(FATAL) << "MaxPool currently only supports the following (layout, "
1312 "type) combinations: (NHWC, non-qint8), "
1313 "(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The "
1314 "requested combination ("
1315 << ToString(data_format_) << ", "
1316 << DataTypeString(DataTypeToEnum<T>::v())
1317 << ") is not supported.";
1318 }
1319 }
1320#endif
1321 }
1322
1323 private:
1324 std::vector<int32> ksize_;
1325 std::vector<int32> stride_;
1326 Padding padding_;
1327 std::vector<int64_t> explicit_paddings_;
1328 TensorFormat data_format_;
1329 bool propagate_nans_;
1330};
1331
1332template <typename T>
1333class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
1334 public:
1335 typedef GPUDevice Device;
1336 explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
1337 : OpKernel(context) {
1338 string data_format;
1339 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1340 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1341 errors::InvalidArgument("Invalid data format"));
1342 if (context->num_inputs() == 1) {
1343 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1344 OP_REQUIRES(context, ksize_.size() == 4,
1345 errors::InvalidArgument("Sliding window ksize field must "
1346 "specify 4 dimensions"));
1347 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1348 OP_REQUIRES(context, stride_.size() == 4,
1349 errors::InvalidArgument("Sliding window stride field must "
1350 "specify 4 dimensions"));
1351 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
1352 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
1353 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1354 errors::Unimplemented(
1355 "Pooling is not yet supported on the batch dimension."));
1356 }
1357 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1358 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
1359 &propagate_nans_));
1360 }
1361
1362 void Compute(OpKernelContext* context) override {
1363 const Tensor& tensor_in = context->input(0);
1364
1365 std::vector<int32> ksize = ksize_;
1366 std::vector<int32> stride = stride_;
1367
1368 if (context->num_inputs() != 1) {
1369 const Tensor& tensor_ksize = context->input(1);
1370 auto value_ksize = tensor_ksize.flat<int32>();
1371 ksize.resize(tensor_ksize.shape().num_elements());
1372 std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
1373
1374 const Tensor& tensor_stride = context->input(2);
1375 auto value_stride = tensor_stride.flat<int32>();
1376 stride.resize(tensor_stride.shape().num_elements());
1377 std::copy_n(&value_stride(0), stride.size(), stride.begin());
1378 }
1379 OP_REQUIRES(context, ksize.size() == 4,
1380 errors::InvalidArgument("Sliding window ksize field must "
1381 "specify 4 dimensions"));
1382 OP_REQUIRES(context, stride.size() == 4,
1383 errors::InvalidArgument("Sliding window stride field must "
1384 "specify 4 dimensions"));
1385 const int32_t ksize_n = GetTensorDim(ksize, data_format_, 'N');
1386 const int32_t stride_n = GetTensorDim(stride, data_format_, 'N');
1387 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1388 errors::Unimplemented(
1389 "Pooling is not yet supported on the batch dimension."));
1390
1391 PoolParameters params{context,
1392 ksize,
1393 stride,
1394 padding_,
1395 /*explicit_paddings=*/{},
1396 data_format_,
1397 tensor_in.shape()};
1398 if (!context->status().ok()) {
1399 return;
1400 }
1401
1402 TensorShape out_shape =
1403 ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
1404 params.out_width, params.depth);
1405 if (data_format_ == FORMAT_NCHW) {
1406 DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
1407 stride, padding_, explicit_paddings_,
1408 data_format_, tensor_in, out_shape,
1409 propagate_nans_);
1410 } else {
1411 CHECK(data_format_ == FORMAT_NHWC)
1412 << "MaxPool only supports NCHW or NHWC format";
1413 Tensor* output = nullptr;
1414 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1415 LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
1416 output, propagate_nans_);
1417 }
1418 }
1419
1420 private:
1421 std::vector<int32> ksize_;
1422 std::vector<int32> stride_;
1423 Padding padding_;
1424 std::vector<int64_t> explicit_paddings_;
1425 TensorFormat data_format_;
1426 bool propagate_nans_;
1427};
1428
1429template <typename T>
1430struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
1431 static void launch(OpKernelContext* context, const PoolParameters& params,
1432 const Tensor& input, Tensor* output, bool propagate_nans) {
1433 bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
1434 input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
1435 params.tensor_in_cols, params.depth, params.out_height,
1436 params.out_width, params.window_rows, params.window_cols,
1437 params.row_stride, params.col_stride, params.pad_top, params.pad_left,
1438 output->flat<T>().data(), nullptr, context->eigen_gpu_device(),
1439 propagate_nans, false);
1440 if (!status) {
1441 context->SetStatus(
1442 errors::Internal("Failed launching MaxPoolForwardNoMask"));
1443 }
1444 }
1445};
1446
1447template <typename T>
1448struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T, int64_t> {
1449 static void launch(OpKernelContext* context, const PoolParameters& params,
1450 const Tensor& input, Tensor* output, Tensor* argmax,
1451 bool propagate_nans, bool include_batch_in_index) {
1452 bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
1453 input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
1454 params.tensor_in_cols, params.depth, params.out_height,
1455 params.out_width, params.window_rows, params.window_cols,
1456 params.row_stride, params.col_stride, params.pad_top, params.pad_left,
1457 output->flat<T>().data(),
1458 reinterpret_cast<int64_t*>(argmax->flat<int64_t>().data()),
1459 context->eigen_gpu_device(), propagate_nans, include_batch_in_index);
1460 if (!status) {
1461 context->SetStatus(
1462 errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
1463 }
1464 }
1465};
1466
1467template <typename T>
1468struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
1469 static void launch(OpKernelContext* context, const PoolParameters& params,
1470 const Tensor& grad_in, const Tensor& argmax,
1471 Tensor* grad_out, const bool include_batch_in_index) {
1472 const int input_size = params.tensor_in_batch * params.tensor_in_rows *
1473 params.tensor_in_cols * params.depth;
1474 const int output_size = params.tensor_in_batch * params.out_height *
1475 params.out_width * params.depth;
1476 const int top_offset = params.out_height * params.out_width * params.depth;
1477 const int bottom_offset =
1478 params.tensor_in_rows * params.tensor_in_cols * params.depth;
1479 bool status = functor::MaxPoolBackwardWithArgmax<T>()(
1480 output_size, input_size, grad_in.flat<T>().data(),
1481 reinterpret_cast<const int64_t*>(argmax.flat<int64_t>().data()),
1482 top_offset, bottom_offset, grad_out->flat<T>().data(),
1483 context->eigen_gpu_device(), include_batch_in_index);
1484 if (!status) {
1485 context->SetStatus(
1486 errors::Internal("Failed launching MaxPoolBackwardWithArgmax"));
1487 }
1488 }
1489};
1490
1491template <typename T>
1492struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
1493 static void launch(OpKernelContext* context, const PoolParameters& params,
1494 const Tensor& grad_in, const Tensor& argmax,
1495 Tensor* grad_out, const bool include_batch_in_index) {
1496 const int input_size = params.tensor_in_batch * params.tensor_in_rows *
1497 params.tensor_in_cols * params.depth;
1498 const int output_size = params.tensor_in_batch * params.out_height *
1499 params.out_width * params.depth;
1500 const int top_offset =
1501 params.tensor_in_rows * params.tensor_in_cols * params.depth;
1502 const int bottom_offset =
1503 params.out_width * params.out_height * params.depth;
1504 bool status = functor::MaxPoolGradBackwardWithArgmax<T>()(
1505 output_size, input_size, grad_in.flat<T>().data(),
1506 reinterpret_cast<const int64_t*>(argmax.flat<int64_t>().data()),
1507 top_offset, bottom_offset, grad_out->flat<T>().data(),
1508 context->eigen_gpu_device(), include_batch_in_index);
1509 if (!status) {
1510 context->SetStatus(
1511 errors::Internal("Failed launching MaxPoolGradBackwardWithArgmax"));
1512 }
1513 }
1514};
1515
1516#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1517
1518#define REGISTER_MAX_POOL_KERNELS(D, T) \
1519 REGISTER_KERNEL_BUILDER( \
1520 Name("MaxPoolGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1521 MaxPoolingGradOp<D##Device, T>); \
1522 REGISTER_KERNEL_BUILDER( \
1523 Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1524 MaxPoolingGradGradOp<D##Device, T>); \
1525 REGISTER_KERNEL_BUILDER(Name("MaxPoolGradV2") \
1526 .Device(DEVICE_##D) \
1527 .HostMemory("ksize") \
1528 .HostMemory("strides") \
1529 .TypeConstraint<T>("T"), \
1530 MaxPoolingGradOp<D##Device, T>); \
1531 REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradV2") \
1532 .Device(DEVICE_##D) \
1533 .HostMemory("ksize") \
1534 .HostMemory("strides") \
1535 .TypeConstraint<T>("T"), \
1536 MaxPoolingGradGradOp<D##Device, T>) \
1537 REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
1538 .Device(DEVICE_##D) \
1539 .TypeConstraint<int64_t>("Targmax") \
1540 .TypeConstraint<T>("T"), \
1541 MaxPoolingWithArgmaxOp<D##Device, T, int64>); \
1542 REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
1543 .Device(DEVICE_##D) \
1544 .TypeConstraint<T>("T") \
1545 .TypeConstraint<int64_t>("Targmax"), \
1546 MaxPoolingGradWithArgmaxOp<D##Device, T>);
1547
1548// Below kernels implemented only for CPU device.
1549#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
1550 REGISTER_KERNEL_BUILDER( \
1551 Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1552 MaxPoolingOp<CPUDevice, T>); \
1553 REGISTER_KERNEL_BUILDER( \
1554 Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1555 MaxPoolingV2Op<CPUDevice, T>); \
1556 REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
1557 .Device(DEVICE_CPU) \
1558 .TypeConstraint<int32>("Targmax") \
1559 .TypeConstraint<T>("T"), \
1560 MaxPoolingWithArgmaxOp<CPUDevice, T, int32>);
1561TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
1562#undef REGISTER_CPU_ONLY_POOL_KERNELS
1563
1564#define REGISTER_CPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(CPU, T);
1565TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MAX_POOL_KERNELS);
1566#undef REGISTER_CPU_KERNELS
1567
1568#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1569
1570// Forward declarations for the functor specializations for GPU.
1571namespace functor {
1572#define DECLARE_GPU_SPEC(T) \
1573 template <> \
1574 void SpatialMaxPooling<Eigen::GpuDevice, T>::operator()( \
1575 const Eigen::GpuDevice& d, typename TTypes<T, 4>::Tensor output, \
1576 typename TTypes<T, 4>::ConstTensor input, int window_rows, \
1577 int window_cols, int row_stride, int col_stride, \
1578 const Eigen::PaddingType& padding); \
1579 extern template struct SpatialMaxPooling<Eigen::GpuDevice, T>;
1580
1581TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
1582#undef DECLARE_GPU_SPEC
1583} // namespace functor
1584
1585#define REGISTER_GPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(GPU, T)
1586TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
1587#undef REGISTER_GPU_MAX_POOL_KERNELS
1588
1589// Below kernels currently implemented only for GPU device.
1590// Note(jiayq): Currently, the Caffe custom implementation is faster than the
1591// default Eigen implementation so we are using the custom kernel as the
1592// default. However, you can explicitly invoke the eigen version using
1593// kernel_label_map.
1594#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
1595 REGISTER_KERNEL_BUILDER(Name("MaxPool") \
1596 .Device(DEVICE_GPU) \
1597 .TypeConstraint<T>("T") \
1598 .Label("eigen_tensor"), \
1599 MaxPoolingOp<GPUDevice, T>); \
1600 REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
1601 .Device(DEVICE_GPU) \
1602 .HostMemory("ksize") \
1603 .HostMemory("strides") \
1604 .TypeConstraint<T>("T") \
1605 .Label("eigen_tensor"), \
1606 MaxPoolingV2Op<GPUDevice, T>); \
1607 REGISTER_KERNEL_BUILDER( \
1608 Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1609 MaxPoolingNoMaskOp<GPUDevice, T>); \
1610 REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
1611 .Device(DEVICE_GPU) \
1612 .HostMemory("ksize") \
1613 .HostMemory("strides") \
1614 .TypeConstraint<T>("T"), \
1615 MaxPoolingNoMaskV2Op<GPUDevice, T>); \
1616 REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
1617 .Device(DEVICE_GPU) \
1618 .TypeConstraint<T>("T") \
1619 .TypeConstraint<int64_t>("Targmax"), \
1620 MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
1621TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
1622
1623// TODO(b/65847473): Re-enable once the underlying build error is fixed.
1624#if !defined(PLATFORM_WINDOWS)
1625REGISTER_KERNEL_BUILDER(
1626 Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
1627 MaxPoolingNoMaskOp<GPUDevice, qint8>);
1628
1629REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
1630 .Device(DEVICE_GPU)
1631 .HostMemory("ksize")
1632 .HostMemory("strides")
1633 .TypeConstraint<qint8>("T"),
1634 MaxPoolingV2Op<GPUDevice, qint8>);
1635
1636REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
1637 .Device(DEVICE_GPU)
1638 .HostMemory("ksize")
1639 .HostMemory("strides")
1640 .TypeConstraint<qint8>("T")
1641 .Label("eigen_tensor"),
1642 MaxPoolingV2Op<GPUDevice, qint8>);
1643#endif // !defined(PLATFORM_WINDOWS)
1644
1645#undef REGISTER_GPU_ONLY_POOL_KERNELS
1646
1647#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1648
1649#undef REGISTER_MAX_POOL_KERNELS
1650
1651} // namespace tensorflow
1652