1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/kernel_shape_util.h" |
16 | |
17 | #include "tensorflow/core/lib/core/errors.h" |
18 | |
19 | namespace tensorflow { |
20 | Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, |
21 | int64_t dilation_rate, int64_t stride, |
22 | Padding padding_type, |
23 | int64_t* output_size, |
24 | int64_t* padding_before, |
25 | int64_t* padding_after) { |
26 | if (stride <= 0) { |
27 | return errors::InvalidArgument("Stride must be > 0, but got " , stride); |
28 | } |
29 | if (dilation_rate < 1) { |
30 | return errors::InvalidArgument("Dilation rate must be >= 1, but got " , |
31 | dilation_rate); |
32 | } |
33 | |
34 | // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. |
35 | int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; |
36 | switch (padding_type) { |
37 | case Padding::VALID: |
38 | *output_size = (input_size - effective_filter_size + stride) / stride; |
39 | *padding_before = *padding_after = 0; |
40 | break; |
41 | case Padding::EXPLICIT: |
42 | *output_size = (input_size + *padding_before + *padding_after - |
43 | effective_filter_size + stride) / |
44 | stride; |
45 | break; |
46 | case Padding::SAME: |
47 | *output_size = (input_size + stride - 1) / stride; |
48 | const int64_t padding_needed = |
49 | std::max(int64_t{0}, (*output_size - 1) * stride + |
50 | effective_filter_size - input_size); |
51 | // For odd values of total padding, add more padding at the 'right' |
52 | // side of the given dimension. |
53 | *padding_before = padding_needed / 2; |
54 | *padding_after = padding_needed - *padding_before; |
55 | break; |
56 | } |
57 | if (*output_size < 0) { |
58 | return errors::InvalidArgument( |
59 | "Computed output size would be negative: " , *output_size, |
60 | " [input_size: " , input_size, |
61 | ", effective_filter_size: " , effective_filter_size, |
62 | ", stride: " , stride, "]" ); |
63 | } |
64 | return OkStatus(); |
65 | } |
66 | |
67 | Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, |
68 | int64_t stride, Padding padding_type, |
69 | int64_t* output_size, |
70 | int64_t* padding_before, |
71 | int64_t* padding_after) { |
72 | return GetWindowedOutputSizeVerboseV2(input_size, filter_size, |
73 | /*dilation_rate=*/1, stride, |
74 | padding_type, output_size, |
75 | padding_before, padding_after); |
76 | } |
77 | |
78 | Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, |
79 | int64_t stride, Padding padding_type, |
80 | int64_t* output_size, int64_t* padding_size) { |
81 | if (padding_type == Padding::EXPLICIT) { |
82 | return errors::Internal( |
83 | "GetWindowedOutputSize does not handle EXPLICIT padding; call " |
84 | "GetWindowedOutputSizeVerbose instead" ); |
85 | } |
86 | int64_t padding_after_unused; |
87 | return GetWindowedOutputSizeVerbose(input_size, filter_size, stride, |
88 | padding_type, output_size, padding_size, |
89 | &padding_after_unused); |
90 | } |
91 | |
92 | Status GetWindowedOutputSizeV2(int64_t input_size, int64_t filter_size, |
93 | int64_t dilation_rate, int64_t stride, |
94 | Padding padding_type, int64_t* output_size, |
95 | int64_t* padding_size) { |
96 | if (padding_type == Padding::EXPLICIT) { |
97 | return errors::Internal( |
98 | "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call " |
99 | "GetWindowedOutputSizeVerboseV2 instead" ); |
100 | } |
101 | int64_t padding_after_unused; |
102 | return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, |
103 | stride, padding_type, output_size, |
104 | padding_size, &padding_after_unused); |
105 | } |
106 | |
107 | Status Get3dOutputSize(const std::array<int64_t, 3>& input, |
108 | const std::array<int64_t, 3>& window, |
109 | const std::array<int64_t, 3>& strides, |
110 | Padding padding_type, std::array<int64_t, 3>* output_ptr, |
111 | std::array<int64_t, 3>* padding_ptr) { |
112 | for (size_t i = 0; i < input.size(); ++i) { |
113 | TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i], |
114 | padding_type, &(*output_ptr)[i], |
115 | &(*padding_ptr)[i])); |
116 | } |
117 | return OkStatus(); |
118 | } |
119 | |
120 | Status Get3dOutputSizeV2(const std::array<int64_t, 3>& input, |
121 | const std::array<int64_t, 3>& window, |
122 | const std::array<int64_t, 3>& dilations, |
123 | const std::array<int64_t, 3>& strides, |
124 | Padding padding_type, |
125 | std::array<int64_t, 3>* output_ptr, |
126 | std::array<int64_t, 3>* padding_ptr) { |
127 | for (size_t i = 0; i < input.size(); ++i) { |
128 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( |
129 | input[i], window[i], dilations[i], strides[i], padding_type, |
130 | &(*output_ptr)[i], &(*padding_ptr)[i])); |
131 | } |
132 | return OkStatus(); |
133 | } |
134 | } // namespace tensorflow |
135 | |