1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/lib/core/stringpiece.h" |
23 | #include "tensorflow/core/util/padding.h" |
24 | #include "tensorflow/core/util/tensor_format.h" |
25 | |
26 | namespace tensorflow { |
27 | // Information about a single spatial dimension for a convolution |
28 | // backpropagation. |
29 | struct ConvBackpropSpatialDimension { |
30 | int64_t input_size; |
31 | int64_t filter_size; |
32 | int64_t output_size; |
33 | int64_t stride; |
34 | int64_t dilation; |
35 | |
36 | // Output size after scaling by the stride. |
37 | int64_t expanded_output_size; |
38 | |
39 | // Number of padding elements to be added before/after this dimension of |
40 | // the input when computing Conv?DBackpropInput. |
41 | int64_t pad_before, pad_after; |
42 | }; |
43 | |
44 | // Computed dimensions for a backwards convolution. |
45 | struct ConvBackpropDimensions { |
46 | // Information about each spatial dimension. |
47 | gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims; |
48 | |
49 | // Batch size. |
50 | int64_t batch_size; |
51 | |
52 | // Input and output feature depth. |
53 | int64_t in_depth, out_depth; |
54 | |
55 | // Convenience access methods for spatial dimensions properties. |
56 | int64_t input_size(int dim) const { return spatial_dims[dim].input_size; } |
57 | int64_t filter_size(int dim) const { return spatial_dims[dim].filter_size; } |
58 | int64_t output_size(int dim) const { return spatial_dims[dim].output_size; } |
59 | int64_t stride(int dim) const { return spatial_dims[dim].stride; } |
60 | int64_t dilation(int dim) const { return spatial_dims[dim].dilation; } |
61 | |
62 | // Compute padding for the given spatial dimension. |
63 | int SpatialPadding(const Padding& padding, int dim) const; |
64 | }; |
65 | |
66 | // Common code between implementations of Conv?DBackpropInput and |
67 | // Conv?DBackpropFilter. Verifies that the dimensions all match, and computes |
68 | // sizes/padding for the spatial dimensions. Does not support explicit padding. |
69 | Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, |
70 | const TensorShape& input_shape, |
71 | const TensorShape& filter_shape, |
72 | const TensorShape& out_backprop_shape, |
73 | const std::vector<int32>& strides, |
74 | Padding padding, TensorFormat data_format, |
75 | ConvBackpropDimensions* dims); |
76 | |
77 | // The V2 version computes the same outputs with arbitrary dilation rate and |
78 | // supports explicit padding. |
79 | // TODO(b/67112639): Merge V2 versions and the original versions eventually. |
80 | Status ConvBackpropComputeDimensionsV2( |
81 | StringPiece label, int num_spatial_dims, const TensorShape& input_shape, |
82 | const TensorShape& filter_shape, const TensorShape& out_backprop_shape, |
83 | const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides, |
84 | Padding padding, absl::Span<const int64_t> explicit_paddings, |
85 | TensorFormat data_format, ConvBackpropDimensions* dims); |
86 | |
87 | // Computes the shape of the in_backprop. |
88 | Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes, |
89 | const TensorShape& filter_shape, |
90 | const TensorShape& out_backprop_shape, |
91 | const TensorFormat& data_format, |
92 | TensorShape* input_shape); |
93 | } // namespace tensorflow |
94 | |
95 | #endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ |
96 | |