1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_shape.h"
17
18#include "tensorflow/core/framework/bounds_check.h"
19#include "tensorflow/core/framework/tensor_shape.pb.h"
20#include "tensorflow/core/lib/strings/str_util.h"
21#include "tensorflow/core/lib/strings/strcat.h"
22#include "tensorflow/core/platform/errors.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/util/overflow.h"
26
27namespace tensorflow {
28
29// TensorShape and PartialTensorShape should have no fields beyond
30// TensorShapeRep. In particular, their sizes should be the same.
31static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
32 "TensorShape must have no fields beyond TensorShapeRep");
33static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
34 "PartialTensorShape must have no fields beyond TensorShapeRep");
35
36template <class Shape>
37static void AppendTo(const TensorShapeBase<Shape>& s,
38 gtl::InlinedVector<int64, 8>* vals) {
39 for (auto dim : s) {
40 vals->push_back(dim.size);
41 }
42}
43
44void TensorShape::CheckDimsEqual(int NDIMS) const {
45 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
46 << " from a tensor of " << dims() << " dimensions";
47}
48
49void TensorShape::CheckDimsAtMost(int NDIMS) const {
50 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at most " << NDIMS
51 << " dimensions from a tensor of " << dims()
52 << " dimensions";
53}
54
55// TODO(slebedev): Consider merging IsValid implementations.
56template <class Shape>
57bool TensorShapeBase<Shape>::IsValid() {
58 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
59 // unknown_shape() set, and it seems hard to remove this without backwards
60 // compatibility issues.
61 if (kIsPartial && unknown_rank()) return dims() == 0;
62 int64_t num_elements = 1;
63 if (dims() > MaxDimensions()) return false;
64 for (auto d : dim_sizes()) {
65 if (d < (kIsPartial ? -1 : 0)) return false;
66 if (d == -1) {
67 num_elements = -1;
68 } else if (!kIsPartial || num_elements >= 0) {
69 num_elements = MultiplyWithoutOverflow(num_elements, d);
70 if (num_elements < 0) return false;
71 }
72 }
73 return true;
74}
75
76template <class Shape>
77bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
78 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
79 // unknown_shape() set, and it seems hard to remove this without backwards
80 // compatibility issues.
81 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
82 int64_t num_elements = 1;
83 if (proto.dim().size() > MaxDimensions()) return false;
84 for (const auto& d : proto.dim()) {
85 if (d.size() < (kIsPartial ? -1 : 0)) return false;
86 if (d.size() == -1) {
87 num_elements = -1;
88 } else if (!kIsPartial || num_elements >= 0) {
89 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
90 if (num_elements < 0) return false;
91 }
92 }
93 return true;
94}
95
96template <class Shape>
97Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
98 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
99 // unknown_shape() set, and it seems hard to remove this without backwards
100 // compatibility issues.
101 if (kIsPartial && proto.unknown_rank()) {
102 if (proto.dim_size() > 0) {
103 return errors::InvalidArgument(
104 "An unknown shape must not have any dimensions set.");
105 }
106 return OkStatus();
107 }
108 int64_t num_elements = 1;
109 if (proto.dim().size() > MaxDimensions()) {
110 return errors::InvalidArgument("Shape ", DebugString(proto),
111 " has too many dimensions");
112 }
113 for (const auto& d : proto.dim()) {
114 if (d.size() < (kIsPartial ? -1 : 0)) {
115 if (kIsPartial) {
116 return errors::InvalidArgument(
117 "Shape ", DebugString(proto),
118 " has dimensions with values below -1 (where -1 means unknown)");
119 } else {
120 return errors::InvalidArgument("Shape ", DebugString(proto),
121 " is not fully defined");
122 }
123 }
124 if (d.size() == -1) {
125 num_elements = -1;
126 } else if (!kIsPartial || num_elements >= 0) {
127 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
128 if (num_elements < 0) {
129 return errors::InvalidArgument(
130 "Shape ", DebugString(proto),
131 " is too large (more than 2**63 - 1 entries)");
132 }
133 }
134 }
135 return OkStatus();
136}
137
138template <class Shape>
139TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
140 set_tag(REP16);
141 set_data_type(DT_INVALID);
142 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
143 // unknown_shape() set, and it seems hard to remove this without backwards
144 // compatibility issues.
145 if (kIsPartial && proto.unknown_rank()) {
146 set_ndims_byte(kUnknownRank);
147 set_num_elements(-1);
148 } else {
149 set_ndims_byte(0);
150 set_num_elements(1);
151 for (const auto& d : proto.dim()) {
152 AddDim(d.size());
153 }
154 }
155}
156
157template <class Shape>
158Status TensorShapeBase<Shape>::BuildTensorShapeBase(
159 const TensorShapeProto& proto, TensorShapeBase* out) {
160 out->set_tag(REP16);
161 out->set_data_type(DT_INVALID);
162 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
163 // unknown_shape() set, and it seems hard to remove this without backwards
164 // compatibility issues.
165 if (kIsPartial && proto.unknown_rank()) {
166 out->set_ndims_byte(kUnknownRank);
167 out->set_num_elements(-1);
168 } else {
169 out->set_ndims_byte(0);
170 out->set_num_elements(1);
171 Status s = OkStatus();
172 for (const auto& d : proto.dim()) {
173 s = out->AddDimWithStatus(d.size());
174 if (!s.ok()) {
175 return s;
176 }
177 }
178 }
179 return OkStatus();
180}
181
182template <class Shape>
183TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes) {
184 set_tag(REP16);
185 set_data_type(DT_INVALID);
186 TF_CHECK_OK(InitDims(dim_sizes));
187}
188
189template <class Shape>
190Status TensorShapeBase<Shape>::BuildTensorShapeBase(
191 gtl::ArraySlice<int64_t> dim_sizes, TensorShapeBase* out) {
192 out->set_tag(REP16);
193 out->set_data_type(DT_INVALID);
194 return out->InitDims(dim_sizes);
195}
196
197// Returns true iff partial is true and val is < 0.
198// REQUIRES: val < kMaxRep16
199// REQUIRES: partial || val >= 0
200static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) {
201 if (partial) {
202 if (val < 0) {
203 dst[dim] = std::numeric_limits<uint16>::max();
204 return true;
205 }
206 }
207 dst[dim] = val;
208 return false;
209}
210
211template <class Shape>
212Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64_t> dim_sizes) {
213 DCHECK_EQ(tag(), REP16);
214
215 // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
216 // below cannot overflow.
217 static const int64_t kMaxSmall = 0xd744;
218 static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
219 "bad overflow check");
220 bool large_size = false;
221 for (auto s : dim_sizes) {
222 if (s > kMaxSmall) {
223 large_size = true;
224 break;
225 }
226 }
227
228 if (!kIsPartial && !large_size) {
229 for (auto s : dim_sizes) {
230 if (TF_PREDICT_FALSE(s < 0)) {
231 return errors::InvalidArgument(
232 "Expected shape dimensions to be non-negative, got ", s);
233 }
234 }
235 }
236
237 if (!large_size) {
238 // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
239 uint16* dst = as16()->dims_;
240 switch (dim_sizes.size()) {
241 case 1: {
242 set_ndims_byte(1);
243 const int64_t size = dim_sizes[0];
244 const bool neg = Set16(kIsPartial, dst, 0, size);
245 set_num_elements(neg ? -1 : size);
246 return OkStatus();
247 }
248 case 2: {
249 set_ndims_byte(2);
250 const int64_t size0 = dim_sizes[0];
251 const int64_t size1 = dim_sizes[1];
252 bool neg = Set16(kIsPartial, dst, 0, size0);
253 neg |= Set16(kIsPartial, dst, 1, size1);
254 set_num_elements(neg ? -1 : (size0 * size1));
255 return OkStatus();
256 }
257 case 3: {
258 set_ndims_byte(3);
259 const int64_t size0 = dim_sizes[0];
260 const int64_t size1 = dim_sizes[1];
261 const int64_t size2 = dim_sizes[2];
262 bool neg = Set16(kIsPartial, dst, 0, size0);
263 neg |= Set16(kIsPartial, dst, 1, size1);
264 neg |= Set16(kIsPartial, dst, 2, size2);
265 set_num_elements(neg ? -1 : (size0 * size1 * size2));
266 return OkStatus();
267 }
268 case 4: {
269 set_ndims_byte(4);
270 const int64_t size0 = dim_sizes[0];
271 const int64_t size1 = dim_sizes[1];
272 const int64_t size2 = dim_sizes[2];
273 const int64_t size3 = dim_sizes[3];
274 bool neg = Set16(kIsPartial, dst, 0, size0);
275 neg |= Set16(kIsPartial, dst, 1, size1);
276 neg |= Set16(kIsPartial, dst, 2, size2);
277 neg |= Set16(kIsPartial, dst, 3, size3);
278 set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
279 return OkStatus();
280 }
281 }
282 }
283
284 set_ndims_byte(0);
285 set_num_elements(1);
286 Status status = OkStatus();
287 for (int64_t s : dim_sizes) {
288 status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
289 if (!status.ok()) {
290 return status;
291 }
292 }
293
294 return status;
295}
296
297template <class Shape>
298TensorShapeBase<Shape>::TensorShapeBase() {
299 set_tag(REP16);
300 set_data_type(DT_INVALID);
301 if (kIsPartial) {
302 set_ndims_byte(kUnknownRank);
303 set_num_elements(-1);
304 } else {
305 set_ndims_byte(0);
306 set_num_elements(1);
307 }
308}
309
310void TensorShapeRep::DestructorOutOfLine() {
311 DCHECK(tag() == REP_OUT_OF_LINE);
312 delete as64()->dims_;
313}
314
315void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
316 if (b.tag() != REP_OUT_OF_LINE) {
317 if (tag() == REP_OUT_OF_LINE) {
318 delete as64()->dims_;
319 }
320 memcpy(buf(), b.buf(), sizeof(u_.buf));
321 // memcpy above implicitly also does:
322 // set_tag(b.tag());
323 // set_ndims_byte(b.ndims_byte());
324 // set_data_type(b.data_type());
325 } else {
326 set_ndims_byte(b.ndims_byte());
327 set_data_type(b.data_type());
328 if (tag() == REP_OUT_OF_LINE) {
329 // vector already allocated
330 *(as64()->dims_) = *(b.as64()->dims_);
331 } else {
332 set_tag(REP_OUT_OF_LINE);
333 as64()->dims_ = new gtl::InlinedVector<int64_t, 4>(*(b.as64()->dims_));
334 }
335 }
336}
337
338template <class Shape>
339int64_t TensorShapeBase<Shape>::dim_size(int d) const {
340 if (unknown_rank()) return -1;
341 DCHECK_GE(d, 0);
342 DCHECK_LT(d, dims());
343 if (tag() == REP16) {
344 uint16 dim = as16()->dims_[d];
345 if (kIsPartial && dim == kUnknownRep16) return -1;
346 return dim;
347 } else if (tag() == REP32) {
348 uint32 dim = as32()->dims_[d];
349 if (kIsPartial && dim == kUnknownRep32) return -1;
350 return dim;
351 } else {
352 return (*as64()->dims_)[d];
353 }
354}
355
356void TensorShapeRep::Clear() {
357 ClearAllButDataType();
358 set_data_type(DT_INVALID);
359}
360
361void TensorShapeRep::ClearAllButDataType() {
362 if (tag() == REP_OUT_OF_LINE) {
363 delete as64()->dims_;
364 }
365 set_tag(REP16);
366 set_ndims_byte(0);
367 // Leaves data_type alone
368 set_num_elements(1);
369}
370
371template <class Shape>
372Status TensorShapeBase<Shape>::RecomputeNumElements() {
373 if (unknown_rank()) {
374 set_num_elements(-1);
375 return OkStatus();
376 }
377 int64_t n = 1;
378 for (auto dim : *this) {
379 if (kIsPartial && dim.size < 0) {
380 n = -1;
381 break;
382 }
383 n = MultiplyWithoutOverflow(n, dim.size);
384 if (TF_PREDICT_FALSE(n < 0)) {
385 return errors::InvalidArgument(
386 "Shape ", this->DebugString(),
387 " results in overflow when computing number of elements");
388 }
389 }
390 set_num_elements(n);
391 return OkStatus();
392}
393
394template <class Shape>
395void TensorShapeBase<Shape>::AddDim(int64_t size) {
396 if (!kIsPartial) CHECK_GE(size, 0);
397 if (unknown_rank()) return;
398 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
399 int64_t new_num_elements;
400 if (kIsPartial && (num_elements() < 0 || size < 0)) {
401 new_num_elements = -1;
402 } else {
403 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
404 CHECK_LE(0, new_num_elements);
405 }
406 UnsafeAddDim(size, new_num_elements);
407}
408
409template <class Shape>
410Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
411 if (!kIsPartial) {
412 if (TF_PREDICT_FALSE(size < 0)) {
413 return errors::InvalidArgument("Expected a non-negative size, got ",
414 size);
415 }
416 }
417
418 if (unknown_rank()) {
419 return OkStatus();
420 }
421
422 if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423 return errors::InvalidArgument("Too many dimensions in tensor");
424 }
425
426 int64_t new_num_elements;
427 if (kIsPartial && (num_elements() < 0 || size < 0)) {
428 new_num_elements = -1;
429 } else {
430 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432 return errors::InvalidArgument("Encountered overflow when multiplying ",
433 num_elements(), " with ", size,
434 ", result: ", new_num_elements);
435 }
436 }
437
438 UnsafeAddDim(size, new_num_elements);
439 return OkStatus();
440}
441
442template <class Shape>
443void TensorShapeBase<Shape>::UnsafeAddDim(int64_t size,
444 int64_t new_num_elements) {
445 const int nd = ndims_byte();
446 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
447 as16()->dims_[nd] =
448 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
449 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
450 as32()->dims_[nd] =
451 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
452 } else if (tag() == REP_OUT_OF_LINE) {
453 as64()->dims_->push_back(size);
454 } else {
455 // Need to change representation
456 gtl::InlinedVector<int64_t, 8> vals;
457 AppendTo(*this, &vals);
458 vals.push_back(size);
459 // We know we can't be REP16. See if we have a small enough
460 // number of dimensions and each dimension's size is small enough
461 // to allow REP32.
462 bool can_be_rep32 = (vals.size() <= 3);
463 if (can_be_rep32) {
464 for (size_t i = 0; i < vals.size(); i++) {
465 if (vals[i] >= kMaxRep32) {
466 can_be_rep32 = false;
467 break;
468 }
469 }
470 }
471 if (can_be_rep32) {
472 set_tag(REP32);
473 for (size_t d = 0; d < vals.size(); d++) {
474 as32()->dims_[d] = kIsPartial && vals[d] < 0
475 ? kUnknownRep32
476 : static_cast<uint32>(vals[d]);
477 }
478 } else {
479 set_tag(REP_OUT_OF_LINE);
480 as64()->dims_ =
481 new gtl::InlinedVector<int64_t, 4>(vals.begin(), vals.end());
482 }
483 }
484 set_ndims_byte(nd + 1);
485 set_num_elements(new_num_elements);
486}
487
488template <class Shape>
489void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
490 for (auto d : shape) AddDim(d.size);
491}
492
493template <class Shape>
494Status TensorShapeBase<Shape>::AppendShapeWithStatus(
495 const TensorShapeBase& shape) {
496 Status s = OkStatus();
497 for (auto d : shape) {
498 s.Update(AddDimWithStatus(d.size));
499 if (!s.ok()) {
500 return s;
501 }
502 }
503 return s;
504}
505
506template <class Shape>
507void TensorShapeBase<Shape>::InsertDim(int d, int64_t size) {
508 CHECK_GE(d, 0);
509 CHECK_LE(d, dims());
510 if (!kIsPartial) CHECK_GE(size, 0);
511 CHECK_LT(dims(), MaxDimensions());
512 gtl::InlinedVector<int64_t, 8> vals;
513 AppendTo(*this, &vals);
514 vals.insert(vals.begin() + d, size);
515 ClearAllButDataType();
516 for (auto dval : vals) {
517 AddDim(dval);
518 }
519}
520
521template <class Shape>
522Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) {
523 if (!kIsPartial) {
524 if (TF_PREDICT_FALSE(size < 0)) {
525 return errors::InvalidArgument("Expected a non-negative size, got ",
526 size);
527 }
528 }
529
530 if (TF_PREDICT_FALSE(d < 0)) {
531 return errors::Internal("The insertion index must be non-negative, got ",
532 d);
533 }
534 if (TF_PREDICT_FALSE(d > dims())) {
535 return errors::Internal("The insertion index must be at most ", dims(),
536 " got ", d);
537 }
538 if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
539 return errors::Internal("Shape has ", dims(),
540 " dimensions which is the maximum allowed");
541 }
542
543 gtl::InlinedVector<int64_t, 8> vals;
544 AppendTo(*this, &vals);
545 vals.insert(vals.begin() + d, size);
546 ClearAllButDataType();
547
548 Status s = OkStatus();
549 for (auto dval : vals) {
550 s.Update(AddDimWithStatus(dval));
551 if (!s.ok()) {
552 return s;
553 }
554 }
555 return s;
556}
557
558template <class Shape>
559gtl::InlinedVector<int64_t, 4> TensorShapeBase<Shape>::dim_sizes() const {
560 gtl::InlinedVector<int64_t, 4> result;
561 for (auto dim : *this) {
562 result.push_back(dim.size);
563 }
564 return result;
565}
566
567template <class Shape>
568void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
569 CHECK_GE(d, 0);
570 CHECK_LT(d, dims());
571 if (!kIsPartial) {
572 CHECK_GE(size, 0);
573 }
574 if (tag() == REP16 && size < kMaxRep16) {
575 as16()->dims_[d] =
576 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
577 } else if (tag() == REP32 && size < kMaxRep32) {
578 as32()->dims_[d] =
579 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
580 } else if (tag() == REP_OUT_OF_LINE) {
581 (*as64()->dims_)[d] = size;
582 } else {
583 // Must upgrade
584 gtl::InlinedVector<int64_t, 8> vals;
585 AppendTo(*this, &vals);
586 vals[d] = size;
587 ClearAllButDataType();
588 for (auto dval : vals) {
589 AddDim(dval);
590 }
591 }
592 TF_CHECK_OK(RecomputeNumElements());
593}
594
595template <class Shape>
596Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
597 if (TF_PREDICT_FALSE(d < 0)) {
598 return errors::InvalidArgument("Index must be non-negative, got ", d);
599 }
600 if (TF_PREDICT_FALSE(d >= dims())) {
601 return errors::InvalidArgument("Index must be less than ", dims(), ", got ",
602 d);
603 }
604 if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) {
605 return errors::InvalidArgument("Expected a non-negative size, got ", size);
606 }
607
608 if (tag() == REP16 && size < kMaxRep16) {
609 as16()->dims_[d] =
610 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
611 } else if (tag() == REP32 && size < kMaxRep32) {
612 as32()->dims_[d] =
613 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
614 } else if (tag() == REP_OUT_OF_LINE) {
615 (*as64()->dims_)[d] = size;
616 } else {
617 // Must upgrade
618 gtl::InlinedVector<int64_t, 8> vals;
619 AppendTo(*this, &vals);
620 vals[d] = size;
621 ClearAllButDataType();
622
623 Status s = OkStatus();
624 for (auto dval : vals) {
625 s.Update(AddDimWithStatus(dval));
626 if (!s.ok()) {
627 return s;
628 }
629 }
630 }
631
632 return RecomputeNumElements();
633}
634
635template <class Shape>
636void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
637 if (unknown_rank()) return;
638 begin = begin < 0 ? dims() + begin + 1 : begin;
639 end = end < 0 ? dims() + end + 1 : end;
640 CHECK_GE(begin, 0);
641 CHECK_LE(begin, dims());
642 CHECK_GE(end, 0);
643 CHECK_LE(end, dims());
644 if (begin >= end) return;
645 gtl::InlinedVector<int64_t, 8> vals;
646 AppendTo(*this, &vals);
647 vals.erase(vals.begin() + begin, vals.begin() + end);
648 ClearAllButDataType();
649 for (auto dval : vals) {
650 AddDim(dval);
651 }
652 TF_CHECK_OK(RecomputeNumElements());
653}
654
655template <class Shape>
656Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
657 if (unknown_rank()) {
658 return OkStatus();
659 }
660
661 begin = begin < 0 ? dims() + begin + 1 : begin;
662 end = end < 0 ? dims() + end + 1 : end;
663
664 if (TF_PREDICT_FALSE(begin < 0)) {
665 return errors::Internal("Start index must be non-negative, got ", begin);
666 }
667 if (TF_PREDICT_FALSE(begin > dims())) {
668 return errors::Internal("Start index must be less than ", dims(), ", got ",
669 begin);
670 }
671 if (TF_PREDICT_FALSE(end < 0)) {
672 return errors::Internal("End index must be non-negative, got ", end);
673 }
674 if (TF_PREDICT_FALSE(end > dims())) {
675 return errors::Internal("End index must be less than ", dims(), ", got ",
676 end);
677 }
678
679 if (begin >= end) {
680 return OkStatus();
681 }
682
683 gtl::InlinedVector<int64_t, 8> vals;
684 AppendTo(*this, &vals);
685 vals.erase(vals.begin() + begin, vals.begin() + end);
686 ClearAllButDataType();
687
688 Status s = OkStatus();
689 for (auto dval : vals) {
690 s.Update(AddDimWithStatus(dval));
691 if (!s.ok()) {
692 return s;
693 }
694 }
695
696 return RecomputeNumElements();
697}
698
699bool TensorShape::IsSameSize(const TensorShape& b) const {
700 if (b.dims() != dims()) return false;
701 for (int d = 0; d < dims(); d++) {
702 if (dim_size(d) != b.dim_size(d)) return false;
703 }
704 return true;
705}
706
707template <class Shape>
708void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
709 proto->Clear();
710 if (unknown_rank()) {
711 proto->set_unknown_rank(true);
712 } else {
713 for (int i = 0; i < dims(); i++) {
714 proto->add_dim()->set_size(dim_size(i));
715 }
716 }
717}
718
719template <class Shape>
720TensorShapeProto TensorShapeBase<Shape>::AsProto() const {
721 TensorShapeProto out;
722 AsProto(&out);
723 return out;
724}
725
726template <class Shape>
727TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
728 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
729}
730
731template <class Shape>
732TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
733 const int max_dim = unknown_rank() ? -1 : dims();
734 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
735}
736
737string TensorShapeRep::DebugString() const {
738 const auto& shape = *static_cast<const PartialTensorShape*>(this);
739 if (shape.unknown_rank()) return "<unknown>";
740 string s = "[";
741 for (int i = 0; i < shape.dims(); i++) {
742 if (i > 0) strings::StrAppend(&s, ",");
743 int64_t dim = shape.dim_size(i);
744 if (dim < 0) {
745 strings::StrAppend(&s, "?");
746 } else {
747 strings::StrAppend(&s, dim);
748 }
749 }
750 strings::StrAppend(&s, "]");
751 return s;
752}
753
754string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
755 string s;
756 if (proto.unknown_rank()) {
757 strings::StrAppend(&s, "<unknown>");
758 if (proto.dim_size() == 0) return s;
759 }
760 strings::StrAppend(&s, "[");
761 bool first = true;
762 for (const auto& d : proto.dim()) {
763 if (!first) strings::StrAppend(&s, ",");
764 if (d.size() == -1) {
765 strings::StrAppend(&s, "?");
766 } else {
767 strings::StrAppend(&s, d.size());
768 }
769 first = false;
770 }
771 strings::StrAppend(&s, "]");
772 return s;
773}
774
775bool TensorShapeUtils::StartsWith(const TensorShape& shape,
776 const TensorShape& prefix) {
777 if (shape.dims() < prefix.dims()) return false;
778 for (int i = 0; i < prefix.dims(); ++i) {
779 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
780 }
781 return true;
782}
783
784bool TensorShapeUtils::EndsWith(const TensorShape& shape,
785 const TensorShape& suffix) {
786 const int suffix_size = suffix.dims();
787 if (shape.dims() < suffix_size) return false;
788 for (int i = 0; i < suffix_size; ++i) {
789 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
790 return false;
791 }
792 }
793 return true;
794}
795
796template <typename T, class Shape>
797Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) {
798 out->Clear();
799 if (n > TensorShape::MaxDimensions()) {
800 return errors::InvalidArgument("Too many dimensions");
801 }
802 if (n < 0) {
803 return errors::InvalidArgument("Negative number of dimensions ", n);
804 }
805 for (int64_t i = 0; i < n; ++i) {
806 T dim = internal::SubtleMustCopy(dims[i]);
807 int64_t new_num_elements;
808 if (dim < 0) {
809 if (!out->kIsPartial) {
810 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
811 }
812 if (dim < -1) {
813 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
814 }
815 dim = -1;
816 new_num_elements = -1;
817 } else if (out->num_elements() < 0) {
818 new_num_elements = -1;
819 } else {
820 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
821 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
822 TensorShapeProto proto;
823 for (int64_t j = 0; j < n; ++j) {
824 proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
825 }
826 return errors::InvalidArgument(
827 "Shape ", TensorShape::DebugString(proto),
828 " would have more than 2**63 - 1 elements");
829 }
830 }
831 out->UnsafeAddDim(dim, new_num_elements);
832 }
833 return OkStatus();
834}
835
836#define MAKE_SHAPE(T, Shape) \
837 Status TensorShapeUtils::MakeShape(const T* dims, int64_t n, Shape* out) { \
838 return MakeShapeHelper(dims, n, out); \
839 } \
840 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
841 return MakeShapeHelper(shape.data(), shape.size(), out); \
842 }
843MAKE_SHAPE(int32, TensorShape)
844MAKE_SHAPE(int64_t, TensorShape)
845MAKE_SHAPE(int32, PartialTensorShape)
846MAKE_SHAPE(int64_t, PartialTensorShape)
847#undef MAKE_SHAPE
848
849string TensorShapeUtils::ShapeListString(
850 const gtl::ArraySlice<TensorShape>& shapes) {
851 string result = "[";
852 bool first = true;
853 for (const TensorShape& shape : shapes) {
854 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
855 first = false;
856 }
857 strings::StrAppend(&result, "]");
858 return result;
859}
860
861PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const {
862 PartialTensorShape out = *this;
863 out.AddDim(size);
864 return out;
865}
866
867Status PartialTensorShape::ConcatenateWithStatus(
868 int64_t size, PartialTensorShape* out) const {
869 out = const_cast<PartialTensorShape*>(this);
870 return out->AddDimWithStatus(size);
871}
872
873PartialTensorShape PartialTensorShape::Concatenate(
874 const PartialTensorShape& shape) const {
875 if (unknown_rank() || shape.unknown_rank()) {
876 return PartialTensorShape();
877 }
878 PartialTensorShape out = *this;
879 for (auto dim : shape) out.AddDim(dim.size);
880 return out;
881}
882
883Status PartialTensorShape::ConcatenateWithStatus(
884 const PartialTensorShape& shape, PartialTensorShape* out) const {
885 if (unknown_rank() || shape.unknown_rank()) {
886 *out = PartialTensorShape();
887 return OkStatus();
888 }
889 out = const_cast<PartialTensorShape*>(this);
890 for (auto dim : shape) {
891 Status s = out->AddDimWithStatus(dim.size);
892 if (!s.ok()) return s;
893 }
894
895 return OkStatus();
896}
897
898Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
899 PartialTensorShape* result) const {
900 if (unknown_rank()) {
901 *result = shape;
902 return OkStatus();
903 }
904 if (shape.unknown_rank()) {
905 *result = *this;
906 return OkStatus();
907 }
908 const int dims_ = dims();
909 if (dims_ != shape.dims()) {
910 return errors::InvalidArgument(
911 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
912 shape.dims());
913 }
914
915 if (result == this) {
916 return errors::Internal(
917 "PartialTensorShape::MergeWith: cannot merge shape with itself");
918 }
919
920 result->Clear();
921 Status s = OkStatus();
922 for (int i = 0; i < dims_; ++i) {
923 const int64_t dim0 = dim_size(i);
924 const int64_t dim1 = shape.dim_size(i);
925 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
926 return errors::InvalidArgument(
927 "PartialTensorShape: Incompatible shapes during merge: ",
928 DebugString(), " vs. ", shape.DebugString());
929 }
930 s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
931 if (!s.ok()) {
932 return s;
933 }
934 }
935 return OkStatus();
936}
937
938bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
939 if (IsFullyDefined()) {
940 const TensorShapeRep* rep = this;
941 *shape = *static_cast<const TensorShape*>(rep);
942 return true;
943 }
944 return false;
945}
946
947bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
948 if (unknown_rank() || shape.unknown_rank()) {
949 return unknown_rank() == shape.unknown_rank();
950 }
951 if (dims() != shape.dims()) return false;
952 for (int i = 0; i < dims(); i++) {
953 if (dim_size(i) != shape.dim_size(i)) return false;
954 }
955 return true;
956}
957
958bool PartialTensorShape::IsCompatibleWith(
959 const PartialTensorShape& shape) const {
960 if (unknown_rank() || shape.unknown_rank()) return true;
961 if (dims() != shape.dims()) return false;
962 for (int i = 0; i < dims(); i++) {
963 const int64_t dim0 = dim_size(i);
964 const int64_t dim1 = shape.dim_size(i);
965 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
966 }
967 return true;
968}
969
970string PartialTensorShapeUtils::PartialShapeListString(
971 const gtl::ArraySlice<PartialTensorShape>& shapes) {
972 string result = "[";
973 bool first = true;
974 for (const PartialTensorShape& shape : shapes) {
975 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
976 first = false;
977 }
978 strings::StrAppend(&result, "]");
979 return result;
980}
981
982bool PartialTensorShapeUtils::AreCompatible(
983 const gtl::ArraySlice<PartialTensorShape>& shapes0,
984 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
985 if (shapes0.size() == shapes1.size()) {
986 for (size_t i = 0; i < shapes0.size(); ++i) {
987 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
988 return false;
989 }
990 }
991 return true;
992 } else {
993 return false;
994 }
995}
996
997bool PartialTensorShapeUtils::AreIdentical(
998 const gtl::ArraySlice<PartialTensorShape>& shapes0,
999 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
1000 if (shapes0.size() == shapes1.size()) {
1001 for (size_t i = 0; i < shapes0.size(); ++i) {
1002 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
1003 return false;
1004 }
1005 }
1006 return true;
1007 } else {
1008 return false;
1009 }
1010}
1011
1012Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64_t> shape,
1013 int64_t* num_elements) {
1014 int64_t n = 1;
1015 for (auto dim : shape) {
1016 n = MultiplyWithoutOverflow(n, dim);
1017 if (n < 0) {
1018 return errors::InvalidArgument("Can't compute total size of shape [",
1019 absl::StrJoin(shape, ","),
1020 "]; product would overflow int64");
1021 }
1022 }
1023 *num_elements = n;
1024 return OkStatus();
1025}
1026
1027template class TensorShapeBase<TensorShape>;
1028template class TensorShapeBase<PartialTensorShape>;
1029
1030} // namespace tensorflow
1031