1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/tensor_shape.h" |
17 | |
18 | #include "tensorflow/core/framework/bounds_check.h" |
19 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
20 | #include "tensorflow/core/lib/strings/str_util.h" |
21 | #include "tensorflow/core/lib/strings/strcat.h" |
22 | #include "tensorflow/core/platform/errors.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/util/overflow.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // TensorShape and PartialTensorShape should have no fields beyond |
30 | // TensorShapeRep. In particular, their sizes should be the same. |
31 | static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), |
32 | "TensorShape must have no fields beyond TensorShapeRep" ); |
33 | static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape), |
34 | "PartialTensorShape must have no fields beyond TensorShapeRep" ); |
35 | |
36 | template <class Shape> |
37 | static void AppendTo(const TensorShapeBase<Shape>& s, |
38 | gtl::InlinedVector<int64, 8>* vals) { |
39 | for (auto dim : s) { |
40 | vals->push_back(dim.size); |
41 | } |
42 | } |
43 | |
44 | void TensorShape::CheckDimsEqual(int NDIMS) const { |
45 | CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions" |
46 | << " from a tensor of " << dims() << " dimensions" ; |
47 | } |
48 | |
49 | void TensorShape::CheckDimsAtMost(int NDIMS) const { |
50 | CHECK_GE(NDIMS, dims()) << "Asking for tensor of at most " << NDIMS |
51 | << " dimensions from a tensor of " << dims() |
52 | << " dimensions" ; |
53 | } |
54 | |
55 | // TODO(slebedev): Consider merging IsValid implementations. |
56 | template <class Shape> |
57 | bool TensorShapeBase<Shape>::IsValid() { |
58 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
59 | // unknown_shape() set, and it seems hard to remove this without backwards |
60 | // compatibility issues. |
61 | if (kIsPartial && unknown_rank()) return dims() == 0; |
62 | int64_t num_elements = 1; |
63 | if (dims() > MaxDimensions()) return false; |
64 | for (auto d : dim_sizes()) { |
65 | if (d < (kIsPartial ? -1 : 0)) return false; |
66 | if (d == -1) { |
67 | num_elements = -1; |
68 | } else if (!kIsPartial || num_elements >= 0) { |
69 | num_elements = MultiplyWithoutOverflow(num_elements, d); |
70 | if (num_elements < 0) return false; |
71 | } |
72 | } |
73 | return true; |
74 | } |
75 | |
76 | template <class Shape> |
77 | bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) { |
78 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
79 | // unknown_shape() set, and it seems hard to remove this without backwards |
80 | // compatibility issues. |
81 | if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0; |
82 | int64_t num_elements = 1; |
83 | if (proto.dim().size() > MaxDimensions()) return false; |
84 | for (const auto& d : proto.dim()) { |
85 | if (d.size() < (kIsPartial ? -1 : 0)) return false; |
86 | if (d.size() == -1) { |
87 | num_elements = -1; |
88 | } else if (!kIsPartial || num_elements >= 0) { |
89 | num_elements = MultiplyWithoutOverflow(num_elements, d.size()); |
90 | if (num_elements < 0) return false; |
91 | } |
92 | } |
93 | return true; |
94 | } |
95 | |
96 | template <class Shape> |
97 | Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) { |
98 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
99 | // unknown_shape() set, and it seems hard to remove this without backwards |
100 | // compatibility issues. |
101 | if (kIsPartial && proto.unknown_rank()) { |
102 | if (proto.dim_size() > 0) { |
103 | return errors::InvalidArgument( |
104 | "An unknown shape must not have any dimensions set." ); |
105 | } |
106 | return OkStatus(); |
107 | } |
108 | int64_t num_elements = 1; |
109 | if (proto.dim().size() > MaxDimensions()) { |
110 | return errors::InvalidArgument("Shape " , DebugString(proto), |
111 | " has too many dimensions" ); |
112 | } |
113 | for (const auto& d : proto.dim()) { |
114 | if (d.size() < (kIsPartial ? -1 : 0)) { |
115 | if (kIsPartial) { |
116 | return errors::InvalidArgument( |
117 | "Shape " , DebugString(proto), |
118 | " has dimensions with values below -1 (where -1 means unknown)" ); |
119 | } else { |
120 | return errors::InvalidArgument("Shape " , DebugString(proto), |
121 | " is not fully defined" ); |
122 | } |
123 | } |
124 | if (d.size() == -1) { |
125 | num_elements = -1; |
126 | } else if (!kIsPartial || num_elements >= 0) { |
127 | num_elements = MultiplyWithoutOverflow(num_elements, d.size()); |
128 | if (num_elements < 0) { |
129 | return errors::InvalidArgument( |
130 | "Shape " , DebugString(proto), |
131 | " is too large (more than 2**63 - 1 entries)" ); |
132 | } |
133 | } |
134 | } |
135 | return OkStatus(); |
136 | } |
137 | |
138 | template <class Shape> |
139 | TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) { |
140 | set_tag(REP16); |
141 | set_data_type(DT_INVALID); |
142 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
143 | // unknown_shape() set, and it seems hard to remove this without backwards |
144 | // compatibility issues. |
145 | if (kIsPartial && proto.unknown_rank()) { |
146 | set_ndims_byte(kUnknownRank); |
147 | set_num_elements(-1); |
148 | } else { |
149 | set_ndims_byte(0); |
150 | set_num_elements(1); |
151 | for (const auto& d : proto.dim()) { |
152 | AddDim(d.size()); |
153 | } |
154 | } |
155 | } |
156 | |
157 | template <class Shape> |
158 | Status TensorShapeBase<Shape>::BuildTensorShapeBase( |
159 | const TensorShapeProto& proto, TensorShapeBase* out) { |
160 | out->set_tag(REP16); |
161 | out->set_data_type(DT_INVALID); |
162 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
163 | // unknown_shape() set, and it seems hard to remove this without backwards |
164 | // compatibility issues. |
165 | if (kIsPartial && proto.unknown_rank()) { |
166 | out->set_ndims_byte(kUnknownRank); |
167 | out->set_num_elements(-1); |
168 | } else { |
169 | out->set_ndims_byte(0); |
170 | out->set_num_elements(1); |
171 | Status s = OkStatus(); |
172 | for (const auto& d : proto.dim()) { |
173 | s = out->AddDimWithStatus(d.size()); |
174 | if (!s.ok()) { |
175 | return s; |
176 | } |
177 | } |
178 | } |
179 | return OkStatus(); |
180 | } |
181 | |
182 | template <class Shape> |
183 | TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes) { |
184 | set_tag(REP16); |
185 | set_data_type(DT_INVALID); |
186 | TF_CHECK_OK(InitDims(dim_sizes)); |
187 | } |
188 | |
189 | template <class Shape> |
190 | Status TensorShapeBase<Shape>::BuildTensorShapeBase( |
191 | gtl::ArraySlice<int64_t> dim_sizes, TensorShapeBase* out) { |
192 | out->set_tag(REP16); |
193 | out->set_data_type(DT_INVALID); |
194 | return out->InitDims(dim_sizes); |
195 | } |
196 | |
197 | // Returns true iff partial is true and val is < 0. |
198 | // REQUIRES: val < kMaxRep16 |
199 | // REQUIRES: partial || val >= 0 |
200 | static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) { |
201 | if (partial) { |
202 | if (val < 0) { |
203 | dst[dim] = std::numeric_limits<uint16>::max(); |
204 | return true; |
205 | } |
206 | } |
207 | dst[dim] = val; |
208 | return false; |
209 | } |
210 | |
211 | template <class Shape> |
212 | Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64_t> dim_sizes) { |
213 | DCHECK_EQ(tag(), REP16); |
214 | |
215 | // Allow sizes that are under kint64max^0.25 so that 4-way multiplication |
216 | // below cannot overflow. |
217 | static const int64_t kMaxSmall = 0xd744; |
218 | static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max, |
219 | "bad overflow check" ); |
220 | bool large_size = false; |
221 | for (auto s : dim_sizes) { |
222 | if (s > kMaxSmall) { |
223 | large_size = true; |
224 | break; |
225 | } |
226 | } |
227 | |
228 | if (!kIsPartial && !large_size) { |
229 | for (auto s : dim_sizes) { |
230 | if (TF_PREDICT_FALSE(s < 0)) { |
231 | return errors::InvalidArgument( |
232 | "Expected shape dimensions to be non-negative, got " , s); |
233 | } |
234 | } |
235 | } |
236 | |
237 | if (!large_size) { |
238 | // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}. |
239 | uint16* dst = as16()->dims_; |
240 | switch (dim_sizes.size()) { |
241 | case 1: { |
242 | set_ndims_byte(1); |
243 | const int64_t size = dim_sizes[0]; |
244 | const bool neg = Set16(kIsPartial, dst, 0, size); |
245 | set_num_elements(neg ? -1 : size); |
246 | return OkStatus(); |
247 | } |
248 | case 2: { |
249 | set_ndims_byte(2); |
250 | const int64_t size0 = dim_sizes[0]; |
251 | const int64_t size1 = dim_sizes[1]; |
252 | bool neg = Set16(kIsPartial, dst, 0, size0); |
253 | neg |= Set16(kIsPartial, dst, 1, size1); |
254 | set_num_elements(neg ? -1 : (size0 * size1)); |
255 | return OkStatus(); |
256 | } |
257 | case 3: { |
258 | set_ndims_byte(3); |
259 | const int64_t size0 = dim_sizes[0]; |
260 | const int64_t size1 = dim_sizes[1]; |
261 | const int64_t size2 = dim_sizes[2]; |
262 | bool neg = Set16(kIsPartial, dst, 0, size0); |
263 | neg |= Set16(kIsPartial, dst, 1, size1); |
264 | neg |= Set16(kIsPartial, dst, 2, size2); |
265 | set_num_elements(neg ? -1 : (size0 * size1 * size2)); |
266 | return OkStatus(); |
267 | } |
268 | case 4: { |
269 | set_ndims_byte(4); |
270 | const int64_t size0 = dim_sizes[0]; |
271 | const int64_t size1 = dim_sizes[1]; |
272 | const int64_t size2 = dim_sizes[2]; |
273 | const int64_t size3 = dim_sizes[3]; |
274 | bool neg = Set16(kIsPartial, dst, 0, size0); |
275 | neg |= Set16(kIsPartial, dst, 1, size1); |
276 | neg |= Set16(kIsPartial, dst, 2, size2); |
277 | neg |= Set16(kIsPartial, dst, 3, size3); |
278 | set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3)); |
279 | return OkStatus(); |
280 | } |
281 | } |
282 | } |
283 | |
284 | set_ndims_byte(0); |
285 | set_num_elements(1); |
286 | Status status = OkStatus(); |
287 | for (int64_t s : dim_sizes) { |
288 | status.Update(AddDimWithStatus(internal::SubtleMustCopy(s))); |
289 | if (!status.ok()) { |
290 | return status; |
291 | } |
292 | } |
293 | |
294 | return status; |
295 | } |
296 | |
297 | template <class Shape> |
298 | TensorShapeBase<Shape>::TensorShapeBase() { |
299 | set_tag(REP16); |
300 | set_data_type(DT_INVALID); |
301 | if (kIsPartial) { |
302 | set_ndims_byte(kUnknownRank); |
303 | set_num_elements(-1); |
304 | } else { |
305 | set_ndims_byte(0); |
306 | set_num_elements(1); |
307 | } |
308 | } |
309 | |
310 | void TensorShapeRep::DestructorOutOfLine() { |
311 | DCHECK(tag() == REP_OUT_OF_LINE); |
312 | delete as64()->dims_; |
313 | } |
314 | |
315 | void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) { |
316 | if (b.tag() != REP_OUT_OF_LINE) { |
317 | if (tag() == REP_OUT_OF_LINE) { |
318 | delete as64()->dims_; |
319 | } |
320 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
321 | // memcpy above implicitly also does: |
322 | // set_tag(b.tag()); |
323 | // set_ndims_byte(b.ndims_byte()); |
324 | // set_data_type(b.data_type()); |
325 | } else { |
326 | set_ndims_byte(b.ndims_byte()); |
327 | set_data_type(b.data_type()); |
328 | if (tag() == REP_OUT_OF_LINE) { |
329 | // vector already allocated |
330 | *(as64()->dims_) = *(b.as64()->dims_); |
331 | } else { |
332 | set_tag(REP_OUT_OF_LINE); |
333 | as64()->dims_ = new gtl::InlinedVector<int64_t, 4>(*(b.as64()->dims_)); |
334 | } |
335 | } |
336 | } |
337 | |
338 | template <class Shape> |
339 | int64_t TensorShapeBase<Shape>::dim_size(int d) const { |
340 | if (unknown_rank()) return -1; |
341 | DCHECK_GE(d, 0); |
342 | DCHECK_LT(d, dims()); |
343 | if (tag() == REP16) { |
344 | uint16 dim = as16()->dims_[d]; |
345 | if (kIsPartial && dim == kUnknownRep16) return -1; |
346 | return dim; |
347 | } else if (tag() == REP32) { |
348 | uint32 dim = as32()->dims_[d]; |
349 | if (kIsPartial && dim == kUnknownRep32) return -1; |
350 | return dim; |
351 | } else { |
352 | return (*as64()->dims_)[d]; |
353 | } |
354 | } |
355 | |
356 | void TensorShapeRep::Clear() { |
357 | ClearAllButDataType(); |
358 | set_data_type(DT_INVALID); |
359 | } |
360 | |
361 | void TensorShapeRep::ClearAllButDataType() { |
362 | if (tag() == REP_OUT_OF_LINE) { |
363 | delete as64()->dims_; |
364 | } |
365 | set_tag(REP16); |
366 | set_ndims_byte(0); |
367 | // Leaves data_type alone |
368 | set_num_elements(1); |
369 | } |
370 | |
371 | template <class Shape> |
372 | Status TensorShapeBase<Shape>::RecomputeNumElements() { |
373 | if (unknown_rank()) { |
374 | set_num_elements(-1); |
375 | return OkStatus(); |
376 | } |
377 | int64_t n = 1; |
378 | for (auto dim : *this) { |
379 | if (kIsPartial && dim.size < 0) { |
380 | n = -1; |
381 | break; |
382 | } |
383 | n = MultiplyWithoutOverflow(n, dim.size); |
384 | if (TF_PREDICT_FALSE(n < 0)) { |
385 | return errors::InvalidArgument( |
386 | "Shape " , this->DebugString(), |
387 | " results in overflow when computing number of elements" ); |
388 | } |
389 | } |
390 | set_num_elements(n); |
391 | return OkStatus(); |
392 | } |
393 | |
394 | template <class Shape> |
395 | void TensorShapeBase<Shape>::AddDim(int64_t size) { |
396 | if (!kIsPartial) CHECK_GE(size, 0); |
397 | if (unknown_rank()) return; |
398 | CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor" ; |
399 | int64_t new_num_elements; |
400 | if (kIsPartial && (num_elements() < 0 || size < 0)) { |
401 | new_num_elements = -1; |
402 | } else { |
403 | new_num_elements = MultiplyWithoutOverflow(num_elements(), size); |
404 | CHECK_LE(0, new_num_elements); |
405 | } |
406 | UnsafeAddDim(size, new_num_elements); |
407 | } |
408 | |
409 | template <class Shape> |
410 | Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) { |
411 | if (!kIsPartial) { |
412 | if (TF_PREDICT_FALSE(size < 0)) { |
413 | return errors::InvalidArgument("Expected a non-negative size, got " , |
414 | size); |
415 | } |
416 | } |
417 | |
418 | if (unknown_rank()) { |
419 | return OkStatus(); |
420 | } |
421 | |
422 | if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) { |
423 | return errors::InvalidArgument("Too many dimensions in tensor" ); |
424 | } |
425 | |
426 | int64_t new_num_elements; |
427 | if (kIsPartial && (num_elements() < 0 || size < 0)) { |
428 | new_num_elements = -1; |
429 | } else { |
430 | new_num_elements = MultiplyWithoutOverflow(num_elements(), size); |
431 | if (TF_PREDICT_FALSE(new_num_elements < 0)) { |
432 | return errors::InvalidArgument("Encountered overflow when multiplying " , |
433 | num_elements(), " with " , size, |
434 | ", result: " , new_num_elements); |
435 | } |
436 | } |
437 | |
438 | UnsafeAddDim(size, new_num_elements); |
439 | return OkStatus(); |
440 | } |
441 | |
442 | template <class Shape> |
443 | void TensorShapeBase<Shape>::UnsafeAddDim(int64_t size, |
444 | int64_t new_num_elements) { |
445 | const int nd = ndims_byte(); |
446 | if (tag() == REP16 && nd < 6 && size < kMaxRep16) { |
447 | as16()->dims_[nd] = |
448 | kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); |
449 | } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) { |
450 | as32()->dims_[nd] = |
451 | kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); |
452 | } else if (tag() == REP_OUT_OF_LINE) { |
453 | as64()->dims_->push_back(size); |
454 | } else { |
455 | // Need to change representation |
456 | gtl::InlinedVector<int64_t, 8> vals; |
457 | AppendTo(*this, &vals); |
458 | vals.push_back(size); |
459 | // We know we can't be REP16. See if we have a small enough |
460 | // number of dimensions and each dimension's size is small enough |
461 | // to allow REP32. |
462 | bool can_be_rep32 = (vals.size() <= 3); |
463 | if (can_be_rep32) { |
464 | for (size_t i = 0; i < vals.size(); i++) { |
465 | if (vals[i] >= kMaxRep32) { |
466 | can_be_rep32 = false; |
467 | break; |
468 | } |
469 | } |
470 | } |
471 | if (can_be_rep32) { |
472 | set_tag(REP32); |
473 | for (size_t d = 0; d < vals.size(); d++) { |
474 | as32()->dims_[d] = kIsPartial && vals[d] < 0 |
475 | ? kUnknownRep32 |
476 | : static_cast<uint32>(vals[d]); |
477 | } |
478 | } else { |
479 | set_tag(REP_OUT_OF_LINE); |
480 | as64()->dims_ = |
481 | new gtl::InlinedVector<int64_t, 4>(vals.begin(), vals.end()); |
482 | } |
483 | } |
484 | set_ndims_byte(nd + 1); |
485 | set_num_elements(new_num_elements); |
486 | } |
487 | |
488 | template <class Shape> |
489 | void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) { |
490 | for (auto d : shape) AddDim(d.size); |
491 | } |
492 | |
493 | template <class Shape> |
494 | Status TensorShapeBase<Shape>::AppendShapeWithStatus( |
495 | const TensorShapeBase& shape) { |
496 | Status s = OkStatus(); |
497 | for (auto d : shape) { |
498 | s.Update(AddDimWithStatus(d.size)); |
499 | if (!s.ok()) { |
500 | return s; |
501 | } |
502 | } |
503 | return s; |
504 | } |
505 | |
506 | template <class Shape> |
507 | void TensorShapeBase<Shape>::InsertDim(int d, int64_t size) { |
508 | CHECK_GE(d, 0); |
509 | CHECK_LE(d, dims()); |
510 | if (!kIsPartial) CHECK_GE(size, 0); |
511 | CHECK_LT(dims(), MaxDimensions()); |
512 | gtl::InlinedVector<int64_t, 8> vals; |
513 | AppendTo(*this, &vals); |
514 | vals.insert(vals.begin() + d, size); |
515 | ClearAllButDataType(); |
516 | for (auto dval : vals) { |
517 | AddDim(dval); |
518 | } |
519 | } |
520 | |
521 | template <class Shape> |
522 | Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) { |
523 | if (!kIsPartial) { |
524 | if (TF_PREDICT_FALSE(size < 0)) { |
525 | return errors::InvalidArgument("Expected a non-negative size, got " , |
526 | size); |
527 | } |
528 | } |
529 | |
530 | if (TF_PREDICT_FALSE(d < 0)) { |
531 | return errors::Internal("The insertion index must be non-negative, got " , |
532 | d); |
533 | } |
534 | if (TF_PREDICT_FALSE(d > dims())) { |
535 | return errors::Internal("The insertion index must be at most " , dims(), |
536 | " got " , d); |
537 | } |
538 | if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) { |
539 | return errors::Internal("Shape has " , dims(), |
540 | " dimensions which is the maximum allowed" ); |
541 | } |
542 | |
543 | gtl::InlinedVector<int64_t, 8> vals; |
544 | AppendTo(*this, &vals); |
545 | vals.insert(vals.begin() + d, size); |
546 | ClearAllButDataType(); |
547 | |
548 | Status s = OkStatus(); |
549 | for (auto dval : vals) { |
550 | s.Update(AddDimWithStatus(dval)); |
551 | if (!s.ok()) { |
552 | return s; |
553 | } |
554 | } |
555 | return s; |
556 | } |
557 | |
558 | template <class Shape> |
559 | gtl::InlinedVector<int64_t, 4> TensorShapeBase<Shape>::dim_sizes() const { |
560 | gtl::InlinedVector<int64_t, 4> result; |
561 | for (auto dim : *this) { |
562 | result.push_back(dim.size); |
563 | } |
564 | return result; |
565 | } |
566 | |
567 | template <class Shape> |
568 | void TensorShapeBase<Shape>::set_dim(int d, int64_t size) { |
569 | CHECK_GE(d, 0); |
570 | CHECK_LT(d, dims()); |
571 | if (!kIsPartial) { |
572 | CHECK_GE(size, 0); |
573 | } |
574 | if (tag() == REP16 && size < kMaxRep16) { |
575 | as16()->dims_[d] = |
576 | kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); |
577 | } else if (tag() == REP32 && size < kMaxRep32) { |
578 | as32()->dims_[d] = |
579 | kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); |
580 | } else if (tag() == REP_OUT_OF_LINE) { |
581 | (*as64()->dims_)[d] = size; |
582 | } else { |
583 | // Must upgrade |
584 | gtl::InlinedVector<int64_t, 8> vals; |
585 | AppendTo(*this, &vals); |
586 | vals[d] = size; |
587 | ClearAllButDataType(); |
588 | for (auto dval : vals) { |
589 | AddDim(dval); |
590 | } |
591 | } |
592 | TF_CHECK_OK(RecomputeNumElements()); |
593 | } |
594 | |
595 | template <class Shape> |
596 | Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) { |
597 | if (TF_PREDICT_FALSE(d < 0)) { |
598 | return errors::InvalidArgument("Index must be non-negative, got " , d); |
599 | } |
600 | if (TF_PREDICT_FALSE(d >= dims())) { |
601 | return errors::InvalidArgument("Index must be less than " , dims(), ", got " , |
602 | d); |
603 | } |
604 | if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) { |
605 | return errors::InvalidArgument("Expected a non-negative size, got " , size); |
606 | } |
607 | |
608 | if (tag() == REP16 && size < kMaxRep16) { |
609 | as16()->dims_[d] = |
610 | kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); |
611 | } else if (tag() == REP32 && size < kMaxRep32) { |
612 | as32()->dims_[d] = |
613 | kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); |
614 | } else if (tag() == REP_OUT_OF_LINE) { |
615 | (*as64()->dims_)[d] = size; |
616 | } else { |
617 | // Must upgrade |
618 | gtl::InlinedVector<int64_t, 8> vals; |
619 | AppendTo(*this, &vals); |
620 | vals[d] = size; |
621 | ClearAllButDataType(); |
622 | |
623 | Status s = OkStatus(); |
624 | for (auto dval : vals) { |
625 | s.Update(AddDimWithStatus(dval)); |
626 | if (!s.ok()) { |
627 | return s; |
628 | } |
629 | } |
630 | } |
631 | |
632 | return RecomputeNumElements(); |
633 | } |
634 | |
635 | template <class Shape> |
636 | void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) { |
637 | if (unknown_rank()) return; |
638 | begin = begin < 0 ? dims() + begin + 1 : begin; |
639 | end = end < 0 ? dims() + end + 1 : end; |
640 | CHECK_GE(begin, 0); |
641 | CHECK_LE(begin, dims()); |
642 | CHECK_GE(end, 0); |
643 | CHECK_LE(end, dims()); |
644 | if (begin >= end) return; |
645 | gtl::InlinedVector<int64_t, 8> vals; |
646 | AppendTo(*this, &vals); |
647 | vals.erase(vals.begin() + begin, vals.begin() + end); |
648 | ClearAllButDataType(); |
649 | for (auto dval : vals) { |
650 | AddDim(dval); |
651 | } |
652 | TF_CHECK_OK(RecomputeNumElements()); |
653 | } |
654 | |
655 | template <class Shape> |
656 | Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) { |
657 | if (unknown_rank()) { |
658 | return OkStatus(); |
659 | } |
660 | |
661 | begin = begin < 0 ? dims() + begin + 1 : begin; |
662 | end = end < 0 ? dims() + end + 1 : end; |
663 | |
664 | if (TF_PREDICT_FALSE(begin < 0)) { |
665 | return errors::Internal("Start index must be non-negative, got " , begin); |
666 | } |
667 | if (TF_PREDICT_FALSE(begin > dims())) { |
668 | return errors::Internal("Start index must be less than " , dims(), ", got " , |
669 | begin); |
670 | } |
671 | if (TF_PREDICT_FALSE(end < 0)) { |
672 | return errors::Internal("End index must be non-negative, got " , end); |
673 | } |
674 | if (TF_PREDICT_FALSE(end > dims())) { |
675 | return errors::Internal("End index must be less than " , dims(), ", got " , |
676 | end); |
677 | } |
678 | |
679 | if (begin >= end) { |
680 | return OkStatus(); |
681 | } |
682 | |
683 | gtl::InlinedVector<int64_t, 8> vals; |
684 | AppendTo(*this, &vals); |
685 | vals.erase(vals.begin() + begin, vals.begin() + end); |
686 | ClearAllButDataType(); |
687 | |
688 | Status s = OkStatus(); |
689 | for (auto dval : vals) { |
690 | s.Update(AddDimWithStatus(dval)); |
691 | if (!s.ok()) { |
692 | return s; |
693 | } |
694 | } |
695 | |
696 | return RecomputeNumElements(); |
697 | } |
698 | |
699 | bool TensorShape::IsSameSize(const TensorShape& b) const { |
700 | if (b.dims() != dims()) return false; |
701 | for (int d = 0; d < dims(); d++) { |
702 | if (dim_size(d) != b.dim_size(d)) return false; |
703 | } |
704 | return true; |
705 | } |
706 | |
707 | template <class Shape> |
708 | void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const { |
709 | proto->Clear(); |
710 | if (unknown_rank()) { |
711 | proto->set_unknown_rank(true); |
712 | } else { |
713 | for (int i = 0; i < dims(); i++) { |
714 | proto->add_dim()->set_size(dim_size(i)); |
715 | } |
716 | } |
717 | } |
718 | |
719 | template <class Shape> |
720 | TensorShapeProto TensorShapeBase<Shape>::AsProto() const { |
721 | TensorShapeProto out; |
722 | AsProto(&out); |
723 | return out; |
724 | } |
725 | |
726 | template <class Shape> |
727 | TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const { |
728 | return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0); |
729 | } |
730 | |
731 | template <class Shape> |
732 | TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const { |
733 | const int max_dim = unknown_rank() ? -1 : dims(); |
734 | return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim); |
735 | } |
736 | |
737 | string TensorShapeRep::DebugString() const { |
738 | const auto& shape = *static_cast<const PartialTensorShape*>(this); |
739 | if (shape.unknown_rank()) return "<unknown>" ; |
740 | string s = "[" ; |
741 | for (int i = 0; i < shape.dims(); i++) { |
742 | if (i > 0) strings::StrAppend(&s, "," ); |
743 | int64_t dim = shape.dim_size(i); |
744 | if (dim < 0) { |
745 | strings::StrAppend(&s, "?" ); |
746 | } else { |
747 | strings::StrAppend(&s, dim); |
748 | } |
749 | } |
750 | strings::StrAppend(&s, "]" ); |
751 | return s; |
752 | } |
753 | |
754 | string TensorShapeRep::DebugString(const TensorShapeProto& proto) { |
755 | string s; |
756 | if (proto.unknown_rank()) { |
757 | strings::StrAppend(&s, "<unknown>" ); |
758 | if (proto.dim_size() == 0) return s; |
759 | } |
760 | strings::StrAppend(&s, "[" ); |
761 | bool first = true; |
762 | for (const auto& d : proto.dim()) { |
763 | if (!first) strings::StrAppend(&s, "," ); |
764 | if (d.size() == -1) { |
765 | strings::StrAppend(&s, "?" ); |
766 | } else { |
767 | strings::StrAppend(&s, d.size()); |
768 | } |
769 | first = false; |
770 | } |
771 | strings::StrAppend(&s, "]" ); |
772 | return s; |
773 | } |
774 | |
775 | bool TensorShapeUtils::StartsWith(const TensorShape& shape, |
776 | const TensorShape& prefix) { |
777 | if (shape.dims() < prefix.dims()) return false; |
778 | for (int i = 0; i < prefix.dims(); ++i) { |
779 | if (shape.dim_size(i) != prefix.dim_size(i)) return false; |
780 | } |
781 | return true; |
782 | } |
783 | |
784 | bool TensorShapeUtils::EndsWith(const TensorShape& shape, |
785 | const TensorShape& suffix) { |
786 | const int suffix_size = suffix.dims(); |
787 | if (shape.dims() < suffix_size) return false; |
788 | for (int i = 0; i < suffix_size; ++i) { |
789 | if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) { |
790 | return false; |
791 | } |
792 | } |
793 | return true; |
794 | } |
795 | |
796 | template <typename T, class Shape> |
797 | Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) { |
798 | out->Clear(); |
799 | if (n > TensorShape::MaxDimensions()) { |
800 | return errors::InvalidArgument("Too many dimensions" ); |
801 | } |
802 | if (n < 0) { |
803 | return errors::InvalidArgument("Negative number of dimensions " , n); |
804 | } |
805 | for (int64_t i = 0; i < n; ++i) { |
806 | T dim = internal::SubtleMustCopy(dims[i]); |
807 | int64_t new_num_elements; |
808 | if (dim < 0) { |
809 | if (!out->kIsPartial) { |
810 | return errors::InvalidArgument("Dimension " , dim, " must be >= 0" ); |
811 | } |
812 | if (dim < -1) { |
813 | return errors::InvalidArgument("Dimension " , dim, " must be >= -1" ); |
814 | } |
815 | dim = -1; |
816 | new_num_elements = -1; |
817 | } else if (out->num_elements() < 0) { |
818 | new_num_elements = -1; |
819 | } else { |
820 | new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim); |
821 | if (TF_PREDICT_FALSE(new_num_elements < 0)) { |
822 | TensorShapeProto proto; |
823 | for (int64_t j = 0; j < n; ++j) { |
824 | proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j])); |
825 | } |
826 | return errors::InvalidArgument( |
827 | "Shape " , TensorShape::DebugString(proto), |
828 | " would have more than 2**63 - 1 elements" ); |
829 | } |
830 | } |
831 | out->UnsafeAddDim(dim, new_num_elements); |
832 | } |
833 | return OkStatus(); |
834 | } |
835 | |
836 | #define MAKE_SHAPE(T, Shape) \ |
837 | Status TensorShapeUtils::MakeShape(const T* dims, int64_t n, Shape* out) { \ |
838 | return MakeShapeHelper(dims, n, out); \ |
839 | } \ |
840 | Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \ |
841 | return MakeShapeHelper(shape.data(), shape.size(), out); \ |
842 | } |
843 | MAKE_SHAPE(int32, TensorShape) |
844 | MAKE_SHAPE(int64_t, TensorShape) |
845 | MAKE_SHAPE(int32, PartialTensorShape) |
846 | MAKE_SHAPE(int64_t, PartialTensorShape) |
847 | #undef MAKE_SHAPE |
848 | |
849 | string TensorShapeUtils::ShapeListString( |
850 | const gtl::ArraySlice<TensorShape>& shapes) { |
851 | string result = "[" ; |
852 | bool first = true; |
853 | for (const TensorShape& shape : shapes) { |
854 | strings::StrAppend(&result, (first ? "" : ", " ), shape.DebugString()); |
855 | first = false; |
856 | } |
857 | strings::StrAppend(&result, "]" ); |
858 | return result; |
859 | } |
860 | |
861 | PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const { |
862 | PartialTensorShape out = *this; |
863 | out.AddDim(size); |
864 | return out; |
865 | } |
866 | |
867 | Status PartialTensorShape::ConcatenateWithStatus( |
868 | int64_t size, PartialTensorShape* out) const { |
869 | out = const_cast<PartialTensorShape*>(this); |
870 | return out->AddDimWithStatus(size); |
871 | } |
872 | |
873 | PartialTensorShape PartialTensorShape::Concatenate( |
874 | const PartialTensorShape& shape) const { |
875 | if (unknown_rank() || shape.unknown_rank()) { |
876 | return PartialTensorShape(); |
877 | } |
878 | PartialTensorShape out = *this; |
879 | for (auto dim : shape) out.AddDim(dim.size); |
880 | return out; |
881 | } |
882 | |
883 | Status PartialTensorShape::ConcatenateWithStatus( |
884 | const PartialTensorShape& shape, PartialTensorShape* out) const { |
885 | if (unknown_rank() || shape.unknown_rank()) { |
886 | *out = PartialTensorShape(); |
887 | return OkStatus(); |
888 | } |
889 | out = const_cast<PartialTensorShape*>(this); |
890 | for (auto dim : shape) { |
891 | Status s = out->AddDimWithStatus(dim.size); |
892 | if (!s.ok()) return s; |
893 | } |
894 | |
895 | return OkStatus(); |
896 | } |
897 | |
898 | Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, |
899 | PartialTensorShape* result) const { |
900 | if (unknown_rank()) { |
901 | *result = shape; |
902 | return OkStatus(); |
903 | } |
904 | if (shape.unknown_rank()) { |
905 | *result = *this; |
906 | return OkStatus(); |
907 | } |
908 | const int dims_ = dims(); |
909 | if (dims_ != shape.dims()) { |
910 | return errors::InvalidArgument( |
911 | "PartialTensorShape: Incompatible ranks during merge: " , dims_, " vs. " , |
912 | shape.dims()); |
913 | } |
914 | |
915 | if (result == this) { |
916 | return errors::Internal( |
917 | "PartialTensorShape::MergeWith: cannot merge shape with itself" ); |
918 | } |
919 | |
920 | result->Clear(); |
921 | Status s = OkStatus(); |
922 | for (int i = 0; i < dims_; ++i) { |
923 | const int64_t dim0 = dim_size(i); |
924 | const int64_t dim1 = shape.dim_size(i); |
925 | if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) { |
926 | return errors::InvalidArgument( |
927 | "PartialTensorShape: Incompatible shapes during merge: " , |
928 | DebugString(), " vs. " , shape.DebugString()); |
929 | } |
930 | s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1)); |
931 | if (!s.ok()) { |
932 | return s; |
933 | } |
934 | } |
935 | return OkStatus(); |
936 | } |
937 | |
938 | bool PartialTensorShape::AsTensorShape(TensorShape* shape) const { |
939 | if (IsFullyDefined()) { |
940 | const TensorShapeRep* rep = this; |
941 | *shape = *static_cast<const TensorShape*>(rep); |
942 | return true; |
943 | } |
944 | return false; |
945 | } |
946 | |
947 | bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const { |
948 | if (unknown_rank() || shape.unknown_rank()) { |
949 | return unknown_rank() == shape.unknown_rank(); |
950 | } |
951 | if (dims() != shape.dims()) return false; |
952 | for (int i = 0; i < dims(); i++) { |
953 | if (dim_size(i) != shape.dim_size(i)) return false; |
954 | } |
955 | return true; |
956 | } |
957 | |
958 | bool PartialTensorShape::IsCompatibleWith( |
959 | const PartialTensorShape& shape) const { |
960 | if (unknown_rank() || shape.unknown_rank()) return true; |
961 | if (dims() != shape.dims()) return false; |
962 | for (int i = 0; i < dims(); i++) { |
963 | const int64_t dim0 = dim_size(i); |
964 | const int64_t dim1 = shape.dim_size(i); |
965 | if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false; |
966 | } |
967 | return true; |
968 | } |
969 | |
970 | string PartialTensorShapeUtils::PartialShapeListString( |
971 | const gtl::ArraySlice<PartialTensorShape>& shapes) { |
972 | string result = "[" ; |
973 | bool first = true; |
974 | for (const PartialTensorShape& shape : shapes) { |
975 | strings::StrAppend(&result, (first ? "" : ", " ), shape.DebugString()); |
976 | first = false; |
977 | } |
978 | strings::StrAppend(&result, "]" ); |
979 | return result; |
980 | } |
981 | |
982 | bool PartialTensorShapeUtils::AreCompatible( |
983 | const gtl::ArraySlice<PartialTensorShape>& shapes0, |
984 | const gtl::ArraySlice<PartialTensorShape>& shapes1) { |
985 | if (shapes0.size() == shapes1.size()) { |
986 | for (size_t i = 0; i < shapes0.size(); ++i) { |
987 | if (!shapes0[i].IsCompatibleWith(shapes1[i])) { |
988 | return false; |
989 | } |
990 | } |
991 | return true; |
992 | } else { |
993 | return false; |
994 | } |
995 | } |
996 | |
997 | bool PartialTensorShapeUtils::AreIdentical( |
998 | const gtl::ArraySlice<PartialTensorShape>& shapes0, |
999 | const gtl::ArraySlice<PartialTensorShape>& shapes1) { |
1000 | if (shapes0.size() == shapes1.size()) { |
1001 | for (size_t i = 0; i < shapes0.size(); ++i) { |
1002 | if (!shapes0[i].IsIdenticalTo(shapes1[i])) { |
1003 | return false; |
1004 | } |
1005 | } |
1006 | return true; |
1007 | } else { |
1008 | return false; |
1009 | } |
1010 | } |
1011 | |
1012 | Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64_t> shape, |
1013 | int64_t* num_elements) { |
1014 | int64_t n = 1; |
1015 | for (auto dim : shape) { |
1016 | n = MultiplyWithoutOverflow(n, dim); |
1017 | if (n < 0) { |
1018 | return errors::InvalidArgument("Can't compute total size of shape [" , |
1019 | absl::StrJoin(shape, "," ), |
1020 | "]; product would overflow int64" ); |
1021 | } |
1022 | } |
1023 | *num_elements = n; |
1024 | return OkStatus(); |
1025 | } |
1026 | |
1027 | template class TensorShapeBase<TensorShape>; |
1028 | template class TensorShapeBase<PartialTensorShape>; |
1029 | |
1030 | } // namespace tensorflow |
1031 | |