1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/tensor_slice.h" |
17 | |
18 | #include <limits> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/strings/numbers.h" |
23 | #include "tensorflow/core/lib/strings/str_util.h" |
24 | #include "tensorflow/core/lib/strings/strcat.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } |
30 | |
31 | TensorSlice::TensorSlice(const TensorSliceProto& proto) { |
32 | starts_.reserve(proto.extent_size()); |
33 | lengths_.reserve(proto.extent_size()); |
34 | for (const auto& e : proto.extent()) { |
35 | starts_.push_back(e.start()); |
36 | lengths_.push_back(GetExtentLength(e)); |
37 | } |
38 | } |
39 | |
40 | TensorSlice::TensorSlice( |
41 | std::initializer_list<std::pair<int64_t, int64_t>> extents) { |
42 | starts_.reserve(extents.size()); |
43 | lengths_.reserve(extents.size()); |
44 | for (const auto& e : extents) { |
45 | starts_.push_back(e.first); |
46 | lengths_.push_back(e.second); |
47 | } |
48 | } |
49 | |
50 | Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto, |
51 | TensorSlice* output) { |
52 | output->Clear(); |
53 | output->starts_.reserve(proto.extent_size()); |
54 | output->lengths_.reserve(proto.extent_size()); |
55 | for (const auto& e : proto.extent()) { |
56 | int64_t l = GetExtentLength(e); |
57 | if (e.start() != 0 || l != kFullExtent) { |
58 | if (e.start() < 0 || l <= 0) { |
59 | return errors::InvalidArgument( |
60 | "Expected non-negative start and positive length but got start = " , |
61 | e.start(), ", length = " , l, ": extent = " , e.ShortDebugString()); |
62 | } |
63 | // Calculating the extent end must not cause signed integer overflow. |
64 | if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) > |
65 | std::numeric_limits<int64_t>::max()) { |
66 | return errors::InvalidArgument( |
67 | "Extent end exceeds the maximum possible size: extent = " , |
68 | e.ShortDebugString()); |
69 | } |
70 | } |
71 | output->starts_.push_back(e.start()); |
72 | output->lengths_.push_back(l); |
73 | } |
74 | |
75 | return OkStatus(); |
76 | } |
77 | |
78 | Status TensorSlice::Parse(const string& str, TensorSlice* slice) { |
79 | std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty()); |
80 | slice->starts_.reserve(items.size()); |
81 | slice->lengths_.reserve(items.size()); |
82 | for (const string& x : items) { |
83 | int64_t s, l; |
84 | if (x == "-" ) { |
85 | // "everything" |
86 | s = 0; |
87 | l = kFullExtent; |
88 | } else { |
89 | std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty()); |
90 | if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) || |
91 | !strings::safe_strto64(sl[1], &l)) { |
92 | return errors::InvalidArgument( |
93 | "Expected a pair of numbers or '-' " |
94 | "but got '" , |
95 | x, "': string = " , str); |
96 | } |
97 | if (s < 0 || l <= 0) { |
98 | return errors::InvalidArgument( |
99 | "Expected non-negative start and " |
100 | "positive length but got start = " , |
101 | s, ", length = " , l, ": string = " , str); |
102 | } |
103 | } |
104 | slice->starts_.push_back(s); |
105 | slice->lengths_.push_back(l); |
106 | } |
107 | |
108 | return OkStatus(); |
109 | } |
110 | |
111 | void TensorSlice::Clear() { |
112 | starts_.clear(); |
113 | lengths_.clear(); |
114 | } |
115 | |
116 | bool TensorSlice::IsFull() const { |
117 | for (int d = 0; d < dims(); ++d) { |
118 | if (!IsFullAt(d)) return false; |
119 | } |
120 | return true; |
121 | } |
122 | |
123 | void TensorSlice::SetFullSlice(int dim) { |
124 | Clear(); |
125 | starts_.reserve(dim); |
126 | lengths_.reserve(dim); |
127 | for (int d = 0; d < dim; ++d) { |
128 | starts_.push_back(0); |
129 | lengths_.push_back(kFullExtent); |
130 | } |
131 | } |
132 | |
133 | void TensorSlice::Extend(int dim) { |
134 | int old_dim = dims(); |
135 | DCHECK_LE(old_dim, dim); |
136 | starts_.resize(dim); |
137 | lengths_.resize(dim); |
138 | for (int d = old_dim; d < dim; ++d) { |
139 | starts_[d] = 0; |
140 | lengths_[d] = kFullExtent; |
141 | } |
142 | } |
143 | |
144 | void TensorSlice::AsProto(TensorSliceProto* proto) const { |
145 | for (int d = 0; d < dims(); ++d) { |
146 | TensorSliceProto::Extent* e = proto->add_extent(); |
147 | // We only need to record the explicit slice for non-full slices |
148 | if (!IsFullAt(d)) { |
149 | e->set_start(starts_[d]); |
150 | e->set_length(lengths_[d]); |
151 | } |
152 | } |
153 | } |
154 | |
155 | string TensorSlice::DebugString() const { |
156 | string buffer; |
157 | bool first = true; |
158 | for (int d = 0; d < dims(); ++d) { |
159 | if (!first) { |
160 | buffer.append(":" ); |
161 | } |
162 | if (IsFullAt(d)) { |
163 | buffer.append("-" ); |
164 | } else { |
165 | strings::StrAppend(&buffer, starts_[d], "," , lengths_[d]); |
166 | } |
167 | first = false; |
168 | } |
169 | return buffer; |
170 | } |
171 | |
172 | bool TensorSlice::Intersect(const TensorSlice& other, |
173 | TensorSlice* result) const { |
174 | // First, if two slices have different ranks, they obviously don't overlap |
175 | // -- in fact they are not compatible. |
176 | if (dims() != other.dims()) { |
177 | return false; |
178 | } |
179 | |
180 | // Setting the result to the right dimension |
181 | if (result) { |
182 | result->SetFullSlice(dims()); |
183 | } |
184 | // The two slices overlap if they overlap in all dimensions. |
185 | for (int d = 0; d < dims(); ++d) { |
186 | if (IsFullAt(d)) { |
187 | if (result) { |
188 | result->set_start(d, other.start(d)); |
189 | result->set_length(d, other.length(d)); |
190 | } |
191 | } else if (other.IsFullAt(d)) { |
192 | if (result) { |
193 | result->set_start(d, start(d)); |
194 | result->set_length(d, length(d)); |
195 | } |
196 | } else { |
197 | // If we have an intersection here, it should have a start that is the |
198 | // max of the two starts and an end that is the min of the two ends. |
199 | int64_t s = std::max(start(d), other.start(d)); |
200 | int64_t l = std::min(end(d), other.end(d)) - s; |
201 | if (l > 0) { |
202 | // We have a real intersection |
203 | if (result) { |
204 | result->set_start(d, s); |
205 | result->set_length(d, l); |
206 | } |
207 | } else { |
208 | // We don't have an intersection for this dimension -- thus we don't |
209 | // have any intersection at all. |
210 | if (result) { |
211 | result->Clear(); |
212 | } |
213 | return false; |
214 | } |
215 | } |
216 | } |
217 | // If we are here, we know there is overlap in every dimension. |
218 | return true; |
219 | } |
220 | |
221 | bool TensorSlice::operator==(const TensorSlice& other) const { |
222 | return dims() == other.dims() && starts_ == other.starts_ && |
223 | lengths_ == other.lengths_; |
224 | } |
225 | |
226 | void TensorSlice::ComputeRelative(const TensorSlice& sub, |
227 | TensorSlice* relative) const { |
228 | DCHECK_EQ(dims(), sub.dims()); |
229 | relative->SetFullSlice(dims()); |
230 | for (int d = 0; d < dims(); ++d) { |
231 | if (IsFullAt(d)) { |
232 | relative->set_start(d, sub.start(d)); |
233 | relative->set_length(d, sub.length(d)); |
234 | } else { |
235 | // Otherwise the relative start is the difference between the start of |
236 | // sub and the start of base |
237 | relative->set_start(d, sub.start(d) - start(d)); |
238 | relative->set_length(d, sub.length(d)); |
239 | } |
240 | } |
241 | } |
242 | |
243 | void TensorSlice::UpdateToCover(const TensorSlice& other) { |
244 | DCHECK_EQ(dims(), other.dims()); |
245 | for (int d = 0; d < dims(); ++d) { |
246 | if (!IsFullAt(d)) { |
247 | if (other.IsFullAt(d)) { |
248 | starts_[d] = 0; |
249 | lengths_[d] = kFullExtent; |
250 | } else { |
251 | const auto new_end = std::max(end(d), other.end(d)); |
252 | set_start(d, std::min(start(d), other.start(d))); |
253 | set_length(d, new_end - start(d)); |
254 | } |
255 | } |
256 | } |
257 | } |
258 | |
259 | // static |
260 | bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { |
261 | return extent.has_length_case() == TensorSliceProto::Extent::kLength; |
262 | } |
263 | |
264 | // static |
265 | int64_t TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { |
266 | if (!HasExtentLength(extent)) return -1; |
267 | return extent.length(); |
268 | } |
269 | |
270 | Status TensorSlice::SliceTensorShape(const TensorShape& shape, |
271 | TensorShape* result_shape) const { |
272 | result_shape->Clear(); |
273 | // Mismatching ranks: we can't apply the slice at all. |
274 | if (shape.dims() != dims()) { |
275 | return errors::Internal("Mismatching ranks: shape = " , shape.DebugString(), |
276 | ", slice = " , DebugString()); |
277 | } |
278 | for (int d = 0; d < dims(); ++d) { |
279 | if (IsFullAt(d)) { |
280 | result_shape->AddDim(shape.dim_size(d)); |
281 | } else { |
282 | // Check if the extent applies to the dimension |
283 | if (end(d) <= shape.dim_size(d)) { |
284 | // Yes: the end is within the range of the dim -- we adjust the result |
285 | // shape so that its size along this dimension is the length of the |
286 | // slice. |
287 | result_shape->AddDim(length(d)); |
288 | } else { |
289 | // The extent doesn't apply to the dimension |
290 | result_shape->Clear(); |
291 | return errors::Internal("Extent in dimension " , d, |
292 | " out of bounds: shape = " , shape.DebugString(), |
293 | ", slice = " , DebugString()); |
294 | } |
295 | } |
296 | } |
297 | // If we are here, we have successfully applied the shape. |
298 | return OkStatus(); |
299 | } |
300 | |
301 | const int64_t TensorSlice::kFullExtent = -1; |
302 | |
303 | } // namespace tensorflow |
304 | |