1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/avgpooling_op.h"
21
22#include <vector>
23
24#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25#include "tensorflow/core/framework/kernel_shape_util.h"
26#include "tensorflow/core/framework/numeric_op.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/framework/tensor_slice.h"
32#include "tensorflow/core/kernels/eigen_pooling.h"
33#include "tensorflow/core/kernels/ops_util.h"
34#include "tensorflow/core/kernels/pooling_ops_common.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/gtl/array_slice.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/util/overflow.h"
39#include "tensorflow/core/util/padding.h"
40#include "tensorflow/core/util/tensor_format.h"
41
42#if GOOGLE_CUDA
43#include "third_party/gpus/cudnn/cudnn.h"
44#endif // GOOGLE_CUDA
45
46#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
48#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
49#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50
51namespace tensorflow {
52
53typedef Eigen::ThreadPoolDevice CPUDevice;
54typedef Eigen::GpuDevice GPUDevice;
55
56template <typename Device, typename T>
57class AvgPoolingOp : public UnaryOp<T> {
58 public:
59 explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
60 string data_format;
61 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
62 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
63 errors::InvalidArgument("Invalid data format"));
64 OP_REQUIRES(
65 context, data_format_ == FORMAT_NHWC,
66 errors::InvalidArgument("Default AvgPoolingOp only supports NHWC ",
67 "on device type ",
68 DeviceTypeString(context->device_type())));
69 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
70 OP_REQUIRES(context, ksize_.size() == 4,
71 errors::InvalidArgument("Sliding window ksize field must "
72 "specify 4 dimensions"));
73 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
74 OP_REQUIRES(context, stride_.size() == 4,
75 errors::InvalidArgument("Sliding window stride field must "
76 "specify 4 dimensions"));
77 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
78 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
79 errors::Unimplemented(
80 "Pooling is not yet supported on the batch dimension."));
81 for (int i = 0; i < ksize_.size(); ++i) {
82 OP_REQUIRES(context, ksize_[i] > 0,
83 errors::InvalidArgument(
84 "ksize must be a postive int32 value, got:", ksize_[i]));
85 }
86 }
87
88 void Compute(OpKernelContext* context) override {
89 const Tensor& tensor_in = context->input(0);
90 PoolParameters params{context,
91 ksize_,
92 stride_,
93 padding_,
94 /*explicit_paddings=*/{},
95 data_format_,
96 tensor_in.shape()};
97 if (!context->status().ok()) {
98 return;
99 }
100 OP_REQUIRES(context, params.depth_window == 1,
101 errors::Unimplemented("Non-spatial pooling is not "
102 "yet supported. Volunteers? :)"));
103
104 // For avgpooling, tensor_in should have 4 dimensions.
105 OP_REQUIRES(context, tensor_in.dims() == 4,
106 errors::InvalidArgument("tensor_in must be 4-dimensional"));
107
108 Tensor* output = nullptr;
109 OP_REQUIRES_OK(context, context->allocate_output(
110 0, params.forward_output_shape(), &output));
111
112 SpatialAvgPool<Device, T>(context, output, tensor_in, params, padding_);
113 }
114
115 private:
116 std::vector<int32> ksize_;
117 std::vector<int32> stride_;
118 Padding padding_;
119 TensorFormat data_format_;
120};
121
122REGISTER_KERNEL_BUILDER(
123 Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<double>("T"),
124 AvgPoolingOp<CPUDevice, double>);
125REGISTER_KERNEL_BUILDER(
126 Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
127 AvgPoolingOp<CPUDevice, float>);
128REGISTER_KERNEL_BUILDER(
129 Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
130 AvgPoolingOp<CPUDevice, Eigen::half>);
131
132#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
133template <typename T>
134class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
135 public:
136 typedef GPUDevice Device;
137 explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
138 string data_format;
139 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
140 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
141 errors::InvalidArgument("Invalid data format"));
142 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
143 OP_REQUIRES(context, ksize_.size() == 4,
144 errors::InvalidArgument("Sliding window ksize field must "
145 "specify 4 dimensions"));
146 for (int i = 0; i < ksize_.size(); ++i) {
147 OP_REQUIRES(context, ksize_[i] > 0,
148 errors::InvalidArgument(
149 "ksize must be a postive int32 value, got:", ksize_[i]));
150 }
151 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
152 OP_REQUIRES(context, stride_.size() == 4,
153 errors::InvalidArgument("Sliding window stride field must "
154 "specify 4 dimensions"));
155 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
156 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
157 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
158 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
159 errors::Unimplemented(
160 "Pooling is not yet supported on the batch dimension."));
161
162 for (int i = 0; i < ksize_.size(); ++i) {
163 OP_REQUIRES(context, ksize_[i] != 0,
164 errors::InvalidArgument("ksize cannot be zero"));
165 }
166 }
167
168 void Compute(OpKernelContext* context) override {
169 const Tensor& tensor_in = context->input(0);
170 PoolParameters params{context,
171 ksize_,
172 stride_,
173 padding_,
174 /*explicit_paddings=*/{},
175 data_format_,
176 tensor_in.shape()};
177 if (!context->status().ok()) {
178 return;
179 }
180 OP_REQUIRES(context, params.depth_window == 1,
181 errors::Unimplemented("Non-spatial pooling is not "
182 "yet supported. Volunteers? :)"));
183
184 // For avgpooling, tensor_in should have 4 dimensions.
185 OP_REQUIRES(context, tensor_in.dims() == 4,
186 errors::InvalidArgument("tensor_in must be 4-dimensional"));
187
188 TensorShape output_shape = params.forward_output_shape();
189 if (output_shape.num_elements() == 0) {
190 Tensor* output = nullptr;
191 OP_REQUIRES_OK(context,
192 context->allocate_output(0, output_shape, &output));
193 return;
194 }
195
196#if CUDNN_VERSION >= 7300
197 DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
198 stride_, padding_, /*explicit_paddings=*/{},
199 data_format_, tensor_in, output_shape,
200 /*propagate_nans=*/false);
201#else
202 if (data_format_ == FORMAT_NCHW) {
203 DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
204 stride_, padding_, /*explicit_paddings=*/{},
205 data_format_, tensor_in, output_shape,
206 /*propagate_nans=*/false);
207 } else {
208 Tensor* output = nullptr;
209 OP_REQUIRES_OK(context,
210 context->allocate_output(0, output_shape, &output));
211 Eigen::PaddingType pt = BrainPadding2EigenPadding(padding_);
212 functor::SpatialAvgPooling<Device, T>()(
213 context->eigen_device<Device>(), output->tensor<T, 4>(),
214 tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
215 params.row_stride, params.col_stride, pt);
216 }
217#endif // CUDNN_VERSION >= 7300
218 }
219
220 private:
221 std::vector<int32> ksize_;
222 std::vector<int32> stride_;
223 Padding padding_;
224 TensorFormat data_format_;
225};
226
227// Forward declarations of the functor specializations for GPU.
228namespace functor {
229#define DECLARE_GPU_SPEC(T) \
230 template <> \
231 void SpatialAvgPooling<GPUDevice, T>::operator()( \
232 const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
233 typename TTypes<T, 4>::ConstTensor input, int window_rows, \
234 int window_cols, int row_stride, int col_stride, \
235 const Eigen::PaddingType& padding); \
236 extern template struct SpatialAvgPooling<GPUDevice, T>;
237
238DECLARE_GPU_SPEC(Eigen::half);
239DECLARE_GPU_SPEC(float);
240DECLARE_GPU_SPEC(double);
241#undef DECLARE_GPU_SPEC
242} // namespace functor
243
244REGISTER_KERNEL_BUILDER(
245 Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
246 AvgPoolingOp<GPUDevice, Eigen::half>);
247REGISTER_KERNEL_BUILDER(
248 Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
249 AvgPoolingOp<GPUDevice, float>);
250REGISTER_KERNEL_BUILDER(
251 Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<double>("T"),
252 AvgPoolingOp<GPUDevice, double>);
253#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254
255// The operation to compute AvgPool gradients.
256// It takes two inputs:
257// - The original input tensor shape
258// - Backprop tensor for output
259// It produces one output: backprop tensor for input.
260template <typename Device, class T>
261class AvgPoolingGradOp : public OpKernel {
262 public:
263 explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
264 string data_format;
265 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
266 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
267 errors::InvalidArgument("Invalid data format"));
268 OP_REQUIRES(
269 context, data_format_ == FORMAT_NHWC,
270 errors::InvalidArgument("Default AvgPoolingGradOp only supports NHWC ",
271 "on device type ",
272 DeviceTypeString(context->device_type())));
273 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
274 OP_REQUIRES(context, ksize_.size() == 4,
275 errors::InvalidArgument("Sliding window ksize field must "
276 "specify 4 dimensions"));
277 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
278 OP_REQUIRES(context, stride_.size() == 4,
279 errors::InvalidArgument("Sliding window strides field must "
280 "specify 4 dimensions"));
281 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
282 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
283 errors::Unimplemented(
284 "Pooling is not yet supported on the batch dimension."));
285 }
286
287 void Compute(OpKernelContext* context) override {
288 const Tensor& tensor_in_shape = context->input(0);
289 const Tensor& out_backprop = context->input(1);
290 // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
291 OP_REQUIRES(
292 context,
293 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
294 errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
295 "elements"));
296 // For avgpooling, out_backprop should have 4 dimensions.
297 OP_REQUIRES(context, out_backprop.dims() == 4,
298 errors::InvalidArgument("out_backprop must be 4-dimensional"));
299 const int64_t out_backprop_batch = out_backprop.dim_size(0);
300 const int64_t out_backprop_rows = out_backprop.dim_size(1);
301 const int64_t out_backprop_cols = out_backprop.dim_size(2);
302 const int64_t out_backprop_depth = out_backprop.dim_size(3);
303
304 TensorShape output_shape;
305 auto shape_vec = tensor_in_shape.vec<int32>();
306 for (int64_t i = 0; i < tensor_in_shape.NumElements(); ++i) {
307 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(shape_vec(i)));
308 }
309 const int64_t in_rows = output_shape.dim_size(1);
310 const int64_t in_cols = output_shape.dim_size(2);
311
312 Tensor* output = nullptr;
313 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
314 output->flat<T>().setZero();
315
316 if (output_shape.num_elements() == 0) {
317 return;
318 }
319 const int window_rows = ksize_[1];
320 const int window_cols = ksize_[2];
321 const int depth_window = ksize_[3];
322
323 const int row_stride = stride_[1];
324 const int col_stride = stride_[2];
325
326 // We (will) use different code for spatial pooling and
327 // non-spatial pooling.
328 //
329 // Spatial pooling is when depth_window = 1
330 OP_REQUIRES(context, depth_window == 1,
331 errors::Unimplemented("Non-spatial pooling is not "
332 "yet supported. Volunteers? :)"));
333
334 int64_t out_height, out_width, pad_rows, pad_cols;
335 OP_REQUIRES_OK(context,
336 GetWindowedOutputSize(in_rows, window_rows, row_stride,
337 padding_, &out_height, &pad_rows));
338 OP_REQUIRES_OK(context,
339 GetWindowedOutputSize(in_cols, window_cols, col_stride,
340 padding_, &out_width, &pad_cols));
341
342 const T* out_backprop_ptr = out_backprop.flat<T>().data();
343 T* input_backprop_ptr = output->flat<T>().data();
344
345 auto shard = [context, out_backprop_ptr, input_backprop_ptr,
346 out_backprop_rows, out_backprop_cols, out_backprop_depth,
347 in_rows, in_cols, window_rows, window_cols, row_stride,
348 col_stride, pad_rows,
349 pad_cols](int64_t start, int64_t limit) {
350 for (int64_t b = start; b < limit; ++b) {
351 for (int64_t r = 0; r < out_backprop_rows; ++r) {
352 // Calculates row broadcast size. For SAME padding, current
353 // index could be in the padding area, and r*row_stride +
354 // window_rows could be beyond the input tensor's boundary. In
355 // such cases, change the starting index and reduce the
356 // broadcast size.
357 int rindex, rsize;
358 OP_REQUIRES_OK(context,
359 GetBroadcastSize(r, in_rows, window_rows, row_stride,
360 pad_rows, &rindex, &rsize));
361 for (int64_t c = 0; c < out_backprop_cols; ++c) {
362 // Calculates col broadcast size. For SAME padding, current
363 // index could be in the padding area, and c*col_stride +
364 // window_cols could be beyond the input tensor's boundary. In
365 // such cases, change the starting index and reduce the
366 // broadcast size.
367 int cindex, csize;
368 OP_REQUIRES_OK(context,
369 GetBroadcastSize(c, in_cols, window_cols, col_stride,
370 pad_cols, &cindex, &csize));
371
372 T divide_coeff(1.0 / (rsize * csize));
373 int64_t output_index =
374 (b * out_backprop_rows + r) * out_backprop_cols + c;
375 for (int64_t r_dst = rindex; r_dst < rindex + rsize; ++r_dst) {
376 for (int64_t c_dst = cindex; c_dst < cindex + csize; ++c_dst) {
377 int64_t input_index = (b * in_rows + r_dst) * in_cols + c_dst;
378 const T* output_offset =
379 out_backprop_ptr + output_index * out_backprop_depth;
380 T* input_offset =
381 input_backprop_ptr + input_index * out_backprop_depth;
382 for (int64_t d = 0; d < out_backprop_depth; ++d) {
383 *input_offset += *output_offset * divide_coeff;
384 ++output_offset;
385 ++input_offset;
386 }
387 }
388 }
389 }
390 }
391 }
392 };
393
394 const DeviceBase::CpuWorkerThreads& worker_threads =
395 *(context->device()->tensorflow_cpu_worker_threads());
396 const int64_t shard_cost =
397 window_rows * window_cols * depth_window * in_rows * in_rows * in_cols;
398 Shard(worker_threads.num_threads, worker_threads.workers,
399 out_backprop_batch, shard_cost, shard);
400 }
401
402 private:
403 std::vector<int32> ksize_;
404 std::vector<int32> stride_;
405 Padding padding_;
406 TensorFormat data_format_;
407};
408
409#define REGISTER_CPU_KERNEL(T) \
410 REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") \
411 .Device(DEVICE_CPU) \
412 .TypeConstraint<T>("T") \
413 .HostMemory("orig_input_shape"), \
414 AvgPoolingGradOp<CPUDevice, T>);
415
416TF_CALL_float(REGISTER_CPU_KERNEL);
417TF_CALL_double(REGISTER_CPU_KERNEL);
418TF_CALL_half(REGISTER_CPU_KERNEL);
419
420#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
421
422// A CUDNN based AvgPoolingGrad implementation. It includes the padding as the
423// candidates for the pooling operation.
424template <class T>
425class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
426 public:
427 typedef GPUDevice Device;
428
429 explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
430 string data_format;
431 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
432 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
433 errors::InvalidArgument("Invalid data format"));
434 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
435 OP_REQUIRES(context, ksize_.size() == 4,
436 errors::InvalidArgument("Sliding window ksize field must "
437 "specify 4 dimensions"));
438 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
439 OP_REQUIRES(context, stride_.size() == 4,
440 errors::InvalidArgument("Sliding window strides field must "
441 "specify 4 dimensions"));
442 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
443 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
444 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
445 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
446 errors::Unimplemented(
447 "Pooling is not yet supported on the batch dimension."));
448 }
449
450 void Compute(OpKernelContext* context) override {
451 const Tensor& tensor_in_shape = context->input(0);
452 const Tensor& out_backprop = context->input(1);
453 // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
454 OP_REQUIRES(
455 context,
456 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
457 errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
458 "elements"));
459 // For avgpooling, out_backprop should have 4 dimensions.
460 OP_REQUIRES(context, out_backprop.dims() == 4,
461 errors::InvalidArgument("out_backprop must be 4-dimensional"));
462
463 TensorShape output_shape;
464 auto shape_vec = tensor_in_shape.vec<int32>();
465 for (int64_t i = 0; i < tensor_in_shape.NumElements(); ++i) {
466 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(shape_vec(i)));
467 }
468
469 if (output_shape.num_elements() == 0) {
470 Tensor* output = nullptr;
471 OP_REQUIRES_OK(context,
472 context->allocate_output(0, output_shape, &output));
473 return;
474 }
475
476 DnnPoolingGradOp<T>::Compute(
477 context, se::dnn::PoolingMode::kAverage, ksize_, stride_, padding_,
478 /*explicit_paddings=*/{}, data_format_, nullptr, nullptr, out_backprop,
479 output_shape, /*propagate_nans=*/false);
480 }
481
482 private:
483 std::vector<int32> ksize_;
484 std::vector<int32> stride_;
485 Padding padding_;
486 TensorFormat data_format_;
487};
488
489REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
490 .Device(DEVICE_GPU)
491 .TypeConstraint<double>("T")
492 .HostMemory("orig_input_shape")
493 .Label("cudnn"),
494 AvgPoolingGradOp<GPUDevice, double>);
495REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
496 .Device(DEVICE_GPU)
497 .TypeConstraint<float>("T")
498 .HostMemory("orig_input_shape")
499 .Label("cudnn"),
500 AvgPoolingGradOp<GPUDevice, float>);
501REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
502 .Device(DEVICE_GPU)
503 .TypeConstraint<Eigen::half>("T")
504 .HostMemory("orig_input_shape")
505 .Label("cudnn"),
506 AvgPoolingGradOp<GPUDevice, Eigen::half>);
507
508// A custom GPU kernel based AvgPoolingGrad implementation. It includes the
509// padding as the candidates for the pooling operation.
510template <class T>
511class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
512 public:
513 typedef GPUDevice Device;
514
515 explicit AvgPoolingGradOpCustomGPUKernel(OpKernelConstruction* context)
516 : OpKernel(context) {
517 string data_format;
518 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
519 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
520 errors::InvalidArgument("Invalid data format"));
521 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
522 OP_REQUIRES(context, ksize_.size() == 4,
523 errors::InvalidArgument("Sliding window ksize field must "
524 "specify 4 dimensions"));
525 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
526 OP_REQUIRES(context, stride_.size() == 4,
527 errors::InvalidArgument("Sliding window strides field must "
528 "specify 4 dimensions"));
529 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
530 const int32_t ksize_n = GetTensorDim(ksize_, data_format_, 'N');
531 const int32_t stride_n = GetTensorDim(stride_, data_format_, 'N');
532 OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
533 errors::Unimplemented(
534 "Pooling is not yet supported on the batch dimension."));
535 }
536
537 void Compute(OpKernelContext* context) override {
538 const Tensor& tensor_in_shape = context->input(0);
539 const Tensor& out_backprop = context->input(1);
540 // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
541 OP_REQUIRES(
542 context,
543 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
544 errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
545 "elements"));
546 // For avgpooling, out_backprop should have 4 dimensions.
547 OP_REQUIRES(context, out_backprop.dims() == 4,
548 errors::InvalidArgument("out_backprop must be 4-dimensional"));
549 TensorShape output_shape;
550 auto shape_vec = tensor_in_shape.vec<int32>();
551 for (int64_t i = 0; i < tensor_in_shape.NumElements(); ++i) {
552 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(shape_vec(i)));
553 }
554 if (output_shape.num_elements() == 0) {
555 Tensor* output = nullptr;
556 OP_REQUIRES_OK(context,
557 context->allocate_output(0, output_shape, &output));
558 return;
559 }
560
561#if CUDNN_VERSION >= 7300
562 DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
563 ksize_, stride_, padding_,
564 /*explicit_paddings=*/{}, data_format_,
565 nullptr, nullptr, out_backprop, output_shape,
566 /*propagate_nans=*/false);
567#else
568 if (data_format_ == FORMAT_NHWC) {
569 const int64 out_backprop_batch = out_backprop.dim_size(0);
570 const int64 out_backprop_rows = out_backprop.dim_size(1);
571 const int64 out_backprop_cols = out_backprop.dim_size(2);
572 const int64 out_backprop_depth = out_backprop.dim_size(3);
573
574 const int64 in_rows = output_shape.dim_size(1);
575 const int64 in_cols = output_shape.dim_size(2);
576 Tensor* output = nullptr;
577 OP_REQUIRES_OK(context,
578 context->allocate_output(0, output_shape, &output));
579
580 const int window_rows = ksize_[1];
581 const int window_cols = ksize_[2];
582 const int depth_window = ksize_[3];
583
584 const int row_stride = stride_[1];
585 const int col_stride = stride_[2];
586
587 // We (will) use different code for spatial pooling and
588 // non-spatial pooling.
589 //
590 // Spatial pooling is when depth_window = 1
591 OP_REQUIRES(context, depth_window == 1,
592 errors::Unimplemented("Non-spatial pooling is not "
593 "yet supported. Volunteers? :)"));
594
595 int64 out_height, out_width, pad_rows, pad_cols;
596 OP_REQUIRES_OK(context,
597 GetWindowedOutputSize(in_rows, window_rows, row_stride,
598 padding_, &out_height, &pad_rows));
599 OP_REQUIRES_OK(context,
600 GetWindowedOutputSize(in_cols, window_cols, col_stride,
601 padding_, &out_width, &pad_cols));
602
603 RunAvePoolBackwardNHWC<T>(out_backprop.flat<T>().data(), // top_diff
604 out_backprop_batch, // num
605 in_rows, // height
606 in_cols, // width
607 out_backprop_depth, // channels
608 out_backprop_rows, // pooled_height
609 out_backprop_cols, // pooled_width
610 window_rows, // kernel_h
611 window_cols, // kernel_w
612 row_stride, // stride_h
613 col_stride, // stride_w
614 pad_rows, // pad_t
615 pad_cols, // pad_l
616 output->flat<T>().data(), // bottom_diff
617 context->eigen_gpu_device()); // d
618 } else {
619 DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
620 ksize_, stride_, padding_,
621 /*explicit_paddings=*/{}, data_format_,
622 nullptr, nullptr, out_backprop, output_shape,
623 /*propagate_nans=*/false);
624 }
625#endif // CUDNN_VERSION >= 7300
626 }
627
628 private:
629 std::vector<int32> ksize_;
630 std::vector<int32> stride_;
631 Padding padding_;
632 TensorFormat data_format_;
633};
634
635REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
636 .Device(DEVICE_GPU)
637 .TypeConstraint<float>("T")
638 .HostMemory("orig_input_shape"),
639 AvgPoolingGradOpCustomGPUKernel<float>);
640REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
641 .Device(DEVICE_GPU)
642 .TypeConstraint<double>("T")
643 .HostMemory("orig_input_shape"),
644 AvgPoolingGradOpCustomGPUKernel<double>);
645REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
646 .Device(DEVICE_GPU)
647 .TypeConstraint<Eigen::half>("T")
648 .HostMemory("orig_input_shape"),
649 AvgPoolingGradOpCustomGPUKernel<Eigen::half>);
650
651#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
652
653} // namespace tensorflow
654