1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_slice.h"
17
18#include <limits>
19#include <vector>
20
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/strings/numbers.h"
23#include "tensorflow/core/lib/strings/str_util.h"
24#include "tensorflow/core/lib/strings/strcat.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace tensorflow {
28
29TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
30
31TensorSlice::TensorSlice(const TensorSliceProto& proto) {
32 starts_.reserve(proto.extent_size());
33 lengths_.reserve(proto.extent_size());
34 for (const auto& e : proto.extent()) {
35 starts_.push_back(e.start());
36 lengths_.push_back(GetExtentLength(e));
37 }
38}
39
40TensorSlice::TensorSlice(
41 std::initializer_list<std::pair<int64_t, int64_t>> extents) {
42 starts_.reserve(extents.size());
43 lengths_.reserve(extents.size());
44 for (const auto& e : extents) {
45 starts_.push_back(e.first);
46 lengths_.push_back(e.second);
47 }
48}
49
50Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto,
51 TensorSlice* output) {
52 output->Clear();
53 output->starts_.reserve(proto.extent_size());
54 output->lengths_.reserve(proto.extent_size());
55 for (const auto& e : proto.extent()) {
56 int64_t l = GetExtentLength(e);
57 if (e.start() != 0 || l != kFullExtent) {
58 if (e.start() < 0 || l <= 0) {
59 return errors::InvalidArgument(
60 "Expected non-negative start and positive length but got start = ",
61 e.start(), ", length = ", l, ": extent = ", e.ShortDebugString());
62 }
63 // Calculating the extent end must not cause signed integer overflow.
64 if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) >
65 std::numeric_limits<int64_t>::max()) {
66 return errors::InvalidArgument(
67 "Extent end exceeds the maximum possible size: extent = ",
68 e.ShortDebugString());
69 }
70 }
71 output->starts_.push_back(e.start());
72 output->lengths_.push_back(l);
73 }
74
75 return OkStatus();
76}
77
78Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
79 std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
80 slice->starts_.reserve(items.size());
81 slice->lengths_.reserve(items.size());
82 for (const string& x : items) {
83 int64_t s, l;
84 if (x == "-") {
85 // "everything"
86 s = 0;
87 l = kFullExtent;
88 } else {
89 std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
90 if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
91 !strings::safe_strto64(sl[1], &l)) {
92 return errors::InvalidArgument(
93 "Expected a pair of numbers or '-' "
94 "but got '",
95 x, "': string = ", str);
96 }
97 if (s < 0 || l <= 0) {
98 return errors::InvalidArgument(
99 "Expected non-negative start and "
100 "positive length but got start = ",
101 s, ", length = ", l, ": string = ", str);
102 }
103 }
104 slice->starts_.push_back(s);
105 slice->lengths_.push_back(l);
106 }
107
108 return OkStatus();
109}
110
111void TensorSlice::Clear() {
112 starts_.clear();
113 lengths_.clear();
114}
115
116bool TensorSlice::IsFull() const {
117 for (int d = 0; d < dims(); ++d) {
118 if (!IsFullAt(d)) return false;
119 }
120 return true;
121}
122
123void TensorSlice::SetFullSlice(int dim) {
124 Clear();
125 starts_.reserve(dim);
126 lengths_.reserve(dim);
127 for (int d = 0; d < dim; ++d) {
128 starts_.push_back(0);
129 lengths_.push_back(kFullExtent);
130 }
131}
132
133void TensorSlice::Extend(int dim) {
134 int old_dim = dims();
135 DCHECK_LE(old_dim, dim);
136 starts_.resize(dim);
137 lengths_.resize(dim);
138 for (int d = old_dim; d < dim; ++d) {
139 starts_[d] = 0;
140 lengths_[d] = kFullExtent;
141 }
142}
143
144void TensorSlice::AsProto(TensorSliceProto* proto) const {
145 for (int d = 0; d < dims(); ++d) {
146 TensorSliceProto::Extent* e = proto->add_extent();
147 // We only need to record the explicit slice for non-full slices
148 if (!IsFullAt(d)) {
149 e->set_start(starts_[d]);
150 e->set_length(lengths_[d]);
151 }
152 }
153}
154
155string TensorSlice::DebugString() const {
156 string buffer;
157 bool first = true;
158 for (int d = 0; d < dims(); ++d) {
159 if (!first) {
160 buffer.append(":");
161 }
162 if (IsFullAt(d)) {
163 buffer.append("-");
164 } else {
165 strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
166 }
167 first = false;
168 }
169 return buffer;
170}
171
172bool TensorSlice::Intersect(const TensorSlice& other,
173 TensorSlice* result) const {
174 // First, if two slices have different ranks, they obviously don't overlap
175 // -- in fact they are not compatible.
176 if (dims() != other.dims()) {
177 return false;
178 }
179
180 // Setting the result to the right dimension
181 if (result) {
182 result->SetFullSlice(dims());
183 }
184 // The two slices overlap if they overlap in all dimensions.
185 for (int d = 0; d < dims(); ++d) {
186 if (IsFullAt(d)) {
187 if (result) {
188 result->set_start(d, other.start(d));
189 result->set_length(d, other.length(d));
190 }
191 } else if (other.IsFullAt(d)) {
192 if (result) {
193 result->set_start(d, start(d));
194 result->set_length(d, length(d));
195 }
196 } else {
197 // If we have an intersection here, it should have a start that is the
198 // max of the two starts and an end that is the min of the two ends.
199 int64_t s = std::max(start(d), other.start(d));
200 int64_t l = std::min(end(d), other.end(d)) - s;
201 if (l > 0) {
202 // We have a real intersection
203 if (result) {
204 result->set_start(d, s);
205 result->set_length(d, l);
206 }
207 } else {
208 // We don't have an intersection for this dimension -- thus we don't
209 // have any intersection at all.
210 if (result) {
211 result->Clear();
212 }
213 return false;
214 }
215 }
216 }
217 // If we are here, we know there is overlap in every dimension.
218 return true;
219}
220
221bool TensorSlice::operator==(const TensorSlice& other) const {
222 return dims() == other.dims() && starts_ == other.starts_ &&
223 lengths_ == other.lengths_;
224}
225
226void TensorSlice::ComputeRelative(const TensorSlice& sub,
227 TensorSlice* relative) const {
228 DCHECK_EQ(dims(), sub.dims());
229 relative->SetFullSlice(dims());
230 for (int d = 0; d < dims(); ++d) {
231 if (IsFullAt(d)) {
232 relative->set_start(d, sub.start(d));
233 relative->set_length(d, sub.length(d));
234 } else {
235 // Otherwise the relative start is the difference between the start of
236 // sub and the start of base
237 relative->set_start(d, sub.start(d) - start(d));
238 relative->set_length(d, sub.length(d));
239 }
240 }
241}
242
243void TensorSlice::UpdateToCover(const TensorSlice& other) {
244 DCHECK_EQ(dims(), other.dims());
245 for (int d = 0; d < dims(); ++d) {
246 if (!IsFullAt(d)) {
247 if (other.IsFullAt(d)) {
248 starts_[d] = 0;
249 lengths_[d] = kFullExtent;
250 } else {
251 const auto new_end = std::max(end(d), other.end(d));
252 set_start(d, std::min(start(d), other.start(d)));
253 set_length(d, new_end - start(d));
254 }
255 }
256 }
257}
258
259// static
260bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
261 return extent.has_length_case() == TensorSliceProto::Extent::kLength;
262}
263
264// static
265int64_t TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
266 if (!HasExtentLength(extent)) return -1;
267 return extent.length();
268}
269
270Status TensorSlice::SliceTensorShape(const TensorShape& shape,
271 TensorShape* result_shape) const {
272 result_shape->Clear();
273 // Mismatching ranks: we can't apply the slice at all.
274 if (shape.dims() != dims()) {
275 return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
276 ", slice = ", DebugString());
277 }
278 for (int d = 0; d < dims(); ++d) {
279 if (IsFullAt(d)) {
280 result_shape->AddDim(shape.dim_size(d));
281 } else {
282 // Check if the extent applies to the dimension
283 if (end(d) <= shape.dim_size(d)) {
284 // Yes: the end is within the range of the dim -- we adjust the result
285 // shape so that its size along this dimension is the length of the
286 // slice.
287 result_shape->AddDim(length(d));
288 } else {
289 // The extent doesn't apply to the dimension
290 result_shape->Clear();
291 return errors::Internal("Extent in dimension ", d,
292 " out of bounds: shape = ", shape.DebugString(),
293 ", slice = ", DebugString());
294 }
295 }
296 }
297 // If we are here, we have successfully applied the shape.
298 return OkStatus();
299}
300
301const int64_t TensorSlice::kFullExtent = -1;
302
303} // namespace tensorflow
304