1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/reduction_ops_common.h" |
17 | |
18 | #include "tensorflow/core/lib/strings/str_util.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | TensorShape ReductionHelper::out_reshape() const { |
23 | TensorShape shape; |
24 | for (auto size : out_reshape_) shape.AddDim(size); |
25 | return shape; |
26 | } |
27 | |
28 | // The final output shape must be allocated with this shape. |
29 | TensorShape ReductionHelper::out_shape() const { |
30 | TensorShape shape; |
31 | for (auto size : out_shape_) shape.AddDim(size); |
32 | return shape; |
33 | } |
34 | |
35 | TensorShape ReductionHelper::shuffled_shape() { |
36 | const int dims = data_reshape_.size(); |
37 | TensorShape shape; |
38 | for (int i = reduce_first_axis_; i < dims; i += 2) { |
39 | shape.AddDim(data_reshape_[i]); |
40 | } |
41 | for (int i = !reduce_first_axis_; i < dims; i += 2) { |
42 | shape.AddDim(data_reshape_[i]); |
43 | } |
44 | return shape; |
45 | } |
46 | |
47 | gtl::InlinedVector<int32, 8> ReductionHelper::permutation() { |
48 | const int dims = data_reshape_.size(); |
49 | const int unreduced_dims = (dims + !reduce_first_axis_) / 2; |
50 | gtl::InlinedVector<int32, 8> perm(dims); |
51 | for (int i = 0; i < unreduced_dims; i++) { |
52 | perm[i] = 2 * i + reduce_first_axis_; |
53 | } |
54 | for (int i = unreduced_dims; i < dims; i++) { |
55 | perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_; |
56 | } |
57 | return perm; |
58 | } |
59 | |
60 | template <typename Tperm> |
61 | Status SimplifyHelper(const Tensor& data, const Tensor& axis, |
62 | gtl::InlinedVector<bool, 4>& bitmap) { |
63 | auto axis_vec = axis.flat<Tperm>(); |
64 | for (int64_t i = 0; i < axis.NumElements(); ++i) { |
65 | Tperm index = axis_vec(i); |
66 | if (index < -data.dims() || index >= data.dims()) { |
67 | return errors::InvalidArgument("Invalid reduction dimension (" , index, |
68 | " for input with " , data.dims(), |
69 | " dimension(s)" ); |
70 | } |
71 | index = (index + data.dims()) % data.dims(); |
72 | if (bitmap[index]) { |
73 | return errors::InvalidArgument( |
74 | "Invalid reduction arguments: Axes contains duplicate dimension: " , |
75 | index); |
76 | } |
77 | bitmap[index] = true; |
78 | } |
79 | return OkStatus(); |
80 | } |
81 | |
82 | Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, |
83 | const bool keep_dims) { |
84 | // bitmap[i] indicates whether to reduce data along i-th axis. |
85 | gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); |
86 | if (axis.dtype() == DT_INT32) { |
87 | TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap)); |
88 | } else { |
89 | TF_RETURN_IF_ERROR(SimplifyHelper<int64_t>(data, axis, bitmap)); |
90 | } |
91 | // Output tensor's dim sizes. |
92 | out_shape_.clear(); |
93 | for (int i = 0; i < data.dims(); ++i) { |
94 | if (!bitmap[i]) { |
95 | // If we are not reducing along dimension i. |
96 | out_shape_.push_back(data.dim_size(i)); |
97 | } else if (keep_dims) { |
98 | // We are reducing along dimension i, but we want to keep the |
99 | // same number of dimensions, so we set the dimension of i to |
100 | // '1'. |
101 | out_shape_.push_back(1); |
102 | } |
103 | } |
104 | |
105 | // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of |
106 | // the input data before doing the reduction on the resulting |
107 | // tensor. The shape of the reduction is a reshape of the final |
108 | // output. |
109 | |
110 | // We'll skip the leading 1s. |
111 | int dim_index = 0; |
112 | for (; dim_index < data.dims(); ++dim_index) { |
113 | if (data.dim_size(dim_index) != 1) break; |
114 | } |
115 | if (dim_index >= data.dims()) { |
116 | // Special case. The input is essentially a scalar. |
117 | reduce_first_axis_ = true; |
118 | } else { |
119 | // Starting from the (dim_index)-th dimension, dimensions |
120 | // alternates between runs that need to be reduced and runs that |
121 | // don't. |
122 | // |
123 | // NOTE: If a dimension has size 1, we group it as the current |
124 | // run so that we can minimize the number of runs. |
125 | // |
126 | // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1, |
127 | // 5] by axes = [1, 4], we should treat the tensor as a [6, 5] |
128 | // and reduce by axes = [1] (i.e., the output is shape [6]). |
129 | reduce_first_axis_ = bitmap[dim_index]; |
130 | data_reshape_.push_back(data.dim_size(dim_index)); |
131 | ++dim_index; |
132 | for (; dim_index < data.dims(); ++dim_index) { |
133 | const auto size = data.dim_size(dim_index); |
134 | if (size == 1) { |
135 | bitmap[dim_index] = bitmap[dim_index - 1]; |
136 | } |
137 | if (bitmap[dim_index - 1] != bitmap[dim_index]) { |
138 | // Starts a new run of reduce or !reduce. |
139 | data_reshape_.push_back(size); |
140 | } else { |
141 | // Continue a run of reduce or !reduce. |
142 | data_reshape_.back() *= size; |
143 | } |
144 | } |
145 | // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc |
146 | // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_, |
147 | // otherwise, data_reshape_[0, 2, 4, ...] is. |
148 | for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size(); |
149 | i += 2) { |
150 | out_reshape_.push_back(data_reshape_[i]); |
151 | } |
152 | } |
153 | |
154 | VLOG(1) << "data reshape: " << absl::StrJoin(data_reshape_, "," ); |
155 | VLOG(1) << "out reshape: " << absl::StrJoin(out_reshape_, "," ); |
156 | VLOG(1) << "out shape: " << absl::StrJoin(out_shape_, "," ); |
157 | return OkStatus(); |
158 | } |
159 | |
160 | } // namespace tensorflow |
161 | |