1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/kernel_shape_util.h"
16
17#include "tensorflow/core/lib/core/errors.h"
18
19namespace tensorflow {
20Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size,
21 int64_t dilation_rate, int64_t stride,
22 Padding padding_type,
23 int64_t* output_size,
24 int64_t* padding_before,
25 int64_t* padding_after) {
26 if (stride <= 0) {
27 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
28 }
29 if (dilation_rate < 1) {
30 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
31 dilation_rate);
32 }
33
34 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
35 int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
36 switch (padding_type) {
37 case Padding::VALID:
38 *output_size = (input_size - effective_filter_size + stride) / stride;
39 *padding_before = *padding_after = 0;
40 break;
41 case Padding::EXPLICIT:
42 *output_size = (input_size + *padding_before + *padding_after -
43 effective_filter_size + stride) /
44 stride;
45 break;
46 case Padding::SAME:
47 *output_size = (input_size + stride - 1) / stride;
48 const int64_t padding_needed =
49 std::max(int64_t{0}, (*output_size - 1) * stride +
50 effective_filter_size - input_size);
51 // For odd values of total padding, add more padding at the 'right'
52 // side of the given dimension.
53 *padding_before = padding_needed / 2;
54 *padding_after = padding_needed - *padding_before;
55 break;
56 }
57 if (*output_size < 0) {
58 return errors::InvalidArgument(
59 "Computed output size would be negative: ", *output_size,
60 " [input_size: ", input_size,
61 ", effective_filter_size: ", effective_filter_size,
62 ", stride: ", stride, "]");
63 }
64 return OkStatus();
65}
66
67Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size,
68 int64_t stride, Padding padding_type,
69 int64_t* output_size,
70 int64_t* padding_before,
71 int64_t* padding_after) {
72 return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
73 /*dilation_rate=*/1, stride,
74 padding_type, output_size,
75 padding_before, padding_after);
76}
77
78Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size,
79 int64_t stride, Padding padding_type,
80 int64_t* output_size, int64_t* padding_size) {
81 if (padding_type == Padding::EXPLICIT) {
82 return errors::Internal(
83 "GetWindowedOutputSize does not handle EXPLICIT padding; call "
84 "GetWindowedOutputSizeVerbose instead");
85 }
86 int64_t padding_after_unused;
87 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
88 padding_type, output_size, padding_size,
89 &padding_after_unused);
90}
91
92Status GetWindowedOutputSizeV2(int64_t input_size, int64_t filter_size,
93 int64_t dilation_rate, int64_t stride,
94 Padding padding_type, int64_t* output_size,
95 int64_t* padding_size) {
96 if (padding_type == Padding::EXPLICIT) {
97 return errors::Internal(
98 "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
99 "GetWindowedOutputSizeVerboseV2 instead");
100 }
101 int64_t padding_after_unused;
102 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
103 stride, padding_type, output_size,
104 padding_size, &padding_after_unused);
105}
106
107Status Get3dOutputSize(const std::array<int64_t, 3>& input,
108 const std::array<int64_t, 3>& window,
109 const std::array<int64_t, 3>& strides,
110 Padding padding_type, std::array<int64_t, 3>* output_ptr,
111 std::array<int64_t, 3>* padding_ptr) {
112 for (size_t i = 0; i < input.size(); ++i) {
113 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
114 padding_type, &(*output_ptr)[i],
115 &(*padding_ptr)[i]));
116 }
117 return OkStatus();
118}
119
120Status Get3dOutputSizeV2(const std::array<int64_t, 3>& input,
121 const std::array<int64_t, 3>& window,
122 const std::array<int64_t, 3>& dilations,
123 const std::array<int64_t, 3>& strides,
124 Padding padding_type,
125 std::array<int64_t, 3>* output_ptr,
126 std::array<int64_t, 3>* padding_ptr) {
127 for (size_t i = 0; i < input.size(); ++i) {
128 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
129 input[i], window[i], dilations[i], strides[i], padding_type,
130 &(*output_ptr)[i], &(*padding_ptr)[i]));
131 }
132 return OkStatus();
133}
134} // namespace tensorflow
135