1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19#include <string>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/types.pb.h"
23#include "tensorflow/core/lib/gtl/array_slice.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/platform/errors.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/macros.h"
29#include "tensorflow/core/platform/statusor.h"
30
31namespace tensorflow {
32
33// START_SKIP_DOXYGEN
34template <class Shape>
35class TensorShapeIter;
36class TensorShape;
37class TensorShapeProto;
38class PartialTensorShape;
39// END_SKIP_DOXYGEN
40
41/// Internal representation for both TensorShape and PartialTensorShape.
42class TensorShapeRep {
43 public:
44 ~TensorShapeRep();
45
46 /// Copy the specified shape
47 TensorShapeRep(const TensorShapeRep& b);
48 void operator=(const TensorShapeRep& b);
49
50 /// Move the specified shape. After moving, `b` is safe for destruction and
51 // can be reassigned into, but its dimensions and number of elements can be
52 // nonsensical (e.g., negative dimension sizes, or number of elements not
53 // properly recomputed).
54 TensorShapeRep(TensorShapeRep&& b);
55 void operator=(TensorShapeRep&& b);
56
57 /// Clear a tensor shape, producing the scalar shape.
58 void Clear();
59
60 // Maximum number of dimensions in a tensor.
61 // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
62 static constexpr int MaxDimensions() { return 254; }
63
64 /// \brief Returns the number of elements in the tensor.
65 ///
66 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully
68 /// defined.
69 int64_t num_elements() const { return num_elements_; }
70
71 /// For error messages.
72 std::string DebugString() const;
73 static std::string DebugString(const TensorShapeProto& proto);
74
75 protected:
76 // Constructable only via TensorShapeBase
77 TensorShapeRep() = default;
78
79 void ClearAllButDataType();
80
81 // We use 16 bytes to represent a TensorShape. Because we need to
82 // be able to support full 64-bit dimension sizes and an arbitrary
83 // number of dimensions for a Tensor, but most tensor dimensions are
84 // significantly smaller than 64 bits and most tensors are 1, 2, or 3
85 // dimensions, we have several representations.
86 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
87 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
88 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
89 // an out of line vector.
90 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
91 // This value is not allowed in TensorShape either for format compatibility.
92 struct Rep16 {
93 uint16 dims_[6];
94 };
95 struct Rep32 {
96 uint32 dims_[3];
97 };
98 struct Rep64 {
99 gtl::InlinedVector<int64_t, 4>* dims_;
100 };
101
102 // We use the max value of uint16 or uint32 to represent unknown shapes, so
103 // the maximum representable valid shape in these representations is one less.
104 static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
105 static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
106 static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
107 static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
108
109 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
110 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
111 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
112
113 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
114 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
115 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
116
117 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
118
119 // Since we have a convenient extra byte available, we allow the
120 // Tensor class to store an 8-bit value in this extra storage. This
121 // allows it to store the Tensor's datatype enum value here and avoid
122 // an extra word of storage.
123 friend class Tensor;
124 friend class TensorShapeTestHelper;
125 DataType data_type() const { return static_cast<DataType>(buf()[13]); }
126 void set_data_type(DataType dt) {
127 // We only have 8 bits available to store DataType, so make sure it fits
128 DCHECK_LT(static_cast<uint32>(dt), 256u);
129 buf()[13] = static_cast<uint8>(dt);
130 }
131
132 // We store the number of dimensions in byte 14, and the RepTag in byte 15.
133 // Bytes [0..13] vary depending on the representation.
134 // A value of 255 indicates unknown rank in the PartialTensorShape case.
135 static constexpr uint8 kUnknownRank = 255;
136 uint8 ndims_byte() const { return buf()[14]; }
137 void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
138
139 RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
140 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
141
142 void set_num_elements(int64_t n) { num_elements_ = n; }
143
144 private:
145 void DestructorOutOfLine();
146 void SlowCopyFrom(const TensorShapeRep& b);
147
148 uint8* buf() { return &u_.buf[0]; }
149 const uint8* buf() const { return &u_.buf[0]; }
150
151 union {
152 uint8 buf[16];
153 // Force data to be aligned enough for a pointer.
154 Rep64* unused_aligner;
155 } u_;
156 int64_t num_elements_;
157};
158
159/// Base class for TensorShape and PartialTensorShape.
160/// The class is templatized by either TensorShape or PartialTensorShape to
161/// allow skipping known/unknown checks in the TensorShape case, but the
162/// representation is shared exactly for fast conversion.
163template <class Shape>
164class TensorShapeBase : public TensorShapeRep {
165 public:
166 /// \brief Construct a `TensorShapeBase` from the provided sizes.
167 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
168 explicit TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes);
169 TensorShapeBase(std::initializer_list<int64_t> dim_sizes)
170 : TensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes)) {}
171
172 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
173 TensorShapeBase();
174
175 // Cannot be made explicit because we rely on conversion between proto and
176 // `TensorShapeBase` throughtout the codebase (needs bigger cleanup)
177 TensorShapeBase(const TensorShapeProto& proto);
178
179 // These factory methods should be used instead of the constructors that take
180 // an array of sizes if calling code cannot validate that the sizes specify a
181 // valid `TensorShape`.
182 // The value in `*out` is valid iff the returned value is `Status::OK`.
183 static Status BuildTensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes,
184 TensorShapeBase* out);
185 static Status BuildTensorShapeBase(std::initializer_list<int64_t> dim_sizes,
186 TensorShapeBase* out) {
187 return BuildTensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes), out);
188 }
189 static Status BuildTensorShapeBase(const TensorShapeProto& proto,
190 TensorShapeBase* out);
191
192 /// Returns `true` iff `proto` is a valid tensor shape.
193 // For TensorShape, the proto shape must be fully defined.
194 static bool IsValid(const TensorShapeProto& proto);
195
196 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
197 /// status otherwise.
198 static Status IsValidShape(const TensorShapeProto& proto);
199
200 /// Returns `true` iff this is a valid tensor shape.
201 bool IsValid();
202
203 /// \brief Add a dimension to the end ("inner-most").
204 /// REQUIRES: `size >= 0`
205 void AddDim(int64_t size);
206
207 /// Same as `AddDim` but returns a `Status`.
208 /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
209 Status AddDimWithStatus(int64_t size);
210
211 /// Appends all the dimensions from `shape`.
212 void AppendShape(const TensorShapeBase& shape);
213
214 /// Same as `RemoveDim` but returns a `Status`.
215 /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
216 Status AppendShapeWithStatus(const TensorShapeBase& shape);
217
218 /// \brief Insert a dimension somewhere in the `TensorShape`.
219 /// REQUIRES: `0 <= d <= dims()`
220 /// REQUIRES: `size >= 0`
221 void InsertDim(int d, int64_t size);
222
223 /// Same as `InsertDim` but returns a `Status`.
224 /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
225 /// `CHECK`-fail crashes.
226 Status InsertDimWithStatus(int d, int64_t size);
227
228 /// \brief Modifies the size of the dimension `d` to be `size`
229 /// REQUIRES: `0 <= d < dims()`
230 /// REQUIRES: `size >= 0`
231 void set_dim(int d, int64_t size);
232
233 /// Same as `set_dim` but returns a `Status`.
234 /// Use if unsure if requirements in `set_dim` are satistified, to prevent
235 /// `CHECK`-fail crashes.
236 Status SetDimWithStatus(int d, int64_t size);
237
238 /// \brief Removes dimension `d` from the `TensorShape`.
239 /// REQUIRES: `0 <= d < dims()`
240 void RemoveDim(int d) {
241 CHECK_GE(d, 0);
242 RemoveDimRange(d, d + 1);
243 }
244
245 /// Same as `RemoveDim` but returns a `Status`.
246 /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
247 Status RemoveDimWithStatus(int64_t d) {
248 if (TF_PREDICT_FALSE(d < 0)) {
249 return errors::Internal(
250 "Expected dimension index to be non-negative, got ", d);
251 }
252 return RemoveDimRangeWithStatus(d, d + 1);
253 }
254
255 /// \brief Removes last `n` dimensions from the `TensorShape`.
256 /// REQUIRES: `0 <= n <= dims()`
257 void RemoveLastDims(int n) {
258 CHECK_LE(n, dims());
259 RemoveDimRange(dims() - n, dims());
260 }
261
262 /// Same as `RemoveLastDims` but returns a `Status`.
263 /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
264 Status RemoveLastDimsWithStatus(int64_t n) {
265 if (TF_PREDICT_FALSE(n < dims())) {
266 return errors::Internal("Expected dimension index to be at most ", dims(),
267 " got ", n);
268 }
269 return RemoveDimRangeWithStatus(dims() - n, dims());
270 }
271
272 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
273 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
274 /// Python). The same is true for negative values of `begin`.
275 /// REQUIRES: `-(dims()+1) <= begin <= dims()`
276 /// REQUIRES: `-(dims()+1) <= end <= dims()`
277 void RemoveDimRange(int begin, int end);
278
279 /// Same as `RemoveDimRange` but returns a `Status`.
280 /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
281 /// prevent `CHECK`-fail crashes.
282 Status RemoveDimRangeWithStatus(int begin, int end);
283
284 /// Return whether the rank is unknown
285 bool unknown_rank() const {
286 return kIsPartial && ndims_byte() == kUnknownRank;
287 }
288
289 /// Return the number of dimensions in the tensor.
290 /// Can be -1 meaning unknown rank for PartialTensorShape.
291 int dims() const {
292 uint8 dims = ndims_byte();
293 return kIsPartial && dims == kUnknownRank ? -1 : dims;
294 }
295
296 /// \brief Returns the number of elements in dimension `d`.
297 /// REQUIRES: `0 <= d < dims()`
298 // TODO(touts): Rename to `dimension()` to match
299 // `Eigen::Tensor::dimension()`?
300 int64_t dim_size(int d) const;
301
302 /// Returns sizes of all dimensions.
303 // Returns an empty list for unknown rank PartialTensorShape.
304 gtl::InlinedVector<int64_t, 4> dim_sizes() const;
305
306 /// Return true iff the rank and all of the dimensions are well defined
307 // TODO(irving): Rename to is_fully_defined now that it's fast.
308 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
309
310 /// Fill `*proto` from `*this`.
311 void AsProto(TensorShapeProto* proto) const;
312 TensorShapeProto AsProto() const;
313
314 /// For iterating through the dimensions.
315 TensorShapeIter<Shape> begin() const;
316 TensorShapeIter<Shape> end() const;
317
318 protected:
319 // Optimized constructor for a shape representing an empty vector.
320 //
321 // This constructor is provided to optimize the default constructor for
322 // `Tensor`.
323 explicit TensorShapeBase(DataType dt);
324
325 private:
326 Status RecomputeNumElements();
327 Status InitDims(gtl::ArraySlice<int64_t> dim_sizes);
328
329 // True for PartialTensorShape, false for TensorShape
330 static constexpr bool kIsPartial =
331 std::is_same<Shape, PartialTensorShape>::value;
332 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333 "Shape is neither TensorShape nor PartialTensorShape");
334
335 // Used by AddDim and MakeShapeHelper. Does no error checking.
336 void UnsafeAddDim(int64_t size, int64_t new_num_elements);
337
338 // For use by TensorShapeUtils::MakeShape
339 template <class T, class S>
340 friend Status MakeShapeHelper(const T*, int64_t, S*);
341};
342
343/// Outputs `TensorShapeBase` to `std::ostream`.
344template <typename Shape>
345std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346 return os << tsb.DebugString();
347}
348
349/// Represents the shape of a Tensor.
350///
351/// A tensor's shape is denoted by its number of dimensions and a size for each
352/// dimension. For example, a Tensor represented by a 3 x 4 matrix would have
353/// a shape of 2-D, [3,4].
354///
355/// If you know the exact shape of your Tensor when you create the TensorShape
356/// object, you can specify it then, or you can create a TensorShape with
357/// zero dimensions and one element, and call AddDim() to add dimensions later.
358class TensorShape : public TensorShapeBase<TensorShape> {
359 public:
360 using TensorShapeBase<TensorShape>::TensorShapeBase;
361
362 // These factory methods should be used instead of the constructors that take
363 // an array of sizes if calling code cannot validate that the sizes specify a
364 // valid `TensorShape`.
365 // The value in `*out` is valid iff the returned value is `Status::OK`.
366 static Status BuildTensorShape(gtl::ArraySlice<int64_t> dim_sizes,
367 TensorShape* out) {
368 return BuildTensorShapeBase(dim_sizes, out);
369 }
370 static Status BuildTensorShape(std::initializer_list<int64_t> dim_sizes,
371 TensorShape* out) {
372 return BuildTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out);
373 }
374 static Status BuildTensorShape(const TensorShapeProto& proto,
375 TensorShape* out) {
376 return BuildTensorShapeBase(proto, out);
377 }
378
379 static StatusOr<TensorShape> BuildTensorShape(const TensorShapeProto& proto) {
380 TensorShape out;
381 TF_RETURN_IF_ERROR(BuildTensorShape(proto, &out));
382 return out;
383 }
384
385 /// Allow a TensorShape to be used as a PartialTensorShape without copying
386 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit)
387
388 /// Returns true if `*this` and `b` have the same sizes. Ignores
389 /// dimension names.
390 bool IsSameSize(const TensorShape& b) const;
391 bool operator==(const TensorShape& b) const { return IsSameSize(b); }
392 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
393
394 /// Fill `*dsizes` from `*this`.
395 /// Notice: Using IndexType=int32 in combination with To32Bit() can
396 /// significantly improve performance on GPU.
397 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
398 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
399
400 // Same as `AsEigenDSizes()` but returns a `Status` instead.
401 // Use this method to surface error to user instead of crashing if `NDMIS` is
402 // not equal to `dims()`.
403 // Caller must take ownership of `out`.
404 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
405 Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
406
407 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
408 /// which case we pad the rest of the sizes with 1.
409 /// Notice: Using IndexType=int32 in combination with To32Bit() can
410 /// significantly improve performance on GPU.
411 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
412 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
413
414 // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
415 // Use this method to surface error to user instead of crashing if `NDMIS` is
416 // not equal to `dims()`.
417 // Caller must take ownership of `out`.
418 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
419 Status AsEigenDSizesWithPaddingWithStatus(
420 Eigen::DSizes<IndexType, NDIMS>* out) const;
421
422 private:
423 // These CHECK fail to ease debugging.
424 // REQUIRES: dims() == NDIMS
425 void CheckDimsEqual(int NDIMS) const;
426 // REQUIRES: dims() <= NDIMS
427 void CheckDimsAtMost(int NDIMS) const;
428
429 // Fill output from `*this`.
430 // Helper method for common code between `AsEigenDSize()` and
431 // `AsEigenDSizeWithStatus()`.
432 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
433 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
434
435 // Fill output from `*this`.
436 // Helper method for common code between `AsEigenDSizesWithPadding()` and
437 // `AsEigenDSizeWithPaddingWithStatus()`.
438 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
439 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
440
441 // For access to TensorShapeBase(DataType).
442 friend class Tensor;
443};
444
445/// Outputs `TensorShapeBase` to `std::ostream`.
446inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
447 return os << ts.DebugString();
448}
449
450/// Represents the value of one dimension in a TensorShape.
451struct TensorShapeDim {
452 explicit TensorShapeDim(int64_t s) : size(s) {}
453 int64_t size;
454};
455
456// START_SKIP_DOXYGEN
457template <class Shape>
458class TensorShapeIter {
459 public:
460 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
461 bool operator==(const TensorShapeIter& rhs) {
462 DCHECK(shape_ == rhs.shape_);
463 return d_ == rhs.d_;
464 }
465 bool operator!=(const TensorShapeIter& rhs) {
466 DCHECK(shape_ == rhs.shape_);
467 return d_ != rhs.d_;
468 }
469 void operator++() { ++d_; }
470 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
471
472 private:
473 const Shape* shape_;
474 int d_;
475};
476// END_SKIP_DOXYGEN
477
478/// \brief Static helper routines for `TensorShape`. Includes a few common
479/// predicates on a tensor shape.
480class TensorShapeUtils {
481 public:
482 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
483
484 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
485
486 static bool IsVectorOrHigher(const TensorShape& shape) {
487 return shape.dims() >= 1;
488 }
489
490 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
491
492 static bool IsSquareMatrix(const TensorShape& shape) {
493 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
494 }
495
496 static bool IsMatrixOrHigher(const TensorShape& shape) {
497 return shape.dims() >= 2;
498 }
499
500 /// \brief Returns a `TensorShape` whose dimensions are
501 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
502 static Status MakeShape(const int32* dims, int64_t n, TensorShape* out);
503 static Status MakeShape(const int64_t* dims, int64_t n, TensorShape* out);
504 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
505 static Status MakeShape(gtl::ArraySlice<int64_t> shape, TensorShape* out);
506 static Status MakeShape(const int32* dims, int64_t n,
507 PartialTensorShape* out);
508 static Status MakeShape(const int64_t* dims, int64_t n,
509 PartialTensorShape* out);
510 static Status MakeShape(gtl::ArraySlice<int32> shape,
511 PartialTensorShape* out);
512 static Status MakeShape(gtl::ArraySlice<int64_t> shape,
513 PartialTensorShape* out);
514
515 static std::string ShapeListString(
516 const gtl::ArraySlice<TensorShape>& shapes);
517
518 /// \brief Returns true iff `shape` starts with `prefix`.
519 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
520
521 /// \brief Returns true iff `shape` ends with `suffix`.
522 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
523
524 /// \brief Returns the product of values in an int64 array,
525 /// or a failing Status if the array represents a value larger than
526 /// a `TensorShape` can hold.
527 static Status NumElements(gtl::ArraySlice<int64_t> shape,
528 int64_t* num_elements);
529};
530
531/// Manages the partially known dimensions of a Tensor and their sizes.
532class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
533 public:
534 PartialTensorShape() {}
535 using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
536
537 // These factory methods should be used instead of the constructors that take
538 // an array of sizes if calling code cannot validate that the sizes specify a
539 // valid `PartialTensorShape`.
540 // The value in `*out` is valid iff the returned value is `Status::OK`.
541 static Status BuildPartialTensorShape(gtl::ArraySlice<int64_t> dim_sizes,
542 PartialTensorShape* out) {
543 return BuildTensorShapeBase(dim_sizes, out);
544 }
545 static Status BuildPartialTensorShape(
546 std::initializer_list<int64_t> dim_sizes, PartialTensorShape* out) {
547 return BuildPartialTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out);
548 }
549 static Status BuildPartialTensorShape(const TensorShapeProto& proto,
550 PartialTensorShape* out) {
551 return BuildTensorShapeBase(proto, out);
552 }
553
554 static StatusOr<PartialTensorShape> BuildPartialTensorShape(
555 const TensorShapeProto& proto) {
556 PartialTensorShape out;
557 TF_RETURN_IF_ERROR(BuildTensorShapeBase(proto, &out));
558 return out;
559 }
560
561 /// Add a dimension to the end ("inner-most"), returns a new
562 /// PartialTensorShape.
563 /// REQUIRES: `size >= -1`, where -1 means unknown
564 PartialTensorShape Concatenate(int64_t size) const;
565
566 /// Similar to `Concatenate` but returning `Status`.
567 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
568 /// are to be avoided.
569 Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const;
570
571 /// Appends all the dimensions from `shape`. Returns a new
572 /// PartialTensorShape.
573 PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
574
575 /// Similar to `Concatenate` but returning `Status`.
576 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
577 /// are to be avoided.
578 Status ConcatenateWithStatus(const PartialTensorShape& shape,
579 PartialTensorShape* out) const;
580
581 /// Merges all the dimensions from `shape`. Returns
582 /// `InvalidArgument` error if either `shape` has a different rank
583 /// or if any of the dimensions are incompatible.
584 Status MergeWith(const PartialTensorShape& shape,
585 PartialTensorShape* result) const;
586
587 /// Exact equality test. Returns true iff the ranks match (i.e., both are
588 /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
589 /// both dimensions are known, or both are known and equal). This is a
590 /// stronger condition that IsCompatibleWith.
591 bool IsIdenticalTo(const PartialTensorShape& shape) const;
592
593 /// Return true iff the ranks match, and if the
594 /// dimensions all either match or one is unknown.
595 bool IsCompatibleWith(const PartialTensorShape& shape) const;
596
597 // Fill `*shape` from `*this`.
598 // If `*this` is not fully defined, returns false and
599 // `*shape` is left in an intermediate state. Otherwise
600 // returns true.
601 bool AsTensorShape(TensorShape* shape) const;
602
603 /// \brief Returns a `PartialTensorShape` whose dimensions are
604 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
605 /// considered "unknown".
606 template <class T>
607 static Status MakePartialShape(const T* dims, int n,
608 PartialTensorShape* out) {
609 return TensorShapeUtils::MakeShape(dims, n, out);
610 }
611};
612
613/// \brief Static helper routines for `PartialTensorShape`. Includes a few
614/// common predicates on a partially known tensor shape.
615class PartialTensorShapeUtils {
616 public:
617 static std::string PartialShapeListString(
618 const gtl::ArraySlice<PartialTensorShape>& shapes);
619
620 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
621 const gtl::ArraySlice<PartialTensorShape>& shapes1);
622
623 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
624 const gtl::ArraySlice<PartialTensorShape>& shapes1);
625};
626
627// ----------------------------------------------------------------------------
628// Template method implementation details below
629// ----------------------------------------------------------------------------
630
631template <int NDIMS, typename IndexType>
632Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
633 Eigen::DSizes<IndexType, NDIMS> dsizes;
634 for (int d = 0; d < NDIMS; d++) {
635 dsizes[d] = static_cast<IndexType>(dim_size(d));
636 }
637 return dsizes;
638}
639
640template <int NDIMS, typename IndexType>
641Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
642 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
643 Eigen::DSizes<IndexType, NDIMS> dsizes;
644 for (int d = 0; d < dims(); d++) {
645 dsizes[d] = static_cast<IndexType>(dim_size(d));
646 }
647 for (int d = dims(); d < NDIMS; d++) {
648 dsizes[d] = 1;
649 }
650 return dsizes;
651}
652
653template <int NDIMS, typename IndexType>
654Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
655 CheckDimsEqual(NDIMS);
656 return AsEigenDSizesCopy<NDIMS, IndexType>();
657}
658
659template <int NDIMS, typename IndexType>
660Status TensorShape::AsEigenDSizesWithStatus(
661 Eigen::DSizes<IndexType, NDIMS>* out) const {
662 if (TF_PREDICT_FALSE(NDIMS != dims())) {
663 return errors::Internal("Asking for tensor of ", NDIMS,
664 " dimensions from a tensor of ", dims(),
665 " dimensions");
666 }
667 *out = AsEigenDSizesCopy<NDIMS, IndexType>();
668}
669
670template <int NDIMS, typename IndexType>
671Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
672 CheckDimsAtMost(NDIMS);
673 return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
674}
675
676template <int NDIMS, typename IndexType>
677Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
678 Eigen::DSizes<IndexType, NDIMS>* out) const {
679 if (TF_PREDICT_FALSE(NDIMS < dims())) {
680 return errors::Internal("Asking for tensor of at least ", NDIMS,
681 " dimensions from a tensor of ", dims(),
682 " dimensions");
683 }
684 *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
685}
686
687// ----------------------------------------------------------------------------
688// Inlining of some performance critical routines
689// ----------------------------------------------------------------------------
690
691inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
692 num_elements_ = b.num_elements_;
693 if (b.tag() != REP_OUT_OF_LINE) {
694 memcpy(buf(), b.buf(), sizeof(u_.buf));
695 // memcpy above Implicitly does:
696 // set_ndims_byte(b.ndims_byte());
697 // set_tag(b.tag());
698 } else {
699 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate
700 SlowCopyFrom(b);
701 }
702}
703
704inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
705 num_elements_ = b.num_elements_;
706 memcpy(buf(), b.buf(), sizeof(u_.buf));
707 // memcpy above Implicitly does:
708 // set_ndims_byte(b.ndims_byte());
709 // set_tag(b.tag());
710 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
711}
712
713inline TensorShapeRep::~TensorShapeRep() {
714 if (tag() == REP_OUT_OF_LINE) {
715 DestructorOutOfLine();
716 }
717}
718
719inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
720 num_elements_ = b.num_elements_;
721 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
722 memcpy(buf(), b.buf(), sizeof(u_.buf));
723 // memcpy above implicitly also does:
724 // set_tag(b.tag());
725 // set_ndims_byte(b.ndims_byte());
726 } else {
727 SlowCopyFrom(b);
728 }
729}
730
731inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
732 if (tag() == REP_OUT_OF_LINE) {
733 DestructorOutOfLine();
734 }
735 num_elements_ = b.num_elements_;
736 memcpy(buf(), b.buf(), sizeof(u_.buf));
737 // memcpy above Implicitly does:
738 // set_ndims_byte(b.ndims_byte());
739 // set_tag(b.tag());
740 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
741}
742
743inline TensorShape::operator const PartialTensorShape&() const {
744 // Downcast to the shared representation and upcast to PartialTensorShape
745 const TensorShapeRep* rep = this;
746 return *static_cast<const PartialTensorShape*>(rep);
747}
748
749template <class Shape>
750inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
751 set_tag(REP16);
752 set_data_type(dt);
753
754 // Optimized implementation of InitDims() where the shape is statically known
755 // to be {0}.
756 set_ndims_byte(1);
757 uint16* dst = as16()->dims_;
758 *dst = 0;
759 set_num_elements(0);
760}
761
762// Declare explicit instantiations in .cc file
763extern template class TensorShapeBase<TensorShape>;
764extern template class TensorShapeBase<PartialTensorShape>;
765
766// A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
767// often used in shape inference.
768struct DtypeAndPartialTensorShape {
769 DataType dtype;
770 PartialTensorShape shape;
771};
772
773} // namespace tensorflow
774
775#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
776