1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/dilation_ops.h"
21
22#include <cfloat>
23#include <vector>
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/framework/kernel_shape_util.h"
28#include "tensorflow/core/framework/numeric_op.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_slice.h"
34#include "tensorflow/core/lib/core/errors.h"
35#include "tensorflow/core/lib/gtl/array_slice.h"
36#include "tensorflow/core/util/determinism.h"
37#include "tensorflow/core/util/padding.h"
38
39namespace tensorflow {
40
41typedef Eigen::ThreadPoolDevice CPUDevice;
42typedef Eigen::GpuDevice GPUDevice;
43
44void ParseAttributes(OpKernelConstruction* context, std::vector<int32>* strides,
45 std::vector<int32>* rates, Padding* padding) {
46 OP_REQUIRES_OK(context, context->GetAttr("strides", strides));
47 OP_REQUIRES(context, strides->size() == 4,
48 errors::InvalidArgument("Sliding window stride field must "
49 "specify 4 dimensions"));
50 OP_REQUIRES(context, (*strides)[0] == 1 && (*strides)[3] == 1,
51 errors::Unimplemented(
52 "Stride is only supported across spatial dimensions."));
53
54 OP_REQUIRES_OK(context, context->GetAttr("rates", rates));
55 OP_REQUIRES(context, rates->size() == 4,
56 errors::InvalidArgument("Input stride (atrous rate) field "
57 "must specify 4 dimensions"));
58 OP_REQUIRES(context, (*rates)[0] == 1 && (*rates)[3] == 1,
59 errors::Unimplemented(
60 "Rate is only supported across spatial dimensions."));
61
62 OP_REQUIRES_OK(context, context->GetAttr("padding", padding));
63}
64
65void ParseSizes(OpKernelContext* context, const std::vector<int32>& strides,
66 const std::vector<int32>& rates, const Padding& padding,
67 int* stride_rows, int* stride_cols, int* rate_rows,
68 int* rate_cols, int64_t* pad_top, int64_t* pad_left,
69 int64_t* out_rows, int64_t* out_cols) {
70 // Input tensor is of the following dimensions:
71 // [ batch, input_rows, input_cols, depth ]
72 const Tensor& input = context->input(0);
73 OP_REQUIRES(context, input.dims() == 4,
74 errors::InvalidArgument("input must be 4-dimensional",
75 input.shape().DebugString()));
76 const int input_rows = input.dim_size(1);
77 const int input_cols = input.dim_size(2);
78 const int depth = input.dim_size(3);
79
80 // For now we take the stride and rate from the second and third dimensions
81 // only (we do not support striding on the batch or depth dimension).
82 *stride_rows = strides[1];
83 *stride_cols = strides[2];
84 *rate_rows = rates[1];
85 *rate_cols = rates[2];
86
87 // Input filter is of the following dimensions:
88 // [ filter_rows, filter_cols, depth ]
89 const Tensor& filter = context->input(1);
90 OP_REQUIRES(context, filter.dims() == 3,
91 errors::InvalidArgument("filter must be 3-dimensional: ",
92 filter.shape().DebugString()));
93 const int filter_rows = filter.dim_size(0);
94 const int filter_cols = filter.dim_size(1);
95 OP_REQUIRES(context, depth == filter.dim_size(2),
96 errors::InvalidArgument(
97 "input and filter must have the same depth: ", depth, " vs ",
98 filter.dim_size(2)));
99
100 // Effective filter size, after introducing rate - 1 zeros between each
101 // non-zero filter element.
102 const int filter_rows_eff =
103 filter_rows + (filter_rows - 1) * (*rate_rows - 1);
104 const int filter_cols_eff =
105 filter_cols + (filter_cols - 1) * (*rate_cols - 1);
106
107 OP_REQUIRES_OK(
108 context, GetWindowedOutputSize(input_rows, filter_rows_eff, *stride_rows,
109 padding, out_rows, pad_top));
110 OP_REQUIRES_OK(
111 context, GetWindowedOutputSize(input_cols, filter_cols_eff, *stride_cols,
112 padding, out_cols, pad_left));
113}
114
115template <typename Device, typename T>
116class DilationOp : public OpKernel {
117 public:
118 explicit DilationOp(OpKernelConstruction* context) : OpKernel(context) {
119 ParseAttributes(context, &strides_, &rates_, &padding_);
120 }
121
122 void Compute(OpKernelContext* context) override {
123 const Tensor& input = context->input(0);
124 const Tensor& filter = context->input(1);
125
126 // Determine relevant sizes from input and filters.
127 int stride_rows = 0, stride_cols = 0;
128 int rate_rows = 0, rate_cols = 0;
129 int64_t pad_top = 0, pad_left = 0;
130 int64_t out_rows = 0, out_cols = 0;
131 ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
132 &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
133 &out_cols);
134 if (!context->status().ok()) return;
135
136 // Output tensor is of the following dimensions:
137 // [ batch, out_rows, out_cols, depth ]
138 const int batch = input.dim_size(0);
139 const int depth = input.dim_size(3);
140 const std::vector<int64_t> out_sizes = {batch, out_rows, out_cols, depth};
141 TensorShape out_shape(out_sizes);
142
143 Tensor* output = nullptr;
144 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
145
146 // If there is nothing to compute, return.
147 if (out_shape.num_elements() == 0) {
148 return;
149 }
150
151 functor::Dilation<Device, T>()(
152 context->eigen_device<Device>(), input.tensor<T, 4>(),
153 filter.tensor<T, 3>(), stride_rows, stride_cols, rate_rows, rate_cols,
154 pad_top, pad_left, output->tensor<T, 4>());
155 }
156
157 std::vector<int32> strides_;
158 std::vector<int32> rates_;
159 Padding padding_;
160};
161
162// Partial specialization of Dilation functor for a CPUDevice.
163namespace functor {
164template <typename T>
165struct Dilation<CPUDevice, T> {
166 void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
167 typename TTypes<T, 3>::ConstTensor filter, int stride_rows,
168 int stride_cols, int rate_rows, int rate_cols, int pad_top,
169 int pad_left, typename TTypes<T, 4>::Tensor output) {
170 const int batch = input.dimension(0);
171 const int input_rows = input.dimension(1);
172 const int input_cols = input.dimension(2);
173 const int depth = input.dimension(3);
174
175 const int filter_rows = filter.dimension(0);
176 const int filter_cols = filter.dimension(1);
177
178 const int output_rows = output.dimension(1);
179 const int output_cols = output.dimension(2);
180
181 // This is a reference implementation, likely to be slow.
182 // TODO(gpapan): Write multi-threaded implementation.
183 for (int b = 0; b < batch; ++b) {
184 for (int h_out = 0; h_out < output_rows; ++h_out) {
185 int h_beg = h_out * stride_rows - pad_top;
186 for (int w_out = 0; w_out < output_cols; ++w_out) {
187 int w_beg = w_out * stride_cols - pad_left;
188 for (int d = 0; d < depth; ++d) {
189 T cur_val = Eigen::NumTraits<T>::lowest();
190 for (int h = 0; h < filter_rows; ++h) {
191 const int h_in = h_beg + h * rate_rows;
192 if (h_in >= 0 && h_in < input_rows) {
193 for (int w = 0; w < filter_cols; ++w) {
194 const int w_in = w_beg + w * rate_cols;
195 if (w_in >= 0 && w_in < input_cols) {
196 const T val = input(b, h_in, w_in, d) + filter(h, w, d);
197 if (val > cur_val) {
198 cur_val = val;
199 }
200 }
201 }
202 }
203 }
204 output(b, h_out, w_out, d) = cur_val;
205 }
206 }
207 }
208 }
209 }
210};
211} // namespace functor
212
213template <typename Device, typename T>
214class DilationBackpropInputOp : public OpKernel {
215 public:
216 explicit DilationBackpropInputOp(OpKernelConstruction* context)
217 : OpKernel(context) {
218 ParseAttributes(context, &strides_, &rates_, &padding_);
219 }
220
221 void Compute(OpKernelContext* context) override {
222 const Tensor& input = context->input(0);
223 const Tensor& filter = context->input(1);
224 const Tensor& out_backprop = context->input(2);
225
226 if (std::is_same<Device, GPUDevice>::value) {
227 OP_REQUIRES(context, !tensorflow::OpDeterminismRequired(),
228 errors::Unimplemented("Determinism is not yet supported "
229 "for Dilation2DBackpropInput."));
230 }
231 // Determine relevant sizes from input and filters.
232 int stride_rows = 0, stride_cols = 0;
233 int rate_rows = 0, rate_cols = 0;
234 int64_t pad_top = 0, pad_left = 0;
235 int64_t out_rows = 0, out_cols = 0;
236 ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
237 &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
238 &out_cols);
239 if (!context->status().ok()) return;
240
241 // Verify that the incoming gradient tensor has the expected size
242 // [ batch, out_rows, out_cols, depth ]
243 const int batch = input.dim_size(0);
244 const int depth = input.dim_size(3);
245 OP_REQUIRES(context,
246 batch == out_backprop.dim_size(0) &&
247 out_rows == out_backprop.dim_size(1) &&
248 out_cols == out_backprop.dim_size(2) &&
249 depth == out_backprop.dim_size(3),
250 errors::InvalidArgument("out_backprop has incompatible size."));
251
252 // The computed in_backprop has the same dimensions as the input:
253 // [ batch, input_rows, input_cols, depth ]
254 Tensor* in_backprop = nullptr;
255 OP_REQUIRES_OK(context,
256 context->allocate_output(0, input.shape(), &in_backprop));
257
258 // If there is nothing to compute, return.
259 if (input.shape().num_elements() == 0) {
260 return;
261 }
262
263 functor::DilationBackpropInput<Device, T>()(
264 context->eigen_device<Device>(), input.tensor<T, 4>(),
265 filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
266 stride_cols, rate_rows, rate_cols, pad_top, pad_left,
267 in_backprop->tensor<T, 4>());
268 }
269
270 std::vector<int32> strides_;
271 std::vector<int32> rates_;
272 Padding padding_;
273};
274
275// Partial specialization of DilationBackpropInput functor for a CPUDevice.
276namespace functor {
277template <typename T>
278struct DilationBackpropInput<CPUDevice, T> {
279 void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
280 typename TTypes<T, 3>::ConstTensor filter,
281 typename TTypes<T, 4>::ConstTensor out_backprop,
282 int stride_rows, int stride_cols, int rate_rows,
283 int rate_cols, int pad_top, int pad_left,
284 typename TTypes<T, 4>::Tensor in_backprop) {
285 const int batch = input.dimension(0);
286 const int input_rows = input.dimension(1);
287 const int input_cols = input.dimension(2);
288 const int depth = input.dimension(3);
289
290 const int filter_rows = filter.dimension(0);
291 const int filter_cols = filter.dimension(1);
292
293 const int output_rows = out_backprop.dimension(1);
294 const int output_cols = out_backprop.dimension(2);
295
296 // Initialize gradient with all zeros.
297 in_backprop.setZero();
298
299 // This is a reference implementation, likely to be slow.
300 // TODO(gpapan): Write multi-threaded implementation.
301 // In the case of multiple argmax branches, we only back-propagate along the
302 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
303 // similarly to the max-pooling backward routines.
304 for (int b = 0; b < batch; ++b) {
305 for (int h_out = 0; h_out < output_rows; ++h_out) {
306 int h_beg = h_out * stride_rows - pad_top;
307 for (int w_out = 0; w_out < output_cols; ++w_out) {
308 int w_beg = w_out * stride_cols - pad_left;
309 for (int d = 0; d < depth; ++d) {
310 T cur_val = Eigen::NumTraits<T>::lowest();
311 int h_in_max = (h_beg < 0) ? 0 : h_beg;
312 int w_in_max = (w_beg < 0) ? 0 : w_beg;
313 for (int h = 0; h < filter_rows; ++h) {
314 const int h_in = h_beg + h * rate_rows;
315 if (h_in >= 0 && h_in < input_rows) {
316 for (int w = 0; w < filter_cols; ++w) {
317 const int w_in = w_beg + w * rate_cols;
318 if (w_in >= 0 && w_in < input_cols) {
319 const T val = input(b, h_in, w_in, d) + filter(h, w, d);
320 if (val > cur_val) {
321 cur_val = val;
322 h_in_max = h_in;
323 w_in_max = w_in;
324 }
325 }
326 }
327 }
328 }
329 if (h_in_max < input_rows && w_in_max < input_cols) {
330 in_backprop(b, h_in_max, w_in_max, d) +=
331 out_backprop(b, h_out, w_out, d);
332 }
333 }
334 }
335 }
336 }
337 }
338};
339} // namespace functor
340
341template <typename Device, typename T>
342class DilationBackpropFilterOp : public OpKernel {
343 public:
344 explicit DilationBackpropFilterOp(OpKernelConstruction* context)
345 : OpKernel(context) {
346 ParseAttributes(context, &strides_, &rates_, &padding_);
347 }
348
349 void Compute(OpKernelContext* context) override {
350 if (std::is_same<Device, GPUDevice>::value) {
351 OP_REQUIRES(context, !tensorflow::OpDeterminismRequired(),
352 errors::Unimplemented("Determinism is not yet supported "
353 "for Dilation2DBackpropFilter."));
354 }
355 const Tensor& input = context->input(0);
356 const Tensor& filter = context->input(1);
357 const Tensor& out_backprop = context->input(2);
358
359 // Determine relevant sizes from input and filters.
360 int stride_rows = 0, stride_cols = 0;
361 int rate_rows = 0, rate_cols = 0;
362 int64_t pad_top = 0, pad_left = 0;
363 int64_t out_rows = 0, out_cols = 0;
364 ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
365 &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
366 &out_cols);
367 if (!context->status().ok()) return;
368
369 // Verify that the incoming gradient tensor has the expected size
370 // [ batch, out_rows, out_cols, depth ]
371 const int batch = input.dim_size(0);
372 const int depth = input.dim_size(3);
373 OP_REQUIRES(context,
374 batch == out_backprop.dim_size(0) &&
375 out_rows == out_backprop.dim_size(1) &&
376 out_cols == out_backprop.dim_size(2) &&
377 depth == out_backprop.dim_size(3),
378 errors::InvalidArgument("out_backprop has incompatible size."));
379
380 // The computed filter_backprop has the same dimensions as the filter:
381 // [ batch, input_rows, input_cols, depth ]
382 Tensor* filter_backprop = nullptr;
383 OP_REQUIRES_OK(
384 context, context->allocate_output(0, filter.shape(), &filter_backprop));
385
386 // If there is nothing to compute, return.
387 if (filter.shape().num_elements() == 0) {
388 return;
389 }
390
391 functor::DilationBackpropFilter<Device, T>()(
392 context->eigen_device<Device>(), input.tensor<T, 4>(),
393 filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
394 stride_cols, rate_rows, rate_cols, pad_top, pad_left,
395 filter_backprop->tensor<T, 3>());
396 }
397
398 std::vector<int32> strides_;
399 std::vector<int32> rates_;
400 Padding padding_;
401};
402
403// Partial specialization of DilationBackpropFilter functor for a CPUDevice.
404namespace functor {
405template <typename T>
406struct DilationBackpropFilter<CPUDevice, T> {
407 void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
408 typename TTypes<T, 3>::ConstTensor filter,
409 typename TTypes<T, 4>::ConstTensor out_backprop,
410 int stride_rows, int stride_cols, int rate_rows,
411 int rate_cols, int pad_top, int pad_left,
412 typename TTypes<T, 3>::Tensor filter_backprop) {
413 const int batch = input.dimension(0);
414 const int input_rows = input.dimension(1);
415 const int input_cols = input.dimension(2);
416 const int depth = input.dimension(3);
417
418 const int filter_rows = filter.dimension(0);
419 const int filter_cols = filter.dimension(1);
420
421 const int output_rows = out_backprop.dimension(1);
422 const int output_cols = out_backprop.dimension(2);
423
424 // Initialize gradient with all zeros.
425 filter_backprop.setZero();
426
427 // This is a reference implementation, likely to be slow.
428 // TODO(gpapan): Write multi-threaded implementation.
429 // In the case of multiple argmax branches, we only back-propagate along the
430 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
431 // similarly to the max-pooling backward routines.
432 for (int b = 0; b < batch; ++b) {
433 for (int h_out = 0; h_out < output_rows; ++h_out) {
434 int h_beg = h_out * stride_rows - pad_top;
435 for (int w_out = 0; w_out < output_cols; ++w_out) {
436 int w_beg = w_out * stride_cols - pad_left;
437 for (int d = 0; d < depth; ++d) {
438 T cur_val = Eigen::NumTraits<T>::lowest();
439 int h_max = 0;
440 int w_max = 0;
441 for (int h = 0; h < filter_rows; ++h) {
442 const int h_in = h_beg + h * rate_rows;
443 if (h_in >= 0 && h_in < input_rows) {
444 for (int w = 0; w < filter_cols; ++w) {
445 const int w_in = w_beg + w * rate_cols;
446 if (w_in >= 0 && w_in < input_cols) {
447 const T val = input(b, h_in, w_in, d) + filter(h, w, d);
448 if (val > cur_val) {
449 cur_val = val;
450 h_max = h;
451 w_max = w;
452 }
453 }
454 }
455 }
456 }
457 if (h_max < filter_rows && w_max < filter_cols) {
458 filter_backprop(h_max, w_max, d) +=
459 out_backprop(b, h_out, w_out, d);
460 }
461 }
462 }
463 }
464 }
465 }
466};
467} // namespace functor
468
469#define REGISTER(T) \
470 REGISTER_KERNEL_BUILDER( \
471 Name("Dilation2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
472 DilationOp<CPUDevice, T>); \
473 \
474 REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \
475 .Device(DEVICE_CPU) \
476 .TypeConstraint<T>("T"), \
477 DilationBackpropInputOp<CPUDevice, T>); \
478 \
479 REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \
480 .Device(DEVICE_CPU) \
481 .TypeConstraint<T>("T"), \
482 DilationBackpropFilterOp<CPUDevice, T>);
483
484TF_CALL_REAL_NUMBER_TYPES(REGISTER);
485
486#undef REGISTER
487
488#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
489
490#define REGISTER(T) \
491 REGISTER_KERNEL_BUILDER( \
492 Name("Dilation2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
493 DilationOp<GPUDevice, T>); \
494 \
495 REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \
496 .Device(DEVICE_GPU) \
497 .TypeConstraint<T>("T"), \
498 DilationBackpropInputOp<GPUDevice, T>); \
499 \
500 REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \
501 .Device(DEVICE_GPU) \
502 .TypeConstraint<T>("T"), \
503 DilationBackpropFilterOp<GPUDevice, T>);
504
505TF_CALL_GPU_NUMBER_TYPES(REGISTER);
506
507#undef REGISTER
508
509#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
510
511} // namespace tensorflow
512