1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#define EIGEN_USE_THREADS
16
17#include <algorithm>
18#include <cmath>
19#include <random>
20#include <vector>
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/numeric_op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/op_requires.h"
26#include "tensorflow/core/kernels/fractional_pool_common.h"
27#include "tensorflow/core/lib/random/random.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/mutex.h"
31#include "tensorflow/core/util/guarded_philox_random.h"
32
33namespace tensorflow {
34typedef Eigen::ThreadPoolDevice CPUDevice;
35
36template <typename T>
37class FractionalMaxPoolOp : public OpKernel {
38 public:
39 explicit FractionalMaxPoolOp(OpKernelConstruction* context)
40 : OpKernel(context) {
41 OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
42 OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
43 OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
44
45 OP_REQUIRES(context, pooling_ratio_.size() == 4,
46 errors::InvalidArgument("pooling_ratio field must "
47 "specify 4 dimensions"));
48
49 OP_REQUIRES(
50 context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
51 errors::Unimplemented("Fractional max pooling is not yet "
52 "supported on the batch nor channel dimension."));
53
54 OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
55 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
56 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
57 if (deterministic_) {
58 // If both seeds are not set when deterministic_ is true, force set seeds.
59 if ((seed_ == 0) && (seed2_ == 0)) {
60 seed_ = random::New64();
61 seed2_ = random::New64();
62 }
63 } else {
64 OP_REQUIRES(
65 context, (seed_ == 0) && (seed2_ == 0),
66 errors::InvalidArgument(
67 "Both seed and seed2 should be 0 if deterministic is false."));
68 }
69 }
70
71 void Compute(OpKernelContext* context) override {
72 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
73 ConstEigenMatrixMap;
74 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
75 EigenMatrixMap;
76
77 constexpr int tensor_in_and_out_dims = 4;
78
79 const Tensor& tensor_in = context->input(0);
80 OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
81 errors::InvalidArgument("tensor_in must be 4-dimensional"));
82
83 std::vector<int> input_size(tensor_in_and_out_dims);
84 std::vector<int> output_size(tensor_in_and_out_dims);
85 for (int i = 0; i < tensor_in_and_out_dims; ++i) {
86 input_size[i] = tensor_in.dim_size(i);
87
88 OP_REQUIRES(
89 context, input_size[i] >= pooling_ratio_[i],
90 errors::InvalidArgument("Pooling ratio is higher than input "
91 "dimension size for dimension ",
92 i, ". Input dim size: ", input_size[i],
93 " pooling ratio: ", pooling_ratio_[i]));
94 }
95 // Output size.
96 for (int i = 0; i < tensor_in_and_out_dims; ++i) {
97 // This must match the same logic in the shape function in
98 // core/ops/nn_ops.cc.
99 output_size[i] =
100 static_cast<int>(std::floor(input_size[i] / pooling_ratio_[i]));
101 DCHECK_GT(output_size[i], 0);
102 }
103
104 // Generate pooling sequence.
105 std::vector<int64_t> height_cum_seq;
106 std::vector<int64_t> width_cum_seq;
107 GuardedPhiloxRandom generator;
108 generator.Init(seed_, seed2_);
109 height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
110 &generator, pseudo_random_);
111 width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
112 &generator, pseudo_random_);
113
114 // Prepare output.
115 Tensor* output_tensor = nullptr;
116 OP_REQUIRES_OK(context, context->allocate_output(
117 0,
118 TensorShape({output_size[0], output_size[1],
119 output_size[2], output_size[3]}),
120 &output_tensor));
121 Tensor* output_height_seq_tensor = nullptr;
122 OP_REQUIRES_OK(
123 context,
124 context->allocate_output(
125 1, TensorShape({static_cast<int64_t>(height_cum_seq.size())}),
126 &output_height_seq_tensor));
127 Tensor* output_width_seq_tensor = nullptr;
128 OP_REQUIRES_OK(
129 context,
130 context->allocate_output(
131 2, TensorShape({static_cast<int64_t>(width_cum_seq.size())}),
132 &output_width_seq_tensor));
133
134 ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
135 input_size[2] * input_size[1] * input_size[0]);
136
137 EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
138 output_size[2] * output_size[1] * output_size[0]);
139
140 // Initializes the output tensor with MIN<T>.
141 output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
142
143 auto output_height_seq_flat = output_height_seq_tensor->flat<int64_t>();
144 auto output_width_seq_flat = output_width_seq_tensor->flat<int64_t>();
145
146 // Set output tensors.
147 for (int i = 0; i < height_cum_seq.size(); ++i) {
148 output_height_seq_flat(i) = height_cum_seq[i];
149 }
150
151 for (int i = 0; i < width_cum_seq.size(); ++i) {
152 output_width_seq_flat(i) = width_cum_seq[i];
153 }
154
155 // For both input and output,
156 // 0: batch
157 // 1: height / row
158 // 2: width / col
159 // 3: depth / channel
160 const int64_t height_max = input_size[1] - 1;
161 const int64_t width_max = input_size[2] - 1;
162 for (int64_t b = 0; b < input_size[0]; ++b) {
163 // height sequence.
164 for (int64_t hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
165 // height start and end.
166 const int64_t height_start = height_cum_seq[hs];
167 int64_t height_end =
168 overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
169 height_end = std::min(height_end, height_max);
170
171 // width sequence.
172 for (int64_t ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
173 const int64_t out_offset =
174 (b * output_size[1] + hs) * output_size[2] + ws;
175 // width start and end.
176 const int64_t width_start = width_cum_seq[ws];
177 int64_t width_end =
178 overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
179 width_end = std::min(width_end, width_max);
180 for (int64_t h = height_start; h <= height_end; ++h) {
181 for (int64_t w = width_start; w <= width_end; ++w) {
182 const int64_t in_offset =
183 (b * input_size[1] + h) * input_size[2] + w;
184 out_mat.col(out_offset) =
185 out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
186 }
187 }
188 }
189 }
190 }
191 }
192
193 private:
194 bool deterministic_;
195 int64_t seed_;
196 int64_t seed2_;
197 std::vector<float> pooling_ratio_;
198 bool pseudo_random_;
199 bool overlapping_;
200};
201
202#define REGISTER_FRACTIONALMAXPOOL(type) \
203 REGISTER_KERNEL_BUILDER( \
204 Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
205 FractionalMaxPoolOp<type>)
206
207REGISTER_FRACTIONALMAXPOOL(int32);
208REGISTER_FRACTIONALMAXPOOL(int64_t);
209REGISTER_FRACTIONALMAXPOOL(float);
210REGISTER_FRACTIONALMAXPOOL(double);
211
212#undef REGISTER_FRACTIONALMAXPOOL
213
214static const int kInvalidMaxPoolingIndex = -1;
215
216template <class T>
217class FractionalMaxPoolGradOp : public OpKernel {
218 public:
219 explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
220 : OpKernel(context) {
221 OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
222 }
223
224 void Compute(OpKernelContext* context) override {
225 // There are two steps when calculating gradient for FractionalMaxPool.
226 // 1) Walk through the process of calculating fractional pooling given
227 // pooling region; however, in the process, keep track of where the max
228 // element comes from. (arg_max)
229 // 2) Populate the value of out_backprop to where arg_max indicates. If
230 // we support overlapping, it is likely to have multiple out_backprop[i]
231 // propagates back to the same arg_max value.
232 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
233 ConstEigenMatrixMap;
234 typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
235 EigenMatrixMap;
236 typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
237 EigenIndexMatrixMap;
238
239 const Tensor& tensor_in = context->input(0);
240 const Tensor& tensor_out = context->input(1);
241 const Tensor& out_backprop = context->input(2);
242 const Tensor& height_seq_tensor = context->input(3);
243 const Tensor& width_seq_tensor = context->input(4);
244
245 // Just to make it similar to FractionalMaxPoolOp.
246 constexpr int tensor_in_and_out_dims = 4;
247 OP_REQUIRES(
248 context, tensor_in.dims() == tensor_in_and_out_dims,
249 errors::InvalidArgument("orig_input should be a tensor of rank 4, got ",
250 tensor_in.DebugString()));
251 OP_REQUIRES(context, tensor_in.NumElements() > 0,
252 errors::InvalidArgument("orig_input must not be empty, got ",
253 tensor_in.DebugString()));
254 OP_REQUIRES(context, tensor_out.dims() == tensor_in_and_out_dims,
255 errors::InvalidArgument(
256 "orig_output should be a tensor of rank 4, got ",
257 tensor_out.DebugString()));
258 OP_REQUIRES(context, tensor_out.NumElements() > 0,
259 errors::InvalidArgument("orig_output must not be empty, got ",
260 tensor_out.DebugString()));
261 std::vector<int64_t> input_size(tensor_in_and_out_dims);
262 std::vector<int64_t> output_size(tensor_in_and_out_dims);
263 for (int i = 0; i < tensor_in_and_out_dims; ++i) {
264 input_size[i] = tensor_in.dim_size(i);
265 }
266 for (int i = 0; i < tensor_in_and_out_dims; ++i) {
267 output_size[i] = tensor_out.dim_size(i);
268 }
269
270 // ---------
271 // Step 1
272 // ---------
273 Tensor tensor_out_dup;
274 OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
275 {1}, DataTypeToEnum<T>::v(), tensor_out.shape(),
276 &tensor_out_dup));
277 Tensor tensor_out_arg_max;
278 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64_t>::v(),
279 tensor_out.shape(),
280 &tensor_out_arg_max));
281 // Find arg_max for each tensor_out
282 ConstEigenMatrixMap tensor_in_mat(
283 tensor_in.flat<T>().data(), input_size[3],
284 input_size[2] * input_size[1] * input_size[0]);
285 EigenMatrixMap tensor_out_dup_mat(
286 tensor_out_dup.flat<T>().data(), output_size[3],
287 output_size[2] * output_size[1] * output_size[0]);
288 EigenIndexMatrixMap tensor_out_arg_max_mat(
289 tensor_out_arg_max.flat<int64_t>().data(), output_size[3],
290 output_size[2] * output_size[1] * output_size[0]);
291
292 tensor_out_arg_max.flat<int64_t>().setConstant(kInvalidMaxPoolingIndex);
293 // Initializes the duplicate output tensor with MIN<T>.
294 tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
295
296 auto height_seq_tensor_flat = height_seq_tensor.flat<int64_t>();
297 auto width_seq_tensor_flat = width_seq_tensor.flat<int64_t>();
298
299 // Now walk through the process of fractional max pooling again.
300 // For both input and output,
301 // 0: batch
302 // 1: height / row
303 // 2: width / col
304 // 3: depth / channel
305 const int64_t height_max = input_size[1] - 1;
306 const int64_t width_max = input_size[2] - 1;
307 for (int64_t b = 0; b < input_size[0]; ++b) {
308 // height sequence.
309 for (int64_t hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
310 // height start and end.
311 const int64_t height_start = height_seq_tensor_flat(hs);
312 int64_t height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
313 : height_seq_tensor_flat(hs + 1) - 1;
314 height_end = std::min(height_end, height_max);
315
316 // width sequence.
317 for (int64_t ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
318 const int64_t out_index =
319 (b * output_size[1] + hs) * output_size[2] + ws;
320 // width start and end.
321 const int64_t width_start = width_seq_tensor_flat(ws);
322 int64_t width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
323 : width_seq_tensor_flat(ws + 1) - 1;
324 width_end = std::min(width_end, width_max);
325 for (int64_t h = height_start; h <= height_end; ++h) {
326 for (int64_t w = width_start; w <= width_end; ++w) {
327 const int64_t in_index =
328 (b * input_size[1] + h) * input_size[2] + w;
329 // Walk through each channel (depth).
330 for (int64_t d = 0; d < input_size[3]; ++d) {
331 const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
332 T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
333 int64_t& out_arg_max_ref =
334 tensor_out_arg_max_mat.coeffRef(d, out_index);
335 if (output_ref < input_ref ||
336 out_arg_max_ref == kInvalidMaxPoolingIndex) {
337 output_ref = input_ref;
338 int input_offset = in_index * input_size[3] + d;
339 out_arg_max_ref = input_offset;
340 }
341 }
342 }
343 }
344 }
345 }
346 }
347
348 // Check tensor_out_dup is the same as tensor_out.
349 ConstEigenMatrixMap tensor_out_mat(
350 tensor_out.flat<T>().data(), output_size[3],
351 output_size[2] * output_size[1] * output_size[0]);
352 const int64_t num_reshaped_cols =
353 output_size[2] * output_size[1] * output_size[0];
354 for (int64_t i = 0; i < num_reshaped_cols; ++i) {
355 for (int64_t j = 0; j < output_size[3]; ++j) {
356 OP_REQUIRES(context, tensor_out_dup_mat(j, i) == tensor_out_mat(j, i),
357 errors::InvalidArgument(
358 "tensor_out_dup is not the same as tensor_out"));
359 }
360 }
361
362 Tensor* output = nullptr;
363 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
364 {0}, 0, tensor_in.shape(), &output));
365 output->flat<T>().setZero();
366
367 auto out_backprop_flat = out_backprop.flat<T>();
368 auto input_backprop_flat = output->flat<T>();
369 auto out_arg_max_flat = tensor_out_arg_max.flat<int64_t>();
370 int num_total_outputs = out_backprop_flat.size();
371 int num_total_inputs = input_backprop_flat.size();
372
373 for (int index = 0; index < num_total_outputs; ++index) {
374 int input_backprop_index = out_arg_max_flat(index);
375 OP_REQUIRES(
376 context,
377 input_backprop_index >= 0 && input_backprop_index < num_total_inputs,
378 errors::InvalidArgument(
379 "Invalid input backprop index: ", input_backprop_index, ", ",
380 num_total_inputs));
381 input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
382 }
383 }
384
385 private:
386 bool overlapping_;
387};
388
389#define REGISTER_FRACTIONALMAXPOOLGRAD(type) \
390 REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad") \
391 .Device(DEVICE_CPU) \
392 .TypeConstraint<type>("T"), \
393 FractionalMaxPoolGradOp<type>)
394
395REGISTER_FRACTIONALMAXPOOLGRAD(int32);
396REGISTER_FRACTIONALMAXPOOLGRAD(int64_t);
397REGISTER_FRACTIONALMAXPOOLGRAD(float);
398REGISTER_FRACTIONALMAXPOOLGRAD(double);
399
400#undef REGISTER_FRACTIONALMAXPOOLGRAD
401} // namespace tensorflow
402