1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#define EIGEN_USE_THREADS
16
17#include "tensorflow/core/kernels/pooling_ops_3d.h"
18
19#include <array>
20
21#include "third_party/eigen3/Eigen/Core"
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/kernel_shape_util.h"
24#include "tensorflow/core/framework/numeric_op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/framework/tensor_slice.h"
30#include "tensorflow/core/kernels/eigen_pooling.h"
31#include "tensorflow/core/kernels/ops_util.h"
32#include "tensorflow/core/lib/core/errors.h"
33#include "tensorflow/core/util/padding.h"
34#include "tensorflow/core/util/tensor_format.h"
35#include "tensorflow/core/util/work_sharder.h"
36
37#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
38#include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
39#include "tensorflow/core/kernels/pooling_ops_3d_gpu.h"
40#endif
41
42
43namespace tensorflow {
44
45typedef Eigen::ThreadPoolDevice CPUDevice;
46typedef Eigen::GpuDevice GPUDevice;
47
48Pool3dParameters::Pool3dParameters(OpKernelContext* context,
49 const std::vector<int32>& ksize,
50 const std::vector<int32>& stride,
51 Padding padding, TensorFormat data_format,
52 const TensorShape& tensor_in_shape) {
53 // For maxpooling, tensor_in should have 4 dimensions.
54 OP_REQUIRES(context, tensor_in_shape.dims() == 5,
55 errors::InvalidArgument("tensor_in must be 4-dimensional"));
56
57 this->data_format = data_format;
58 depth = GetTensorDim(tensor_in_shape, data_format, 'C');
59 tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
60 tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
61 tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
62 tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
63 window_planes = GetTensorDim(ksize, data_format, '0');
64 window_rows = GetTensorDim(ksize, data_format, '1');
65 window_cols = GetTensorDim(ksize, data_format, '2');
66 depth_window = GetTensorDim(ksize, data_format, 'C');
67 plane_stride = GetTensorDim(stride, data_format, '0');
68 row_stride = GetTensorDim(stride, data_format, '1');
69 col_stride = GetTensorDim(stride, data_format, '2');
70 depth_stride = GetTensorDim(stride, data_format, 'C');
71
72 // We only support 3D pooling across plane/width/height. Depthwise
73 // pooling is not supported.
74 OP_REQUIRES(
75 context, depth_window == 1 && depth_stride == 1,
76 errors::Unimplemented(
77 "Pooling3d only supports pooling across plane/width/height."));
78
79 OP_REQUIRES_OK(context, GetWindowedOutputSize(tensor_in_planes, window_planes,
80 plane_stride, padding,
81 &out_plane, &pad_planes));
82 OP_REQUIRES_OK(context,
83 GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
84 padding, &out_height, &pad_rows));
85 OP_REQUIRES_OK(context,
86 GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
87 padding, &out_width, &pad_cols));
88}
89
90TensorShape Pool3dParameters::forward_output_shape() {
91 return ShapeFromFormat(data_format, tensor_in_batch,
92 {{out_plane, out_height, out_width}}, depth);
93}
94
95template <typename T>
96struct LaunchPoolingOp<CPUDevice, T, AVG> {
97 static void launch(OpKernelContext* context, const Tensor& tensor_in,
98 const std::array<int64, 3>& window,
99 const std::array<int64, 3>& stride,
100 const std::array<int64, 3>& padding,
101 TensorFormat data_format, Padding padding_type,
102 Tensor* output) {
103 output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
104 Eigen::CuboidAvgPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
105 window[2], stride[0], stride[1], stride[2],
106 BrainPadding2EigenPadding(padding_type));
107 }
108};
109
110template <typename T>
111struct LaunchPoolingOp<CPUDevice, T, MAX> {
112 static void launch(OpKernelContext* context, const Tensor& tensor_in,
113 const std::array<int64, 3>& window,
114 const std::array<int64, 3>& stride,
115 const std::array<int64, 3>& padding,
116 TensorFormat data_format, Padding padding_type,
117 Tensor* output) {
118 output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
119 Eigen::CuboidMaxPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
120 window[2], stride[0], stride[1], stride[2],
121 BrainPadding2EigenPadding(padding_type));
122 }
123};
124
125template <typename Device, typename T, PoolingType Type>
126class Pooling3DOp : public UnaryOp<T> {
127 public:
128 explicit Pooling3DOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
129 string data_format;
130 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
131 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
132 errors::InvalidArgument("Invalid data format"));
133 if (context->device_type() == DEVICE_CPU) {
134 OP_REQUIRES(
135 context, data_format_ == FORMAT_NHWC,
136 errors::InvalidArgument("Default Pooling3DOp only supports NDHWC ",
137 "on device type ",
138 DeviceTypeString(context->device_type())));
139 }
140 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
141 OP_REQUIRES(context, ksize_.size() == 5,
142 errors::InvalidArgument("Sliding window ksize field must "
143 "specify 5 dimensions"));
144 bool non_negative =
145 std::all_of(ksize_.begin(), ksize_.end(), [](int k) { return k > 0; });
146 OP_REQUIRES(context, non_negative,
147 errors::InvalidArgument("Sliding window ksize field must "
148 "have non-negative dimensions"));
149 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
150 OP_REQUIRES(context, stride_.size() == 5,
151 errors::InvalidArgument("Sliding window stride field must "
152 "specify 5 dimensions"));
153 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
154 OP_REQUIRES(context,
155 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
156 GetTensorDim(stride_, data_format_, 'N') == 1),
157 errors::Unimplemented(
158 "Pooling is not yet supported on the batch dimension."));
159 OP_REQUIRES(context,
160 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
161 GetTensorDim(stride_, data_format_, 'C') == 1),
162 errors::Unimplemented(
163 "Pooling is not yet supported on the depth dimension."));
164 }
165
166 void Compute(OpKernelContext* context) override {
167 const Tensor& tensor_in = context->input(0);
168
169 OP_REQUIRES(context, tensor_in.dims() == 5,
170 errors::InvalidArgument("tensor_in must be 5-dimensional"));
171 const int64_t depth = GetTensorDim(tensor_in, data_format_, 'C');
172 const int64_t in_batch = GetTensorDim(tensor_in, data_format_, 'N');
173
174 // Dimension order for these arrays is: x, y, z.
175 std::array<int64_t, 3> input_size{
176 {GetTensorDim(tensor_in, data_format_, '2'),
177 GetTensorDim(tensor_in, data_format_, '1'),
178 GetTensorDim(tensor_in, data_format_, '0')}};
179 std::array<int64_t, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
180 GetTensorDim(ksize_, data_format_, '1'),
181 GetTensorDim(ksize_, data_format_, '0')}};
182 std::array<int64_t, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
183 GetTensorDim(stride_, data_format_, '1'),
184 GetTensorDim(stride_, data_format_, '0')}};
185 std::array<int64_t, 3> padding, out;
186
187 OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
188 padding_, &out, &padding));
189
190 TensorShape out_shape = ShapeFromFormat(data_format_, in_batch,
191 {{out[2], out[1], out[0]}}, depth);
192 Tensor* output;
193 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
194 if (out_shape.num_elements() == 0) return;
195 LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
196 padding, data_format_, padding_,
197 output);
198 }
199
200 private:
201 std::vector<int32> ksize_;
202 std::vector<int32> stride_;
203 Padding padding_;
204 TensorFormat data_format_;
205};
206
207template <typename T>
208struct LaunchMaxPooling3dGradOp<CPUDevice, T> {
209 static void launch(OpKernelContext* context, const Tensor& tensor_in,
210 const Tensor& tensor_out, const Tensor& out_backprop,
211 const std::array<int64, 3>& window,
212 const std::array<int64, 3>& stride,
213 const std::array<int64, 3>& out,
214 const std::array<int64, 3>& padding,
215 TensorFormat data_format, Tensor* output) {
216 output->flat<T>().setZero();
217 for (int64_t p = 0; p < out_backprop.dim_size(3); ++p) {
218 // Calculate broadcast size for planes/rows/cols. For SAME padding,
219 // current index could be in the padding area, and
220 // p * stride_planes + window_planes
221 // could be beyond the input tensor's boundary. In such cases, change
222 // the starting index and reduce the broadcast size.
223 //
224 // The same procedure is repeated for every spatial dimension in the
225 // nested loops below.
226 int pindex, psize;
227 std::array<int64_t, 3> input_size{{tensor_in.dim_size(3),
228 tensor_in.dim_size(2),
229 tensor_in.dim_size(1)}};
230 OP_REQUIRES_OK(context,
231 GetBroadcastSize(p, input_size[0], window[0], stride[0],
232 padding[0], &pindex, &psize));
233 for (int64_t r = 0; r < out_backprop.dim_size(2); ++r) {
234 int rindex, rsize;
235 OP_REQUIRES_OK(context,
236 GetBroadcastSize(r, input_size[1], window[1], stride[1],
237 padding[1], &rindex, &rsize));
238 for (int64_t c = 0; c < out_backprop.dim_size(1); ++c) {
239 int cindex, csize;
240 OP_REQUIRES_OK(
241 context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
242 padding[2], &cindex, &csize));
243 TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
244 TensorSlice dst{{0, -1},
245 {cindex, csize},
246 {rindex, rsize},
247 {pindex, psize},
248 {0, -1}};
249 Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
250 Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
251 Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
252 Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
253 src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
254 &src_sizes);
255 dst.FillIndicesAndSizes<5>(tensor_in.shape(), &dst_indices,
256 &dst_sizes);
257
258 Eigen::IndexList<Eigen::type2index<1>, int, int, int,
259 Eigen::type2index<1>>
260 bcast;
261 bcast.set(1, csize);
262 bcast.set(2, rsize);
263 bcast.set(3, psize);
264
265 // Slice from tensor_in.
266 Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_in_slice(dst_sizes);
267 tensor_in_slice.device(context->eigen_cpu_device()) =
268 tensor_in.tensor<T, 5>().slice(dst_indices, dst_sizes);
269
270 // Slice from tensor_out.
271 Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_out_slice(src_sizes);
272 tensor_out_slice.device(context->eigen_cpu_device()) =
273 tensor_out.tensor<T, 5>().slice(src_indices, src_sizes);
274
275 // Backprop slice.
276 Eigen::Tensor<T, 5, Eigen::RowMajor> out_backprop_slice(src_sizes);
277 out_backprop_slice.device(context->eigen_cpu_device()) =
278 out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
279
280 // The true backprop slice: if an element is the max, choose
281 // the backprop slice; otherwise set to 0.
282 Eigen::Tensor<T, 5, Eigen::RowMajor> select_slice(dst_sizes);
283 Eigen::Tensor<T, 5, Eigen::RowMajor> mat0(dst_sizes);
284 mat0.setZero();
285 select_slice =
286 ((tensor_in_slice - tensor_out_slice.broadcast(bcast)).abs() <
287 tensor_in_slice.constant(1e-5))
288 .select(out_backprop_slice.broadcast(bcast), mat0);
289
290 output->tensor<T, 5>()
291 .slice(dst_indices, dst_sizes)
292 .device(context->eigen_cpu_device()) += select_slice;
293 }
294 }
295 }
296 }
297};
298
299template <class Device, class T>
300class MaxPooling3dGradOp : public OpKernel {
301 public:
302 explicit MaxPooling3dGradOp(OpKernelConstruction* context)
303 : OpKernel(context) {
304 string data_format;
305 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
306 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
307 errors::InvalidArgument("Invalid data format"));
308 if (context->device_type() == DEVICE_CPU) {
309 OP_REQUIRES(
310 context, data_format_ == FORMAT_NHWC,
311 errors::InvalidArgument(
312 "Default MaxPooling3dGradOp only supports NDHWC ",
313 "on device type ", DeviceTypeString(context->device_type())));
314 }
315 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
316 OP_REQUIRES(context, ksize_.size() == 5,
317 errors::InvalidArgument("Sliding window ksize field must "
318 "specify 5 dimensions"));
319 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
320 OP_REQUIRES(context, stride_.size() == 5,
321 errors::InvalidArgument("Sliding window stride field must "
322 "specify 5 dimensions"));
323 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
324 OP_REQUIRES(context,
325 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
326 GetTensorDim(stride_, data_format_, 'N') == 1),
327 errors::Unimplemented(
328 "Pooling is not yet supported on the batch dimension."));
329 OP_REQUIRES(context,
330 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
331 GetTensorDim(stride_, data_format_, 'C') == 1),
332 errors::Unimplemented(
333 "Pooling is not yet supported on the depth dimension."));
334 }
335
336 void Compute(OpKernelContext* context) override {
337 const Tensor& tensor_in = context->input(0);
338 const Tensor& tensor_out = context->input(1);
339 const Tensor& out_backprop = context->input(2);
340 OP_REQUIRES(context, tensor_in.dims() == 5,
341 errors::InvalidArgument("tensor_in must be 5-dimensional"));
342 OP_REQUIRES(context, tensor_out.dims() == 5,
343 errors::InvalidArgument("tensor_out must be 5-dimensional"));
344 OP_REQUIRES(context, out_backprop.dims() == 5,
345 errors::InvalidArgument("out_backprop must be 5-dimensional"));
346
347 const TensorShape& output_shape = tensor_in.shape();
348 Tensor* input_backprop;
349 OP_REQUIRES_OK(context,
350 context->allocate_output(0, output_shape, &input_backprop));
351 std::array<int64_t, 3> input_size{
352 {GetTensorDim(output_shape, data_format_, '2'),
353 GetTensorDim(output_shape, data_format_, '1'),
354 GetTensorDim(output_shape, data_format_, '0')}};
355 std::array<int64_t, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
356 GetTensorDim(ksize_, data_format_, '1'),
357 GetTensorDim(ksize_, data_format_, '0')}};
358 std::array<int64_t, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
359 GetTensorDim(stride_, data_format_, '1'),
360 GetTensorDim(stride_, data_format_, '0')}};
361 std::array<int64_t, 3> out, padding;
362
363 OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
364 padding_, &out, &padding));
365
366 const int64_t depth = GetTensorDim(tensor_in, data_format_, 'C');
367 const int64_t in_batch = GetTensorDim(tensor_in, data_format_, 'N');
368 TensorShape out_shape = ShapeFromFormat(data_format_, in_batch,
369 {{out[2], out[1], out[0]}}, depth);
370 OP_REQUIRES(
371 context, tensor_out.shape() == out_shape,
372 errors::InvalidArgument("Expected orig_output shape to be ", out_shape,
373 ", but got ", tensor_out.shape()));
374 OP_REQUIRES(context, out_backprop.shape() == out_shape,
375 errors::InvalidArgument("Expected grad shape to be ", out_shape,
376 ", but got ", out_backprop.shape()));
377
378 LaunchMaxPooling3dGradOp<Device, T>::launch(
379 context, tensor_in, tensor_out, out_backprop, window, stride, out,
380 padding, data_format_, input_backprop);
381 }
382
383 private:
384 std::vector<int32> ksize_;
385 std::vector<int32> stride_;
386 Padding padding_;
387 TensorFormat data_format_;
388};
389
390template <typename T>
391struct LaunchAvgPooling3dGradOp<CPUDevice, T> {
392 static void launch(OpKernelContext* context,
393 const TensorShape& tensor_in_shape,
394 const Tensor& out_backprop,
395 const std::array<int64, 3>& window,
396 const std::array<int64, 3>& stride,
397 const std::array<int64, 3>& output_shape,
398 const std::array<int64, 3>& padding,
399 TensorFormat data_format, Tensor* output) {
400 OP_REQUIRES(
401 context, tensor_in_shape.dim_size(0) == out_backprop.dim_size(0),
402 errors::InvalidArgument(
403 "Expected first dimension of tensor_in_shape and "
404 "out_backprop to match, got ",
405 tensor_in_shape.dim_size(0), " and ", out_backprop.dim_size(0)));
406 OP_REQUIRES(
407 context, tensor_in_shape.dim_size(4) == out_backprop.dim_size(4),
408 errors::InvalidArgument(
409 "Expected last dimension of tensor_in_shape and "
410 "out_backprop to match, got ",
411 tensor_in_shape.dim_size(4), " and ", out_backprop.dim_size(4)));
412
413 output->flat<T>().setZero();
414 std::array<int64_t, 3> input_size = {{tensor_in_shape.dim_size(3),
415 tensor_in_shape.dim_size(2),
416 tensor_in_shape.dim_size(1)}};
417 for (int64_t p = 0; p < out_backprop.dim_size(3); ++p) {
418 // Calculate broadcast size for planes/rows/cols. For SAME padding,
419 // current index could be in the padding area, and
420 // p * stride_planes + window_planes
421 // could be beyond the input tensor's boundary. In such cases, change
422 // the starting index and reduce the broadcast size.
423 //
424 // The same procedure is repeated for every spatial dimension in the
425 // nested loops below.
426 int pindex, psize;
427 OP_REQUIRES_OK(context,
428 GetBroadcastSize(p, input_size[0], window[0], stride[0],
429 padding[0], &pindex, &psize));
430 for (int64_t r = 0; r < out_backprop.dim_size(2); ++r) {
431 int rindex, rsize;
432 OP_REQUIRES_OK(context,
433 GetBroadcastSize(r, input_size[1], window[1], stride[1],
434 padding[1], &rindex, &rsize));
435 for (int64_t c = 0; c < out_backprop.dim_size(1); ++c) {
436 int cindex, csize;
437 OP_REQUIRES_OK(
438 context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
439 padding[2], &cindex, &csize));
440 TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
441 TensorSlice dst{{0, -1},
442 {cindex, csize},
443 {rindex, rsize},
444 {pindex, psize},
445 {0, -1}};
446 Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
447 Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
448 Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
449 Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
450 src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
451 &src_sizes);
452 dst.FillIndicesAndSizes<5>(tensor_in_shape, &dst_indices, &dst_sizes);
453 Eigen::IndexList<Eigen::type2index<1>, int, int, int,
454 Eigen::type2index<1>>
455 bcast;
456 bcast.set(1, csize);
457 bcast.set(2, rsize);
458 bcast.set(3, psize);
459 Eigen::Tensor<T, 5, Eigen::RowMajor> slices(src_sizes);
460 slices.device(context->eigen_cpu_device()) =
461 out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
462 // Divide by the size of the actual patch (psize * rsize * csize).
463 float divide_size = rsize * csize * psize * 1.0f;
464 slices *= slices.constant(1.0f / divide_size);
465
466 output->tensor<T, 5>()
467 .slice(dst_indices, dst_sizes)
468 .device(context->eigen_cpu_device()) += slices.broadcast(bcast);
469 }
470 }
471 }
472 }
473};
474
475template <class Device, class T>
476class AvgPooling3dGradOp : public OpKernel {
477 public:
478 explicit AvgPooling3dGradOp(OpKernelConstruction* context)
479 : OpKernel(context) {
480 string data_format;
481 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
482 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
483 errors::InvalidArgument("Invalid data format"));
484 if (context->device_type() == DEVICE_CPU) {
485 OP_REQUIRES(
486 context, data_format_ == FORMAT_NHWC,
487 errors::InvalidArgument(
488 "Default AvgPooling3dGradOp only supports NDHWC ",
489 "on device type ", DeviceTypeString(context->device_type())));
490 }
491 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
492 OP_REQUIRES(context, ksize_.size() == 5,
493 errors::InvalidArgument("Sliding window ksize field must "
494 "specify 5 dimensions"));
495 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
496 OP_REQUIRES(context, stride_.size() == 5,
497 errors::InvalidArgument("Sliding window stride field must "
498 "specify 5 dimensions"));
499 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
500 OP_REQUIRES(context,
501 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
502 GetTensorDim(stride_, data_format_, 'N') == 1),
503 errors::Unimplemented(
504 "Pooling is not yet supported on the batch dimension."));
505 OP_REQUIRES(context,
506 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
507 GetTensorDim(stride_, data_format_, 'C') == 1),
508 errors::Unimplemented(
509 "Pooling is not yet supported on the depth dimension."));
510 }
511
512 void Compute(OpKernelContext* context) override {
513 const Tensor& tensor_in_shape = context->input(0);
514 const Tensor& out_backprop = context->input(1);
515 OP_REQUIRES(
516 context,
517 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 5,
518 errors::InvalidArgument("tensor_in must be 1-dimensional and 5 "
519 "elements"));
520 OP_REQUIRES(context, out_backprop.dims() == 5,
521 errors::InvalidArgument("out_backprop must be 5-dimensional"));
522
523 TensorShape output_shape;
524 auto shape_vec = tensor_in_shape.vec<int32>();
525 for (int64_t i = 0; i < tensor_in_shape.NumElements(); ++i) {
526 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(shape_vec(i)));
527 }
528
529 Tensor* output;
530 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
531
532 // Dimension order for these arrays is x, y, z.
533 std::array<int64_t, 3> input_size{
534 {GetTensorDim(output_shape, data_format_, '2'),
535 GetTensorDim(output_shape, data_format_, '1'),
536 GetTensorDim(output_shape, data_format_, '0')}};
537 std::array<int64_t, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
538 GetTensorDim(ksize_, data_format_, '1'),
539 GetTensorDim(ksize_, data_format_, '0')}};
540 std::array<int64_t, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
541 GetTensorDim(stride_, data_format_, '1'),
542 GetTensorDim(stride_, data_format_, '0')}};
543 std::array<int64_t, 3> padding, out;
544
545 OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
546 padding_, &out, &padding));
547
548 LaunchAvgPooling3dGradOp<Device, T>::launch(
549 context, output_shape, out_backprop, window, stride, out, padding,
550 data_format_, output);
551 }
552
553 private:
554 std::vector<int32> ksize_;
555 std::vector<int32> stride_;
556 Padding padding_;
557 TensorFormat data_format_;
558};
559
560template <typename T>
561struct LaunchMaxPooling3dGradGradOp<CPUDevice, T> {
562 static void launch(OpKernelContext* context, const Pool3dParameters& params,
563 const Tensor& tensor_in, const Tensor& tensor_out,
564 const Tensor& tensor_top_diff,
565 Tensor* tensor_bottom_diff) {
566 OP_REQUIRES(
567 context, params.data_format == FORMAT_NHWC,
568 errors::InvalidArgument("Default MaxPooling3dGradGradOp only supports",
569 "NDHWC on CPU device type"));
570
571 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
572 ConstEigenMatrixMap;
573 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
574 EigenMatrixMap;
575
576 ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
577 params.tensor_in_planes * params.tensor_in_cols *
578 params.tensor_in_rows *
579 params.tensor_in_batch);
580 ConstEigenMatrixMap out_mat(tensor_out.flat<T>().data(), params.depth,
581 params.out_plane * params.out_width *
582 params.out_height * params.tensor_in_batch);
583 ConstEigenMatrixMap top_diff_mat(
584 tensor_top_diff.flat<T>().data(), params.depth,
585 params.tensor_in_planes * params.tensor_in_cols *
586 params.tensor_in_rows * params.tensor_in_batch);
587 EigenMatrixMap bottom_diff_mat(
588 tensor_bottom_diff->flat<T>().data(), params.depth,
589 params.out_plane * params.out_width * params.out_height *
590 params.tensor_in_batch);
591
592 const DeviceBase::CpuWorkerThreads& worker_threads =
593 *(context->device()->tensorflow_cpu_worker_threads());
594
595 auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
596 int64_t start, int64_t limit) {
597 const int32_t depth = params.depth;
598 const int32_t in_planes = params.tensor_in_planes;
599 const int32_t in_rows = params.tensor_in_rows;
600 const int32_t in_cols = params.tensor_in_cols;
601 const int32_t pad_planes = params.pad_planes;
602 const int32_t pad_rows = params.pad_rows;
603 const int32_t pad_cols = params.pad_cols;
604 const int32_t window_planes = params.window_planes;
605 const int32_t window_rows = params.window_rows;
606 const int32_t window_cols = params.window_cols;
607 const int32_t plane_stride = params.plane_stride;
608 const int32_t row_stride = params.row_stride;
609 const int32_t col_stride = params.col_stride;
610 const int32_t out_plane = params.out_plane;
611 const int32_t out_height = params.out_height;
612 const int32_t out_width = params.out_width;
613
614 {
615 // Initializes the output grad backprop tensor with 0.
616 const int32_t output_image_size =
617 out_plane * out_height * out_width * params.depth;
618 EigenMatrixMap bottom_diff_shard(
619 bottom_diff_mat.data() + start * output_image_size, 1,
620 (limit - start) * output_image_size);
621 bottom_diff_shard.setZero();
622 }
623
624 for (int b = start; b < limit; ++b) {
625 for (int pp = 0; pp < out_plane; ++pp) {
626 for (int ph = 0; ph < out_height; ++ph) {
627 for (int pw = 0; pw < out_width; ++pw) {
628 // (p_start, p_end) * (h_start, h_end) * (w_start, w_end) is the
629 // range that the input vector projects to.
630 int p_start = pp * plane_stride - pad_planes;
631 const int p_end = std::min(p_start + window_planes, in_planes);
632 int h_start = ph * row_stride - pad_rows;
633 const int h_end = std::min(h_start + window_rows, in_rows);
634 int w_start = pw * col_stride - pad_cols;
635 const int w_end = std::min(w_start + window_cols, in_cols);
636 p_start = std::max(p_start, 0);
637 h_start = std::max(h_start, 0);
638 w_start = std::max(w_start, 0);
639 const int out_index =
640 ((b * out_plane + pp) * out_height + ph) * out_width + pw;
641 // Find value corresponding to the input maximum in top_diff.
642 for (int d = 0; d < depth; ++d) {
643 const T& output_ref = out_mat.coeffRef(d, out_index);
644 bool should_stop = false;
645 for (int p = p_start; p < p_end && !should_stop; ++p) {
646 for (int h = h_start; h < h_end && !should_stop; ++h) {
647 for (int w = w_start; w < w_end && !should_stop; ++w) {
648 const int in_index =
649 ((b * in_planes + p) * in_rows + h) * in_cols + w;
650 const T& input_ref = in_mat.coeffRef(d, in_index);
651 if (output_ref == input_ref) {
652 T& bottom_diff_ref =
653 bottom_diff_mat.coeffRef(d, out_index);
654 bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
655 should_stop = true;
656 }
657 }
658 }
659 }
660 }
661 }
662 }
663 }
664 }
665 };
666 const int64_t shard_cost =
667 params.out_plane * params.out_height * params.out_width * params.depth *
668 params.window_planes * params.window_rows * params.window_cols;
669 Shard(worker_threads.num_threads, worker_threads.workers,
670 params.tensor_in_batch, shard_cost, shard);
671 }
672};
673
674template <class Device, class T>
675class MaxPooling3dGradGradOp : public OpKernel {
676 public:
677 explicit MaxPooling3dGradGradOp(OpKernelConstruction* context)
678 : OpKernel(context) {
679 string data_format;
680 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
681 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
682 errors::InvalidArgument("Invalid data format"));
683 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
684 OP_REQUIRES(context, ksize_.size() == 5,
685 errors::InvalidArgument("Sliding window ksize field must "
686 "specify 5 dimensions"));
687 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
688 OP_REQUIRES(context, stride_.size() == 5,
689 errors::InvalidArgument("Sliding window strides field must "
690 "specify 5 dimensions"));
691 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
692 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
693 errors::Unimplemented(
694 "Pooling is not yet supported on the batch dimension."));
695 const int32_t ksize_c = GetTensorDim(ksize_, data_format_, 'C');
696 const int32_t stride_c = GetTensorDim(stride_, data_format_, 'C');
697 OP_REQUIRES(context, ksize_c == 1 && stride_c == 1,
698 errors::Unimplemented("MaxPooling3dGradGrad is not yet "
699 "supported on the depth dimension."));
700 }
701
702 void Compute(OpKernelContext* context) override {
703 const Tensor& tensor_in = context->input(0);
704 const Tensor& tensor_out = context->input(1);
705 const Tensor& out_grad_backprop = context->input(2);
706
707 // For maxpooling3d, tensor_in should have 5 dimensions.
708 OP_REQUIRES(context, tensor_in.dims() == 5,
709 errors::InvalidArgument("tensor_in must be 5-dimensional"));
710 OP_REQUIRES(context, tensor_out.dims() == 5,
711 errors::InvalidArgument("tensor_out must be 5-dimensional"));
712 // For maxpooling3d, out_grad_backprop should have 5 dimensions.
713 OP_REQUIRES(
714 context, out_grad_backprop.dims() == 5,
715 errors::InvalidArgument("out_grad_backprop must be 5-dimensional"));
716
717 Pool3dParameters params{context, ksize_, stride_,
718 padding_, data_format_, tensor_in.shape()};
719 if (!context->status().ok()) return; // params is invalid
720 OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
721 errors::InvalidArgument("Expected orig_output shape to be ",
722 params.forward_output_shape(),
723 ", but got ", tensor_out.shape()));
724 OP_REQUIRES(
725 context, out_grad_backprop.shape() == tensor_in.shape(),
726 errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
727 ", but got ", out_grad_backprop.shape()));
728
729 Tensor* output = nullptr;
730 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
731 {2}, 0, tensor_out.shape(), &output));
732
733 // Given access patterns in LaunchMaxPooling3dGradGradOp, these tensors must
734 // have elements.
735 OP_REQUIRES(context, tensor_in.NumElements() > 0,
736 errors::InvalidArgument("received empty tensor tensor_in: ",
737 tensor_in.DebugString()));
738 OP_REQUIRES(context, tensor_out.NumElements() > 0,
739 errors::InvalidArgument("received empty tensor tensor_out: ",
740 tensor_out.DebugString()));
741 OP_REQUIRES(
742 context, out_grad_backprop.NumElements() > 0,
743 errors::InvalidArgument("received empty tensor out_grad_backprop: ",
744 out_grad_backprop.DebugString()));
745 OP_REQUIRES(context,
746 tensor_in.NumElements() == out_grad_backprop.NumElements(),
747 errors::InvalidArgument("tensor_in and out_grad_backprop must "
748 "have same number of elements, got <",
749 tensor_in.DebugString(), "> and <",
750 out_grad_backprop.DebugString(), ">"));
751 OP_REQUIRES(
752 context, tensor_out.NumElements() == output->NumElements(),
753 errors::InvalidArgument(
754 "tensor_out and output must have same number of elements, got <",
755 tensor_out.DebugString(), "> and <", output->DebugString(), ">"));
756
757 LaunchMaxPooling3dGradGradOp<Device, T>::launch(
758 context, params, tensor_in, tensor_out, out_grad_backprop, output);
759 }
760
761 private:
762 std::vector<int32> ksize_;
763 std::vector<int32> stride_;
764 Padding padding_;
765 TensorFormat data_format_;
766};
767
768#define REGISTER_KERNELS(D, T) \
769 REGISTER_KERNEL_BUILDER( \
770 Name("MaxPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \
771 Pooling3DOp<D##Device, T, MAX>); \
772 REGISTER_KERNEL_BUILDER(Name("MaxPool3DGrad") \
773 .Device(DEVICE_##D) \
774 .TypeConstraint<T>("T") \
775 .TypeConstraint<T>("TInput"), \
776 MaxPooling3dGradOp<D##Device, T>); \
777 REGISTER_KERNEL_BUILDER( \
778 Name("MaxPool3DGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
779 MaxPooling3dGradGradOp<D##Device, T>); \
780 REGISTER_KERNEL_BUILDER( \
781 Name("AvgPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \
782 Pooling3DOp<D##Device, T, AVG>); \
783 REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") \
784 .Device(DEVICE_##D) \
785 .TypeConstraint<T>("T") \
786 .HostMemory("orig_input_shape"), \
787 AvgPooling3dGradOp<D##Device, T>);
788
789#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T)
790TF_CALL_float(REGISTER_CPU_KERNELS);
791#undef REGISTER_CPU_KERNELS
792
793#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
794
795template <typename T>
796struct LaunchPoolingOp<GPUDevice, T, AVG> {
797 static void launch(OpKernelContext* context, const Tensor& tensor_in,
798 const std::array<int64, 3>& window,
799 const std::array<int64, 3>& stride,
800 const std::array<int64, 3>& padding,
801 TensorFormat data_format, Padding padding_type,
802 Tensor* output) {
803 DnnPooling3dOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, window,
804 stride, padding, data_format, tensor_in, output);
805 }
806};
807
808template <typename T>
809struct LaunchPoolingOp<GPUDevice, T, MAX> {
810 static void launch(OpKernelContext* context, const Tensor& tensor_in,
811 const std::array<int64, 3>& window,
812 const std::array<int64, 3>& stride,
813 const std::array<int64, 3>& padding,
814 TensorFormat data_format, Padding padding_type,
815 Tensor* output) {
816 DnnPooling3dOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, window,
817 stride, padding, data_format, tensor_in, output);
818 }
819};
820
821template <typename T>
822struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
823 static void launch(OpKernelContext* context, const Tensor& tensor_in,
824 const Tensor& tensor_out, const Tensor& out_backprop,
825 const std::array<int64, 3>& window,
826 const std::array<int64, 3>& stride,
827 const std::array<int64, 3>& out,
828 const std::array<int64, 3>& padding,
829 TensorFormat data_format, Tensor* input_backprop) {
830 const TensorShape output_shape = tensor_in.shape();
831 DnnPooling3dGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum,
832 window, stride, padding, out, data_format,
833 out_backprop, output_shape, &tensor_in,
834 &tensor_out, input_backprop);
835 }
836};
837
838template <typename T>
839struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
840 static void launch(OpKernelContext* context,
841 const TensorShape& tensor_in_shape,
842 const Tensor& out_backprop,
843 const std::array<int64, 3>& window,
844 const std::array<int64, 3>& stride,
845 const std::array<int64, 3>& out,
846 const std::array<int64, 3>& padding,
847 TensorFormat data_format, Tensor* output) {
848 DnnPooling3dGradOp<T>::Compute(
849 context, se::dnn::PoolingMode::kAverage, window, stride, padding, out,
850 data_format, out_backprop, tensor_in_shape, nullptr, nullptr, output);
851 }
852};
853
854template <typename T>
855struct LaunchMaxPooling3dGradGradOp<GPUDevice, T> {
856 static void launch(OpKernelContext* context, const Pool3dParameters& params,
857 const Tensor& tensor_in, const Tensor& tensor_out,
858 const Tensor& tensor_top_diff,
859 Tensor* tensor_bottom_diff) {
860 bool status = functor::MaxPool3dGradBackward<T>()(
861 params.data_format, tensor_in.flat<T>().data(),
862 tensor_out.flat<T>().data(), params.tensor_in_batch, params.out_plane,
863 params.out_height, params.out_width, params.depth,
864 params.tensor_in_planes, params.tensor_in_rows, params.tensor_in_cols,
865 params.window_planes, params.window_rows, params.window_cols,
866 params.plane_stride, params.row_stride, params.col_stride,
867 params.pad_planes, params.pad_rows, params.pad_cols,
868 tensor_top_diff.flat<T>().data(), tensor_bottom_diff->flat<T>().data(),
869 context->eigen_gpu_device());
870 if (!status) {
871 context->SetStatus(
872 errors::Internal("Failed launching MaxPool3dGradBackward"));
873 }
874 }
875};
876
877#define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T)
878TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS)
879#undef REGISTER_GPU_KERNELS
880
881#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
882
883
884#undef REGISTER_KERNELS
885
886} // namespace tensorflow
887